mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql: minor changes to allow using Migrate externally (#3316)
This commit is contained in:
@@ -141,7 +141,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
|
||||
// Set up connections.
|
||||
if a.driver != nil {
|
||||
var err error
|
||||
a.sqlDialect, err = a.entDialect(a.driver)
|
||||
a.sqlDialect, err = a.entDialect(ctx, a.driver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -155,7 +155,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
|
||||
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -169,10 +169,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
|
||||
return err
|
||||
}
|
||||
if a.universalID {
|
||||
tables = append(tables, NewTable(TypeTable).
|
||||
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
|
||||
)
|
||||
tables = append(tables, NewTypesTable())
|
||||
}
|
||||
var (
|
||||
err error
|
||||
@@ -222,7 +219,7 @@ func (a *Atlas) cleanSchema(ctx context.Context, name string, err0 error) (err e
|
||||
func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error {
|
||||
if a.driver != nil {
|
||||
var err error
|
||||
a.sqlDialect, err = a.entDialect(a.driver)
|
||||
a.sqlDialect, err = a.entDialect(ctx, a.driver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -232,7 +229,7 @@ func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error {
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
|
||||
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -550,9 +547,9 @@ const (
|
||||
|
||||
// StateReader returns an atlas migrate.StateReader returning the state as described by the Ent table slice.
|
||||
func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc {
|
||||
return func(context.Context) (*schema.Realm, error) {
|
||||
return func(ctx context.Context) (*schema.Realm, error) {
|
||||
if a.sqlDialect == nil {
|
||||
drv, err := a.entDialect(a.driver)
|
||||
drv, err := a.entDialect(ctx, a.driver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -628,13 +625,10 @@ func (a *Atlas) init() error {
|
||||
// create is the Atlas engine based online migration.
|
||||
func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
|
||||
if a.universalID {
|
||||
tables = append(tables, NewTable(TypeTable).
|
||||
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
|
||||
)
|
||||
tables = append(tables, NewTypesTable())
|
||||
}
|
||||
if a.driver != nil {
|
||||
a.sqlDialect, err = a.entDialect(a.driver)
|
||||
a.sqlDialect, err = a.entDialect(ctx, a.driver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -644,7 +638,7 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
|
||||
a.sqlDialect, err = a.entDialect(ctx, entsql.OpenDB(a.dialect, c.DB))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1062,17 +1056,22 @@ func (a *Atlas) symbol(name string) string {
|
||||
}
|
||||
|
||||
// entDialect returns the Ent dialect as configured by the dialect option.
|
||||
func (a *Atlas) entDialect(drv dialect.Driver) (sqlDialect, error) {
|
||||
func (a *Atlas) entDialect(ctx context.Context, drv dialect.Driver) (sqlDialect, error) {
|
||||
var d sqlDialect
|
||||
switch a.dialect {
|
||||
case dialect.MySQL:
|
||||
return &MySQL{Driver: drv}, nil
|
||||
d = &MySQL{Driver: drv}
|
||||
case dialect.SQLite:
|
||||
return &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}, nil
|
||||
d = &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}
|
||||
case dialect.Postgres:
|
||||
return &Postgres{Driver: drv}, nil
|
||||
d = &Postgres{Driver: drv}
|
||||
default:
|
||||
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect)
|
||||
}
|
||||
if err := d.init(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (a *Atlas) pkRange(et *Table) (int64, error) {
|
||||
|
||||
@@ -23,6 +23,13 @@ const (
|
||||
MaxTypes = math.MaxUint16
|
||||
)
|
||||
|
||||
// NewTypesTable returns a new table for holding the global-id information.
|
||||
func NewTypesTable() *Table {
|
||||
return NewTable(TypeTable).
|
||||
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
|
||||
}
|
||||
|
||||
// MigrateOption allows configuring Atlas using functional arguments.
|
||||
type MigrateOption func(*Atlas)
|
||||
|
||||
@@ -494,9 +501,7 @@ func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
t := NewTable(TypeTable).
|
||||
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
|
||||
t := NewTypesTable()
|
||||
query, args := m.tBuilder(t).Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("create types table: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user