mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: verify and fix mysql auto-increment on reset (#329)
This commit is contained in:
@@ -118,6 +118,9 @@ func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) e
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.verify(ctx, tx, curr); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.fixture(ctx, tx, curr, t); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -410,6 +413,19 @@ func (m *Migrate) fixture(ctx context.Context, tx dialect.Tx, curr, new *Table)
|
||||
return nil
|
||||
}
|
||||
|
||||
// verify verifies that the auto-increment counter is correct for table with universal-id support.
|
||||
func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error {
|
||||
vr, ok := m.sqlDialect.(verifyRanger)
|
||||
if !ok || !m.universalID {
|
||||
return nil
|
||||
}
|
||||
id := indexOf(m.typeRanges, t.Name)
|
||||
if id == -1 {
|
||||
return nil
|
||||
}
|
||||
return vr.verifyRange(ctx, tx, t.Name, id<<32)
|
||||
}
|
||||
|
||||
// types loads the type list from the database.
|
||||
// If the table does not create, it will create one.
|
||||
func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error {
|
||||
@@ -438,15 +454,9 @@ func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error {
|
||||
}
|
||||
|
||||
func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) error {
|
||||
id := -1
|
||||
// if the table re-created, re-use its range from the past.
|
||||
for i, name := range m.typeRanges {
|
||||
if name == t.Name {
|
||||
id = i
|
||||
break
|
||||
}
|
||||
}
|
||||
// allocate a new id-range.
|
||||
id := indexOf(m.typeRanges, t.Name)
|
||||
// if the table re-created, re-use its range from
|
||||
// the past. otherwise, allocate a new id-range.
|
||||
if id == -1 {
|
||||
if len(m.typeRanges) > MaxTypes {
|
||||
return fmt.Errorf("max number of types exceeded: %d", MaxTypes)
|
||||
@@ -551,6 +561,15 @@ func exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
func indexOf(a []string, s string) int {
|
||||
for i := range a {
|
||||
if a[i] == s {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
type sqlDialect interface {
|
||||
dialect.Driver
|
||||
init(context.Context, dialect.Tx) error
|
||||
@@ -579,3 +598,8 @@ type fkRenamer interface {
|
||||
renameIndex(*Table, *Index, *Index) sql.Querier
|
||||
renameColumn(*Table, *Column, *Column) sql.Querier
|
||||
}
|
||||
|
||||
// verifyRanger wraps the method for verifying global-id range correctness.
|
||||
type verifyRanger interface {
|
||||
verifyRange(context.Context, dialect.Tx, string, int) error
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("mysql: reading table description %v", err)
|
||||
}
|
||||
// call `Close` in cases of failures (`Close` is idempotent).
|
||||
// call Close in cases of failures (Close is idempotent).
|
||||
defer rows.Close()
|
||||
t := NewTable(name)
|
||||
for rows.Next() {
|
||||
@@ -112,6 +112,36 @@ func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, nil)
|
||||
}
|
||||
|
||||
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, name string, expected int) error {
|
||||
if expected == 0 {
|
||||
return nil
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("AUTO_INCREMENT").
|
||||
From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return fmt.Errorf("mysql: query auto_increment %v", err)
|
||||
}
|
||||
// call Close in cases of failures (Close is idempotent).
|
||||
defer rows.Close()
|
||||
actual := &sql.NullInt64{}
|
||||
if err := sql.ScanOne(rows, actual); err != nil {
|
||||
return fmt.Errorf("mysql: scan auto_increment %v", err)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Table is empty and auto-increment is not configured. This can happen
|
||||
// because MySQL (< 8.0) stores the auto-increment counter in main memory
|
||||
// (not persistent), and the value is reset on restart (if table is empty).
|
||||
if actual.Int64 == 0 {
|
||||
return d.setRange(ctx, tx, name, expected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tBuilder returns the MySQL DSL query for table creation.
|
||||
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name).IfNotExists()
|
||||
|
||||
@@ -1040,6 +1040,47 @@ func TestMySQL_Create(t *testing.T) {
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "universal id mismatch with ent_types",
|
||||
tables: []*Table{
|
||||
NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
},
|
||||
options: []MigrateOption{WithGlobalUniqueID(true)},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("ent_types").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
// query ent_types table.
|
||||
mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"type"}).
|
||||
AddRow("deleted").
|
||||
AddRow("users"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
// users table has no changes.
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
|
||||
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""))
|
||||
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM INFORMATION_SCHEMA.STATISTICS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}).
|
||||
AddRow("PRIMARY", "id", "0", "1"))
|
||||
// query the auto-increment value.
|
||||
mock.ExpectQuery(escape("SELECT `AUTO_INCREMENT` FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"AUTO_INCREMENT"}).
|
||||
AddRow(0))
|
||||
// restore the auto-increment counter.
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 4294967296")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user