From 4f31aa6cfecb5343115939b3524c3069025b3c0e Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 17 Jun 2019 04:09:15 -0700 Subject: [PATCH] imporve sql migration (#3) Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/3 add an append-only mode to the migration Reviewed By: alexsn Differential Revision: D15845370 fbshipit-source-id: f22ae1866d4bb9250bf2d1c6cba476d574a3f45d --- dialect/sql/builder.go | 2 +- dialect/sql/builder_test.go | 6 +- dialect/sql/schema/mysql.go | 104 +++++++++++++++++++++++++--- dialect/sql/schema/schema.go | 81 ++++++++++++++++++++-- entc/integration/ent/schema/user.go | 4 +- 5 files changed, 174 insertions(+), 23 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 1e1c78ea2..b6296d384 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -274,7 +274,7 @@ func AlterTable(name string) *TableAlter { return &TableAlter{b: Builder{}, name // AddColumn appends the `ADD COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter { - t.nodes = append(t.nodes, &Wrapper{"ADD %s", c}) + t.nodes = append(t.nodes, &Wrapper{"ADD COLUMN %s", c}) return t } diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index cbd803911..ef02be266 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -50,7 +50,7 @@ func TestBuilder(t *testing.T) { Reference(Reference().Table("groups").Columns("id")). OnDelete("CASCADE"), ), - wantQuery: "ALTER TABLE `users` ADD `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`) ON DELETE CASCADE", + wantQuery: "ALTER TABLE `users` ADD COLUMN `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`) ON DELETE CASCADE", }, { input: AlterTable("users"). @@ -58,13 +58,13 @@ func TestBuilder(t *testing.T) { AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")), ), - wantQuery: "ALTER TABLE `users` ADD `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`)", + wantQuery: "ALTER TABLE `users` ADD COLUMN `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`)", }, { input: AlterTable("users"). AddColumn(Column("age").Type("int")). AddColumn(Column("name").Type("varchar(255)")), - wantQuery: "ALTER TABLE `users` ADD `age` int, ADD `name` varchar(255)", + wantQuery: "ALTER TABLE `users` ADD COLUMN `age` int, ADD COLUMN `name` varchar(255)", }, { input: AlterTable("users"). diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index 584ecf442..5b0c72ca5 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -4,6 +4,7 @@ import ( "context" "crypto/md5" "fmt" + "sort" "fbc/ent/dialect" "fbc/ent/dialect/sql" @@ -14,26 +15,48 @@ type MySQL struct { dialect.Driver } -// Create creates all tables resources in the database. +// Create creates all schema resources in the database. It works in an "append-only" +// mode, which means, it won't delete or change any existing resource in the database. func (d *MySQL) Create(ctx context.Context, tables ...*Table) error { tx, err := d.Tx(ctx) if err != nil { return err } for _, t := range tables { - exist, err := d.tableExist(ctx, tx, t.Name) - if err != nil { + switch exist, err := d.tableExist(ctx, tx, t.Name); { + case err != nil: return rollback(tx, err) - } - if exist { - continue - } - query, args := t.DSL().Query() - if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil { - return rollback(tx, fmt.Errorf("sql/mysql: create table %q: %v", t.Name, err)) + case exist: + curr, err := d.table(ctx, tx, t.Name) + if err != nil { + return rollback(tx, err) + } + changes, err := changeSet(curr, t) + if err != nil { + return rollback(tx, err) + } + if len(changes.Columns) > 0 { + b := sql.AlterTable(curr.Name) + for _, c := range changes.Columns { + b.AddColumn(c.DSL()) + } + query, args := b.Query() + if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil { + return rollback(tx, fmt.Errorf("sql/mysql: alter table %q: %v", t.Name, err)) + } + } + if len(changes.Indexes) > 0 { + panic("missing implementation") + } + default: // !exist + query, args := t.DSL().Query() + if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil { + return rollback(tx, fmt.Errorf("sql/mysql: create table %q: %v", t.Name, err)) + } } } - // create foreign keys after table was created, because circular foreign-key constraints are possible. + // create foreign keys after tables were created/altered, + // because circular foreign-key constraints are possible. for _, t := range tables { if len(t.ForeignKeys) == 0 { continue @@ -98,6 +121,65 @@ func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ... return n > 0, nil } +// table loads the current table description from the database. +func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { + rows := &sql.Rows{} + if err := tx.Query(ctx, "DESCRIBE "+name, []interface{}{}, rows); err != nil { + return nil, fmt.Errorf("dialect/mysql: reading table description %v", err) + } + defer rows.Close() + t := &Table{Name: name} + for rows.Next() { + c := &Column{} + if err := c.ScanMySQL(rows); err != nil { + return nil, fmt.Errorf("dialect/mysql: %v", err) + } + if c.PrimaryKey() { + t.PrimaryKey = append(t.PrimaryKey, c) + } + t.Columns = append(t.Columns, c) + } + return t, nil +} + +// changeSet returns a dummy table represents the change set that need +// to be applied on the table. it fails if one of the changes is invalid. +func changeSet(curr, new *Table) (*Table, error) { + changes := &Table{} + // pks. + if len(curr.PrimaryKey) != len(new.PrimaryKey) { + return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) + } + sort.Slice(new.PrimaryKey, func(i, j int) bool { return new.PrimaryKey[i].Name < new.PrimaryKey[j].Name }) + sort.Slice(curr.PrimaryKey, func(i, j int) bool { return curr.PrimaryKey[i].Name < curr.PrimaryKey[j].Name }) + for i := range curr.PrimaryKey { + if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name { + return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) + } + } + // columns. + for _, c1 := range new.Columns { + switch c2, ok := curr.column(c1.Name); { + case !ok: + changes.Columns = append(changes.Columns, c1) + case c1.Type != c2.Type: + return nil, fmt.Errorf("changing column type for %q is invalid", c1.Name) + case c1.Unique != c2.Unique: + return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name) + } + } + // indexes. + for _, idx1 := range new.Indexes { + switch idx2, ok := curr.index(idx1.Name); { + case !ok: + changes.Indexes = append(changes.Indexes, idx1) + case idx1.Unique != idx2.Unique: + return nil, fmt.Errorf("changing index %q uniqness is invalid", idx1.Name) + } + } + return changes, nil +} + // symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64). func symbol(name string) string { if len(name) <= 64 { diff --git a/dialect/sql/schema/schema.go b/dialect/sql/schema/schema.go index 74a8689a6..fbef56f37 100644 --- a/dialect/sql/schema/schema.go +++ b/dialect/sql/schema/schema.go @@ -2,6 +2,7 @@ package schema import ( "fmt" + "strconv" "strings" "fbc/ent/dialect/sql" @@ -70,17 +71,40 @@ func (t *Table) SQLite() *sql.TableBuilder { return b } +// column returns a table column by its name. +// faster than map lookup for most cases. +func (t *Table) column(name string) (*Column, bool) { + for _, c := range t.Columns { + if c.Name == name { + return c, true + } + } + return nil, false +} + +// index returns a table index by its name. +// faster than map lookup for most cases. +func (t *Table) index(name string) (*Index, bool) { + for _, idx := range t.Indexes { + if idx.Name == name { + return idx, true + } + } + return nil, false +} + // Column schema definition for SQL dialects. type Column struct { Name string // column name. Type field.Type // column type. + typ string // row column type (used for Rows.Scan). Attr string // extra attributes. - Default string // default value. - Nullable *bool // null or not null attribute. Size int // max size parameter for string, blob, etc. Key string // key definition (PRI, UNI or MUL). Unique bool // column with unique constraint. Increment bool // auto increment attribute. + Nullable *bool // null or not null attribute. + Default string // default value. } // UniqueKey returns boolean indicates if this column is a unique key. @@ -182,6 +206,52 @@ func (c *Column) SQLiteType() (t string) { return t } +// ScanMySQL scans the information from MySQL column description. +func (c *Column) ScanMySQL(rows *sql.Rows) error { + var ( + nullable sql.NullString + defaults sql.NullString + ) + if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr); err != nil { + return fmt.Errorf("scanning column description: %v", err) + } + c.Unique = c.UniqueKey() + c.Default = defaults.String + if nullable.Valid { + null := nullable.String == "YES" + c.Nullable = &null + } + switch parts := strings.FieldsFunc(c.typ, func(r rune) bool { + return r == '(' || r == ')' || r == ' ' + }); parts[0] { + case "int": + c.Type = field.TypeInt + case "timestamp": + c.Type = field.TypeTime + case "tinyint": + size, err := strconv.Atoi(parts[1]) + if err != nil { + return fmt.Errorf("converting varchar size to int: %v", err) + } + switch { + case size == 1: + c.Type = field.TypeBool + case len(parts) == 3: // tinyint(3) unsigned. + c.Type = field.TypeUint8 + default: + c.Type = field.TypeInt8 + } + case "varchar": + c.Type = field.TypeString + size, err := strconv.Atoi(parts[1]) + if err != nil { + return fmt.Errorf("converting varchar size to int: %v", err) + } + c.Size = size + } + return nil +} + // unique adds the `UNIQUE` attribute if the column is a unique type. // it is exist in a different function to share the common declaration // between the two dialects. @@ -257,10 +327,11 @@ func (r ReferenceOption) ConstName() string { // Index definition for table index. type Index struct { - Key string // key name. - Column string // column name. + Name string + Unique bool + Columns []*Column } // Primary indicates if this index is a primary key. // Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects. -func (i *Index) Primary() bool { return i.Key == "PRIMARY" } +func (i *Index) Primary() bool { return i.Name == "PRIMARY" } diff --git a/entc/integration/ent/schema/user.go b/entc/integration/ent/schema/user.go index 7759eb7d0..73a2ff7d2 100644 --- a/entc/integration/ent/schema/user.go +++ b/entc/integration/ent/schema/user.go @@ -32,9 +32,7 @@ func (User) Fields() []ent.Field { // Edges of the user. func (User) Edges() []ent.Edge { return []ent.Edge{ - edge.To("card", Card.Type). - Comment("O2O edge"). - Unique(), + edge.To("card", Card.Type).Comment("O2O edge").Unique(), edge.To("pets", Pet.Type), edge.To("files", File.Type), edge.To("groups", Group.Type),