dialect/sql: minor changes to allow using Migrate externally (#3316)

This commit is contained in:
Ariel Mashraki
2023-02-14 15:06:36 +02:00
committed by GitHub
parent 4c87e262a6
commit 4e05f76717
2 changed files with 27 additions and 23 deletions

View File

@@ -141,7 +141,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
// Set up connections.
if a.driver != nil {
var err error
a.sqlDialect, err = a.entDialect(a.driver)
a.sqlDialect, err = a.entDialect(ctx, a.driver)
if err != nil {
return err
}
@@ -155,7 +155,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
return err
}
defer c.Close()
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
if err != nil {
return err
}
@@ -169,10 +169,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
return err
}
if a.universalID {
tables = append(tables, NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
)
tables = append(tables, NewTypesTable())
}
var (
err error
@@ -222,7 +219,7 @@ func (a *Atlas) cleanSchema(ctx context.Context, name string, err0 error) (err e
func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error {
if a.driver != nil {
var err error
a.sqlDialect, err = a.entDialect(a.driver)
a.sqlDialect, err = a.entDialect(ctx, a.driver)
if err != nil {
return err
}
@@ -232,7 +229,7 @@ func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error {
return err
}
defer c.Close()
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
if err != nil {
return err
}
@@ -550,9 +547,9 @@ const (
// StateReader returns an atlas migrate.StateReader returning the state as described by the Ent table slice.
func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc {
return func(context.Context) (*schema.Realm, error) {
return func(ctx context.Context) (*schema.Realm, error) {
if a.sqlDialect == nil {
drv, err := a.entDialect(a.driver)
drv, err := a.entDialect(ctx, a.driver)
if err != nil {
return nil, err
}
@@ -628,13 +625,10 @@ func (a *Atlas) init() error {
// create is the Atlas engine based online migration.
func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
if a.universalID {
tables = append(tables, NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
)
tables = append(tables, NewTypesTable())
}
if a.driver != nil {
a.sqlDialect, err = a.entDialect(a.driver)
a.sqlDialect, err = a.entDialect(ctx, a.driver)
if err != nil {
return err
}
@@ -644,7 +638,7 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
return err
}
defer c.Close()
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
if err != nil {
return err
}
@@ -1062,17 +1056,22 @@ func (a *Atlas) symbol(name string) string {
}
// entDialect returns the Ent dialect as configured by the dialect option.
func (a *Atlas) entDialect(drv dialect.Driver) (sqlDialect, error) {
func (a *Atlas) entDialect(ctx context.Context, drv dialect.Driver) (sqlDialect, error) {
var d sqlDialect
switch a.dialect {
case dialect.MySQL:
return &MySQL{Driver: drv}, nil
d = &MySQL{Driver: drv}
case dialect.SQLite:
return &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}, nil
d = &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}
case dialect.Postgres:
return &Postgres{Driver: drv}, nil
d = &Postgres{Driver: drv}
default:
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect)
}
if err := d.init(ctx); err != nil {
return nil, err
}
return d, nil
}
func (a *Atlas) pkRange(et *Table) (int64, error) {

View File

@@ -23,6 +23,13 @@ const (
MaxTypes = math.MaxUint16
)
// NewTypesTable returns a new table for holding the global-id information.
func NewTypesTable() *Table {
return NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
}
// MigrateOption allows configuring Atlas using functional arguments.
type MigrateOption func(*Atlas)
@@ -494,9 +501,7 @@ func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) error {
return err
}
if !exists {
t := NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
t := NewTypesTable()
query, args := m.tBuilder(t).Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("create types table: %w", err)