mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: remove deprecated legacy migration engine (#4294)
This commit is contained in:
@@ -29,10 +29,6 @@ type Atlas struct {
|
||||
atDriver migrate.Driver
|
||||
sqlDialect sqlDialect
|
||||
|
||||
legacy bool // if the legacy migration engine instead of Atlas should be used
|
||||
withFixture bool // deprecated: with fks rename fixture
|
||||
sum bool // deprecated: sum file generation will be required
|
||||
|
||||
indent string // plan indentation
|
||||
errNoPlan bool // no plan error enabled
|
||||
universalID bool // global unique ids
|
||||
@@ -67,7 +63,7 @@ func Diff(ctx context.Context, u, name string, tables []*Table, opts ...MigrateO
|
||||
|
||||
// NewMigrate creates a new Atlas form the given dialect.Driver.
|
||||
func NewMigrate(drv dialect.Driver, opts ...MigrateOption) (*Atlas, error) {
|
||||
a := &Atlas{driver: drv, withForeignKeys: true, mode: ModeInspect, sum: true}
|
||||
a := &Atlas{driver: drv, withForeignKeys: true, mode: ModeInspect}
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
@@ -84,7 +80,7 @@ func NewMigrateURL(u string, opts ...MigrateOption) (*Atlas, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a := &Atlas{url: parsed, withForeignKeys: true, mode: ModeInspect, sum: true}
|
||||
a := &Atlas{url: parsed, withForeignKeys: true, mode: ModeInspect}
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
@@ -106,13 +102,6 @@ func NewMigrateURL(u string, opts ...MigrateOption) (*Atlas, error) {
|
||||
func (a *Atlas) Create(ctx context.Context, tables ...*Table) (err error) {
|
||||
a.setupTables(tables)
|
||||
var creator Creator = CreateFunc(a.create)
|
||||
if a.legacy {
|
||||
m, err := a.legacyMigrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
creator = CreateFunc(m.create)
|
||||
}
|
||||
for i := len(a.hooks) - 1; i >= 0; i-- {
|
||||
creator = a.hooks[i](creator)
|
||||
}
|
||||
@@ -132,13 +121,9 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
|
||||
return errors.New("no migration directory given")
|
||||
}
|
||||
opts := []migrate.PlannerOption{migrate.WithFormatter(a.fmt)}
|
||||
if a.sum {
|
||||
// Validate the migration directory before proceeding.
|
||||
if err := migrate.Validate(a.dir); err != nil {
|
||||
return fmt.Errorf("validating migration directory: %w", err)
|
||||
}
|
||||
} else {
|
||||
opts = append(opts, migrate.DisableChecksum())
|
||||
// Validate the migration directory before proceeding.
|
||||
if err := migrate.Validate(a.dir); err != nil {
|
||||
return fmt.Errorf("validating migration directory: %w", err)
|
||||
}
|
||||
a.setupTables(tables)
|
||||
// Set up connections.
|
||||
@@ -488,18 +473,6 @@ func WithApplyHook(hooks ...ApplyHook) MigrateOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithAtlas is an opt-out option for v0.11 indicating the migration
|
||||
// should be executed using the deprecated legacy engine.
|
||||
// Note, in future versions, this option is going to be removed
|
||||
// and the Atlas (https://atlasgo.io) based migration engine should be used.
|
||||
//
|
||||
// Deprecated: The legacy engine will be removed.
|
||||
func WithAtlas(b bool) MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
a.legacy = !b
|
||||
}
|
||||
}
|
||||
|
||||
// WithDir sets the atlas migration directory to use to store migration files.
|
||||
func WithDir(dir migrate.Dir) MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
@@ -522,22 +495,6 @@ func WithDialect(d string) MigrateOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSumFile instructs atlas to generate a migration directory integrity sum file.
|
||||
//
|
||||
// Deprecated: generating the sum file is now opt-out. This method will be removed in future versions.
|
||||
func WithSumFile() MigrateOption {
|
||||
return func(a *Atlas) {}
|
||||
}
|
||||
|
||||
// DisableChecksum instructs atlas to skip migration directory integrity sum file generation.
|
||||
//
|
||||
// Deprecated: generating the sum file will no longer be optional in future versions.
|
||||
func DisableChecksum() MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
a.sum = false
|
||||
}
|
||||
}
|
||||
|
||||
// WithMigrationMode instructs atlas how to compute the current state of the schema. This can be done by either
|
||||
// replaying (ModeReplay) the migration directory on the connected database, or by inspecting (ModeInspect) the
|
||||
// connection. Currently, ModeReplay is opt-in, and ModeInspect is the default. In future versions, ModeReplay will
|
||||
@@ -626,15 +583,9 @@ func (a *Atlas) init() error {
|
||||
a.fmt = sqltool.GolangMigrateFormatter
|
||||
}
|
||||
}
|
||||
if a.mode == ModeReplay {
|
||||
// ModeReplay requires a migration directory.
|
||||
if a.dir == nil {
|
||||
return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires versioned migrations: WithDir()")
|
||||
}
|
||||
// ModeReplay requires sum file generation.
|
||||
if !a.sum {
|
||||
return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires migration directory integrity file")
|
||||
}
|
||||
// ModeReplay requires a migration directory.
|
||||
if a.mode == ModeReplay && a.dir == nil {
|
||||
return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires versioned migrations: WithDir()")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1238,32 +1189,6 @@ func (r *diffDriver) SchemaDiff(from, to *schema.Schema, opts ...schema.DiffOpti
|
||||
return d.Diff(from, to)
|
||||
}
|
||||
|
||||
// legacyMigrate returns a configured legacy migration engine (before Atlas) to keep backwards compatibility.
|
||||
//
|
||||
// Deprecated: Will be removed alongside legacy migration support.
|
||||
func (a *Atlas) legacyMigrate() (*Migrate, error) {
|
||||
m := &Migrate{
|
||||
universalID: a.universalID,
|
||||
dropColumns: a.dropColumns,
|
||||
dropIndexes: a.dropIndexes,
|
||||
withFixture: a.withFixture,
|
||||
withForeignKeys: a.withForeignKeys,
|
||||
hooks: a.hooks,
|
||||
atlas: a,
|
||||
}
|
||||
switch a.dialect {
|
||||
case dialect.MySQL:
|
||||
m.sqlDialect = &MySQL{Driver: a.driver}
|
||||
case dialect.SQLite:
|
||||
m.sqlDialect = &SQLite{Driver: a.driver, WithForeignKeys: a.withForeignKeys}
|
||||
case dialect.Postgres:
|
||||
m.sqlDialect = &Postgres{Driver: a.driver}
|
||||
default:
|
||||
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// removeAttr is a temporary patch due to compiler errors we get by using the generic
|
||||
// schema.RemoveAttr function (<autogenerated>:1: internal compiler error: panic: ...).
|
||||
// Can be removed in Go 1.20. See: https://github.com/golang/go/issues/54302.
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// InspectOption allows for managing schema configuration using functional options.
|
||||
type InspectOption func(inspect *Inspector)
|
||||
|
||||
// WithSchema provides a schema (named-database) for reading the tables from.
|
||||
func WithSchema(schema string) InspectOption {
|
||||
return func(m *Inspector) {
|
||||
m.schema = schema
|
||||
}
|
||||
}
|
||||
|
||||
// An Inspector provides methods for inspecting database tables.
|
||||
type Inspector struct {
|
||||
sqlDialect
|
||||
schema string
|
||||
}
|
||||
|
||||
// NewInspect returns an inspector for the given SQL driver.
|
||||
func NewInspect(d dialect.Driver, opts ...InspectOption) (*Inspector, error) {
|
||||
i := &Inspector{}
|
||||
for _, opt := range opts {
|
||||
opt(i)
|
||||
}
|
||||
switch d.Dialect() {
|
||||
case dialect.MySQL:
|
||||
i.sqlDialect = &MySQL{Driver: d, schema: i.schema}
|
||||
case dialect.SQLite:
|
||||
i.sqlDialect = &SQLite{Driver: d}
|
||||
case dialect.Postgres:
|
||||
i.sqlDialect = &Postgres{Driver: d, schema: i.schema}
|
||||
default:
|
||||
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect())
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// Tables returns the tables in the schema.
|
||||
func (i *Inspector) Tables(ctx context.Context) ([]*Table, error) {
|
||||
names, err := i.tables(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx := dialect.NopTx(i.sqlDialect)
|
||||
tables := make([]*Table, 0, len(names))
|
||||
for _, name := range names {
|
||||
t, err := i.table(ctx, tx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
fki, ok := i.sqlDialect.(interface {
|
||||
foreignKeys(context.Context, dialect.Tx, []*Table) error
|
||||
})
|
||||
if ok {
|
||||
if err := fki.foreignKeys(ctx, tx, tables); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (i *Inspector) tables(ctx context.Context) ([]string, error) {
|
||||
t, ok := i.sqlDialect.(interface{ tables() sql.Querier })
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sql/schema: %q driver does not support inspection", i.Dialect())
|
||||
}
|
||||
query, args := t.tables().Query()
|
||||
var (
|
||||
names []string
|
||||
rows = &sql.Rows{}
|
||||
)
|
||||
if err := i.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("%q driver: reading table names %w", i.Dialect(), err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := sql.ScanSlice(rows, &names); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
@@ -1,333 +0,0 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/schema/field"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInspector_Tables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options []InspectOption
|
||||
before map[string]func(mysqlMock)
|
||||
tables func(drv string) []*Table
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "default schema",
|
||||
before: map[string]func(mysqlMock){
|
||||
dialect.MySQL: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE())")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
|
||||
},
|
||||
dialect.SQLite: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")).
|
||||
WithArgs("table").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
},
|
||||
dialect.Postgres: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA()`)).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
},
|
||||
},
|
||||
tables: func(drv string) []*Table {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom schema",
|
||||
options: []InspectOption{WithSchema("public")},
|
||||
before: map[string]func(mysqlMock){
|
||||
dialect.MySQL: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = ?")).
|
||||
WithArgs("public").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).
|
||||
AddRow("users").
|
||||
AddRow("pets").
|
||||
AddRow("groups").
|
||||
AddRow("user_groups"))
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
|
||||
WithArgs("public", "users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}).
|
||||
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil).
|
||||
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil).
|
||||
AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "", nil, nil).
|
||||
AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin", nil, nil).
|
||||
AddRow("price", "decimal(6, 4)", "NO", "YES", "NULL", "", "", "", "6", "4").
|
||||
AddRow("bank_id", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil))
|
||||
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
|
||||
WithArgs("public", "users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}).
|
||||
AddRow("PRIMARY", "id", nil, "0", "1"))
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
|
||||
WithArgs("public", "pets").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}).
|
||||
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil).
|
||||
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil).
|
||||
AddRow("user_pets", "bigint(20)", "YES", "YES", "NULL", "", "", "", nil, nil))
|
||||
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
|
||||
WithArgs("public", "pets").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}).
|
||||
AddRow("PRIMARY", "id", nil, "0", "1"))
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
|
||||
WithArgs("public", "groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}).
|
||||
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil).
|
||||
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil))
|
||||
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
|
||||
WithArgs("public", "groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}).
|
||||
AddRow("PRIMARY", "id", nil, "0", "1"))
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
|
||||
WithArgs("public", "user_groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}).
|
||||
AddRow("user_id", "bigint(20)", "NO", "YES", "NULL", "", "", "", nil, nil).
|
||||
AddRow("group_id", "bigint(20)", "NO", "YES", "NULL", "", "", "", nil, nil))
|
||||
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
|
||||
WithArgs("public", "user_groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}))
|
||||
},
|
||||
dialect.SQLite: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")).
|
||||
WithArgs("table").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
||||
AddRow("users").
|
||||
AddRow("pets").
|
||||
AddRow("groups").
|
||||
AddRow("user_groups"))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("id", "integer", 1, "NULL", 1).
|
||||
AddRow("name", "varchar(255)", 0, "NULL", 0).
|
||||
AddRow("text", "text", 0, "NULL", 0).
|
||||
AddRow("uuid", "uuid", 0, "NULL", 0).
|
||||
AddRow("price", "real", 1, "NULL", 0).
|
||||
AddRow("bank_id", "varchar(255)", 1, "NULL", 0))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('pets') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("id", "integer", 1, "NULL", 1).
|
||||
AddRow("name", "varchar(255)", 0, "NULL", 0).
|
||||
AddRow("user_pets", "integer", 0, "NULL", 0))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('pets')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('groups') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("id", "integer", 1, "NULL", 1).
|
||||
AddRow("name", "varchar(255)", 1, "NULL", 0))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('groups')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('user_groups') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("user_id", "integer", 1, "NULL", 0).
|
||||
AddRow("group_id", "integer", 1, "NULL", 0))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('user_groups')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
|
||||
},
|
||||
dialect.Postgres: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = $1`)).
|
||||
WithArgs("public").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
||||
AddRow("users").
|
||||
AddRow("pets").
|
||||
AddRow("groups").
|
||||
AddRow("user_groups"))
|
||||
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)).
|
||||
WithArgs("public", "users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}).
|
||||
AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil).
|
||||
AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil).
|
||||
AddRow("text", "text", "YES", "NULL", "text", nil, nil, nil).
|
||||
AddRow("uuid", "uuid", "YES", "NULL", "uuid", nil, nil, nil).
|
||||
AddRow("price", "numeric", "NO", "NULL", "numeric", "6", "4", nil).
|
||||
AddRow("bank_id", "character", "NO", "NULL", "bpchar", nil, nil, 20))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "users"))).
|
||||
WithArgs("public").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}).
|
||||
AddRow("users_pkey", "id", "t", "t", 0))
|
||||
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)).
|
||||
WithArgs("public", "pets").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}).
|
||||
AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil).
|
||||
AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil).
|
||||
AddRow("user_pets", "bigint", "YES", "NULL", "int8", nil, nil, nil))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "pets"))).
|
||||
WithArgs("public").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}).
|
||||
AddRow("pets_pkey", "id", "t", "t", 0))
|
||||
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)).
|
||||
WithArgs("public", "groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}).
|
||||
AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil).
|
||||
AddRow("name", "character", "NO", "NULL", "bpchar", nil, nil, nil))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "groups"))).
|
||||
WithArgs("public").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}).
|
||||
AddRow("groups_pkey", "id", "t", "t", 0))
|
||||
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)).
|
||||
WithArgs("public", "user_groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}).
|
||||
AddRow("user_id", "bigint", "NO", "NULL", "int8", nil, nil, nil).
|
||||
AddRow("group_id", "bigint", "NO", "NULL", "int8", nil, nil, nil))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "user_groups"))).
|
||||
WithArgs("public").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "users"))).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "pets"))).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}).
|
||||
AddRow("public", "pet_users_pets", "pets", "user_pets", "public", "users", "id"))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "groups"))).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "user_groups"))).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}).
|
||||
AddRow("public", "user_groups_group_id", "user_groups", "group_id", "public", "groups", "id").
|
||||
AddRow("public", "user_groups_user_id", "user_groups", "user_id", "public", "users", "id"))
|
||||
},
|
||||
},
|
||||
tables: func(drv string) []*Table {
|
||||
var (
|
||||
c1 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Size: 255, Nullable: true},
|
||||
{Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true},
|
||||
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
|
||||
{Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{
|
||||
dialect.MySQL: "decimal(6,4)",
|
||||
dialect.Postgres: "numeric(6,4)",
|
||||
}},
|
||||
{Name: "bank_id", Type: field.TypeString, SchemaType: map[string]string{
|
||||
dialect.Postgres: "varchar(20)",
|
||||
}},
|
||||
}
|
||||
t1 = &Table{
|
||||
Name: "users",
|
||||
Columns: c1,
|
||||
PrimaryKey: c1[0:1],
|
||||
}
|
||||
c2 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Size: 255, Nullable: true},
|
||||
{Name: "user_pets", Type: field.TypeInt64, Nullable: true},
|
||||
}
|
||||
t2 = &Table{
|
||||
Name: "pets",
|
||||
Columns: c2,
|
||||
PrimaryKey: c2[0:1],
|
||||
}
|
||||
c3 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "name", Type: field.TypeString},
|
||||
}
|
||||
t3 = &Table{
|
||||
Name: "groups",
|
||||
Columns: c3,
|
||||
PrimaryKey: c3[0:1],
|
||||
}
|
||||
c4 = []*Column{
|
||||
{Name: "user_id", Type: field.TypeInt64},
|
||||
{Name: "group_id", Type: field.TypeInt64},
|
||||
}
|
||||
t4 = &Table{
|
||||
Name: "user_groups",
|
||||
Columns: c4,
|
||||
}
|
||||
)
|
||||
|
||||
// Only postgres currently supports foreign key inspection
|
||||
if drv == dialect.Postgres {
|
||||
t2.ForeignKeys = []*ForeignKey{
|
||||
{
|
||||
Symbol: "pet_users_pets",
|
||||
Columns: []*Column{c2[2]},
|
||||
RefTable: t1,
|
||||
RefColumns: []*Column{c1[0]},
|
||||
},
|
||||
}
|
||||
t4.ForeignKeys = []*ForeignKey{
|
||||
{
|
||||
Symbol: "user_groups_group_id",
|
||||
Columns: []*Column{c4[1]},
|
||||
RefTable: t3,
|
||||
RefColumns: []*Column{c3[0]},
|
||||
},
|
||||
{
|
||||
Symbol: "user_groups_user_id",
|
||||
Columns: []*Column{c4[0]},
|
||||
RefTable: t1,
|
||||
RefColumns: []*Column{c1[0]},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return []*Table{t1, t2, t3, t4}
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
for drv := range tt.before {
|
||||
t.Run(path.Join(drv, tt.name), func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.before[drv](mysqlMock{mock})
|
||||
inspect, err := NewInspect(sql.OpenDB(drv, db), tt.options...)
|
||||
require.NoError(t, err)
|
||||
tables, err := inspect.Tables(context.Background())
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
tablesMatch(t, drv, tables, tt.tables(drv))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tablesMatch(t *testing.T, drv string, got, expected []*Table) {
|
||||
require.Equal(t, len(expected), len(got))
|
||||
for i := range got {
|
||||
columnsMatch(t, drv, got[i].Columns, expected[i].Columns)
|
||||
columnsMatch(t, drv, got[i].PrimaryKey, expected[i].PrimaryKey)
|
||||
foreignKeysMatch(t, drv, got[i].ForeignKeys, expected[i].ForeignKeys)
|
||||
}
|
||||
}
|
||||
|
||||
func columnsMatch(t *testing.T, drv string, got, expected []*Column) {
|
||||
require.Equal(t, len(expected), len(got))
|
||||
for i := range got {
|
||||
c1, c2 := got[i], expected[i]
|
||||
require.Equal(t, c2.Name, c1.Name)
|
||||
require.Equal(t, c2.Nullable, c1.Nullable)
|
||||
require.True(t, c1.Type == c2.Type || c1.ConvertibleTo(c2), "mismatched types: %s - %s", c1.Type, c2.Type)
|
||||
if c2.SchemaType[drv] != "" {
|
||||
require.Equal(t, c2.SchemaType[drv], c1.SchemaType[drv])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func foreignKeysMatch(t *testing.T, drv string, expected []*ForeignKey, got []*ForeignKey) {
|
||||
require.Equal(t, len(expected), len(got))
|
||||
for i := range got {
|
||||
fk1, fk2 := got[i], expected[i]
|
||||
require.Equal(t, fk2.Symbol, fk1.Symbol)
|
||||
require.Equal(t, fk2.RefTable.Name, fk1.RefTable.Name)
|
||||
columnsMatch(t, drv, fk1.Columns, fk2.Columns)
|
||||
columnsMatch(t, drv, fk1.RefColumns, fk2.RefColumns)
|
||||
}
|
||||
}
|
||||
@@ -73,17 +73,6 @@ func WithDropIndex(b bool) MigrateOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithFixture sets the foreign-key renaming option to the migration when upgrading
|
||||
// sqlDialect from v0.1.0 (issue-#285). Defaults to false.
|
||||
//
|
||||
// Deprecated: This option is no longer needed with the Atlas based
|
||||
// migration engine, which now is the default.
|
||||
func WithFixture(b bool) MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
a.withFixture = b
|
||||
}
|
||||
}
|
||||
|
||||
// WithForeignKeys enables creating foreign-key in ddl. Defaults to true.
|
||||
func WithForeignKeys(b bool) MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
@@ -127,480 +116,6 @@ func (f CreateFunc) Create(ctx context.Context, tables ...*Table) error {
|
||||
return f(ctx, tables...)
|
||||
}
|
||||
|
||||
// Migrate runs the migration logic for the SQL dialects.
|
||||
//
|
||||
// Deprecated: Use the new Atlas struct instead.
|
||||
type Migrate struct {
|
||||
sqlDialect
|
||||
atlas *Atlas // Atlas this Migrate is based on
|
||||
|
||||
universalID bool // global unique ids
|
||||
dropColumns bool // drop deleted columns
|
||||
dropIndexes bool // drop deleted indexes
|
||||
withFixture bool // with fks rename fixture
|
||||
withForeignKeys bool // with foreign keys
|
||||
typeRanges []string // types order by their range
|
||||
hooks []Hook // hooks to apply before creation
|
||||
}
|
||||
|
||||
// Create creates all schema resources in the database. It works in an "append-only"
|
||||
// mode, which means, it only creates tables, appends columns to tables or modifies column types.
|
||||
//
|
||||
// Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not
|
||||
// resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but
|
||||
// changing varchar(120) to varchar(255) is valid. For more info, see the convert function below.
|
||||
//
|
||||
// Note that SQLite dialect does not support (this moment) the "append-only" mode describe above,
|
||||
// since it's used only for testing.
|
||||
func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
|
||||
m.setupTables(tables)
|
||||
var creator Creator = CreateFunc(m.create)
|
||||
for i := len(m.hooks) - 1; i >= 0; i-- {
|
||||
creator = m.hooks[i](creator)
|
||||
}
|
||||
return creator.Create(ctx, tables...)
|
||||
}
|
||||
|
||||
func (m *Migrate) create(ctx context.Context, tables ...*Table) error {
|
||||
if err := m.init(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
tx, err := m.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m.universalID {
|
||||
if err := m.types(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
}
|
||||
if err := m.txCreate(ctx, tx, tables...); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
|
||||
for _, t := range tables {
|
||||
switch exist, err := m.tableExist(ctx, tx, t.Name); {
|
||||
case err != nil:
|
||||
return err
|
||||
case exist:
|
||||
curr, err := m.table(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.verify(ctx, tx, curr); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.fixture(ctx, tx, curr, t); err != nil {
|
||||
return err
|
||||
}
|
||||
change, err := m.changeSet(curr, t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating changeset for %q: %w", t.Name, err)
|
||||
}
|
||||
if err := m.apply(ctx, tx, t.Name, change); err != nil {
|
||||
return err
|
||||
}
|
||||
default: // !exist
|
||||
query, args := m.tBuilder(t).Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("create table %q: %w", t.Name, err)
|
||||
}
|
||||
// If global unique identifier is enabled, and it's not
|
||||
// a relation table, allocate a range for the table pk.
|
||||
if m.universalID && len(t.PrimaryKey) == 1 {
|
||||
if err := m.allocPKRange(ctx, tx, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// indexes.
|
||||
for _, idx := range t.Indexes {
|
||||
query, args := m.addIndex(idx, t.Name).Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("create index %q: %w", idx.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !m.withForeignKeys {
|
||||
return nil
|
||||
}
|
||||
// Create foreign keys after tables were created/altered,
|
||||
// because circular foreign-key constraints are possible.
|
||||
for _, t := range tables {
|
||||
if len(t.ForeignKeys) == 0 {
|
||||
continue
|
||||
}
|
||||
fks := make([]*ForeignKey, 0, len(t.ForeignKeys))
|
||||
for _, fk := range t.ForeignKeys {
|
||||
exist, err := m.fkExist(ctx, tx, fk.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exist {
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
}
|
||||
if len(fks) == 0 {
|
||||
continue
|
||||
}
|
||||
b := sql.Dialect(m.Dialect()).AlterTable(t.Name)
|
||||
for _, fk := range fks {
|
||||
b.AddForeignKey(fk.DSL())
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("create foreign keys for %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// apply changes on the given table.
|
||||
func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change *changes) error {
|
||||
// Constraints should be dropped before dropping columns, because if a column
|
||||
// is a part of multi-column constraints (like, unique index), ALTER TABLE
|
||||
// might fail if the intermediate state violates the constraints.
|
||||
if m.dropIndexes {
|
||||
if pr, ok := m.sqlDialect.(preparer); ok {
|
||||
if err := pr.prepare(ctx, tx, change, table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, idx := range change.index.drop {
|
||||
if err := m.dropIndex(ctx, tx, idx, table); err != nil {
|
||||
return fmt.Errorf("drop index of table %q: %w", table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
var drop []*Column
|
||||
if m.dropColumns {
|
||||
drop = change.column.drop
|
||||
}
|
||||
queries := m.alterColumns(table, change.column.add, change.column.modify, drop)
|
||||
// If there's actual action to execute on ALTER TABLE.
|
||||
for i := range queries {
|
||||
query, args := queries[i].Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("alter table %q: %w", table, err)
|
||||
}
|
||||
}
|
||||
for _, idx := range change.index.add {
|
||||
query, args := m.addIndex(idx, table).Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("create index %q: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// changes to apply on existing table.
|
||||
type changes struct {
|
||||
// column changes.
|
||||
column struct {
|
||||
add []*Column
|
||||
drop []*Column
|
||||
modify []*Column
|
||||
}
|
||||
// index changes.
|
||||
index struct {
|
||||
add Indexes
|
||||
drop Indexes
|
||||
}
|
||||
}
|
||||
|
||||
// dropColumn returns the dropped column by name (if any).
|
||||
func (c *changes) dropColumn(name string) (*Column, bool) {
|
||||
for _, col := range c.column.drop {
|
||||
if col.Name == name {
|
||||
return col, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// changeSet returns a changes object to be applied on existing table.
|
||||
// It fails if one of the changes is invalid.
|
||||
func (m *Migrate) changeSet(curr, new *Table) (*changes, error) {
|
||||
change := &changes{}
|
||||
// pks.
|
||||
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
}
|
||||
for i := range curr.PrimaryKey {
|
||||
if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
}
|
||||
}
|
||||
// Add or modify columns.
|
||||
for _, c1 := range new.Columns {
|
||||
// Ignore primary keys.
|
||||
if c1.PrimaryKey() {
|
||||
continue
|
||||
}
|
||||
switch c2, ok := curr.column(c1.Name); {
|
||||
case !ok:
|
||||
change.column.add = append(change.column.add, c1)
|
||||
case !c2.Type.Valid():
|
||||
return nil, fmt.Errorf("invalid type %q for column %q", c2.typ, c2.Name)
|
||||
// Modify a non-unique column to unique.
|
||||
case c1.Unique && !c2.Unique:
|
||||
// Make sure the table does not have unique index for this column
|
||||
// before adding it to the changeset, because there are 2 ways to
|
||||
// configure uniqueness on sqlDialect.Field (using the Unique modifier or
|
||||
// adding rule on the Indexes option).
|
||||
if idx, ok := curr.index(c1.Name); !ok || !idx.Unique {
|
||||
change.index.add.append(&Index{
|
||||
Name: c1.Name,
|
||||
Unique: true,
|
||||
Columns: []*Column{c1},
|
||||
columns: []string{c1.Name},
|
||||
})
|
||||
}
|
||||
// Modify a unique column to non-unique.
|
||||
case !c1.Unique && c2.Unique:
|
||||
// If the uniqueness was defined on the Indexes option,
|
||||
// or was moved from the Unique modifier to the Indexes.
|
||||
if idx, ok := new.index(c1.Name); ok && idx.Unique {
|
||||
continue
|
||||
}
|
||||
idx, ok := curr.index(c2.Name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing index to drop for unique column %q", c2.Name)
|
||||
}
|
||||
change.index.drop.append(idx)
|
||||
// Extending column types.
|
||||
case m.needsConversion(c2, c1):
|
||||
if !c2.ConvertibleTo(c1) {
|
||||
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, m.cType(c1), m.cType(c2))
|
||||
}
|
||||
fallthrough
|
||||
// Change nullability of a column.
|
||||
case c1.Nullable != c2.Nullable:
|
||||
change.column.modify = append(change.column.modify, c1)
|
||||
// Change default value.
|
||||
case c1.Default != nil && c2.Default == nil:
|
||||
change.column.modify = append(change.column.modify, c1)
|
||||
}
|
||||
}
|
||||
// Drop columns.
|
||||
for _, c1 := range curr.Columns {
|
||||
// If a column was dropped, multi-columns indexes that are associated with this column will
|
||||
// no longer behave the same. Therefore, these indexes should be dropped too. There's no need
|
||||
// to do it explicitly (here), because entc will remove them from the schema specification,
|
||||
// and they will be dropped in the block below.
|
||||
if _, ok := new.column(c1.Name); !ok {
|
||||
change.column.drop = append(change.column.drop, c1)
|
||||
}
|
||||
}
|
||||
// Add or modify indexes.
|
||||
for _, idx1 := range new.Indexes {
|
||||
switch idx2, ok := curr.index(idx1.Name); {
|
||||
case !ok:
|
||||
change.index.add.append(idx1)
|
||||
// Changing index cardinality require drop and create.
|
||||
case idx1.Unique != idx2.Unique:
|
||||
change.index.drop.append(idx2)
|
||||
change.index.add.append(idx1)
|
||||
default:
|
||||
im, ok := m.sqlDialect.(interface{ indexModified(old, new *Index) bool })
|
||||
// If the dialect supports comparing indexes.
|
||||
if ok && im.indexModified(idx2, idx1) {
|
||||
change.index.drop.append(idx2)
|
||||
change.index.add.append(idx1)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Drop indexes.
|
||||
for _, idx := range curr.Indexes {
|
||||
if _, isFK := new.fk(idx.Name); !isFK && !new.hasIndex(idx.Name, idx.realname) {
|
||||
change.index.drop.append(idx)
|
||||
}
|
||||
}
|
||||
return change, nil
|
||||
}
|
||||
|
||||
// fixture is a special migration code for renaming foreign-key columns (issue-#285).
|
||||
func (m *Migrate) fixture(ctx context.Context, tx dialect.Tx, curr, new *Table) error {
|
||||
d, ok := m.sqlDialect.(fkRenamer)
|
||||
if !m.withFixture || !m.withForeignKeys || !ok {
|
||||
return nil
|
||||
}
|
||||
rename := make(map[string]*Index)
|
||||
for _, fk := range new.ForeignKeys {
|
||||
ok, err := m.fkExist(ctx, tx, fk.Symbol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking foreign-key existence %q: %w", fk.Symbol, err)
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
column, err := m.fkColumn(ctx, tx, fk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newcol := fk.Columns[0]
|
||||
if column == newcol.Name {
|
||||
continue
|
||||
}
|
||||
query, args := d.renameColumn(curr, &Column{Name: column}, newcol).Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("rename column %q: %w", column, err)
|
||||
}
|
||||
prev, ok := curr.column(column)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// Find all indexes that ~maybe need to be renamed.
|
||||
for _, idx := range prev.indexes {
|
||||
switch _, ok := new.index(idx.Name); {
|
||||
// Ignore indexes that exist in the schema, PKs.
|
||||
case ok || idx.primary:
|
||||
// Index that was created implicitly for a unique
|
||||
// column needs to be renamed to the column name.
|
||||
case d.isImplicitIndex(idx, prev):
|
||||
idx2 := &Index{Name: newcol.Name, Unique: true, Columns: []*Column{newcol}}
|
||||
query, args := d.renameIndex(curr, idx, idx2).Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("rename index %q: %w", prev.Name, err)
|
||||
}
|
||||
idx.Name = idx2.Name
|
||||
default:
|
||||
rename[idx.Name] = idx
|
||||
}
|
||||
}
|
||||
// Update the name of the loaded column, so `changeSet` won't create it.
|
||||
prev.Name = newcol.Name
|
||||
}
|
||||
// Go over the indexes that need to be renamed
|
||||
// and find their ~identical in the new schema.
|
||||
for _, idx := range rename {
|
||||
Find:
|
||||
// Find its ~identical in the new schema, and rename it
|
||||
// if it doesn't exist.
|
||||
for _, idx2 := range new.Indexes {
|
||||
if _, ok := curr.index(idx2.Name); ok {
|
||||
continue
|
||||
}
|
||||
if idx.sameAs(idx2) {
|
||||
query, args := d.renameIndex(curr, idx, idx2).Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("rename index %q: %w", idx.Name, err)
|
||||
}
|
||||
idx.Name = idx2.Name
|
||||
break Find
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verify that the auto-increment counter is correct for table with universal-id support.
|
||||
func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error {
|
||||
vr, ok := m.sqlDialect.(verifyRanger)
|
||||
if !ok || !m.universalID {
|
||||
return nil
|
||||
}
|
||||
id := indexOf(m.typeRanges, t.Name)
|
||||
if id == -1 {
|
||||
return nil
|
||||
}
|
||||
return vr.verifyRange(ctx, tx, t, int64(id<<32))
|
||||
}
|
||||
|
||||
// types loads the type list from the type store. It will create the types table, if it does not exist yet.
|
||||
func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
exists, err := m.tableExist(ctx, tx, TypeTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
t := NewTypesTable()
|
||||
query, args := m.tBuilder(t).Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("create types table: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Dialect(m.Dialect()).
|
||||
Select("type").From(sql.Table(TypeTable)).OrderBy(sql.Asc("id")).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return fmt.Errorf("query types table: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, &m.typeRanges)
|
||||
}
|
||||
|
||||
func (m *Migrate) allocPKRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) error {
|
||||
r, err := m.pkRange(ctx, conn, t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.setRange(ctx, conn, t, r)
|
||||
}
|
||||
|
||||
func (m *Migrate) pkRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) (int64, error) {
|
||||
id := indexOf(m.typeRanges, t.Name)
|
||||
// If the table re-created, re-use its range from
|
||||
// the past. Otherwise, allocate a new id-range.
|
||||
if id == -1 {
|
||||
if len(m.typeRanges) > MaxTypes {
|
||||
return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes)
|
||||
}
|
||||
query, args := sql.Dialect(m.Dialect()).Insert(TypeTable).Columns("type").Values(t.Name).Query()
|
||||
if err := conn.Exec(ctx, query, args, nil); err != nil {
|
||||
return 0, fmt.Errorf("insert into ent_types: %w", err)
|
||||
}
|
||||
id = len(m.typeRanges)
|
||||
m.typeRanges = append(m.typeRanges, t.Name)
|
||||
}
|
||||
return int64(id << 32), nil
|
||||
}
|
||||
|
||||
// fkColumn returns the column name of a foreign-key.
|
||||
func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (string, error) {
|
||||
t1 := sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS t1").Unquote().As("t1")
|
||||
t2 := sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t2").Unquote().As("t2")
|
||||
query, args := sql.Dialect(m.Dialect()).
|
||||
Select("column_name").
|
||||
From(t1).
|
||||
Join(t2).
|
||||
On(t1.C("constraint_name"), t2.C("constraint_name")).
|
||||
Where(sql.And(
|
||||
sql.EQ(t2.C("constraint_type"), sql.Raw("'FOREIGN KEY'")),
|
||||
m.sqlDialect.(fkRenamer).matchSchema(t2.C("table_schema")),
|
||||
m.sqlDialect.(fkRenamer).matchSchema(t1.C("table_schema")),
|
||||
sql.EQ(t2.C("constraint_name"), fk.Symbol),
|
||||
)).
|
||||
Query()
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return "", fmt.Errorf("reading foreign-key %q column: %w", fk.Symbol, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
column, err := sql.ScanString(rows)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("scanning foreign-key %q column: %w", fk.Symbol, err)
|
||||
}
|
||||
return column, nil
|
||||
}
|
||||
|
||||
// setup ensures the table is configured properly, like table columns
|
||||
// are linked to their indexes, and PKs columns are defined.
|
||||
func (m *Migrate) setupTables(tables []*Table) { m.atlas.setupTables(tables) }
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
func rollback(tx dialect.Tx, err error) error {
|
||||
err = fmt.Errorf("sql/schema: %w", err)
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
err = fmt.Errorf("%w: %v", err, rerr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// exist checks if the given COUNT query returns a value >= 1.
|
||||
func exist(ctx context.Context, conn dialect.ExecQuerier, query string, args ...any) (bool, error) {
|
||||
rows := &sql.Rows{}
|
||||
@@ -628,30 +143,7 @@ type sqlDialect interface {
|
||||
atBuilder
|
||||
dialect.Driver
|
||||
init(context.Context) error
|
||||
table(context.Context, dialect.Tx, string) (*Table, error)
|
||||
tableExist(context.Context, dialect.ExecQuerier, string) (bool, error)
|
||||
fkExist(context.Context, dialect.Tx, string) (bool, error)
|
||||
setRange(context.Context, dialect.ExecQuerier, *Table, int64) error
|
||||
dropIndex(context.Context, dialect.Tx, *Index, string) error
|
||||
// table, column and index builder per dialect.
|
||||
cType(*Column) string
|
||||
tBuilder(*Table) *sql.TableBuilder
|
||||
addIndex(*Index, string) *sql.IndexBuilder
|
||||
alterColumns(table string, add, modify, drop []*Column) sql.Queries
|
||||
needsConversion(*Column, *Column) bool
|
||||
}
|
||||
|
||||
type preparer interface {
|
||||
prepare(context.Context, dialect.Tx, *changes, string) error
|
||||
}
|
||||
|
||||
// fkRenamer is used by the fixture migration (to solve #285),
|
||||
// and it's implemented by the different dialects for renaming FKs.
|
||||
type fkRenamer interface {
|
||||
matchSchema(...string) *sql.Predicate
|
||||
isImplicitIndex(*Index, *Column) bool
|
||||
renameIndex(*Table, *Index, *Index) sql.Querier
|
||||
renameColumn(*Table, *Column, *Column) sql.Querier
|
||||
}
|
||||
|
||||
// verifyRanger wraps the method for verifying global-id range correctness.
|
||||
|
||||
@@ -28,53 +28,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrateHookOmitTable(t *testing.T) {
|
||||
db, mk, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
tables := []*Table{{Name: "users"}, {Name: "pets"}}
|
||||
mock := mysqlMock{mk}
|
||||
mock.start("5.7.23")
|
||||
mock.tableExists("pets", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
m, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator {
|
||||
return CreateFunc(func(ctx context.Context, tables ...*Table) error {
|
||||
return next.Create(ctx, tables[1])
|
||||
})
|
||||
}), WithAtlas(false))
|
||||
require.NoError(t, err)
|
||||
err = m.Create(context.Background(), tables...)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMigrateHookAddTable(t *testing.T) {
|
||||
db, mk, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
tables := []*Table{{Name: "users"}}
|
||||
mock := mysqlMock{mk}
|
||||
mock.start("5.7.23")
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.tableExists("pets", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
m, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator {
|
||||
return CreateFunc(func(ctx context.Context, tables ...*Table) error {
|
||||
return next.Create(ctx, tables[0], &Table{Name: "pets"})
|
||||
})
|
||||
}), WithAtlas(false))
|
||||
require.NoError(t, err)
|
||||
err = m.Create(context.Background(), tables...)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMigrate_Formatter(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/schema/field"
|
||||
|
||||
@@ -22,7 +21,7 @@ import (
|
||||
"ariga.io/atlas/sql/schema"
|
||||
)
|
||||
|
||||
// MySQL is a MySQL migration driver.
|
||||
// MySQL adapter for Atlas migration engine.
|
||||
type MySQL struct {
|
||||
dialect.Driver
|
||||
schema string
|
||||
@@ -59,532 +58,6 @@ func (d *MySQL) tableExist(ctx context.Context, conn dialect.ExecQuerier, name s
|
||||
return exist(ctx, conn, query, args...)
|
||||
}
|
||||
|
||||
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLE_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
d.matchSchema(),
|
||||
sql.EQ("CONSTRAINT_TYPE", "FOREIGN KEY"),
|
||||
sql.EQ("CONSTRAINT_NAME", name),
|
||||
)).Query()
|
||||
return exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
// table loads the current table description from the database.
|
||||
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select(
|
||||
"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name",
|
||||
"numeric_precision", "numeric_scale",
|
||||
).
|
||||
From(sql.Table("COLUMNS").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
d.matchSchema(),
|
||||
sql.EQ("TABLE_NAME", name)),
|
||||
).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("mysql: reading table description %w", err)
|
||||
}
|
||||
// Call Close in cases of failures (Close is idempotent).
|
||||
defer rows.Close()
|
||||
t := NewTable(name)
|
||||
for rows.Next() {
|
||||
c := &Column{}
|
||||
if err := d.scanColumn(c, rows); err != nil {
|
||||
return nil, fmt.Errorf("mysql: %w", err)
|
||||
}
|
||||
t.AddColumn(c)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("mysql: closing rows %w", err)
|
||||
}
|
||||
indexes, err := d.indexes(ctx, tx, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Add and link indexes to table columns.
|
||||
for _, idx := range indexes {
|
||||
t.addIndex(idx)
|
||||
}
|
||||
if _, ok := d.mariadb(); ok {
|
||||
if err := d.normalizeJSON(ctx, tx, t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// table loads the table indexes from the database.
|
||||
func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, t *Table) ([]*Index, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("index_name", "column_name", "sub_part", "non_unique", "seq_in_index").
|
||||
From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
d.matchSchema(),
|
||||
sql.EQ("TABLE_NAME", t.Name),
|
||||
)).
|
||||
OrderBy("index_name", "seq_in_index").
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("mysql: reading index description %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
idx, err := d.scanIndexes(rows, t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: %w", err)
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
func (d *MySQL) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
|
||||
return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []any{}, nil)
|
||||
}
|
||||
|
||||
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.ExecQuerier, t *Table, expected int64) error {
|
||||
if expected == 0 {
|
||||
return nil
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("AUTO_INCREMENT").
|
||||
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
d.matchSchema(),
|
||||
sql.EQ("TABLE_NAME", t.Name),
|
||||
)).
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return fmt.Errorf("mysql: query auto_increment %w", err)
|
||||
}
|
||||
// Call Close in cases of failures (Close is idempotent).
|
||||
defer rows.Close()
|
||||
actual := &sql.NullInt64{}
|
||||
if err := sql.ScanOne(rows, actual); err != nil {
|
||||
return fmt.Errorf("mysql: scan auto_increment %w", err)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Table is empty and auto-increment is not configured. This can happen
|
||||
// because MySQL (< 8.0) stores the auto-increment counter in main memory
|
||||
// (not persistent), and the value is reset on restart (if table is empty).
|
||||
if actual.Int64 <= 1 {
|
||||
return d.setRange(ctx, tx, t, expected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tBuilder returns the MySQL DSL query for table creation.
|
||||
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name).IfNotExists()
|
||||
for _, c := range t.Columns {
|
||||
b.Column(d.addColumn(c))
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
// Charset and collation config on MySQL table.
|
||||
// These options can be overridden by the entsql annotation.
|
||||
b.Charset("utf8mb4").Collate("utf8mb4_bin")
|
||||
if t.Annotation != nil {
|
||||
if charset := t.Annotation.Charset; charset != "" {
|
||||
b.Charset(charset)
|
||||
}
|
||||
if collate := t.Annotation.Collation; collate != "" {
|
||||
b.Collate(collate)
|
||||
}
|
||||
if opts := t.Annotation.Options; opts != "" {
|
||||
b.Options(opts)
|
||||
}
|
||||
addChecks(b, t.Annotation)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// cType returns the MySQL string type for the given column.
|
||||
func (d *MySQL) cType(c *Column) (t string) {
|
||||
if c.SchemaType != nil && c.SchemaType[dialect.MySQL] != "" {
|
||||
// MySQL returns the column type lower cased.
|
||||
return strings.ToLower(c.SchemaType[dialect.MySQL])
|
||||
}
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "boolean"
|
||||
case field.TypeInt8:
|
||||
t = "tinyint"
|
||||
case field.TypeUint8:
|
||||
t = "tinyint unsigned"
|
||||
case field.TypeInt16:
|
||||
t = "smallint"
|
||||
case field.TypeUint16:
|
||||
t = "smallint unsigned"
|
||||
case field.TypeInt32:
|
||||
t = "int"
|
||||
case field.TypeUint32:
|
||||
t = "int unsigned"
|
||||
case field.TypeInt, field.TypeInt64:
|
||||
t = "bigint"
|
||||
case field.TypeUint, field.TypeUint64:
|
||||
t = "bigint unsigned"
|
||||
case field.TypeBytes:
|
||||
size := int64(math.MaxUint16)
|
||||
if c.Size > 0 {
|
||||
size = c.Size
|
||||
}
|
||||
switch {
|
||||
case size <= math.MaxUint8:
|
||||
t = "tinyblob"
|
||||
case size <= math.MaxUint16:
|
||||
t = "blob"
|
||||
case size < 1<<24:
|
||||
t = "mediumblob"
|
||||
case size <= math.MaxUint32:
|
||||
t = "longblob"
|
||||
}
|
||||
case field.TypeJSON:
|
||||
t = "json"
|
||||
if compareVersions(d.version, "5.7.8") == -1 {
|
||||
t = "longblob"
|
||||
}
|
||||
case field.TypeString:
|
||||
size := c.Size
|
||||
if size == 0 {
|
||||
size = d.defaultSize(c)
|
||||
}
|
||||
switch {
|
||||
case c.typ == "tinytext", c.typ == "text":
|
||||
t = c.typ
|
||||
case size <= math.MaxUint16:
|
||||
t = fmt.Sprintf("varchar(%d)", size)
|
||||
case size == 1<<24-1:
|
||||
t = "mediumtext"
|
||||
default:
|
||||
t = "longtext"
|
||||
}
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = c.scanTypeOr("double")
|
||||
case field.TypeTime:
|
||||
t = c.scanTypeOr("timestamp")
|
||||
// In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP`
|
||||
// and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is
|
||||
// suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute.
|
||||
if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c.Default == nil {
|
||||
c.Nullable = c.Attr == ""
|
||||
}
|
||||
case field.TypeEnum:
|
||||
values := make([]string, len(c.Enums))
|
||||
for i, e := range c.Enums {
|
||||
values[i] = fmt.Sprintf("'%s'", e)
|
||||
}
|
||||
t = fmt.Sprintf("enum(%s)", strings.Join(values, ", "))
|
||||
case field.TypeUUID:
|
||||
t = "char(36) binary"
|
||||
if d.supportsUUID() {
|
||||
t = "uuid"
|
||||
}
|
||||
case field.TypeOther:
|
||||
t = c.typ
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// addColumn returns the DSL query for adding the given column to a table.
|
||||
// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable].
|
||||
func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("AUTO_INCREMENT")
|
||||
}
|
||||
c.nullable(b)
|
||||
c.defaultValue(b)
|
||||
if c.Collation != "" {
|
||||
b.Attr("COLLATE " + c.Collation)
|
||||
}
|
||||
if c.Type == field.TypeJSON {
|
||||
// Manually add a `CHECK` clause for older versions of MariaDB for validating the
|
||||
// JSON documents. This constraint is automatically included from version 10.4.3.
|
||||
if version, ok := d.mariadb(); ok && compareVersions(version, "10.4.3") == -1 {
|
||||
b.Check(func(b *sql.Builder) {
|
||||
b.WriteString("JSON_VALID(").Ident(c.Name).WriteByte(')')
|
||||
})
|
||||
}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// addIndex returns the querying for adding an index to MySQL.
|
||||
func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder {
|
||||
idx := sql.CreateIndex(i.Name).Table(table)
|
||||
if i.Unique {
|
||||
idx.Unique()
|
||||
}
|
||||
parts := indexParts(i)
|
||||
for _, c := range i.Columns {
|
||||
part, ok := parts[c.Name]
|
||||
if !ok || part == 0 {
|
||||
idx.Column(c.Name)
|
||||
} else {
|
||||
idx.Column(fmt.Sprintf("%s(%d)", idx.Builder.Quote(c.Name), part))
|
||||
}
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
// dropIndex drops a MySQL index.
|
||||
func (d *MySQL) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error {
|
||||
query, args := idx.DropBuilder(table).Query()
|
||||
return tx.Exec(ctx, query, args, nil)
|
||||
}
|
||||
|
||||
// prepare runs preparation work that needs to be done to apply the change-set.
|
||||
func (d *MySQL) prepare(ctx context.Context, tx dialect.Tx, change *changes, table string) error {
|
||||
for _, idx := range change.index.drop {
|
||||
switch n := len(idx.columns); {
|
||||
case n == 0:
|
||||
return fmt.Errorf("index %q has no columns", idx.Name)
|
||||
case n > 1:
|
||||
continue // not a foreign-key index.
|
||||
}
|
||||
var qr sql.Querier
|
||||
Switch:
|
||||
switch col, ok := change.dropColumn(idx.columns[0]); {
|
||||
// If both the index and the column need to be dropped, the foreign-key
|
||||
// constraint that is associated with them need to be dropped as well.
|
||||
case ok:
|
||||
names, err := d.fkNames(ctx, tx, table, col.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(names) == 1 {
|
||||
qr = sql.AlterTable(table).DropForeignKey(names[0])
|
||||
}
|
||||
// If the uniqueness was dropped from a foreign-key column,
|
||||
// create a "simple index" if no other index exist for it.
|
||||
case !ok && idx.Unique && len(idx.Columns) > 0:
|
||||
col := idx.Columns[0]
|
||||
for _, idx2 := range col.indexes {
|
||||
if idx2 != idx && len(idx2.columns) == 1 {
|
||||
break Switch
|
||||
}
|
||||
}
|
||||
names, err := d.fkNames(ctx, tx, table, col.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(names) == 1 {
|
||||
qr = sql.CreateIndex(names[0]).Table(table).Columns(col.Name)
|
||||
}
|
||||
}
|
||||
if qr != nil {
|
||||
query, args := qr.Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanColumn scans the column information from MySQL column description.
|
||||
func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
var (
|
||||
nullable sql.NullString
|
||||
defaults sql.NullString
|
||||
numericPrecision sql.NullInt64
|
||||
numericScale sql.NullInt64
|
||||
)
|
||||
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}, &numericPrecision, &numericScale); err != nil {
|
||||
return fmt.Errorf("scanning column description: %w", err)
|
||||
}
|
||||
c.Unique = c.UniqueKey()
|
||||
if nullable.Valid {
|
||||
c.Nullable = nullable.String == "YES"
|
||||
}
|
||||
if c.typ == "" {
|
||||
return fmt.Errorf("missing type information for column %q", c.Name)
|
||||
}
|
||||
parts, size, unsigned, err := parseColumn(c.typ)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch parts[0] {
|
||||
case "mediumint", "int":
|
||||
c.Type = field.TypeInt32
|
||||
if unsigned {
|
||||
c.Type = field.TypeUint32
|
||||
}
|
||||
case "smallint":
|
||||
c.Type = field.TypeInt16
|
||||
if unsigned {
|
||||
c.Type = field.TypeUint16
|
||||
}
|
||||
case "bigint":
|
||||
c.Type = field.TypeInt64
|
||||
if unsigned {
|
||||
c.Type = field.TypeUint64
|
||||
}
|
||||
case "tinyint":
|
||||
switch {
|
||||
case size == 1:
|
||||
c.Type = field.TypeBool
|
||||
case unsigned:
|
||||
c.Type = field.TypeUint8
|
||||
default:
|
||||
c.Type = field.TypeInt8
|
||||
}
|
||||
case "double", "float":
|
||||
c.Type = field.TypeFloat64
|
||||
case "numeric", "decimal":
|
||||
c.Type = field.TypeFloat64
|
||||
// If precision is specified then we should take that into account.
|
||||
if numericPrecision.Valid {
|
||||
schemaType := fmt.Sprintf("%s(%d,%d)", parts[0], numericPrecision.Int64, numericScale.Int64)
|
||||
c.SchemaType = map[string]string{dialect.MySQL: schemaType}
|
||||
}
|
||||
case "time", "timestamp", "date", "datetime":
|
||||
c.Type = field.TypeTime
|
||||
// The mapping from schema defaults to database
|
||||
// defaults is not supported for TypeTime fields.
|
||||
defaults = sql.NullString{}
|
||||
case "tinyblob":
|
||||
c.Size = math.MaxUint8
|
||||
c.Type = field.TypeBytes
|
||||
case "blob":
|
||||
c.Size = math.MaxUint16
|
||||
c.Type = field.TypeBytes
|
||||
case "mediumblob":
|
||||
c.Size = 1<<24 - 1
|
||||
c.Type = field.TypeBytes
|
||||
case "longblob":
|
||||
c.Size = math.MaxUint32
|
||||
c.Type = field.TypeBytes
|
||||
case "binary", "varbinary":
|
||||
c.Type = field.TypeBytes
|
||||
c.Size = size
|
||||
case "varchar":
|
||||
c.Type = field.TypeString
|
||||
c.Size = size
|
||||
case "text":
|
||||
c.Size = math.MaxUint16
|
||||
c.Type = field.TypeString
|
||||
case "mediumtext":
|
||||
c.Size = 1<<24 - 1
|
||||
c.Type = field.TypeString
|
||||
case "longtext":
|
||||
c.Size = math.MaxInt32
|
||||
c.Type = field.TypeString
|
||||
case "json":
|
||||
c.Type = field.TypeJSON
|
||||
case "enum":
|
||||
c.Type = field.TypeEnum
|
||||
// Parse the enum values according to the MySQL format.
|
||||
// github.com/mysql/mysql-server/blob/8.0/sql/field.cc#Field_enum::sql_type
|
||||
values := strings.TrimSuffix(strings.TrimPrefix(c.typ, "enum("), ")")
|
||||
if values == "" {
|
||||
return fmt.Errorf("mysql: unexpected enum type: %q", c.typ)
|
||||
}
|
||||
parts := strings.Split(values, "','")
|
||||
for i := range parts {
|
||||
c.Enums = append(c.Enums, strings.Trim(parts[i], "'"))
|
||||
}
|
||||
case "char":
|
||||
c.Type = field.TypeOther
|
||||
// UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens).
|
||||
if size == 36 {
|
||||
c.Type = field.TypeUUID
|
||||
}
|
||||
case "point", "geometry", "linestring", "polygon":
|
||||
c.Type = field.TypeOther
|
||||
default:
|
||||
return fmt.Errorf("unknown column type %q for version %q", parts[0], d.version)
|
||||
}
|
||||
if defaults.Valid {
|
||||
return c.ScanDefault(defaults.String)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows,
|
||||
// should return the following 5 columns: INDEX_NAME, COLUMN_NAME, SUB_PART, NON_UNIQUE,
|
||||
// SEQ_IN_INDEX. SEQ_IN_INDEX specifies the position of the column in the index columns.
|
||||
func (d *MySQL) scanIndexes(rows *sql.Rows, t *Table) (Indexes, error) {
|
||||
var (
|
||||
i Indexes
|
||||
names = make(map[string]*Index)
|
||||
)
|
||||
for rows.Next() {
|
||||
var (
|
||||
name string
|
||||
column string
|
||||
nonuniq bool
|
||||
seqindex int
|
||||
subpart sql.NullInt64
|
||||
)
|
||||
if err := rows.Scan(&name, &column, &subpart, &nonuniq, &seqindex); err != nil {
|
||||
return nil, fmt.Errorf("scanning index description: %w", err)
|
||||
}
|
||||
// Skip primary keys.
|
||||
if name == "PRIMARY" {
|
||||
c, ok := t.column(column)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing primary-key column: %q", column)
|
||||
}
|
||||
t.PrimaryKey = append(t.PrimaryKey, c)
|
||||
continue
|
||||
}
|
||||
idx, ok := names[name]
|
||||
if !ok {
|
||||
idx = &Index{Name: name, Unique: !nonuniq, Annotation: &entsql.IndexAnnotation{}}
|
||||
i = append(i, idx)
|
||||
names[name] = idx
|
||||
}
|
||||
idx.columns = append(idx.columns, column)
|
||||
if subpart.Int64 > 0 {
|
||||
if idx.Annotation.PrefixColumns == nil {
|
||||
idx.Annotation.PrefixColumns = make(map[string]uint)
|
||||
}
|
||||
idx.Annotation.PrefixColumns[column] = uint(subpart.Int64)
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// isImplicitIndex reports if the index was created implicitly for the unique column.
|
||||
func (d *MySQL) isImplicitIndex(idx *Index, col *Column) bool {
|
||||
// We execute `CHANGE COLUMN` on older versions of MySQL (<8.0), which
|
||||
// auto create the new index. The old one, will be dropped in `changeSet`.
|
||||
if compareVersions(d.version, "8.0.0") >= 0 {
|
||||
return idx.Name == col.Name && col.Unique
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// renameColumn returns the statement for renaming a column in
|
||||
// MySQL based on its version.
|
||||
func (d *MySQL) renameColumn(t *Table, old, new *Column) sql.Querier {
|
||||
q := sql.AlterTable(t.Name)
|
||||
if compareVersions(d.version, "8.0.0") >= 0 {
|
||||
return q.RenameColumn(old.Name, new.Name)
|
||||
}
|
||||
return q.ChangeColumn(old.Name, d.addColumn(new))
|
||||
}
|
||||
|
||||
// renameIndex returns the statement for renaming an index.
|
||||
func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier {
|
||||
q := sql.AlterTable(t.Name)
|
||||
if compareVersions(d.version, "5.7.0") >= 0 {
|
||||
return q.RenameIndex(old.Name, new.Name)
|
||||
}
|
||||
return q.DropIndex(old.Name).AddIndex(new.Builder(t.Name))
|
||||
}
|
||||
|
||||
// matchSchema returns the predicate for matching table schema.
|
||||
func (d *MySQL) matchSchema(columns ...string) *sql.Predicate {
|
||||
column := "TABLE_SCHEMA"
|
||||
@@ -597,196 +70,6 @@ func (d *MySQL) matchSchema(columns ...string) *sql.Predicate {
|
||||
return sql.EQ(column, sql.Raw("(SELECT DATABASE())"))
|
||||
}
|
||||
|
||||
// tables returns the query for getting the in the schema.
|
||||
func (d *MySQL) tables() sql.Querier {
|
||||
return sql.Select("TABLE_NAME").
|
||||
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
||||
Where(d.matchSchema())
|
||||
}
|
||||
|
||||
// alterColumns returns the queries for applying the columns change-set.
|
||||
func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
|
||||
b := sql.Dialect(dialect.MySQL).AlterTable(table)
|
||||
for _, c := range add {
|
||||
b.AddColumn(d.addColumn(c))
|
||||
}
|
||||
for _, c := range modify {
|
||||
b.ModifyColumn(d.addColumn(c))
|
||||
}
|
||||
for _, c := range drop {
|
||||
b.DropColumn(sql.Dialect(dialect.MySQL).Column(c.Name))
|
||||
}
|
||||
if len(b.Queries) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sql.Queries{b}
|
||||
}
|
||||
|
||||
// normalizeJSON normalize MariaDB longtext columns to type JSON.
|
||||
func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) error {
|
||||
columns := make(map[string]*Column)
|
||||
for _, c := range t.Columns {
|
||||
if c.typ == "longtext" {
|
||||
columns[c.Name] = c
|
||||
}
|
||||
}
|
||||
if len(columns) == 0 {
|
||||
return nil
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("CONSTRAINT_NAME").
|
||||
From(sql.Table("CHECK_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
d.matchSchema("CONSTRAINT_SCHEMA"),
|
||||
sql.EQ("TABLE_NAME", t.Name),
|
||||
sql.Like("CHECK_CLAUSE", "json_valid(%)"),
|
||||
)).
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return fmt.Errorf("mysql: query table constraints %w", err)
|
||||
}
|
||||
// Call Close in cases of failures (Close is idempotent).
|
||||
defer rows.Close()
|
||||
names := make([]string, 0, len(columns))
|
||||
if err := sql.ScanSlice(rows, &names); err != nil {
|
||||
return fmt.Errorf("mysql: scan table constraints: %w", err)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range names {
|
||||
c, ok := columns[name]
|
||||
if ok {
|
||||
c.Type = field.TypeJSON
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mariadb reports if the migration runs on MariaDB and returns the semver string.
|
||||
func (d *MySQL) mariadb() (string, bool) {
|
||||
idx := strings.Index(d.version, "MariaDB")
|
||||
if idx == -1 {
|
||||
return "", false
|
||||
}
|
||||
return d.version[:idx-1], true
|
||||
}
|
||||
|
||||
// parseColumn returns column parts, size and signed-info from a MySQL type.
|
||||
func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) {
|
||||
switch parts = strings.FieldsFunc(typ, func(r rune) bool {
|
||||
return r == '(' || r == ')' || r == ' ' || r == ','
|
||||
}); parts[0] {
|
||||
case "tinyint", "smallint", "mediumint", "int", "bigint":
|
||||
switch {
|
||||
case len(parts) == 2 && parts[1] == "unsigned": // int unsigned
|
||||
unsigned = true
|
||||
case len(parts) == 3: // int(10) unsigned
|
||||
unsigned = true
|
||||
fallthrough
|
||||
case len(parts) == 2: // int(10)
|
||||
size, err = strconv.ParseInt(parts[1], 10, 0)
|
||||
}
|
||||
case "varbinary", "varchar", "char", "binary":
|
||||
if len(parts) > 1 {
|
||||
size, err = strconv.ParseInt(parts[1], 10, 64)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return parts, size, unsigned, fmt.Errorf("converting %s size to int: %w", parts[0], err)
|
||||
}
|
||||
return parts, size, unsigned, nil
|
||||
}
|
||||
|
||||
// fkNames returns the foreign-key names of a column.
|
||||
func (d *MySQL) fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) {
|
||||
query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("KEY_COLUMN_USAGE").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.And(
|
||||
sql.EQ("TABLE_NAME", table),
|
||||
sql.EQ("COLUMN_NAME", column),
|
||||
// NULL for unique and primary-key constraints.
|
||||
sql.NotNull("POSITION_IN_UNIQUE_CONSTRAINT"),
|
||||
d.matchSchema(),
|
||||
)).
|
||||
Query()
|
||||
var (
|
||||
names []string
|
||||
rows = &sql.Rows{}
|
||||
)
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("mysql: reading constraint names %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := sql.ScanSlice(rows, &names); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// defaultSize returns the default size for MySQL/MariaDB varchar type
|
||||
// based on column size, charset and table indexes, in order to avoid
|
||||
// index prefix key limit (767) for older versions of MySQL/MariaDB.
|
||||
func (d *MySQL) defaultSize(c *Column) int64 {
|
||||
size := DefaultStringLen
|
||||
version, checked := d.version, "5.7.0"
|
||||
if v, ok := d.mariadb(); ok {
|
||||
version, checked = v, "10.2.2"
|
||||
}
|
||||
switch {
|
||||
// Version is >= 5.7 for MySQL, or >= 10.2.2 for MariaDB.
|
||||
case compareVersions(version, checked) != -1:
|
||||
// Column is non-unique, or not part of any index (reaching
|
||||
// the error 1071).
|
||||
case !c.Unique && len(c.indexes) == 0 && !c.PrimaryKey():
|
||||
default:
|
||||
size = 191
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// needsConversion reports if column "old" needs to be converted
|
||||
// (by table altering) to column "new".
|
||||
func (d *MySQL) needsConversion(old, new *Column) bool {
|
||||
return d.cType(old) != d.cType(new)
|
||||
}
|
||||
|
||||
// indexModified used by the migration differ to check if the index was modified.
|
||||
func (d *MySQL) indexModified(old, new *Index) bool {
|
||||
oldParts, newParts := indexParts(old), indexParts(new)
|
||||
if len(oldParts) != len(newParts) {
|
||||
return true
|
||||
}
|
||||
for column, oldPart := range oldParts {
|
||||
newPart, ok := newParts[column]
|
||||
if !ok || oldPart != newPart {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// indexParts returns a map holding the sub_part mapping if exists.
|
||||
func indexParts(idx *Index) map[string]uint {
|
||||
parts := make(map[string]uint)
|
||||
if idx.Annotation == nil {
|
||||
return parts
|
||||
}
|
||||
// If prefix (without a name) was defined on the
|
||||
// annotation, map it to the single column index.
|
||||
if idx.Annotation.Prefix > 0 && len(idx.Columns) == 1 {
|
||||
parts[idx.Columns[0].Name] = idx.Annotation.Prefix
|
||||
}
|
||||
for column, part := range idx.Annotation.PrefixColumns {
|
||||
parts[column] = part
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// Atlas integration.
|
||||
|
||||
func (d *MySQL) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
|
||||
return mysql.Open(&db{ExecQuerier: conn})
|
||||
}
|
||||
@@ -988,23 +271,56 @@ func (d *MySQL) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func indexType(idx *Index, d string) (string, bool) {
|
||||
ant := idx.Annotation
|
||||
if ant == nil {
|
||||
return "", false
|
||||
}
|
||||
if ant.Types != nil && ant.Types[d] != "" {
|
||||
return ant.Types[d], true
|
||||
}
|
||||
if ant.Type != "" {
|
||||
return ant.Type, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (MySQL) atTypeRangeSQL(ts ...string) string {
|
||||
func (*MySQL) atTypeRangeSQL(ts ...string) string {
|
||||
for i := range ts {
|
||||
ts[i] = fmt.Sprintf("('%s')", ts[i])
|
||||
}
|
||||
return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", "))
|
||||
}
|
||||
|
||||
// mariadb reports if the migration runs on MariaDB and returns the semver string.
|
||||
func (d *MySQL) mariadb() (string, bool) {
|
||||
idx := strings.Index(d.version, "MariaDB")
|
||||
if idx == -1 {
|
||||
return "", false
|
||||
}
|
||||
return d.version[:idx-1], true
|
||||
}
|
||||
|
||||
// defaultSize returns the default size for MySQL/MariaDB varchar type
|
||||
// based on column size, charset and table indexes, in order to avoid
|
||||
// index prefix key limit (767) for older versions of MySQL/MariaDB.
|
||||
func (d *MySQL) defaultSize(c *Column) int64 {
|
||||
size := DefaultStringLen
|
||||
version, checked := d.version, "5.7.0"
|
||||
if v, ok := d.mariadb(); ok {
|
||||
version, checked = v, "10.2.2"
|
||||
}
|
||||
switch {
|
||||
// Version is >= 5.7 for MySQL, or >= 10.2.2 for MariaDB.
|
||||
case compareVersions(version, checked) != -1:
|
||||
// Column is non-unique, or not part of any index (reaching
|
||||
// the error 1071).
|
||||
case !c.Unique && len(c.indexes) == 0 && !c.PrimaryKey():
|
||||
default:
|
||||
size = 191
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// indexParts returns a map holding the sub_part mapping if exists.
|
||||
func indexParts(idx *Index) map[string]uint {
|
||||
parts := make(map[string]uint)
|
||||
if idx.Annotation == nil {
|
||||
return parts
|
||||
}
|
||||
// If prefix (without a name) was defined on the
|
||||
// annotation, map it to the single column index.
|
||||
if idx.Annotation.Prefix > 0 && len(idx.Columns) == 1 {
|
||||
parts[idx.Columns[0].Name] = idx.Annotation.Prefix
|
||||
}
|
||||
for column, part := range idx.Annotation.PrefixColumns {
|
||||
parts[column] = part
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,6 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
@@ -21,7 +20,7 @@ import (
|
||||
"ariga.io/atlas/sql/schema"
|
||||
)
|
||||
|
||||
// Postgres is a postgres migration driver.
|
||||
// Postgres adapter for Atlas migration engine.
|
||||
type Postgres struct {
|
||||
dialect.Driver
|
||||
schema string
|
||||
@@ -67,453 +66,6 @@ func (d *Postgres) tableExist(ctx context.Context, conn dialect.ExecQuerier, nam
|
||||
return exist(ctx, conn, query, args...)
|
||||
}
|
||||
|
||||
// tableExist checks if a foreign-key exists in the current schema.
|
||||
func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
query, args := sql.Dialect(dialect.Postgres).
|
||||
Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")).
|
||||
Where(sql.And(
|
||||
d.matchSchema(),
|
||||
sql.EQ("constraint_type", "FOREIGN KEY"),
|
||||
sql.EQ("constraint_name", name),
|
||||
)).Query()
|
||||
return exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
// setRange sets restart the identity column to the given offset. Used by the universal-id option.
|
||||
func (d *Postgres) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
|
||||
if value == 0 {
|
||||
value = 1 // RESTART value cannot be < 1.
|
||||
}
|
||||
pk := "id"
|
||||
if len(t.PrimaryKey) == 1 {
|
||||
pk = t.PrimaryKey[0].Name
|
||||
}
|
||||
return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE %q ALTER COLUMN %q RESTART WITH %d", t.Name, pk, value), []any{}, nil)
|
||||
}
|
||||
|
||||
// table loads the current table description from the database.
|
||||
func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Dialect(dialect.Postgres).
|
||||
Select(
|
||||
"column_name", "data_type", "is_nullable", "column_default", "udt_name",
|
||||
"numeric_precision", "numeric_scale", "character_maximum_length",
|
||||
).
|
||||
From(sql.Table("columns").Schema("information_schema")).
|
||||
Where(sql.And(
|
||||
d.matchSchema(),
|
||||
sql.EQ("table_name", name),
|
||||
)).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("postgres: reading table description %w", err)
|
||||
}
|
||||
// Call `Close` in cases of failures (`Close` is idempotent).
|
||||
defer rows.Close()
|
||||
t := NewTable(name)
|
||||
for rows.Next() {
|
||||
c := &Column{}
|
||||
if err := d.scanColumn(c, rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.AddColumn(c)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("closing rows %w", err)
|
||||
}
|
||||
idxs, err := d.indexes(ctx, tx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Populate the index information to the table and its columns.
|
||||
// We do it manually, because PK and uniqueness information does
|
||||
// not exist when querying the information_schema.COLUMNS above.
|
||||
for _, idx := range idxs {
|
||||
switch {
|
||||
case idx.primary:
|
||||
for _, name := range idx.columns {
|
||||
c, ok := t.column(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
|
||||
}
|
||||
c.Key = PrimaryKey
|
||||
t.PrimaryKey = append(t.PrimaryKey, c)
|
||||
}
|
||||
case idx.Unique && len(idx.columns) == 1:
|
||||
name := idx.columns[0]
|
||||
c, ok := t.column(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
|
||||
}
|
||||
c.Key = UniqueKey
|
||||
c.Unique = true
|
||||
fallthrough
|
||||
default:
|
||||
t.addIndex(idx)
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// indexesQuery holds a query format for retrieving
|
||||
// table indexes of the current schema.
|
||||
const indexesQuery = `
|
||||
SELECT i.relname AS index_name,
|
||||
a.attname AS column_name,
|
||||
idx.indisprimary AS primary,
|
||||
idx.indisunique AS unique,
|
||||
array_position(idx.indkey, a.attnum) as seq_in_index
|
||||
FROM pg_class t,
|
||||
pg_class i,
|
||||
pg_index idx,
|
||||
pg_attribute a,
|
||||
pg_namespace n
|
||||
WHERE t.oid = idx.indrelid
|
||||
AND i.oid = idx.indexrelid
|
||||
AND n.oid = t.relnamespace
|
||||
AND a.attrelid = t.oid
|
||||
AND a.attnum = ANY(idx.indkey)
|
||||
AND t.relkind = 'r'
|
||||
AND n.nspname = %s
|
||||
AND t.relname = '%s'
|
||||
ORDER BY index_name, seq_in_index;
|
||||
`
|
||||
|
||||
// indexesQuery returns the query (and its placeholders) for getting table indexes.
|
||||
func (d *Postgres) indexesQuery(table string) (string, []any) {
|
||||
if d.schema != "" {
|
||||
return fmt.Sprintf(indexesQuery, "$1", table), []any{d.schema}
|
||||
}
|
||||
return fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", table), nil
|
||||
}
|
||||
|
||||
func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (Indexes, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := d.indexesQuery(table)
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("querying indexes for table %s: %w", table, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var (
|
||||
idxs Indexes
|
||||
names = make(map[string]*Index)
|
||||
)
|
||||
for rows.Next() {
|
||||
var (
|
||||
seqindex int
|
||||
name, column string
|
||||
unique, primary bool
|
||||
)
|
||||
if err := rows.Scan(&name, &column, &primary, &unique, &seqindex); err != nil {
|
||||
return nil, fmt.Errorf("scanning index description: %w", err)
|
||||
}
|
||||
// If the index is prefixed with the table, it may was added by
|
||||
// `addIndex` and it should be trimmed. But, since entc prefixes
|
||||
// all indexes with schema-type, for uncountable types (like, media
|
||||
// or equipment) this isn't correct, and we fallback for the real-name.
|
||||
short := strings.TrimPrefix(name, table+"_")
|
||||
idx, ok := names[short]
|
||||
if !ok {
|
||||
idx = &Index{Name: short, Unique: unique, primary: primary, realname: name}
|
||||
idxs = append(idxs, idx)
|
||||
names[short] = idx
|
||||
}
|
||||
idx.columns = append(idx.columns, column)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return idxs, nil
|
||||
}
|
||||
|
||||
// maxCharSize defines the maximum size of limited character types in Postgres (10 MB).
|
||||
const maxCharSize = 10 << 20
|
||||
|
||||
// scanColumn scans the information a column from column description.
|
||||
func (d *Postgres) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
var (
|
||||
nullable sql.NullString
|
||||
defaults sql.NullString
|
||||
udt sql.NullString
|
||||
numericPrecision sql.NullInt64
|
||||
numericScale sql.NullInt64
|
||||
characterMaximumLen sql.NullInt64
|
||||
)
|
||||
if err := rows.Scan(&c.Name, &c.typ, &nullable, &defaults, &udt, &numericPrecision, &numericScale, &characterMaximumLen); err != nil {
|
||||
return fmt.Errorf("scanning column description: %w", err)
|
||||
}
|
||||
if nullable.Valid {
|
||||
c.Nullable = nullable.String == "YES"
|
||||
}
|
||||
switch c.typ {
|
||||
case "boolean":
|
||||
c.Type = field.TypeBool
|
||||
case "smallint":
|
||||
c.Type = field.TypeInt16
|
||||
case "integer":
|
||||
c.Type = field.TypeInt32
|
||||
case "bigint":
|
||||
c.Type = field.TypeInt64
|
||||
case "real":
|
||||
c.Type = field.TypeFloat32
|
||||
case "double precision":
|
||||
c.Type = field.TypeFloat64
|
||||
case "numeric", "decimal":
|
||||
c.Type = field.TypeFloat64
|
||||
// If precision is specified then we should take that into account.
|
||||
if numericPrecision.Valid {
|
||||
schemaType := fmt.Sprintf("%s(%d,%d)", c.typ, numericPrecision.Int64, numericScale.Int64)
|
||||
c.SchemaType = map[string]string{dialect.Postgres: schemaType}
|
||||
}
|
||||
case "text":
|
||||
c.Type = field.TypeString
|
||||
c.Size = maxCharSize + 1
|
||||
case "character", "character varying":
|
||||
c.Type = field.TypeString
|
||||
// If character maximum length is specified then we should take that into account.
|
||||
if characterMaximumLen.Valid {
|
||||
schemaType := fmt.Sprintf("varchar(%d)", characterMaximumLen.Int64)
|
||||
c.SchemaType = map[string]string{dialect.Postgres: schemaType}
|
||||
}
|
||||
case "date", "time with time zone", "time without time zone", "timestamp with time zone", "timestamp without time zone":
|
||||
c.Type = field.TypeTime
|
||||
case "bytea":
|
||||
c.Type = field.TypeBytes
|
||||
case "jsonb":
|
||||
c.Type = field.TypeJSON
|
||||
case "uuid":
|
||||
c.Type = field.TypeUUID
|
||||
case "cidr", "inet", "macaddr", "macaddr8":
|
||||
c.Type = field.TypeOther
|
||||
case "point", "line", "lseg", "box", "path", "polygon", "circle":
|
||||
c.Type = field.TypeOther
|
||||
case "ARRAY":
|
||||
c.Type = field.TypeOther
|
||||
if !udt.Valid {
|
||||
return fmt.Errorf("missing array type for column %q", c.Name)
|
||||
}
|
||||
// Note that for ARRAY types, the 'udt_name' column holds the array type
|
||||
// prefixed with '_'. For example, for 'integer[]' the result is '_int',
|
||||
// and for 'text[N][M]' the result is also '_text'. That's because, the
|
||||
// database ignores any size or multi-dimensions constraints.
|
||||
c.SchemaType = map[string]string{dialect.Postgres: "ARRAY"}
|
||||
c.typ = udt.String
|
||||
case "USER-DEFINED", "tstzrange", "interval":
|
||||
c.Type = field.TypeOther
|
||||
if !udt.Valid {
|
||||
return fmt.Errorf("missing user defined type for column %q", c.Name)
|
||||
}
|
||||
c.SchemaType = map[string]string{dialect.Postgres: udt.String}
|
||||
}
|
||||
switch {
|
||||
case !defaults.Valid || c.Type == field.TypeTime || callExpr(defaults.String):
|
||||
return nil
|
||||
case strings.Contains(defaults.String, "::"):
|
||||
parts := strings.Split(defaults.String, "::")
|
||||
defaults.String = strings.Trim(parts[0], "'")
|
||||
fallthrough
|
||||
default:
|
||||
return c.ScanDefault(defaults.String)
|
||||
}
|
||||
}
|
||||
|
||||
// tBuilder returns the TableBuilder for the given table.
|
||||
func (d *Postgres) tBuilder(t *Table) *sql.TableBuilder {
|
||||
b := sql.Dialect(dialect.Postgres).
|
||||
CreateTable(t.Name).IfNotExists()
|
||||
for _, c := range t.Columns {
|
||||
b.Column(d.addColumn(c))
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
if t.Annotation != nil {
|
||||
addChecks(b, t.Annotation)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// cType returns the PostgreSQL string type for this column.
|
||||
func (d *Postgres) cType(c *Column) (t string) {
|
||||
if c.SchemaType != nil && c.SchemaType[dialect.Postgres] != "" {
|
||||
return c.SchemaType[dialect.Postgres]
|
||||
}
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "boolean"
|
||||
case field.TypeUint8, field.TypeInt8, field.TypeInt16, field.TypeUint16:
|
||||
t = "smallint"
|
||||
case field.TypeInt32, field.TypeUint32:
|
||||
t = "int"
|
||||
case field.TypeInt, field.TypeUint, field.TypeInt64, field.TypeUint64:
|
||||
t = "bigint"
|
||||
case field.TypeFloat32:
|
||||
t = c.scanTypeOr("real")
|
||||
case field.TypeFloat64:
|
||||
t = c.scanTypeOr("double precision")
|
||||
case field.TypeBytes:
|
||||
t = "bytea"
|
||||
case field.TypeJSON:
|
||||
t = "jsonb"
|
||||
case field.TypeUUID:
|
||||
t = "uuid"
|
||||
case field.TypeString:
|
||||
t = "varchar"
|
||||
if c.Size > maxCharSize {
|
||||
t = "text"
|
||||
}
|
||||
case field.TypeTime:
|
||||
t = c.scanTypeOr("timestamp with time zone")
|
||||
case field.TypeEnum:
|
||||
// Currently, the support for enums is weak (application level only.
|
||||
// like SQLite). Dialect needs to create and maintain its enum type.
|
||||
t = "varchar"
|
||||
case field.TypeOther:
|
||||
t = c.typ
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// addColumn returns the ColumnBuilder for adding the given column to a table.
|
||||
func (d *Postgres) addColumn(c *Column) *sql.ColumnBuilder {
|
||||
b := sql.Dialect(dialect.Postgres).
|
||||
Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("GENERATED BY DEFAULT AS IDENTITY")
|
||||
}
|
||||
c.nullable(b)
|
||||
d.writeDefault(b, c, "DEFAULT")
|
||||
if c.Collation != "" {
|
||||
b.Attr("COLLATE " + strconv.Quote(c.Collation))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// writeDefault writes the `DEFAULT` clause to column builder
|
||||
// if exists and supported by the driver.
|
||||
func (d *Postgres) writeDefault(b *sql.ColumnBuilder, c *Column, clause string) {
|
||||
if c.Default == nil || !c.supportDefault() {
|
||||
return
|
||||
}
|
||||
attr := fmt.Sprint(c.Default)
|
||||
switch v := c.Default.(type) {
|
||||
case bool:
|
||||
attr = strconv.FormatBool(v)
|
||||
case string:
|
||||
if t := c.Type; t != field.TypeUUID && t != field.TypeTime && !t.Numeric() {
|
||||
// Escape single quote by replacing each with 2.
|
||||
attr = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
|
||||
}
|
||||
}
|
||||
b.Attr(clause + " " + attr)
|
||||
}
|
||||
|
||||
// alterColumn returns list of ColumnBuilder for applying in order to alter a column.
|
||||
func (d *Postgres) alterColumn(c *Column) (ops []*sql.ColumnBuilder) {
|
||||
b := sql.Dialect(dialect.Postgres)
|
||||
ops = append(ops, b.Column(c.Name).Type(d.cType(c)))
|
||||
if c.Nullable {
|
||||
ops = append(ops, b.Column(c.Name).Attr("DROP NOT NULL"))
|
||||
} else {
|
||||
ops = append(ops, b.Column(c.Name).Attr("SET NOT NULL"))
|
||||
}
|
||||
if c.Default != nil && c.supportDefault() {
|
||||
ops = append(ops, d.writeSetDefault(b.Column(c.Name), c))
|
||||
}
|
||||
return ops
|
||||
}
|
||||
|
||||
func (d *Postgres) writeSetDefault(b *sql.ColumnBuilder, c *Column) *sql.ColumnBuilder {
|
||||
d.writeDefault(b, c, "SET DEFAULT")
|
||||
return b
|
||||
}
|
||||
|
||||
// hasUniqueName reports if the index has a unique name in the schema.
|
||||
func hasUniqueName(i *Index) bool {
|
||||
// Trim the "_key" suffix if it was added by Postgres for implicit indexes.
|
||||
name := strings.TrimSuffix(i.Name, "_key")
|
||||
suffix := strings.Join(i.columnNames(), "_")
|
||||
if !strings.HasSuffix(name, suffix) {
|
||||
return true // Assume it has a custom storage-key.
|
||||
}
|
||||
// The codegen prefixes by default indexes with the type name.
|
||||
// For example, an index "users"("name"), will named as "user_name".
|
||||
return name != suffix
|
||||
}
|
||||
|
||||
// addIndex returns the query for adding an index to PostgreSQL.
|
||||
func (d *Postgres) addIndex(i *Index, table string) *sql.IndexBuilder {
|
||||
name := i.Name
|
||||
if !hasUniqueName(i) {
|
||||
// Since index name should be unique in pg_class for schema,
|
||||
// we prefix it with the table name and remove on read.
|
||||
name = fmt.Sprintf("%s_%s", table, i.Name)
|
||||
}
|
||||
idx := sql.Dialect(dialect.Postgres).
|
||||
CreateIndex(name).IfNotExists().Table(table)
|
||||
if i.Unique {
|
||||
idx.Unique()
|
||||
}
|
||||
for _, c := range i.Columns {
|
||||
idx.Column(c.Name)
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
// dropIndex drops a Postgres index.
|
||||
func (d *Postgres) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error {
|
||||
name := idx.Name
|
||||
build := sql.Dialect(dialect.Postgres)
|
||||
if prefix := table + "_"; !strings.HasPrefix(name, prefix) && !hasUniqueName(idx) {
|
||||
name = prefix + name
|
||||
}
|
||||
query, args := sql.Dialect(dialect.Postgres).
|
||||
Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")).
|
||||
Where(sql.And(
|
||||
d.matchSchema(),
|
||||
sql.EQ("constraint_type", "UNIQUE"),
|
||||
sql.EQ("constraint_name", name),
|
||||
)).
|
||||
Query()
|
||||
exists, err := exist(ctx, tx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query, args = build.DropIndex(name).Query()
|
||||
if exists {
|
||||
query, args = build.AlterTable(table).DropConstraint(name).Query()
|
||||
}
|
||||
return tx.Exec(ctx, query, args, nil)
|
||||
}
|
||||
|
||||
// isImplicitIndex reports if the index was created implicitly for the unique column.
|
||||
func (d *Postgres) isImplicitIndex(idx *Index, col *Column) bool {
|
||||
return strings.TrimSuffix(idx.Name, "_key") == col.Name && col.Unique
|
||||
}
|
||||
|
||||
// renameColumn returns the statement for renaming a column.
|
||||
func (d *Postgres) renameColumn(t *Table, old, new *Column) sql.Querier {
|
||||
return sql.Dialect(dialect.Postgres).
|
||||
AlterTable(t.Name).
|
||||
RenameColumn(old.Name, new.Name)
|
||||
}
|
||||
|
||||
// renameIndex returns the statement for renaming an index.
|
||||
func (d *Postgres) renameIndex(t *Table, old, new *Index) sql.Querier {
|
||||
if sfx := "_key"; strings.HasSuffix(old.Name, sfx) && !strings.HasSuffix(new.Name, sfx) {
|
||||
new.Name += sfx
|
||||
}
|
||||
if pfx := t.Name + "_"; strings.HasPrefix(old.realname, pfx) && !strings.HasPrefix(new.Name, pfx) {
|
||||
new.Name = pfx + new.Name
|
||||
}
|
||||
return sql.Dialect(dialect.Postgres).AlterIndex(old.realname).Rename(new.Name)
|
||||
}
|
||||
|
||||
// matchSchema returns the predicate for matching table schema.
|
||||
func (d *Postgres) matchSchema(columns ...string) *sql.Predicate {
|
||||
column := "table_schema"
|
||||
@@ -526,156 +78,8 @@ func (d *Postgres) matchSchema(columns ...string) *sql.Predicate {
|
||||
return sql.EQ(column, sql.Raw("CURRENT_SCHEMA()"))
|
||||
}
|
||||
|
||||
// tables returns the query for getting the in the schema.
|
||||
func (d *Postgres) tables() sql.Querier {
|
||||
return sql.Dialect(dialect.Postgres).
|
||||
Select("table_name").
|
||||
From(sql.Table("tables").Schema("information_schema")).
|
||||
Where(d.matchSchema())
|
||||
}
|
||||
|
||||
// alterColumns returns the queries for applying the columns change-set.
|
||||
func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
|
||||
b := sql.Dialect(dialect.Postgres).AlterTable(table)
|
||||
for _, c := range add {
|
||||
b.AddColumn(d.addColumn(c))
|
||||
}
|
||||
for _, c := range modify {
|
||||
b.ModifyColumns(d.alterColumn(c)...)
|
||||
}
|
||||
for _, c := range drop {
|
||||
b.DropColumn(sql.Dialect(dialect.Postgres).Column(c.Name))
|
||||
}
|
||||
if len(b.Queries) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sql.Queries{b}
|
||||
}
|
||||
|
||||
// needsConversion reports if column "old" needs to be converted
|
||||
// (by table altering) to column "new".
|
||||
func (d *Postgres) needsConversion(old, new *Column) bool {
|
||||
oldT, newT := d.cType(old), d.cType(new)
|
||||
return oldT != newT && (oldT != "ARRAY" || !arrayType(newT))
|
||||
}
|
||||
|
||||
// callExpr reports if the given string ~looks like a function call expression.
|
||||
func callExpr(s string) bool {
|
||||
if parts := strings.Split(s, "::"); !strings.HasSuffix(s, ")") && strings.HasSuffix(parts[0], ")") {
|
||||
s = parts[0]
|
||||
}
|
||||
i, j := strings.IndexByte(s, '('), strings.LastIndexByte(s, ')')
|
||||
if i == -1 || i > j || j != len(s)-1 {
|
||||
return false
|
||||
}
|
||||
for i, r := range s[:i] {
|
||||
if !isAlpha(r, i > 0) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isAlpha(r rune, digit bool) bool {
|
||||
return 'a' <= r && r <= 'z' || 'A' <= r && r <= 'Z' || r == '_' || digit && '0' <= r && r <= '9'
|
||||
}
|
||||
|
||||
// arrayType reports if the given string is an array type (e.g. int[], text[2]).
|
||||
func arrayType(t string) bool {
|
||||
i, j := strings.LastIndexByte(t, '['), strings.LastIndexByte(t, ']')
|
||||
if i == -1 || j == -1 {
|
||||
return false
|
||||
}
|
||||
for _, r := range t[i+1 : j] {
|
||||
if !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// foreignKeys populates the tables foreign keys using the information_schema tables
|
||||
func (d *Postgres) foreignKeys(ctx context.Context, tx dialect.Tx, tables []*Table) error {
|
||||
var tableLookup = make(map[string]*Table)
|
||||
for _, t := range tables {
|
||||
tableLookup[t.Name] = t
|
||||
}
|
||||
for _, t := range tables {
|
||||
rows := &sql.Rows{}
|
||||
query := fmt.Sprintf(fkQuery, t.Name)
|
||||
if err := tx.Query(ctx, query, []any{}, rows); err != nil {
|
||||
return fmt.Errorf("querying foreign keys for table %s: %w", t.Name, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var tableFksLookup = make(map[string]*ForeignKey)
|
||||
for rows.Next() {
|
||||
var tableSchema, constraintName, tableName, columnName, refTableSchema, refTableName, refColumnName string
|
||||
if err := rows.Scan(&tableSchema, &constraintName, &tableName, &columnName, &refTableSchema, &refTableName, &refColumnName); err != nil {
|
||||
return fmt.Errorf("scanning index description: %w", err)
|
||||
}
|
||||
refTable := tableLookup[refTableName]
|
||||
if refTable == nil {
|
||||
return fmt.Errorf("could not find table: %s", refTableName)
|
||||
}
|
||||
column, ok := t.column(columnName)
|
||||
if !ok {
|
||||
return fmt.Errorf("could not find column: %s on table: %s", columnName, tableName)
|
||||
}
|
||||
refColumn, ok := refTable.column(refColumnName)
|
||||
if !ok {
|
||||
return fmt.Errorf("could not find ref column: %s on ref table: %s", refTableName, refColumnName)
|
||||
}
|
||||
if fk, ok := tableFksLookup[constraintName]; ok {
|
||||
if _, ok := fk.column(columnName); !ok {
|
||||
fk.Columns = append(fk.Columns, column)
|
||||
}
|
||||
if _, ok := fk.refColumn(refColumnName); !ok {
|
||||
fk.RefColumns = append(fk.RefColumns, refColumn)
|
||||
}
|
||||
} else {
|
||||
newFk := &ForeignKey{
|
||||
Symbol: constraintName,
|
||||
Columns: []*Column{column},
|
||||
RefTable: refTable,
|
||||
RefColumns: []*Column{refColumn},
|
||||
}
|
||||
tableFksLookup[constraintName] = newFk
|
||||
t.AddForeignKey(newFk)
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fkQuery holds a query format for retrieving
|
||||
// foreign keys of the current schema.
|
||||
const fkQuery = `
|
||||
SELECT tc.table_schema,
|
||||
tc.constraint_name,
|
||||
tc.table_name,
|
||||
kcu.column_name,
|
||||
ccu.table_schema AS foreign_table_schema,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_name = '%s'
|
||||
order by constraint_name, kcu.ordinal_position;
|
||||
`
|
||||
|
||||
// Atlas integration.
|
||||
// maxCharSize defines the maximum size of limited character types in Postgres (10 MB).
|
||||
const maxCharSize = 10 << 20
|
||||
|
||||
func (d *Postgres) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
|
||||
return postgres.Open(&db{ExecQuerier: conn})
|
||||
@@ -843,7 +247,7 @@ func (d *Postgres) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Postgres) atTypeRangeSQL(ts ...string) string {
|
||||
func (*Postgres) atTypeRangeSQL(ts ...string) string {
|
||||
for i := range ts {
|
||||
ts[i] = fmt.Sprintf("('%s')", ts[i])
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,6 @@ package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -186,30 +185,6 @@ func (t *Table) index(name string) (*Index, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// hasIndex reports if the table has at least one index that matches the given names.
|
||||
func (t *Table) hasIndex(names ...string) bool {
|
||||
for i := range names {
|
||||
if names[i] == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := t.index(names[i]); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// fk returns a table foreign-key by its symbol.
|
||||
// faster than map lookup for most cases.
|
||||
func (t *Table) fk(symbol string) (*ForeignKey, bool) {
|
||||
for _, fk := range t.ForeignKeys {
|
||||
if fk.Symbol == symbol {
|
||||
return fk, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// CopyTables returns a deep-copy of the given tables. This utility function is
|
||||
// useful for copying the generated schema tables (i.e. migrate.Tables) before
|
||||
// running schema migration when there is a need for execute multiple migrations
|
||||
@@ -417,27 +392,6 @@ func (c *Column) ScanDefault(value string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// defaultValue adds the `DEFAULT` attribute to the column.
|
||||
// Note that, in SQLite if a NOT NULL constraint is specified,
|
||||
// then the column must have a default value which not NULL.
|
||||
func (c *Column) defaultValue(b *sql.ColumnBuilder) {
|
||||
if c.Default == nil || !c.supportDefault() {
|
||||
return
|
||||
}
|
||||
// Has default and the database supports adding this default.
|
||||
attr := fmt.Sprint(c.Default)
|
||||
switch v := c.Default.(type) {
|
||||
case bool:
|
||||
attr = strconv.FormatBool(v)
|
||||
case string:
|
||||
if t := c.Type; t != field.TypeUUID && t != field.TypeTime {
|
||||
// Escape single quote by replacing each with 2.
|
||||
attr = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
|
||||
}
|
||||
}
|
||||
b.Attr("DEFAULT " + attr)
|
||||
}
|
||||
|
||||
// supportDefault reports if the column type supports default value.
|
||||
func (c Column) supportDefault() bool {
|
||||
switch t := c.Type; t {
|
||||
@@ -450,25 +404,6 @@ func (c Column) supportDefault() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// unique adds the `UNIQUE` attribute if the column is a unique type.
|
||||
// it is exist in a different function to share the common declaration
|
||||
// between the two dialects.
|
||||
func (c *Column) unique(b *sql.ColumnBuilder) {
|
||||
if c.Unique {
|
||||
b.Attr("UNIQUE")
|
||||
}
|
||||
}
|
||||
|
||||
// nullable adds the `NULL`/`NOT NULL` attribute to the column if it exists in
|
||||
// a different function to share the common declaration between the two dialects.
|
||||
func (c *Column) nullable(b *sql.ColumnBuilder) {
|
||||
attr := Null
|
||||
if !c.Nullable {
|
||||
attr = "NOT " + attr
|
||||
}
|
||||
b.Attr(attr)
|
||||
}
|
||||
|
||||
// scanTypeOr returns the scanning type or the given value.
|
||||
func (c *Column) scanTypeOr(t string) string {
|
||||
if c.typ != "" {
|
||||
@@ -487,24 +422,6 @@ type ForeignKey struct {
|
||||
OnDelete ReferenceOption // action on delete.
|
||||
}
|
||||
|
||||
func (fk ForeignKey) column(name string) (*Column, bool) {
|
||||
for _, c := range fk.Columns {
|
||||
if c.Name == name {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (fk ForeignKey) refColumn(name string) (*Column, bool) {
|
||||
for _, c := range fk.RefColumns {
|
||||
if c.Name == name {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// DSL returns a default DSL query for a foreign-key.
|
||||
func (fk ForeignKey) DSL() *sql.ForeignKeyBuilder {
|
||||
cols := make([]string, len(fk.Columns))
|
||||
@@ -551,7 +468,6 @@ type Index struct {
|
||||
Columns []*Column // actual table columns.
|
||||
Annotation *entsql.IndexAnnotation // index annotation.
|
||||
columns []string // columns loaded from query scan.
|
||||
primary bool // primary key index.
|
||||
realname string // real name in the database (Postgres only).
|
||||
}
|
||||
|
||||
@@ -573,32 +489,6 @@ func (i *Index) DropBuilder(table string) *sql.DropIndexBuilder {
|
||||
return idx
|
||||
}
|
||||
|
||||
// sameAs reports if the index has the same properties
|
||||
// as the given index (except the name).
|
||||
func (i *Index) sameAs(idx *Index) bool {
|
||||
if i.Unique != idx.Unique || len(i.Columns) != len(idx.Columns) {
|
||||
return false
|
||||
}
|
||||
for j, c := range i.Columns {
|
||||
if c.Name != idx.Columns[j].Name {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// columnNames returns the names of the columns of the index.
|
||||
func (i *Index) columnNames() []string {
|
||||
if len(i.columns) > 0 {
|
||||
return i.columns
|
||||
}
|
||||
columns := make([]string, 0, len(i.Columns))
|
||||
for _, c := range i.Columns {
|
||||
columns = append(columns, c.Name)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
// Indexes used for scanning all sql.Rows into a list of indexes, because
|
||||
// multiple sql rows can represent the same index (multi-columns indexes).
|
||||
type Indexes []*Index
|
||||
@@ -673,33 +563,16 @@ func compare(v1, v2 int) int {
|
||||
return 1
|
||||
}
|
||||
|
||||
// addChecks appends the CHECK clauses from the entsql.Annotation.
|
||||
func addChecks(t *sql.TableBuilder, ant *entsql.Annotation) {
|
||||
if check := ant.Check; check != "" {
|
||||
t.Checks(func(b *sql.Builder) {
|
||||
b.WriteString("CHECK " + checkExpr(check))
|
||||
})
|
||||
func indexType(idx *Index, d string) (string, bool) {
|
||||
ant := idx.Annotation
|
||||
if ant == nil {
|
||||
return "", false
|
||||
}
|
||||
if checks := ant.Checks; len(ant.Checks) > 0 {
|
||||
names := make([]string, 0, len(checks))
|
||||
for name := range checks {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
name := name
|
||||
t.Checks(func(b *sql.Builder) {
|
||||
b.WriteString("CONSTRAINT ").Ident(name).WriteString(" CHECK " + checkExpr(checks[name]))
|
||||
})
|
||||
}
|
||||
if ant.Types != nil && ant.Types[d] != "" {
|
||||
return ant.Types[d], true
|
||||
}
|
||||
}
|
||||
|
||||
// checkExpr formats the CHECK expression.
|
||||
func checkExpr(expr string) string {
|
||||
expr = strings.TrimSpace(expr)
|
||||
if !strings.HasPrefix(expr, "(") && !strings.HasSuffix(expr, ")") {
|
||||
expr = "(" + expr + ")"
|
||||
}
|
||||
return expr
|
||||
if ant.Type != "" {
|
||||
return ant.Type, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// SQLite is an SQLite migration driver.
|
||||
// SQLite adapter for Atlas migration engine.
|
||||
SQLite struct {
|
||||
dialect.Driver
|
||||
WithForeignKeys bool
|
||||
@@ -88,309 +88,6 @@ func (d *SQLite) tableExist(ctx context.Context, conn dialect.ExecQuerier, name
|
||||
return exist(ctx, conn, query, args...)
|
||||
}
|
||||
|
||||
// setRange sets the start value of table PK.
|
||||
// SQLite tracks the AUTOINCREMENT in the "sqlite_sequence" table that is created and initialized automatically
|
||||
// whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables)
|
||||
// only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence"
|
||||
// table, we updated it. Otherwise, we insert a new value.
|
||||
func (d *SQLite) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
|
||||
query, args := sql.Select().Count().
|
||||
From(sql.Table("sqlite_sequence")).
|
||||
Where(sql.EQ("name", t.Name)).
|
||||
Query()
|
||||
exists, err := exist(ctx, conn, query, args...)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case exists:
|
||||
query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", t.Name)).Query()
|
||||
default: // !exists
|
||||
query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(t.Name, value).Query()
|
||||
}
|
||||
return conn.Exec(ctx, query, args, nil)
|
||||
}
|
||||
|
||||
func (d *SQLite) tBuilder(t *Table) *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name)
|
||||
for _, c := range t.Columns {
|
||||
b.Column(d.addColumn(c))
|
||||
}
|
||||
if t.Annotation != nil {
|
||||
addChecks(b, t.Annotation)
|
||||
}
|
||||
// Unlike in MySQL, we're not able to add foreign-key constraints to table
|
||||
// after it was created, and adding them to the `CREATE TABLE` statement is
|
||||
// not always valid (because circular foreign-keys situation is possible).
|
||||
// We stay consistent by not using constraints at all, and just defining the
|
||||
// foreign keys in the `CREATE TABLE` statement.
|
||||
if d.WithForeignKeys {
|
||||
for _, fk := range t.ForeignKeys {
|
||||
b.ForeignKeys(fk.DSL())
|
||||
}
|
||||
}
|
||||
// If it's an ID based primary key with autoincrement, we add
|
||||
// the `PRIMARY KEY` clause to the column declaration. Otherwise,
|
||||
// we append it to the constraint clause.
|
||||
if len(t.PrimaryKey) == 1 && t.PrimaryKey[0].Increment {
|
||||
return b
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// cType returns the SQLite string type for the given column.
|
||||
func (*SQLite) cType(c *Column) (t string) {
|
||||
if c.SchemaType != nil && c.SchemaType[dialect.SQLite] != "" {
|
||||
return c.SchemaType[dialect.SQLite]
|
||||
}
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "bool"
|
||||
case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32,
|
||||
field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64:
|
||||
t = "integer"
|
||||
case field.TypeBytes:
|
||||
t = "blob"
|
||||
case field.TypeString, field.TypeEnum:
|
||||
// SQLite does not impose any length restrictions on
|
||||
// the length of strings, BLOBs or numeric values.
|
||||
t = fmt.Sprintf("varchar(%d)", DefaultStringLen)
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = "real"
|
||||
case field.TypeTime:
|
||||
t = "datetime"
|
||||
case field.TypeJSON:
|
||||
t = "json"
|
||||
case field.TypeUUID:
|
||||
t = "uuid"
|
||||
case field.TypeOther:
|
||||
t = c.typ
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type, c.Name))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// addColumn returns the DSL query for adding the given column to a table.
|
||||
func (d *SQLite) addColumn(c *Column) *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.PrimaryKey() && c.Increment {
|
||||
b.Attr("PRIMARY KEY AUTOINCREMENT")
|
||||
}
|
||||
c.nullable(b)
|
||||
c.defaultValue(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// addIndex returns the query for adding an index to SQLite.
|
||||
func (d *SQLite) addIndex(i *Index, table string) *sql.IndexBuilder {
|
||||
return i.Builder(table).IfNotExists()
|
||||
}
|
||||
|
||||
// dropIndex drops a SQLite index.
|
||||
func (d *SQLite) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error {
|
||||
query, args := idx.DropBuilder("").Query()
|
||||
return tx.Exec(ctx, query, args, nil)
|
||||
}
|
||||
|
||||
// fkExist returns always true to disable foreign-keys creation after the table was created.
|
||||
func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil }
|
||||
|
||||
// table returns always error to indicate that SQLite dialect doesn't support incremental migration.
|
||||
func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("name", "type", "notnull", "dflt_value", "pk").
|
||||
From(sql.Table(fmt.Sprintf("pragma_table_info('%s')", name)).Unquote()).
|
||||
OrderBy("pk").
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: reading table description %w", err)
|
||||
}
|
||||
// Call Close in cases of failures (Close is idempotent).
|
||||
defer rows.Close()
|
||||
t := NewTable(name)
|
||||
for rows.Next() {
|
||||
c := &Column{}
|
||||
if err := d.scanColumn(c, rows); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: %w", err)
|
||||
}
|
||||
if c.PrimaryKey() {
|
||||
t.PrimaryKey = append(t.PrimaryKey, c)
|
||||
}
|
||||
t.AddColumn(c)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: closing rows %w", err)
|
||||
}
|
||||
indexes, err := d.indexes(ctx, tx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Add and link indexes to table columns.
|
||||
for _, idx := range indexes {
|
||||
switch {
|
||||
case idx.primary:
|
||||
case idx.Unique && len(idx.columns) == 1:
|
||||
name := idx.columns[0]
|
||||
c, ok := t.column(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
|
||||
}
|
||||
c.Key = UniqueKey
|
||||
c.Unique = true
|
||||
fallthrough
|
||||
default:
|
||||
t.addIndex(idx)
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// table loads the table indexes from the database.
|
||||
func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Indexes, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("name", "unique", "origin").
|
||||
From(sql.Table(fmt.Sprintf("pragma_index_list('%s')", name)).Unquote()).
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("reading table indexes %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var idx Indexes
|
||||
for rows.Next() {
|
||||
i := &Index{}
|
||||
origin := sql.NullString{}
|
||||
if err := rows.Scan(&i.Name, &i.Unique, &origin); err != nil {
|
||||
return nil, fmt.Errorf("scanning index description %w", err)
|
||||
}
|
||||
i.primary = origin.String == "pk"
|
||||
idx = append(idx, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("closing rows %w", err)
|
||||
}
|
||||
for i := range idx {
|
||||
columns, err := d.indexColumns(ctx, tx, idx[i].Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idx[i].columns = columns
|
||||
// Normalize implicit index names to ent naming convention. See:
|
||||
// https://github.com/sqlite/sqlite/blob/e937df8/src/build.c#L3583
|
||||
if len(columns) == 1 && strings.HasPrefix(idx[i].Name, "sqlite_autoindex_"+name) {
|
||||
idx[i].Name = columns[0]
|
||||
}
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
// indexColumns loads index columns from index info.
|
||||
func (d *SQLite) indexColumns(ctx context.Context, tx dialect.Tx, name string) ([]string, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("name").
|
||||
From(sql.Table(fmt.Sprintf("pragma_index_info('%s')", name)).Unquote()).
|
||||
OrderBy("seqno").
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("reading table indexes %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var names []string
|
||||
if err := sql.ScanSlice(rows, &names); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// scanColumn scans the column information from SQLite column description.
|
||||
func (d *SQLite) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
var (
|
||||
pk sql.NullInt64
|
||||
notnull sql.NullInt64
|
||||
defaults sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&c.Name, &c.typ, ¬null, &defaults, &pk); err != nil {
|
||||
return fmt.Errorf("scanning column description: %w", err)
|
||||
}
|
||||
c.Nullable = notnull.Int64 == 0
|
||||
if pk.Int64 > 0 {
|
||||
c.Key = PrimaryKey
|
||||
}
|
||||
if c.typ == "" {
|
||||
return fmt.Errorf("missing type information for column %q", c.Name)
|
||||
}
|
||||
parts, size, _, err := parseColumn(c.typ)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch strings.ToLower(parts[0]) {
|
||||
case "bool", "boolean":
|
||||
c.Type = field.TypeBool
|
||||
case "blob":
|
||||
c.Type = field.TypeBytes
|
||||
case "integer":
|
||||
// All integer types have the same "type affinity".
|
||||
c.Type = field.TypeInt
|
||||
case "real", "float", "double":
|
||||
c.Type = field.TypeFloat64
|
||||
case "datetime":
|
||||
c.Type = field.TypeTime
|
||||
case "json":
|
||||
c.Type = field.TypeJSON
|
||||
case "uuid":
|
||||
c.Type = field.TypeUUID
|
||||
case "varchar", "char", "text":
|
||||
c.Size = size
|
||||
c.Type = field.TypeString
|
||||
case "decimal", "numeric":
|
||||
c.Type = field.TypeOther
|
||||
}
|
||||
if defaults.Valid {
|
||||
return c.ScanDefault(defaults.String)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// alterColumns returns the queries for applying the columns change-set.
|
||||
func (d *SQLite) alterColumns(table string, add, _, _ []*Column) sql.Queries {
|
||||
queries := make(sql.Queries, 0, len(add))
|
||||
for i := range add {
|
||||
c := d.addColumn(add[i])
|
||||
if fk := add[i].foreign; fk != nil {
|
||||
c.Constraint(fk.DSL())
|
||||
}
|
||||
queries = append(queries, sql.Dialect(dialect.SQLite).AlterTable(table).AddColumn(c))
|
||||
}
|
||||
// Modifying and dropping columns is not supported and disabled until we
|
||||
// will support https://www.sqlite.org/lang_altertable.html#otheralter
|
||||
return queries
|
||||
}
|
||||
|
||||
// tables returns the query for getting the in the schema.
|
||||
func (d *SQLite) tables() sql.Querier {
|
||||
return sql.Select("name").
|
||||
From(sql.Table("sqlite_schema")).
|
||||
Where(sql.EQ("type", "table"))
|
||||
}
|
||||
|
||||
// needsConversion reports if column "old" needs to be converted
|
||||
// (by table altering) to column "new".
|
||||
func (d *SQLite) needsConversion(old, new *Column) bool {
|
||||
c1, c2 := d.cType(old), d.cType(new)
|
||||
return c1 != c2 && old.typ != c2
|
||||
}
|
||||
|
||||
// Atlas integration.
|
||||
|
||||
func (d *SQLite) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
|
||||
return sqlite.Open(&db{ExecQuerier: conn})
|
||||
}
|
||||
|
||||
@@ -1,478 +0,0 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/schema/field"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSQLite_Create(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tables []*Table
|
||||
options []MigrateOption
|
||||
before func(sqliteMock)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "tx failed",
|
||||
before: func(mock sqliteMock) {
|
||||
mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fk disabled",
|
||||
before: func(mock sqliteMock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery("PRAGMA foreign_keys").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(0))
|
||||
mock.ExpectRollback()
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no tables",
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create new table",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Nullable: true},
|
||||
{Name: "age", Type: field.TypeInt},
|
||||
{Name: "doc", Type: field.TypeJSON, Nullable: true},
|
||||
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
|
||||
{Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.SQLite: "decimal(6,2)"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create new table with foreign key",
|
||||
tables: func() []*Table {
|
||||
var (
|
||||
c1 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Nullable: true},
|
||||
{Name: "created_at", Type: field.TypeTime},
|
||||
}
|
||||
c2 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString},
|
||||
{Name: "owner_id", Type: field.TypeInt, Nullable: true},
|
||||
}
|
||||
t1 = &Table{
|
||||
Name: "users",
|
||||
Columns: c1,
|
||||
PrimaryKey: c1[0:1],
|
||||
}
|
||||
t2 = &Table{
|
||||
Name: "pets",
|
||||
Columns: c2,
|
||||
PrimaryKey: c2[0:1],
|
||||
ForeignKeys: []*ForeignKey{
|
||||
{
|
||||
Symbol: "pets_owner",
|
||||
Columns: c2[2:],
|
||||
RefTable: t1,
|
||||
RefColumns: c1[0:1],
|
||||
OnDelete: Cascade,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
return []*Table{t1, t2}
|
||||
}(),
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` datetime NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.tableExists("pets", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create new table with foreign key disabled",
|
||||
options: []MigrateOption{
|
||||
WithForeignKeys(false),
|
||||
},
|
||||
tables: func() []*Table {
|
||||
var (
|
||||
c1 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Nullable: true},
|
||||
{Name: "created_at", Type: field.TypeTime},
|
||||
}
|
||||
c2 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString},
|
||||
{Name: "owner_id", Type: field.TypeInt, Nullable: true},
|
||||
}
|
||||
t1 = &Table{
|
||||
Name: "users",
|
||||
Columns: c1,
|
||||
PrimaryKey: c1[0:1],
|
||||
}
|
||||
t2 = &Table{
|
||||
Name: "pets",
|
||||
Columns: c2,
|
||||
PrimaryKey: c2[0:1],
|
||||
ForeignKeys: []*ForeignKey{
|
||||
{
|
||||
Symbol: "pets_owner",
|
||||
Columns: c2[2:],
|
||||
RefTable: t1,
|
||||
RefColumns: c1[0:1],
|
||||
OnDelete: Cascade,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
return []*Table{t1, t2}
|
||||
}(),
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` datetime NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.tableExists("pets", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add column to table",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Nullable: true},
|
||||
{Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32},
|
||||
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
|
||||
{Name: "age", Type: field.TypeInt, Default: 0},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.tableExists("users", true)
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("name", "varchar(255)", 0, nil, 0).
|
||||
AddRow("text", "text", 0, "NULL", 0).
|
||||
AddRow("uuid", "uuid", 0, "Null", 0).
|
||||
AddRow("id", "integer", 1, "NULL", 1))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "datetime and timestamp",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "created_at", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "updated_at", Type: field.TypeTime, Nullable: true},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.tableExists("users", true)
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("created_at", "datetime", 0, nil, 0).
|
||||
AddRow("id", "integer", 1, "NULL", 1))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add blob columns",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "blobs",
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "old_tiny", Type: field.TypeBytes, Size: 100},
|
||||
{Name: "old_blob", Type: field.TypeBytes, Size: 1e3},
|
||||
{Name: "old_medium", Type: field.TypeBytes, Size: 1e5},
|
||||
{Name: "old_long", Type: field.TypeBytes, Size: 1e8},
|
||||
{Name: "new_tiny", Type: field.TypeBytes, Size: 100},
|
||||
{Name: "new_blob", Type: field.TypeBytes, Size: 1e3},
|
||||
{Name: "new_medium", Type: field.TypeBytes, Size: 1e5},
|
||||
{Name: "new_long", Type: field.TypeBytes, Size: 1e8},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.tableExists("blobs", true)
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('blobs') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("old_tiny", "blob", 1, nil, 0).
|
||||
AddRow("old_blob", "blob", 1, nil, 0).
|
||||
AddRow("old_medium", "blob", 1, nil, 0).
|
||||
AddRow("old_long", "blob", 1, nil, 0).
|
||||
AddRow("id", "integer", 1, "NULL", 1))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('blobs')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
|
||||
for _, c := range []string{"tiny", "blob", "medium", "long"} {
|
||||
mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
}
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add columns with default values",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Default: "unknown"},
|
||||
{Name: "active", Type: field.TypeBool, Default: false},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.tableExists("users", true)
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("id", "integer", 1, "NULL", 1))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `name` varchar(255) NOT NULL DEFAULT 'unknown'")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add edge to table",
|
||||
tables: func() []*Table {
|
||||
var (
|
||||
c1 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Nullable: true},
|
||||
{Name: "spouse_id", Type: field.TypeInt, Nullable: true},
|
||||
}
|
||||
t1 = &Table{
|
||||
Name: "users",
|
||||
Columns: c1,
|
||||
PrimaryKey: c1[0:1],
|
||||
ForeignKeys: []*ForeignKey{
|
||||
{
|
||||
Symbol: "user_spouse",
|
||||
Columns: c1[2:],
|
||||
RefColumns: c1[0:1],
|
||||
OnDelete: Cascade,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
t1.ForeignKeys[0].RefTable = t1
|
||||
return []*Table{t1}
|
||||
}(),
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.tableExists("users", true)
|
||||
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
|
||||
WithArgs().
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
|
||||
AddRow("name", "varchar(255)", 1, "NULL", 0).
|
||||
AddRow("id", "integer", 1, "NULL", 1))
|
||||
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "universal id for all tables",
|
||||
tables: []*Table{
|
||||
NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
},
|
||||
options: []MigrateOption{WithGlobalUniqueID(true)},
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
// creating ent_types table.
|
||||
mock.tableExists("ent_types", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `ent_types`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `type` varchar(255) UNIQUE NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set users id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("users").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
|
||||
WithArgs("users", 0).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.tableExists("groups", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set groups id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("groups").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")).
|
||||
WithArgs("groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
|
||||
WithArgs("groups", 1<<32).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "universal id for restored tables",
|
||||
tables: []*Table{
|
||||
NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
},
|
||||
options: []MigrateOption{WithGlobalUniqueID(true)},
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
// query ent_types table.
|
||||
mock.tableExists("ent_types", true)
|
||||
mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users"))
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set users id range (without inserting to ent_types).
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
mock.ExpectExec(escape("UPDATE `sqlite_sequence` SET `seq` = ? WHERE `name` = ?")).
|
||||
WithArgs(0, "users").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.tableExists("groups", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set groups id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("groups").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")).
|
||||
WithArgs("groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
|
||||
WithArgs("groups", 1<<32).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.before(sqliteMock{mock})
|
||||
migrate, err := NewMigrate(sql.OpenDB("sqlite3", db), append(tt.options, WithAtlas(false))...)
|
||||
require.NoError(t, err)
|
||||
err = migrate.Create(context.Background(), tt.tables...)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type sqliteMock struct {
|
||||
sqlmock.Sqlmock
|
||||
}
|
||||
|
||||
func (m sqliteMock) start() {
|
||||
m.ExpectQuery("PRAGMA foreign_keys").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1))
|
||||
m.ExpectExec("PRAGMA foreign_keys = off").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
m.ExpectBegin()
|
||||
m.ExpectQuery("PRAGMA foreign_key_check").
|
||||
WillReturnRows(sqlmock.NewRows([]string{})) // empty
|
||||
}
|
||||
|
||||
func (m sqliteMock) commit() {
|
||||
m.ExpectQuery("PRAGMA foreign_key_check").
|
||||
WillReturnRows(sqlmock.NewRows([]string{})) // empty
|
||||
m.ExpectCommit()
|
||||
m.ExpectExec("PRAGMA foreign_keys = on").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
}
|
||||
|
||||
func (m sqliteMock) tableExists(table string, exists bool) {
|
||||
count := 0
|
||||
if exists {
|
||||
count = 1
|
||||
}
|
||||
m.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")).
|
||||
WithArgs("table", table).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count))
|
||||
}
|
||||
Reference in New Issue
Block a user