dialect/sql/schema: setup tables before running migrate diff (#2703)

Keep the same API as schema.Create
This commit is contained in:
Ariel Mashraki
2022-06-30 09:55:40 +03:00
committed by GitHub
parent 6793d74da7
commit 8416fb502d
2 changed files with 28 additions and 25 deletions

View File

@@ -153,9 +153,7 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
// Note that SQLite dialect does not support (this moment) the "append-only" mode describe above,
// since it's used only for testing.
func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
for _, t := range tables {
m.setupTable(t)
}
m.setupTables(tables)
var creator Creator = CreateFunc(m.create)
if m.atlas.enabled {
creator = CreateFunc(m.atCreate)
@@ -212,6 +210,7 @@ func (m *Migrate) NamedDiff(ctx context.Context, name string, tables ...*Table)
m.dbTypeRanges = nil
}()
}
m.setupTables(tables)
plan, err := m.atDiff(ctx, m, name, tables...)
if err != nil {
return err
@@ -642,28 +641,30 @@ func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (
// setup ensures the table is configured properly, like table columns
// are linked to their indexes, and PKs columns are defined.
func (m *Migrate) setupTable(t *Table) {
if t.columns == nil {
t.columns = make(map[string]*Column, len(t.Columns))
}
for _, c := range t.Columns {
t.columns[c.Name] = c
}
for _, idx := range t.Indexes {
idx.Name = m.symbol(idx.Name)
for _, c := range idx.Columns {
c.indexes.append(idx)
func (m *Migrate) setupTables(tables []*Table) {
for _, t := range tables {
if t.columns == nil {
t.columns = make(map[string]*Column, len(t.Columns))
}
}
for _, pk := range t.PrimaryKey {
c := t.columns[pk.Name]
c.Key = PrimaryKey
pk.Key = PrimaryKey
}
for _, fk := range t.ForeignKeys {
fk.Symbol = m.symbol(fk.Symbol)
for i := range fk.Columns {
fk.Columns[i].foreign = fk
for _, c := range t.Columns {
t.columns[c.Name] = c
}
for _, idx := range t.Indexes {
idx.Name = m.symbol(idx.Name)
for _, c := range idx.Columns {
c.indexes.append(idx)
}
}
for _, pk := range t.PrimaryKey {
c := t.columns[pk.Name]
c.Key = PrimaryKey
pk.Key = PrimaryKey
}
for _, fk := range t.ForeignKeys {
fk.Symbol = m.symbol(fk.Symbol)
for i := range fk.Columns {
fk.Columns[i].foreign = fk
}
}
}
}

View File

@@ -128,13 +128,15 @@ func TestMigrate_Diff(t *testing.T) {
require.IsType(t, &dirTypeStore{}, m.typeStore)
require.NoError(t, m.Diff(context.Background(),
&Table{Name: "users", Columns: idCol, PrimaryKey: idCol},
&Table{Name: "groups", Columns: idCol, PrimaryKey: idCol},
&Table{Name: "groups", Columns: idCol, PrimaryKey: idCol, Indexes: []*Index{{Name: "short", Columns: idCol}, {Name: "long_" + strings.Repeat("_", 60), Columns: idCol}}},
))
requireFileEqual(t, filepath.Join(p, ".ent_types"), atlasDirective+"users,groups")
changesSQL := strings.Join([]string{
"CREATE TABLE `users` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);",
"CREATE TABLE `groups` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);",
fmt.Sprintf("INSERT INTO sqlite_sequence (name, seq) VALUES (\"groups\", %d);", 1<<32),
"CREATE INDEX `short` ON `groups` (`id`);",
"CREATE INDEX `long____________________________1cb2e7e47a309191385af4ad320875b1` ON `groups` (`id`);",
"CREATE TABLE `ent_types` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `type` text NOT NULL);",
"CREATE UNIQUE INDEX `ent_types_type_key` ON `ent_types` (`type`);",
"INSERT INTO `ent_types` (`type`) VALUES ('users'), ('groups');", "",