entc/gen: change name format for edge fks (#286)

* entc/gen: change name format for edge fks

* dialect/sql/schema: add fixture support for mysql and postgres

* sql/dialect: merge fkcolumn queries to one for the 2 dialects
This commit is contained in:
Ariel Mashraki
2020-02-03 15:41:55 +02:00
committed by GitHub
parent c70f1017e3
commit b4255998bf
113 changed files with 1082 additions and 431 deletions

View File

@@ -238,6 +238,12 @@ func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter {
return t
}
// RenameColumn appends the `RENAME COLUMN` clause to the given `ALTER TABLE` statement.
func (t *TableAlter) RenameColumn(old, new string) *TableAlter {
t.Queries = append(t.Queries, Raw(fmt.Sprintf("RENAME COLUMN %s TO %s", t.Quote(old), t.Quote(new))))
return t
}
// ModifyColumns calls ModifyColumn with each of the given builders.
func (t *TableAlter) ModifyColumns(cs ...*ColumnBuilder) *TableAlter {
for _, c := range cs {

View File

@@ -19,24 +19,32 @@ type ColumnScanner interface {
Columns() ([]string, error)
}
// ScanInt64 scans and returns an int64 from the rows columns.
func ScanInt64(rows ColumnScanner) (int64, error) {
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
func ScanOne(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()
if err != nil {
return 0, fmt.Errorf("sql/scan: failed getting column names: %v", err)
return fmt.Errorf("sql/scan: failed getting column names: %v", err)
}
if n := len(columns); n != 1 {
return 0, fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
return fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
}
if !rows.Next() {
return 0, sql.ErrNoRows
return sql.ErrNoRows
}
var n int64
if err := rows.Scan(&n); err != nil {
return 0, err
if err := rows.Scan(v); err != nil {
return err
}
if rows.Next() {
return 0, fmt.Errorf("sql/scan: expect exactly one row in result set")
return fmt.Errorf("sql/scan: expect exactly one row in result set")
}
return nil
}
// ScanInt64 scans and returns an int64 from the rows columns.
func ScanInt64(rows ColumnScanner) (int64, error) {
var n int64
if err := ScanOne(rows, &n); err != nil {
return 0, err
}
return n, nil
}
@@ -50,6 +58,15 @@ func ScanInt(rows ColumnScanner) (int, error) {
return int(n), nil
}
// ScanString scans and returns a string from the rows columns.
func ScanString(rows ColumnScanner) (string, error) {
var s string
if err := ScanOne(rows, &s); err != nil {
return "", err
}
return s, nil
}
// ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice.
func ScanSlice(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()

View File

@@ -153,6 +153,26 @@ func TestScanInt64(t *testing.T) {
require.EqualValues(t, 10, n)
}
func TestScanOne(t *testing.T) {
mock := sqlmock.NewRows([]string{"name"}).
AddRow("10").
AddRow("20")
err := ScanOne(toRows(mock), new(string))
require.Error(t, err, "multiple lines")
mock = sqlmock.NewRows([]string{"name"}).
AddRow("10")
err = ScanOne(toRows(mock), "")
require.Error(t, err, "not a pointer")
mock = sqlmock.NewRows([]string{"name"}).
AddRow("10")
var s string
err = ScanOne(toRows(mock), &s)
require.NoError(t, err)
require.Equal(t, "10", s)
}
func TestInterface(t *testing.T) {
mock := sqlmock.NewRows([]string{"age"}).
AddRow("10").

View File

@@ -118,6 +118,9 @@ func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) e
if err != nil {
return err
}
if err := m.fixture(ctx, tx, curr, t); err != nil {
return err
}
change, err := m.changeSet(curr, t)
if err != nil {
return err
@@ -206,7 +209,7 @@ func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change
b.DropColumn(sql.Dialect(m.Dialect()).Column(c.Name))
}
}
// if there's actual action to execute on ALTER TABLE.
// If there's actual action to execute on ALTER TABLE.
if len(b.Queries) != 0 {
query, args := b.Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
@@ -332,6 +335,81 @@ func (m *Migrate) changeSet(curr, new *Table) (*changes, error) {
return change, nil
}
// fixture is a special migration code for renaming foreign-key columns (issue-#285).
func (m *Migrate) fixture(ctx context.Context, tx dialect.Tx, curr, new *Table) error {
d, ok := m.sqlDialect.(fkRenamer)
if !ok {
return nil
}
rename := make(map[string]*Index)
for _, fk := range new.ForeignKeys {
ok, err := m.fkExist(ctx, tx, fk.Symbol)
if err != nil {
return fmt.Errorf("checking foreign-key existence %q: %v", fk.Symbol, err)
}
if !ok {
continue
}
column, err := m.fkColumn(ctx, tx, fk)
if err != nil {
return err
}
newcol := fk.Columns[0]
if column == newcol.Name {
continue
}
query, args := d.renameColumn(curr, &Column{Name: column}, newcol).Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("rename column %q: %v", column, err)
}
prev, ok := curr.column(column)
if !ok {
continue
}
// Find all indexes that ~maybe need to be renamed.
for _, idx := range prev.indexes {
switch _, ok := new.index(idx.Name); {
// Ignore indexes that exist in the schema, PKs.
case ok || idx.primary:
// Index that was created implicitly for a unique
// column needs to be renamed to the column name.
case d.isImplicitIndex(idx, prev):
idx2 := &Index{Name: newcol.Name, Unique: true, Columns: []*Column{newcol}}
query, args := d.renameIndex(curr, idx, idx2).Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("rename index %q: %v", prev.Name, err)
}
idx.Name = idx2.Name
default:
rename[idx.Name] = idx
}
}
// Update the name of the loaded column, so `changeSet` won't create it.
prev.Name = newcol.Name
}
// Go over the indexes that need to be renamed
// and find their ~identical in the new schema.
for _, idx := range rename {
Find:
// Find its ~identical in the new schema, and rename it
// if it doesn't exist.
for _, idx2 := range new.Indexes {
if _, ok := curr.index(idx2.Name); ok {
continue
}
if idx.sameAs(idx2) {
query, args := d.renameIndex(curr, idx, idx2).Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("rename index %q: %v", idx.Name, err)
}
idx.Name = idx2.Name
break Find
}
}
}
return nil
}
// types loads the type list from the database.
// If the table does not create, it will create one.
func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error {
@@ -385,6 +463,34 @@ func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) err
return m.setRange(ctx, tx, t.Name, id<<32)
}
// fkColumn returns the column name of a foreign-key.
func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (string, error) {
t1 := sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS t1").Unquote().As("t1")
t2 := sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t2").Unquote().As("t2")
query, args := sql.Dialect(m.Dialect()).
Select("column_name").
From(t1).
Join(t2).
On(t1.C("constraint_name"), t2.C("constraint_name")).
Where(sql.And(
sql.EQ(t2.C("constraint_type"), sql.Raw("'FOREIGN KEY'")),
sql.EQ(t2.C("table_schema"), m.sqlDialect.(fkRenamer).tableSchema()),
sql.EQ(t1.C("table_schema"), m.sqlDialect.(fkRenamer).tableSchema()),
sql.EQ(t2.C("constraint_name"), fk.Symbol),
)).
Query()
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return "", fmt.Errorf("reading foreign-key %q column: %v", fk.Symbol, err)
}
defer rows.Close()
column, err := sql.ScanString(rows)
if err != nil {
return "", fmt.Errorf("scanning foreign-key %q column: %v", fk.Symbol, err)
}
return column, nil
}
// setup ensures the table is configured properly, like table columns
// are linked to their indexes, and PKs columns are defined.
func (m *Migrate) setupTable(t *Table) {
@@ -395,6 +501,7 @@ func (m *Migrate) setupTable(t *Table) {
t.columns[c.Name] = c
}
for _, idx := range t.Indexes {
idx.Name = m.symbol(idx.Name)
for _, c := range idx.Columns {
c.indexes.append(idx)
}
@@ -463,3 +570,12 @@ type sqlDialect interface {
type preparer interface {
prepare(context.Context, dialect.Tx, *changes, string) error
}
// fkRenamer is used by the fixture migration (to solve #285),
// and it's implemented by the different dialects for renaming FKs.
type fkRenamer interface {
tableSchema() sql.Querier
isImplicitIndex(*Index, *Column) bool
renameIndex(*Table, *Index, *Index) sql.Querier
renameColumn(*Table, *Column, *Column) sql.Querier
}

View File

@@ -415,6 +415,40 @@ func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) {
return i, nil
}
// isImplicitIndex reports if the index was created implicitly for the unique column.
func (d *MySQL) isImplicitIndex(idx *Index, col *Column) bool {
// We execute `CHANGE COLUMN` on older versions of MySQL (<8.0), which
// auto create the new index. The old one, will be dropped in `changeSet`.
if compareVersions(d.version, "8.0.0") >= 0 {
return idx.Name == col.Name && col.Unique
}
return false
}
// renameColumn returns the statement for renaming a column in
// MySQL based on its version.
func (d *MySQL) renameColumn(t *Table, old, new *Column) sql.Querier {
q := sql.AlterTable(t.Name)
if compareVersions(d.version, "8.0.0") >= 0 {
return q.RenameColumn(old.Name, new.Name)
}
return q.ChangeColumn(old.Name, d.addColumn(new))
}
// renameIndex returns the statement for renaming an index.
func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier {
q := sql.AlterTable(t.Name)
if compareVersions(d.version, "5.7.0") >= 0 {
return q.RenameIndex(old.Name, new.Name)
}
return q.DropIndex(old.Name).AddIndex(new.Builder(t.Name))
}
// tableSchema returns the query for getting the table schema.
func (d *MySQL) tableSchema() sql.Querier {
return sql.Raw("(SELECT DATABASE())")
}
// fkNames returns the foreign-key names of a column.
func fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) {
query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE").Unquote()).

View File

@@ -753,6 +753,13 @@ func TestMySQL_Create(t *testing.T) {
AddRow("PRIMARY", "id", "0", "1").
AddRow("old_index", "old", "0", "1").
AddRow("parent_id", "parent_id", "0", "1"))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
WithArgs("FOREIGN KEY", "parent_id").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery(escape("SELECT `column_name` FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS t1 JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t2 ON `t1`.`constraint_name` = `t2`.`constraint_name` WHERE (`t2`.`constraint_type` = 'FOREIGN KEY') AND (`t2`.`table_schema` = (SELECT DATABASE())) AND (`t1`.`table_schema` = (SELECT DATABASE())) AND (`t2`.`constraint_name` = ?)")).
WithArgs("parent_id").
WillReturnRows(sqlmock.NewRows([]string{"COLUMN_NAME"}).
AddRow("parent_id"))
// drop the unique index.
mock.ExpectExec(escape("DROP INDEX `old_index` ON `users`")).
WillReturnResult(sqlmock.NewResult(0, 1))
@@ -895,9 +902,13 @@ func TestMySQL_Create(t *testing.T) {
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}).
AddRow("PRIMARY", "id", "0", "1"))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
WithArgs("FOREIGN KEY", "user_spouse_____________________390ed76f91d3c57cd3516e7690f621dc").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` bigint")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
WithArgs("FOREIGN KEY", "user_spouse_____________________390ed76f91d3c57cd3516e7690f621dc").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("ALTER TABLE `users` ADD CONSTRAINT `.{64}` FOREIGN KEY\\(`spouse_id`\\) REFERENCES `users`\\(`id`\\) ON DELETE CASCADE").
WillReturnResult(sqlmock.NewResult(0, 1))

View File

@@ -118,10 +118,9 @@ func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Tabl
}
c.Key = UniqueKey
c.Unique = true
c.indexes.append(idx)
fallthrough
default:
t.AddIndex(idx.Name, idx.Unique, idx.columns)
t.addIndex(idx)
}
}
return t, nil
@@ -172,12 +171,12 @@ func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (In
}
// If the index is prefixed with the table, it's probably was
// added by `addIndex` (and not entc) and it should be trimmed.
name = strings.TrimPrefix(name, table+"_")
idx, ok := names[name]
short := strings.TrimPrefix(name, table+"_")
idx, ok := names[short]
if !ok {
idx = &Index{Name: name, Unique: unique, primary: primary}
idx = &Index{Name: short, Unique: unique, primary: primary, realname: name}
idxs = append(idxs, idx)
names[name] = idx
names[short] = idx
}
idx.columns = append(idx.columns, column)
}
@@ -351,3 +350,31 @@ func (d *Postgres) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, tab
}
return tx.Exec(ctx, query, args, nil)
}
// isImplicitIndex reports if the index was created implicitly for the unique column.
func (d *Postgres) isImplicitIndex(idx *Index, col *Column) bool {
return strings.TrimSuffix(idx.Name, "_key") == col.Name && col.Unique
}
// renameColumn returns the statement for renaming a column.
func (d *Postgres) renameColumn(t *Table, old, new *Column) sql.Querier {
return sql.Dialect(dialect.Postgres).
AlterTable(t.Name).
RenameColumn(old.Name, new.Name)
}
// renameIndex returns the statement for renaming an index.
func (d *Postgres) renameIndex(t *Table, old, new *Index) sql.Querier {
if sfx := "_key"; strings.HasSuffix(old.Name, sfx) && !strings.HasSuffix(new.Name, sfx) {
new.Name += sfx
}
if pfx := t.Name + "_"; strings.HasPrefix(old.realname, pfx) && !strings.HasPrefix(new.Name, pfx) {
new.Name = pfx + new.Name
}
return sql.Dialect(dialect.Postgres).AlterIndex(old.realname).Rename(new.Name)
}
// tableSchema returns the query for getting the table schema.
func (d *Postgres) tableSchema() sql.Querier {
return sql.Raw("(CURRENT_SCHEMA())")
}

View File

@@ -564,6 +564,11 @@ func TestPostgres_Create(t *testing.T) {
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))).
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}).
AddRow("users_pkey", "id", "t", "t", 0))
mock.ExpectQuery(escape(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)).
WithArgs("FOREIGN KEY", "user_spouse____________________390ed76f91d3c57cd3516e7690f621dc").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "spouse_id" bigint NULL`)).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)).

