dialect/sql/schema: initial work for incremental migration (#428)

This is a WIP PR and should be ignored this moment.
It's based on PR #221 created by Erik Hollensbe (He should
get his credit for his work before we land this).
This commit is contained in:
Ariel Mashraki
2020-04-12 19:12:33 +03:00
committed by GitHub
parent 8effe6dfeb
commit 2208b243db
11 changed files with 421 additions and 52 deletions

View File

@@ -209,21 +209,14 @@ func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change
}
}
}
b := sql.Dialect(m.Dialect()).AlterTable(table)
for _, c := range change.column.add {
b.AddColumn(m.addColumn(c))
}
for _, c := range change.column.modify {
b.ModifyColumns(m.alterColumn(c)...)
}
var drop []*Column
if m.dropColumns {
for _, c := range change.column.drop {
b.DropColumn(sql.Dialect(m.Dialect()).Column(c.Name))
}
drop = change.column.drop
}
queries := m.alterColumns(table, change.column.add, change.column.modify, drop)
// If there's actual action to execute on ALTER TABLE.
if len(b.Queries) != 0 {
query, args := b.Query()
for i := range queries {
query, args := queries[i].Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("alter table %q: %v", table, err)
}
@@ -337,11 +330,11 @@ func (m *Migrate) changeSet(curr, new *Table) (*changes, error) {
}
// Drop indexes.
for _, idx1 := range curr.Indexes {
_, ok1 := new.fk(idx1.Name)
_, ok2 := new.index(idx1.Name)
for _, idx := range curr.Indexes {
_, ok1 := new.fk(idx.Name)
_, ok2 := new.index(idx.Name)
if !ok1 && !ok2 {
change.index.drop.append(idx1)
change.index.drop.append(idx)
}
}
return change, nil
@@ -532,6 +525,9 @@ func (m *Migrate) setupTable(t *Table) {
}
for _, fk := range t.ForeignKeys {
fk.Symbol = m.symbol(fk.Symbol)
for i := range fk.Columns {
fk.Columns[i].foreign = fk
}
}
}
@@ -590,9 +586,8 @@ type sqlDialect interface {
// table, column and index builder per dialect.
cType(*Column) string
tBuilder(*Table) *sql.TableBuilder
addColumn(*Column) *sql.ColumnBuilder
alterColumn(*Column) []*sql.ColumnBuilder
addIndex(*Index, string) *sql.IndexBuilder
alterColumns(table string, add, modify, drop []*Column) sql.Queries
}
type preparer interface {

View File

@@ -243,11 +243,6 @@ func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder {
return b
}
// alterColumn returns the DSL query for modifying the given column.
func (d *MySQL) alterColumn(c *Column) []*sql.ColumnBuilder {
return []*sql.ColumnBuilder{d.addColumn(c)}
}
// addIndex returns the querying for adding an index to MySQL.
func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder {
return i.Builder(table)
@@ -465,6 +460,24 @@ func (d *MySQL) tableSchema() sql.Querier {
return sql.Raw("(SELECT DATABASE())")
}
// alterColumns returns the queries for applying the columns change-set.
func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
b := sql.Dialect(dialect.MySQL).AlterTable(table)
for _, c := range add {
b.AddColumn(d.addColumn(c))
}
for _, c := range modify {
b.ModifyColumn(d.addColumn(c))
}
for _, c := range drop {
b.DropColumn(sql.Dialect(dialect.MySQL).Column(c.Name))
}
if len(b.Queries) == 0 {
return nil
}
return sql.Queries{b}
}
// parseColumn returns column parts, size and signedness by mysql type
func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) {
switch parts = strings.FieldsFunc(typ, func(r rune) bool {

View File

@@ -388,7 +388,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "add bool column with default value to table",
name: "add bool column with default value",
tables: []*Table{
{
Name: "users",
@@ -420,7 +420,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "add string column with default value to table",
name: "add string column with default value",
tables: []*Table{
{
Name: "users",
@@ -452,7 +452,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "add column with unsupported default value to table",
name: "add column with unsupported default value",
tables: []*Table{
{
Name: "users",
@@ -484,7 +484,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "drop column to table",
name: "drop columns",
tables: []*Table{
{
Name: "users",

View File

@@ -382,3 +382,21 @@ func (d *Postgres) renameIndex(t *Table, old, new *Index) sql.Querier {
func (d *Postgres) tableSchema() sql.Querier {
return sql.Raw("(CURRENT_SCHEMA())")
}
// alterColumns returns the queries for applying the columns change-set.
func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
b := sql.Dialect(dialect.Postgres).AlterTable(table)
for _, c := range add {
b.AddColumn(d.addColumn(c))
}
for _, c := range modify {
b.ModifyColumns(d.alterColumn(c)...)
}
for _, c := range drop {
b.DropColumn(sql.Dialect(dialect.Postgres).Column(c.Name))
}
if len(b.Queries) == 0 {
return nil
}
return sql.Queries{b}
}

View File

@@ -98,7 +98,6 @@ func (t *Table) column(name string) (*Column, bool) {
}
// index returns a table index by its name.
// faster than map lookup for most cases.
func (t *Table) index(name string) (*Index, bool) {
for _, idx := range t.Indexes {
if idx.Name == name {
@@ -150,6 +149,7 @@ type Column struct {
Default interface{} // default value.
Enums []string // enum values.
indexes Indexes // linked indexes.
foreign *ForeignKey // linked foreign-key.
}
// UniqueKey returns boolean indicates if this column is a unique key.
@@ -186,7 +186,7 @@ func (c Column) FloatType() bool { return c.Type == field.TypeFloat32 || c.Type
// ScanDefault scans the default value string to its interface type.
func (c *Column) ScanDefault(value string) (err error) {
switch {
case value == Null: // ignore.
case strings.ToUpper(value) == Null: // ignore.
case c.IntType():
v := &sql.NullInt64{}
if err := v.Scan(value); err != nil {

View File

@@ -92,19 +92,15 @@ func (*SQLite) cType(c *Column) (t string) {
switch c.Type {
case field.TypeBool:
t = "bool"
case field.TypeInt8, field.TypeUint8, field.TypeInt, field.TypeInt16, field.TypeInt32, field.TypeUint, field.TypeUint16, field.TypeUint32:
case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32,
field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64:
t = "integer"
case field.TypeInt64, field.TypeUint64:
t = "bigint"
case field.TypeBytes:
t = "blob"
case field.TypeString, field.TypeEnum:
size := c.Size
if size == 0 {
size = DefaultStringLen
}
// sqlite has no size limit on varchar.
t = fmt.Sprintf("varchar(%d)", size)
// SQLite does not impose any length restrictions on
// the length of strings, BLOBs or numeric values.
t = fmt.Sprintf("varchar(%d)", DefaultStringLen)
case field.TypeFloat32, field.TypeFloat64:
t = "real"
case field.TypeTime:
@@ -131,11 +127,6 @@ func (d *SQLite) addColumn(c *Column) *sql.ColumnBuilder {
return b
}
// alterColumn returns the DSL query for modifying the given column.
func (d *SQLite) alterColumn(c *Column) []*sql.ColumnBuilder {
return []*sql.ColumnBuilder{d.addColumn(c)}
}
// addIndex returns the querying for adding an index to SQLite.
func (d *SQLite) addIndex(i *Index, table string) *sql.IndexBuilder {
return i.Builder(table)
@@ -151,6 +142,146 @@ func (d *SQLite) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table
func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil }
// table returns always error to indicate that SQLite dialect doesn't support incremental migration.
func (d *SQLite) table(context.Context, dialect.Tx, string) (*Table, error) {
return nil, fmt.Errorf("sqlite dialect does not support incremental migration")
func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
rows := &sql.Rows{}
query, args := sql.Select("name", "type", "notnull", "dflt_value", "pk").
From(sql.Table(fmt.Sprintf("pragma_table_info('%s')", name)).Unquote()).
OrderBy("pk").
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("sqlite: reading table description %v", err)
}
// Call Close in cases of failures (Close is idempotent).
defer rows.Close()
t := NewTable(name)
for rows.Next() {
c := &Column{}
if err := d.scanColumn(c, rows); err != nil {
return nil, fmt.Errorf("sqlite: %v", err)
}
if c.PrimaryKey() {
t.PrimaryKey = append(t.PrimaryKey, c)
}
t.AddColumn(c)
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("sqlite: closing rows %v", err)
}
indexes, err := d.indexes(ctx, tx, name)
if err != nil {
return nil, err
}
// Add and link indexes to table columns.
for _, idx := range indexes {
t.AddIndex(idx.Name, idx.Unique, idx.columns)
}
return t, nil
}
// table loads the table indexes from the database.
func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Indexes, error) {
rows := &sql.Rows{}
query, args := sql.Select("name", "unique").
From(sql.Table(fmt.Sprintf("pragma_index_list('%s')", name)).Unquote()).
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("reading table indexes %v", err)
}
defer rows.Close()
var idx Indexes
for rows.Next() {
i := &Index{}
if err := rows.Scan(&i.Name, &i.Unique); err != nil {
return nil, fmt.Errorf("scanning index description %v", err)
}
idx = append(idx, i)
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("closing rows %v", err)
}
for i := range idx {
columns, err := d.indexColumns(ctx, tx, idx[i].Name)
if err != nil {
return nil, err
}
idx[i].columns = columns
}
return idx, nil
}
// indexColumns loads index columns from index info.
func (d *SQLite) indexColumns(ctx context.Context, tx dialect.Tx, name string) ([]string, error) {
rows := &sql.Rows{}
query, args := sql.Select("name").
From(sql.Table(fmt.Sprintf("pragma_index_info('%s')", name)).Unquote()).
OrderBy("seqno").
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("reading table indexes %v", err)
}
defer rows.Close()
var names []string
if err := sql.ScanSlice(rows, &names); err != nil {
return nil, err
}
return names, nil
}
// scanColumn scans the column information from SQLite column description.
func (d *SQLite) scanColumn(c *Column, rows *sql.Rows) error {
var (
pk sql.NullInt64
notnull sql.NullInt64
defaults sql.NullString
)
if err := rows.Scan(&c.Name, &c.typ, &notnull, &defaults, &pk); err != nil {
return fmt.Errorf("scanning column description: %v", err)
}
c.Nullable = notnull.Int64 == 0
if pk.Int64 > 0 {
c.Key = PrimaryKey
}
parts, _, _, err := parseColumn(c.typ)
if err != nil {
return err
}
switch parts[0] {
case "bool", "boolean":
c.Type = field.TypeBool
case "blob":
c.Type = field.TypeBytes
case "integer":
// All integer types have the same "type affinity".
c.Type = field.TypeInt
case "real", "float", "double":
c.Type = field.TypeFloat64
case "datetime":
c.Type = field.TypeTime
case "json":
c.Type = field.TypeJSON
case "uuid":
c.Type = field.TypeUUID
case "varchar", "text":
c.Size = DefaultStringLen
c.Type = field.TypeString
}
if defaults.Valid {
return c.ScanDefault(defaults.String)
}
return nil
}
// alterColumns returns the queries for applying the columns change-set.
func (d *SQLite) alterColumns(table string, add, _, _ []*Column) sql.Queries {
queries := make(sql.Queries, 0, len(add))
for i := range add {
c := d.addColumn(add[i])
if fk := add[i].foreign; fk != nil {
c.Constraint(fk.DSL())
}
queries = append(queries, sql.Dialect(dialect.SQLite).AlterTable(table).AddColumn(c))
}
// Modifying and dropping columns is not supported and disabled until we
// will support https://www.sqlite.org/lang_altertable.html#otheralter
return queries
}

View File

@@ -6,6 +6,8 @@ package schema
import (
"context"
"fmt"
"math"
"testing"
"github.com/facebookincubator/ent/dialect/sql"
@@ -119,6 +121,188 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectCommit()
},
},
{
name: "add column to table",
tables: []*Table{
{
Name: "users",
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString, Nullable: true},
{Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32},
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
{Name: "age", Type: field.TypeInt, Default: 0},
},
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
},
},
before: func(mock sqliteMock) {
mock.start()
mock.tableExists("users", true)
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("name", "varchar(255)", 0, nil, 0).
AddRow("text", "text", 0, "NULL", 0).
AddRow("uuid", "uuid", 0, "Null", 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
{
name: "datetime and timestamp",
tables: []*Table{
{
Name: "users",
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "created_at", Type: field.TypeTime, Nullable: true},
{Name: "updated_at", Type: field.TypeTime, Nullable: true},
},
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
},
},
before: func(mock sqliteMock) {
mock.start()
mock.tableExists("users", true)
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("created_at", "datetime", 0, nil, 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
{
name: "add blob columns",
tables: []*Table{
{
Name: "blobs",
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "old_tiny", Type: field.TypeBytes, Size: 100},
{Name: "old_blob", Type: field.TypeBytes, Size: 1e3},
{Name: "old_medium", Type: field.TypeBytes, Size: 1e5},
{Name: "old_long", Type: field.TypeBytes, Size: 1e8},
{Name: "new_tiny", Type: field.TypeBytes, Size: 100},
{Name: "new_blob", Type: field.TypeBytes, Size: 1e3},
{Name: "new_medium", Type: field.TypeBytes, Size: 1e5},
{Name: "new_long", Type: field.TypeBytes, Size: 1e8},
},
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
},
},
before: func(mock sqliteMock) {
mock.start()
mock.tableExists("blobs", true)
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('blobs') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("old_tiny", "blob", 1, nil, 0).
AddRow("old_blob", "blob", 1, nil, 0).
AddRow("old_medium", "blob", 1, nil, 0).
AddRow("old_long", "blob", 1, nil, 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('blobs')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
for _, c := range []string{"tiny", "blob", "medium", "long"} {
mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))).
WillReturnResult(sqlmock.NewResult(0, 1))
}
mock.ExpectCommit()
},
},
{
name: "add columns with default values",
tables: []*Table{
{
Name: "users",
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString, Default: "unknown"},
{Name: "active", Type: field.TypeBool, Default: false},
},
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
},
},
before: func(mock sqliteMock) {
mock.start()
mock.tableExists("users", true)
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `name` varchar(255) NOT NULL DEFAULT 'unknown'")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
{
name: "add edge to table",
tables: func() []*Table {
var (
c1 = []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString, Nullable: true},
{Name: "spouse_id", Type: field.TypeInt, Nullable: true},
}
t1 = &Table{
Name: "users",
Columns: c1,
PrimaryKey: c1[0:1],
ForeignKeys: []*ForeignKey{
{
Symbol: "user_spouse",
Columns: c1[2:],
RefColumns: c1[0:1],
OnDelete: Cascade,
},
},
}
)
t1.ForeignKeys[0].RefTable = t1
return []*Table{t1}
}(),
before: func(mock sqliteMock) {
mock.start()
mock.tableExists("users", true)
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("name", "varchar(255)", 1, "NULL", 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
{
name: "universal id for all tables",
tables: []*Table{