dialect/sql: minor changes to allow using Migrate externally (#3316)

This commit is contained in:
Ariel Mashraki
2023-02-14 15:06:36 +02:00
committed by GitHub
parent 4c87e262a6
commit 4e05f76717
2 changed files with 27 additions and 23 deletions

View File

@@ -141,7 +141,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
// Set up connections.
if a.driver != nil {
var err error
a.sqlDialect, err = a.entDialect(a.driver)
a.sqlDialect, err = a.entDialect(ctx, a.driver)
if err != nil {
return err
}
@@ -155,7 +155,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
return err
}
defer c.Close()
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
if err != nil {
return err
}
@@ -169,10 +169,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
return err
}
if a.universalID {
tables = append(tables, NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
)
tables = append(tables, NewTypesTable())
}
var (
err error
@@ -222,7 +219,7 @@ func (a *Atlas) cleanSchema(ctx context.Context, name string, err0 error) (err e
func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error {
if a.driver != nil {
var err error
a.sqlDialect, err = a.entDialect(a.driver)
a.sqlDialect, err = a.entDialect(ctx, a.driver)
if err != nil {
return err
}
@@ -232,7 +229,7 @@ func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error {
return err
}
defer c.Close()
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
if err != nil {
return err
}
@@ -550,9 +547,9 @@ const (
// StateReader returns an atlas migrate.StateReader returning the state as described by the Ent table slice.
func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc {
return func(context.Context) (*schema.Realm, error) {
return func(ctx context.Context) (*schema.Realm, error) {
if a.sqlDialect == nil {
drv, err := a.entDialect(a.driver)
drv, err := a.entDialect(ctx, a.driver)
if err != nil {
return nil, err
}
@@ -628,13 +625,10 @@ func (a *Atlas) init() error {
// create is the Atlas engine based online migration.
func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
if a.universalID {
tables = append(tables, NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
)
tables = append(tables, NewTypesTable())
}
if a.driver != nil {
a.sqlDialect, err = a.entDialect(a.driver)
a.sqlDialect, err = a.entDialect(ctx, a.driver)
if err != nil {
return err
}
@@ -644,7 +638,7 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
return err
}
defer c.Close()
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
if err != nil {
return err
}
@@ -1062,17 +1056,22 @@ func (a *Atlas) symbol(name string) string {
}
// entDialect returns the Ent dialect as configured by the dialect option.
func (a *Atlas) entDialect(drv dialect.Driver) (sqlDialect, error) {
func (a *Atlas) entDialect(ctx context.Context, drv dialect.Driver) (sqlDialect, error) {
var d sqlDialect
switch a.dialect {
case dialect.MySQL:
return &MySQL{Driver: drv}, nil
d = &MySQL{Driver: drv}
case dialect.SQLite:
return &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}, nil
d = &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}
case dialect.Postgres:
return &Postgres{Driver: drv}, nil
d = &Postgres{Driver: drv}
default:
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect)
}
if err := d.init(ctx); err != nil {
return nil, err
}
return d, nil
}
func (a *Atlas) pkRange(et *Table) (int64, error) {