View File

@@ -65,16 +65,20 @@ func (t *Table) AddColumn(c *Column) *Table {
// AddIndex creates and adds a new index to the table from the given options.
func (t *Table) AddIndex(name string, unique bool, columns []string) *Table {
idx := &Index{
return t.addIndex(&Index{
Name: name,
Unique: unique,
columns: columns,
Columns: make([]*Column, 0, len(columns)),
}
for _, name := range columns {
})
}
// AddIndex creates and adds a new index to the table from the given options.
func (t *Table) addIndex(idx *Index) *Table {
for _, name := range idx.columns {
c, ok := t.columns[name]
if ok {
c.indexes = append(c.indexes, idx)
c.indexes.append(idx)
idx.Columns = append(idx.Columns, c)
}
}
@@ -107,11 +111,13 @@ func (t *Table) index(name string) (*Index, bool) {
}
}
// If it is an "implicit index" (unique constraint on
// table creation) and it didn't load on table scanning.
// table creation) and it wasn't loaded in table scanning.
c, ok := t.column(name)
if !ok {
// Postgres naming convention for unique constraint.
c, ok = t.column(strings.TrimSuffix(name, "_key"))
// Postgres naming convention for unique constraint (<table>_<column>_key).
name = strings.TrimPrefix(name, t.Name+"_")
name = strings.TrimSuffix(name, "_key")
c, ok = t.column(name)
}
if ok && c.Unique {
return &Index{Name: name, Unique: c.Unique, Columns: []*Column{c}, columns: []string{c.Name}}, true
@@ -142,8 +148,8 @@ type Column struct {
Increment bool // auto increment attribute.
Nullable bool // null or not null attribute.
Default interface{} // default value.
indexes Indexes // linked indexes.
Enums []string // enum values.
indexes Indexes // linked indexes.
}
// UniqueKey returns boolean indicates if this column is a unique key.
@@ -338,11 +344,12 @@ func (r ReferenceOption) ConstName() string {
// Index definition for table index.
type Index struct {
Name string // index name.
Unique bool // uniqueness.
Columns []*Column // actual table columns.
columns []string // columns loaded from query scan.
primary bool // primary key index.
Name string // index name.
Unique bool // uniqueness.
Columns []*Column // actual table columns.
columns []string // columns loaded from query scan.
primary bool // primary key index.
realname string // real name in the database (Postgres only).
}
// Builder returns the query builder for index creation. The DSL is identical in all dialects.
@@ -363,6 +370,20 @@ func (i *Index) DropBuilder(table string) *sql.DropIndexBuilder {
return idx
}
// sameAs reports if the index has the same properties
// as the given index (except the name).
func (i *Index) sameAs(idx *Index) bool {
if i.Unique != idx.Unique || len(i.Columns) != len(idx.Columns) {
return false
}
for j, c := range i.Columns {
if c.Name != idx.Columns[j].Name {
return false
}
}
return true
}
// Indexes used for scanning all sql.Rows into a list of indexes, because
// multiple sql rows can represent the same index (multi-columns indexes).
type Indexes []*Index