// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "crypto/md5" "database/sql" "errors" "fmt" "net/url" "reflect" "sort" "strings" "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/schema" "ariga.io/atlas/sql/sqlclient" "ariga.io/atlas/sql/sqltool" "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" ) // Atlas atlas migration engine. type Atlas struct { atDriver migrate.Driver sqlDialect sqlDialect schema string // schema to use indent string // plan indentation errNoPlan bool // no plan error enabled universalID bool // global unique ids dropColumns bool // drop deleted columns dropIndexes bool // drop deleted indexes withForeignKeys bool // with foreign keys mode Mode hooks []Hook // hooks to apply before creation diffHooks []DiffHook // diff hooks to run when diffing current and desired diffOptions []schema.DiffOption // diff options to pass to the diff engine applyHook []ApplyHook // apply hooks to run when applying the plan skip ChangeKind // what changes to skip and not apply dir migrate.Dir // the migration directory to read from fmt migrate.Formatter // how to format the plan into migration files driver dialect.Driver // driver passed in when not using an atlas URL url *url.URL // url of database connection dialect string // Ent dialect to use when generating migration files types []string // pre-existing pk range allocation for global unique id } // Diff compares the state read from a database connection or migration directory with the state defined by the Ent // schema. Changes will be written to new migration files. func Diff(ctx context.Context, u, name string, tables []*Table, opts ...MigrateOption) (err error) { m, err := NewMigrateURL(u, opts...) if err != nil { return err } return m.NamedDiff(ctx, name, tables...) } // NewMigrate creates a new Atlas form the given dialect.Driver. func NewMigrate(drv dialect.Driver, opts ...MigrateOption) (*Atlas, error) { a := &Atlas{driver: drv, withForeignKeys: true, mode: ModeInspect} for _, opt := range opts { opt(a) } a.dialect = a.driver.Dialect() if err := a.init(); err != nil { return nil, err } return a, nil } // NewMigrateURL create a new Atlas from the given url. func NewMigrateURL(u string, opts ...MigrateOption) (*Atlas, error) { parsed, err := url.Parse(u) if err != nil { return nil, err } a := &Atlas{url: parsed, withForeignKeys: true, mode: ModeInspect} for _, opt := range opts { opt(a) } if a.dialect == "" { a.dialect = parsed.Scheme } if err := a.init(); err != nil { return nil, err } return a, nil } // Create creates all schema resources in the database. It works in an "append-only" // mode, which means, it only creates tables, appends columns to tables or modifies column types. // // Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not // resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but // changing varchar(120) to varchar(255) is valid. For more info, see the convert function below. func (a *Atlas) Create(ctx context.Context, tables ...*Table) (err error) { a.setupTables(tables) var creator Creator = CreateFunc(a.create) for i := len(a.hooks) - 1; i >= 0; i-- { creator = a.hooks[i](creator) } return creator.Create(ctx, tables...) } // Diff compares the state read from the connected database with the state defined by Ent. // Changes will be written to migration files by the configured Planner. func (a *Atlas) Diff(ctx context.Context, tables ...*Table) error { return a.NamedDiff(ctx, "changes", tables...) } // NamedDiff compares the state read from the connected database with the state defined by Ent. // Changes will be written to migration files by the configured Planner. func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) error { if a.dir == nil { return errors.New("no migration directory given") } opts := []migrate.PlannerOption{migrate.PlanFormat(a.fmt)} // Validate the migration directory before proceeding. if err := migrate.Validate(a.dir); err != nil { return fmt.Errorf("validating migration directory: %w", err) } a.setupTables(tables) // Set up connections. if a.driver != nil { var err error a.sqlDialect, err = a.entDialect(ctx, a.driver) if err != nil { return err } a.atDriver, err = a.sqlDialect.atOpen(a.sqlDialect) if err != nil { return err } } else { c, err := sqlclient.OpenURL(ctx, a.url) if err != nil { return err } defer c.Close() a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } a.atDriver = c.Driver } defer func() { a.sqlDialect = nil a.atDriver = nil }() if err := a.sqlDialect.init(ctx); err != nil { return err } if a.universalID { tables = append(tables, NewTypesTable()) } var ( err error plan *migrate.Plan ) switch a.mode { case ModeInspect: plan, err = a.planInspect(ctx, a.sqlDialect, name, tables) case ModeReplay: plan, err = a.planReplay(ctx, name, tables) default: return fmt.Errorf("unknown migration mode: %q", a.mode) } switch { case err != nil: return err case len(plan.Changes) == 0: if a.errNoPlan { return migrate.ErrNoPlan } return nil default: return migrate.NewPlanner(nil, a.dir, opts...).WritePlan(plan) } } func (a *Atlas) cleanSchema(ctx context.Context, name string, err0 error) (err error) { defer func() { if err0 != nil { err = errors.Join(err, err0) } }() s, err := a.atDriver.InspectSchema(ctx, name, nil) if err != nil { return err } drop := make([]schema.Change, len(s.Tables)) for i, t := range s.Tables { drop[i] = &schema.DropTable{T: t, Extra: []schema.Clause{&schema.IfExists{}}} } return a.atDriver.ApplyChanges(ctx, drop) } // VerifyTableRange ensures, that the defined autoincrement starting value is set for each table as defined by the // TypTable. This is necessary for MySQL versions < 8.0. In those versions the defined starting value for AUTOINCREMENT // columns was stored in memory, and when a server restarts happens and there are no rows yet in a table, the defined // starting value is lost, which will result in incorrect behavior when working with global unique ids. Calling this // method on service start ensures the information are correct and are set again, if they aren't. For MySQL versions > 8 // calling this method is only required once after the upgrade. func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error { if a.driver != nil { var err error a.sqlDialect, err = a.entDialect(ctx, a.driver) if err != nil { return err } } else { c, err := sqlclient.OpenURL(ctx, a.url) if err != nil { return err } defer c.Close() a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } } defer func() { a.sqlDialect = nil }() vr, ok := a.sqlDialect.(verifyRanger) if !ok { return nil } types, err := a.loadTypes(ctx, a.sqlDialect) if err != nil { // In most cases this means the table does not exist, which in turn // indicates the user does not use global unique ids. return err } for _, t := range tables { id := indexOf(types, t.Name) if id == -1 { continue } if err := vr.verifyRange(ctx, a.sqlDialect, t, int64(id<<32)); err != nil { return err } } return nil } type ( // Differ is the interface that wraps the Diff method. Differ interface { // Diff returns a list of changes that construct a migration plan. Diff(current, desired *schema.Schema) ([]schema.Change, error) } // The DiffFunc type is an adapter to allow the use of ordinary function as Differ. // If f is a function with the appropriate signature, DiffFunc(f) is a Differ that calls f. DiffFunc func(current, desired *schema.Schema) ([]schema.Change, error) // DiffHook defines the "diff middleware". A function that gets a Differ and returns a Differ. DiffHook func(Differ) Differ ) // Diff calls f(current, desired). func (f DiffFunc) Diff(current, desired *schema.Schema) ([]schema.Change, error) { return f(current, desired) } // WithDiffHook adds a list of DiffHook to the schema migration. // // schema.WithDiffHook(func(next schema.Differ) schema.Differ { // return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) { // // Code before standard diff. // changes, err := next.Diff(current, desired) // if err != nil { // return nil, err // } // // After diff, you can filter // // changes or return new ones. // return changes, nil // }) // }) func WithDiffHook(hooks ...DiffHook) MigrateOption { return func(a *Atlas) { a.diffHooks = append(a.diffHooks, hooks...) } } // WithDiffOptions adds a list of options to pass to the diff engine. func WithDiffOptions(opts ...schema.DiffOption) MigrateOption { return func(a *Atlas) { a.diffOptions = append(a.diffOptions, opts...) } } // WithSkipChanges allows skipping/filtering list of changes // returned by the Differ before executing migration planning. // // SkipChanges(schema.DropTable|schema.DropColumn) func WithSkipChanges(skip ChangeKind) MigrateOption { return func(a *Atlas) { a.skip = skip } } // A ChangeKind denotes the kind of schema change. type ChangeKind uint // List of change types. const ( NoChange ChangeKind = 0 AddSchema ChangeKind = 1 << (iota - 1) ModifySchema DropSchema AddTable ModifyTable DropTable AddColumn ModifyColumn DropColumn AddIndex ModifyIndex DropIndex AddForeignKey ModifyForeignKey DropForeignKey AddCheck ModifyCheck DropCheck ) // Is reports whether c is match the given change kind. func (k ChangeKind) Is(c ChangeKind) bool { return k == c || k&c != 0 } // filterChanges is a DiffHook for filtering changes before plan. func filterChanges(skip ChangeKind) DiffHook { return func(next Differ) Differ { return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) { var f func([]schema.Change) []schema.Change f = func(changes []schema.Change) (keep []schema.Change) { var k ChangeKind for _, c := range changes { switch c := c.(type) { case *schema.AddSchema: k = AddSchema case *schema.ModifySchema: k = ModifySchema if !skip.Is(k) { c.Changes = f(c.Changes) } case *schema.DropSchema: k = DropSchema case *schema.AddTable: k = AddTable case *schema.ModifyTable: k = ModifyTable if !skip.Is(k) { c.Changes = f(c.Changes) } case *schema.DropTable: k = DropTable case *schema.AddColumn: k = AddColumn case *schema.ModifyColumn: k = ModifyColumn case *schema.DropColumn: k = DropColumn case *schema.AddIndex: k = AddIndex case *schema.ModifyIndex: k = ModifyIndex case *schema.DropIndex: k = DropIndex case *schema.AddForeignKey: k = AddIndex case *schema.ModifyForeignKey: k = ModifyForeignKey case *schema.DropForeignKey: k = DropForeignKey case *schema.AddCheck: k = AddCheck case *schema.ModifyCheck: k = ModifyCheck case *schema.DropCheck: k = DropCheck } if !skip.Is(k) { keep = append(keep, c) } } return } changes, err := next.Diff(current, desired) if err != nil { return nil, err } return f(changes), nil }) } } func withoutForeignKeys(next Differ) Differ { return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) { changes, err := next.Diff(current, desired) if err != nil { return nil, err } for _, c := range changes { switch c := c.(type) { case *schema.AddTable: c.T.ForeignKeys = nil case *schema.ModifyTable: c.T.ForeignKeys = nil filtered := make([]schema.Change, 0, len(c.Changes)) for _, change := range c.Changes { switch change.(type) { case *schema.AddForeignKey, *schema.DropForeignKey, *schema.ModifyForeignKey: continue default: filtered = append(filtered, change) } } c.Changes = filtered } } return changes, nil }) } type ( // Applier is the interface that wraps the Apply method. Applier interface { // Apply applies the given migrate.Plan on the database. Apply(context.Context, dialect.ExecQuerier, *migrate.Plan) error } // The ApplyFunc type is an adapter to allow the use of ordinary function as Applier. // If f is a function with the appropriate signature, ApplyFunc(f) is an Applier that calls f. ApplyFunc func(context.Context, dialect.ExecQuerier, *migrate.Plan) error // ApplyHook defines the "migration applying middleware". A function that gets an Applier and returns an Applier. ApplyHook func(Applier) Applier ) // Apply calls f(ctx, tables...). func (f ApplyFunc) Apply(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error { return f(ctx, conn, plan) } // WithApplyHook adds a list of ApplyHook to the schema migration. // // schema.WithApplyHook(func(next schema.Applier) schema.Applier { // return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error { // // Example to hook into the apply process, or implement // // a custom applier. // // // // for _, c := range plan.Changes { // // fmt.Printf("%s: %s", c.Comment, c.Cmd) // // } // // // return next.Apply(ctx, conn, plan) // }) // }) func WithApplyHook(hooks ...ApplyHook) MigrateOption { return func(a *Atlas) { a.applyHook = append(a.applyHook, hooks...) } } // WithDir sets the atlas migration directory to use to store migration files. func WithDir(dir migrate.Dir) MigrateOption { return func(a *Atlas) { a.dir = dir } } // WithFormatter sets atlas formatter to use to write changes to migration files. func WithFormatter(fmt migrate.Formatter) MigrateOption { return func(a *Atlas) { a.fmt = fmt } } // WithDialect configures the Ent dialect to use when migrating for an Atlas supported dialect flavor. // As an example, Ent can work with TiDB in MySQL dialect and Atlas can handle TiDB migrations. func WithDialect(d string) MigrateOption { return func(a *Atlas) { a.dialect = d } } // WithMigrationMode instructs atlas how to compute the current state of the schema. This can be done by either // replaying (ModeReplay) the migration directory on the connected database, or by inspecting (ModeInspect) the // connection. Currently, ModeReplay is opt-in, and ModeInspect is the default. In future versions, ModeReplay will // become the default behavior. This option has no effect when using online migrations. func WithMigrationMode(mode Mode) MigrateOption { return func(a *Atlas) { a.mode = mode } } // Mode to compute the current state. type Mode uint const ( // ModeReplay computes the current state by replaying the migration directory on the connected database. ModeReplay = iota // ModeInspect computes the current state by inspecting the connected database. ModeInspect ) // StateReader returns an atlas migrate.StateReader returning the state as described by the Ent table slice. func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc { return func(ctx context.Context) (*schema.Realm, error) { if a.sqlDialect == nil { drv, err := a.entDialect(ctx, a.driver) if err != nil { return nil, err } a.sqlDialect = drv } a.setupTables(tables) ts, err := a.tables(tables) if err != nil { return nil, err } vs, err := a.views(tables) if err != nil { return nil, err } return &schema.Realm{Schemas: []*schema.Schema{{Tables: ts, Views: vs}}}, nil } } // atBuilder must be implemented by the different drivers in // order to convert a dialect/sql/schema to atlas/sql/schema. type atBuilder interface { atOpen(dialect.ExecQuerier) (migrate.Driver, error) atTable(*Table, *schema.Table) supportsDefault(*Column) bool atTypeC(*Column, *schema.Column) error atUniqueC(*Table, *Column, *schema.Table, *schema.Column) atIncrementC(*schema.Table, *schema.Column) atIncrementT(*schema.Table, int64) atIndex(*Index, *schema.Table, *schema.Index) error atTypeRangeSQL(t ...string) string } // init initializes the configuration object based on the options passed in. func (a *Atlas) init() error { skip := DropIndex | DropColumn if a.skip != NoChange { skip = a.skip } if a.dropIndexes { skip &= ^DropIndex } if a.dropColumns { skip &= ^DropColumn } if skip != NoChange { a.diffHooks = append(a.diffHooks, filterChanges(skip)) } if !a.withForeignKeys { a.diffHooks = append(a.diffHooks, withoutForeignKeys) } if a.dir != nil && a.fmt == nil { switch a.dir.(type) { case *sqltool.GooseDir: a.fmt = sqltool.GooseFormatter case *sqltool.DBMateDir: a.fmt = sqltool.DBMateFormatter case *sqltool.FlywayDir: a.fmt = sqltool.FlywayFormatter case *sqltool.LiquibaseDir: a.fmt = sqltool.LiquibaseFormatter default: // migrate.LocalDir, sqltool.GolangMigrateDir and custom ones a.fmt = sqltool.GolangMigrateFormatter } } // ModeReplay requires a migration directory. if a.mode == ModeReplay && a.dir == nil { return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires versioned migrations: WithDir()") } return nil } // create is the Atlas engine based online migration. func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) { if a.universalID { tables = append(tables, NewTypesTable()) } if a.driver != nil { a.sqlDialect, err = a.entDialect(ctx, a.driver) if err != nil { return err } } else { c, err := sqlclient.OpenURL(ctx, a.url) if err != nil { return err } defer c.Close() a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } } defer func() { a.sqlDialect = nil }() if err := a.sqlDialect.init(ctx); err != nil { return err } a.atDriver, err = a.sqlDialect.atOpen(a.sqlDialect) if err != nil { return err } defer func() { a.atDriver = nil }() plan, err := a.planInspect(ctx, a.sqlDialect, "changes", tables) if err != nil { return fmt.Errorf("sql/schema: %w", err) } if len(plan.Changes) == 0 { return nil } // Open a transaction for backwards compatibility, // even if the migration is not transactional. tx, err := a.sqlDialect.Tx(ctx) if err != nil { return err } a.atDriver, err = a.sqlDialect.atOpen(tx) if err != nil { return err } // Apply plan (changes). var applier Applier = ApplyFunc(func(ctx context.Context, tx dialect.ExecQuerier, plan *migrate.Plan) error { for _, c := range plan.Changes { if err := tx.Exec(ctx, c.Cmd, c.Args, nil); err != nil { if c.Comment != "" { err = fmt.Errorf("%s: %w", c.Comment, err) } return err } } return nil }) for i := len(a.applyHook) - 1; i >= 0; i-- { applier = a.applyHook[i](applier) } if err = applier.Apply(ctx, tx, plan); err != nil { return errors.Join(fmt.Errorf("sql/schema: %w", err), tx.Rollback()) } return tx.Commit() } // planInspect creates the current state by inspecting the connected database, computing the current state of the Ent schema // and proceeds to diff the changes to create a migration plan. func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) { current, err := a.atDriver.InspectSchema(ctx, a.schema, &schema.InspectOptions{ Tables: func() (t []string) { for i := range tables { t = append(t, tables[i].Name) } return t }(), // Ent supports table-level inspection only. Mode: schema.InspectSchemas | schema.InspectTables, }) if err != nil { return nil, err } var types []string if a.universalID { types, err = a.loadTypes(ctx, conn) if err != nil && !errors.Is(err, errTypeTableNotFound) { return nil, err } a.types = types } realm, err := a.StateReader(tables...).ReadState(ctx) if err != nil { return nil, err } desired := realm.Schemas[0] desired.Name, desired.Attrs = current.Name, current.Attrs return a.diff(ctx, name, current, desired, a.types[len(types):]) } func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*migrate.Plan, error) { // We consider a database clean if there are no tables in the connected schema. s, err := a.atDriver.InspectSchema(ctx, a.schema, nil) if err != nil { return nil, err } if len(s.Tables) > 0 { return nil, &migrate.NotCleanError{Reason: fmt.Sprintf("found table %q", s.Tables[0].Name)} } // Replay the migration directory on the database. ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{}) if err != nil { return nil, err } if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) { return nil, a.cleanSchema(ctx, a.schema, err) } // Inspect the current schema (migration directory). current, err := a.atDriver.InspectSchema(ctx, a.schema, nil) if err != nil { return nil, a.cleanSchema(ctx, a.schema, err) } var types []string if a.universalID { if types, err = a.loadTypes(ctx, a.sqlDialect); err != nil && !errors.Is(err, errTypeTableNotFound) { return nil, a.cleanSchema(ctx, a.schema, err) } a.types = types } if err := a.cleanSchema(ctx, a.schema, nil); err != nil { return nil, fmt.Errorf("clean schemas after migration replaying: %w", err) } desired, err := a.tables(tables) if err != nil { return nil, err } // In case of replay mode, normalize the desired state (i.e. ent/schema). if nr, ok := a.atDriver.(schema.Normalizer); ok { ns, err := nr.NormalizeSchema(ctx, schema.New(current.Name).AddTables(desired...)) if err != nil { return nil, err } if len(ns.Tables) != len(desired) { return nil, fmt.Errorf("unexpected number of tables after normalization: %d != %d", len(ns.Tables), len(desired)) } // Ensure all tables exist in the normalized format and the order is preserved. for i, t := range desired { d, ok := ns.Table(t.Name) if !ok { return nil, fmt.Errorf("table %q not found after normalization", t.Name) } desired[i] = d } } return a.diff(ctx, name, current, &schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: desired}, a.types[len(types):], // For BC reason, we omit the schema qualifier from the migration scripts, // but that is currently limiting versioned migration to a single schema. func(opts *migrate.PlanOptions) { var noQualifier string opts.SchemaQualifier = &noQualifier }, ) } func (a *Atlas) diff(ctx context.Context, name string, current, desired *schema.Schema, newTypes []string, opts ...migrate.PlanOption) (*migrate.Plan, error) { changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, desired, a.diffOptions...) if err != nil { return nil, err } filtered := make([]schema.Change, 0, len(changes)) for _, c := range changes { switch c.(type) { // Select only table creation and modification. The reason we may encounter this, even though specific tables // are passed to Inspect, is if the MySQL system variable 'lower_case_table_names' is set to 1. In such a case, // the given tables will be returned from inspection because MySQL compares case-insensitive, but they won't // match when compare them in code. case *schema.AddTable, *schema.ModifyTable: filtered = append(filtered, c) } } if a.indent != "" { opts = append(opts, func(opts *migrate.PlanOptions) { opts.Indent = a.indent }) } plan, err := a.atDriver.PlanChanges(ctx, name, filtered, opts...) if err != nil { return nil, err } if len(newTypes) > 0 { plan.Changes = append(plan.Changes, &migrate.Change{ Cmd: a.sqlDialect.atTypeRangeSQL(newTypes...), Comment: fmt.Sprintf("add pk ranges for %s tables", strings.Join(newTypes, ",")), }) } return plan, nil } var errTypeTableNotFound = errors.New("ent_type table not found") // loadTypes loads the currently saved range allocations from the TypeTable. func (a *Atlas) loadTypes(ctx context.Context, conn dialect.ExecQuerier) ([]string, error) { // Fetch pre-existing type allocations. exists, err := a.sqlDialect.tableExist(ctx, conn, TypeTable) if err != nil { return nil, err } if !exists { return nil, errTypeTableNotFound } rows := &entsql.Rows{} query, args := entsql.Dialect(a.dialect). Select("type").From(entsql.Table(TypeTable)).OrderBy(entsql.Asc("id")).Query() if err := conn.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("query types table: %w", err) } defer rows.Close() var types []string if err := entsql.ScanSlice(rows, &types); err != nil { return nil, err } return types, nil } type db struct{ dialect.ExecQuerier } func (d *db) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { rows := &entsql.Rows{} if err := d.ExecQuerier.Query(ctx, query, args, rows); err != nil { return nil, err } return rows.ColumnScanner.(*sql.Rows), nil } func (d *db) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { var r sql.Result if err := d.ExecQuerier.Exec(ctx, query, args, &r); err != nil { return nil, err } return r, nil } // tables converts an Ent table slice to an atlas table slice func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) { var ( byT = make(map[*Table]*schema.Table) ts = make([]*schema.Table, 0, len(tables)) ) for _, et := range tables { if et.View { continue } at := schema.NewTable(et.Name) if et.Comment != "" { at.SetComment(et.Comment) } a.sqlDialect.atTable(et, at) // universalID is the old implementation of the global unique id, relying on a table in the database. // The new implementation is based on annotations attached to the schema. Only one can be enabled. switch { case a.universalID && et.Annotation != nil && et.Annotation.IncrementStart != nil: return nil, errors.New("universal id and increment start annotation are mutually exclusive") case a.universalID && et.Name != TypeTable && len(et.PrimaryKey) == 1: r, err := a.pkRange(et) if err != nil { return nil, err } a.sqlDialect.atIncrementT(at, r) case et.Annotation != nil && et.Annotation.IncrementStart != nil: a.sqlDialect.atIncrementT(at, int64(*et.Annotation.IncrementStart)) } if err := a.aColumns(et, at); err != nil { return nil, err } if err := a.aIndexes(et, at); err != nil { return nil, err } ts = append(ts, at) byT[et] = at } for _, t1 := range tables { if t1.View { continue } t2 := byT[t1] for _, fk1 := range t1.ForeignKeys { fk2 := schema.NewForeignKey(fk1.Symbol). SetTable(t2). SetOnUpdate(schema.ReferenceOption(fk1.OnUpdate)). SetOnDelete(schema.ReferenceOption(fk1.OnDelete)) for _, c1 := range fk1.Columns { c2, ok := t2.Column(c1.Name) if !ok { return nil, fmt.Errorf("unexpected fk %q column: %q", fk1.Symbol, c1.Name) } fk2.AddColumns(c2) } var refT *schema.Table for _, t2 := range ts { if t2.Name == fk1.RefTable.Name { refT = t2 break } } if refT == nil { return nil, fmt.Errorf("unexpected fk %q ref-table: %q", fk1.Symbol, fk1.RefTable.Name) } fk2.SetRefTable(refT) for _, c1 := range fk1.RefColumns { c2, ok := refT.Column(c1.Name) if !ok { return nil, fmt.Errorf("unexpected fk %q ref-column: %q", fk1.Symbol, c1.Name) } fk2.AddRefColumns(c2) } t2.AddForeignKeys(fk2) } } return ts, nil } // tables converts an Ent table slice to an atlas table slice func (a *Atlas) views(tables []*Table) ([]*schema.View, error) { vs := make([]*schema.View, 0, len(tables)) for _, et := range tables { // Not a view, or the view defined externally. if !et.View || et.Annotation == nil || (et.Annotation.ViewAs == "" && et.Annotation.ViewFor[a.dialect] == "") { continue } def := et.Annotation.ViewFor[a.dialect] if def == "" { def = et.Annotation.ViewAs } av := schema.NewView(et.Name, def) if et.Comment != "" { av.SetComment(et.Comment) } if err := a.aVColumns(et, av); err != nil { return nil, err } vs = append(vs, av) } return vs, nil } func (a *Atlas) aColumns(et *Table, at *schema.Table) error { for _, c1 := range et.Columns { c2 := schema.NewColumn(c1.Name). SetNull(c1.Nullable) if c1.Collation != "" { c2.SetCollation(c1.Collation) } if c1.Comment != "" { c2.SetComment(c1.Comment) } if err := a.sqlDialect.atTypeC(c1, c2); err != nil { return err } if err := a.atDefault(c1, c2); err != nil { return err } if c1.Unique && (len(et.PrimaryKey) != 1 || et.PrimaryKey[0] != c1) { a.sqlDialect.atUniqueC(et, c1, at, c2) } if c1.Increment { a.sqlDialect.atIncrementC(at, c2) } at.AddColumns(c2) } return nil } func (a *Atlas) aVColumns(et *Table, at *schema.View) error { for _, c1 := range et.Columns { c2 := schema.NewColumn(c1.Name). SetNull(c1.Nullable) if c1.Collation != "" { c2.SetCollation(c1.Collation) } if c1.Comment != "" { c2.SetComment(c1.Comment) } if err := a.sqlDialect.atTypeC(c1, c2); err != nil { return err } if err := a.atDefault(c1, c2); err != nil { return err } at.AddColumns(c2) } return nil } func (a *Atlas) atDefault(c1 *Column, c2 *schema.Column) error { if c1.Default == nil || !a.sqlDialect.supportsDefault(c1) { return nil } switch x := c1.Default.(type) { case Expr: if len(x) > 1 && (x[0] != '(' || x[len(x)-1] != ')') { x = "(" + x + ")" } c2.SetDefault(&schema.RawExpr{X: string(x)}) case map[string]Expr: d, ok := x[a.sqlDialect.Dialect()] if !ok { return nil } if len(d) > 1 && (d[0] != '(' || d[len(d)-1] != ')') { d = "(" + d + ")" } c2.SetDefault(&schema.RawExpr{X: string(d)}) default: switch { case c1.Type == field.TypeJSON: s, ok := c1.Default.(string) if !ok { return fmt.Errorf("invalid default value for JSON column %q: %v", c1.Name, c1.Default) } c2.SetDefault(&schema.Literal{V: strings.ReplaceAll(s, "'", "''")}) default: // Keep backwards compatibility with the old default value format. x := fmt.Sprint(c1.Default) if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime { // Escape single quote by replacing each with 2. x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) } c2.SetDefault(&schema.RawExpr{X: x}) } } return nil } func (a *Atlas) aIndexes(et *Table, at *schema.Table) error { // Primary-key index. pk := make([]*schema.Column, 0, len(et.PrimaryKey)) for _, c1 := range et.PrimaryKey { c2, ok := at.Column(c1.Name) if !ok { return fmt.Errorf("unexpected primary-key column: %q", c1.Name) } pk = append(pk, c2) } // CreateFunc might clear the primary keys. if len(pk) > 0 { at.SetPrimaryKey(schema.NewPrimaryKey(pk...)) } // Rest of indexes. for _, idx1 := range et.Indexes { idx2 := schema.NewIndex(idx1.Name). SetUnique(idx1.Unique) if err := a.sqlDialect.atIndex(idx1, at, idx2); err != nil { return err } desc := descIndexes(idx1) for _, p := range idx2.Parts { p.Desc = desc[p.C.Name] } at.AddIndexes(idx2) } return nil } // setupTables ensures the table is configured properly, like table columns // are linked to their indexes, and PKs columns are defined. func (a *Atlas) setupTables(tables []*Table) { for _, t := range tables { if t.columns == nil { t.columns = make(map[string]*Column, len(t.Columns)) } for _, c := range t.Columns { t.columns[c.Name] = c } for _, idx := range t.Indexes { idx.Name = a.symbol(idx.Name) for _, c := range idx.Columns { c.indexes.append(idx) } } for _, pk := range t.PrimaryKey { c := t.columns[pk.Name] c.Key = PrimaryKey pk.Key = PrimaryKey } for _, fk := range t.ForeignKeys { fk.Symbol = a.symbol(fk.Symbol) for i := range fk.Columns { fk.Columns[i].foreign = fk } } } } // symbol makes sure the symbol length is not longer than the maxlength in the dialect. func (a *Atlas) symbol(name string) string { size := 64 if a.dialect == dialect.Postgres { size = 63 } if len(name) <= size { return name } return fmt.Sprintf("%s_%x", name[:size-33], md5.Sum([]byte(name))) } // entDialect returns the Ent dialect as configured by the dialect option. func (a *Atlas) entDialect(ctx context.Context, drv dialect.Driver) (sqlDialect, error) { var d sqlDialect switch a.dialect { case dialect.MySQL: d = &MySQL{Driver: drv} case dialect.SQLite: d = &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys} case dialect.Postgres: d = &Postgres{Driver: drv} default: return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect) } if err := d.init(ctx); err != nil { return nil, err } return d, nil } func (a *Atlas) pkRange(et *Table) (int64, error) { idx := indexOf(a.types, et.Name) // If the table re-created, re-use its range from // the past. Otherwise, allocate a new id-range. if idx == -1 { if len(a.types) > MaxTypes { return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes) } idx = len(a.types) a.types = append(a.types, et.Name) } return int64(idx << 32), nil } func setAtChecks(et *Table, at *schema.Table) { if check := et.Annotation.Check; check != "" { at.AddChecks(&schema.Check{ Expr: check, }) } if checks := et.Annotation.Checks; len(et.Annotation.Checks) > 0 { names := make([]string, 0, len(checks)) for name := range checks { names = append(names, name) } sort.Strings(names) for _, name := range names { at.AddChecks(&schema.Check{ Name: name, Expr: checks[name], }) } } } // descIndexes returns a map holding the DESC mapping if exist. func descIndexes(idx *Index) map[string]bool { descs := make(map[string]bool) if idx.Annotation == nil { return descs } // If DESC (without a column) was defined on the // annotation, map it to the single column index. if idx.Annotation.Desc && len(idx.Columns) == 1 { descs[idx.Columns[0].Name] = idx.Annotation.Desc } for column, desc := range idx.Annotation.DescColumns { descs[column] = desc } return descs } // driver decorates the atlas migrate.Driver and adds "diff hooking" and functionality. type diffDriver struct { migrate.Driver hooks []DiffHook // hooks to apply } // RealmDiff creates the diff between two realms. Since Ent does not care about Realms, // not even schema changes, calling this method raises an error. func (r *diffDriver) RealmDiff(_, _ *schema.Realm, _ ...schema.DiffOption) ([]schema.Change, error) { return nil, errors.New("sqlDialect does not support working with realms") } // SchemaDiff creates the diff between two schemas, but includes "diff hooks". func (r *diffDriver) SchemaDiff(from, to *schema.Schema, opts ...schema.DiffOption) ([]schema.Change, error) { var d Differ = DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) { return r.Driver.SchemaDiff(current, desired, opts...) }) for i := len(r.hooks) - 1; i >= 0; i-- { d = r.hooks[i](d) } return d.Diff(from, to) } // removeAttr is a temporary patch due to compiler errors we get by using the generic // schema.RemoveAttr function (:1: internal compiler error: panic: ...). // Can be removed in Go 1.20. See: https://github.com/golang/go/issues/54302. func removeAttr(attrs []schema.Attr, t reflect.Type) []schema.Attr { f := make([]schema.Attr, 0, len(attrs)) for _, a := range attrs { if reflect.TypeOf(a) != t { f = append(f, a) } } return f }