// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package customid import ( "context" "database/sql" "fmt" "net" "strconv" "testing" "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/schema" "entgo.io/ent/entc/integration/customid/ent" "entgo.io/ent/entc/integration/customid/ent/blob" "entgo.io/ent/entc/integration/customid/ent/doc" "entgo.io/ent/entc/integration/customid/ent/intsid" "entgo.io/ent/entc/integration/customid/ent/pet" "entgo.io/ent/entc/integration/customid/ent/token" "entgo.io/ent/entc/integration/customid/ent/user" "entgo.io/ent/entc/integration/customid/sid" "entgo.io/ent/schema/field" atlas "ariga.io/atlas/sql/schema" "github.com/go-sql-driver/mysql" "github.com/google/uuid" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/require" ) func TestMySQL(t *testing.T) { for version, port := range map[string]int{"56": 3306, "57": 3307, "8": 3308} { addr := net.JoinHostPort("localhost", strconv.Itoa(port)) t.Run(version, func(t *testing.T) { cfg := mysql.Config{ User: "root", Passwd: "pass", Net: "tcp", Addr: addr, AllowNativePasswords: true, ParseTime: true, } db, err := sql.Open("mysql", cfg.FormatDSN()) require.NoError(t, err) defer db.Close() _, err = db.Exec("CREATE DATABASE IF NOT EXISTS custom_id") require.NoError(t, err, "creating database") defer db.Exec("DROP DATABASE IF EXISTS custom_id") cfg.DBName = "custom_id" client, err := ent.Open("mysql", cfg.FormatDSN()) require.NoError(t, err, "connecting to custom_id database") err = client.Schema.Create(context.Background(), schema.WithHooks(clearDefault, skipBytesID)) require.NoError(t, err) CustomID(t, client) }) } } func TestPostgres(t *testing.T) { for version, port := range map[string]int{"10": 5430, "11": 5431, "12": 5433, "13": 5434} { t.Run(version, func(t *testing.T) { dsn := fmt.Sprintf("host=localhost port=%d user=postgres password=pass sslmode=disable dbname=test", port) db, err := sql.Open(dialect.Postgres, dsn) require.NoError(t, err) defer db.Close() _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS custom_id") require.NoError(t, err, "creating schema") _, err = db.Exec("SET search_path TO custom_id") require.NoError(t, err, "setting schema") _, err = db.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp" SCHEMA custom_id`) require.NoError(t, err, "creating extension") defer db.Exec(`DROP EXTENSION "uuid-ossp"`) defer db.Exec("DROP SCHEMA custom_id CASCADE") client := ent.NewClient(ent.Driver(entsql.OpenDB(dialect.Postgres, db))) err = client.Schema.Create(context.Background(), schema.WithDiffHook(expectOnePetsIndex)) require.NoError(t, err) CustomID(t, client) BytesID(t, client) }) } } func TestSQLite(t *testing.T) { client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") require.NoError(t, err) defer client.Close() require.NoError(t, client.Schema.Create(context.Background(), schema.WithHooks(clearDefault))) CustomID(t, client) BytesID(t, client) } func CustomID(t *testing.T, client *ent.Client) { ctx := context.Background() nat := client.User.Create().SaveX(ctx) require.Equal(t, 1, nat.ID) _, err := client.User.Create().SetID(1).Save(ctx) require.True(t, ent.IsConstraintError(err), "duplicate id") a8m := client.User.Create().SetID(5).SaveX(ctx) require.Equal(t, 5, a8m.ID) hub := client.Group.Create().SetID(3).AddUsers(a8m, nat).SaveX(ctx) require.Equal(t, 3, hub.ID) require.Equal(t, []int{1, 5}, hub.QueryUsers().Order(ent.Asc(user.FieldID)).IDsX(ctx)) blb := client.Blob.Create().SaveX(ctx) require.NotEmpty(t, blb.ID, "use default value") id := uuid.New() chd := client.Blob.Create().SetID(id).SetParent(blb).SaveX(ctx) require.Equal(t, id, chd.ID, "use provided id") require.Equal(t, blb.ID, chd.QueryParent().OnlyX(ctx).ID) lnk := client.Blob.Create().SetID(uuid.New()).AddLinks(chd, blb).SaveX(ctx) require.Equal(t, 2, lnk.QueryLinks().CountX(ctx)) require.Equal(t, lnk.ID, chd.QueryLinks().OnlyX(ctx).ID) require.Equal(t, lnk.ID, blb.QueryLinks().OnlyX(ctx).ID) require.Len(t, client.Blob.Query().IDsX(ctx), 3) links := lnk.QueryBlobLinks().AllX(ctx) require.Len(t, links, 2) require.Equal(t, lnk.ID, links[0].BlobID) require.NotEqual(t, uuid.Nil, links[0].LinkID) require.Equal(t, lnk.ID, links[1].BlobID) require.NotEqual(t, uuid.Nil, links[1].LinkID) pedro := client.Pet.Create().SetID("pedro").SetOwner(a8m).SaveX(ctx) require.Equal(t, a8m.ID, pedro.QueryOwner().OnlyIDX(ctx)) require.Equal(t, pedro.ID, a8m.QueryPets().OnlyIDX(ctx)) xabi := client.Pet.Create().SetID("xabi").AddFriends(pedro).SetBestFriend(pedro).SaveX(ctx) require.Equal(t, "xabi", xabi.ID) pedro = client.Pet.Query().Where(pet.HasOwnerWith(user.ID(a8m.ID))).OnlyX(ctx) require.Equal(t, "pedro", pedro.ID) pets := client.Pet.Query().WithFriends().WithBestFriend().Order(ent.Asc(pet.FieldID)).AllX(ctx) require.Len(t, pets, 2) require.Equal(t, pedro.ID, pets[0].ID) require.NotNil(t, pets[0].Edges.BestFriend) require.Equal(t, xabi.ID, pets[0].Edges.BestFriend.ID) require.Len(t, pets[0].Edges.Friends, 1) require.Equal(t, xabi.ID, pets[0].Edges.Friends[0].ID) require.Equal(t, xabi.ID, pets[1].ID) require.NotNil(t, pets[1].Edges.BestFriend) require.Equal(t, pedro.ID, pets[1].Edges.BestFriend.ID) require.Len(t, pets[1].Edges.Friends, 1) require.Equal(t, pedro.ID, pets[1].Edges.Friends[0].ID) bee := client.Car.Create().SetModel("Chevrolet Camaro").SetOwner(pedro).SaveX(ctx) require.NotNil(t, bee) bee = client.Car.Query().WithOwner().OnlyX(ctx) require.Equal(t, "Chevrolet Camaro", bee.Model) require.NotNil(t, bee.Edges.Owner) require.Equal(t, pedro.ID, bee.Edges.Owner.ID) pets = client.Pet.CreateBulk( client.Pet.Create().SetID("luna").SetOwner(a8m).AddFriends(xabi), client.Pet.Create().SetID("layla").SetOwner(a8m).AddFriendIDs(pedro.ID), client.Pet.Create().AddFriends(pedro, xabi), ).SaveX(ctx) require.Equal(t, "luna", pets[0].ID) require.Equal(t, xabi.ID, pets[0].QueryFriends().OnlyIDX(ctx)) require.Equal(t, "layla", pets[1].ID) require.Equal(t, pedro.ID, pets[1].QueryFriends().OnlyIDX(ctx)) require.Equal(t, []string{"pedro", "xabi"}, pets[2].QueryFriends().Order(ent.Asc(pet.FieldID)).IDsX(ctx)) u1, u2 := uuid.New(), uuid.New() blobs := client.Blob.CreateBulk( client.Blob.Create().SetID(u1), client.Blob.Create().SetID(u2), ).SaveX(ctx) require.Equal(t, u1, blobs[0].ID) require.Equal(t, u2, blobs[1].ID) parent := client.Note.Create().SetText("parent").SaveX(ctx) require.NotEmpty(t, parent.ID) require.NotEmpty(t, parent.Text) child := client.Note.Create().SetText("child").SetParent(parent).SaveX(ctx) require.NotEmpty(t, child.QueryParent().OnlyIDX(ctx)) pdoc := client.Doc.Create().SetText("parent").SaveX(ctx) require.NotEmpty(t, pdoc.ID) require.NotEmpty(t, pdoc.Text) cdoc := client.Doc.Create().SetText("child").SetParent(pdoc).SaveX(ctx) require.NotEmpty(t, cdoc.QueryParent().OnlyIDX(ctx)) t.Run("IntSID", func(t *testing.T) { root := client.IntSID.Create().SaveX(ctx) require.EqualValues(t, sid.ID("1"), root.ID) children := client.IntSID.CreateBulk( client.IntSID.Create().SetParent(root), client.IntSID.Create().SetParent(root), ).SaveX(ctx) require.EqualValues(t, sid.ID("2"), children[0].ID) require.EqualValues(t, sid.ID("3"), children[1].ID) el := client.IntSID.Query().Where(intsid.ID(root.ID)).WithChildren().AllX(ctx) require.EqualValues(t, 1, len(el)) require.EqualValues(t, 2, len(el[0].Edges.Children)) cid := sid.ID("100") child := client.IntSID.Create().SetID(cid).SetParent(root).SaveX(ctx) require.EqualValues(t, cid, child.ID) require.EqualValues(t, root.ID, child.QueryParent().OnlyX(ctx).ID) }) t.Run("Upsert", func(t *testing.T) { id := uuid.New() client.Blob.Create(). SetID(id). OnConflictColumns(blob.FieldID). UpdateNewValues(). ExecX(ctx) require.Zero(t, client.Blob.GetX(ctx, id).Count) client.Blob.Create(). SetID(id). OnConflictColumns(blob.FieldID). Update(func(set *ent.BlobUpsert) { set.AddCount(1) }). ExecX(ctx) require.Equal(t, 1, client.Blob.GetX(ctx, id).Count) d := client.Doc.Create().SaveX(ctx) client.Doc.Create(). SetID(d.ID). OnConflictColumns(doc.FieldID). SetText("Hello World"). UpdateNewValues(). ExecX(ctx) require.Equal(t, "Hello World", client.Doc.GetX(ctx, d.ID).Text) }) t.Run("Other ID", func(t *testing.T) { o := client.Other.Create().SaveX(ctx) require.NotEmpty(t, o.ID.String()) o = client.Other.Create().SetID(sid.NewLength(15)).SaveX(ctx) require.NotEmpty(t, o.ID.String()) }) t.Run("CustomID edge", func(t *testing.T) { a := client.Account.Create().SetEmail("test@example.org").SaveX(ctx) require.NotEmpty(t, a.ID) tk := client.Token.Create().SetAccountID(a.ID).SetBody("token").SaveX(ctx) require.NotEmpty(t, tk.ID) ta := client.Token.Query().Where(token.Body("token")).WithAccount().FirstX(ctx) require.Equal(t, tk.ID, ta.ID) require.NotNil(t, ta.Edges.Account) require.Equal(t, a.ID, ta.Edges.Account.ID) }) t.Run("UUID compatible", func(t *testing.T) { l := client.Link.Create().SaveX(ctx) require.NotEmpty(t, l.ID) require.Len(t, l.LinkInformation, 1) require.Equal(t, "ent", l.LinkInformation["ent"].Name) require.Equal(t, "https://entgo.io/", l.LinkInformation["ent"].Link) }) } func BytesID(t *testing.T, client *ent.Client) { ctx := context.Background() s := client.Session.Create().SaveX(ctx) require.NotEmpty(t, s.ID) client.Device.Create().SetActiveSession(s).AddSessionIDs(s.ID).SaveX(ctx) d := client.Device.Query().WithActiveSession().WithSessions().OnlyX(ctx) require.Equal(t, s.ID, d.Edges.ActiveSession.ID) require.Equal(t, s.ID, d.Edges.Sessions[0].ID) } // clearDefault clears the id's default for non-postgres dialects. func clearDefault(c schema.Creator) schema.Creator { return schema.CreateFunc(func(ctx context.Context, tables ...*schema.Table) error { // Drop DEFAULT clause for MySQL without changing the tables. ct := make([]*schema.Table, len(tables)) copy(ct, tables) *ct[1] = *tables[1] ct[1].Columns = append([]*schema.Column(nil), tables[1].Columns...) *ct[1].Columns[0] = *tables[1].Columns[0] ct[1].Columns[0].Default = nil return c.Create(ctx, ct...) }) } // skipBytesID tables with blob ids from the migration. func skipBytesID(c schema.Creator) schema.Creator { return schema.CreateFunc(func(ctx context.Context, tables ...*schema.Table) error { t := make([]*schema.Table, 0, len(tables)) for i := range tables { if tables[i].PrimaryKey[0].Type == field.TypeBytes { continue } t = append(t, tables[i]) } return c.Create(ctx, t...) }) } // expectOnePetsIndex expects that pets table contains only one index. func expectOnePetsIndex(next schema.Differ) schema.Differ { return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) { changes, err := next.Diff(current, desired) for _, c := range changes { addT, ok := c.(*atlas.AddTable) if !ok || addT.T.Name != pet.Table { continue } if n := len(addT.T.Indexes); n != 1 { return nil, fmt.Errorf("expect only one index, but got: %d", n) } } return changes, err }) }