dialect/sql/schema: setrange on custom column name of pks (#333)

This commit is contained in:
Ariel Mashraki
2020-02-09 09:41:26 +02:00
committed by GitHub
parent 48d33fde9d
commit 26440c2bc9
12 changed files with 60 additions and 26 deletions

View File

@@ -423,7 +423,7 @@ func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error {
if id == -1 {
return nil
}
return vr.verifyRange(ctx, tx, t.Name, id<<32)
return vr.verifyRange(ctx, tx, t, id<<32)
}
// types loads the type list from the database.
@@ -470,7 +470,7 @@ func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) err
m.typeRanges = append(m.typeRanges, t.Name)
}
// set the id offset for table.
return m.setRange(ctx, tx, t.Name, id<<32)
return m.setRange(ctx, tx, t, id<<32)
}
// fkColumn returns the column name of a foreign-key.
@@ -576,7 +576,7 @@ type sqlDialect interface {
table(context.Context, dialect.Tx, string) (*Table, error)
tableExist(context.Context, dialect.Tx, string) (bool, error)
fkExist(context.Context, dialect.Tx, string) (bool, error)
setRange(context.Context, dialect.Tx, string, int) error
setRange(context.Context, dialect.Tx, *Table, int) error
dropIndex(context.Context, dialect.Tx, *Index, string) error
// table, column and index builder per dialect.
cType(*Column) string
@@ -601,5 +601,5 @@ type fkRenamer interface {
// verifyRanger wraps the method for verifying global-id range correctness.
type verifyRanger interface {
verifyRange(context.Context, dialect.Tx, string, int) error
verifyRange(context.Context, dialect.Tx, *Table, int) error
}

View File

@@ -108,18 +108,18 @@ func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Ind
return idx, nil
}
func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, nil)
func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []interface{}{}, nil)
}
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, name string, expected int) error {
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expected int) error {
if expected == 0 {
return nil
}
rows := &sql.Rows{}
query, args := sql.Select("AUTO_INCREMENT").
From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", t.Name)).
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return fmt.Errorf("mysql: query auto_increment %v", err)
@@ -137,7 +137,7 @@ func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, name string, exp
// because MySQL (< 8.0) stores the auto-increment counter in main memory
// (not persistent), and the value is reset on restart (if table is empty).
if actual.Int64 == 0 {
return d.setRange(ctx, tx, name, expected)
return d.setRange(ctx, tx, t, expected)
}
return nil
}

View File

@@ -62,11 +62,15 @@ func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (boo
}
// setRange sets restart the identity column to the given offset. Used by the universal-id option.
func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
if value == 0 {
value = 1 // RESTART value cannot be < 1.
}
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN id RESTART WITH %d", name, value), []interface{}{}, nil)
pk := "id"
if len(t.PrimaryKey) == 1 {
pk = t.PrimaryKey[0].Name
}
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s RESTART WITH %d", t.Name, pk, value), []interface{}{}, nil)
}
// table loads the current table description from the database.

View File

@@ -45,19 +45,19 @@ func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bo
// whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables)
// only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence"
// table, we updated it. Otherwise, we insert a new value.
func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
query, args := sql.Select().Count().
From(sql.Table("sqlite_sequence")).
Where(sql.EQ("name", name)).
Where(sql.EQ("name", t.Name)).
Query()
exists, err := exist(ctx, tx, query, args...)
switch {
case err != nil:
return err
case exists:
query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", name)).Query()
query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", t.Name)).Query()
default: // !exists
query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(name, value).Query()
query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(t.Name, value).Query()
}
return tx.Exec(ctx, query, args, nil)
}