mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: universl id allocation support
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/9 Reviewed By: alexsn Differential Revision: D16252229 fbshipit-source-id: 795b6556d322e5c1ff5fb826c3b06ba5421ac857
This commit is contained in:
committed by
Facebook Github Bot
parent
ad051e6d72
commit
b5cdb810b8
307
dialect/sql/schema/migrate.go
Normal file
307
dialect/sql/schema/migrate.go
Normal file
@@ -0,0 +1,307 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"fbc/ent/dialect"
|
||||
"fbc/ent/dialect/sql"
|
||||
"fbc/ent/field"
|
||||
)
|
||||
|
||||
const (
|
||||
// TypeTable holds the table name holding the type information.
|
||||
TypeTable = "ent_types"
|
||||
// MaxTypes defines the max number of types can be created when
|
||||
// defining universal ids. The left 16-bits are reserved.
|
||||
MaxTypes = math.MaxUint16
|
||||
)
|
||||
|
||||
// MigrateOption allows for managing schema configuration using functional options.
|
||||
type MigrateOption func(m *Migrate)
|
||||
|
||||
// WithGlobalUniqueID sets the universal ids options to the migration.
|
||||
func WithGlobalUniqueID(b bool) MigrateOption {
|
||||
return func(o *Migrate) {
|
||||
o.universalID = b
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate runs the migrations logic for the SQL dialects.
|
||||
type Migrate struct {
|
||||
sqlDialect
|
||||
universalID bool // global unique id flag.
|
||||
typeRanges []string // types order by their range.
|
||||
}
|
||||
|
||||
// NewMigrate create a migration structure for the given SQL driver.
|
||||
func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
|
||||
m := &Migrate{}
|
||||
switch d.Dialect() {
|
||||
case dialect.MySQL:
|
||||
m.sqlDialect = &MySQL{Driver: d}
|
||||
case dialect.SQLite:
|
||||
m.sqlDialect = &SQLite{Driver: d}
|
||||
default:
|
||||
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect())
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Create creates all schema resources in the database. It works in an "append-only"
|
||||
// mode, which means, it only create tables, append column to tables or modifying column type.
|
||||
//
|
||||
// Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not
|
||||
// resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but
|
||||
// changing varchar(120) to varchar(255) is valid. For more info, see the convert function below.
|
||||
//
|
||||
// Note that SQLite dialect does not support (this moment) the "append-only" mode describe above,
|
||||
// since it's used only for testing.
|
||||
func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
|
||||
tx, err := m.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.init(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if m.universalID {
|
||||
if err := m.types(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
}
|
||||
if err := m.create(ctx, tx, tables...); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
|
||||
for _, t := range tables {
|
||||
switch exist, err := m.tableExist(ctx, tx, t.Name); {
|
||||
case err != nil:
|
||||
return err
|
||||
case exist:
|
||||
curr, err := m.table(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
change, err := m.changeSet(curr, t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(change.add) != 0 || len(change.modify) != 0 {
|
||||
b := sql.AlterTable(curr.Name)
|
||||
for _, c := range change.add {
|
||||
b.AddColumn(m.cBuilder(c))
|
||||
}
|
||||
for _, c := range change.modify {
|
||||
b.ModifyColumn(m.cBuilder(c))
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("alter table %q: %v", t.Name, err)
|
||||
}
|
||||
}
|
||||
if len(change.indexes) > 0 {
|
||||
panic("missing implementation")
|
||||
}
|
||||
default: // !exist
|
||||
query, args := m.tBuilder(t).Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("create table %q: %v", t.Name, err)
|
||||
}
|
||||
// if global unique identifier is enabled and it's not a relation table,
|
||||
// allocate a range for the table pk.
|
||||
if m.universalID && len(t.PrimaryKey) == 1 {
|
||||
if err := m.allocPKRange(ctx, tx, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// create foreign keys after tables were created/altered,
|
||||
// because circular foreign-key constraints are possible.
|
||||
for _, t := range tables {
|
||||
if len(t.ForeignKeys) == 0 {
|
||||
continue
|
||||
}
|
||||
fks := make([]*ForeignKey, 0, len(t.ForeignKeys))
|
||||
for _, fk := range t.ForeignKeys {
|
||||
fk.Symbol = symbol(fk.Symbol)
|
||||
exist, err := m.fkExist(ctx, tx, fk.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exist {
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
}
|
||||
if len(fks) == 0 {
|
||||
continue
|
||||
}
|
||||
b := sql.AlterTable(t.Name)
|
||||
for _, fk := range fks {
|
||||
b.AddForeignKey(fk.DSL())
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("create foreign keys for %q: %v", t.Name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// changes to apply on existing table.
|
||||
type changes struct {
|
||||
add []*Column
|
||||
modify []*Column
|
||||
indexes []*Index
|
||||
}
|
||||
|
||||
// changeSet returns a changes object to be applied on existing table.
|
||||
// It fails if one of the changes is invalid.
|
||||
func (m *Migrate) changeSet(curr, new *Table) (*changes, error) {
|
||||
change := &changes{}
|
||||
// pks.
|
||||
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
}
|
||||
sort.Slice(new.PrimaryKey, func(i, j int) bool { return new.PrimaryKey[i].Name < new.PrimaryKey[j].Name })
|
||||
sort.Slice(curr.PrimaryKey, func(i, j int) bool { return curr.PrimaryKey[i].Name < curr.PrimaryKey[j].Name })
|
||||
for i := range curr.PrimaryKey {
|
||||
if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
}
|
||||
}
|
||||
// columns.
|
||||
for _, c1 := range new.Columns {
|
||||
switch c2, ok := curr.column(c1.Name); {
|
||||
case !ok:
|
||||
change.add = append(change.add, c1)
|
||||
case c1.Unique != c2.Unique:
|
||||
return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name)
|
||||
case m.cType(c1) != m.cType(c2):
|
||||
if !c2.ConvertibleTo(c1) {
|
||||
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, m.cType(c1), m.cType(c2))
|
||||
}
|
||||
fallthrough
|
||||
case c1.Charset != "" && c1.Charset != c2.Charset || c1.Collation != "" && c1.Charset != c2.Collation:
|
||||
change.modify = append(change.modify, c1)
|
||||
}
|
||||
}
|
||||
// indexes.
|
||||
for _, idx1 := range new.Indexes {
|
||||
switch idx2, ok := curr.index(idx1.Name); {
|
||||
case !ok:
|
||||
change.indexes = append(change.indexes, idx1)
|
||||
case idx1.Unique != idx2.Unique:
|
||||
return nil, fmt.Errorf("changing index %q uniqness is invalid", idx1.Name)
|
||||
}
|
||||
}
|
||||
return change, nil
|
||||
}
|
||||
|
||||
// types loads the type list from the database.
|
||||
// If the table does not create, it will create one.
|
||||
func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error {
|
||||
exists, err := m.tableExist(ctx, tx, TypeTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
t := NewTable(TypeTable).
|
||||
AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}).
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
|
||||
query, args := m.tBuilder(t).Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("create types table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("type").From(sql.Table(TypeTable)).OrderBy(sql.Asc("id")).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return fmt.Errorf("query types table: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, &m.typeRanges)
|
||||
}
|
||||
|
||||
func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) error {
|
||||
id := -1
|
||||
// if the table re-created, re-use its range from the past.
|
||||
for i, name := range m.typeRanges {
|
||||
if name == t.Name {
|
||||
id = i
|
||||
break
|
||||
}
|
||||
}
|
||||
// allocate a new id-range.
|
||||
if id == -1 {
|
||||
if len(m.typeRanges) > MaxTypes {
|
||||
return fmt.Errorf("max number of types exceeded: %d", MaxTypes)
|
||||
}
|
||||
query, args := sql.Insert(TypeTable).Columns("type").Values(t.Name).Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("insert into type: %v", err)
|
||||
}
|
||||
id = len(m.typeRanges)
|
||||
m.typeRanges = append(m.typeRanges, t.Name)
|
||||
}
|
||||
// set the id offset for table.
|
||||
return m.setRange(ctx, tx, t.Name, id<<32)
|
||||
}
|
||||
|
||||
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).
|
||||
func symbol(name string) string {
|
||||
if len(name) <= 64 {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", name[:31], md5.Sum([]byte(name)))
|
||||
}
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
func rollback(tx dialect.Tx, err error) error {
|
||||
err = fmt.Errorf("sql/schema: %v", err)
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
err = fmt.Errorf("%s: %v", err.Error(), rerr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// exist checks if the given COUNT query returns a value >= 1.
|
||||
func exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return false, fmt.Errorf("reading schema information %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, fmt.Errorf("no rows returned")
|
||||
}
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return false, fmt.Errorf("scanning count")
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
type sqlDialect interface {
|
||||
dialect.Driver
|
||||
init(context.Context, dialect.Tx) error
|
||||
table(context.Context, dialect.Tx, string) (*Table, error)
|
||||
tableExist(context.Context, dialect.Tx, string) (bool, error)
|
||||
fkExist(context.Context, dialect.Tx, string) (bool, error)
|
||||
setRange(context.Context, dialect.Tx, string, int) error
|
||||
// table and column builder per dialect.
|
||||
cType(*Column) string
|
||||
tBuilder(*Table) *sql.TableBuilder
|
||||
cBuilder(*Column) *sql.ColumnBuilder
|
||||
}
|
||||
@@ -2,9 +2,7 @@ package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"fbc/ent/dialect"
|
||||
"fbc/ent/dialect/sql"
|
||||
@@ -13,140 +11,37 @@ import (
|
||||
// MySQL is a mysql migration driver.
|
||||
type MySQL struct {
|
||||
dialect.Driver
|
||||
version string
|
||||
}
|
||||
|
||||
// Create creates all schema resources in the database. It works in an "append-only"
|
||||
// mode, which means, it only create tables, append column to tables or modifying column type.
|
||||
//
|
||||
// Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not
|
||||
// resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but
|
||||
// changing varchar(120) to varchar(255) is valid. For more info, see the convert function below.
|
||||
func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
|
||||
tx, err := d.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.create(ctx, tx, tables...); err != nil {
|
||||
return rollback(tx, fmt.Errorf("dialect/mysql: %v", err))
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
|
||||
version, err := d.version(ctx, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range tables {
|
||||
switch exist, err := d.tableExist(ctx, tx, t.Name); {
|
||||
case err != nil:
|
||||
return err
|
||||
case exist:
|
||||
curr, err := d.table(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
change, err := changeSet(curr, t, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(change.add) != 0 || len(change.modify) != 0 {
|
||||
b := sql.AlterTable(curr.Name)
|
||||
for _, c := range change.add {
|
||||
b.AddColumn(c.MySQL(version))
|
||||
}
|
||||
for _, c := range change.modify {
|
||||
b.ModifyColumn(c.MySQL(version))
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("alter table %q: %v", t.Name, err)
|
||||
}
|
||||
}
|
||||
if len(change.indexes) > 0 {
|
||||
panic("missing implementation")
|
||||
}
|
||||
default: // !exist
|
||||
query, args := t.MySQL(version).Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("create table %q: %v", t.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// create foreign keys after tables were created/altered,
|
||||
// because circular foreign-key constraints are possible.
|
||||
for _, t := range tables {
|
||||
if len(t.ForeignKeys) == 0 {
|
||||
continue
|
||||
}
|
||||
fks := make([]*ForeignKey, 0, len(t.ForeignKeys))
|
||||
for _, fk := range t.ForeignKeys {
|
||||
fk.Symbol = symbol(fk.Symbol)
|
||||
exist, err := d.fkExist(ctx, tx, fk.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exist {
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
}
|
||||
if len(fks) == 0 {
|
||||
continue
|
||||
}
|
||||
b := sql.AlterTable(t.Name)
|
||||
for _, fk := range fks {
|
||||
b.AddForeignKey(fk.DSL())
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("create foreign keys for %q: %v", t.Name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *MySQL) version(ctx context.Context, tx dialect.Tx) (string, error) {
|
||||
// init loads the MySQL version from the database for later use in the migration process.
|
||||
func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
|
||||
return "", fmt.Errorf("dialect/mysql: querying mysql version %v", err)
|
||||
return fmt.Errorf("mysql: querying mysql version %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", fmt.Errorf("dialect/mysql: version variable was not found")
|
||||
return fmt.Errorf("mysql: version variable was not found")
|
||||
}
|
||||
version := make([]string, 2)
|
||||
if err := rows.Scan(&version[0], &version[1]); err != nil {
|
||||
return "", fmt.Errorf("dialect/mysql: scanning mysql version: %v", err)
|
||||
return fmt.Errorf("mysql: scanning mysql version: %v", err)
|
||||
}
|
||||
return version[1], nil
|
||||
d.version = version[1]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
|
||||
return d.exist(ctx, tx, query, args...)
|
||||
return exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS").Unquote()).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("CONSTRAINT_TYPE", "FOREIGN KEY").And().EQ("CONSTRAINT_NAME", name)).Query()
|
||||
return d.exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return false, fmt.Errorf("dialect/mysql: reading schema information %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, fmt.Errorf("dialect/mysql: no rows returned")
|
||||
}
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return false, fmt.Errorf("dialect/mysql: scanning count")
|
||||
}
|
||||
return n > 0, nil
|
||||
return exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
// table loads the current table description from the database.
|
||||
@@ -156,14 +51,14 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
From(sql.Table("INFORMATION_SCHEMA.COLUMNS").Unquote()).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("dialect/mysql: reading table description %v", err)
|
||||
return nil, fmt.Errorf("mysql: reading table description %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
t := &Table{Name: name}
|
||||
for rows.Next() {
|
||||
c := &Column{}
|
||||
if err := c.ScanMySQL(rows); err != nil {
|
||||
return nil, fmt.Errorf("dialect/mysql: %v", err)
|
||||
return nil, fmt.Errorf("mysql: %v", err)
|
||||
}
|
||||
if c.PrimaryKey() {
|
||||
t.PrimaryKey = append(t.PrimaryKey, c)
|
||||
@@ -173,68 +68,10 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// changes to apply on existing table.
|
||||
type changes struct {
|
||||
add []*Column
|
||||
modify []*Column
|
||||
indexes []*Index
|
||||
func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, new(sql.Result))
|
||||
}
|
||||
|
||||
// changeSet returns a changes object to be applied on existing table.
|
||||
// It fails if one of the changes is invalid.
|
||||
func changeSet(curr, new *Table, version string) (*changes, error) {
|
||||
change := &changes{}
|
||||
// pks.
|
||||
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
}
|
||||
sort.Slice(new.PrimaryKey, func(i, j int) bool { return new.PrimaryKey[i].Name < new.PrimaryKey[j].Name })
|
||||
sort.Slice(curr.PrimaryKey, func(i, j int) bool { return curr.PrimaryKey[i].Name < curr.PrimaryKey[j].Name })
|
||||
for i := range curr.PrimaryKey {
|
||||
if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
}
|
||||
}
|
||||
// columns.
|
||||
for _, c1 := range new.Columns {
|
||||
switch c2, ok := curr.column(c1.Name); {
|
||||
case !ok:
|
||||
change.add = append(change.add, c1)
|
||||
case c1.Unique != c2.Unique:
|
||||
return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name)
|
||||
case c1.MySQLType(version) != c2.MySQLType(version):
|
||||
if !c2.ConvertibleTo(c1) {
|
||||
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, c1.MySQLType(version), c2.MySQLType(version))
|
||||
}
|
||||
fallthrough
|
||||
case c1.Charset != "" && c1.Charset != c2.Charset || c1.Collation != "" && c1.Charset != c2.Collation:
|
||||
change.modify = append(change.modify, c1)
|
||||
}
|
||||
}
|
||||
// indexes.
|
||||
for _, idx1 := range new.Indexes {
|
||||
switch idx2, ok := curr.index(idx1.Name); {
|
||||
case !ok:
|
||||
change.indexes = append(change.indexes, idx1)
|
||||
case idx1.Unique != idx2.Unique:
|
||||
return nil, fmt.Errorf("changing index %q uniqness is invalid", idx1.Name)
|
||||
}
|
||||
}
|
||||
return change, nil
|
||||
}
|
||||
|
||||
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).
|
||||
func symbol(name string) string {
|
||||
if len(name) <= 64 {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", name[:31], md5.Sum([]byte(name)))
|
||||
}
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
func rollback(tx dialect.Tx, err error) error {
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
err = fmt.Errorf("%s: %v", err.Error(), rerr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
func (d *MySQL) cType(c *Column) string { return c.MySQLType(d.version) }
|
||||
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { return t.MySQL(d.version) }
|
||||
func (d *MySQL) cBuilder(c *Column) *sql.ColumnBuilder { return c.MySQL(d.version) }
|
||||
|
||||
@@ -18,6 +18,7 @@ func TestMySQL_Create(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tables []*Table
|
||||
options []MigrateOption
|
||||
before func(sqlmock.Sqlmock)
|
||||
wantErr bool
|
||||
}{
|
||||
@@ -261,14 +262,135 @@ func TestMySQL_Create(t *testing.T) {
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "universal id for all tables",
|
||||
tables: []*Table{
|
||||
NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
},
|
||||
options: []MigrateOption{WithGlobalUniqueID(true)},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("ent_types").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
// create ent_types table.
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `ent_types`(`id` bigint AUTO_INCREMENT, `type` varchar(255) UNIQUE, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set users id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("users").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set groups id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("groups").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "universal id for new tables",
|
||||
tables: []*Table{
|
||||
NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
},
|
||||
options: []MigrateOption{WithGlobalUniqueID(true)},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("ent_types").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
// query ent_types table.
|
||||
mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
// users table has no changes.
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
|
||||
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set groups id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("groups").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "universal id for restored tables",
|
||||
tables: []*Table{
|
||||
NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
},
|
||||
options: []MigrateOption{WithGlobalUniqueID(true)},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("ent_types").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
// query ent_types table.
|
||||
mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set users id range (without inserting to ent_types).
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set groups id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("groups").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.before(mock)
|
||||
mysql := &MySQL{sql.OpenDB("mysql", db)}
|
||||
err = mysql.Create(context.Background(), tt.tables...)
|
||||
migrate, err := NewMigrate(sql.OpenDB("mysql", db), tt.options...)
|
||||
require.NoError(t, err)
|
||||
err = migrate.Create(context.Background(), tt.tables...)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -34,6 +34,12 @@ func (t *Table) AddForeignKey(fk *ForeignKey) *Table {
|
||||
return t
|
||||
}
|
||||
|
||||
// AddColumn adds a new column to the table.
|
||||
func (t *Table) AddColumn(c *Column) *Table {
|
||||
t.Columns = append(t.Columns, c)
|
||||
return t
|
||||
}
|
||||
|
||||
// MySQL returns the MySQL DSL query for table creation.
|
||||
func (t *Table) MySQL(version string) *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name).IfNotExists()
|
||||
|
||||
@@ -13,40 +13,16 @@ type SQLite struct {
|
||||
dialect.Driver
|
||||
}
|
||||
|
||||
// Create creates all tables resources in the database.
|
||||
func (d *SQLite) Create(ctx context.Context, tables ...*Table) error {
|
||||
tx, err := d.Tx(ctx)
|
||||
// init makes sure that foreign_keys support is enabled.
|
||||
func (d *SQLite) init(ctx context.Context, tx dialect.Tx) error {
|
||||
on, err := exist(ctx, tx, "PRAGMA foreign_keys")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.create(ctx, tx, tables...); err != nil {
|
||||
return rollback(tx, fmt.Errorf("dialect/sqlite: %v", err))
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (d *SQLite) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
|
||||
on, err := d.fkEnabled(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check foreign_keys pragma: %v", err)
|
||||
return fmt.Errorf("sqlite: check foreign_keys pragma: %v", err)
|
||||
}
|
||||
if !on {
|
||||
// foreign_keys pragma is off, either enable it by execute "PRAGMA foreign_keys=ON"
|
||||
// or add the following parameter in the connection string "_fk=1".
|
||||
return fmt.Errorf("foreign_keys pragma is off: missing %q is the connection string", "_fk=1")
|
||||
}
|
||||
for _, t := range tables {
|
||||
exist, err := d.tableExist(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exist {
|
||||
continue
|
||||
}
|
||||
query, args := t.SQLite().Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("create table %q: %v", t.Name, err)
|
||||
}
|
||||
return fmt.Errorf("sqlite: foreign_keys pragma is off: missing %q is the connection string", "_fk=1")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -56,25 +32,35 @@ func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bo
|
||||
From(sql.Table("sqlite_master")).
|
||||
Where(sql.EQ("type", "table").And().EQ("name", name)).
|
||||
Query()
|
||||
return d.exist(ctx, tx, query, args...)
|
||||
return exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
func (d *SQLite) fkEnabled(ctx context.Context, tx dialect.Tx) (bool, error) {
|
||||
return d.exist(ctx, tx, "PRAGMA foreign_keys")
|
||||
// setRange sets the start value of table PK.
|
||||
// SQLite tracks the AUTOINCREMENT in the "sqlite_sequence" table that is created and initialized automatically
|
||||
// whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables)
|
||||
// only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence"
|
||||
// table, we updated it. Otherwise, we insert a new value.
|
||||
func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
|
||||
query, args := sql.Select().Count().
|
||||
From(sql.Table("sqlite_sequence")).
|
||||
Where(sql.EQ("name", name)).
|
||||
Query()
|
||||
exists, err := exist(ctx, tx, query, args...)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case exists:
|
||||
query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", name)).Query()
|
||||
default: // !exists
|
||||
query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(name, value).Query()
|
||||
}
|
||||
return tx.Exec(ctx, query, args, new(sql.Result))
|
||||
}
|
||||
|
||||
func (d *SQLite) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return false, fmt.Errorf("reading schema information %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, fmt.Errorf("no rows returned")
|
||||
}
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return false, fmt.Errorf("scanning count")
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
func (d *SQLite) cType(c *Column) string { return c.SQLiteType() }
|
||||
func (d *SQLite) tBuilder(t *Table) *sql.TableBuilder { return t.SQLite() }
|
||||
func (d *SQLite) cBuilder(c *Column) *sql.ColumnBuilder { return c.SQLite() }
|
||||
|
||||
// fkExist returns always tru to disable foreign-keys creation after the table was created.
|
||||
func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil }
|
||||
func (d *SQLite) table(context.Context, dialect.Tx, string) (*Table, error) { return nil, nil }
|
||||
|
||||
@@ -16,6 +16,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tables []*Table
|
||||
options []MigrateOption
|
||||
before func(sqlmock.Sqlmock)
|
||||
wantErr bool
|
||||
}{
|
||||
@@ -125,14 +126,112 @@ func TestSQLite_Create(t *testing.T) {
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "universal id for all tables",
|
||||
tables: []*Table{
|
||||
NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
},
|
||||
options: []MigrateOption{WithGlobalUniqueID(true)},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery("PRAGMA foreign_keys").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1))
|
||||
// creating ent_types table.
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")).
|
||||
WithArgs("table", "ent_types").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE `ent_types`(`id` integer PRIMARY KEY AUTOINCREMENT, `type` varchar(255) UNIQUE)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")).
|
||||
WithArgs("table", "users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set users id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("users").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
|
||||
WithArgs("users", 0).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")).
|
||||
WithArgs("table", "groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set groups id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("groups").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")).
|
||||
WithArgs("groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
|
||||
WithArgs("groups", 1<<32).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "universal id for restored tables",
|
||||
tables: []*Table{
|
||||
NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}),
|
||||
},
|
||||
options: []MigrateOption{WithGlobalUniqueID(true)},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery("PRAGMA foreign_keys").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1))
|
||||
// query ent_types table.
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")).
|
||||
WithArgs("table", "ent_types").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")).
|
||||
WithArgs("table", "users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set users id range (without inserting to ent_types).
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
mock.ExpectExec(escape("UPDATE `sqlite_sequence` SET `seq` = ? WHERE `name` = ?")).
|
||||
WithArgs(0, "users").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")).
|
||||
WithArgs("table", "groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// set groups id range.
|
||||
mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")).
|
||||
WithArgs("groups").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")).
|
||||
WithArgs("groups").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
|
||||
WithArgs("groups", 1<<32).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.before(mock)
|
||||
sqlite := &SQLite{sql.OpenDB("sqlite", db)}
|
||||
err = sqlite.Create(context.Background(), tt.tables...)
|
||||
migrate, err := NewMigrate(sql.OpenDB("sqlite3", db), tt.options...)
|
||||
require.NoError(t, err)
|
||||
err = migrate.Create(context.Background(), tt.tables...)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user