dialect/sql/schema: support alternate schema for drivers (#1167)

This commit is contained in:
Ariel Mashraki
2021-01-13 14:21:03 +02:00
committed by GitHub
parent d4d10d3977
commit 601a4ee50d

View File

@@ -20,6 +20,7 @@ import (
// MySQL is a MySQL migration driver.
type MySQL struct {
dialect.Driver
schema string
version string
}
@@ -47,7 +48,7 @@ func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error {
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")),
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
sql.EQ("TABLE_NAME", name),
)).Query()
return exist(ctx, tx, query, args...)
@@ -56,7 +57,7 @@ func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (boo
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLE_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")),
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
sql.EQ("CONSTRAINT_TYPE", "FOREIGN KEY"),
sql.EQ("CONSTRAINT_NAME", name),
)).Query()
@@ -69,7 +70,7 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
query, args := sql.Select("column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name").
From(sql.Table("COLUMNS").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")),
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
sql.EQ("TABLE_NAME", name)),
).Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
@@ -116,7 +117,7 @@ func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Ind
query, args := sql.Select("index_name", "column_name", "non_unique", "seq_in_index").
From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")),
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
sql.EQ("TABLE_NAME", name),
)).
OrderBy("index_name", "seq_in_index").
@@ -144,7 +145,7 @@ func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expect
query, args := sql.Select("AUTO_INCREMENT").
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")),
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
sql.EQ("TABLE_NAME", t.Name),
)).
Query()
@@ -319,7 +320,7 @@ func (d *MySQL) prepare(ctx context.Context, tx dialect.Tx, change *changes, tab
// If both the index and the column need to be dropped, the foreign-key
// constraint that is associated with them need to be dropped as well.
case ok:
names, err := fkNames(ctx, tx, table, col.Name)
names, err := d.fkNames(ctx, tx, table, col.Name)
if err != nil {
return err
}
@@ -335,7 +336,7 @@ func (d *MySQL) prepare(ctx context.Context, tx dialect.Tx, change *changes, tab
break Switch
}
}
names, err := fkNames(ctx, tx, table, col.Name)
names, err := d.fkNames(ctx, tx, table, col.Name)
if err != nil {
return err
}
@@ -513,6 +514,9 @@ func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier {
// tableSchema returns the query for getting the table schema.
func (d *MySQL) tableSchema() sql.Querier {
if d.schema != "" {
return sql.Expr("?", d.schema)
}
return sql.Raw("(SELECT DATABASE())")
}
@@ -553,7 +557,7 @@ func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) erro
query, args := sql.Select("CONSTRAINT_NAME", "CHECK_CLAUSE").
From(sql.Table("CHECK_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
sql.EQ("CONSTRAINT_SCHEMA", sql.Raw("(SELECT DATABASE())")),
sql.EQ("CONSTRAINT_SCHEMA", d.tableSchema()),
sql.EQ("TABLE_NAME", t.Name),
sql.InValues("CONSTRAINT_NAME", names...),
)).
@@ -614,14 +618,14 @@ func parseColumn(typ string) (parts []string, size int64, unsigned bool, err err
}
// fkNames returns the foreign-key names of a column.
func fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) {
func (d *MySQL) fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) {
query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("KEY_COLUMN_USAGE").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
sql.EQ("TABLE_NAME", table),
sql.EQ("COLUMN_NAME", column),
// NULL for unique and primary-key constraints.
sql.NotNull("POSITION_IN_UNIQUE_CONSTRAINT"),
sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")),
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
)).
Query()
var (