dialect/sql/schema: setup tables before running migrate diff (#2703)

Keep the same API as schema.Create
This commit is contained in:
Ariel Mashraki
2022-06-30 09:55:40 +03:00
committed by GitHub
parent 6793d74da7
commit 8416fb502d
2 changed files with 28 additions and 25 deletions

View File

@@ -153,9 +153,7 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
// Note that SQLite dialect does not support (this moment) the "append-only" mode describe above,
// since it's used only for testing.
func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
for _, t := range tables {
m.setupTable(t)
}
m.setupTables(tables)
var creator Creator = CreateFunc(m.create)
if m.atlas.enabled {
creator = CreateFunc(m.atCreate)
@@ -212,6 +210,7 @@ func (m *Migrate) NamedDiff(ctx context.Context, name string, tables ...*Table)
m.dbTypeRanges = nil
}()
}
m.setupTables(tables)
plan, err := m.atDiff(ctx, m, name, tables...)
if err != nil {
return err
@@ -642,28 +641,30 @@ func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (
// setup ensures the table is configured properly, like table columns
// are linked to their indexes, and PKs columns are defined.
func (m *Migrate) setupTable(t *Table) {
if t.columns == nil {
t.columns = make(map[string]*Column, len(t.Columns))
}
for _, c := range t.Columns {
t.columns[c.Name] = c
}
for _, idx := range t.Indexes {
idx.Name = m.symbol(idx.Name)
for _, c := range idx.Columns {
c.indexes.append(idx)
func (m *Migrate) setupTables(tables []*Table) {
for _, t := range tables {
if t.columns == nil {
t.columns = make(map[string]*Column, len(t.Columns))
}
}
for _, pk := range t.PrimaryKey {
c := t.columns[pk.Name]
c.Key = PrimaryKey
pk.Key = PrimaryKey
}
for _, fk := range t.ForeignKeys {
fk.Symbol = m.symbol(fk.Symbol)
for i := range fk.Columns {
fk.Columns[i].foreign = fk
for _, c := range t.Columns {
t.columns[c.Name] = c
}
for _, idx := range t.Indexes {
idx.Name = m.symbol(idx.Name)
for _, c := range idx.Columns {
c.indexes.append(idx)
}
}
for _, pk := range t.PrimaryKey {
c := t.columns[pk.Name]
c.Key = PrimaryKey
pk.Key = PrimaryKey
}
for _, fk := range t.ForeignKeys {
fk.Symbol = m.symbol(fk.Symbol)
for i := range fk.Columns {
fk.Columns[i].foreign = fk
}
}
}
}