mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: schema match from predicate
This commit is contained in:
committed by
Ariel Mashraki
parent
7e904f0e1c
commit
cc8da8fbf7
@@ -545,8 +545,8 @@ func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (
|
||||
On(t1.C("constraint_name"), t2.C("constraint_name")).
|
||||
Where(sql.And(
|
||||
sql.EQ(t2.C("constraint_type"), sql.Raw("'FOREIGN KEY'")),
|
||||
sql.EQ(t2.C("table_schema"), m.sqlDialect.(fkRenamer).tableSchema()),
|
||||
sql.EQ(t1.C("table_schema"), m.sqlDialect.(fkRenamer).tableSchema()),
|
||||
m.sqlDialect.(fkRenamer).matchSchema(t2.C("table_schema")),
|
||||
m.sqlDialect.(fkRenamer).matchSchema(t1.C("table_schema")),
|
||||
sql.EQ(t2.C("constraint_name"), fk.Symbol),
|
||||
)).
|
||||
Query()
|
||||
@@ -656,7 +656,7 @@ type preparer interface {
|
||||
// fkRenamer is used by the fixture migration (to solve #285),
|
||||
// and it's implemented by the different dialects for renaming FKs.
|
||||
type fkRenamer interface {
|
||||
tableSchema() sql.Querier
|
||||
matchSchema(...string) *sql.Predicate
|
||||
isImplicitIndex(*Index, *Column) bool
|
||||
renameIndex(*Table, *Index, *Index) sql.Querier
|
||||
renameColumn(*Table, *Column, *Column) sql.Querier
|
||||
|
||||
@@ -48,7 +48,7 @@ func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error {
|
||||
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("TABLE_NAME", name),
|
||||
)).Query()
|
||||
return exist(ctx, tx, query, args...)
|
||||
@@ -57,7 +57,7 @@ func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (boo
|
||||
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLE_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("CONSTRAINT_TYPE", "FOREIGN KEY"),
|
||||
sql.EQ("CONSTRAINT_NAME", name),
|
||||
)).Query()
|
||||
@@ -70,7 +70,7 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
query, args := sql.Select("column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name").
|
||||
From(sql.Table("COLUMNS").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("TABLE_NAME", name)),
|
||||
).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
@@ -117,7 +117,7 @@ func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Ind
|
||||
query, args := sql.Select("index_name", "column_name", "non_unique", "seq_in_index").
|
||||
From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("TABLE_NAME", name),
|
||||
)).
|
||||
OrderBy("index_name", "seq_in_index").
|
||||
@@ -145,7 +145,7 @@ func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expect
|
||||
query, args := sql.Select("AUTO_INCREMENT").
|
||||
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("TABLE_NAME", t.Name),
|
||||
)).
|
||||
Query()
|
||||
@@ -512,19 +512,23 @@ func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier {
|
||||
return q.DropIndex(old.Name).AddIndex(new.Builder(t.Name))
|
||||
}
|
||||
|
||||
// tableSchema returns the query for getting the table schema.
|
||||
func (d *MySQL) tableSchema() sql.Querier {
|
||||
if d.schema != "" {
|
||||
return sql.Expr("?", d.schema)
|
||||
// matchSchema returns the predicate for matching table schema.
|
||||
func (d *MySQL) matchSchema(columns ...string) *sql.Predicate {
|
||||
column := "TABLE_SCHEMA"
|
||||
if len(columns) > 0 {
|
||||
column = columns[0]
|
||||
}
|
||||
return sql.Raw("(SELECT DATABASE())")
|
||||
if d.schema != "" {
|
||||
return sql.EQ(column, d.schema)
|
||||
}
|
||||
return sql.EQ(column, sql.Raw("(SELECT DATABASE())"))
|
||||
}
|
||||
|
||||
// tables returns the query for getting the in the schema.
|
||||
func (d *MySQL) tables() sql.Querier {
|
||||
return sql.Select("TABLE_NAME").
|
||||
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.EQ("TABLE_SCHEMA", d.tableSchema()))
|
||||
Where(d.matchSchema())
|
||||
}
|
||||
|
||||
// alterColumns returns the queries for applying the columns change-set.
|
||||
@@ -564,7 +568,7 @@ func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) erro
|
||||
query, args := sql.Select("CONSTRAINT_NAME", "CHECK_CLAUSE").
|
||||
From(sql.Table("CHECK_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
sql.EQ("CONSTRAINT_SCHEMA", d.tableSchema()),
|
||||
d.matchSchema("CONSTRAINT_SCHEMA"),
|
||||
sql.EQ("TABLE_NAME", t.Name),
|
||||
sql.InValues("CONSTRAINT_NAME", names...),
|
||||
)).
|
||||
@@ -632,7 +636,7 @@ func (d *MySQL) fkNames(ctx context.Context, tx dialect.Tx, table, column string
|
||||
sql.EQ("COLUMN_NAME", column),
|
||||
// NULL for unique and primary-key constraints.
|
||||
sql.NotNull("POSITION_IN_UNIQUE_CONSTRAINT"),
|
||||
sql.EQ("TABLE_SCHEMA", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
)).
|
||||
Query()
|
||||
var (
|
||||
|
||||
@@ -54,7 +54,7 @@ func (d *Postgres) tableExist(ctx context.Context, tx dialect.Tx, name string) (
|
||||
query, args := sql.Dialect(dialect.Postgres).
|
||||
Select(sql.Count("*")).From(sql.Table("tables").Schema("information_schema")).
|
||||
Where(sql.And(
|
||||
sql.EQ("table_schema", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("table_name", name),
|
||||
)).Query()
|
||||
return exist(ctx, tx, query, args...)
|
||||
@@ -65,7 +65,7 @@ func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (boo
|
||||
query, args := sql.Dialect(dialect.Postgres).
|
||||
Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")).
|
||||
Where(sql.And(
|
||||
sql.EQ("table_schema", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("constraint_type", "FOREIGN KEY"),
|
||||
sql.EQ("constraint_name", name),
|
||||
)).Query()
|
||||
@@ -91,7 +91,7 @@ func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Tabl
|
||||
Select("column_name", "data_type", "is_nullable", "column_default", "udt_name").
|
||||
From(sql.Table("columns").Schema("information_schema")).
|
||||
Where(sql.And(
|
||||
sql.EQ("table_schema", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("table_name", name),
|
||||
)).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
@@ -173,8 +173,10 @@ ORDER BY index_name, seq_in_index;
|
||||
|
||||
// indexesQuery returns the query (and its placeholders) for getting table indexes.
|
||||
func (d *Postgres) indexesQuery(table string) (string, []interface{}) {
|
||||
expr, args := d.tableSchema().Query()
|
||||
return fmt.Sprintf(indexesQuery, expr, table), args
|
||||
if d.schema != "" {
|
||||
return fmt.Sprintf(indexesQuery, "$1", table), []interface{}{d.schema}
|
||||
}
|
||||
return fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", table), nil
|
||||
}
|
||||
|
||||
func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (Indexes, error) {
|
||||
@@ -405,7 +407,7 @@ func (d *Postgres) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, tab
|
||||
query, args := sql.Dialect(dialect.Postgres).
|
||||
Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")).
|
||||
Where(sql.And(
|
||||
sql.EQ("table_schema", d.tableSchema()),
|
||||
d.matchSchema(),
|
||||
sql.EQ("constraint_type", "UNIQUE"),
|
||||
sql.EQ("constraint_name", name),
|
||||
)).
|
||||
@@ -444,12 +446,16 @@ func (d *Postgres) renameIndex(t *Table, old, new *Index) sql.Querier {
|
||||
return sql.Dialect(dialect.Postgres).AlterIndex(old.realname).Rename(new.Name)
|
||||
}
|
||||
|
||||
// tableSchema returns the query for getting the table schema.
|
||||
func (d *Postgres) tableSchema() sql.Querier {
|
||||
if d.schema != "" {
|
||||
return sql.Expr("?", d.schema)
|
||||
// matchSchema returns the predicate for matching table schema.
|
||||
func (d *Postgres) matchSchema(columns ...string) *sql.Predicate {
|
||||
column := "table_schema"
|
||||
if len(columns) > 0 {
|
||||
column = columns[0]
|
||||
}
|
||||
return sql.Raw("CURRENT_SCHEMA()")
|
||||
if d.schema != "" {
|
||||
return sql.EQ(column, d.schema)
|
||||
}
|
||||
return sql.EQ(column, sql.Raw("CURRENT_SCHEMA()"))
|
||||
}
|
||||
|
||||
// tables returns the query for getting the in the schema.
|
||||
@@ -457,7 +463,7 @@ func (d *Postgres) tables() sql.Querier {
|
||||
return sql.Dialect(dialect.Postgres).
|
||||
Select("table_name").
|
||||
From(sql.Table("tables").Schema("information_schema")).
|
||||
Where(sql.EQ("table_schema", d.tableSchema()))
|
||||
Where(d.matchSchema())
|
||||
}
|
||||
|
||||
// alterColumns returns the queries for applying the columns change-set.
|
||||
|
||||
Reference in New Issue
Block a user