dialect/sql/schema: normalize ent/schema (desired state) on replay mode (#3100)

This commit is contained in:
Ariel Mashraki
2022-11-16 20:01:36 +02:00
committed by GitHub
parent cd60f84853
commit 5954fa8b15
8 changed files with 186 additions and 75 deletions

View File

@@ -174,50 +174,18 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
)
}
var (
err error
plan *migrate.Plan
)
switch a.mode {
case ModeInspect:
// Do nothing here, simply inspect later on.
plan, err = a.planInspect(ctx, a.sqlDialect, name, tables)
case ModeReplay:
// We consider a database clean if there are no tables in the connected schema.
s, err := a.atDriver.InspectSchema(ctx, "", nil)
if err != nil {
return err
}
if len(s.Tables) > 0 {
return migrate.NotCleanError{Reason: fmt.Sprintf("found table %q", s.Tables[0].Name)}
}
// Clean up once done.
defer func() {
// We clean a database by dropping all tables inside the connected schema.
s, err = a.atDriver.InspectSchema(ctx, "", nil)
if err != nil {
return
}
tbls := make([]schema.Change, len(s.Tables))
for i, t := range s.Tables {
tbls[i] = &schema.DropTable{T: t}
}
if err2 := a.atDriver.ApplyChanges(ctx, tbls); err2 != nil {
if err != nil {
err = fmt.Errorf("%v: %w", err2, err)
return
}
err = err2
return
}
}()
// Replay the migration directory on the database.
ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{})
if err != nil {
return err
}
if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
return err
}
plan, err = a.planReplay(ctx, name, tables)
default:
return fmt.Errorf("unknown migration mode: %q", a.mode)
}
plan, err := a.plan(ctx, a.sqlDialect, name, tables)
if err != nil {
return err
}
@@ -228,6 +196,23 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
return migrate.NewPlanner(nil, a.dir, opts...).WritePlan(plan)
}
func (a *Atlas) cleanSchema(ctx context.Context, name string, err0 error) (err error) {
defer func() {
if err0 != nil {
err = fmt.Errorf("%v: %w", err0, err)
}
}()
s, err := a.atDriver.InspectSchema(ctx, name, nil)
if err != nil {
return err
}
drop := make([]schema.Change, len(s.Tables))
for i, t := range s.Tables {
drop[i] = &schema.DropTable{T: t}
}
return a.atDriver.ApplyChanges(ctx, drop)
}
// VerifyTableRange ensures, that the defined autoincrement starting value is set for each table as defined by the
// TypTable. This is necessary for MySQL versions < 8.0. In those versions the defined starting value for AUTOINCREMENT
// columns was stored in memory, and when a server restarts happens and there are no rows yet in a table, the defined
@@ -673,7 +658,7 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
}
defer func() { a.atDriver = nil }()
if err := func() error {
plan, err := a.plan(ctx, tx, "changes", tables)
plan, err := a.planInspect(ctx, tx, "changes", tables)
if err != nil {
return err
}
@@ -703,10 +688,9 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
return tx.Commit()
}
// plan creates the current state by inspecting the connected database, computing the current state of the Ent schema
// planInspect creates the current state by inspecting the connected database, computing the current state of the Ent schema
// and proceeds to diff the changes to create a migration plan.
// before diffing.
func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) {
func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) {
current, err := a.atDriver.InspectSchema(ctx, "", &schema.InspectOptions{
Tables: func() (t []string) {
for i := range tables {
@@ -726,26 +710,89 @@ func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string,
}
a.types = types
}
desired, err := a.StateReader(tables...).ReadState(ctx)
realm, err := a.StateReader(tables...).ReadState(ctx)
if err != nil {
return nil, err
}
// Diff changes.
changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, &schema.Schema{
Name: current.Name,
Attrs: current.Attrs,
Tables: desired.Schemas[0].Tables,
})
desired := realm.Schemas[0]
desired.Name, desired.Attrs = current.Name, current.Attrs
return a.diff(ctx, name, current, desired, a.types[len(types):])
}
func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*migrate.Plan, error) {
// We consider a database clean if there are no tables in the connected schema.
s, err := a.atDriver.InspectSchema(ctx, "", nil)
if err != nil {
return nil, err
}
// Plan changes.
plan, err := a.atDriver.PlanChanges(ctx, name, changes)
if len(s.Tables) > 0 {
return nil, migrate.NotCleanError{Reason: fmt.Sprintf("found table %q", s.Tables[0].Name)}
}
// Replay the migration directory on the database.
ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{})
if err != nil {
return nil, err
}
if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
return nil, a.cleanSchema(ctx, "", err)
}
// Inspect the current schema (migration directory).
current, err := a.atDriver.InspectSchema(ctx, "", nil)
if err != nil {
return nil, a.cleanSchema(ctx, "", err)
}
var types []string
if a.universalID {
if types, err = a.loadTypes(ctx, a.sqlDialect); err != nil && !errors.Is(err, errTypeTableNotFound) {
return nil, a.cleanSchema(ctx, "", err)
}
a.types = types
}
if err := a.cleanSchema(ctx, "", nil); err != nil {
return nil, fmt.Errorf("clean schemas after migration replaying: %w", err)
}
desired, err := a.tables(tables)
if err != nil {
return nil, err
}
// In case of replay mode, normalize the desired state (i.e. ent/schema).
if nr, ok := a.atDriver.(schema.Normalizer); ok {
ns, err := nr.NormalizeSchema(ctx, schema.New(current.Name).AddTables(desired...))
if err != nil {
return nil, err
}
if len(ns.Tables) != len(desired) {
return nil, fmt.Errorf("unexpected number of tables after normalization: %d != %d", len(ns.Tables), len(desired))
}
// Ensure all tables exist in the normalized format and the order is preserved.
for i, t := range desired {
d, ok := ns.Table(t.Name)
if !ok {
return nil, fmt.Errorf("table %q not found after normalization", t.Name)
}
desired[i] = d
}
}
return a.diff(ctx, name, current,
&schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: desired}, a.types[len(types):],
// For BC reason, we omit the schema qualifier from the migration scripts,
// but that is currently limiting versioned migration to a single schema.
func(opts *migrate.PlanOptions) {
var noQualifier string
opts.SchemaQualifier = &noQualifier
},
)
}
func (a *Atlas) diff(ctx context.Context, name string, current, desired *schema.Schema, newTypes []string, opts ...migrate.PlanOption) (*migrate.Plan, error) {
changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, desired)
if err != nil {
return nil, err
}
plan, err := a.atDriver.PlanChanges(ctx, name, changes, opts...)
if err != nil {
return nil, err
}
// Insert new types.
newTypes := a.types[len(types):]
if len(newTypes) > 0 {
plan.Changes = append(plan.Changes, &migrate.Change{
Cmd: a.sqlDialect.atTypeRangeSQL(newTypes...),
@@ -893,7 +940,7 @@ func (a *Atlas) atDefault(c1 *Column, c2 *schema.Column) error {
}
c2.SetDefault(&schema.RawExpr{X: string(x)})
case map[string]Expr:
d, ok := x[a.driver.Dialect()]
d, ok := x[a.sqlDialect.Dialect()]
if !ok {
return nil
}