mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: setrange on custom column name of pks (#333)
This commit is contained in:
@@ -423,7 +423,7 @@ func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error {
|
||||
if id == -1 {
|
||||
return nil
|
||||
}
|
||||
return vr.verifyRange(ctx, tx, t.Name, id<<32)
|
||||
return vr.verifyRange(ctx, tx, t, id<<32)
|
||||
}
|
||||
|
||||
// types loads the type list from the database.
|
||||
@@ -470,7 +470,7 @@ func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) err
|
||||
m.typeRanges = append(m.typeRanges, t.Name)
|
||||
}
|
||||
// set the id offset for table.
|
||||
return m.setRange(ctx, tx, t.Name, id<<32)
|
||||
return m.setRange(ctx, tx, t, id<<32)
|
||||
}
|
||||
|
||||
// fkColumn returns the column name of a foreign-key.
|
||||
@@ -576,7 +576,7 @@ type sqlDialect interface {
|
||||
table(context.Context, dialect.Tx, string) (*Table, error)
|
||||
tableExist(context.Context, dialect.Tx, string) (bool, error)
|
||||
fkExist(context.Context, dialect.Tx, string) (bool, error)
|
||||
setRange(context.Context, dialect.Tx, string, int) error
|
||||
setRange(context.Context, dialect.Tx, *Table, int) error
|
||||
dropIndex(context.Context, dialect.Tx, *Index, string) error
|
||||
// table, column and index builder per dialect.
|
||||
cType(*Column) string
|
||||
@@ -601,5 +601,5 @@ type fkRenamer interface {
|
||||
|
||||
// verifyRanger wraps the method for verifying global-id range correctness.
|
||||
type verifyRanger interface {
|
||||
verifyRange(context.Context, dialect.Tx, string, int) error
|
||||
verifyRange(context.Context, dialect.Tx, *Table, int) error
|
||||
}
|
||||
|
||||
@@ -108,18 +108,18 @@ func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Ind
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, nil)
|
||||
func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []interface{}{}, nil)
|
||||
}
|
||||
|
||||
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, name string, expected int) error {
|
||||
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expected int) error {
|
||||
if expected == 0 {
|
||||
return nil
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("AUTO_INCREMENT").
|
||||
From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", t.Name)).
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return fmt.Errorf("mysql: query auto_increment %v", err)
|
||||
@@ -137,7 +137,7 @@ func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, name string, exp
|
||||
// because MySQL (< 8.0) stores the auto-increment counter in main memory
|
||||
// (not persistent), and the value is reset on restart (if table is empty).
|
||||
if actual.Int64 == 0 {
|
||||
return d.setRange(ctx, tx, name, expected)
|
||||
return d.setRange(ctx, tx, t, expected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -62,11 +62,15 @@ func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (boo
|
||||
}
|
||||
|
||||
// setRange sets restart the identity column to the given offset. Used by the universal-id option.
|
||||
func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
|
||||
func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
|
||||
if value == 0 {
|
||||
value = 1 // RESTART value cannot be < 1.
|
||||
}
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN id RESTART WITH %d", name, value), []interface{}{}, nil)
|
||||
pk := "id"
|
||||
if len(t.PrimaryKey) == 1 {
|
||||
pk = t.PrimaryKey[0].Name
|
||||
}
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s RESTART WITH %d", t.Name, pk, value), []interface{}{}, nil)
|
||||
}
|
||||
|
||||
// table loads the current table description from the database.
|
||||
|
||||
@@ -45,19 +45,19 @@ func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bo
|
||||
// whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables)
|
||||
// only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence"
|
||||
// table, we updated it. Otherwise, we insert a new value.
|
||||
func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
|
||||
func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
|
||||
query, args := sql.Select().Count().
|
||||
From(sql.Table("sqlite_sequence")).
|
||||
Where(sql.EQ("name", name)).
|
||||
Where(sql.EQ("name", t.Name)).
|
||||
Query()
|
||||
exists, err := exist(ctx, tx, query, args...)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case exists:
|
||||
query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", name)).Query()
|
||||
query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", t.Name)).Query()
|
||||
default: // !exists
|
||||
query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(name, value).Query()
|
||||
query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(t.Name, value).Query()
|
||||
}
|
||||
return tx.Exec(ctx, query, args, nil)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user