diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index 95230ba3d..d4e51dd79 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -118,6 +118,9 @@ func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) e if err != nil { return err } + if err := m.verify(ctx, tx, curr); err != nil { + return err + } if err := m.fixture(ctx, tx, curr, t); err != nil { return err } @@ -410,6 +413,19 @@ func (m *Migrate) fixture(ctx context.Context, tx dialect.Tx, curr, new *Table) return nil } +// verify verifies that the auto-increment counter is correct for table with universal-id support. +func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error { + vr, ok := m.sqlDialect.(verifyRanger) + if !ok || !m.universalID { + return nil + } + id := indexOf(m.typeRanges, t.Name) + if id == -1 { + return nil + } + return vr.verifyRange(ctx, tx, t.Name, id<<32) +} + // types loads the type list from the database. // If the table does not create, it will create one. func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error { @@ -438,15 +454,9 @@ func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error { } func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) error { - id := -1 - // if the table re-created, re-use its range from the past. - for i, name := range m.typeRanges { - if name == t.Name { - id = i - break - } - } - // allocate a new id-range. + id := indexOf(m.typeRanges, t.Name) + // if the table re-created, re-use its range from + // the past. otherwise, allocate a new id-range. if id == -1 { if len(m.typeRanges) > MaxTypes { return fmt.Errorf("max number of types exceeded: %d", MaxTypes) @@ -551,6 +561,15 @@ func exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{} return n > 0, nil } +func indexOf(a []string, s string) int { + for i := range a { + if a[i] == s { + return i + } + } + return -1 +} + type sqlDialect interface { dialect.Driver init(context.Context, dialect.Tx) error @@ -579,3 +598,8 @@ type fkRenamer interface { renameIndex(*Table, *Index, *Index) sql.Querier renameColumn(*Table, *Column, *Column) sql.Querier } + +// verifyRanger wraps the method for verifying global-id range correctness. +type verifyRanger interface { + verifyRange(context.Context, dialect.Tx, string, int) error +} diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index 776c06bfb..86898db89 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -62,7 +62,7 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("mysql: reading table description %v", err) } - // call `Close` in cases of failures (`Close` is idempotent). + // call Close in cases of failures (Close is idempotent). defer rows.Close() t := NewTable(name) for rows.Next() { @@ -112,6 +112,36 @@ func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, nil) } +func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, name string, expected int) error { + if expected == 0 { + return nil + } + rows := &sql.Rows{} + query, args := sql.Select("AUTO_INCREMENT"). + From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()). + Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)). + Query() + if err := tx.Query(ctx, query, args, rows); err != nil { + return fmt.Errorf("mysql: query auto_increment %v", err) + } + // call Close in cases of failures (Close is idempotent). + defer rows.Close() + actual := &sql.NullInt64{} + if err := sql.ScanOne(rows, actual); err != nil { + return fmt.Errorf("mysql: scan auto_increment %v", err) + } + if err := rows.Close(); err != nil { + return err + } + // Table is empty and auto-increment is not configured. This can happen + // because MySQL (< 8.0) stores the auto-increment counter in main memory + // (not persistent), and the value is reset on restart (if table is empty). + if actual.Int64 == 0 { + return d.setRange(ctx, tx, name, expected) + } + return nil +} + // tBuilder returns the MySQL DSL query for table creation. func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { b := sql.CreateTable(t.Name).IfNotExists() diff --git a/dialect/sql/schema/mysql_test.go b/dialect/sql/schema/mysql_test.go index cb1fddf63..886449898 100644 --- a/dialect/sql/schema/mysql_test.go +++ b/dialect/sql/schema/mysql_test.go @@ -1040,6 +1040,47 @@ func TestMySQL_Create(t *testing.T) { mock.ExpectCommit() }, }, + { + name: "universal id mismatch with ent_types", + tables: []*Table{ + NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), + }, + options: []MigrateOption{WithGlobalUniqueID(true)}, + before: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")). + WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23")) + mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). + WithArgs("ent_types"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + // query ent_types table. + mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). + WillReturnRows(sqlmock.NewRows([]string{"type"}). + AddRow("deleted"). + AddRow("users")) + mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). + WithArgs("users"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + // users table has no changes. + mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). + WithArgs("users"). + WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). + AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "")) + mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM INFORMATION_SCHEMA.STATISTICS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). + WithArgs("users"). + WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). + AddRow("PRIMARY", "id", "0", "1")) + // query the auto-increment value. + mock.ExpectQuery(escape("SELECT `AUTO_INCREMENT` FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). + WithArgs("users"). + WillReturnRows(sqlmock.NewRows([]string{"AUTO_INCREMENT"}). + AddRow(0)) + // restore the auto-increment counter. + mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 4294967296")). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {