diff --git a/dialect/sql/schema/atlas.go b/dialect/sql/schema/atlas.go index b98b14a63..6ff3ded52 100644 --- a/dialect/sql/schema/atlas.go +++ b/dialect/sql/schema/atlas.go @@ -29,10 +29,6 @@ type Atlas struct { atDriver migrate.Driver sqlDialect sqlDialect - legacy bool // if the legacy migration engine instead of Atlas should be used - withFixture bool // deprecated: with fks rename fixture - sum bool // deprecated: sum file generation will be required - indent string // plan indentation errNoPlan bool // no plan error enabled universalID bool // global unique ids @@ -67,7 +63,7 @@ func Diff(ctx context.Context, u, name string, tables []*Table, opts ...MigrateO // NewMigrate creates a new Atlas form the given dialect.Driver. func NewMigrate(drv dialect.Driver, opts ...MigrateOption) (*Atlas, error) { - a := &Atlas{driver: drv, withForeignKeys: true, mode: ModeInspect, sum: true} + a := &Atlas{driver: drv, withForeignKeys: true, mode: ModeInspect} for _, opt := range opts { opt(a) } @@ -84,7 +80,7 @@ func NewMigrateURL(u string, opts ...MigrateOption) (*Atlas, error) { if err != nil { return nil, err } - a := &Atlas{url: parsed, withForeignKeys: true, mode: ModeInspect, sum: true} + a := &Atlas{url: parsed, withForeignKeys: true, mode: ModeInspect} for _, opt := range opts { opt(a) } @@ -106,13 +102,6 @@ func NewMigrateURL(u string, opts ...MigrateOption) (*Atlas, error) { func (a *Atlas) Create(ctx context.Context, tables ...*Table) (err error) { a.setupTables(tables) var creator Creator = CreateFunc(a.create) - if a.legacy { - m, err := a.legacyMigrate() - if err != nil { - return err - } - creator = CreateFunc(m.create) - } for i := len(a.hooks) - 1; i >= 0; i-- { creator = a.hooks[i](creator) } @@ -132,13 +121,9 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er return errors.New("no migration directory given") } opts := []migrate.PlannerOption{migrate.WithFormatter(a.fmt)} - if a.sum { - // Validate the migration directory before proceeding. - if err := migrate.Validate(a.dir); err != nil { - return fmt.Errorf("validating migration directory: %w", err) - } - } else { - opts = append(opts, migrate.DisableChecksum()) + // Validate the migration directory before proceeding. + if err := migrate.Validate(a.dir); err != nil { + return fmt.Errorf("validating migration directory: %w", err) } a.setupTables(tables) // Set up connections. @@ -488,18 +473,6 @@ func WithApplyHook(hooks ...ApplyHook) MigrateOption { } } -// WithAtlas is an opt-out option for v0.11 indicating the migration -// should be executed using the deprecated legacy engine. -// Note, in future versions, this option is going to be removed -// and the Atlas (https://atlasgo.io) based migration engine should be used. -// -// Deprecated: The legacy engine will be removed. -func WithAtlas(b bool) MigrateOption { - return func(a *Atlas) { - a.legacy = !b - } -} - // WithDir sets the atlas migration directory to use to store migration files. func WithDir(dir migrate.Dir) MigrateOption { return func(a *Atlas) { @@ -522,22 +495,6 @@ func WithDialect(d string) MigrateOption { } } -// WithSumFile instructs atlas to generate a migration directory integrity sum file. -// -// Deprecated: generating the sum file is now opt-out. This method will be removed in future versions. -func WithSumFile() MigrateOption { - return func(a *Atlas) {} -} - -// DisableChecksum instructs atlas to skip migration directory integrity sum file generation. -// -// Deprecated: generating the sum file will no longer be optional in future versions. -func DisableChecksum() MigrateOption { - return func(a *Atlas) { - a.sum = false - } -} - // WithMigrationMode instructs atlas how to compute the current state of the schema. This can be done by either // replaying (ModeReplay) the migration directory on the connected database, or by inspecting (ModeInspect) the // connection. Currently, ModeReplay is opt-in, and ModeInspect is the default. In future versions, ModeReplay will @@ -626,15 +583,9 @@ func (a *Atlas) init() error { a.fmt = sqltool.GolangMigrateFormatter } } - if a.mode == ModeReplay { - // ModeReplay requires a migration directory. - if a.dir == nil { - return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires versioned migrations: WithDir()") - } - // ModeReplay requires sum file generation. - if !a.sum { - return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires migration directory integrity file") - } + // ModeReplay requires a migration directory. + if a.mode == ModeReplay && a.dir == nil { + return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires versioned migrations: WithDir()") } return nil } @@ -1238,32 +1189,6 @@ func (r *diffDriver) SchemaDiff(from, to *schema.Schema, opts ...schema.DiffOpti return d.Diff(from, to) } -// legacyMigrate returns a configured legacy migration engine (before Atlas) to keep backwards compatibility. -// -// Deprecated: Will be removed alongside legacy migration support. -func (a *Atlas) legacyMigrate() (*Migrate, error) { - m := &Migrate{ - universalID: a.universalID, - dropColumns: a.dropColumns, - dropIndexes: a.dropIndexes, - withFixture: a.withFixture, - withForeignKeys: a.withForeignKeys, - hooks: a.hooks, - atlas: a, - } - switch a.dialect { - case dialect.MySQL: - m.sqlDialect = &MySQL{Driver: a.driver} - case dialect.SQLite: - m.sqlDialect = &SQLite{Driver: a.driver, WithForeignKeys: a.withForeignKeys} - case dialect.Postgres: - m.sqlDialect = &Postgres{Driver: a.driver} - default: - return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect) - } - return m, nil -} - // removeAttr is a temporary patch due to compiler errors we get by using the generic // schema.RemoveAttr function (:1: internal compiler error: panic: ...). // Can be removed in Go 1.20. See: https://github.com/golang/go/issues/54302. diff --git a/dialect/sql/schema/inspect.go b/dialect/sql/schema/inspect.go deleted file mode 100644 index 95d529150..000000000 --- a/dialect/sql/schema/inspect.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2019-present Facebook Inc. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package schema - -import ( - "context" - "fmt" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/sql" -) - -// InspectOption allows for managing schema configuration using functional options. -type InspectOption func(inspect *Inspector) - -// WithSchema provides a schema (named-database) for reading the tables from. -func WithSchema(schema string) InspectOption { - return func(m *Inspector) { - m.schema = schema - } -} - -// An Inspector provides methods for inspecting database tables. -type Inspector struct { - sqlDialect - schema string -} - -// NewInspect returns an inspector for the given SQL driver. -func NewInspect(d dialect.Driver, opts ...InspectOption) (*Inspector, error) { - i := &Inspector{} - for _, opt := range opts { - opt(i) - } - switch d.Dialect() { - case dialect.MySQL: - i.sqlDialect = &MySQL{Driver: d, schema: i.schema} - case dialect.SQLite: - i.sqlDialect = &SQLite{Driver: d} - case dialect.Postgres: - i.sqlDialect = &Postgres{Driver: d, schema: i.schema} - default: - return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect()) - } - return i, nil -} - -// Tables returns the tables in the schema. -func (i *Inspector) Tables(ctx context.Context) ([]*Table, error) { - names, err := i.tables(ctx) - if err != nil { - return nil, err - } - tx := dialect.NopTx(i.sqlDialect) - tables := make([]*Table, 0, len(names)) - for _, name := range names { - t, err := i.table(ctx, tx, name) - if err != nil { - return nil, err - } - tables = append(tables, t) - } - - fki, ok := i.sqlDialect.(interface { - foreignKeys(context.Context, dialect.Tx, []*Table) error - }) - if ok { - if err := fki.foreignKeys(ctx, tx, tables); err != nil { - return nil, err - } - } - return tables, nil -} - -func (i *Inspector) tables(ctx context.Context) ([]string, error) { - t, ok := i.sqlDialect.(interface{ tables() sql.Querier }) - if !ok { - return nil, fmt.Errorf("sql/schema: %q driver does not support inspection", i.Dialect()) - } - query, args := t.tables().Query() - var ( - names []string - rows = &sql.Rows{} - ) - if err := i.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("%q driver: reading table names %w", i.Dialect(), err) - } - defer rows.Close() - if err := sql.ScanSlice(rows, &names); err != nil { - return nil, err - } - return names, nil -} diff --git a/dialect/sql/schema/inspect_test.go b/dialect/sql/schema/inspect_test.go deleted file mode 100644 index ff5e6eee9..000000000 --- a/dialect/sql/schema/inspect_test.go +++ /dev/null @@ -1,333 +0,0 @@ -// Copyright 2019-present Facebook Inc. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package schema - -import ( - "context" - "fmt" - "math" - "path" - "testing" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/sql" - "entgo.io/ent/schema/field" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/require" -) - -func TestInspector_Tables(t *testing.T) { - tests := []struct { - name string - options []InspectOption - before map[string]func(mysqlMock) - tables func(drv string) []*Table - wantErr bool - }{ - { - name: "default schema", - before: map[string]func(mysqlMock){ - dialect.MySQL: func(mock mysqlMock) { - mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE())")). - WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) - }, - dialect.SQLite: func(mock mysqlMock) { - mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")). - WithArgs("table"). - WillReturnRows(sqlmock.NewRows([]string{"name"})) - }, - dialect.Postgres: func(mock mysqlMock) { - mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA()`)). - WillReturnRows(sqlmock.NewRows([]string{"name"})) - }, - }, - tables: func(drv string) []*Table { - return nil - }, - }, - { - name: "custom schema", - options: []InspectOption{WithSchema("public")}, - before: map[string]func(mysqlMock){ - dialect.MySQL: func(mock mysqlMock) { - mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = ?")). - WithArgs("public"). - WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}). - AddRow("users"). - AddRow("pets"). - AddRow("groups"). - AddRow("user_groups")) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). - WithArgs("public", "users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin", nil, nil). - AddRow("price", "decimal(6, 4)", "NO", "YES", "NULL", "", "", "", "6", "4"). - AddRow("bank_id", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("public", "users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). - WithArgs("public", "pets"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("user_pets", "bigint(20)", "YES", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("public", "pets"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). - WithArgs("public", "groups"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("public", "groups"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). - WithArgs("public", "user_groups"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("user_id", "bigint(20)", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("group_id", "bigint(20)", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("public", "user_groups"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"})) - }, - dialect.SQLite: func(mock mysqlMock) { - mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")). - WithArgs("table"). - WillReturnRows(sqlmock.NewRows([]string{"name"}). - AddRow("users"). - AddRow("pets"). - AddRow("groups"). - AddRow("user_groups")) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("id", "integer", 1, "NULL", 1). - AddRow("name", "varchar(255)", 0, "NULL", 0). - AddRow("text", "text", 0, "NULL", 0). - AddRow("uuid", "uuid", 0, "NULL", 0). - AddRow("price", "real", 1, "NULL", 0). - AddRow("bank_id", "varchar(255)", 1, "NULL", 0)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('pets') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("id", "integer", 1, "NULL", 1). - AddRow("name", "varchar(255)", 0, "NULL", 0). - AddRow("user_pets", "integer", 0, "NULL", 0)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('pets')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('groups') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("id", "integer", 1, "NULL", 1). - AddRow("name", "varchar(255)", 1, "NULL", 0)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('groups')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('user_groups') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("user_id", "integer", 1, "NULL", 0). - AddRow("group_id", "integer", 1, "NULL", 0)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('user_groups')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) - }, - dialect.Postgres: func(mock mysqlMock) { - mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = $1`)). - WithArgs("public"). - WillReturnRows(sqlmock.NewRows([]string{"name"}). - AddRow("users"). - AddRow("pets"). - AddRow("groups"). - AddRow("user_groups")) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)). - WithArgs("public", "users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil). - AddRow("text", "text", "YES", "NULL", "text", nil, nil, nil). - AddRow("uuid", "uuid", "YES", "NULL", "uuid", nil, nil, nil). - AddRow("price", "numeric", "NO", "NULL", "numeric", "6", "4", nil). - AddRow("bank_id", "character", "NO", "NULL", "bpchar", nil, nil, 20)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "users"))). - WithArgs("public"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)). - WithArgs("public", "pets"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil). - AddRow("user_pets", "bigint", "YES", "NULL", "int8", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "pets"))). - WithArgs("public"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("pets_pkey", "id", "t", "t", 0)) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)). - WithArgs("public", "groups"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "NO", "NULL", "bpchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "groups"))). - WithArgs("public"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("groups_pkey", "id", "t", "t", 0)) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)). - WithArgs("public", "user_groups"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("user_id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("group_id", "bigint", "NO", "NULL", "int8", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "user_groups"))). - WithArgs("public"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"})) - mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "users"))). - WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"})) - mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "pets"))). - WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}). - AddRow("public", "pet_users_pets", "pets", "user_pets", "public", "users", "id")) - mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "groups"))). - WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"})) - mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "user_groups"))). - WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}). - AddRow("public", "user_groups_group_id", "user_groups", "group_id", "public", "groups", "id"). - AddRow("public", "user_groups_user_id", "user_groups", "user_id", "public", "users", "id")) - }, - }, - tables: func(drv string) []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt64, Increment: true}, - {Name: "name", Type: field.TypeString, Size: 255, Nullable: true}, - {Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true}, - {Name: "uuid", Type: field.TypeUUID, Nullable: true}, - {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{ - dialect.MySQL: "decimal(6,4)", - dialect.Postgres: "numeric(6,4)", - }}, - {Name: "bank_id", Type: field.TypeString, SchemaType: map[string]string{ - dialect.Postgres: "varchar(20)", - }}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - } - c2 = []*Column{ - {Name: "id", Type: field.TypeInt64, Increment: true}, - {Name: "name", Type: field.TypeString, Size: 255, Nullable: true}, - {Name: "user_pets", Type: field.TypeInt64, Nullable: true}, - } - t2 = &Table{ - Name: "pets", - Columns: c2, - PrimaryKey: c2[0:1], - } - c3 = []*Column{ - {Name: "id", Type: field.TypeInt64, Increment: true}, - {Name: "name", Type: field.TypeString}, - } - t3 = &Table{ - Name: "groups", - Columns: c3, - PrimaryKey: c3[0:1], - } - c4 = []*Column{ - {Name: "user_id", Type: field.TypeInt64}, - {Name: "group_id", Type: field.TypeInt64}, - } - t4 = &Table{ - Name: "user_groups", - Columns: c4, - } - ) - - // Only postgres currently supports foreign key inspection - if drv == dialect.Postgres { - t2.ForeignKeys = []*ForeignKey{ - { - Symbol: "pet_users_pets", - Columns: []*Column{c2[2]}, - RefTable: t1, - RefColumns: []*Column{c1[0]}, - }, - } - t4.ForeignKeys = []*ForeignKey{ - { - Symbol: "user_groups_group_id", - Columns: []*Column{c4[1]}, - RefTable: t3, - RefColumns: []*Column{c3[0]}, - }, - { - Symbol: "user_groups_user_id", - Columns: []*Column{c4[0]}, - RefTable: t1, - RefColumns: []*Column{c1[0]}, - }, - } - } - - return []*Table{t1, t2, t3, t4} - }, - }, - } - for _, tt := range tests { - for drv := range tt.before { - t.Run(path.Join(drv, tt.name), func(t *testing.T) { - db, mock, err := sqlmock.New() - require.NoError(t, err) - tt.before[drv](mysqlMock{mock}) - inspect, err := NewInspect(sql.OpenDB(drv, db), tt.options...) - require.NoError(t, err) - tables, err := inspect.Tables(context.Background()) - require.Equal(t, tt.wantErr, err != nil, err) - tablesMatch(t, drv, tables, tt.tables(drv)) - }) - } - } -} - -func tablesMatch(t *testing.T, drv string, got, expected []*Table) { - require.Equal(t, len(expected), len(got)) - for i := range got { - columnsMatch(t, drv, got[i].Columns, expected[i].Columns) - columnsMatch(t, drv, got[i].PrimaryKey, expected[i].PrimaryKey) - foreignKeysMatch(t, drv, got[i].ForeignKeys, expected[i].ForeignKeys) - } -} - -func columnsMatch(t *testing.T, drv string, got, expected []*Column) { - require.Equal(t, len(expected), len(got)) - for i := range got { - c1, c2 := got[i], expected[i] - require.Equal(t, c2.Name, c1.Name) - require.Equal(t, c2.Nullable, c1.Nullable) - require.True(t, c1.Type == c2.Type || c1.ConvertibleTo(c2), "mismatched types: %s - %s", c1.Type, c2.Type) - if c2.SchemaType[drv] != "" { - require.Equal(t, c2.SchemaType[drv], c1.SchemaType[drv]) - } - } -} - -func foreignKeysMatch(t *testing.T, drv string, expected []*ForeignKey, got []*ForeignKey) { - require.Equal(t, len(expected), len(got)) - for i := range got { - fk1, fk2 := got[i], expected[i] - require.Equal(t, fk2.Symbol, fk1.Symbol) - require.Equal(t, fk2.RefTable.Name, fk1.RefTable.Name) - columnsMatch(t, drv, fk1.Columns, fk2.Columns) - columnsMatch(t, drv, fk1.RefColumns, fk2.RefColumns) - } -} diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index 014eb98b3..fdd31bd9d 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -73,17 +73,6 @@ func WithDropIndex(b bool) MigrateOption { } } -// WithFixture sets the foreign-key renaming option to the migration when upgrading -// sqlDialect from v0.1.0 (issue-#285). Defaults to false. -// -// Deprecated: This option is no longer needed with the Atlas based -// migration engine, which now is the default. -func WithFixture(b bool) MigrateOption { - return func(a *Atlas) { - a.withFixture = b - } -} - // WithForeignKeys enables creating foreign-key in ddl. Defaults to true. func WithForeignKeys(b bool) MigrateOption { return func(a *Atlas) { @@ -127,480 +116,6 @@ func (f CreateFunc) Create(ctx context.Context, tables ...*Table) error { return f(ctx, tables...) } -// Migrate runs the migration logic for the SQL dialects. -// -// Deprecated: Use the new Atlas struct instead. -type Migrate struct { - sqlDialect - atlas *Atlas // Atlas this Migrate is based on - - universalID bool // global unique ids - dropColumns bool // drop deleted columns - dropIndexes bool // drop deleted indexes - withFixture bool // with fks rename fixture - withForeignKeys bool // with foreign keys - typeRanges []string // types order by their range - hooks []Hook // hooks to apply before creation -} - -// Create creates all schema resources in the database. It works in an "append-only" -// mode, which means, it only creates tables, appends columns to tables or modifies column types. -// -// Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not -// resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but -// changing varchar(120) to varchar(255) is valid. For more info, see the convert function below. -// -// Note that SQLite dialect does not support (this moment) the "append-only" mode describe above, -// since it's used only for testing. -func (m *Migrate) Create(ctx context.Context, tables ...*Table) error { - m.setupTables(tables) - var creator Creator = CreateFunc(m.create) - for i := len(m.hooks) - 1; i >= 0; i-- { - creator = m.hooks[i](creator) - } - return creator.Create(ctx, tables...) -} - -func (m *Migrate) create(ctx context.Context, tables ...*Table) error { - if err := m.init(ctx); err != nil { - return err - } - tx, err := m.Tx(ctx) - if err != nil { - return err - } - if m.universalID { - if err := m.types(ctx, tx); err != nil { - return rollback(tx, err) - } - } - if err := m.txCreate(ctx, tx, tables...); err != nil { - return rollback(tx, err) - } - return tx.Commit() -} - -func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table) error { - for _, t := range tables { - switch exist, err := m.tableExist(ctx, tx, t.Name); { - case err != nil: - return err - case exist: - curr, err := m.table(ctx, tx, t.Name) - if err != nil { - return err - } - if err := m.verify(ctx, tx, curr); err != nil { - return err - } - if err := m.fixture(ctx, tx, curr, t); err != nil { - return err - } - change, err := m.changeSet(curr, t) - if err != nil { - return fmt.Errorf("creating changeset for %q: %w", t.Name, err) - } - if err := m.apply(ctx, tx, t.Name, change); err != nil { - return err - } - default: // !exist - query, args := m.tBuilder(t).Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("create table %q: %w", t.Name, err) - } - // If global unique identifier is enabled, and it's not - // a relation table, allocate a range for the table pk. - if m.universalID && len(t.PrimaryKey) == 1 { - if err := m.allocPKRange(ctx, tx, t); err != nil { - return err - } - } - // indexes. - for _, idx := range t.Indexes { - query, args := m.addIndex(idx, t.Name).Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("create index %q: %w", idx.Name, err) - } - } - } - } - if !m.withForeignKeys { - return nil - } - // Create foreign keys after tables were created/altered, - // because circular foreign-key constraints are possible. - for _, t := range tables { - if len(t.ForeignKeys) == 0 { - continue - } - fks := make([]*ForeignKey, 0, len(t.ForeignKeys)) - for _, fk := range t.ForeignKeys { - exist, err := m.fkExist(ctx, tx, fk.Symbol) - if err != nil { - return err - } - if !exist { - fks = append(fks, fk) - } - } - if len(fks) == 0 { - continue - } - b := sql.Dialect(m.Dialect()).AlterTable(t.Name) - for _, fk := range fks { - b.AddForeignKey(fk.DSL()) - } - query, args := b.Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("create foreign keys for %q: %w", t.Name, err) - } - } - return nil -} - -// apply changes on the given table. -func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change *changes) error { - // Constraints should be dropped before dropping columns, because if a column - // is a part of multi-column constraints (like, unique index), ALTER TABLE - // might fail if the intermediate state violates the constraints. - if m.dropIndexes { - if pr, ok := m.sqlDialect.(preparer); ok { - if err := pr.prepare(ctx, tx, change, table); err != nil { - return err - } - } - for _, idx := range change.index.drop { - if err := m.dropIndex(ctx, tx, idx, table); err != nil { - return fmt.Errorf("drop index of table %q: %w", table, err) - } - } - } - var drop []*Column - if m.dropColumns { - drop = change.column.drop - } - queries := m.alterColumns(table, change.column.add, change.column.modify, drop) - // If there's actual action to execute on ALTER TABLE. - for i := range queries { - query, args := queries[i].Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("alter table %q: %w", table, err) - } - } - for _, idx := range change.index.add { - query, args := m.addIndex(idx, table).Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("create index %q: %w", table, err) - } - } - return nil -} - -// changes to apply on existing table. -type changes struct { - // column changes. - column struct { - add []*Column - drop []*Column - modify []*Column - } - // index changes. - index struct { - add Indexes - drop Indexes - } -} - -// dropColumn returns the dropped column by name (if any). -func (c *changes) dropColumn(name string) (*Column, bool) { - for _, col := range c.column.drop { - if col.Name == name { - return col, true - } - } - return nil, false -} - -// changeSet returns a changes object to be applied on existing table. -// It fails if one of the changes is invalid. -func (m *Migrate) changeSet(curr, new *Table) (*changes, error) { - change := &changes{} - // pks. - if len(curr.PrimaryKey) != len(new.PrimaryKey) { - return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) - } - for i := range curr.PrimaryKey { - if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name { - return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) - } - } - // Add or modify columns. - for _, c1 := range new.Columns { - // Ignore primary keys. - if c1.PrimaryKey() { - continue - } - switch c2, ok := curr.column(c1.Name); { - case !ok: - change.column.add = append(change.column.add, c1) - case !c2.Type.Valid(): - return nil, fmt.Errorf("invalid type %q for column %q", c2.typ, c2.Name) - // Modify a non-unique column to unique. - case c1.Unique && !c2.Unique: - // Make sure the table does not have unique index for this column - // before adding it to the changeset, because there are 2 ways to - // configure uniqueness on sqlDialect.Field (using the Unique modifier or - // adding rule on the Indexes option). - if idx, ok := curr.index(c1.Name); !ok || !idx.Unique { - change.index.add.append(&Index{ - Name: c1.Name, - Unique: true, - Columns: []*Column{c1}, - columns: []string{c1.Name}, - }) - } - // Modify a unique column to non-unique. - case !c1.Unique && c2.Unique: - // If the uniqueness was defined on the Indexes option, - // or was moved from the Unique modifier to the Indexes. - if idx, ok := new.index(c1.Name); ok && idx.Unique { - continue - } - idx, ok := curr.index(c2.Name) - if !ok { - return nil, fmt.Errorf("missing index to drop for unique column %q", c2.Name) - } - change.index.drop.append(idx) - // Extending column types. - case m.needsConversion(c2, c1): - if !c2.ConvertibleTo(c1) { - return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, m.cType(c1), m.cType(c2)) - } - fallthrough - // Change nullability of a column. - case c1.Nullable != c2.Nullable: - change.column.modify = append(change.column.modify, c1) - // Change default value. - case c1.Default != nil && c2.Default == nil: - change.column.modify = append(change.column.modify, c1) - } - } - // Drop columns. - for _, c1 := range curr.Columns { - // If a column was dropped, multi-columns indexes that are associated with this column will - // no longer behave the same. Therefore, these indexes should be dropped too. There's no need - // to do it explicitly (here), because entc will remove them from the schema specification, - // and they will be dropped in the block below. - if _, ok := new.column(c1.Name); !ok { - change.column.drop = append(change.column.drop, c1) - } - } - // Add or modify indexes. - for _, idx1 := range new.Indexes { - switch idx2, ok := curr.index(idx1.Name); { - case !ok: - change.index.add.append(idx1) - // Changing index cardinality require drop and create. - case idx1.Unique != idx2.Unique: - change.index.drop.append(idx2) - change.index.add.append(idx1) - default: - im, ok := m.sqlDialect.(interface{ indexModified(old, new *Index) bool }) - // If the dialect supports comparing indexes. - if ok && im.indexModified(idx2, idx1) { - change.index.drop.append(idx2) - change.index.add.append(idx1) - } - } - } - // Drop indexes. - for _, idx := range curr.Indexes { - if _, isFK := new.fk(idx.Name); !isFK && !new.hasIndex(idx.Name, idx.realname) { - change.index.drop.append(idx) - } - } - return change, nil -} - -// fixture is a special migration code for renaming foreign-key columns (issue-#285). -func (m *Migrate) fixture(ctx context.Context, tx dialect.Tx, curr, new *Table) error { - d, ok := m.sqlDialect.(fkRenamer) - if !m.withFixture || !m.withForeignKeys || !ok { - return nil - } - rename := make(map[string]*Index) - for _, fk := range new.ForeignKeys { - ok, err := m.fkExist(ctx, tx, fk.Symbol) - if err != nil { - return fmt.Errorf("checking foreign-key existence %q: %w", fk.Symbol, err) - } - if !ok { - continue - } - column, err := m.fkColumn(ctx, tx, fk) - if err != nil { - return err - } - newcol := fk.Columns[0] - if column == newcol.Name { - continue - } - query, args := d.renameColumn(curr, &Column{Name: column}, newcol).Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("rename column %q: %w", column, err) - } - prev, ok := curr.column(column) - if !ok { - continue - } - // Find all indexes that ~maybe need to be renamed. - for _, idx := range prev.indexes { - switch _, ok := new.index(idx.Name); { - // Ignore indexes that exist in the schema, PKs. - case ok || idx.primary: - // Index that was created implicitly for a unique - // column needs to be renamed to the column name. - case d.isImplicitIndex(idx, prev): - idx2 := &Index{Name: newcol.Name, Unique: true, Columns: []*Column{newcol}} - query, args := d.renameIndex(curr, idx, idx2).Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("rename index %q: %w", prev.Name, err) - } - idx.Name = idx2.Name - default: - rename[idx.Name] = idx - } - } - // Update the name of the loaded column, so `changeSet` won't create it. - prev.Name = newcol.Name - } - // Go over the indexes that need to be renamed - // and find their ~identical in the new schema. - for _, idx := range rename { - Find: - // Find its ~identical in the new schema, and rename it - // if it doesn't exist. - for _, idx2 := range new.Indexes { - if _, ok := curr.index(idx2.Name); ok { - continue - } - if idx.sameAs(idx2) { - query, args := d.renameIndex(curr, idx, idx2).Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("rename index %q: %w", idx.Name, err) - } - idx.Name = idx2.Name - break Find - } - } - } - return nil -} - -// verify that the auto-increment counter is correct for table with universal-id support. -func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error { - vr, ok := m.sqlDialect.(verifyRanger) - if !ok || !m.universalID { - return nil - } - id := indexOf(m.typeRanges, t.Name) - if id == -1 { - return nil - } - return vr.verifyRange(ctx, tx, t, int64(id<<32)) -} - -// types loads the type list from the type store. It will create the types table, if it does not exist yet. -func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) error { - exists, err := m.tableExist(ctx, tx, TypeTable) - if err != nil { - return err - } - if !exists { - t := NewTypesTable() - query, args := m.tBuilder(t).Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return fmt.Errorf("create types table: %w", err) - } - return nil - } - rows := &sql.Rows{} - query, args := sql.Dialect(m.Dialect()). - Select("type").From(sql.Table(TypeTable)).OrderBy(sql.Asc("id")).Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return fmt.Errorf("query types table: %w", err) - } - defer rows.Close() - return sql.ScanSlice(rows, &m.typeRanges) -} - -func (m *Migrate) allocPKRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) error { - r, err := m.pkRange(ctx, conn, t) - if err != nil { - return err - } - return m.setRange(ctx, conn, t, r) -} - -func (m *Migrate) pkRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) (int64, error) { - id := indexOf(m.typeRanges, t.Name) - // If the table re-created, re-use its range from - // the past. Otherwise, allocate a new id-range. - if id == -1 { - if len(m.typeRanges) > MaxTypes { - return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes) - } - query, args := sql.Dialect(m.Dialect()).Insert(TypeTable).Columns("type").Values(t.Name).Query() - if err := conn.Exec(ctx, query, args, nil); err != nil { - return 0, fmt.Errorf("insert into ent_types: %w", err) - } - id = len(m.typeRanges) - m.typeRanges = append(m.typeRanges, t.Name) - } - return int64(id << 32), nil -} - -// fkColumn returns the column name of a foreign-key. -func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (string, error) { - t1 := sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS t1").Unquote().As("t1") - t2 := sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t2").Unquote().As("t2") - query, args := sql.Dialect(m.Dialect()). - Select("column_name"). - From(t1). - Join(t2). - On(t1.C("constraint_name"), t2.C("constraint_name")). - Where(sql.And( - sql.EQ(t2.C("constraint_type"), sql.Raw("'FOREIGN KEY'")), - m.sqlDialect.(fkRenamer).matchSchema(t2.C("table_schema")), - m.sqlDialect.(fkRenamer).matchSchema(t1.C("table_schema")), - sql.EQ(t2.C("constraint_name"), fk.Symbol), - )). - Query() - rows := &sql.Rows{} - if err := tx.Query(ctx, query, args, rows); err != nil { - return "", fmt.Errorf("reading foreign-key %q column: %w", fk.Symbol, err) - } - defer rows.Close() - column, err := sql.ScanString(rows) - if err != nil { - return "", fmt.Errorf("scanning foreign-key %q column: %w", fk.Symbol, err) - } - return column, nil -} - -// setup ensures the table is configured properly, like table columns -// are linked to their indexes, and PKs columns are defined. -func (m *Migrate) setupTables(tables []*Table) { m.atlas.setupTables(tables) } - -// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred. -func rollback(tx dialect.Tx, err error) error { - err = fmt.Errorf("sql/schema: %w", err) - if rerr := tx.Rollback(); rerr != nil { - err = fmt.Errorf("%w: %v", err, rerr) - } - return err -} - // exist checks if the given COUNT query returns a value >= 1. func exist(ctx context.Context, conn dialect.ExecQuerier, query string, args ...any) (bool, error) { rows := &sql.Rows{} @@ -628,30 +143,7 @@ type sqlDialect interface { atBuilder dialect.Driver init(context.Context) error - table(context.Context, dialect.Tx, string) (*Table, error) tableExist(context.Context, dialect.ExecQuerier, string) (bool, error) - fkExist(context.Context, dialect.Tx, string) (bool, error) - setRange(context.Context, dialect.ExecQuerier, *Table, int64) error - dropIndex(context.Context, dialect.Tx, *Index, string) error - // table, column and index builder per dialect. - cType(*Column) string - tBuilder(*Table) *sql.TableBuilder - addIndex(*Index, string) *sql.IndexBuilder - alterColumns(table string, add, modify, drop []*Column) sql.Queries - needsConversion(*Column, *Column) bool -} - -type preparer interface { - prepare(context.Context, dialect.Tx, *changes, string) error -} - -// fkRenamer is used by the fixture migration (to solve #285), -// and it's implemented by the different dialects for renaming FKs. -type fkRenamer interface { - matchSchema(...string) *sql.Predicate - isImplicitIndex(*Index, *Column) bool - renameIndex(*Table, *Index, *Index) sql.Querier - renameColumn(*Table, *Column, *Column) sql.Querier } // verifyRanger wraps the method for verifying global-id range correctness. diff --git a/dialect/sql/schema/migrate_test.go b/dialect/sql/schema/migrate_test.go index 1a2bd11ac..c35a43d13 100644 --- a/dialect/sql/schema/migrate_test.go +++ b/dialect/sql/schema/migrate_test.go @@ -28,53 +28,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMigrateHookOmitTable(t *testing.T) { - db, mk, err := sqlmock.New() - require.NoError(t, err) - - tables := []*Table{{Name: "users"}, {Name: "pets"}} - mock := mysqlMock{mk} - mock.start("5.7.23") - mock.tableExists("pets", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - - m, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator { - return CreateFunc(func(ctx context.Context, tables ...*Table) error { - return next.Create(ctx, tables[1]) - }) - }), WithAtlas(false)) - require.NoError(t, err) - err = m.Create(context.Background(), tables...) - require.NoError(t, err) -} - -func TestMigrateHookAddTable(t *testing.T) { - db, mk, err := sqlmock.New() - require.NoError(t, err) - - tables := []*Table{{Name: "users"}} - mock := mysqlMock{mk} - mock.start("5.7.23") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("pets", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - - m, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator { - return CreateFunc(func(ctx context.Context, tables ...*Table) error { - return next.Create(ctx, tables[0], &Table{Name: "pets"}) - }) - }), WithAtlas(false)) - require.NoError(t, err) - err = m.Create(context.Background(), tables...) - require.NoError(t, err) -} - func TestMigrate_Formatter(t *testing.T) { db, _, err := sqlmock.New() require.NoError(t, err) diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index e2ff21f79..e4e946407 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -13,7 +13,6 @@ import ( "strings" "entgo.io/ent/dialect" - "entgo.io/ent/dialect/entsql" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" @@ -22,7 +21,7 @@ import ( "ariga.io/atlas/sql/schema" ) -// MySQL is a MySQL migration driver. +// MySQL adapter for Atlas migration engine. type MySQL struct { dialect.Driver schema string @@ -59,532 +58,6 @@ func (d *MySQL) tableExist(ctx context.Context, conn dialect.ExecQuerier, name s return exist(ctx, conn, query, args...) } -func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { - query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLE_CONSTRAINTS").Schema("INFORMATION_SCHEMA")). - Where(sql.And( - d.matchSchema(), - sql.EQ("CONSTRAINT_TYPE", "FOREIGN KEY"), - sql.EQ("CONSTRAINT_NAME", name), - )).Query() - return exist(ctx, tx, query, args...) -} - -// table loads the current table description from the database. -func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { - rows := &sql.Rows{} - query, args := sql.Select( - "column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", - "numeric_precision", "numeric_scale", - ). - From(sql.Table("COLUMNS").Schema("INFORMATION_SCHEMA")). - Where(sql.And( - d.matchSchema(), - sql.EQ("TABLE_NAME", name)), - ).Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("mysql: reading table description %w", err) - } - // Call Close in cases of failures (Close is idempotent). - defer rows.Close() - t := NewTable(name) - for rows.Next() { - c := &Column{} - if err := d.scanColumn(c, rows); err != nil { - return nil, fmt.Errorf("mysql: %w", err) - } - t.AddColumn(c) - } - if err := rows.Err(); err != nil { - return nil, err - } - if err := rows.Close(); err != nil { - return nil, fmt.Errorf("mysql: closing rows %w", err) - } - indexes, err := d.indexes(ctx, tx, t) - if err != nil { - return nil, err - } - // Add and link indexes to table columns. - for _, idx := range indexes { - t.addIndex(idx) - } - if _, ok := d.mariadb(); ok { - if err := d.normalizeJSON(ctx, tx, t); err != nil { - return nil, err - } - } - return t, nil -} - -// table loads the table indexes from the database. -func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, t *Table) ([]*Index, error) { - rows := &sql.Rows{} - query, args := sql.Select("index_name", "column_name", "sub_part", "non_unique", "seq_in_index"). - From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")). - Where(sql.And( - d.matchSchema(), - sql.EQ("TABLE_NAME", t.Name), - )). - OrderBy("index_name", "seq_in_index"). - Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("mysql: reading index description %w", err) - } - defer rows.Close() - idx, err := d.scanIndexes(rows, t) - if err != nil { - return nil, fmt.Errorf("mysql: %w", err) - } - return idx, nil -} - -func (d *MySQL) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error { - return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []any{}, nil) -} - -func (d *MySQL) verifyRange(ctx context.Context, tx dialect.ExecQuerier, t *Table, expected int64) error { - if expected == 0 { - return nil - } - rows := &sql.Rows{} - query, args := sql.Select("AUTO_INCREMENT"). - From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")). - Where(sql.And( - d.matchSchema(), - sql.EQ("TABLE_NAME", t.Name), - )). - Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return fmt.Errorf("mysql: query auto_increment %w", err) - } - // Call Close in cases of failures (Close is idempotent). - defer rows.Close() - actual := &sql.NullInt64{} - if err := sql.ScanOne(rows, actual); err != nil { - return fmt.Errorf("mysql: scan auto_increment %w", err) - } - if err := rows.Close(); err != nil { - return err - } - // Table is empty and auto-increment is not configured. This can happen - // because MySQL (< 8.0) stores the auto-increment counter in main memory - // (not persistent), and the value is reset on restart (if table is empty). - if actual.Int64 <= 1 { - return d.setRange(ctx, tx, t, expected) - } - return nil -} - -// tBuilder returns the MySQL DSL query for table creation. -func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { - b := sql.CreateTable(t.Name).IfNotExists() - for _, c := range t.Columns { - b.Column(d.addColumn(c)) - } - for _, pk := range t.PrimaryKey { - b.PrimaryKey(pk.Name) - } - // Charset and collation config on MySQL table. - // These options can be overridden by the entsql annotation. - b.Charset("utf8mb4").Collate("utf8mb4_bin") - if t.Annotation != nil { - if charset := t.Annotation.Charset; charset != "" { - b.Charset(charset) - } - if collate := t.Annotation.Collation; collate != "" { - b.Collate(collate) - } - if opts := t.Annotation.Options; opts != "" { - b.Options(opts) - } - addChecks(b, t.Annotation) - } - return b -} - -// cType returns the MySQL string type for the given column. -func (d *MySQL) cType(c *Column) (t string) { - if c.SchemaType != nil && c.SchemaType[dialect.MySQL] != "" { - // MySQL returns the column type lower cased. - return strings.ToLower(c.SchemaType[dialect.MySQL]) - } - switch c.Type { - case field.TypeBool: - t = "boolean" - case field.TypeInt8: - t = "tinyint" - case field.TypeUint8: - t = "tinyint unsigned" - case field.TypeInt16: - t = "smallint" - case field.TypeUint16: - t = "smallint unsigned" - case field.TypeInt32: - t = "int" - case field.TypeUint32: - t = "int unsigned" - case field.TypeInt, field.TypeInt64: - t = "bigint" - case field.TypeUint, field.TypeUint64: - t = "bigint unsigned" - case field.TypeBytes: - size := int64(math.MaxUint16) - if c.Size > 0 { - size = c.Size - } - switch { - case size <= math.MaxUint8: - t = "tinyblob" - case size <= math.MaxUint16: - t = "blob" - case size < 1<<24: - t = "mediumblob" - case size <= math.MaxUint32: - t = "longblob" - } - case field.TypeJSON: - t = "json" - if compareVersions(d.version, "5.7.8") == -1 { - t = "longblob" - } - case field.TypeString: - size := c.Size - if size == 0 { - size = d.defaultSize(c) - } - switch { - case c.typ == "tinytext", c.typ == "text": - t = c.typ - case size <= math.MaxUint16: - t = fmt.Sprintf("varchar(%d)", size) - case size == 1<<24-1: - t = "mediumtext" - default: - t = "longtext" - } - case field.TypeFloat32, field.TypeFloat64: - t = c.scanTypeOr("double") - case field.TypeTime: - t = c.scanTypeOr("timestamp") - // In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP` - // and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is - // suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute. - if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c.Default == nil { - c.Nullable = c.Attr == "" - } - case field.TypeEnum: - values := make([]string, len(c.Enums)) - for i, e := range c.Enums { - values[i] = fmt.Sprintf("'%s'", e) - } - t = fmt.Sprintf("enum(%s)", strings.Join(values, ", ")) - case field.TypeUUID: - t = "char(36) binary" - if d.supportsUUID() { - t = "uuid" - } - case field.TypeOther: - t = c.typ - default: - panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name)) - } - return t -} - -// addColumn returns the DSL query for adding the given column to a table. -// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable]. -func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { - b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr) - c.unique(b) - if c.Increment { - b.Attr("AUTO_INCREMENT") - } - c.nullable(b) - c.defaultValue(b) - if c.Collation != "" { - b.Attr("COLLATE " + c.Collation) - } - if c.Type == field.TypeJSON { - // Manually add a `CHECK` clause for older versions of MariaDB for validating the - // JSON documents. This constraint is automatically included from version 10.4.3. - if version, ok := d.mariadb(); ok && compareVersions(version, "10.4.3") == -1 { - b.Check(func(b *sql.Builder) { - b.WriteString("JSON_VALID(").Ident(c.Name).WriteByte(')') - }) - } - } - return b -} - -// addIndex returns the querying for adding an index to MySQL. -func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder { - idx := sql.CreateIndex(i.Name).Table(table) - if i.Unique { - idx.Unique() - } - parts := indexParts(i) - for _, c := range i.Columns { - part, ok := parts[c.Name] - if !ok || part == 0 { - idx.Column(c.Name) - } else { - idx.Column(fmt.Sprintf("%s(%d)", idx.Builder.Quote(c.Name), part)) - } - } - return idx -} - -// dropIndex drops a MySQL index. -func (d *MySQL) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { - query, args := idx.DropBuilder(table).Query() - return tx.Exec(ctx, query, args, nil) -} - -// prepare runs preparation work that needs to be done to apply the change-set. -func (d *MySQL) prepare(ctx context.Context, tx dialect.Tx, change *changes, table string) error { - for _, idx := range change.index.drop { - switch n := len(idx.columns); { - case n == 0: - return fmt.Errorf("index %q has no columns", idx.Name) - case n > 1: - continue // not a foreign-key index. - } - var qr sql.Querier - Switch: - switch col, ok := change.dropColumn(idx.columns[0]); { - // If both the index and the column need to be dropped, the foreign-key - // constraint that is associated with them need to be dropped as well. - case ok: - names, err := d.fkNames(ctx, tx, table, col.Name) - if err != nil { - return err - } - if len(names) == 1 { - qr = sql.AlterTable(table).DropForeignKey(names[0]) - } - // If the uniqueness was dropped from a foreign-key column, - // create a "simple index" if no other index exist for it. - case !ok && idx.Unique && len(idx.Columns) > 0: - col := idx.Columns[0] - for _, idx2 := range col.indexes { - if idx2 != idx && len(idx2.columns) == 1 { - break Switch - } - } - names, err := d.fkNames(ctx, tx, table, col.Name) - if err != nil { - return err - } - if len(names) == 1 { - qr = sql.CreateIndex(names[0]).Table(table).Columns(col.Name) - } - } - if qr != nil { - query, args := qr.Query() - if err := tx.Exec(ctx, query, args, nil); err != nil { - return err - } - } - } - return nil -} - -// scanColumn scans the column information from MySQL column description. -func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { - var ( - nullable sql.NullString - defaults sql.NullString - numericPrecision sql.NullInt64 - numericScale sql.NullInt64 - ) - if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}, &numericPrecision, &numericScale); err != nil { - return fmt.Errorf("scanning column description: %w", err) - } - c.Unique = c.UniqueKey() - if nullable.Valid { - c.Nullable = nullable.String == "YES" - } - if c.typ == "" { - return fmt.Errorf("missing type information for column %q", c.Name) - } - parts, size, unsigned, err := parseColumn(c.typ) - if err != nil { - return err - } - switch parts[0] { - case "mediumint", "int": - c.Type = field.TypeInt32 - if unsigned { - c.Type = field.TypeUint32 - } - case "smallint": - c.Type = field.TypeInt16 - if unsigned { - c.Type = field.TypeUint16 - } - case "bigint": - c.Type = field.TypeInt64 - if unsigned { - c.Type = field.TypeUint64 - } - case "tinyint": - switch { - case size == 1: - c.Type = field.TypeBool - case unsigned: - c.Type = field.TypeUint8 - default: - c.Type = field.TypeInt8 - } - case "double", "float": - c.Type = field.TypeFloat64 - case "numeric", "decimal": - c.Type = field.TypeFloat64 - // If precision is specified then we should take that into account. - if numericPrecision.Valid { - schemaType := fmt.Sprintf("%s(%d,%d)", parts[0], numericPrecision.Int64, numericScale.Int64) - c.SchemaType = map[string]string{dialect.MySQL: schemaType} - } - case "time", "timestamp", "date", "datetime": - c.Type = field.TypeTime - // The mapping from schema defaults to database - // defaults is not supported for TypeTime fields. - defaults = sql.NullString{} - case "tinyblob": - c.Size = math.MaxUint8 - c.Type = field.TypeBytes - case "blob": - c.Size = math.MaxUint16 - c.Type = field.TypeBytes - case "mediumblob": - c.Size = 1<<24 - 1 - c.Type = field.TypeBytes - case "longblob": - c.Size = math.MaxUint32 - c.Type = field.TypeBytes - case "binary", "varbinary": - c.Type = field.TypeBytes - c.Size = size - case "varchar": - c.Type = field.TypeString - c.Size = size - case "text": - c.Size = math.MaxUint16 - c.Type = field.TypeString - case "mediumtext": - c.Size = 1<<24 - 1 - c.Type = field.TypeString - case "longtext": - c.Size = math.MaxInt32 - c.Type = field.TypeString - case "json": - c.Type = field.TypeJSON - case "enum": - c.Type = field.TypeEnum - // Parse the enum values according to the MySQL format. - // github.com/mysql/mysql-server/blob/8.0/sql/field.cc#Field_enum::sql_type - values := strings.TrimSuffix(strings.TrimPrefix(c.typ, "enum("), ")") - if values == "" { - return fmt.Errorf("mysql: unexpected enum type: %q", c.typ) - } - parts := strings.Split(values, "','") - for i := range parts { - c.Enums = append(c.Enums, strings.Trim(parts[i], "'")) - } - case "char": - c.Type = field.TypeOther - // UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens). - if size == 36 { - c.Type = field.TypeUUID - } - case "point", "geometry", "linestring", "polygon": - c.Type = field.TypeOther - default: - return fmt.Errorf("unknown column type %q for version %q", parts[0], d.version) - } - if defaults.Valid { - return c.ScanDefault(defaults.String) - } - return nil -} - -// scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows, -// should return the following 5 columns: INDEX_NAME, COLUMN_NAME, SUB_PART, NON_UNIQUE, -// SEQ_IN_INDEX. SEQ_IN_INDEX specifies the position of the column in the index columns. -func (d *MySQL) scanIndexes(rows *sql.Rows, t *Table) (Indexes, error) { - var ( - i Indexes - names = make(map[string]*Index) - ) - for rows.Next() { - var ( - name string - column string - nonuniq bool - seqindex int - subpart sql.NullInt64 - ) - if err := rows.Scan(&name, &column, &subpart, &nonuniq, &seqindex); err != nil { - return nil, fmt.Errorf("scanning index description: %w", err) - } - // Skip primary keys. - if name == "PRIMARY" { - c, ok := t.column(column) - if !ok { - return nil, fmt.Errorf("missing primary-key column: %q", column) - } - t.PrimaryKey = append(t.PrimaryKey, c) - continue - } - idx, ok := names[name] - if !ok { - idx = &Index{Name: name, Unique: !nonuniq, Annotation: &entsql.IndexAnnotation{}} - i = append(i, idx) - names[name] = idx - } - idx.columns = append(idx.columns, column) - if subpart.Int64 > 0 { - if idx.Annotation.PrefixColumns == nil { - idx.Annotation.PrefixColumns = make(map[string]uint) - } - idx.Annotation.PrefixColumns[column] = uint(subpart.Int64) - } - } - if err := rows.Err(); err != nil { - return nil, err - } - return i, nil -} - -// isImplicitIndex reports if the index was created implicitly for the unique column. -func (d *MySQL) isImplicitIndex(idx *Index, col *Column) bool { - // We execute `CHANGE COLUMN` on older versions of MySQL (<8.0), which - // auto create the new index. The old one, will be dropped in `changeSet`. - if compareVersions(d.version, "8.0.0") >= 0 { - return idx.Name == col.Name && col.Unique - } - return false -} - -// renameColumn returns the statement for renaming a column in -// MySQL based on its version. -func (d *MySQL) renameColumn(t *Table, old, new *Column) sql.Querier { - q := sql.AlterTable(t.Name) - if compareVersions(d.version, "8.0.0") >= 0 { - return q.RenameColumn(old.Name, new.Name) - } - return q.ChangeColumn(old.Name, d.addColumn(new)) -} - -// renameIndex returns the statement for renaming an index. -func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier { - q := sql.AlterTable(t.Name) - if compareVersions(d.version, "5.7.0") >= 0 { - return q.RenameIndex(old.Name, new.Name) - } - return q.DropIndex(old.Name).AddIndex(new.Builder(t.Name)) -} - // matchSchema returns the predicate for matching table schema. func (d *MySQL) matchSchema(columns ...string) *sql.Predicate { column := "TABLE_SCHEMA" @@ -597,196 +70,6 @@ func (d *MySQL) matchSchema(columns ...string) *sql.Predicate { return sql.EQ(column, sql.Raw("(SELECT DATABASE())")) } -// tables returns the query for getting the in the schema. -func (d *MySQL) tables() sql.Querier { - return sql.Select("TABLE_NAME"). - From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")). - Where(d.matchSchema()) -} - -// alterColumns returns the queries for applying the columns change-set. -func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries { - b := sql.Dialect(dialect.MySQL).AlterTable(table) - for _, c := range add { - b.AddColumn(d.addColumn(c)) - } - for _, c := range modify { - b.ModifyColumn(d.addColumn(c)) - } - for _, c := range drop { - b.DropColumn(sql.Dialect(dialect.MySQL).Column(c.Name)) - } - if len(b.Queries) == 0 { - return nil - } - return sql.Queries{b} -} - -// normalizeJSON normalize MariaDB longtext columns to type JSON. -func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) error { - columns := make(map[string]*Column) - for _, c := range t.Columns { - if c.typ == "longtext" { - columns[c.Name] = c - } - } - if len(columns) == 0 { - return nil - } - rows := &sql.Rows{} - query, args := sql.Select("CONSTRAINT_NAME"). - From(sql.Table("CHECK_CONSTRAINTS").Schema("INFORMATION_SCHEMA")). - Where(sql.And( - d.matchSchema("CONSTRAINT_SCHEMA"), - sql.EQ("TABLE_NAME", t.Name), - sql.Like("CHECK_CLAUSE", "json_valid(%)"), - )). - Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return fmt.Errorf("mysql: query table constraints %w", err) - } - // Call Close in cases of failures (Close is idempotent). - defer rows.Close() - names := make([]string, 0, len(columns)) - if err := sql.ScanSlice(rows, &names); err != nil { - return fmt.Errorf("mysql: scan table constraints: %w", err) - } - if err := rows.Err(); err != nil { - return err - } - if err := rows.Close(); err != nil { - return err - } - for _, name := range names { - c, ok := columns[name] - if ok { - c.Type = field.TypeJSON - } - } - return nil -} - -// mariadb reports if the migration runs on MariaDB and returns the semver string. -func (d *MySQL) mariadb() (string, bool) { - idx := strings.Index(d.version, "MariaDB") - if idx == -1 { - return "", false - } - return d.version[:idx-1], true -} - -// parseColumn returns column parts, size and signed-info from a MySQL type. -func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) { - switch parts = strings.FieldsFunc(typ, func(r rune) bool { - return r == '(' || r == ')' || r == ' ' || r == ',' - }); parts[0] { - case "tinyint", "smallint", "mediumint", "int", "bigint": - switch { - case len(parts) == 2 && parts[1] == "unsigned": // int unsigned - unsigned = true - case len(parts) == 3: // int(10) unsigned - unsigned = true - fallthrough - case len(parts) == 2: // int(10) - size, err = strconv.ParseInt(parts[1], 10, 0) - } - case "varbinary", "varchar", "char", "binary": - if len(parts) > 1 { - size, err = strconv.ParseInt(parts[1], 10, 64) - } - } - if err != nil { - return parts, size, unsigned, fmt.Errorf("converting %s size to int: %w", parts[0], err) - } - return parts, size, unsigned, nil -} - -// fkNames returns the foreign-key names of a column. -func (d *MySQL) fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) { - query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("KEY_COLUMN_USAGE").Schema("INFORMATION_SCHEMA")). - Where(sql.And( - sql.EQ("TABLE_NAME", table), - sql.EQ("COLUMN_NAME", column), - // NULL for unique and primary-key constraints. - sql.NotNull("POSITION_IN_UNIQUE_CONSTRAINT"), - d.matchSchema(), - )). - Query() - var ( - names []string - rows = &sql.Rows{} - ) - if err := tx.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("mysql: reading constraint names %w", err) - } - defer rows.Close() - if err := sql.ScanSlice(rows, &names); err != nil { - return nil, err - } - return names, nil -} - -// defaultSize returns the default size for MySQL/MariaDB varchar type -// based on column size, charset and table indexes, in order to avoid -// index prefix key limit (767) for older versions of MySQL/MariaDB. -func (d *MySQL) defaultSize(c *Column) int64 { - size := DefaultStringLen - version, checked := d.version, "5.7.0" - if v, ok := d.mariadb(); ok { - version, checked = v, "10.2.2" - } - switch { - // Version is >= 5.7 for MySQL, or >= 10.2.2 for MariaDB. - case compareVersions(version, checked) != -1: - // Column is non-unique, or not part of any index (reaching - // the error 1071). - case !c.Unique && len(c.indexes) == 0 && !c.PrimaryKey(): - default: - size = 191 - } - return size -} - -// needsConversion reports if column "old" needs to be converted -// (by table altering) to column "new". -func (d *MySQL) needsConversion(old, new *Column) bool { - return d.cType(old) != d.cType(new) -} - -// indexModified used by the migration differ to check if the index was modified. -func (d *MySQL) indexModified(old, new *Index) bool { - oldParts, newParts := indexParts(old), indexParts(new) - if len(oldParts) != len(newParts) { - return true - } - for column, oldPart := range oldParts { - newPart, ok := newParts[column] - if !ok || oldPart != newPart { - return true - } - } - return false -} - -// indexParts returns a map holding the sub_part mapping if exists. -func indexParts(idx *Index) map[string]uint { - parts := make(map[string]uint) - if idx.Annotation == nil { - return parts - } - // If prefix (without a name) was defined on the - // annotation, map it to the single column index. - if idx.Annotation.Prefix > 0 && len(idx.Columns) == 1 { - parts[idx.Columns[0].Name] = idx.Annotation.Prefix - } - for column, part := range idx.Annotation.PrefixColumns { - parts[column] = part - } - return parts -} - -// Atlas integration. - func (d *MySQL) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { return mysql.Open(&db{ExecQuerier: conn}) } @@ -988,23 +271,56 @@ func (d *MySQL) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error return nil } -func indexType(idx *Index, d string) (string, bool) { - ant := idx.Annotation - if ant == nil { - return "", false - } - if ant.Types != nil && ant.Types[d] != "" { - return ant.Types[d], true - } - if ant.Type != "" { - return ant.Type, true - } - return "", false -} - -func (MySQL) atTypeRangeSQL(ts ...string) string { +func (*MySQL) atTypeRangeSQL(ts ...string) string { for i := range ts { ts[i] = fmt.Sprintf("('%s')", ts[i]) } return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", ")) } + +// mariadb reports if the migration runs on MariaDB and returns the semver string. +func (d *MySQL) mariadb() (string, bool) { + idx := strings.Index(d.version, "MariaDB") + if idx == -1 { + return "", false + } + return d.version[:idx-1], true +} + +// defaultSize returns the default size for MySQL/MariaDB varchar type +// based on column size, charset and table indexes, in order to avoid +// index prefix key limit (767) for older versions of MySQL/MariaDB. +func (d *MySQL) defaultSize(c *Column) int64 { + size := DefaultStringLen + version, checked := d.version, "5.7.0" + if v, ok := d.mariadb(); ok { + version, checked = v, "10.2.2" + } + switch { + // Version is >= 5.7 for MySQL, or >= 10.2.2 for MariaDB. + case compareVersions(version, checked) != -1: + // Column is non-unique, or not part of any index (reaching + // the error 1071). + case !c.Unique && len(c.indexes) == 0 && !c.PrimaryKey(): + default: + size = 191 + } + return size +} + +// indexParts returns a map holding the sub_part mapping if exists. +func indexParts(idx *Index) map[string]uint { + parts := make(map[string]uint) + if idx.Annotation == nil { + return parts + } + // If prefix (without a name) was defined on the + // annotation, map it to the single column index. + if idx.Annotation.Prefix > 0 && len(idx.Columns) == 1 { + parts[idx.Columns[0].Name] = idx.Annotation.Prefix + } + for column, part := range idx.Annotation.PrefixColumns { + parts[column] = part + } + return parts +} diff --git a/dialect/sql/schema/mysql_test.go b/dialect/sql/schema/mysql_test.go deleted file mode 100644 index e93eff605..000000000 --- a/dialect/sql/schema/mysql_test.go +++ /dev/null @@ -1,1410 +0,0 @@ -// Copyright 2019-present Facebook Inc. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package schema - -import ( - "context" - "math" - "regexp" - "strings" - "testing" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/entsql" - "entgo.io/ent/dialect/sql" - "entgo.io/ent/schema/field" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/require" -) - -func TestMySQL_Create(t *testing.T) { - tests := []struct { - name string - tables []*Table - options []MigrateOption - before func(mysqlMock) - wantErr bool - }{ - { - name: "tx failed", - before: func(mock mysqlMock) { - mock.ExpectBegin(). - WillReturnError(sqlmock.ErrCancelled) - }, - wantErr: true, - }, - { - name: "no tables", - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.ExpectCommit() - }, - }, - { - name: "create new table", - tables: []*Table{ - { - Name: "users", - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "age", Type: field.TypeInt}, - {Name: "doc", Type: field.TypeJSON, Nullable: true}, - {Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}}, - {Name: "uuid", Type: field.TypeUUID, Nullable: true}, - {Name: "ts", Type: field.TypeTime}, - {Name: "ts_default", Type: field.TypeTime, Default: "CURRENT_TIMESTAMP"}, - {Name: "datetime", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Default: "CURRENT_TIMESTAMP"}, - {Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2)"}}, - {Name: "unsigned decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2) unsigned"}}, - {Name: "float", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "float"}, Default: "0"}, - }, - Annotation: &entsql.Annotation{ - Charset: "utf8", - Collation: "utf8_general_ci", - Options: "ENGINE = INNODB", - Check: "price > 0", - Checks: map[string]string{ - "valid_age": "age > 0", - "valid_name": "name <> ''", - }, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.8") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `age` bigint NOT NULL, `doc` json NULL, `enums` enum('a', 'b') NOT NULL, `uuid` char(36) binary NULL, `ts` timestamp NULL, `ts_default` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, `datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, `decimal` decimal(6,2) NOT NULL, `unsigned decimal` decimal(6,2) unsigned NOT NULL, `float` float NOT NULL DEFAULT '0', PRIMARY KEY(`id`), CHECK (price > 0), CONSTRAINT `valid_age` CHECK (age > 0), CONSTRAINT `valid_name` CHECK (name <> '')) CHARACTER SET utf8 COLLATE utf8_general_ci ENGINE = INNODB")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "create new table with specific field collation", - tables: []*Table{ - { - Name: "users", - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "address", Type: field.TypeString, Nullable: true, Collation: "utf8_unicode_ci"}, - {Name: "age", Type: field.TypeInt}, - {Name: "doc", Type: field.TypeJSON, Nullable: true}, - {Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}}, - {Name: "uuid", Type: field.TypeUUID, Nullable: true}, - {Name: "datetime", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Default: "CURRENT_TIMESTAMP"}, - {Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2)"}}, - }, - Annotation: &entsql.Annotation{ - Charset: "utf8", - Collation: "utf8_general_ci", - Options: "ENGINE = INNODB", - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.33") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `address` varchar(255) NULL COLLATE utf8_unicode_ci, `age` bigint NOT NULL, `doc` json NULL, `enums` enum('a', 'b') NOT NULL, `uuid` char(36) binary NULL, `datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, `decimal` decimal(6,2) NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8 COLLATE utf8_general_ci ENGINE = INNODB")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "create new table 5.6", - tables: []*Table{ - { - Name: "users", - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt}, - {Name: "name", Type: field.TypeString, Unique: true}, - {Name: "doc", Type: field.TypeJSON, Nullable: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.6.35") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `age` bigint NOT NULL, `name` varchar(191) UNIQUE NOT NULL, `doc` longblob NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "create new table with foreign key", - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "created_at", Type: field.TypeTime}, - } - c2 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString}, - {Name: "owner_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - } - t2 = &Table{ - Name: "pets", - Columns: c2, - PrimaryKey: c2[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "pets_owner", - Columns: c2[2:], - RefTable: t1, - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - return []*Table{t1, t2} - }(), - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("pets", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` bigint NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.fkExists("pets_owner", false) - mock.ExpectExec(escape("ALTER TABLE `pets` ADD CONSTRAINT `pets_owner` FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "create new table with foreign key disabled", - options: []MigrateOption{ - WithForeignKeys(false), - }, - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "created_at", Type: field.TypeTime}, - } - c2 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString}, - {Name: "owner_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - } - t2 = &Table{ - Name: "pets", - Columns: c2, - PrimaryKey: c2[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "pets_owner", - Columns: c2[2:], - RefTable: t1, - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - return []*Table{t1, t2} - }(), - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("pets", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` bigint NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add columns to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, - {Name: "mediumtext", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.MySQL: "mediumtext"}}, - {Name: "uuid", Type: field.TypeUUID, Nullable: true}, - {Name: "date", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{dialect.MySQL: "date"}}, - {Name: "age", Type: field.TypeInt}, - {Name: "tiny", Type: field.TypeInt8}, - {Name: "tiny_unsigned", Type: field.TypeUint8}, - {Name: "small", Type: field.TypeInt16}, - {Name: "small_unsigned", Type: field.TypeUint16}, - {Name: "big", Type: field.TypeInt64}, - {Name: "big_unsigned", Type: field.TypeUint64}, - {Name: "decimal", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2)"}}, - {Name: "unsigned_decimal", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2) unsigned"}}, - {Name: "ts", Type: field.TypeTime}, - {Name: "timestamp", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "TIMESTAMP"}, Default: "CURRENT_TIMESTAMP"}, - {Name: "float", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "float"}, Default: "0"}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("8.0.19") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("mediumtext", "mediumtext", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin", nil, nil). - AddRow("date", "date", "YES", "YES", "NULL", "", "", "", nil, nil). - // 8.0.19: new int column type formats - AddRow("tiny", "tinyint", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("tiny_unsigned", "tinyint unsigned", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("small", "smallint", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("small_unsigned", "smallint unsigned", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("big", "bigint", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("big_unsigned", "bigint unsigned", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("decimal", "decimal(6,2)", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("unsigned_decimal", "decimal(6,2) unsigned", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("timestamp", "timestamp", "NO", "NO", "CURRENT_TIMESTAMP", "DEFAULT_GENERATED on update CURRENT_TIMESTAMP", "", "", nil, nil). - AddRow("float", "float", "NO", "NO", "0", "0", "", "", nil, nil)) - - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` bigint NOT NULL, ADD COLUMN `ts` timestamp NOT NULL, MODIFY COLUMN `timestamp` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "enums", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "enums1", Type: field.TypeEnum, Enums: []string{"a", "b"}}, // add enum. - {Name: "enums2", Type: field.TypeEnum, Enums: []string{"a"}}, // remove enum. - {Name: "enums3", Type: field.TypeEnum, Enums: []string{"a", "b c"}}, // no changes. - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("enums1", "enum('a')", "YES", "NO", "NULL", "", "", "", nil, nil). - AddRow("enums2", "enum('b', 'a')", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("enums3", "enum('a', 'b c')", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `enums1` enum('a', 'b') NOT NULL, MODIFY COLUMN `enums2` enum('a') NOT NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "datetime and timestamp", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Nullable: true}, - {Name: "deleted_at", Type: field.TypeTime, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("created_at", "datetime", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("updated_at", "timestamp", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("deleted_at", "datetime", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `updated_at` datetime NULL, MODIFY COLUMN `deleted_at` timestamp NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add int column with default value to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "age", Type: field.TypeInt, Default: 10}, - {Name: "doc", Type: field.TypeJSON, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.6.0") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("doc", "longblob", "YES", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` bigint NOT NULL DEFAULT 10")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add blob columns", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "tiny", Type: field.TypeBytes, Size: 100}, - {Name: "blob", Type: field.TypeBytes, Size: 1e3}, - {Name: "medium", Type: field.TypeBytes, Size: 1e5}, - {Name: "long", Type: field.TypeBytes, Size: 1e8}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `tiny` tinyblob NOT NULL, ADD COLUMN `blob` blob NOT NULL, ADD COLUMN `medium` mediumblob NOT NULL, ADD COLUMN `long` longblob NOT NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add binary column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "binary", Type: field.TypeBytes, Size: 20, SchemaType: map[string]string{dialect.MySQL: "binary(20)"}}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("8.0.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `binary` binary(20) NOT NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "accept varbinary columns", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "tiny", Type: field.TypeBytes, Size: 100}, - {Name: "medium", Type: field.TypeBytes, Size: math.MaxUint32}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("tiny", "varbinary(255)", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("medium", "varbinary(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `medium` longblob NOT NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add float column with default value to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "age", Type: field.TypeFloat64, Default: 10.1}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec("ALTER TABLE `users` ADD COLUMN `age` double NOT NULL DEFAULT 10.1"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add bool column with default value", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "age", Type: field.TypeBool, Default: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec("ALTER TABLE `users` ADD COLUMN `age` boolean NOT NULL DEFAULT true"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add string column with default value", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "nick", Type: field.TypeString, Default: "unknown"}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `nick` varchar(255) NOT NULL DEFAULT 'unknown'")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add column with unsupported default value", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "nick", Type: field.TypeString, Size: 1 << 17, Default: "unknown"}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `nick` longtext NOT NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "drop columns", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - options: []MigrateOption{WithDropColumn(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` DROP COLUMN `name`")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "modify column to nullable", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt}, - {Name: "name", Type: field.TypeString, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil). - AddRow("age", "bigint(20)", "NO", "NO", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `name` varchar(255) NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "apply uniqueness on column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt, Unique: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("age", "bigint(20)", "NO", "", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - // create the unique index. - mock.ExpectExec(escape("CREATE UNIQUE INDEX `age` ON `users`(`age`)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "remove uniqueness from column without option", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("age", "bigint(20)", "NO", "UNI", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1"). - AddRow("age", "age", nil, "0", "1")) - mock.ExpectCommit() - }, - }, - { - name: "remove uniqueness from column with option", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - options: []MigrateOption{WithDropIndex(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("age", "bigint(20)", "NO", "UNI", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1"). - AddRow("age", "age", nil, "0", "1")) - // check if a foreign-key needs to be dropped. - mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). - WithArgs("users", "age"). - WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"})) - // drop the unique index. - mock.ExpectExec(escape("DROP INDEX `age` ON `users`")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "increase index sub_part", - tables: func() []*Table { - t := &Table{ - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - Indexes: []*Index{ - {Name: "prefix_text", Annotation: &entsql.IndexAnnotation{Prefix: 100}}, - }, - } - t.Indexes[0].Columns = t.Columns[1:] - return []*Table{t} - }(), - options: []MigrateOption{WithDropIndex(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("text", "longtext", "YES", "NO", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1"). - AddRow("prefix_text", "text", "50", "0", "1")) - mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). - WithArgs("users", "text"). - WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"})) - // modify index by dropping and creating it. - mock.ExpectExec(escape("DROP INDEX `prefix_text` ON `users`")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape("CREATE INDEX `prefix_text` ON `users`(`text`(100))")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "ignore foreign keys on index dropping", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "parent_id", Type: field.TypeInt, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - ForeignKeys: []*ForeignKey{ - { - Symbol: "parent_id", - Columns: []*Column{ - {Name: "parent_id", Type: field.TypeInt, Nullable: true}, - }, - }, - }, - }, - }, - options: []MigrateOption{WithDropIndex(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1"). - AddRow("old_index", "old", nil, "0", "1"). - AddRow("parent_id", "parent_id", nil, "0", "1")) - // drop the unique index. - mock.ExpectExec(escape("DROP INDEX `old_index` ON `users`")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // foreign key already exist. - mock.fkExists("parent_id", true) - mock.ExpectCommit() - }, - }, - { - name: "drop foreign key with column and index", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - options: []MigrateOption{WithDropIndex(true), WithDropColumn(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1"). - AddRow("parent_id", "parent_id", nil, "0", "1")) - // check if a foreign-key needs to be dropped. - mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). - WithArgs("users", "parent_id"). - WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"}).AddRow("users_parent_id")) - mock.ExpectExec(escape("ALTER TABLE `users` DROP FOREIGN KEY `users_parent_id`")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // drop the unique index. - mock.ExpectExec(escape("DROP INDEX `parent_id` ON `users`")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // drop the unique index. - mock.ExpectExec(escape("ALTER TABLE `users` DROP COLUMN `parent_id`")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "create a new simple-index for the foreign-key", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "parent_id", Type: field.TypeInt, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - options: []MigrateOption{WithDropIndex(true), WithDropColumn(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1"). - AddRow("parent_id", "parent_id", nil, "0", "1")) - // check if there's a foreign-key that is associated with this index. - mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). - WithArgs("users", "parent_id"). - WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"}).AddRow("users_parent_id")) - // create a new index, to replace the old one (that needs to be dropped). - mock.ExpectExec(escape("CREATE INDEX `users_parent_id` ON `users`(`parent_id`)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // drop the unique index. - mock.ExpectExec(escape("DROP INDEX `parent_id` ON `users`")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add edge to table", - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "user_spouse" + strings.Repeat("_", 64), // super long fk. - Columns: c1[2:], - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - t1.ForeignKeys[0].RefTable = t1 - return []*Table{t1} - }(), - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` bigint NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.fkExists("user_spouse_____________________390ed76f91d3c57cd3516e7690f621dc", false) - mock.ExpectExec("ALTER TABLE `users` ADD CONSTRAINT `.{64}` FOREIGN KEY\\(`spouse_id`\\) REFERENCES `users`\\(`id`\\) ON DELETE CASCADE"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "universal id for all tables", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("ent_types", false) - // create ent_types table. - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `ent_types`(`id` bigint unsigned AUTO_INCREMENT NOT NULL, `type` varchar(255) UNIQUE NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set users id range. - mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). - WithArgs("users"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("groups"). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set groups id range. - mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). - WithArgs("groups"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "universal id for new tables", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("ent_types", true) - // query ent_types table. - mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). - WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) - mock.tableExists("users", true) - // users table has no changes. - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - // query groups table. - mock.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("groups"). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set groups id range. - mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). - WithArgs("groups"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "universal id for restored tables", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("ent_types", true) - // query ent_types table. - mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). - WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set users id range (without inserting to ent_types). - mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("groups", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set groups id range. - mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). - WithArgs("groups"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "universal id mismatch with ent_types", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("ent_types", true) - // query ent_types table. - mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). - WillReturnRows(sqlmock.NewRows([]string{"type"}). - AddRow("deleted"). - AddRow("users")) - mock.tableExists("users", true) - // users table has no changes. - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - // query the auto-increment value. - mock.ExpectQuery(escape("SELECT `AUTO_INCREMENT` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"AUTO_INCREMENT"}). - AddRow(1)) - // restore the auto-increment counter. - mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 4294967296")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "no modify numeric column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,4)"}}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("price", "decimal(6,4)", "NO", "YES", "NULL", "", "", "", "6", "4")) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectCommit() - }, - }, - { - name: "modify numeric column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,4)"}}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("5.7.23") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("price", "decimal(6,4)", "NO", "YES", "NULL", "", "", "", "5", "4")) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `price` decimal(6,4) NOT NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - // MariaDB specific tests. - { - name: "mariadb/10.2.32/create table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "json", Type: field.TypeJSON, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("10.2.32-MariaDB") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL CHECK (JSON_VALID(`json`)), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "mariadb/10.3.13/create table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "json", Type: field.TypeJSON, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("10.3.13-MariaDB-1:10.3.13+maria~bionic") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL CHECK (JSON_VALID(`json`)), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "mariadb/10.5.8/create table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "json", Type: field.TypeJSON, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("10.5.8-MariaDB") - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "mariadb/10.5.8/table exists", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "json", Type: field.TypeJSON, Nullable: true}, - {Name: "longtext", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock mysqlMock) { - mock.start("10.5.8-MariaDB-1:10.5.8+maria~focal") - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). - AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). - AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). - AddRow("json", "longtext", "YES", "YES", "NULL", "", "utf8mb4", "utf8mb4_bin", nil, nil). - AddRow("longtext", "longtext", "YES", "YES", "NULL", "", "utf8mb4", "utf8mb4_bin", nil, nil)) - mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). - AddRow("PRIMARY", "id", nil, "0", "1")) - mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`CHECK_CONSTRAINTS` WHERE `CONSTRAINT_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? AND `CHECK_CLAUSE` LIKE ?")). - WithArgs("users", "json_valid(%)"). - WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"}). - AddRow("json")) - mock.ExpectCommit() - }, - }, - { - name: "mariadb/10.1.37/create table", - tables: []*Table{ - { - Name: "users", - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt}, - {Name: "name", Type: field.TypeString, Unique: true}, - }, - }, - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock mysqlMock) { - mock.start("10.1.48-MariaDB-1~bionic") - mock.tableExists("ent_types", false) - // create ent_types table. - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `ent_types`(`id` bigint unsigned AUTO_INCREMENT NOT NULL, `type` varchar(191) UNIQUE NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `age` bigint NOT NULL, `name` varchar(191) UNIQUE NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set users id range. - mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). - WithArgs("users"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, mock, err := sqlmock.New() - require.NoError(t, err) - tt.before(mysqlMock{mock}) - migrate, err := NewMigrate(sql.OpenDB("mysql", db), append(tt.options, WithAtlas(false))...) - require.NoError(t, err) - err = migrate.Create(context.Background(), tt.tables...) - require.Equal(t, tt.wantErr, err != nil, err) - }) - } -} - -type mysqlMock struct { - sqlmock.Sqlmock -} - -func (m mysqlMock) start(version string) { - m.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")). - WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", version)) - m.ExpectBegin() -} - -func (m mysqlMock) tableExists(table string, exists bool) { - count := 0 - if exists { - count = 1 - } - m.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). - WithArgs(table). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) -} - -func (m mysqlMock) fkExists(fk string, exists bool) { - count := 0 - if exists { - count = 1 - } - m.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLE_CONSTRAINTS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")). - WithArgs("FOREIGN KEY", fk). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) -} - -func escape(query string) string { - rows := strings.Split(query, "\n") - for i := range rows { - rows[i] = strings.TrimPrefix(rows[i], " ") - } - query = strings.Join(rows, " ") - return strings.TrimSpace(regexp.QuoteMeta(query)) + "$" -} diff --git a/dialect/sql/schema/postgres.go b/dialect/sql/schema/postgres.go index b778b2e5f..133801087 100644 --- a/dialect/sql/schema/postgres.go +++ b/dialect/sql/schema/postgres.go @@ -10,7 +10,6 @@ import ( "reflect" "strconv" "strings" - "unicode" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" @@ -21,7 +20,7 @@ import ( "ariga.io/atlas/sql/schema" ) -// Postgres is a postgres migration driver. +// Postgres adapter for Atlas migration engine. type Postgres struct { dialect.Driver schema string @@ -67,453 +66,6 @@ func (d *Postgres) tableExist(ctx context.Context, conn dialect.ExecQuerier, nam return exist(ctx, conn, query, args...) } -// tableExist checks if a foreign-key exists in the current schema. -func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { - query, args := sql.Dialect(dialect.Postgres). - Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")). - Where(sql.And( - d.matchSchema(), - sql.EQ("constraint_type", "FOREIGN KEY"), - sql.EQ("constraint_name", name), - )).Query() - return exist(ctx, tx, query, args...) -} - -// setRange sets restart the identity column to the given offset. Used by the universal-id option. -func (d *Postgres) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error { - if value == 0 { - value = 1 // RESTART value cannot be < 1. - } - pk := "id" - if len(t.PrimaryKey) == 1 { - pk = t.PrimaryKey[0].Name - } - return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE %q ALTER COLUMN %q RESTART WITH %d", t.Name, pk, value), []any{}, nil) -} - -// table loads the current table description from the database. -func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { - rows := &sql.Rows{} - query, args := sql.Dialect(dialect.Postgres). - Select( - "column_name", "data_type", "is_nullable", "column_default", "udt_name", - "numeric_precision", "numeric_scale", "character_maximum_length", - ). - From(sql.Table("columns").Schema("information_schema")). - Where(sql.And( - d.matchSchema(), - sql.EQ("table_name", name), - )).Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("postgres: reading table description %w", err) - } - // Call `Close` in cases of failures (`Close` is idempotent). - defer rows.Close() - t := NewTable(name) - for rows.Next() { - c := &Column{} - if err := d.scanColumn(c, rows); err != nil { - return nil, err - } - t.AddColumn(c) - } - if err := rows.Err(); err != nil { - return nil, err - } - if err := rows.Close(); err != nil { - return nil, fmt.Errorf("closing rows %w", err) - } - idxs, err := d.indexes(ctx, tx, name) - if err != nil { - return nil, err - } - // Populate the index information to the table and its columns. - // We do it manually, because PK and uniqueness information does - // not exist when querying the information_schema.COLUMNS above. - for _, idx := range idxs { - switch { - case idx.primary: - for _, name := range idx.columns { - c, ok := t.column(name) - if !ok { - return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) - } - c.Key = PrimaryKey - t.PrimaryKey = append(t.PrimaryKey, c) - } - case idx.Unique && len(idx.columns) == 1: - name := idx.columns[0] - c, ok := t.column(name) - if !ok { - return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) - } - c.Key = UniqueKey - c.Unique = true - fallthrough - default: - t.addIndex(idx) - } - } - return t, nil -} - -// indexesQuery holds a query format for retrieving -// table indexes of the current schema. -const indexesQuery = ` -SELECT i.relname AS index_name, - a.attname AS column_name, - idx.indisprimary AS primary, - idx.indisunique AS unique, - array_position(idx.indkey, a.attnum) as seq_in_index -FROM pg_class t, - pg_class i, - pg_index idx, - pg_attribute a, - pg_namespace n -WHERE t.oid = idx.indrelid - AND i.oid = idx.indexrelid - AND n.oid = t.relnamespace - AND a.attrelid = t.oid - AND a.attnum = ANY(idx.indkey) - AND t.relkind = 'r' - AND n.nspname = %s - AND t.relname = '%s' -ORDER BY index_name, seq_in_index; -` - -// indexesQuery returns the query (and its placeholders) for getting table indexes. -func (d *Postgres) indexesQuery(table string) (string, []any) { - if d.schema != "" { - return fmt.Sprintf(indexesQuery, "$1", table), []any{d.schema} - } - return fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", table), nil -} - -func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (Indexes, error) { - rows := &sql.Rows{} - query, args := d.indexesQuery(table) - if err := tx.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("querying indexes for table %s: %w", table, err) - } - defer rows.Close() - var ( - idxs Indexes - names = make(map[string]*Index) - ) - for rows.Next() { - var ( - seqindex int - name, column string - unique, primary bool - ) - if err := rows.Scan(&name, &column, &primary, &unique, &seqindex); err != nil { - return nil, fmt.Errorf("scanning index description: %w", err) - } - // If the index is prefixed with the table, it may was added by - // `addIndex` and it should be trimmed. But, since entc prefixes - // all indexes with schema-type, for uncountable types (like, media - // or equipment) this isn't correct, and we fallback for the real-name. - short := strings.TrimPrefix(name, table+"_") - idx, ok := names[short] - if !ok { - idx = &Index{Name: short, Unique: unique, primary: primary, realname: name} - idxs = append(idxs, idx) - names[short] = idx - } - idx.columns = append(idx.columns, column) - } - if err := rows.Err(); err != nil { - return nil, err - } - return idxs, nil -} - -// maxCharSize defines the maximum size of limited character types in Postgres (10 MB). -const maxCharSize = 10 << 20 - -// scanColumn scans the information a column from column description. -func (d *Postgres) scanColumn(c *Column, rows *sql.Rows) error { - var ( - nullable sql.NullString - defaults sql.NullString - udt sql.NullString - numericPrecision sql.NullInt64 - numericScale sql.NullInt64 - characterMaximumLen sql.NullInt64 - ) - if err := rows.Scan(&c.Name, &c.typ, &nullable, &defaults, &udt, &numericPrecision, &numericScale, &characterMaximumLen); err != nil { - return fmt.Errorf("scanning column description: %w", err) - } - if nullable.Valid { - c.Nullable = nullable.String == "YES" - } - switch c.typ { - case "boolean": - c.Type = field.TypeBool - case "smallint": - c.Type = field.TypeInt16 - case "integer": - c.Type = field.TypeInt32 - case "bigint": - c.Type = field.TypeInt64 - case "real": - c.Type = field.TypeFloat32 - case "double precision": - c.Type = field.TypeFloat64 - case "numeric", "decimal": - c.Type = field.TypeFloat64 - // If precision is specified then we should take that into account. - if numericPrecision.Valid { - schemaType := fmt.Sprintf("%s(%d,%d)", c.typ, numericPrecision.Int64, numericScale.Int64) - c.SchemaType = map[string]string{dialect.Postgres: schemaType} - } - case "text": - c.Type = field.TypeString - c.Size = maxCharSize + 1 - case "character", "character varying": - c.Type = field.TypeString - // If character maximum length is specified then we should take that into account. - if characterMaximumLen.Valid { - schemaType := fmt.Sprintf("varchar(%d)", characterMaximumLen.Int64) - c.SchemaType = map[string]string{dialect.Postgres: schemaType} - } - case "date", "time with time zone", "time without time zone", "timestamp with time zone", "timestamp without time zone": - c.Type = field.TypeTime - case "bytea": - c.Type = field.TypeBytes - case "jsonb": - c.Type = field.TypeJSON - case "uuid": - c.Type = field.TypeUUID - case "cidr", "inet", "macaddr", "macaddr8": - c.Type = field.TypeOther - case "point", "line", "lseg", "box", "path", "polygon", "circle": - c.Type = field.TypeOther - case "ARRAY": - c.Type = field.TypeOther - if !udt.Valid { - return fmt.Errorf("missing array type for column %q", c.Name) - } - // Note that for ARRAY types, the 'udt_name' column holds the array type - // prefixed with '_'. For example, for 'integer[]' the result is '_int', - // and for 'text[N][M]' the result is also '_text'. That's because, the - // database ignores any size or multi-dimensions constraints. - c.SchemaType = map[string]string{dialect.Postgres: "ARRAY"} - c.typ = udt.String - case "USER-DEFINED", "tstzrange", "interval": - c.Type = field.TypeOther - if !udt.Valid { - return fmt.Errorf("missing user defined type for column %q", c.Name) - } - c.SchemaType = map[string]string{dialect.Postgres: udt.String} - } - switch { - case !defaults.Valid || c.Type == field.TypeTime || callExpr(defaults.String): - return nil - case strings.Contains(defaults.String, "::"): - parts := strings.Split(defaults.String, "::") - defaults.String = strings.Trim(parts[0], "'") - fallthrough - default: - return c.ScanDefault(defaults.String) - } -} - -// tBuilder returns the TableBuilder for the given table. -func (d *Postgres) tBuilder(t *Table) *sql.TableBuilder { - b := sql.Dialect(dialect.Postgres). - CreateTable(t.Name).IfNotExists() - for _, c := range t.Columns { - b.Column(d.addColumn(c)) - } - for _, pk := range t.PrimaryKey { - b.PrimaryKey(pk.Name) - } - if t.Annotation != nil { - addChecks(b, t.Annotation) - } - return b -} - -// cType returns the PostgreSQL string type for this column. -func (d *Postgres) cType(c *Column) (t string) { - if c.SchemaType != nil && c.SchemaType[dialect.Postgres] != "" { - return c.SchemaType[dialect.Postgres] - } - switch c.Type { - case field.TypeBool: - t = "boolean" - case field.TypeUint8, field.TypeInt8, field.TypeInt16, field.TypeUint16: - t = "smallint" - case field.TypeInt32, field.TypeUint32: - t = "int" - case field.TypeInt, field.TypeUint, field.TypeInt64, field.TypeUint64: - t = "bigint" - case field.TypeFloat32: - t = c.scanTypeOr("real") - case field.TypeFloat64: - t = c.scanTypeOr("double precision") - case field.TypeBytes: - t = "bytea" - case field.TypeJSON: - t = "jsonb" - case field.TypeUUID: - t = "uuid" - case field.TypeString: - t = "varchar" - if c.Size > maxCharSize { - t = "text" - } - case field.TypeTime: - t = c.scanTypeOr("timestamp with time zone") - case field.TypeEnum: - // Currently, the support for enums is weak (application level only. - // like SQLite). Dialect needs to create and maintain its enum type. - t = "varchar" - case field.TypeOther: - t = c.typ - default: - panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name)) - } - return t -} - -// addColumn returns the ColumnBuilder for adding the given column to a table. -func (d *Postgres) addColumn(c *Column) *sql.ColumnBuilder { - b := sql.Dialect(dialect.Postgres). - Column(c.Name).Type(d.cType(c)).Attr(c.Attr) - c.unique(b) - if c.Increment { - b.Attr("GENERATED BY DEFAULT AS IDENTITY") - } - c.nullable(b) - d.writeDefault(b, c, "DEFAULT") - if c.Collation != "" { - b.Attr("COLLATE " + strconv.Quote(c.Collation)) - } - return b -} - -// writeDefault writes the `DEFAULT` clause to column builder -// if exists and supported by the driver. -func (d *Postgres) writeDefault(b *sql.ColumnBuilder, c *Column, clause string) { - if c.Default == nil || !c.supportDefault() { - return - } - attr := fmt.Sprint(c.Default) - switch v := c.Default.(type) { - case bool: - attr = strconv.FormatBool(v) - case string: - if t := c.Type; t != field.TypeUUID && t != field.TypeTime && !t.Numeric() { - // Escape single quote by replacing each with 2. - attr = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) - } - } - b.Attr(clause + " " + attr) -} - -// alterColumn returns list of ColumnBuilder for applying in order to alter a column. -func (d *Postgres) alterColumn(c *Column) (ops []*sql.ColumnBuilder) { - b := sql.Dialect(dialect.Postgres) - ops = append(ops, b.Column(c.Name).Type(d.cType(c))) - if c.Nullable { - ops = append(ops, b.Column(c.Name).Attr("DROP NOT NULL")) - } else { - ops = append(ops, b.Column(c.Name).Attr("SET NOT NULL")) - } - if c.Default != nil && c.supportDefault() { - ops = append(ops, d.writeSetDefault(b.Column(c.Name), c)) - } - return ops -} - -func (d *Postgres) writeSetDefault(b *sql.ColumnBuilder, c *Column) *sql.ColumnBuilder { - d.writeDefault(b, c, "SET DEFAULT") - return b -} - -// hasUniqueName reports if the index has a unique name in the schema. -func hasUniqueName(i *Index) bool { - // Trim the "_key" suffix if it was added by Postgres for implicit indexes. - name := strings.TrimSuffix(i.Name, "_key") - suffix := strings.Join(i.columnNames(), "_") - if !strings.HasSuffix(name, suffix) { - return true // Assume it has a custom storage-key. - } - // The codegen prefixes by default indexes with the type name. - // For example, an index "users"("name"), will named as "user_name". - return name != suffix -} - -// addIndex returns the query for adding an index to PostgreSQL. -func (d *Postgres) addIndex(i *Index, table string) *sql.IndexBuilder { - name := i.Name - if !hasUniqueName(i) { - // Since index name should be unique in pg_class for schema, - // we prefix it with the table name and remove on read. - name = fmt.Sprintf("%s_%s", table, i.Name) - } - idx := sql.Dialect(dialect.Postgres). - CreateIndex(name).IfNotExists().Table(table) - if i.Unique { - idx.Unique() - } - for _, c := range i.Columns { - idx.Column(c.Name) - } - return idx -} - -// dropIndex drops a Postgres index. -func (d *Postgres) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { - name := idx.Name - build := sql.Dialect(dialect.Postgres) - if prefix := table + "_"; !strings.HasPrefix(name, prefix) && !hasUniqueName(idx) { - name = prefix + name - } - query, args := sql.Dialect(dialect.Postgres). - Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")). - Where(sql.And( - d.matchSchema(), - sql.EQ("constraint_type", "UNIQUE"), - sql.EQ("constraint_name", name), - )). - Query() - exists, err := exist(ctx, tx, query, args...) - if err != nil { - return err - } - query, args = build.DropIndex(name).Query() - if exists { - query, args = build.AlterTable(table).DropConstraint(name).Query() - } - return tx.Exec(ctx, query, args, nil) -} - -// isImplicitIndex reports if the index was created implicitly for the unique column. -func (d *Postgres) isImplicitIndex(idx *Index, col *Column) bool { - return strings.TrimSuffix(idx.Name, "_key") == col.Name && col.Unique -} - -// renameColumn returns the statement for renaming a column. -func (d *Postgres) renameColumn(t *Table, old, new *Column) sql.Querier { - return sql.Dialect(dialect.Postgres). - AlterTable(t.Name). - RenameColumn(old.Name, new.Name) -} - -// renameIndex returns the statement for renaming an index. -func (d *Postgres) renameIndex(t *Table, old, new *Index) sql.Querier { - if sfx := "_key"; strings.HasSuffix(old.Name, sfx) && !strings.HasSuffix(new.Name, sfx) { - new.Name += sfx - } - if pfx := t.Name + "_"; strings.HasPrefix(old.realname, pfx) && !strings.HasPrefix(new.Name, pfx) { - new.Name = pfx + new.Name - } - return sql.Dialect(dialect.Postgres).AlterIndex(old.realname).Rename(new.Name) -} - // matchSchema returns the predicate for matching table schema. func (d *Postgres) matchSchema(columns ...string) *sql.Predicate { column := "table_schema" @@ -526,156 +78,8 @@ func (d *Postgres) matchSchema(columns ...string) *sql.Predicate { return sql.EQ(column, sql.Raw("CURRENT_SCHEMA()")) } -// tables returns the query for getting the in the schema. -func (d *Postgres) tables() sql.Querier { - return sql.Dialect(dialect.Postgres). - Select("table_name"). - From(sql.Table("tables").Schema("information_schema")). - Where(d.matchSchema()) -} - -// alterColumns returns the queries for applying the columns change-set. -func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries { - b := sql.Dialect(dialect.Postgres).AlterTable(table) - for _, c := range add { - b.AddColumn(d.addColumn(c)) - } - for _, c := range modify { - b.ModifyColumns(d.alterColumn(c)...) - } - for _, c := range drop { - b.DropColumn(sql.Dialect(dialect.Postgres).Column(c.Name)) - } - if len(b.Queries) == 0 { - return nil - } - return sql.Queries{b} -} - -// needsConversion reports if column "old" needs to be converted -// (by table altering) to column "new". -func (d *Postgres) needsConversion(old, new *Column) bool { - oldT, newT := d.cType(old), d.cType(new) - return oldT != newT && (oldT != "ARRAY" || !arrayType(newT)) -} - -// callExpr reports if the given string ~looks like a function call expression. -func callExpr(s string) bool { - if parts := strings.Split(s, "::"); !strings.HasSuffix(s, ")") && strings.HasSuffix(parts[0], ")") { - s = parts[0] - } - i, j := strings.IndexByte(s, '('), strings.LastIndexByte(s, ')') - if i == -1 || i > j || j != len(s)-1 { - return false - } - for i, r := range s[:i] { - if !isAlpha(r, i > 0) { - return false - } - } - return true -} - -func isAlpha(r rune, digit bool) bool { - return 'a' <= r && r <= 'z' || 'A' <= r && r <= 'Z' || r == '_' || digit && '0' <= r && r <= '9' -} - -// arrayType reports if the given string is an array type (e.g. int[], text[2]). -func arrayType(t string) bool { - i, j := strings.LastIndexByte(t, '['), strings.LastIndexByte(t, ']') - if i == -1 || j == -1 { - return false - } - for _, r := range t[i+1 : j] { - if !unicode.IsDigit(r) { - return false - } - } - return true -} - -// foreignKeys populates the tables foreign keys using the information_schema tables -func (d *Postgres) foreignKeys(ctx context.Context, tx dialect.Tx, tables []*Table) error { - var tableLookup = make(map[string]*Table) - for _, t := range tables { - tableLookup[t.Name] = t - } - for _, t := range tables { - rows := &sql.Rows{} - query := fmt.Sprintf(fkQuery, t.Name) - if err := tx.Query(ctx, query, []any{}, rows); err != nil { - return fmt.Errorf("querying foreign keys for table %s: %w", t.Name, err) - } - defer rows.Close() - var tableFksLookup = make(map[string]*ForeignKey) - for rows.Next() { - var tableSchema, constraintName, tableName, columnName, refTableSchema, refTableName, refColumnName string - if err := rows.Scan(&tableSchema, &constraintName, &tableName, &columnName, &refTableSchema, &refTableName, &refColumnName); err != nil { - return fmt.Errorf("scanning index description: %w", err) - } - refTable := tableLookup[refTableName] - if refTable == nil { - return fmt.Errorf("could not find table: %s", refTableName) - } - column, ok := t.column(columnName) - if !ok { - return fmt.Errorf("could not find column: %s on table: %s", columnName, tableName) - } - refColumn, ok := refTable.column(refColumnName) - if !ok { - return fmt.Errorf("could not find ref column: %s on ref table: %s", refTableName, refColumnName) - } - if fk, ok := tableFksLookup[constraintName]; ok { - if _, ok := fk.column(columnName); !ok { - fk.Columns = append(fk.Columns, column) - } - if _, ok := fk.refColumn(refColumnName); !ok { - fk.RefColumns = append(fk.RefColumns, refColumn) - } - } else { - newFk := &ForeignKey{ - Symbol: constraintName, - Columns: []*Column{column}, - RefTable: refTable, - RefColumns: []*Column{refColumn}, - } - tableFksLookup[constraintName] = newFk - t.AddForeignKey(newFk) - } - } - if err := rows.Close(); err != nil { - return err - } - if err := rows.Err(); err != nil { - return err - } - } - return nil -} - -// fkQuery holds a query format for retrieving -// foreign keys of the current schema. -const fkQuery = ` -SELECT tc.table_schema, - tc.constraint_name, - tc.table_name, - kcu.column_name, - ccu.table_schema AS foreign_table_schema, - ccu.table_name AS foreign_table_name, - ccu.column_name AS foreign_column_name -FROM information_schema.table_constraints AS tc - JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - JOIN information_schema.constraint_column_usage AS ccu - ON ccu.constraint_name = tc.constraint_name - AND ccu.table_schema = tc.table_schema -WHERE tc.constraint_type = 'FOREIGN KEY' - AND tc.table_name = '%s' -order by constraint_name, kcu.ordinal_position; -` - -// Atlas integration. +// maxCharSize defines the maximum size of limited character types in Postgres (10 MB). +const maxCharSize = 10 << 20 func (d *Postgres) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { return postgres.Open(&db{ExecQuerier: conn}) @@ -843,7 +247,7 @@ func (d *Postgres) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) er return nil } -func (Postgres) atTypeRangeSQL(ts ...string) string { +func (*Postgres) atTypeRangeSQL(ts ...string) string { for i := range ts { ts[i] = fmt.Sprintf("('%s')", ts[i]) } diff --git a/dialect/sql/schema/postgres_test.go b/dialect/sql/schema/postgres_test.go deleted file mode 100644 index 69b346851..000000000 --- a/dialect/sql/schema/postgres_test.go +++ /dev/null @@ -1,1036 +0,0 @@ -// Copyright 2019-present Facebook Inc. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package schema - -import ( - "context" - "fmt" - "math" - "strings" - "testing" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/entsql" - "entgo.io/ent/dialect/sql" - "entgo.io/ent/schema/field" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/require" -) - -func TestPostgres_Create(t *testing.T) { - tests := []struct { - name string - tables []*Table - options []MigrateOption - before func(pgMock) - wantErr bool - }{ - { - name: "tx failed", - before: func(mock pgMock) { - mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled) - }, - wantErr: true, - }, - { - name: "unsupported version", - before: func(mock pgMock) { - mock.start("90000") - }, - wantErr: true, - }, - { - name: "no tables", - before: func(mock pgMock) { - mock.start("120000") - mock.ExpectCommit() - }, - }, - { - name: "create new table", - tables: []*Table{ - { - Name: "users", - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - Columns: []*Column{ - {Name: "id", Type: field.TypeUUID, Default: "uuid_generate_v4()"}, - {Name: "block_size", Type: field.TypeInt, Default: "current_setting('block_size')::bigint"}, - {Name: "name", Type: field.TypeString, Nullable: true, Collation: "he_IL"}, - {Name: "age", Type: field.TypeInt}, - {Name: "doc", Type: field.TypeJSON, Nullable: true}, - {Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}, Default: "a"}, - {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.Postgres: "numeric(5,2)"}}, - {Name: "strings", Type: field.TypeOther, SchemaType: map[string]string{dialect.Postgres: "text[]"}, Nullable: true}, - {Name: "fixed_string", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "varchar(100)"}}, - }, - Annotation: &entsql.Annotation{ - Check: "price > 0", - Checks: map[string]string{ - "valid_age": "age > 0", - "valid_name": "name <> ''", - }, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "block_size" bigint NOT NULL DEFAULT current_setting('block_size')::bigint, "name" varchar NULL COLLATE "he_IL", "age" bigint NOT NULL, "doc" jsonb NULL, "enums" varchar NOT NULL DEFAULT 'a', "price" numeric(5,2) NOT NULL, "strings" text[] NULL, "fixed_string" varchar(100) NOT NULL, PRIMARY KEY("id"), CHECK (price > 0), CONSTRAINT "valid_age" CHECK (age > 0), CONSTRAINT "valid_name" CHECK (name <> ''))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "create new table with foreign key", - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "created_at", Type: field.TypeTime}, - {Name: "inet", Type: field.TypeString, Unique: true, SchemaType: map[string]string{dialect.Postgres: "inet"}}, - } - c2 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString}, - {Name: "owner_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - } - t2 = &Table{ - Name: "pets", - Columns: c2, - PrimaryKey: c2[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "pets_owner", - Columns: c2[2:], - RefTable: t1, - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - return []*Table{t1, t2} - }(), - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "created_at" timestamp with time zone NOT NULL, "inet" inet UNIQUE NOT NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("pets", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "pets"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NOT NULL, "owner_id" bigint NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.fkExists("pets_owner", false) - mock.ExpectExec(escape(`ALTER TABLE "pets" ADD CONSTRAINT "pets_owner" FOREIGN KEY("owner_id") REFERENCES "users"("id") ON DELETE CASCADE`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "create new table with foreign key disabled", - options: []MigrateOption{ - WithForeignKeys(false), - }, - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "created_at", Type: field.TypeTime}, - } - c2 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString}, - {Name: "owner_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - } - t2 = &Table{ - Name: "pets", - Columns: c2, - PrimaryKey: c2[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "pets_owner", - Columns: c2[2:], - RefTable: t1, - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - return []*Table{t1, t2} - }(), - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "created_at" timestamp with time zone NOT NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("pets", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "pets"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NOT NULL, "owner_id" bigint NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "scan table with default", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "block_size", Type: field.TypeInt, Default: "current_setting('block_size')::bigint"}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "nextval('users_colname_seq'::regclass)", "int4", nil, nil, nil). - AddRow("block_size", "bigint", "NO", "current_setting('block_size')::bigint", "int4", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "block_size" TYPE bigint, ALTER COLUMN "block_size" SET NOT NULL, ALTER COLUMN "block_size" SET DEFAULT current_setting('block_size')::bigint`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "scan table with custom type", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "custom", Type: field.TypeOther, SchemaType: map[string]string{dialect.Postgres: "customtype"}}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "nextval('users_colname_seq'::regclass)", "NULL", nil, nil, nil). - AddRow("custom", "USER-DEFINED", "NO", "NULL", "customtype", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectCommit() - }, - }, - { - name: "add column to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "uuid", Type: field.TypeUUID, Nullable: true}, - {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, - {Name: "age", Type: field.TypeInt}, - {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.Postgres: "date"}, Default: "CURRENT_DATE"}, - {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "date"}, Nullable: true}, - {Name: "deleted_at", Type: field.TypeTime, Nullable: true}, - {Name: "cidr", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "cidr"}}, - {Name: "point", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "point"}}, - {Name: "line", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "line"}}, - {Name: "lseg", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "lseg"}}, - {Name: "box", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "box"}}, - {Name: "path", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "path"}}, - {Name: "polygon", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "polygon"}}, - {Name: "circle", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "circle"}}, - {Name: "macaddr", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "macaddr"}}, - {Name: "macaddr8", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "macaddr8"}}, - {Name: "strings", Type: field.TypeOther, SchemaType: map[string]string{dialect.Postgres: "text[]"}, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character varying", "YES", "NULL", "varchar", nil, nil, nil). - AddRow("uuid", "uuid", "YES", "NULL", "uuid", nil, nil, nil). - AddRow("created_at", "date", "NO", "CURRENT_DATE", "date", nil, nil, nil). - AddRow("updated_at", "timestamp with time zone", "YES", "NULL", "timestamptz", nil, nil, nil). - AddRow("deleted_at", "date", "YES", "NULL", "date", nil, nil, nil). - AddRow("text", "text", "YES", "NULL", "text", nil, nil, nil). - AddRow("cidr", "cidr", "NO", "NULL", "cidr", nil, nil, nil). - AddRow("inet", "inet", "YES", "NULL", "inet", nil, nil, nil). - AddRow("point", "point", "YES", "NULL", "point", nil, nil, nil). - AddRow("line", "line", "YES", "NULL", "line", nil, nil, nil). - AddRow("lseg", "lseg", "YES", "NULL", "lseg", nil, nil, nil). - AddRow("box", "box", "YES", "NULL", "box", nil, nil, nil). - AddRow("path", "path", "YES", "NULL", "path", nil, nil, nil). - AddRow("polygon", "polygon", "YES", "NULL", "polygon", nil, nil, nil). - AddRow("circle", "circle", "YES", "NULL", "circle", nil, nil, nil). - AddRow("macaddr", "macaddr", "YES", "NULL", "macaddr", nil, nil, nil). - AddRow("macaddr8", "macaddr8", "YES", "NULL", "macaddr8", nil, nil, nil). - AddRow("strings", "ARRAY", "YES", "NULL", "_text", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" bigint NOT NULL, ALTER COLUMN "created_at" TYPE date, ALTER COLUMN "created_at" SET NOT NULL, ALTER COLUMN "created_at" SET DEFAULT CURRENT_DATE, ALTER COLUMN "deleted_at" TYPE timestamp with time zone, ALTER COLUMN "deleted_at" DROP NOT NULL`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add int column with default value to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "age", Type: field.TypeInt, Default: 10}, - {Name: "doc", Type: field.TypeJSON, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil). - AddRow("doc", "jsonb", "YES", "NULL", "jsonb", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" bigint NOT NULL DEFAULT 10`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add blob columns", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "blob", Type: field.TypeBytes, Size: 1e3}, - {Name: "longblob", Type: field.TypeBytes, Size: 1e6}, - {Name: "doc", Type: field.TypeJSON, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil). - AddRow("doc", "jsonb", "YES", "NULL", "jsonb", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "blob" bytea NOT NULL, ADD COLUMN "longblob" bytea NOT NULL`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add float column with default value to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "age", Type: field.TypeFloat64, Default: 10.1}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" double precision NOT NULL DEFAULT 10.1`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add bool column with default value to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "age", Type: field.TypeBool, Default: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" boolean NOT NULL DEFAULT true`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add string column with default value to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "nick", Type: field.TypeString, Default: "unknown"}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "nick" varchar NOT NULL DEFAULT 'unknown'`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "drop column to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - options: []MigrateOption{WithDropColumn(true)}, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" DROP COLUMN "name"`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "modify column to nullable", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "NO", "NULL", "bpchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar, ALTER COLUMN "name" DROP NOT NULL`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "modify column default value", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Default: "unknown"}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "NO", "NULL", "bpchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar, ALTER COLUMN "name" SET NOT NULL, ALTER COLUMN "name" SET DEFAULT 'unknown'`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "apply uniqueness on column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt, Unique: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("age", "bigint", "NO", "NULL", "int8", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`CREATE UNIQUE INDEX IF NOT EXISTS "users_age" ON "users"("age")`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "remove uniqueness from column without option", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("age", "bigint", "NO", "NULL", "int8", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0). - AddRow("users_age_key", "age", "f", "t", 0)) - mock.ExpectCommit() - }, - }, - { - name: "remove uniqueness from column with option", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "age", Type: field.TypeInt}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - options: []MigrateOption{WithDropIndex(true)}, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("age", "bigint", "NO", "NULL", "int8", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0). - AddRow("users_age_key", "age", "f", "t", 0)) - mock.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). - WithArgs("UNIQUE", "users_age_key"). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) - mock.ExpectExec(escape(`ALTER TABLE "users" DROP CONSTRAINT "users_age_key"`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "add and remove indexes", - tables: func() []*Table { - c1 := []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - // Add implicit index. - {Name: "age", Type: field.TypeInt, Unique: true}, - {Name: "score", Type: field.TypeInt}, - } - c2 := []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "score", Type: field.TypeInt}, - {Name: "email", Type: field.TypeString}, - } - return []*Table{ - { - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - Indexes: Indexes{ - // Change non-unique index to unique. - {Name: "user_score", Columns: c1[2:3], Unique: true}, - }, - }, - { - Name: "equipment", - Columns: c2, - PrimaryKey: c2[0:1], - Indexes: Indexes{ - {Name: "equipment_score", Columns: c2[1:2]}, - // Index should not be changed. - {Name: "equipment_email", Unique: true, Columns: c2[2:]}, - }, - }, - } - }(), - options: []MigrateOption{WithDropIndex(true)}, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("age", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("score", "bigint", "NO", "NULL", "int8", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0). - AddRow("user_score", "score", "f", "f", 0)) - mock.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). - WithArgs("UNIQUE", "user_score"). - WillReturnRows(sqlmock.NewRows([]string{"count"}). - AddRow(0)) - mock.ExpectExec(escape(`DROP INDEX "user_score"`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape(`CREATE UNIQUE INDEX IF NOT EXISTS "users_age" ON "users"("age")`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape(`CREATE UNIQUE INDEX IF NOT EXISTS "user_score" ON "users"("score")`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("equipment", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("equipment"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("score", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("email", "character varying", "YES", "NULL", "varchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "equipment"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0). - AddRow("equipment_score", "score", "f", "f", 0). - AddRow("equipment_email", "email", "f", "t", 0)) - mock.ExpectCommit() - }, - }, - { - name: "add edge to table", - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "user_spouse" + strings.Repeat("_", 64), // super long fk. - Columns: c1[2:], - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - t1.ForeignKeys[0].RefTable = t1 - return []*Table{t1} - }(), - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "YES", "NULL", "int8", nil, nil, nil). - AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "spouse_id" bigint NULL`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.fkExists("user_spouse____________________390ed76f91d3c57cd3516e7690f621dc", false) - mock.ExpectExec(`ALTER TABLE "users" ADD CONSTRAINT ".{63}" FOREIGN KEY\("spouse_id"\) REFERENCES "users"\("id"\) ON DELETE CASCADE`). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "universal id for all tables", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("ent_types", false) - // create ent_types table. - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "ent_types"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "type" varchar UNIQUE NOT NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("users", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set users id range. - mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). - WithArgs("users"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(`ALTER TABLE "users" ALTER COLUMN "id" RESTART WITH 1`). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("groups", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set groups id range. - mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). - WithArgs("groups"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(`ALTER TABLE "groups" ALTER COLUMN "id" RESTART WITH 4294967296`). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "universal id for new tables", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("ent_types", true) - // query ent_types table. - mock.ExpectQuery(`SELECT "type" FROM "ent_types" ORDER BY "id" ASC`). - WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) - // query users table. - mock.tableExists("users", true) - // users table has no changes. - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "YES", "NULL", "int8", nil, nil, nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - // query groups table. - mock.tableExists("groups", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set groups id range. - mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). - WithArgs("groups"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(`ALTER TABLE "groups" ALTER COLUMN "id" RESTART WITH 4294967296`). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "universal id for restored tables", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("ent_types", true) - // query ent_types table. - mock.ExpectQuery(`SELECT "type" FROM "ent_types" ORDER BY "id" ASC`). - WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) - // query and create users (restored table). - mock.tableExists("users", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set users id range (without inserting to ent_types). - mock.ExpectExec(`ALTER TABLE "users" ALTER COLUMN "id" RESTART WITH 1`). - WillReturnResult(sqlmock.NewResult(0, 1)) - // query groups table. - mock.tableExists("groups", false) - mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set groups id range. - mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). - WithArgs("groups"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(`ALTER TABLE "groups" ALTER COLUMN "id" RESTART WITH 4294967296`). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "no modify numeric column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.Postgres: "numeric(6,4)"}}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("price", "numeric", "NO", "NULL", "numeric", "6", "4", nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectCommit() - }, - }, - { - name: "modify numeric column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "price", Type: field.TypeFloat64, Nullable: false, SchemaType: map[string]string{dialect.Postgres: "numeric(6,4)"}}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("price", "numeric", "NO", "NULL", "numeric", "5", "4", nil)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "price" TYPE numeric(6,4), ALTER COLUMN "price" SET NOT NULL`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - { - name: "no modify fixed size varchar column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "varchar(20)"}}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character varying", "NO", "NULL", "varchar", nil, nil, 20)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectCommit() - }, - }, - { - name: "modify fixed size varchar column", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "varchar(20)"}, Default: "unknown"}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock pgMock) { - mock.start("120000") - mock.tableExists("users", true) - mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). - AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). - AddRow("name", "character varying", "NO", "NULL", "varchar", nil, nil, 10)) - mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). - WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). - AddRow("users_pkey", "id", "t", "t", 0)) - mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar(20), ALTER COLUMN "name" SET NOT NULL, ALTER COLUMN "name" SET DEFAULT 'unknown'`)). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, mock, err := sqlmock.New() - require.NoError(t, err) - tt.before(pgMock{mock}) - migrate, err := NewMigrate(sql.OpenDB("postgres", db), append(tt.options, WithAtlas(false))...) - require.NoError(t, err) - err = migrate.Create(context.Background(), tt.tables...) - require.Equal(t, tt.wantErr, err != nil, err) - }) - } -} - -type pgMock struct { - sqlmock.Sqlmock -} - -func (m pgMock) start(version string) { - m.ExpectQuery(escape("SHOW server_version_num")). - WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow(version)) - m.ExpectBegin() -} - -func (m pgMock) tableExists(table string, exists bool) { - count := 0 - if exists { - count = 1 - } - m.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). - WithArgs(table). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) -} - -func (m pgMock) fkExists(fk string, exists bool) { - count := 0 - if exists { - count = 1 - } - m.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). - WithArgs("FOREIGN KEY", fk). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) -} diff --git a/dialect/sql/schema/schema.go b/dialect/sql/schema/schema.go index 1e10073c6..1978594ba 100644 --- a/dialect/sql/schema/schema.go +++ b/dialect/sql/schema/schema.go @@ -7,7 +7,6 @@ package schema import ( "fmt" - "sort" "strconv" "strings" @@ -186,30 +185,6 @@ func (t *Table) index(name string) (*Index, bool) { return nil, false } -// hasIndex reports if the table has at least one index that matches the given names. -func (t *Table) hasIndex(names ...string) bool { - for i := range names { - if names[i] == "" { - continue - } - if _, ok := t.index(names[i]); ok { - return true - } - } - return false -} - -// fk returns a table foreign-key by its symbol. -// faster than map lookup for most cases. -func (t *Table) fk(symbol string) (*ForeignKey, bool) { - for _, fk := range t.ForeignKeys { - if fk.Symbol == symbol { - return fk, true - } - } - return nil, false -} - // CopyTables returns a deep-copy of the given tables. This utility function is // useful for copying the generated schema tables (i.e. migrate.Tables) before // running schema migration when there is a need for execute multiple migrations @@ -417,27 +392,6 @@ func (c *Column) ScanDefault(value string) error { return nil } -// defaultValue adds the `DEFAULT` attribute to the column. -// Note that, in SQLite if a NOT NULL constraint is specified, -// then the column must have a default value which not NULL. -func (c *Column) defaultValue(b *sql.ColumnBuilder) { - if c.Default == nil || !c.supportDefault() { - return - } - // Has default and the database supports adding this default. - attr := fmt.Sprint(c.Default) - switch v := c.Default.(type) { - case bool: - attr = strconv.FormatBool(v) - case string: - if t := c.Type; t != field.TypeUUID && t != field.TypeTime { - // Escape single quote by replacing each with 2. - attr = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) - } - } - b.Attr("DEFAULT " + attr) -} - // supportDefault reports if the column type supports default value. func (c Column) supportDefault() bool { switch t := c.Type; t { @@ -450,25 +404,6 @@ func (c Column) supportDefault() bool { } } -// unique adds the `UNIQUE` attribute if the column is a unique type. -// it is exist in a different function to share the common declaration -// between the two dialects. -func (c *Column) unique(b *sql.ColumnBuilder) { - if c.Unique { - b.Attr("UNIQUE") - } -} - -// nullable adds the `NULL`/`NOT NULL` attribute to the column if it exists in -// a different function to share the common declaration between the two dialects. -func (c *Column) nullable(b *sql.ColumnBuilder) { - attr := Null - if !c.Nullable { - attr = "NOT " + attr - } - b.Attr(attr) -} - // scanTypeOr returns the scanning type or the given value. func (c *Column) scanTypeOr(t string) string { if c.typ != "" { @@ -487,24 +422,6 @@ type ForeignKey struct { OnDelete ReferenceOption // action on delete. } -func (fk ForeignKey) column(name string) (*Column, bool) { - for _, c := range fk.Columns { - if c.Name == name { - return c, true - } - } - return nil, false -} - -func (fk ForeignKey) refColumn(name string) (*Column, bool) { - for _, c := range fk.RefColumns { - if c.Name == name { - return c, true - } - } - return nil, false -} - // DSL returns a default DSL query for a foreign-key. func (fk ForeignKey) DSL() *sql.ForeignKeyBuilder { cols := make([]string, len(fk.Columns)) @@ -551,7 +468,6 @@ type Index struct { Columns []*Column // actual table columns. Annotation *entsql.IndexAnnotation // index annotation. columns []string // columns loaded from query scan. - primary bool // primary key index. realname string // real name in the database (Postgres only). } @@ -573,32 +489,6 @@ func (i *Index) DropBuilder(table string) *sql.DropIndexBuilder { return idx } -// sameAs reports if the index has the same properties -// as the given index (except the name). -func (i *Index) sameAs(idx *Index) bool { - if i.Unique != idx.Unique || len(i.Columns) != len(idx.Columns) { - return false - } - for j, c := range i.Columns { - if c.Name != idx.Columns[j].Name { - return false - } - } - return true -} - -// columnNames returns the names of the columns of the index. -func (i *Index) columnNames() []string { - if len(i.columns) > 0 { - return i.columns - } - columns := make([]string, 0, len(i.Columns)) - for _, c := range i.Columns { - columns = append(columns, c.Name) - } - return columns -} - // Indexes used for scanning all sql.Rows into a list of indexes, because // multiple sql rows can represent the same index (multi-columns indexes). type Indexes []*Index @@ -673,33 +563,16 @@ func compare(v1, v2 int) int { return 1 } -// addChecks appends the CHECK clauses from the entsql.Annotation. -func addChecks(t *sql.TableBuilder, ant *entsql.Annotation) { - if check := ant.Check; check != "" { - t.Checks(func(b *sql.Builder) { - b.WriteString("CHECK " + checkExpr(check)) - }) +func indexType(idx *Index, d string) (string, bool) { + ant := idx.Annotation + if ant == nil { + return "", false } - if checks := ant.Checks; len(ant.Checks) > 0 { - names := make([]string, 0, len(checks)) - for name := range checks { - names = append(names, name) - } - sort.Strings(names) - for _, name := range names { - name := name - t.Checks(func(b *sql.Builder) { - b.WriteString("CONSTRAINT ").Ident(name).WriteString(" CHECK " + checkExpr(checks[name])) - }) - } + if ant.Types != nil && ant.Types[d] != "" { + return ant.Types[d], true } -} - -// checkExpr formats the CHECK expression. -func checkExpr(expr string) string { - expr = strings.TrimSpace(expr) - if !strings.HasPrefix(expr, "(") && !strings.HasSuffix(expr, ")") { - expr = "(" + expr + ")" - } - return expr + if ant.Type != "" { + return ant.Type, true + } + return "", false } diff --git a/dialect/sql/schema/sqlite.go b/dialect/sql/schema/sqlite.go index 5e315cd08..33708e881 100644 --- a/dialect/sql/schema/sqlite.go +++ b/dialect/sql/schema/sqlite.go @@ -22,7 +22,7 @@ import ( ) type ( - // SQLite is an SQLite migration driver. + // SQLite adapter for Atlas migration engine. SQLite struct { dialect.Driver WithForeignKeys bool @@ -88,309 +88,6 @@ func (d *SQLite) tableExist(ctx context.Context, conn dialect.ExecQuerier, name return exist(ctx, conn, query, args...) } -// setRange sets the start value of table PK. -// SQLite tracks the AUTOINCREMENT in the "sqlite_sequence" table that is created and initialized automatically -// whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables) -// only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence" -// table, we updated it. Otherwise, we insert a new value. -func (d *SQLite) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error { - query, args := sql.Select().Count(). - From(sql.Table("sqlite_sequence")). - Where(sql.EQ("name", t.Name)). - Query() - exists, err := exist(ctx, conn, query, args...) - switch { - case err != nil: - return err - case exists: - query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", t.Name)).Query() - default: // !exists - query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(t.Name, value).Query() - } - return conn.Exec(ctx, query, args, nil) -} - -func (d *SQLite) tBuilder(t *Table) *sql.TableBuilder { - b := sql.CreateTable(t.Name) - for _, c := range t.Columns { - b.Column(d.addColumn(c)) - } - if t.Annotation != nil { - addChecks(b, t.Annotation) - } - // Unlike in MySQL, we're not able to add foreign-key constraints to table - // after it was created, and adding them to the `CREATE TABLE` statement is - // not always valid (because circular foreign-keys situation is possible). - // We stay consistent by not using constraints at all, and just defining the - // foreign keys in the `CREATE TABLE` statement. - if d.WithForeignKeys { - for _, fk := range t.ForeignKeys { - b.ForeignKeys(fk.DSL()) - } - } - // If it's an ID based primary key with autoincrement, we add - // the `PRIMARY KEY` clause to the column declaration. Otherwise, - // we append it to the constraint clause. - if len(t.PrimaryKey) == 1 && t.PrimaryKey[0].Increment { - return b - } - for _, pk := range t.PrimaryKey { - b.PrimaryKey(pk.Name) - } - return b -} - -// cType returns the SQLite string type for the given column. -func (*SQLite) cType(c *Column) (t string) { - if c.SchemaType != nil && c.SchemaType[dialect.SQLite] != "" { - return c.SchemaType[dialect.SQLite] - } - switch c.Type { - case field.TypeBool: - t = "bool" - case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32, - field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64: - t = "integer" - case field.TypeBytes: - t = "blob" - case field.TypeString, field.TypeEnum: - // SQLite does not impose any length restrictions on - // the length of strings, BLOBs or numeric values. - t = fmt.Sprintf("varchar(%d)", DefaultStringLen) - case field.TypeFloat32, field.TypeFloat64: - t = "real" - case field.TypeTime: - t = "datetime" - case field.TypeJSON: - t = "json" - case field.TypeUUID: - t = "uuid" - case field.TypeOther: - t = c.typ - default: - panic(fmt.Sprintf("unsupported type %q for column %q", c.Type, c.Name)) - } - return t -} - -// addColumn returns the DSL query for adding the given column to a table. -func (d *SQLite) addColumn(c *Column) *sql.ColumnBuilder { - b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr) - c.unique(b) - if c.PrimaryKey() && c.Increment { - b.Attr("PRIMARY KEY AUTOINCREMENT") - } - c.nullable(b) - c.defaultValue(b) - return b -} - -// addIndex returns the query for adding an index to SQLite. -func (d *SQLite) addIndex(i *Index, table string) *sql.IndexBuilder { - return i.Builder(table).IfNotExists() -} - -// dropIndex drops a SQLite index. -func (d *SQLite) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { - query, args := idx.DropBuilder("").Query() - return tx.Exec(ctx, query, args, nil) -} - -// fkExist returns always true to disable foreign-keys creation after the table was created. -func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil } - -// table returns always error to indicate that SQLite dialect doesn't support incremental migration. -func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { - rows := &sql.Rows{} - query, args := sql.Select("name", "type", "notnull", "dflt_value", "pk"). - From(sql.Table(fmt.Sprintf("pragma_table_info('%s')", name)).Unquote()). - OrderBy("pk"). - Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("sqlite: reading table description %w", err) - } - // Call Close in cases of failures (Close is idempotent). - defer rows.Close() - t := NewTable(name) - for rows.Next() { - c := &Column{} - if err := d.scanColumn(c, rows); err != nil { - return nil, fmt.Errorf("sqlite: %w", err) - } - if c.PrimaryKey() { - t.PrimaryKey = append(t.PrimaryKey, c) - } - t.AddColumn(c) - } - if err := rows.Err(); err != nil { - return nil, err - } - if err := rows.Close(); err != nil { - return nil, fmt.Errorf("sqlite: closing rows %w", err) - } - indexes, err := d.indexes(ctx, tx, name) - if err != nil { - return nil, err - } - // Add and link indexes to table columns. - for _, idx := range indexes { - switch { - case idx.primary: - case idx.Unique && len(idx.columns) == 1: - name := idx.columns[0] - c, ok := t.column(name) - if !ok { - return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) - } - c.Key = UniqueKey - c.Unique = true - fallthrough - default: - t.addIndex(idx) - } - } - return t, nil -} - -// table loads the table indexes from the database. -func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Indexes, error) { - rows := &sql.Rows{} - query, args := sql.Select("name", "unique", "origin"). - From(sql.Table(fmt.Sprintf("pragma_index_list('%s')", name)).Unquote()). - Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("reading table indexes %w", err) - } - defer rows.Close() - var idx Indexes - for rows.Next() { - i := &Index{} - origin := sql.NullString{} - if err := rows.Scan(&i.Name, &i.Unique, &origin); err != nil { - return nil, fmt.Errorf("scanning index description %w", err) - } - i.primary = origin.String == "pk" - idx = append(idx, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - if err := rows.Close(); err != nil { - return nil, fmt.Errorf("closing rows %w", err) - } - for i := range idx { - columns, err := d.indexColumns(ctx, tx, idx[i].Name) - if err != nil { - return nil, err - } - idx[i].columns = columns - // Normalize implicit index names to ent naming convention. See: - // https://github.com/sqlite/sqlite/blob/e937df8/src/build.c#L3583 - if len(columns) == 1 && strings.HasPrefix(idx[i].Name, "sqlite_autoindex_"+name) { - idx[i].Name = columns[0] - } - } - return idx, nil -} - -// indexColumns loads index columns from index info. -func (d *SQLite) indexColumns(ctx context.Context, tx dialect.Tx, name string) ([]string, error) { - rows := &sql.Rows{} - query, args := sql.Select("name"). - From(sql.Table(fmt.Sprintf("pragma_index_info('%s')", name)).Unquote()). - OrderBy("seqno"). - Query() - if err := tx.Query(ctx, query, args, rows); err != nil { - return nil, fmt.Errorf("reading table indexes %w", err) - } - defer rows.Close() - var names []string - if err := sql.ScanSlice(rows, &names); err != nil { - return nil, err - } - return names, nil -} - -// scanColumn scans the column information from SQLite column description. -func (d *SQLite) scanColumn(c *Column, rows *sql.Rows) error { - var ( - pk sql.NullInt64 - notnull sql.NullInt64 - defaults sql.NullString - ) - if err := rows.Scan(&c.Name, &c.typ, ¬null, &defaults, &pk); err != nil { - return fmt.Errorf("scanning column description: %w", err) - } - c.Nullable = notnull.Int64 == 0 - if pk.Int64 > 0 { - c.Key = PrimaryKey - } - if c.typ == "" { - return fmt.Errorf("missing type information for column %q", c.Name) - } - parts, size, _, err := parseColumn(c.typ) - if err != nil { - return err - } - switch strings.ToLower(parts[0]) { - case "bool", "boolean": - c.Type = field.TypeBool - case "blob": - c.Type = field.TypeBytes - case "integer": - // All integer types have the same "type affinity". - c.Type = field.TypeInt - case "real", "float", "double": - c.Type = field.TypeFloat64 - case "datetime": - c.Type = field.TypeTime - case "json": - c.Type = field.TypeJSON - case "uuid": - c.Type = field.TypeUUID - case "varchar", "char", "text": - c.Size = size - c.Type = field.TypeString - case "decimal", "numeric": - c.Type = field.TypeOther - } - if defaults.Valid { - return c.ScanDefault(defaults.String) - } - return nil -} - -// alterColumns returns the queries for applying the columns change-set. -func (d *SQLite) alterColumns(table string, add, _, _ []*Column) sql.Queries { - queries := make(sql.Queries, 0, len(add)) - for i := range add { - c := d.addColumn(add[i]) - if fk := add[i].foreign; fk != nil { - c.Constraint(fk.DSL()) - } - queries = append(queries, sql.Dialect(dialect.SQLite).AlterTable(table).AddColumn(c)) - } - // Modifying and dropping columns is not supported and disabled until we - // will support https://www.sqlite.org/lang_altertable.html#otheralter - return queries -} - -// tables returns the query for getting the in the schema. -func (d *SQLite) tables() sql.Querier { - return sql.Select("name"). - From(sql.Table("sqlite_schema")). - Where(sql.EQ("type", "table")) -} - -// needsConversion reports if column "old" needs to be converted -// (by table altering) to column "new". -func (d *SQLite) needsConversion(old, new *Column) bool { - c1, c2 := d.cType(old), d.cType(new) - return c1 != c2 && old.typ != c2 -} - -// Atlas integration. - func (d *SQLite) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { return sqlite.Open(&db{ExecQuerier: conn}) } diff --git a/dialect/sql/schema/sqlite_test.go b/dialect/sql/schema/sqlite_test.go deleted file mode 100644 index 433bb8bf9..000000000 --- a/dialect/sql/schema/sqlite_test.go +++ /dev/null @@ -1,478 +0,0 @@ -// Copyright 2019-present Facebook Inc. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package schema - -import ( - "context" - "fmt" - "math" - "testing" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/sql" - "entgo.io/ent/schema/field" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/require" -) - -func TestSQLite_Create(t *testing.T) { - tests := []struct { - name string - tables []*Table - options []MigrateOption - before func(sqliteMock) - wantErr bool - }{ - { - name: "tx failed", - before: func(mock sqliteMock) { - mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled) - }, - wantErr: true, - }, - { - name: "fk disabled", - before: func(mock sqliteMock) { - mock.ExpectBegin() - mock.ExpectQuery("PRAGMA foreign_keys"). - WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(0)) - mock.ExpectRollback() - }, - wantErr: true, - }, - { - name: "no tables", - before: func(mock sqliteMock) { - mock.start() - mock.commit() - }, - }, - { - name: "create new table", - tables: []*Table{ - { - Name: "users", - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "age", Type: field.TypeInt}, - {Name: "doc", Type: field.TypeJSON, Nullable: true}, - {Name: "uuid", Type: field.TypeUUID, Nullable: true}, - {Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.SQLite: "decimal(6,2)"}}, - }, - }, - }, - before: func(mock sqliteMock) { - mock.start() - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - { - name: "create new table with foreign key", - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "created_at", Type: field.TypeTime}, - } - c2 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString}, - {Name: "owner_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - } - t2 = &Table{ - Name: "pets", - Columns: c2, - PrimaryKey: c2[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "pets_owner", - Columns: c2[2:], - RefTable: t1, - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - return []*Table{t1, t2} - }(), - before: func(mock sqliteMock) { - mock.start() - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` datetime NOT NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("pets", false) - mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - { - name: "create new table with foreign key disabled", - options: []MigrateOption{ - WithForeignKeys(false), - }, - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "created_at", Type: field.TypeTime}, - } - c2 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString}, - {Name: "owner_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - } - t2 = &Table{ - Name: "pets", - Columns: c2, - PrimaryKey: c2[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "pets_owner", - Columns: c2[2:], - RefTable: t1, - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - return []*Table{t1, t2} - }(), - before: func(mock sqliteMock) { - mock.start() - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` datetime NOT NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("pets", false) - mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - { - name: "add column to table", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, - {Name: "uuid", Type: field.TypeUUID, Nullable: true}, - {Name: "age", Type: field.TypeInt, Default: 0}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock sqliteMock) { - mock.start() - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("name", "varchar(255)", 0, nil, 0). - AddRow("text", "text", 0, "NULL", 0). - AddRow("uuid", "uuid", 0, "Null", 0). - AddRow("id", "integer", 1, "NULL", 1)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - { - name: "datetime and timestamp", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock sqliteMock) { - mock.start() - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("created_at", "datetime", 0, nil, 0). - AddRow("id", "integer", 1, "NULL", 1)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - { - name: "add blob columns", - tables: []*Table{ - { - Name: "blobs", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "old_tiny", Type: field.TypeBytes, Size: 100}, - {Name: "old_blob", Type: field.TypeBytes, Size: 1e3}, - {Name: "old_medium", Type: field.TypeBytes, Size: 1e5}, - {Name: "old_long", Type: field.TypeBytes, Size: 1e8}, - {Name: "new_tiny", Type: field.TypeBytes, Size: 100}, - {Name: "new_blob", Type: field.TypeBytes, Size: 1e3}, - {Name: "new_medium", Type: field.TypeBytes, Size: 1e5}, - {Name: "new_long", Type: field.TypeBytes, Size: 1e8}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock sqliteMock) { - mock.start() - mock.tableExists("blobs", true) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('blobs') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("old_tiny", "blob", 1, nil, 0). - AddRow("old_blob", "blob", 1, nil, 0). - AddRow("old_medium", "blob", 1, nil, 0). - AddRow("old_long", "blob", 1, nil, 0). - AddRow("id", "integer", 1, "NULL", 1)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('blobs')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) - for _, c := range []string{"tiny", "blob", "medium", "long"} { - mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))). - WillReturnResult(sqlmock.NewResult(0, 1)) - } - mock.commit() - }, - }, - { - name: "add columns with default values", - tables: []*Table{ - { - Name: "users", - Columns: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Default: "unknown"}, - {Name: "active", Type: field.TypeBool, Default: false}, - }, - PrimaryKey: []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - }, - }, - }, - before: func(mock sqliteMock) { - mock.start() - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("id", "integer", 1, "NULL", 1)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `name` varchar(255) NOT NULL DEFAULT 'unknown'")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - { - name: "add edge to table", - tables: func() []*Table { - var ( - c1 = []*Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, - } - t1 = &Table{ - Name: "users", - Columns: c1, - PrimaryKey: c1[0:1], - ForeignKeys: []*ForeignKey{ - { - Symbol: "user_spouse", - Columns: c1[2:], - RefColumns: c1[0:1], - OnDelete: Cascade, - }, - }, - } - ) - t1.ForeignKeys[0].RefTable = t1 - return []*Table{t1} - }(), - before: func(mock sqliteMock) { - mock.start() - mock.tableExists("users", true) - mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). - WithArgs(). - WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). - AddRow("name", "varchar(255)", 1, "NULL", 0). - AddRow("id", "integer", 1, "NULL", 1)) - mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). - WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) - mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - { - name: "universal id for all tables", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock sqliteMock) { - mock.start() - // creating ent_types table. - mock.tableExists("ent_types", false) - mock.ExpectExec(escape("CREATE TABLE `ent_types`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `type` varchar(255) UNIQUE NOT NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set users id range. - mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). - WithArgs("users"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) - mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). - WithArgs("users", 0). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("groups", false) - mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set groups id range. - mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). - WithArgs("groups"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). - WithArgs("groups"). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) - mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). - WithArgs("groups", 1<<32). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - { - name: "universal id for restored tables", - tables: []*Table{ - NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), - }, - options: []MigrateOption{WithGlobalUniqueID(true)}, - before: func(mock sqliteMock) { - mock.start() - // query ent_types table. - mock.tableExists("ent_types", true) - mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). - WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) - mock.tableExists("users", false) - mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set users id range (without inserting to ent_types). - mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). - WithArgs("users"). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) - mock.ExpectExec(escape("UPDATE `sqlite_sequence` SET `seq` = ? WHERE `name` = ?")). - WithArgs(0, "users"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.tableExists("groups", false) - mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). - WillReturnResult(sqlmock.NewResult(0, 1)) - // set groups id range. - mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). - WithArgs("groups"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). - WithArgs("groups"). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) - mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). - WithArgs("groups", 1<<32). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.commit() - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, mock, err := sqlmock.New() - require.NoError(t, err) - tt.before(sqliteMock{mock}) - migrate, err := NewMigrate(sql.OpenDB("sqlite3", db), append(tt.options, WithAtlas(false))...) - require.NoError(t, err) - err = migrate.Create(context.Background(), tt.tables...) - require.Equal(t, tt.wantErr, err != nil, err) - }) - } -} - -type sqliteMock struct { - sqlmock.Sqlmock -} - -func (m sqliteMock) start() { - m.ExpectQuery("PRAGMA foreign_keys"). - WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1)) - m.ExpectExec("PRAGMA foreign_keys = off"). - WillReturnResult(sqlmock.NewResult(0, 1)) - m.ExpectBegin() - m.ExpectQuery("PRAGMA foreign_key_check"). - WillReturnRows(sqlmock.NewRows([]string{})) // empty -} - -func (m sqliteMock) commit() { - m.ExpectQuery("PRAGMA foreign_key_check"). - WillReturnRows(sqlmock.NewRows([]string{})) // empty - m.ExpectCommit() - m.ExpectExec("PRAGMA foreign_keys = on"). - WillReturnResult(sqlmock.NewResult(0, 1)) -} - -func (m sqliteMock) tableExists(table string, exists bool) { - count := 0 - if exists { - count = 1 - } - m.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")). - WithArgs("table", table). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) -} diff --git a/entc/gen/globalid_test.go b/entc/gen/globalid_test.go index 9adfb92fd..1b8d07967 100644 --- a/entc/gen/globalid_test.go +++ b/entc/gen/globalid_test.go @@ -22,7 +22,6 @@ func TestIncrementStartAnnotation(t *testing.T) { Name: "T1", Annotations: gen.Annotations{a.Name(): a}, }, - {Name: "T2"}, } c = &gen.Config{ Package: "entc/gen", @@ -45,7 +44,7 @@ func TestIncrementStartAnnotation(t *testing.T) { require.NotNil(t, g) // Duplicated increment starting values are not allowed. - s = append(s, &load.Schema{ + s = append(s, &load.Schema{Name: "T2"}, &load.Schema{ Name: "T3", Annotations: gen.Annotations{a.Name(): &entsql.Annotation{IncrementStart: p(1 << 32)}}, }) diff --git a/examples/traversal/example_test.go b/examples/traversal/example_test.go index 34e7086fb..717d117b2 100644 --- a/examples/traversal/example_test.go +++ b/examples/traversal/example_test.go @@ -9,8 +9,6 @@ import ( "fmt" "log" - "entgo.io/ent/dialect/sql/schema" - "entgo.io/ent/examples/traversal/ent" "entgo.io/ent/examples/traversal/ent/group" "entgo.io/ent/examples/traversal/ent/pet" @@ -27,7 +25,7 @@ func Example_Traversal() { defer client.Close() ctx := context.Background() // Run the auto migration tool. - if err := client.Schema.Create(ctx, schema.WithAtlas(true)); err != nil { + if err := client.Schema.Create(ctx); err != nil { log.Fatalf("failed creating schema resources: %v", err) } if err := Gen(ctx, client); err != nil { @@ -120,7 +118,7 @@ func Gen(ctx context.Context, client *ent.Client) error { func Traverse(ctx context.Context, client *ent.Client) error { owner, err := client.Group. // GroupClient. Query(). // Query builder. - Where(group.Name("Github")). // Filter only Github group (only 1). + Where(group.Name("Github")). // Filter only GitHub group (only 1). QueryAdmin(). // Getting Dan. QueryFriends(). // Getting Dan's friends: [Ariel]. QueryPets(). // Their pets: [Pedro, Xabi].