mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
Summary: Basically, adding support for Hebrew characters. Reviewed By: alexsn Differential Revision: D16068537 fbshipit-source-id: 4e934da5ea97c9e804317f746556ab1d51faebcc
234 lines
7.3 KiB
Go
234 lines
7.3 KiB
Go
package schema
|
|
|
|
import (
|
|
"context"
|
|
"crypto/md5"
|
|
"fmt"
|
|
"sort"
|
|
|
|
"fbc/ent/dialect"
|
|
"fbc/ent/dialect/sql"
|
|
)
|
|
|
|
// MySQL is a mysql migration driver.
|
|
type MySQL struct {
|
|
dialect.Driver
|
|
}
|
|
|
|
// Create creates all schema resources in the database. It works in an "append-only"
|
|
// mode, which means, it won't delete or change any existing resource in the database.
|
|
func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
|
|
tx, err := d.Tx(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := d.create(ctx, tx, tables...); err != nil {
|
|
return rollback(tx, fmt.Errorf("dialect/mysql: %v", err))
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
|
|
version, err := d.version(ctx, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, t := range tables {
|
|
switch exist, err := d.tableExist(ctx, tx, t.Name); {
|
|
case err != nil:
|
|
return err
|
|
case exist:
|
|
curr, err := d.table(ctx, tx, t.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
change, err := changeSet(curr, t)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(change.add) != 0 || len(change.modify) != 0 {
|
|
b := sql.AlterTable(curr.Name)
|
|
for _, c := range change.add {
|
|
b.AddColumn(c.MySQL(version))
|
|
}
|
|
for _, c := range change.modify {
|
|
b.ModifyColumn(c.MySQL(version))
|
|
}
|
|
query, args := b.Query()
|
|
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
|
return fmt.Errorf("alter table %q: %v", t.Name, err)
|
|
}
|
|
}
|
|
if len(change.indexes) > 0 {
|
|
panic("missing implementation")
|
|
}
|
|
default: // !exist
|
|
query, args := t.MySQL(version).Query()
|
|
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
|
return fmt.Errorf("create table %q: %v", t.Name, err)
|
|
}
|
|
}
|
|
}
|
|
// create foreign keys after tables were created/altered,
|
|
// because circular foreign-key constraints are possible.
|
|
for _, t := range tables {
|
|
if len(t.ForeignKeys) == 0 {
|
|
continue
|
|
}
|
|
fks := make([]*ForeignKey, 0, len(t.ForeignKeys))
|
|
for _, fk := range t.ForeignKeys {
|
|
fk.Symbol = symbol(fk.Symbol)
|
|
exist, err := d.fkExist(ctx, tx, fk.Symbol)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !exist {
|
|
fks = append(fks, fk)
|
|
}
|
|
}
|
|
if len(fks) == 0 {
|
|
continue
|
|
}
|
|
b := sql.AlterTable(t.Name)
|
|
for _, fk := range fks {
|
|
b.AddForeignKey(fk.DSL())
|
|
}
|
|
query, args := b.Query()
|
|
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
|
return fmt.Errorf("create foreign keys for %q: %v", t.Name, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (d *MySQL) version(ctx context.Context, tx dialect.Tx) (string, error) {
|
|
rows := &sql.Rows{}
|
|
if err := tx.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
|
|
return "", fmt.Errorf("dialect/mysql: querying mysql version %v", err)
|
|
}
|
|
defer rows.Close()
|
|
if !rows.Next() {
|
|
return "", fmt.Errorf("dialect/mysql: version variable was not found")
|
|
}
|
|
version := make([]string, 2)
|
|
if err := rows.Scan(&version[0], &version[1]); err != nil {
|
|
return "", fmt.Errorf("dialect/mysql: scanning mysql version: %v", err)
|
|
}
|
|
return version[1], nil
|
|
}
|
|
|
|
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
|
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()).
|
|
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
|
|
return d.exist(ctx, tx, query, args...)
|
|
}
|
|
|
|
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
|
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS").Unquote()).
|
|
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("CONSTRAINT_TYPE", "FOREIGN KEY").And().EQ("CONSTRAINT_NAME", name)).Query()
|
|
return d.exist(ctx, tx, query, args...)
|
|
}
|
|
|
|
func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
|
rows := &sql.Rows{}
|
|
if err := tx.Query(ctx, query, args, rows); err != nil {
|
|
return false, fmt.Errorf("dialect/mysql: reading schema information %v", err)
|
|
}
|
|
defer rows.Close()
|
|
if !rows.Next() {
|
|
return false, fmt.Errorf("dialect/mysql: no rows returned")
|
|
}
|
|
var n int
|
|
if err := rows.Scan(&n); err != nil {
|
|
return false, fmt.Errorf("dialect/mysql: scanning count")
|
|
}
|
|
return n > 0, nil
|
|
}
|
|
|
|
// table loads the current table description from the database.
|
|
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
|
|
rows := &sql.Rows{}
|
|
query, args := sql.Select("column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name").
|
|
From(sql.Table("INFORMATION_SCHEMA.COLUMNS").Unquote()).
|
|
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
|
|
if err := tx.Query(ctx, query, args, rows); err != nil {
|
|
return nil, fmt.Errorf("dialect/mysql: reading table description %v", err)
|
|
}
|
|
defer rows.Close()
|
|
t := &Table{Name: name}
|
|
for rows.Next() {
|
|
c := &Column{}
|
|
if err := c.ScanMySQL(rows); err != nil {
|
|
return nil, fmt.Errorf("dialect/mysql: %v", err)
|
|
}
|
|
if c.PrimaryKey() {
|
|
t.PrimaryKey = append(t.PrimaryKey, c)
|
|
}
|
|
t.Columns = append(t.Columns, c)
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
// changes to apply on existing table.
|
|
type changes struct {
|
|
add []*Column
|
|
modify []*Column
|
|
indexes []*Index
|
|
}
|
|
|
|
// changeSet returns a changes object to be applied on existing table.
|
|
// It fails if one of the changes is invalid.
|
|
func changeSet(curr, new *Table) (*changes, error) {
|
|
change := &changes{}
|
|
// pks.
|
|
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
|
|
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
|
}
|
|
sort.Slice(new.PrimaryKey, func(i, j int) bool { return new.PrimaryKey[i].Name < new.PrimaryKey[j].Name })
|
|
sort.Slice(curr.PrimaryKey, func(i, j int) bool { return curr.PrimaryKey[i].Name < curr.PrimaryKey[j].Name })
|
|
for i := range curr.PrimaryKey {
|
|
if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name {
|
|
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
|
}
|
|
}
|
|
// columns.
|
|
for _, c1 := range new.Columns {
|
|
switch c2, ok := curr.column(c1.Name); {
|
|
case !ok:
|
|
change.add = append(change.add, c1)
|
|
case c1.Type != c2.Type:
|
|
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, c1.Type, c2.Type)
|
|
case c1.Unique != c2.Unique:
|
|
return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name)
|
|
case c1.Charset != "" && c1.Charset != c2.Charset || c1.Collation != "" && c1.Charset != c2.Collation:
|
|
change.modify = append(change.modify, c1)
|
|
}
|
|
}
|
|
// indexes.
|
|
for _, idx1 := range new.Indexes {
|
|
switch idx2, ok := curr.index(idx1.Name); {
|
|
case !ok:
|
|
change.indexes = append(change.indexes, idx1)
|
|
case idx1.Unique != idx2.Unique:
|
|
return nil, fmt.Errorf("changing index %q uniqness is invalid", idx1.Name)
|
|
}
|
|
}
|
|
return change, nil
|
|
}
|
|
|
|
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).
|
|
func symbol(name string) string {
|
|
if len(name) <= 64 {
|
|
return name
|
|
}
|
|
return fmt.Sprintf("%s_%x", name[:31], md5.Sum([]byte(name)))
|
|
}
|
|
|
|
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
|
func rollback(tx dialect.Tx, err error) error {
|
|
if rerr := tx.Rollback(); rerr != nil {
|
|
err = fmt.Errorf("%s: %v", err.Error(), rerr)
|
|
}
|
|
return err
|
|
}
|