From 0fc310e6001e01016af51814aaa2d7dedfed2a0f Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sun, 24 Nov 2019 05:18:50 -0800 Subject: [PATCH] sql/dialect/schema: ignore foreign-keys in index dropping Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/188 Reviewed By: alexsn Differential Revision: D18676877 fbshipit-source-id: 0babe457edadfa46dcbc7c7478d4468c48f84361 --- dialect/sql/schema/migrate.go | 31 ++++++++++++++++++-- dialect/sql/schema/mysql_test.go | 50 ++++++++++++++++++++++++++++++++ dialect/sql/schema/schema.go | 32 +++++++------------- 3 files changed, 89 insertions(+), 24 deletions(-) diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index 2673a79ba..99333cfc4 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -109,7 +109,7 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error { func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error { for _, t := range tables { - t.setup() + m.setupTable(t) switch exist, err := m.tableExist(ctx, tx, t.Name); { case err != nil: return err @@ -154,7 +154,6 @@ func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) e } fks := make([]*ForeignKey, 0, len(t.ForeignKeys)) for _, fk := range t.ForeignKeys { - fk.Symbol = m.symbol(fk.Symbol) exist, err := m.fkExist(ctx, tx, fk.Symbol) if err != nil { return err @@ -310,7 +309,9 @@ func (m *Migrate) changeSet(curr, new *Table) (*changes, error) { // drop indexes. for _, idx1 := range curr.Indexes { - if _, ok := new.index(idx1.Name); !ok { + _, ok1 := new.fk(idx1.Name) + _, ok2 := new.index(idx1.Name) + if !ok1 && !ok2 { change.index.drop.append(idx1) } } @@ -370,6 +371,30 @@ func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) err return m.setRange(ctx, tx, t.Name, id<<32) } +// setup ensures the table is configured properly, like table columns +// are linked to their indexes, and PKs columns are defined. +func (m *Migrate) setupTable(t *Table) { + if t.columns == nil { + t.columns = make(map[string]*Column, len(t.Columns)) + } + for _, c := range t.Columns { + t.columns[c.Name] = c + } + for _, idx := range t.Indexes { + for _, c := range idx.Columns { + c.indexes.append(idx) + } + } + for _, pk := range t.PrimaryKey { + c := t.columns[pk.Name] + c.Key = PrimaryKey + pk.Key = PrimaryKey + } + for _, fk := range t.ForeignKeys { + fk.Symbol = m.symbol(fk.Symbol) + } +} + // symbol makes sure the symbol length is not longer than the maxlength in the dialect. func (m *Migrate) symbol(name string) string { size := 64 diff --git a/dialect/sql/schema/mysql_test.go b/dialect/sql/schema/mysql_test.go index b37b8f788..67af43358 100644 --- a/dialect/sql/schema/mysql_test.go +++ b/dialect/sql/schema/mysql_test.go @@ -708,6 +708,56 @@ func TestMySQL_Create(t *testing.T) { mock.ExpectCommit() }, }, + { + name: "ignore foreign keys on index dropping", + tables: []*Table{ + { + Name: "users", + Columns: []*Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "parent_id", Type: field.TypeInt, Nullable: true}, + }, + PrimaryKey: []*Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + }, + ForeignKeys: []*ForeignKey{ + { + Symbol: "parent_id", + Columns: []*Column{ + {Name: "parent_id", Type: field.TypeInt, Nullable: true}, + }, + }, + }, + }, + }, + options: []MigrateOption{WithDropIndex(true)}, + before: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")). + WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23")) + mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). + WithArgs("users"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). + WithArgs("users"). + WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). + AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). + AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "")) + mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM INFORMATION_SCHEMA.STATISTICS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). + WithArgs("users"). + WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). + AddRow("PRIMARY", "id", "0", "1"). + AddRow("old_index", "old", "0", "1"). + AddRow("parent_id", "parent_id", "0", "1")) + // drop the unique index. + mock.ExpectExec(escape("DROP INDEX `old_index` ON `users`")). + WillReturnResult(sqlmock.NewResult(0, 1)) + // foreign key already exist. + mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + mock.ExpectCommit() + }, + }, { name: "add edge to table", tables: func() []*Table { diff --git a/dialect/sql/schema/schema.go b/dialect/sql/schema/schema.go index b195c3f54..4c76122f6 100644 --- a/dialect/sql/schema/schema.go +++ b/dialect/sql/schema/schema.go @@ -82,27 +82,6 @@ func (t *Table) AddIndex(name string, unique bool, columns []string) *Table { return t } -// setup ensures the table is configured properly, like table columns -// are linked to their indexes, and PKs columns are defined. -func (t *Table) setup() { - if t.columns == nil { - t.columns = make(map[string]*Column, len(t.Columns)) - } - for _, c := range t.Columns { - t.columns[c.Name] = c - } - for _, idx := range t.Indexes { - for _, c := range idx.Columns { - c.indexes.append(idx) - } - } - for _, pk := range t.PrimaryKey { - c := t.columns[pk.Name] - c.Key = PrimaryKey - pk.Key = PrimaryKey - } -} - // column returns a table column by its name. // faster than map lookup for most cases. func (t *Table) column(name string) (*Column, bool) { @@ -135,6 +114,17 @@ func (t *Table) index(name string) (*Index, bool) { return nil, false } +// fk returns a table foreign-key by its symbol. +// faster than map lookup for most cases. +func (t *Table) fk(symbol string) (*ForeignKey, bool) { + for _, fk := range t.ForeignKeys { + if fk.Symbol == symbol { + return fk, true + } + } + return nil, false +} + // Column schema definition for SQL dialects. type Column struct { Name string // column name.