mirror of
https://github.com/ent/ent.git
synced 2026-05-04 00:20:58 +03:00
292 lines
10 KiB
Go
292 lines
10 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package schema
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"text/template"
|
|
"time"
|
|
|
|
"ariga.io/atlas/sql/migrate"
|
|
"ariga.io/atlas/sql/schema"
|
|
"entgo.io/ent/schema/field"
|
|
|
|
"entgo.io/ent/dialect"
|
|
"entgo.io/ent/dialect/sql"
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
func TestMigrateHookOmitTable(t *testing.T) {
|
|
db, mk, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
|
|
tables := []*Table{{Name: "users"}, {Name: "pets"}}
|
|
mock := mysqlMock{mk}
|
|
mock.start("5.7.23")
|
|
mock.tableExists("pets", false)
|
|
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
mock.ExpectCommit()
|
|
|
|
m, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator {
|
|
return CreateFunc(func(ctx context.Context, tables ...*Table) error {
|
|
return next.Create(ctx, tables[1])
|
|
})
|
|
}))
|
|
require.NoError(t, err)
|
|
err = m.Create(context.Background(), tables...)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestMigrateHookAddTable(t *testing.T) {
|
|
db, mk, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
|
|
tables := []*Table{{Name: "users"}}
|
|
mock := mysqlMock{mk}
|
|
mock.start("5.7.23")
|
|
mock.tableExists("users", false)
|
|
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
mock.tableExists("pets", false)
|
|
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
mock.ExpectCommit()
|
|
|
|
m, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator {
|
|
return CreateFunc(func(ctx context.Context, tables ...*Table) error {
|
|
return next.Create(ctx, tables[0], &Table{Name: "pets"})
|
|
})
|
|
}))
|
|
require.NoError(t, err)
|
|
err = m.Create(context.Background(), tables...)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestMigrate_Diff(t *testing.T) {
|
|
db, err := sql.Open(dialect.SQLite, "file:test?mode=memory&_fk=1")
|
|
require.NoError(t, err)
|
|
|
|
p := t.TempDir()
|
|
d, err := migrate.NewLocalDir(p)
|
|
require.NoError(t, err)
|
|
|
|
m, err := NewMigrate(db, WithDir(d))
|
|
require.NoError(t, err)
|
|
require.NoError(t, m.Diff(context.Background(), &Table{Name: "users"}))
|
|
v := time.Now().UTC().Format("20060102150405")
|
|
requireFileEqual(t, filepath.Join(p, v+"_changes.up.sql"), "-- create \"users\" table\nCREATE TABLE `users` (, PRIMARY KEY ());\n")
|
|
requireFileEqual(t, filepath.Join(p, v+"_changes.down.sql"), "-- reverse: create \"users\" table\nDROP TABLE `users`;\n")
|
|
require.NoFileExists(t, filepath.Join(p, "atlas.sum"))
|
|
|
|
// Test integrity file.
|
|
p = t.TempDir()
|
|
d, err = migrate.NewLocalDir(p)
|
|
require.NoError(t, err)
|
|
m, err = NewMigrate(db, WithDir(d), WithSumFile())
|
|
require.NoError(t, err)
|
|
require.NoError(t, m.Diff(context.Background(), &Table{Name: "users"}))
|
|
requireFileEqual(t, filepath.Join(p, v+"_changes.up.sql"), "-- create \"users\" table\nCREATE TABLE `users` (, PRIMARY KEY ());\n")
|
|
requireFileEqual(t, filepath.Join(p, v+"_changes.down.sql"), "-- reverse: create \"users\" table\nDROP TABLE `users`;\n")
|
|
require.FileExists(t, filepath.Join(p, "atlas.sum"))
|
|
require.NoError(t, d.WriteFile("tmp.sql", nil))
|
|
require.ErrorIs(t, m.Diff(context.Background(), &Table{Name: "users"}), migrate.ErrChecksumMismatch)
|
|
|
|
// Test type store.
|
|
idCol := []*Column{{Name: "id", Type: field.TypeInt, Increment: true}}
|
|
p = t.TempDir()
|
|
d, err = migrate.NewLocalDir(p)
|
|
require.NoError(t, err)
|
|
f, err := migrate.NewTemplateFormatter(
|
|
template.Must(template.New("").Parse("{{ .Name }}.sql")),
|
|
template.Must(template.New("").Parse(
|
|
`{{ range .Changes }}{{ printf "%s;\n" .Cmd }}{{ end }}`,
|
|
)),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
// If using global unique ID and versioned migrations,
|
|
// consent for the file based type store has to be given explicitly.
|
|
_, err = NewMigrate(db, WithDir(d), WithGlobalUniqueID(true))
|
|
require.ErrorIs(t, err, errConsent)
|
|
require.Contains(t, err.Error(), "WithUniversalID")
|
|
require.Contains(t, err.Error(), "WithGlobalUniqueID")
|
|
require.Contains(t, err.Error(), "WithDir")
|
|
|
|
m, err = NewMigrate(db, WithFormatter(f), WithDir(d), WithUniversalID(), WithSumFile())
|
|
require.NoError(t, err)
|
|
require.IsType(t, &dirTypeStore{}, m.typeStore)
|
|
require.NoError(t, m.Diff(context.Background(),
|
|
&Table{Name: "users", Columns: idCol, PrimaryKey: idCol},
|
|
&Table{Name: "groups", Columns: idCol, PrimaryKey: idCol, Indexes: []*Index{{Name: "short", Columns: idCol}, {Name: "long_" + strings.Repeat("_", 60), Columns: idCol}}},
|
|
))
|
|
requireFileEqual(t, filepath.Join(p, ".ent_types"), atlasDirective+"users,groups")
|
|
changesSQL := strings.Join([]string{
|
|
"CREATE TABLE `users` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);",
|
|
"CREATE TABLE `groups` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);",
|
|
fmt.Sprintf("INSERT INTO sqlite_sequence (name, seq) VALUES (\"groups\", %d);", 1<<32),
|
|
"CREATE INDEX `short` ON `groups` (`id`);",
|
|
"CREATE INDEX `long____________________________1cb2e7e47a309191385af4ad320875b1` ON `groups` (`id`);",
|
|
"CREATE TABLE `ent_types` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `type` text NOT NULL);",
|
|
"CREATE UNIQUE INDEX `ent_types_type_key` ON `ent_types` (`type`);",
|
|
"INSERT INTO `ent_types` (`type`) VALUES ('users'), ('groups');", "",
|
|
}, "\n")
|
|
requireFileEqual(t, filepath.Join(p, "changes.sql"), changesSQL)
|
|
|
|
// types file cannot be part of the sum file.
|
|
require.FileExists(t, filepath.Join(p, "atlas.sum"))
|
|
sum, err := os.ReadFile(filepath.Join(p, "atlas.sum"))
|
|
require.NoError(t, err)
|
|
require.NotContains(t, string(sum), ".ent_types")
|
|
|
|
// Adding another node will result in a new entry to the TypeTable (without actually creating it).
|
|
_, err = db.ExecContext(context.Background(), changesSQL, nil, nil)
|
|
require.NoError(t, err)
|
|
require.NoError(t, m.NamedDiff(context.Background(), "changes_2", &Table{Name: "pets", Columns: idCol, PrimaryKey: idCol}))
|
|
requireFileEqual(t, filepath.Join(p, ".ent_types"), atlasDirective+"users,groups,pets")
|
|
requireFileEqual(t,
|
|
filepath.Join(p, "changes_2.sql"), strings.Join([]string{
|
|
"CREATE TABLE `pets` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);",
|
|
fmt.Sprintf("INSERT INTO sqlite_sequence (name, seq) VALUES (\"pets\", %d);", 2<<32),
|
|
"INSERT INTO `ent_types` (`type`) VALUES ('pets');", "",
|
|
}, "\n"))
|
|
|
|
// types file cannot be part of the sum file.
|
|
require.FileExists(t, filepath.Join(p, "atlas.sum"))
|
|
sum, err = os.ReadFile(filepath.Join(p, "atlas.sum"))
|
|
require.NoError(t, err)
|
|
require.NotContains(t, string(sum), ".ent_types")
|
|
|
|
// Checksum will be updated as well.
|
|
require.NoError(t, migrate.Validate(d))
|
|
|
|
// Running diff against an existing database without having a types file yet
|
|
// will result in the types file respect the "old" order of pk allocations.
|
|
switchAllocs := func(one, two string) {
|
|
for _, stmt := range []string{
|
|
"DELETE FROM `ent_types`;",
|
|
fmt.Sprintf("INSERT INTO `ent_types` (`type`) VALUES ('%s'), ('%s');", one, two),
|
|
} {
|
|
_, err = db.ExecContext(context.Background(), stmt)
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
switchAllocs("groups", "users")
|
|
p = t.TempDir()
|
|
d, err = migrate.NewLocalDir(p)
|
|
require.NoError(t, err)
|
|
m, err = NewMigrate(db, WithFormatter(f), WithDir(d), WithUniversalID())
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, m.Diff(context.Background(),
|
|
&Table{Name: "users", Columns: idCol, PrimaryKey: idCol},
|
|
&Table{Name: "groups", Columns: idCol, PrimaryKey: idCol},
|
|
))
|
|
requireFileEqual(t, filepath.Join(p, ".ent_types"), atlasDirective+"groups,users")
|
|
require.NoFileExists(t, filepath.Join(p, "changes.sql"))
|
|
|
|
// Drifts in the types file and types database will be detected,
|
|
switchAllocs("users", "groups")
|
|
require.ErrorContains(t, m.Diff(context.Background()), fmt.Sprintf(
|
|
"type allocation range drift detected: %v <> %v: see %s for more information",
|
|
[]string{"users", "groups"},
|
|
[]string{"groups", "users"},
|
|
"https://entgo.io/docs/versioned-migrations#moving-from-auto-migration-to-versioned-migrations",
|
|
))
|
|
}
|
|
|
|
func requireFileEqual(t *testing.T, name, contents string) {
|
|
c, err := os.ReadFile(name)
|
|
require.NoError(t, err)
|
|
require.Equal(t, contents, string(c))
|
|
}
|
|
|
|
func TestMigrateWithoutForeignKeys(t *testing.T) {
|
|
tbl := &schema.Table{
|
|
Name: "tbl",
|
|
Columns: []*schema.Column{
|
|
{Name: "id", Type: &schema.ColumnType{Type: &schema.IntegerType{T: "bigint"}}},
|
|
},
|
|
}
|
|
fk := &schema.ForeignKey{
|
|
Symbol: "fk",
|
|
Table: tbl,
|
|
Columns: tbl.Columns[1:],
|
|
RefTable: tbl,
|
|
RefColumns: tbl.Columns[:1],
|
|
OnUpdate: schema.NoAction,
|
|
OnDelete: schema.Cascade,
|
|
}
|
|
tbl.ForeignKeys = append(tbl.ForeignKeys, fk)
|
|
t.Run("AddTable", func(t *testing.T) {
|
|
mdiff := DiffFunc(func(_, _ *schema.Schema) ([]schema.Change, error) {
|
|
return []schema.Change{
|
|
&schema.AddTable{
|
|
T: tbl,
|
|
},
|
|
}, nil
|
|
})
|
|
df, err := withoutForeignKeys(mdiff).Diff(nil, nil)
|
|
require.NoError(t, err)
|
|
require.Len(t, df, 1)
|
|
actual, ok := df[0].(*schema.AddTable)
|
|
require.True(t, ok)
|
|
require.Nil(t, actual.T.ForeignKeys)
|
|
})
|
|
t.Run("ModifyTable", func(t *testing.T) {
|
|
mdiff := DiffFunc(func(_, _ *schema.Schema) ([]schema.Change, error) {
|
|
return []schema.Change{
|
|
&schema.ModifyTable{
|
|
T: tbl,
|
|
Changes: []schema.Change{
|
|
&schema.AddIndex{
|
|
I: &schema.Index{
|
|
Name: "id_key",
|
|
Parts: []*schema.IndexPart{
|
|
{C: tbl.Columns[0]},
|
|
},
|
|
},
|
|
},
|
|
&schema.DropForeignKey{
|
|
F: fk,
|
|
},
|
|
&schema.AddForeignKey{
|
|
F: fk,
|
|
},
|
|
&schema.ModifyForeignKey{
|
|
From: fk,
|
|
To: fk,
|
|
Change: schema.ChangeRefColumn,
|
|
},
|
|
&schema.AddColumn{
|
|
C: &schema.Column{Name: "name", Type: &schema.ColumnType{Type: &schema.StringType{T: "varchar(255)"}}},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
})
|
|
df, err := withoutForeignKeys(mdiff).Diff(nil, nil)
|
|
require.NoError(t, err)
|
|
require.Len(t, df, 1)
|
|
actual, ok := df[0].(*schema.ModifyTable)
|
|
require.True(t, ok)
|
|
require.Len(t, actual.Changes, 2)
|
|
addIndex, ok := actual.Changes[0].(*schema.AddIndex)
|
|
require.True(t, ok)
|
|
require.EqualValues(t, "id_key", addIndex.I.Name)
|
|
addColumn, ok := actual.Changes[1].(*schema.AddColumn)
|
|
require.True(t, ok)
|
|
require.EqualValues(t, "name", addColumn.C.Name)
|
|
})
|
|
}
|