diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index a249c3e1e..e72d1215e 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -21,6 +21,7 @@ import ( "github.com/facebookincubator/ent/dialect" "github.com/facebookincubator/ent/entc/integration/ent" + "github.com/facebookincubator/ent/entc/integration/ent/enttest" "github.com/facebookincubator/ent/entc/integration/ent/file" "github.com/facebookincubator/ent/entc/integration/ent/group" "github.com/facebookincubator/ent/entc/integration/ent/groupinfo" @@ -30,17 +31,15 @@ import ( "github.com/facebookincubator/ent/entc/integration/ent/user" "github.com/stretchr/testify/mock" - "github.com/go-sql-driver/mysql" + _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/require" ) func TestSQLite(t *testing.T) { - client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - require.NoError(t, err) + client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1", opts) defer client.Close() - require.NoError(t, client.Schema.Create(context.Background(), migrate.WithDropColumn(true), migrate.WithDropIndex(true))) for _, tt := range tests { name := runtime.FuncForPC(reflect.ValueOf(tt).Pointer()).Name() t.Run(name[strings.LastIndex(name, ".")+1:], func(t *testing.T) { @@ -55,13 +54,8 @@ func TestMySQL(t *testing.T) { for version, port := range map[string]int{"56": 3306, "57": 3307, "8": 3308} { addr := net.JoinHostPort("localhost", strconv.Itoa(port)) t.Run(version, func(t *testing.T) { - client, err := ent.Open("mysql", (&mysql.Config{ - User: "root", Passwd: "pass", Net: "tcp", Addr: addr, - DBName: "test", ParseTime: true, AllowNativePasswords: true, - }).FormatDSN()) - require.NoError(t, err) + client := enttest.Open(t, dialect.MySQL, fmt.Sprintf("root:pass@tcp(%s)/test?parseTime=True", addr), opts) defer client.Close() - require.NoError(t, client.Schema.Create(context.Background(), migrate.WithDropColumn(true), migrate.WithDropIndex(true))) for _, tt := range tests { name := runtime.FuncForPC(reflect.ValueOf(tt).Pointer()).Name() t.Run(name[strings.LastIndex(name, ".")+1:], func(t *testing.T) { @@ -76,10 +70,8 @@ func TestMySQL(t *testing.T) { func TestPostgres(t *testing.T) { for version, port := range map[string]int{"10": 5430, "11": 5431, "12": 5432} { t.Run(version, func(t *testing.T) { - client, err := ent.Open(dialect.Postgres, fmt.Sprintf("host=localhost port=%d user=postgres dbname=test password=pass sslmode=disable", port)) - require.NoError(t, err) + client := enttest.Open(t, dialect.Postgres, fmt.Sprintf("host=localhost port=%d user=postgres dbname=test password=pass sslmode=disable", port), opts) defer client.Close() - require.NoError(t, client.Schema.Create(context.Background(), migrate.WithDropColumn(true), migrate.WithDropIndex(true))) for _, tt := range tests { name := runtime.FuncForPC(reflect.ValueOf(tt).Pointer()).Name() t.Run(name[strings.LastIndex(name, ".")+1:], func(t *testing.T) { @@ -91,34 +83,39 @@ func TestPostgres(t *testing.T) { } } -// tests for all drivers to run. -var tests = [...]func(*testing.T, *ent.Client){ - Tx, - Indexes, - Types, - Clone, - Sanity, - Paging, - Select, - Delete, - Relation, - Predicate, - AddValues, - ClearFields, - UniqueConstraint, - O2OTwoTypes, - O2OSameType, - O2OSelfRef, - O2MTwoTypes, - O2MSameType, - M2MSelfRef, - M2MSameType, - M2MTwoTypes, - DefaultValue, - ImmutableValue, - Sensitive, - EagerLoading, -} +var ( + opts = enttest.WithMigrateOptions( + migrate.WithDropIndex(true), + migrate.WithDropColumn(true), + ) + tests = [...]func(*testing.T, *ent.Client){ + Tx, + Indexes, + Types, + Clone, + Sanity, + Paging, + Select, + Delete, + Relation, + Predicate, + AddValues, + ClearFields, + UniqueConstraint, + O2OTwoTypes, + O2OSameType, + O2OSelfRef, + O2MTwoTypes, + O2MSameType, + M2MSelfRef, + M2MSameType, + M2MTwoTypes, + DefaultValue, + ImmutableValue, + Sensitive, + EagerLoading, + } +) func Sanity(t *testing.T, client *ent.Client) { require := require.New(t)