mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
imporve sql migration (#3)
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/3 add an append-only mode to the migration Reviewed By: alexsn Differential Revision: D15845370 fbshipit-source-id: f22ae1866d4bb9250bf2d1c6cba476d574a3f45d
This commit is contained in:
committed by
Facebook Github Bot
parent
267e3c15bd
commit
4f31aa6cfe
@@ -274,7 +274,7 @@ func AlterTable(name string) *TableAlter { return &TableAlter{b: Builder{}, name
|
||||
|
||||
// AddColumn appends the `ADD COLUMN` clause to the given `ALTER TABLE` statement.
|
||||
func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter {
|
||||
t.nodes = append(t.nodes, &Wrapper{"ADD %s", c})
|
||||
t.nodes = append(t.nodes, &Wrapper{"ADD COLUMN %s", c})
|
||||
return t
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ func TestBuilder(t *testing.T) {
|
||||
Reference(Reference().Table("groups").Columns("id")).
|
||||
OnDelete("CASCADE"),
|
||||
),
|
||||
wantQuery: "ALTER TABLE `users` ADD `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`) ON DELETE CASCADE",
|
||||
wantQuery: "ALTER TABLE `users` ADD COLUMN `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`) ON DELETE CASCADE",
|
||||
},
|
||||
{
|
||||
input: AlterTable("users").
|
||||
@@ -58,13 +58,13 @@ func TestBuilder(t *testing.T) {
|
||||
AddForeignKey(ForeignKey().Columns("group_id").
|
||||
Reference(Reference().Table("groups").Columns("id")),
|
||||
),
|
||||
wantQuery: "ALTER TABLE `users` ADD `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`)",
|
||||
wantQuery: "ALTER TABLE `users` ADD COLUMN `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`)",
|
||||
},
|
||||
{
|
||||
input: AlterTable("users").
|
||||
AddColumn(Column("age").Type("int")).
|
||||
AddColumn(Column("name").Type("varchar(255)")),
|
||||
wantQuery: "ALTER TABLE `users` ADD `age` int, ADD `name` varchar(255)",
|
||||
wantQuery: "ALTER TABLE `users` ADD COLUMN `age` int, ADD COLUMN `name` varchar(255)",
|
||||
},
|
||||
{
|
||||
input: AlterTable("users").
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"fbc/ent/dialect"
|
||||
"fbc/ent/dialect/sql"
|
||||
@@ -14,26 +15,48 @@ type MySQL struct {
|
||||
dialect.Driver
|
||||
}
|
||||
|
||||
// Create creates all tables resources in the database.
|
||||
// Create creates all schema resources in the database. It works in an "append-only"
|
||||
// mode, which means, it won't delete or change any existing resource in the database.
|
||||
func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
|
||||
tx, err := d.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range tables {
|
||||
exist, err := d.tableExist(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
switch exist, err := d.tableExist(ctx, tx, t.Name); {
|
||||
case err != nil:
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if exist {
|
||||
continue
|
||||
}
|
||||
query, args := t.DSL().Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, fmt.Errorf("sql/mysql: create table %q: %v", t.Name, err))
|
||||
case exist:
|
||||
curr, err := d.table(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
changes, err := changeSet(curr, t)
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if len(changes.Columns) > 0 {
|
||||
b := sql.AlterTable(curr.Name)
|
||||
for _, c := range changes.Columns {
|
||||
b.AddColumn(c.DSL())
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, fmt.Errorf("sql/mysql: alter table %q: %v", t.Name, err))
|
||||
}
|
||||
}
|
||||
if len(changes.Indexes) > 0 {
|
||||
panic("missing implementation")
|
||||
}
|
||||
default: // !exist
|
||||
query, args := t.DSL().Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, fmt.Errorf("sql/mysql: create table %q: %v", t.Name, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
// create foreign keys after table was created, because circular foreign-key constraints are possible.
|
||||
// create foreign keys after tables were created/altered,
|
||||
// because circular foreign-key constraints are possible.
|
||||
for _, t := range tables {
|
||||
if len(t.ForeignKeys) == 0 {
|
||||
continue
|
||||
@@ -98,6 +121,65 @@ func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
// table loads the current table description from the database.
|
||||
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, "DESCRIBE "+name, []interface{}{}, rows); err != nil {
|
||||
return nil, fmt.Errorf("dialect/mysql: reading table description %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
t := &Table{Name: name}
|
||||
for rows.Next() {
|
||||
c := &Column{}
|
||||
if err := c.ScanMySQL(rows); err != nil {
|
||||
return nil, fmt.Errorf("dialect/mysql: %v", err)
|
||||
}
|
||||
if c.PrimaryKey() {
|
||||
t.PrimaryKey = append(t.PrimaryKey, c)
|
||||
}
|
||||
t.Columns = append(t.Columns, c)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// changeSet returns a dummy table represents the change set that need
|
||||
// to be applied on the table. it fails if one of the changes is invalid.
|
||||
func changeSet(curr, new *Table) (*Table, error) {
|
||||
changes := &Table{}
|
||||
// pks.
|
||||
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
}
|
||||
sort.Slice(new.PrimaryKey, func(i, j int) bool { return new.PrimaryKey[i].Name < new.PrimaryKey[j].Name })
|
||||
sort.Slice(curr.PrimaryKey, func(i, j int) bool { return curr.PrimaryKey[i].Name < curr.PrimaryKey[j].Name })
|
||||
for i := range curr.PrimaryKey {
|
||||
if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
}
|
||||
}
|
||||
// columns.
|
||||
for _, c1 := range new.Columns {
|
||||
switch c2, ok := curr.column(c1.Name); {
|
||||
case !ok:
|
||||
changes.Columns = append(changes.Columns, c1)
|
||||
case c1.Type != c2.Type:
|
||||
return nil, fmt.Errorf("changing column type for %q is invalid", c1.Name)
|
||||
case c1.Unique != c2.Unique:
|
||||
return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name)
|
||||
}
|
||||
}
|
||||
// indexes.
|
||||
for _, idx1 := range new.Indexes {
|
||||
switch idx2, ok := curr.index(idx1.Name); {
|
||||
case !ok:
|
||||
changes.Indexes = append(changes.Indexes, idx1)
|
||||
case idx1.Unique != idx2.Unique:
|
||||
return nil, fmt.Errorf("changing index %q uniqness is invalid", idx1.Name)
|
||||
}
|
||||
}
|
||||
return changes, nil
|
||||
}
|
||||
|
||||
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).
|
||||
func symbol(name string) string {
|
||||
if len(name) <= 64 {
|
||||
|
||||
@@ -2,6 +2,7 @@ package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"fbc/ent/dialect/sql"
|
||||
@@ -70,17 +71,40 @@ func (t *Table) SQLite() *sql.TableBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// column returns a table column by its name.
|
||||
// faster than map lookup for most cases.
|
||||
func (t *Table) column(name string) (*Column, bool) {
|
||||
for _, c := range t.Columns {
|
||||
if c.Name == name {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// index returns a table index by its name.
|
||||
// faster than map lookup for most cases.
|
||||
func (t *Table) index(name string) (*Index, bool) {
|
||||
for _, idx := range t.Indexes {
|
||||
if idx.Name == name {
|
||||
return idx, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Column schema definition for SQL dialects.
|
||||
type Column struct {
|
||||
Name string // column name.
|
||||
Type field.Type // column type.
|
||||
typ string // row column type (used for Rows.Scan).
|
||||
Attr string // extra attributes.
|
||||
Default string // default value.
|
||||
Nullable *bool // null or not null attribute.
|
||||
Size int // max size parameter for string, blob, etc.
|
||||
Key string // key definition (PRI, UNI or MUL).
|
||||
Unique bool // column with unique constraint.
|
||||
Increment bool // auto increment attribute.
|
||||
Nullable *bool // null or not null attribute.
|
||||
Default string // default value.
|
||||
}
|
||||
|
||||
// UniqueKey returns boolean indicates if this column is a unique key.
|
||||
@@ -182,6 +206,52 @@ func (c *Column) SQLiteType() (t string) {
|
||||
return t
|
||||
}
|
||||
|
||||
// ScanMySQL scans the information from MySQL column description.
|
||||
func (c *Column) ScanMySQL(rows *sql.Rows) error {
|
||||
var (
|
||||
nullable sql.NullString
|
||||
defaults sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr); err != nil {
|
||||
return fmt.Errorf("scanning column description: %v", err)
|
||||
}
|
||||
c.Unique = c.UniqueKey()
|
||||
c.Default = defaults.String
|
||||
if nullable.Valid {
|
||||
null := nullable.String == "YES"
|
||||
c.Nullable = &null
|
||||
}
|
||||
switch parts := strings.FieldsFunc(c.typ, func(r rune) bool {
|
||||
return r == '(' || r == ')' || r == ' '
|
||||
}); parts[0] {
|
||||
case "int":
|
||||
c.Type = field.TypeInt
|
||||
case "timestamp":
|
||||
c.Type = field.TypeTime
|
||||
case "tinyint":
|
||||
size, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varchar size to int: %v", err)
|
||||
}
|
||||
switch {
|
||||
case size == 1:
|
||||
c.Type = field.TypeBool
|
||||
case len(parts) == 3: // tinyint(3) unsigned.
|
||||
c.Type = field.TypeUint8
|
||||
default:
|
||||
c.Type = field.TypeInt8
|
||||
}
|
||||
case "varchar":
|
||||
c.Type = field.TypeString
|
||||
size, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varchar size to int: %v", err)
|
||||
}
|
||||
c.Size = size
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unique adds the `UNIQUE` attribute if the column is a unique type.
|
||||
// it is exist in a different function to share the common declaration
|
||||
// between the two dialects.
|
||||
@@ -257,10 +327,11 @@ func (r ReferenceOption) ConstName() string {
|
||||
|
||||
// Index definition for table index.
|
||||
type Index struct {
|
||||
Key string // key name.
|
||||
Column string // column name.
|
||||
Name string
|
||||
Unique bool
|
||||
Columns []*Column
|
||||
}
|
||||
|
||||
// Primary indicates if this index is a primary key.
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (i *Index) Primary() bool { return i.Key == "PRIMARY" }
|
||||
func (i *Index) Primary() bool { return i.Name == "PRIMARY" }
|
||||
|
||||
Reference in New Issue
Block a user