Files
ent/entc/integration/migrate/migrate_test.go
Ariel Mashraki 35a098fdbb dialect/sql/schema: fix bug in atlas integration when working WithDropIndex/WithDropColumn (#2374)
* dialect/sql/schema: fix no change condition in atlas

* dialect/sql/schema: fix bug in atlas integration when working WithDropIndex/WithDropColumn

Co-authored-by: Zeev Manilovich <zeevmanilovich@gmail.com>
2022-03-03 21:52:44 +02:00

387 lines
15 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package migrate
import (
"context"
"fmt"
"math"
"strconv"
"strings"
"testing"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/schema"
"entgo.io/ent/entc/integration/migrate/entv1"
migratev1 "entgo.io/ent/entc/integration/migrate/entv1/migrate"
userv1 "entgo.io/ent/entc/integration/migrate/entv1/user"
"entgo.io/ent/entc/integration/migrate/entv2"
"entgo.io/ent/entc/integration/migrate/entv2/conversion"
migratev2 "entgo.io/ent/entc/integration/migrate/entv2/migrate"
"entgo.io/ent/entc/integration/migrate/entv2/user"
"ariga.io/atlas/sql/migrate"
atlas "ariga.io/atlas/sql/schema"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
func TestMySQL(t *testing.T) {
for version, port := range map[string]int{"56": 3306, "57": 3307, "8": 3308} {
t.Run(version, func(t *testing.T) {
root, err := sql.Open(dialect.MySQL, fmt.Sprintf("root:pass@tcp(localhost:%d)/", port))
require.NoError(t, err)
defer root.Close()
ctx := context.Background()
err = root.Exec(ctx, "CREATE DATABASE IF NOT EXISTS migrate", []interface{}{}, new(sql.Result))
require.NoError(t, err, "creating database")
defer root.Exec(ctx, "DROP DATABASE IF EXISTS migrate", []interface{}{}, new(sql.Result))
drv, err := sql.Open("mysql", fmt.Sprintf("root:pass@tcp(localhost:%d)/migrate?parseTime=True", port))
require.NoError(t, err, "connecting to migrate database")
clientv1 := entv1.NewClient(entv1.Driver(drv))
clientv2 := entv2.NewClient(entv2.Driver(drv))
V1ToV2(t, drv.Dialect(), clientv1, clientv2)
if version == "8" {
CheckConstraint(t, clientv2)
}
NicknameSearch(t, clientv2)
})
}
}
func TestPostgres(t *testing.T) {
for version, port := range map[string]int{"10": 5430, "11": 5431, "12": 5432, "13": 5433, "14": 5434} {
t.Run(version, func(t *testing.T) {
dsn := fmt.Sprintf("host=localhost port=%d user=postgres password=pass sslmode=disable", port)
root, err := sql.Open(dialect.Postgres, dsn)
require.NoError(t, err)
defer root.Close()
ctx := context.Background()
err = root.Exec(ctx, "DROP DATABASE IF EXISTS migrate", []interface{}{}, new(sql.Result))
require.NoError(t, err)
err = root.Exec(ctx, "CREATE DATABASE migrate", []interface{}{}, new(sql.Result))
require.NoError(t, err, "creating database")
defer root.Exec(ctx, "DROP DATABASE migrate", []interface{}{}, new(sql.Result))
drv, err := sql.Open(dialect.Postgres, dsn+" dbname=migrate")
require.NoError(t, err, "connecting to migrate database")
defer drv.Close()
err = drv.Exec(ctx, "CREATE TYPE customtype as range (subtype = time)", []interface{}{}, new(sql.Result))
require.NoError(t, err, "creating custom type")
clientv1 := entv1.NewClient(entv1.Driver(drv))
clientv2 := entv2.NewClient(entv2.Driver(drv))
V1ToV2(t, drv.Dialect(), clientv1, clientv2)
CheckConstraint(t, clientv2)
})
}
}
func TestSQLite(t *testing.T) {
drv, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
defer drv.Close()
ctx := context.Background()
client := entv2.NewClient(entv2.Driver(drv))
require.NoError(
t,
client.Schema.Create(
ctx,
migratev2.WithGlobalUniqueID(true),
migratev2.WithDropIndex(true),
migratev2.WithDropColumn(true),
schema.WithDiffHook(func(next schema.Differ) schema.Differ {
return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) {
// Example to hook into the diff process.
changes, err := next.Diff(current, desired)
if err != nil {
return nil, err
}
// After diff, you can filter
// changes or return new ones.
return changes, nil
})
}),
schema.WithApplyHook(func(next schema.Applier) schema.Applier {
return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
// Example to hook into the apply process, or implement
// a custom applier. For example, write to a file.
//
// for _, c := range plan.Changes {
// fmt.Printf("%s: %s", c.Comment, c.Cmd)
// if err := conn.Exec(ctx, c.Cmd, c.Args, nil); err != nil {
// return err
// }
// }
//
return next.Apply(ctx, conn, plan)
})
}),
),
)
SanityV2(t, drv.Dialect(), client)
u := client.User.Create().SetAge(1).SetName("x").SetNickname("x'").SetPhone("y").SaveX(ctx)
idRange(t, client.Car.Create().SetOwner(u).SaveX(ctx).ID, 0, 1<<32)
idRange(t, client.Conversion.Create().SaveX(ctx).ID, 1<<32-1, 2<<32)
idRange(t, client.CustomType.Create().SaveX(ctx).ID, 2<<32-1, 3<<32)
idRange(t, client.Group.Create().SaveX(ctx).ID, 3<<32-1, 4<<32)
idRange(t, client.Media.Create().SaveX(ctx).ID, 4<<32-1, 5<<32)
idRange(t, client.Pet.Create().SaveX(ctx).ID, 5<<32-1, 6<<32)
idRange(t, u.ID, 6<<32-1, 7<<32)
// Override the default behavior of LIKE in SQLite.
// https://www.sqlite.org/pragma.html#pragma_case_sensitive_like
_, err = drv.ExecContext(ctx, "PRAGMA case_sensitive_like=1")
require.NoError(t, err)
EqualFold(t, client)
ContainsFold(t, client)
CheckConstraint(t, client)
}
func TestStorageKey(t *testing.T) {
require.Equal(t, "user_pet_id", migratev2.PetsTable.ForeignKeys[0].Symbol)
require.Equal(t, "user_friend_id1", migratev2.FriendsTable.ForeignKeys[0].Symbol)
require.Equal(t, "user_friend_id2", migratev2.FriendsTable.ForeignKeys[1].Symbol)
}
func V1ToV2(t *testing.T, dialect string, clientv1 *entv1.Client, clientv2 *entv2.Client) {
ctx := context.Background()
// Run migration and execute queries on v1.
require.NoError(t, clientv1.Schema.Create(ctx, migratev1.WithGlobalUniqueID(true), schema.WithAtlas(true)))
SanityV1(t, dialect, clientv1)
// Run migration and execute queries on v2.
require.NoError(t, clientv2.Schema.Create(ctx, migratev2.WithGlobalUniqueID(true), migratev2.WithDropIndex(true), migratev2.WithDropColumn(true), schema.WithAtlas(true)))
require.NoError(t, clientv2.Schema.Create(ctx, migratev2.WithGlobalUniqueID(true), migratev2.WithDropIndex(true), migratev2.WithDropColumn(true), schema.WithAtlas(true)), "should not create additional resources on multiple runs")
SanityV2(t, dialect, clientv2)
u := clientv2.User.Create().SetAge(1).SetName("foo").SetNickname("nick_foo").SetPhone("phone").SaveX(ctx)
idRange(t, clientv2.Car.Create().SetOwner(u).SaveX(ctx).ID, 0, 1<<32)
idRange(t, clientv2.Conversion.Create().SaveX(ctx).ID, 1<<32-1, 2<<32)
// Since "users" created in the migration of v1, it will occupy the range of 1<<32-1 ... 2<<32-1,
// even though they are ordered differently in the migration of v2 (groups, pets, users).
idRange(t, u.ID, 3<<32-1, 4<<32)
idRange(t, clientv2.Group.Create().SaveX(ctx).ID, 4<<32-1, 5<<32)
idRange(t, clientv2.Media.Create().SaveX(ctx).ID, 5<<32-1, 6<<32)
idRange(t, clientv2.Pet.Create().SaveX(ctx).ID, 6<<32-1, 7<<32)
// SQL specific predicates.
EqualFold(t, clientv2)
ContainsFold(t, clientv2)
// "renamed" field was renamed to "new_name".
exist := clientv2.User.Query().Where(user.NewName("renamed")).ExistX(ctx)
require.True(t, exist, "expect renamed column to have previous values")
}
func SanityV1(t *testing.T, dbdialect string, client *entv1.Client) {
ctx := context.Background()
u := client.User.Create().SetAge(1).SetName("foo").SetNickname("nick_foo").SetRenamed("renamed").SaveX(ctx)
require.EqualValues(t, 1, u.Age)
require.Equal(t, "foo", u.Name)
err := client.User.Create().SetAge(2).SetName("foobarbazqux").Exec(ctx)
require.Error(t, err, "name is limited to 10 chars")
// Unique index on (name, address).
client.User.Create().SetAge(3).SetName("foo").SetNickname("nick_foo_2").SetAddress("tlv").SetState(userv1.StateLoggedIn).SaveX(ctx)
err = client.User.Create().SetAge(4).SetName("foo").SetAddress("tlv").Exec(ctx)
require.Error(t, err)
// Blob type limited to 255.
u = u.Update().SetBlob([]byte("hello")).SaveX(ctx)
require.Equal(t, "hello", string(u.Blob))
err = u.Update().SetBlob(make([]byte, 256)).Exec(ctx)
require.True(t, strings.Contains(t.Name(), "Postgres") || err != nil, "blob should be limited on SQLite and MySQL")
// Invalid enum value.
err = client.User.Create().SetAge(1).SetName("bar").SetNickname("nick_bar").SetState("unknown").Exec(ctx)
require.Error(t, err)
// Conversions
client.Conversion.Create().
SetName("zero").
SetInt8ToString(0).
SetUint8ToString(0).
SetInt16ToString(0).
SetUint16ToString(0).
SetInt32ToString(0).
SetUint32ToString(0).
SetInt64ToString(0).
SetUint64ToString(0).
SaveX(ctx)
client.Conversion.Create().
SetName("min").
SetInt8ToString(math.MinInt8).
SetUint8ToString(0).
SetInt16ToString(math.MinInt16).
SetUint16ToString(0).
SetInt32ToString(math.MinInt32).
SetUint32ToString(0).
SetInt64ToString(math.MinInt64).
SetUint64ToString(0).
SaveX(ctx)
creator := client.Conversion.Create().
SetName("max").
SetInt8ToString(math.MaxInt8).
SetUint8ToString(math.MaxUint8).
SetInt16ToString(math.MaxInt16).
SetUint16ToString(math.MaxUint16).
SetInt32ToString(math.MaxInt32).
SetUint32ToString(math.MaxUint32).
SetInt64ToString(math.MaxInt64).
SetUint64ToString(math.MaxUint64)
if dbdialect == dialect.Postgres {
// Postgres does not support unsigned types.
creator.SetInt8ToString(math.MaxInt8).
SetUint8ToString(math.MaxInt8).
SetUint16ToString(math.MaxInt16).
SetUint32ToString(math.MaxInt32).
SetUint32ToString(math.MaxInt32).
SetUint64ToString(math.MaxInt64)
}
creator.SaveX(ctx)
}
func SanityV2(t *testing.T, dbdialect string, client *entv2.Client) {
ctx := context.Background()
if dbdialect != dialect.SQLite {
require.True(t, client.User.Query().ExistX(ctx), "table 'users' should contain rows after running the migration")
users := client.User.Query().Select(user.FieldCreatedAt).AllX(ctx)
for i := range users {
require.False(t, users[i].CreatedAt.IsZero(), "default 'CURRENT_TIMESTAMP' should fill previous rows")
}
}
u := client.User.Create().SetAge(1).SetName("bar").SetNickname("nick_bar").SetPhone("100").SetBuffer([]byte("{}")).SetState(user.StateLoggedOut).SaveX(ctx)
require.Equal(t, 1, u.Age)
require.Equal(t, "bar", u.Name)
require.Equal(t, []byte("{}"), u.Buffer)
u = u.Update().SetBuffer([]byte("[]")).SaveX(ctx)
require.Equal(t, []byte("[]"), u.Buffer)
require.Equal(t, user.StateLoggedOut, u.State)
err := u.Update().SetState(user.State("boring")).Exec(ctx)
require.Error(t, err, "invalid enum value")
u = u.Update().SetState(user.StateOnline).SaveX(ctx)
require.Equal(t, user.StateOnline, u.State)
err = client.User.Create().SetAge(1).SetName("foobarbazqux").SetNickname("nick_bar").SetPhone("200").Exec(ctx)
require.NoError(t, err, "name is not limited to 10 chars and nickname is not unique")
// New unique index was added to (age, phone).
err = client.User.Create().SetAge(1).SetName("foo").SetPhone("200").SetNickname("nick_bar").Exec(ctx)
require.Error(t, err)
require.True(t, entv2.IsConstraintError(err))
// Ensure all rows in the database have the same default for the `title` column.
require.Equal(
t,
client.User.Query().CountX(ctx),
client.User.Query().Where(user.Title(user.DefaultTitle)).CountX(ctx),
)
// Blob type was extended.
u, err = u.Update().SetBlob(make([]byte, 256)).SetState(user.StateLoggedOut).Save(ctx)
require.NoError(t, err, "data type blob was extended in v2")
require.Equal(t, make([]byte, 256), u.Blob)
if dbdialect != dialect.SQLite {
// Conversions
zero := client.Conversion.Query().Where(conversion.Name("zero")).OnlyX(ctx)
require.Equal(t, strconv.Itoa(0), zero.Int8ToString)
require.Equal(t, strconv.Itoa(0), zero.Uint8ToString)
require.Equal(t, strconv.Itoa(0), zero.Int16ToString)
require.Equal(t, strconv.Itoa(0), zero.Uint16ToString)
require.Equal(t, strconv.Itoa(0), zero.Int32ToString)
require.Equal(t, strconv.Itoa(0), zero.Uint32ToString)
require.Equal(t, strconv.Itoa(0), zero.Int64ToString)
require.Equal(t, strconv.Itoa(0), zero.Uint64ToString)
min := client.Conversion.Query().Where(conversion.Name("min")).OnlyX(ctx)
require.Equal(t, strconv.Itoa(math.MinInt8), min.Int8ToString)
require.Equal(t, strconv.Itoa(0), min.Uint8ToString)
require.Equal(t, strconv.Itoa(math.MinInt16), min.Int16ToString)
require.Equal(t, strconv.Itoa(0), min.Uint16ToString)
require.Equal(t, strconv.Itoa(math.MinInt32), min.Int32ToString)
require.Equal(t, strconv.Itoa(0), min.Uint32ToString)
require.Equal(t, strconv.Itoa(math.MinInt64), min.Int64ToString)
require.Equal(t, strconv.Itoa(0), min.Uint64ToString)
max := client.Conversion.Query().Where(conversion.Name("max")).OnlyX(ctx)
require.Equal(t, strconv.Itoa(math.MaxInt8), max.Int8ToString)
require.Equal(t, strconv.Itoa(math.MaxInt16), max.Int16ToString)
require.Equal(t, strconv.Itoa(math.MaxInt32), max.Int32ToString)
require.Equal(t, strconv.Itoa(math.MaxInt64), max.Int64ToString)
if dbdialect == dialect.Postgres {
require.Equal(t, strconv.Itoa(math.MaxInt8), max.Uint8ToString)
require.Equal(t, strconv.Itoa(math.MaxInt16), max.Uint16ToString)
require.Equal(t, strconv.Itoa(math.MaxInt32), max.Uint32ToString)
require.Equal(t, strconv.Itoa(math.MaxInt64), max.Uint64ToString)
} else {
require.Equal(t, strconv.Itoa(math.MaxUint8), max.Uint8ToString)
require.Equal(t, strconv.Itoa(math.MaxUint16), max.Uint16ToString)
require.Equal(t, strconv.Itoa(math.MaxUint32), max.Uint32ToString)
require.Equal(t, strconv.FormatUint(math.MaxUint64, 10), max.Uint64ToString)
}
}
}
func CheckConstraint(t *testing.T, client *entv2.Client) {
ctx := context.Background()
t.Log("testing check constraints")
err := client.Media.Create().SetText("boring").Exec(ctx)
require.Error(t, err)
err = client.Media.Create().SetSourceURI("entgo.io").Exec(ctx)
require.Error(t, err)
}
func NicknameSearch(t *testing.T, client *entv2.Client) {
ctx := context.Background()
names := client.User.Query().
Where(func(s *sql.Selector) {
s.Where(sql.P(func(b *sql.Builder) {
b.WriteString("MATCH(").Ident(user.FieldNickname).WriteString(") AGAINST(").Arg("nick_bar | nick_foo").WriteString(")")
}))
}).
Unique(true).
Order(entv2.Asc(user.FieldNickname)).
Select(user.FieldNickname).
StringsX(ctx)
require.Equal(t, []string{"nick_bar", "nick_foo"}, names)
}
func EqualFold(t *testing.T, client *entv2.Client) {
ctx := context.Background()
t.Log("testing equal-fold on sql specific dialects")
client.User.Create().SetAge(37).SetName("Alex").SetNickname("alexsn").SetPhone("123456789").SaveX(ctx)
require.False(t, client.User.Query().Where(user.NameEQ("alex")).ExistX(ctx))
require.True(t, client.User.Query().Where(user.NameEqualFold("alex")).ExistX(ctx))
}
func ContainsFold(t *testing.T, client *entv2.Client) {
ctx := context.Background()
t.Log("testing contains-fold on sql specific dialects")
client.User.Create().SetAge(30).SetName("Mashraki").SetNickname("a8m").SetPhone("102030").SaveX(ctx)
require.Zero(t, client.User.Query().Where(user.NameContains("mash")).CountX(ctx))
require.Equal(t, 1, client.User.Query().Where(user.NameContainsFold("mash")).CountX(ctx))
require.Equal(t, 1, client.User.Query().Where(user.NameContainsFold("Raki")).CountX(ctx))
}
func idRange(t *testing.T, id, l, h int) {
require.Truef(t, id > l && id < h, "id %s should be between %d to %d", id, l, h)
}