From 4e05f767170ef352810a38731829be494b3ae08a Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Tue, 14 Feb 2023 15:06:36 +0200 Subject: [PATCH] dialect/sql: minor changes to allow using Migrate externally (#3316) --- dialect/sql/schema/atlas.go | 39 +++++++++++++++++------------------ dialect/sql/schema/migrate.go | 11 +++++++--- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/dialect/sql/schema/atlas.go b/dialect/sql/schema/atlas.go index 61def2606..d092d7cdf 100644 --- a/dialect/sql/schema/atlas.go +++ b/dialect/sql/schema/atlas.go @@ -141,7 +141,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er // Set up connections. if a.driver != nil { var err error - a.sqlDialect, err = a.entDialect(a.driver) + a.sqlDialect, err = a.entDialect(ctx, a.driver) if err != nil { return err } @@ -155,7 +155,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er return err } defer c.Close() - a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB)) + a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } @@ -169,10 +169,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er return err } if a.universalID { - tables = append(tables, NewTable(TypeTable). - AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}). - AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}), - ) + tables = append(tables, NewTypesTable()) } var ( err error @@ -222,7 +219,7 @@ func (a *Atlas) cleanSchema(ctx context.Context, name string, err0 error) (err e func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error { if a.driver != nil { var err error - a.sqlDialect, err = a.entDialect(a.driver) + a.sqlDialect, err = a.entDialect(ctx, a.driver) if err != nil { return err } @@ -232,7 +229,7 @@ func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error { return err } defer c.Close() - a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB)) + a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } @@ -550,9 +547,9 @@ const ( // StateReader returns an atlas migrate.StateReader returning the state as described by the Ent table slice. func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc { - return func(context.Context) (*schema.Realm, error) { + return func(ctx context.Context) (*schema.Realm, error) { if a.sqlDialect == nil { - drv, err := a.entDialect(a.driver) + drv, err := a.entDialect(ctx, a.driver) if err != nil { return nil, err } @@ -628,13 +625,10 @@ func (a *Atlas) init() error { // create is the Atlas engine based online migration. func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) { if a.universalID { - tables = append(tables, NewTable(TypeTable). - AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}). - AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}), - ) + tables = append(tables, NewTypesTable()) } if a.driver != nil { - a.sqlDialect, err = a.entDialect(a.driver) + a.sqlDialect, err = a.entDialect(ctx, a.driver) if err != nil { return err } @@ -644,7 +638,7 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) { return err } defer c.Close() - a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB)) + a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } @@ -1062,17 +1056,22 @@ func (a *Atlas) symbol(name string) string { } // entDialect returns the Ent dialect as configured by the dialect option. -func (a *Atlas) entDialect(drv dialect.Driver) (sqlDialect, error) { +func (a *Atlas) entDialect(ctx context.Context, drv dialect.Driver) (sqlDialect, error) { + var d sqlDialect switch a.dialect { case dialect.MySQL: - return &MySQL{Driver: drv}, nil + d = &MySQL{Driver: drv} case dialect.SQLite: - return &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}, nil + d = &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys} case dialect.Postgres: - return &Postgres{Driver: drv}, nil + d = &Postgres{Driver: drv} default: return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect) } + if err := d.init(ctx); err != nil { + return nil, err + } + return d, nil } func (a *Atlas) pkRange(et *Table) (int64, error) { diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index a6f0d1285..2fe512324 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -23,6 +23,13 @@ const ( MaxTypes = math.MaxUint16 ) +// NewTypesTable returns a new table for holding the global-id information. +func NewTypesTable() *Table { + return NewTable(TypeTable). + AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}). + AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}) +} + // MigrateOption allows configuring Atlas using functional arguments. type MigrateOption func(*Atlas) @@ -494,9 +501,7 @@ func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) error { return err } if !exists { - t := NewTable(TypeTable). - AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}). - AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}) + t := NewTypesTable() query, args := m.tBuilder(t).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create types table: %w", err)