diff --git a/dialect/sql/schema/atlas.go b/dialect/sql/schema/atlas.go index d423a869d..a83e87c95 100644 --- a/dialect/sql/schema/atlas.go +++ b/dialect/sql/schema/atlas.go @@ -174,50 +174,18 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}), ) } + var ( + err error + plan *migrate.Plan + ) switch a.mode { case ModeInspect: - // Do nothing here, simply inspect later on. + plan, err = a.planInspect(ctx, a.sqlDialect, name, tables) case ModeReplay: - // We consider a database clean if there are no tables in the connected schema. - s, err := a.atDriver.InspectSchema(ctx, "", nil) - if err != nil { - return err - } - if len(s.Tables) > 0 { - return migrate.NotCleanError{Reason: fmt.Sprintf("found table %q", s.Tables[0].Name)} - } - // Clean up once done. - defer func() { - // We clean a database by dropping all tables inside the connected schema. - s, err = a.atDriver.InspectSchema(ctx, "", nil) - if err != nil { - return - } - tbls := make([]schema.Change, len(s.Tables)) - for i, t := range s.Tables { - tbls[i] = &schema.DropTable{T: t} - } - if err2 := a.atDriver.ApplyChanges(ctx, tbls); err2 != nil { - if err != nil { - err = fmt.Errorf("%v: %w", err2, err) - return - } - err = err2 - return - } - }() - // Replay the migration directory on the database. - ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{}) - if err != nil { - return err - } - if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) { - return err - } + plan, err = a.planReplay(ctx, name, tables) default: return fmt.Errorf("unknown migration mode: %q", a.mode) } - plan, err := a.plan(ctx, a.sqlDialect, name, tables) if err != nil { return err } @@ -228,6 +196,23 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er return migrate.NewPlanner(nil, a.dir, opts...).WritePlan(plan) } +func (a *Atlas) cleanSchema(ctx context.Context, name string, err0 error) (err error) { + defer func() { + if err0 != nil { + err = fmt.Errorf("%v: %w", err0, err) + } + }() + s, err := a.atDriver.InspectSchema(ctx, name, nil) + if err != nil { + return err + } + drop := make([]schema.Change, len(s.Tables)) + for i, t := range s.Tables { + drop[i] = &schema.DropTable{T: t} + } + return a.atDriver.ApplyChanges(ctx, drop) +} + // VerifyTableRange ensures, that the defined autoincrement starting value is set for each table as defined by the // TypTable. This is necessary for MySQL versions < 8.0. In those versions the defined starting value for AUTOINCREMENT // columns was stored in memory, and when a server restarts happens and there are no rows yet in a table, the defined @@ -673,7 +658,7 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) { } defer func() { a.atDriver = nil }() if err := func() error { - plan, err := a.plan(ctx, tx, "changes", tables) + plan, err := a.planInspect(ctx, tx, "changes", tables) if err != nil { return err } @@ -703,10 +688,9 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) { return tx.Commit() } -// plan creates the current state by inspecting the connected database, computing the current state of the Ent schema +// planInspect creates the current state by inspecting the connected database, computing the current state of the Ent schema // and proceeds to diff the changes to create a migration plan. -// before diffing. -func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) { +func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) { current, err := a.atDriver.InspectSchema(ctx, "", &schema.InspectOptions{ Tables: func() (t []string) { for i := range tables { @@ -726,26 +710,89 @@ func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string, } a.types = types } - desired, err := a.StateReader(tables...).ReadState(ctx) + realm, err := a.StateReader(tables...).ReadState(ctx) if err != nil { return nil, err } - // Diff changes. - changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, &schema.Schema{ - Name: current.Name, - Attrs: current.Attrs, - Tables: desired.Schemas[0].Tables, - }) + desired := realm.Schemas[0] + desired.Name, desired.Attrs = current.Name, current.Attrs + return a.diff(ctx, name, current, desired, a.types[len(types):]) +} + +func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*migrate.Plan, error) { + // We consider a database clean if there are no tables in the connected schema. + s, err := a.atDriver.InspectSchema(ctx, "", nil) if err != nil { return nil, err } - // Plan changes. - plan, err := a.atDriver.PlanChanges(ctx, name, changes) + if len(s.Tables) > 0 { + return nil, migrate.NotCleanError{Reason: fmt.Sprintf("found table %q", s.Tables[0].Name)} + } + // Replay the migration directory on the database. + ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{}) + if err != nil { + return nil, err + } + if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) { + return nil, a.cleanSchema(ctx, "", err) + } + // Inspect the current schema (migration directory). + current, err := a.atDriver.InspectSchema(ctx, "", nil) + if err != nil { + return nil, a.cleanSchema(ctx, "", err) + } + var types []string + if a.universalID { + if types, err = a.loadTypes(ctx, a.sqlDialect); err != nil && !errors.Is(err, errTypeTableNotFound) { + return nil, a.cleanSchema(ctx, "", err) + } + a.types = types + } + if err := a.cleanSchema(ctx, "", nil); err != nil { + return nil, fmt.Errorf("clean schemas after migration replaying: %w", err) + } + desired, err := a.tables(tables) + if err != nil { + return nil, err + } + // In case of replay mode, normalize the desired state (i.e. ent/schema). + if nr, ok := a.atDriver.(schema.Normalizer); ok { + ns, err := nr.NormalizeSchema(ctx, schema.New(current.Name).AddTables(desired...)) + if err != nil { + return nil, err + } + if len(ns.Tables) != len(desired) { + return nil, fmt.Errorf("unexpected number of tables after normalization: %d != %d", len(ns.Tables), len(desired)) + } + // Ensure all tables exist in the normalized format and the order is preserved. + for i, t := range desired { + d, ok := ns.Table(t.Name) + if !ok { + return nil, fmt.Errorf("table %q not found after normalization", t.Name) + } + desired[i] = d + } + } + return a.diff(ctx, name, current, + &schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: desired}, a.types[len(types):], + // For BC reason, we omit the schema qualifier from the migration scripts, + // but that is currently limiting versioned migration to a single schema. + func(opts *migrate.PlanOptions) { + var noQualifier string + opts.SchemaQualifier = &noQualifier + }, + ) +} + +func (a *Atlas) diff(ctx context.Context, name string, current, desired *schema.Schema, newTypes []string, opts ...migrate.PlanOption) (*migrate.Plan, error) { + changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, desired) + if err != nil { + return nil, err + } + plan, err := a.atDriver.PlanChanges(ctx, name, changes, opts...) if err != nil { return nil, err } - // Insert new types. - newTypes := a.types[len(types):] if len(newTypes) > 0 { plan.Changes = append(plan.Changes, &migrate.Change{ Cmd: a.sqlDialect.atTypeRangeSQL(newTypes...), @@ -893,7 +940,7 @@ func (a *Atlas) atDefault(c1 *Column, c2 *schema.Column) error { } c2.SetDefault(&schema.RawExpr{X: string(x)}) case map[string]Expr: - d, ok := x[a.driver.Dialect()] + d, ok := x[a.sqlDialect.Dialect()] if !ok { return nil } diff --git a/entc/integration/migrate/entv2/generate.go b/entc/integration/migrate/entv2/generate.go index ada459103..8c7950cd6 100644 --- a/entc/integration/migrate/entv2/generate.go +++ b/entc/integration/migrate/entv2/generate.go @@ -4,4 +4,4 @@ package entv2 -//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by ent, DO NOT EDIT." ./schema +//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/versioned-migration --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by ent, DO NOT EDIT." ./schema diff --git a/entc/integration/migrate/entv2/migrate/migrate.go b/entc/integration/migrate/entv2/migrate/migrate.go index 0a632cc3d..9b2b43c5e 100644 --- a/entc/integration/migrate/entv2/migrate/migrate.go +++ b/entc/integration/migrate/entv2/migrate/migrate.go @@ -58,6 +58,38 @@ func Create(ctx context.Context, s *Schema, tables []*schema.Table, opts ...sche return migrate.Create(ctx, tables...) } +// Diff compares the state read from a database connection or migration directory with +// the state defined by the Ent schema. Changes will be written to new migration files. +func Diff(ctx context.Context, url string, opts ...schema.MigrateOption) error { + return NamedDiff(ctx, url, "changes", opts...) +} + +// NamedDiff compares the state read from a database connection or migration directory with +// the state defined by the Ent schema. Changes will be written to new named migration files. +func NamedDiff(ctx context.Context, url, name string, opts ...schema.MigrateOption) error { + return schema.Diff(ctx, url, name, Tables, opts...) +} + +// Diff creates a migration file containing the statements to resolve the diff +// between the Ent schema and the connected database. +func (s *Schema) Diff(ctx context.Context, opts ...schema.MigrateOption) error { + migrate, err := schema.NewMigrate(s.drv, opts...) + if err != nil { + return fmt.Errorf("ent/migrate: %w", err) + } + return migrate.Diff(ctx, Tables...) +} + +// NamedDiff creates a named migration file containing the statements to resolve the diff +// between the Ent schema and the connected database. +func (s *Schema) NamedDiff(ctx context.Context, name string, opts ...schema.MigrateOption) error { + migrate, err := schema.NewMigrate(s.drv, opts...) + if err != nil { + return fmt.Errorf("ent/migrate: %w", err) + } + return migrate.NamedDiff(ctx, name, Tables...) +} + // WriteTo writes the schema changes to w instead of running them against the database. // // if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { diff --git a/entc/integration/migrate/migrate_test.go b/entc/integration/migrate/migrate_test.go index 7daa867f1..7fbb155bd 100644 --- a/entc/integration/migrate/migrate_test.go +++ b/entc/integration/migrate/migrate_test.go @@ -82,10 +82,9 @@ func TestMySQL(t *testing.T) { vdrv, err := sql.Open("mysql", fmt.Sprintf("root:pass@tcp(localhost:%d)/versioned_migrate?parseTime=True", port)) require.NoError(t, err, "connecting to versioned migrate database") defer vdrv.Close() - Versioned(t, vdrv, - fmt.Sprintf("mysql://root:pass@localhost:%d/versioned_migrate_dev?parseTime=True", port), - versioned.NewClient(versioned.Driver(vdrv)), - ) + devURL := fmt.Sprintf("mysql://root:pass@localhost:%d/versioned_migrate_dev?parseTime=True", port) + Versioned(t, vdrv, devURL, versioned.NewClient(versioned.Driver(vdrv))) + ConsistentVersioned(t, devURL) }) } } @@ -98,17 +97,17 @@ func TestPostgres(t *testing.T) { require.NoError(t, err) defer root.Close() ctx := context.Background() - err = root.Exec(ctx, "DROP DATABASE IF EXISTS migrate", []any{}, new(sql.Result)) + err = root.Exec(ctx, "DROP DATABASE IF EXISTS migrate", []any{}, nil) require.NoError(t, err) - err = root.Exec(ctx, "CREATE DATABASE migrate", []any{}, new(sql.Result)) + err = root.Exec(ctx, "CREATE DATABASE migrate", []any{}, nil) require.NoError(t, err, "creating database") - defer root.Exec(ctx, "DROP DATABASE migrate", []any{}, new(sql.Result)) + defer root.Exec(ctx, "DROP DATABASE migrate", []any{}, nil) drv, err := sql.Open(dialect.Postgres, dsn+" dbname=migrate") require.NoError(t, err, "connecting to migrate database") defer drv.Close() - err = drv.Exec(ctx, "CREATE TYPE customtype as range (subtype = time)", []any{}, new(sql.Result)) + err = drv.Exec(ctx, "CREATE TYPE customtype as range (subtype = time)", []any{}, nil) require.NoError(t, err, "creating custom type") clientv1 := entv1.NewClient(entv1.Driver(drv)) @@ -147,16 +146,21 @@ func TestPostgres(t *testing.T) { vdrv, err := sql.Open(dialect.Postgres, dsn+" dbname=versioned_migrate") require.NoError(t, err, "connecting to versioned migrate database") defer vdrv.Close() - require.NoError(t, root.Exec(ctx, "DROP DATABASE IF EXISTS versioned_migrate", []any{}, new(sql.Result))) + require.NoError(t, root.Exec(ctx, "DROP DATABASE IF EXISTS versioned_migrate", []any{}, nil)) require.NoError(t, root.Exec(ctx, "CREATE DATABASE versioned_migrate", []any{}, new(sql.Result))) defer root.Exec(ctx, "DROP DATABASE versioned_migrate", []any{}, new(sql.Result)) - require.NoError(t, root.Exec(ctx, "DROP DATABASE IF EXISTS versioned_migrate_dev", []any{}, new(sql.Result))) + require.NoError(t, root.Exec(ctx, "DROP DATABASE IF EXISTS versioned_migrate_dev", []any{}, nil)) require.NoError(t, root.Exec(ctx, "CREATE DATABASE versioned_migrate_dev", []any{}, new(sql.Result))) defer root.Exec(ctx, "DROP DATABASE versioned_migrate_dev", []any{}, new(sql.Result)) - Versioned(t, vdrv, - fmt.Sprintf("postgres://postgres:pass@localhost:%d/versioned_migrate_dev?sslmode=disable&search_path=public", port), - versioned.NewClient(versioned.Driver(vdrv)), - ) + devURL := fmt.Sprintf("postgres://postgres:pass@localhost:%d/versioned_migrate_dev?sslmode=disable&search_path=public", port) + Versioned(t, vdrv, devURL, versioned.NewClient(versioned.Driver(vdrv))) + // Create the necessary custom types for the versioned schema. + dev, err := sql.Open(dialect.Postgres, dsn+" dbname=versioned_migrate_dev") + require.NoError(t, err, "connecting to versioned_migrate_dev database") + defer dev.Close() + err = dev.Exec(ctx, "CREATE TYPE customtype as range (subtype = time)", []any{}, nil) + require.NoError(t, err, "creating custom type on dev database") + ConsistentVersioned(t, devURL) }) } } @@ -165,7 +169,6 @@ func TestSQLite(t *testing.T) { drv, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") require.NoError(t, err) defer drv.Close() - ctx := context.Background() client := entv2.NewClient(entv2.Driver(drv)) require.NoError( @@ -310,6 +313,32 @@ func TestStorageKey(t *testing.T) { require.Equal(t, "user_friend_id2", migratev2.FriendsTable.ForeignKeys[1].Symbol) } +func ConsistentVersioned(t *testing.T, devURL string) { + p := t.TempDir() + ctx := context.Background() + dir, err := migrate.NewLocalDir(p) + require.NoError(t, err) + opts := []schema.MigrateOption{ + schema.WithDir(dir), // provide migration directory + schema.WithMigrationMode(schema.ModeReplay), // provide migration mode + schema.WithDialect(strings.Split(devURL, "://")[0]), // Ent dialect to use + schema.WithFormatter(migrate.DefaultFormatter), // Default Atlas formatter + } + // Run diff should generate a single SQL file containing the diff. + err = migratev2.NamedDiff(ctx, devURL, "first", opts...) + require.NoError(t, err) + files, err := dir.Files() + require.NoError(t, err) + require.Len(t, files, 1) + require.NotEmpty(t, files[0].Bytes()) + // Re-run diff should not generate any new files. + err = migratev2.NamedDiff(ctx, devURL, "second", opts...) + require.NoError(t, err) + files, err = dir.Files() + require.NoError(t, err) + require.Len(t, files, 1) +} + func Versioned(t *testing.T, drv sql.ExecQuerier, devURL string, client *versioned.Client) { ctx := context.Background() diff --git a/examples/migration/ent/migrate/schema.go b/examples/migration/ent/migrate/schema.go index 8b772af30..db54d23dd 100644 --- a/examples/migration/ent/migrate/schema.go +++ b/examples/migration/ent/migrate/schema.go @@ -33,7 +33,7 @@ var ( func init() { UsersTable.Annotation = &entsql.Annotation{ - Check: "(`age` > 0)", + Check: "age > 0", } UsersTable.Annotation.Checks = map[string]string{ "name_not_empty": "name <> ''", diff --git a/examples/migration/ent/schema/user.go b/examples/migration/ent/schema/user.go index eb9daae3b..123b4ad12 100644 --- a/examples/migration/ent/schema/user.go +++ b/examples/migration/ent/schema/user.go @@ -23,9 +23,10 @@ func (User) Fields() []ent.Field { // Annotations of the User. func (User) Annotations() []schema.Annotation { return []schema.Annotation{ - // Unnamed check constraints should be identical to their definition in the - // database (i.e. normalized). See: https://atlasgo.io/concepts/dev-database. - entsql.Check("(`age` > 0)"), + // In case schema.ModeInspect is used without a dev-database, unnamed check constraints + // should be normalized (i.e. identical to their definition in the database). In this + // case, it is entsql.Check("(`age` > 0)"). See: https://atlasgo.io/concepts/dev-database. + entsql.Check("age > 0"), // Named check constraints are compared by their name. // Thus, the definition does not need to be normalized. diff --git a/go.mod b/go.mod index 25ebcd24f..d95dad2bd 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module entgo.io/ent go 1.19 require ( - ariga.io/atlas v0.8.2-0.20221113160047-09851f798b12 + ariga.io/atlas v0.8.3-0.20221116151337-9e4e9cbf3baf github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/go-openapi/inflect v0.19.0 github.com/go-sql-driver/mysql v1.6.0 diff --git a/go.sum b/go.sum index c56ec3ab5..f30a01fd9 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ ariga.io/atlas v0.8.2-0.20221108073928-ba5d4f596240 h1:Skxqk163AiuhtDEmAfF2/dvaD ariga.io/atlas v0.8.2-0.20221108073928-ba5d4f596240/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= ariga.io/atlas v0.8.2-0.20221113160047-09851f798b12 h1:ocnzGNr1PmOM/UYxlHmhYj/tb7/z4rkN/0v0/qc8aVU= ariga.io/atlas v0.8.2-0.20221113160047-09851f798b12/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= +ariga.io/atlas v0.8.3-0.20221116151337-9e4e9cbf3baf h1:tq28xcfFAtxk75ej1IwK+yIbRYC0fqNZkHljcVbYrOs= +ariga.io/atlas v0.8.3-0.20221116151337-9e4e9cbf3baf/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=