mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: normalize ent/schema (desired state) on replay mode (#3100)
This commit is contained in:
@@ -174,50 +174,18 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
|
||||
)
|
||||
}
|
||||
var (
|
||||
err error
|
||||
plan *migrate.Plan
|
||||
)
|
||||
switch a.mode {
|
||||
case ModeInspect:
|
||||
// Do nothing here, simply inspect later on.
|
||||
plan, err = a.planInspect(ctx, a.sqlDialect, name, tables)
|
||||
case ModeReplay:
|
||||
// We consider a database clean if there are no tables in the connected schema.
|
||||
s, err := a.atDriver.InspectSchema(ctx, "", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(s.Tables) > 0 {
|
||||
return migrate.NotCleanError{Reason: fmt.Sprintf("found table %q", s.Tables[0].Name)}
|
||||
}
|
||||
// Clean up once done.
|
||||
defer func() {
|
||||
// We clean a database by dropping all tables inside the connected schema.
|
||||
s, err = a.atDriver.InspectSchema(ctx, "", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tbls := make([]schema.Change, len(s.Tables))
|
||||
for i, t := range s.Tables {
|
||||
tbls[i] = &schema.DropTable{T: t}
|
||||
}
|
||||
if err2 := a.atDriver.ApplyChanges(ctx, tbls); err2 != nil {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %w", err2, err)
|
||||
return
|
||||
}
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
}()
|
||||
// Replay the migration directory on the database.
|
||||
ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
|
||||
return err
|
||||
}
|
||||
plan, err = a.planReplay(ctx, name, tables)
|
||||
default:
|
||||
return fmt.Errorf("unknown migration mode: %q", a.mode)
|
||||
}
|
||||
plan, err := a.plan(ctx, a.sqlDialect, name, tables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -228,6 +196,23 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
|
||||
return migrate.NewPlanner(nil, a.dir, opts...).WritePlan(plan)
|
||||
}
|
||||
|
||||
func (a *Atlas) cleanSchema(ctx context.Context, name string, err0 error) (err error) {
|
||||
defer func() {
|
||||
if err0 != nil {
|
||||
err = fmt.Errorf("%v: %w", err0, err)
|
||||
}
|
||||
}()
|
||||
s, err := a.atDriver.InspectSchema(ctx, name, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
drop := make([]schema.Change, len(s.Tables))
|
||||
for i, t := range s.Tables {
|
||||
drop[i] = &schema.DropTable{T: t}
|
||||
}
|
||||
return a.atDriver.ApplyChanges(ctx, drop)
|
||||
}
|
||||
|
||||
// VerifyTableRange ensures, that the defined autoincrement starting value is set for each table as defined by the
|
||||
// TypTable. This is necessary for MySQL versions < 8.0. In those versions the defined starting value for AUTOINCREMENT
|
||||
// columns was stored in memory, and when a server restarts happens and there are no rows yet in a table, the defined
|
||||
@@ -673,7 +658,7 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
|
||||
}
|
||||
defer func() { a.atDriver = nil }()
|
||||
if err := func() error {
|
||||
plan, err := a.plan(ctx, tx, "changes", tables)
|
||||
plan, err := a.planInspect(ctx, tx, "changes", tables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -703,10 +688,9 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// plan creates the current state by inspecting the connected database, computing the current state of the Ent schema
|
||||
// planInspect creates the current state by inspecting the connected database, computing the current state of the Ent schema
|
||||
// and proceeds to diff the changes to create a migration plan.
|
||||
// before diffing.
|
||||
func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) {
|
||||
func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) {
|
||||
current, err := a.atDriver.InspectSchema(ctx, "", &schema.InspectOptions{
|
||||
Tables: func() (t []string) {
|
||||
for i := range tables {
|
||||
@@ -726,26 +710,89 @@ func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string,
|
||||
}
|
||||
a.types = types
|
||||
}
|
||||
desired, err := a.StateReader(tables...).ReadState(ctx)
|
||||
realm, err := a.StateReader(tables...).ReadState(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Diff changes.
|
||||
changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, &schema.Schema{
|
||||
Name: current.Name,
|
||||
Attrs: current.Attrs,
|
||||
Tables: desired.Schemas[0].Tables,
|
||||
})
|
||||
desired := realm.Schemas[0]
|
||||
desired.Name, desired.Attrs = current.Name, current.Attrs
|
||||
return a.diff(ctx, name, current, desired, a.types[len(types):])
|
||||
}
|
||||
|
||||
func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*migrate.Plan, error) {
|
||||
// We consider a database clean if there are no tables in the connected schema.
|
||||
s, err := a.atDriver.InspectSchema(ctx, "", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Plan changes.
|
||||
plan, err := a.atDriver.PlanChanges(ctx, name, changes)
|
||||
if len(s.Tables) > 0 {
|
||||
return nil, migrate.NotCleanError{Reason: fmt.Sprintf("found table %q", s.Tables[0].Name)}
|
||||
}
|
||||
// Replay the migration directory on the database.
|
||||
ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
|
||||
return nil, a.cleanSchema(ctx, "", err)
|
||||
}
|
||||
// Inspect the current schema (migration directory).
|
||||
current, err := a.atDriver.InspectSchema(ctx, "", nil)
|
||||
if err != nil {
|
||||
return nil, a.cleanSchema(ctx, "", err)
|
||||
}
|
||||
var types []string
|
||||
if a.universalID {
|
||||
if types, err = a.loadTypes(ctx, a.sqlDialect); err != nil && !errors.Is(err, errTypeTableNotFound) {
|
||||
return nil, a.cleanSchema(ctx, "", err)
|
||||
}
|
||||
a.types = types
|
||||
}
|
||||
if err := a.cleanSchema(ctx, "", nil); err != nil {
|
||||
return nil, fmt.Errorf("clean schemas after migration replaying: %w", err)
|
||||
}
|
||||
desired, err := a.tables(tables)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// In case of replay mode, normalize the desired state (i.e. ent/schema).
|
||||
if nr, ok := a.atDriver.(schema.Normalizer); ok {
|
||||
ns, err := nr.NormalizeSchema(ctx, schema.New(current.Name).AddTables(desired...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ns.Tables) != len(desired) {
|
||||
return nil, fmt.Errorf("unexpected number of tables after normalization: %d != %d", len(ns.Tables), len(desired))
|
||||
}
|
||||
// Ensure all tables exist in the normalized format and the order is preserved.
|
||||
for i, t := range desired {
|
||||
d, ok := ns.Table(t.Name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("table %q not found after normalization", t.Name)
|
||||
}
|
||||
desired[i] = d
|
||||
}
|
||||
}
|
||||
return a.diff(ctx, name, current,
|
||||
&schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: desired}, a.types[len(types):],
|
||||
// For BC reason, we omit the schema qualifier from the migration scripts,
|
||||
// but that is currently limiting versioned migration to a single schema.
|
||||
func(opts *migrate.PlanOptions) {
|
||||
var noQualifier string
|
||||
opts.SchemaQualifier = &noQualifier
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (a *Atlas) diff(ctx context.Context, name string, current, desired *schema.Schema, newTypes []string, opts ...migrate.PlanOption) (*migrate.Plan, error) {
|
||||
changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, desired)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan, err := a.atDriver.PlanChanges(ctx, name, changes, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Insert new types.
|
||||
newTypes := a.types[len(types):]
|
||||
if len(newTypes) > 0 {
|
||||
plan.Changes = append(plan.Changes, &migrate.Change{
|
||||
Cmd: a.sqlDialect.atTypeRangeSQL(newTypes...),
|
||||
@@ -893,7 +940,7 @@ func (a *Atlas) atDefault(c1 *Column, c2 *schema.Column) error {
|
||||
}
|
||||
c2.SetDefault(&schema.RawExpr{X: string(x)})
|
||||
case map[string]Expr:
|
||||
d, ok := x[a.driver.Dialect()]
|
||||
d, ok := x[a.sqlDialect.Dialect()]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user