Files
ent/dialect/sql/schema/mysql.go
Ariel Mashraki 0e605681b2 quote table name on DESCRIBE statement (#7)
Summary:
Pull Request resolved: https://github.com/facebookincubator/ent/pull/7

avoid conflicting with reserved words (e.g. "groups" became reserved word in MySQL 8.x)

Reviewed By: alexsn

Differential Revision: D15957279

fbshipit-source-id: b6944752c7e5f6ec37119b11cc0ab8a0d7d3a7ff
2019-06-23 01:51:28 -07:00

206 lines
5.9 KiB
Go

package schema
import (
"context"
"crypto/md5"
"fmt"
"sort"
"fbc/ent/dialect"
"fbc/ent/dialect/sql"
)
// MySQL is a mysql migration driver.
type MySQL struct {
dialect.Driver
}
// Create creates all schema resources in the database. It works in an "append-only"
// mode, which means, it won't delete or change any existing resource in the database.
func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
tx, err := d.Tx(ctx)
if err != nil {
return err
}
if err := d.create(ctx, tx, tables...); err != nil {
return rollback(tx, fmt.Errorf("dialect/mysql: %v", err))
}
return tx.Commit()
}
func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
for _, t := range tables {
switch exist, err := d.tableExist(ctx, tx, t.Name); {
case err != nil:
return err
case exist:
curr, err := d.table(ctx, tx, t.Name)
if err != nil {
return err
}
changes, err := changeSet(curr, t)
if err != nil {
return err
}
if len(changes.Columns) > 0 {
b := sql.AlterTable(curr.Name)
for _, c := range changes.Columns {
b.AddColumn(c.DSL())
}
query, args := b.Query()
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
return fmt.Errorf("alter table %q: %v", t.Name, err)
}
}
if len(changes.Indexes) > 0 {
panic("missing implementation")
}
default: // !exist
query, args := t.DSL().Query()
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
return fmt.Errorf("create table %q: %v", t.Name, err)
}
}
}
// create foreign keys after tables were created/altered,
// because circular foreign-key constraints are possible.
for _, t := range tables {
if len(t.ForeignKeys) == 0 {
continue
}
fks := make([]*ForeignKey, 0, len(t.ForeignKeys))
for _, fk := range t.ForeignKeys {
fk.Symbol = symbol(fk.Symbol)
exist, err := d.fkExist(ctx, tx, fk.Symbol)
if err != nil {
return err
}
if !exist {
fks = append(fks, fk)
}
}
if len(fks) == 0 {
continue
}
b := sql.AlterTable(t.Name)
for _, fk := range fks {
b.AddForeignKey(fk.DSL())
}
query, args := b.Query()
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
return fmt.Errorf("create foreign keys for %q: %v", t.Name, err)
}
}
return nil
}
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
return d.exist(
ctx,
tx,
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = (SELECT DATABASE()) AND TABLE_NAME = ?",
name,
)
}
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
return d.exist(
ctx,
tx,
`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE TABLE_SCHEMA=(SELECT DATABASE()) AND CONSTRAINT_TYPE="FOREIGN KEY" AND CONSTRAINT_NAME = ?`,
name,
)
}
func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return false, fmt.Errorf("dialect/mysql: reading schema information %v", err)
}
defer rows.Close()
if !rows.Next() {
return false, fmt.Errorf("dialect/mysql: no rows returned")
}
var n int
if err := rows.Scan(&n); err != nil {
return false, fmt.Errorf("dialect/mysql: scanning count")
}
return n > 0, nil
}
// table loads the current table description from the database.
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
rows := &sql.Rows{}
query, args := sql.Describe(name).Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("dialect/mysql: reading table description %v", err)
}
defer rows.Close()
t := &Table{Name: name}
for rows.Next() {
c := &Column{}
if err := c.ScanMySQL(rows); err != nil {
return nil, fmt.Errorf("dialect/mysql: %v", err)
}
if c.PrimaryKey() {
t.PrimaryKey = append(t.PrimaryKey, c)
}
t.Columns = append(t.Columns, c)
}
return t, nil
}
// changeSet returns a dummy table represents the change set that need
// to be applied on the table. it fails if one of the changes is invalid.
func changeSet(curr, new *Table) (*Table, error) {
changes := &Table{}
// pks.
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
}
sort.Slice(new.PrimaryKey, func(i, j int) bool { return new.PrimaryKey[i].Name < new.PrimaryKey[j].Name })
sort.Slice(curr.PrimaryKey, func(i, j int) bool { return curr.PrimaryKey[i].Name < curr.PrimaryKey[j].Name })
for i := range curr.PrimaryKey {
if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name {
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
}
}
// columns.
for _, c1 := range new.Columns {
switch c2, ok := curr.column(c1.Name); {
case !ok:
changes.Columns = append(changes.Columns, c1)
case c1.Type != c2.Type:
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, c1.Type, c2.Type)
case c1.Unique != c2.Unique:
return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name)
}
}
// indexes.
for _, idx1 := range new.Indexes {
switch idx2, ok := curr.index(idx1.Name); {
case !ok:
changes.Indexes = append(changes.Indexes, idx1)
case idx1.Unique != idx2.Unique:
return nil, fmt.Errorf("changing index %q uniqness is invalid", idx1.Name)
}
}
return changes, nil
}
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).
func symbol(name string) string {
if len(name) <= 64 {
return name
}
return fmt.Sprintf("%s_%x", name[:31], md5.Sum([]byte(name)))
}
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
func rollback(tx dialect.Tx, err error) error {
if rerr := tx.Rollback(); rerr != nil {
err = fmt.Errorf("%s: %v", err.Error(), rerr)
}
return err
}