mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
800 lines
27 KiB
Go
800 lines
27 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package hooks
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
"testing"
|
|
"time"
|
|
|
|
"entgo.io/ent/dialect/sql"
|
|
|
|
"entgo.io/ent/dialect"
|
|
|
|
"entgo.io/ent/entc/integration/hooks/ent"
|
|
"entgo.io/ent/entc/integration/hooks/ent/card"
|
|
"entgo.io/ent/entc/integration/hooks/ent/enttest"
|
|
"entgo.io/ent/entc/integration/hooks/ent/hook"
|
|
"entgo.io/ent/entc/integration/hooks/ent/intercept"
|
|
"entgo.io/ent/entc/integration/hooks/ent/migrate"
|
|
"entgo.io/ent/entc/integration/hooks/ent/pet"
|
|
"entgo.io/ent/entc/integration/hooks/ent/schema"
|
|
"entgo.io/ent/entc/integration/hooks/ent/user"
|
|
|
|
entgo "entgo.io/ent"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestSchemaHooks(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1", enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
|
defer client.Close()
|
|
err := client.Card.Create().SetNumber("123").Exec(ctx)
|
|
require.EqualError(t, err, "card number is too short", "error is returned from hook")
|
|
crd := client.Card.Create().SetNumber("1234").SaveX(ctx)
|
|
require.Equal(t, "unknown", crd.Name, "name was set by hook")
|
|
client.Card.Use(func(next ent.Mutator) ent.Mutator {
|
|
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
|
name, ok := m.Name()
|
|
require.True(t, !ok && name == "", "should be the first hook to execute")
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
})
|
|
client.Card.Create().SetNumber("1234").SaveX(ctx)
|
|
err = client.Card.Update().Exec(ctx)
|
|
require.EqualError(t, err, "OpUpdate operation is not allowed")
|
|
|
|
err = client.User.Update().SetPassword("pass").Exec(ctx)
|
|
require.EqualError(t, err, "password cannot be edited on update-many")
|
|
err = client.User.Update().ClearPassword().Exec(ctx)
|
|
require.EqualError(t, err, "password cannot be edited on update-many")
|
|
}
|
|
|
|
func TestRuntimeHooks(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1", enttest.WithOptions(ent.Log(t.Log)), enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
|
defer client.Close()
|
|
var calls int
|
|
client.Use(func(next ent.Mutator) ent.Mutator {
|
|
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
|
calls++
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
})
|
|
client.Card.Create().SetNumber("1234").SaveX(ctx)
|
|
client.User.Create().SetName("a8m").SaveX(ctx)
|
|
require.Equal(t, 2, calls)
|
|
client = client.Debug()
|
|
client.Card.Create().SetNumber("1234").SaveX(ctx)
|
|
client.User.Create().SetName("a8m").SaveX(ctx)
|
|
require.Equal(t, 4, calls, "debug client should keep the same hooks")
|
|
}
|
|
|
|
func TestRuntimeChain(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
var (
|
|
chain hook.Chain
|
|
values []int
|
|
)
|
|
for value := 0; value < 5; value++ {
|
|
chain = chain.Append(func(next ent.Mutator) ent.Mutator {
|
|
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
|
values = append(values, value)
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
})
|
|
}
|
|
client.User.Use(chain.Hook())
|
|
client.User.Create().SetName("alexsn").SaveX(ctx)
|
|
require.Len(t, values, 5)
|
|
require.True(t, sort.IntsAreSorted(values))
|
|
}
|
|
|
|
func TestMutationClient(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1", enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
|
defer client.Close()
|
|
client.Card.Use(func(next ent.Mutator) ent.Mutator {
|
|
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
|
id, _ := m.OwnerID()
|
|
usr := m.Client().User.GetX(ctx, id)
|
|
m.SetName(usr.Name)
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
})
|
|
a8m := client.User.Create().SetName("a8m").SaveX(ctx)
|
|
crd := client.Card.Create().SetNumber("1234").SetOwner(a8m).SaveX(ctx)
|
|
require.Equal(t, a8m.Name, crd.Name)
|
|
}
|
|
|
|
func TestMutatorClient(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
client.Use(
|
|
hook.On(
|
|
func(next ent.Mutator) ent.Mutator {
|
|
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
|
op := ent.OpUpdate
|
|
if _, exists := m.ID(); exists && m.Op().Is(ent.OpDeleteOne) {
|
|
op = ent.OpUpdateOne
|
|
}
|
|
// Ensure card was not expired before.
|
|
m.Where(card.ExpiredAtIsNil())
|
|
// Count the number of affected records.
|
|
ids, err := m.IDs(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Change the operation to update.
|
|
m.SetOp(op)
|
|
// Record when card was expired.
|
|
m.SetExpiredAt(time.Now())
|
|
// Execute the update operation.
|
|
if _, err := m.Client().Mutate(ctx, m); err != nil {
|
|
return nil, err
|
|
}
|
|
return len(ids), nil
|
|
})
|
|
},
|
|
ent.OpDelete|ent.OpDeleteOne,
|
|
),
|
|
)
|
|
c1 := client.Card.Create().SetNumber("1234").SaveX(ctx)
|
|
client.Card.DeleteOne(c1).ExecX(ctx)
|
|
expired := client.Card.Query().OnlyX(ctx)
|
|
require.False(t, expired.ExpiredAt.IsZero())
|
|
|
|
client.Card.Create().SetNumber("4567").ExecX(ctx)
|
|
client.Card.Create().SetNumber("7890").ExecX(ctx)
|
|
client.Card.Delete().Where(card.Number("4567")).ExecX(ctx)
|
|
cards := client.Card.Query().Order(ent.Asc(card.FieldNumber)).AllX(ctx)
|
|
require.Len(t, cards, 3)
|
|
require.True(t, cards[0].ExpiredAt.Equal(expired.ExpiredAt), "expired field should not be updated")
|
|
require.False(t, cards[1].ExpiredAt.IsZero())
|
|
require.True(t, cards[2].ExpiredAt.IsZero())
|
|
|
|
client.Card.Delete().ExecX(ctx)
|
|
require.False(t, client.Card.Query().Where(card.ExpiredAtIsNil()).ExistX(ctx))
|
|
}
|
|
|
|
func TestMutationTx(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1", enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
|
defer client.Close()
|
|
client.Card.Use(func(next ent.Mutator) ent.Mutator {
|
|
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
|
tx, err := m.Tx()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := tx.Rollback(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, fmt.Errorf("rolled back")
|
|
})
|
|
})
|
|
tx, err := client.Tx(ctx)
|
|
require.NoError(t, err)
|
|
a8m := tx.User.Create().SetName("a8m").SaveX(ctx)
|
|
crd, err := tx.Card.Create().SetNumber("1234").SetOwner(a8m).Save(ctx)
|
|
require.EqualError(t, err, "rolled back")
|
|
require.Nil(t, crd)
|
|
_, err = tx.Card.Query().All(ctx)
|
|
require.Error(t, err, "tx already rolled back")
|
|
}
|
|
|
|
func TestDeletion(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1", enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
|
defer client.Close()
|
|
client.User.Use(func(next ent.Mutator) ent.Mutator {
|
|
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
|
|
if !m.Op().Is(ent.OpDeleteOne) {
|
|
return next.Mutate(ctx, m)
|
|
}
|
|
id, ok := m.ID()
|
|
if !ok {
|
|
return nil, fmt.Errorf("missing id")
|
|
}
|
|
m.Client().Card.Delete().Where(card.HasOwnerWith(user.ID(id))).ExecX(ctx)
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
})
|
|
a8m := client.User.Create().SetName("a8m").SaveX(ctx)
|
|
for i := 0; i < 5; i++ {
|
|
client.Card.Create().SetNumber(fmt.Sprintf("card-%d", i)).SetOwner(a8m).SaveX(ctx)
|
|
}
|
|
client.User.DeleteOne(a8m).ExecX(ctx)
|
|
require.Zero(t, client.User.Query().CountX(ctx))
|
|
require.Zero(t, client.Card.Query().CountX(ctx))
|
|
}
|
|
|
|
func TestMutationIDs(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
count := make(map[ent.Op]int)
|
|
client.User.Use(
|
|
hook.Unless(
|
|
func(next ent.Mutator) ent.Mutator {
|
|
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
|
|
count[m.Op()]++
|
|
ids, err := m.IDs(ctx)
|
|
require.NoError(t, err)
|
|
require.Len(t, ids, 1)
|
|
require.Equal(t, count[m.Op()], ids[0])
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
},
|
|
ent.OpCreate,
|
|
),
|
|
)
|
|
for i := 0; i < 5; i++ {
|
|
owner := client.User.Create().SetName(fmt.Sprintf("owner-%d", i)).SaveX(ctx)
|
|
client.Card.Create().SetNumber(fmt.Sprintf("card-%d", i)).SetOwner(owner).ExecX(ctx)
|
|
}
|
|
for i := 0; i < 5; i++ {
|
|
p := user.And(user.Name(fmt.Sprintf("owner-%d", i)), user.HasCardsWith(card.Number(fmt.Sprintf("card-%d", i))))
|
|
client.User.Update().AddVersion(1).Where(p).ExecX(ctx)
|
|
client.User.Delete().Where(p).ExecX(ctx)
|
|
}
|
|
}
|
|
|
|
func TestPostCreation(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
client.Card.Use(hook.On(func(next ent.Mutator) ent.Mutator {
|
|
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
|
id, exists := m.ID()
|
|
require.False(t, exists, "id should not exist pre mutation")
|
|
require.Zero(t, id)
|
|
value, err := next.Mutate(ctx, m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, exists = m.ID()
|
|
require.True(t, exists, "id should exist post mutation")
|
|
require.NotZero(t, id)
|
|
require.True(t, id == value.(*ent.Card).ID)
|
|
return value, nil
|
|
})
|
|
}, ent.OpCreate))
|
|
client.Card.Create().SetNumber("12345").SetName("a8m").SaveX(ctx)
|
|
client.Card.CreateBulk(client.Card.Create().SetNumber("12345")).SaveX(ctx)
|
|
}
|
|
|
|
func TestUpdateAfterCreation(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
client.User.Use(hook.On(func(next ent.Mutator) ent.Mutator {
|
|
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
|
|
value, err := next.Mutate(ctx, m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
existingUser, ok := value.(*ent.User)
|
|
require.Truef(t, ok, "value should be of type %T", existingUser)
|
|
require.Equal(t, 1, existingUser.Version, "version does not match the original value")
|
|
|
|
// After the user was created, return its updated version (a new object).
|
|
newUser := m.Client().User.UpdateOne(existingUser).
|
|
SetVersion(2).
|
|
SaveX(ctx)
|
|
return newUser, nil
|
|
})
|
|
}, ent.OpCreate))
|
|
|
|
u := client.User.Create().SetName("a8m").SetVersion(1).SaveX(ctx)
|
|
require.Equal(t, 2, u.Version, "version mutation in hook should have propagated back to call site")
|
|
}
|
|
|
|
func TestUpdateAfterUpdateOne(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
client.User.Use(hook.On(func(next ent.Mutator) ent.Mutator {
|
|
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
|
|
value, err := next.Mutate(ctx, m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
u, ok := value.(*ent.User)
|
|
require.Truef(t, ok, "value should be of type %T", u)
|
|
require.Equal(t, 2, u.Version, "version does not match the original value")
|
|
|
|
// After the user was created, return its updated version (a new object). Don't use UpdateOne because it
|
|
// will cause recursive calls to this hook.
|
|
m.Client().User.Update().
|
|
Where(user.IDEQ(u.ID)).
|
|
SetVersion(3).
|
|
SaveX(ctx)
|
|
|
|
return m.Client().User.Get(ctx, u.ID)
|
|
})
|
|
}, ent.OpUpdateOne))
|
|
|
|
u := client.User.Create().SetName("a8m").SetVersion(1).SaveX(ctx)
|
|
u = client.User.UpdateOne(u).SetVersion(2).SaveX(ctx)
|
|
|
|
require.Equal(t, 3, u.Version, "version mutation in hook should have propagated back to call site")
|
|
}
|
|
|
|
func TestOldValues(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1", enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
|
defer client.Close()
|
|
|
|
// Querying old fields post mutation should fail.
|
|
client.Card.Use(hook.On(func(next ent.Mutator) ent.Mutator {
|
|
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
|
value, err := next.Mutate(ctx, m)
|
|
require.NoError(t, err)
|
|
_, err = m.OldNumber(ctx)
|
|
require.Error(t, err)
|
|
return value, nil
|
|
})
|
|
}, ent.OpUpdateOne))
|
|
crd := client.Card.Create().SetNumber("1234").SetName("a8m").SaveX(ctx)
|
|
client.Card.UpdateOneID(crd.ID).SetName("a8m").SaveX(ctx)
|
|
|
|
// A typed hook.
|
|
client.User.Use(hook.On(func(next ent.Mutator) ent.Mutator {
|
|
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
|
|
name, err := m.OldName(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
require.Equal(t, "a8m", name)
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
}, ent.OpUpdateOne))
|
|
// A generic hook (executed on all types).
|
|
client.Use(hook.On(func(next ent.Mutator) ent.Mutator {
|
|
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
|
namer, ok := m.(interface {
|
|
OldName(context.Context) (string, error)
|
|
})
|
|
if !ok {
|
|
// Skip if the mutation does not have
|
|
// a method for getting the old name.
|
|
return next.Mutate(ctx, m)
|
|
}
|
|
name, err := namer.OldName(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
require.Equal(t, "a8m", name)
|
|
value, err := next.Mutate(ctx, m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_, err = namer.OldName(ctx)
|
|
require.NoError(t, err)
|
|
return value, nil
|
|
})
|
|
}, ent.OpUpdateOne))
|
|
// A generic hook (executed on all types).
|
|
client.Use(hook.Unless(func(next ent.Mutator) ent.Mutator {
|
|
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
|
namer, ok := m.(interface {
|
|
OldName(context.Context) (string, error)
|
|
})
|
|
if !ok {
|
|
// Skip if the mutation does not have
|
|
// a method for getting the old name.
|
|
return next.Mutate(ctx, m)
|
|
}
|
|
name, err := namer.OldName(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
require.Equal(t, "a8m", name)
|
|
value, err := next.Mutate(ctx, m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_, err = namer.OldName(ctx)
|
|
require.NoError(t, err)
|
|
return value, nil
|
|
})
|
|
}, ^ent.OpUpdateOne))
|
|
a8m := client.User.Create().SetName("a8m").SaveX(ctx)
|
|
require.Equal(t, "a8m", a8m.Name)
|
|
err := client.User.UpdateOne(a8m).SetName("Ariel").SetVersion(a8m.Version).Exec(ctx)
|
|
require.EqualError(t, err, "version field must be incremented by 1")
|
|
a8m = client.User.UpdateOne(a8m).SetName("Ariel").SetVersion(a8m.Version + 1).SaveX(ctx)
|
|
require.Equal(t, "Ariel", a8m.Name)
|
|
}
|
|
|
|
func TestConditions(t *testing.T) {
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1", enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
|
defer client.Close()
|
|
|
|
var calls int
|
|
defer func() { require.Equal(t, 2, calls) }()
|
|
client.Card.Use(hook.If(func(next ent.Mutator) ent.Mutator {
|
|
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
|
require.True(t, m.Op().Is(ent.OpUpdateOne))
|
|
calls++
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
}, hook.Or(
|
|
hook.HasFields(card.FieldName),
|
|
hook.HasClearedFields(card.FieldName),
|
|
)))
|
|
client.User.Use(hook.If(func(next ent.Mutator) ent.Mutator {
|
|
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
|
|
require.True(t, m.Op().Is(ent.OpUpdate))
|
|
incr, exists := m.AddedWorth()
|
|
require.True(t, exists)
|
|
require.EqualValues(t, 100, incr)
|
|
return next.Mutate(ctx, m)
|
|
})
|
|
}, hook.HasAddedFields(user.FieldWorth)))
|
|
|
|
ctx := context.Background()
|
|
crd := client.Card.Create().SetNumber("9876").SaveX(ctx)
|
|
crd = crd.Update().SetName("alexsn").SaveX(ctx)
|
|
crd = crd.Update().ClearName().SaveX(ctx)
|
|
client.Card.DeleteOne(crd).ExecX(ctx)
|
|
|
|
alexsn := client.User.Create().SetName("alexsn").SaveX(ctx)
|
|
client.User.Update().Where(user.ID(alexsn.ID)).AddWorth(100).SaveX(ctx)
|
|
client.User.DeleteOne(alexsn).ExecX(ctx)
|
|
}
|
|
|
|
func TestRuntimeTx(t *testing.T) {
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1", enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
|
defer client.Close()
|
|
client.Card.Use(func(next ent.Mutator) ent.Mutator {
|
|
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
|
v, err := next.Mutate(ctx, m)
|
|
require.NoError(t, err)
|
|
tx, err := m.Tx()
|
|
require.NoError(t, err)
|
|
tx.OnCommit(func(next ent.Committer) ent.Committer {
|
|
return ent.CommitFunc(func(ctx context.Context, tx *ent.Tx) error {
|
|
// Ensure the transaction can see the created card.
|
|
tx.Card.GetX(ctx, v.(*ent.Card).ID)
|
|
// Cause the transaction to fail.
|
|
require.NoError(t, tx.Rollback())
|
|
return fmt.Errorf("fail")
|
|
})
|
|
})
|
|
return v, nil
|
|
})
|
|
})
|
|
ctx := context.Background()
|
|
tx, err := client.Tx(ctx)
|
|
require.NoError(t, err)
|
|
tx.Card.Create().SetNumber("9876").ExecX(ctx)
|
|
require.EqualError(t, tx.Commit(), "fail")
|
|
require.Zero(t, client.Card.Query().CountX(ctx), "database is empty")
|
|
}
|
|
|
|
func TestInterceptor_Sanity(t *testing.T) {
|
|
ctx := context.Background()
|
|
t.Run("All", func(t *testing.T) {
|
|
var (
|
|
calls int
|
|
client = enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
)
|
|
defer client.Close()
|
|
client.Intercept(
|
|
ent.InterceptFunc(func(next ent.Querier) ent.Querier {
|
|
return ent.QuerierFunc(func(ctx context.Context, query ent.Query) (ent.Value, error) {
|
|
_, err := intercept.NewQuery(query)
|
|
require.NoError(t, err)
|
|
calls++
|
|
nodes, err := next.Query(ctx, query)
|
|
require.NoError(t, err)
|
|
require.IsType(t, ([]*ent.User)(nil), nodes)
|
|
return nodes, nil
|
|
})
|
|
}),
|
|
intercept.TraverseFunc(func(ctx context.Context, query intercept.Query) error {
|
|
calls++
|
|
return nil
|
|
}),
|
|
)
|
|
client.User.Query().AllX(ctx)
|
|
require.Equal(t, 2, calls)
|
|
})
|
|
t.Run("Count", func(t *testing.T) {
|
|
var (
|
|
calls int
|
|
client = enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
)
|
|
defer client.Close()
|
|
client.Intercept(
|
|
ent.InterceptFunc(func(next ent.Querier) ent.Querier {
|
|
return ent.QuerierFunc(func(ctx context.Context, query ent.Query) (ent.Value, error) {
|
|
calls++
|
|
count, err := next.Query(ctx, query)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
return count, nil
|
|
})
|
|
}),
|
|
intercept.TraverseFunc(func(ctx context.Context, query intercept.Query) error {
|
|
calls++
|
|
return nil
|
|
}),
|
|
)
|
|
client.User.Query().CountX(ctx)
|
|
require.Equal(t, 2, calls)
|
|
})
|
|
t.Run("IDs", func(t *testing.T) {
|
|
var (
|
|
calls int
|
|
client = enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
users = client.User.CreateBulk(
|
|
client.User.Create().SetName("a8m"),
|
|
client.User.Create().SetName("nati"),
|
|
).SaveX(ctx)
|
|
)
|
|
defer client.Close()
|
|
client.Intercept(
|
|
ent.InterceptFunc(func(next ent.Querier) ent.Querier {
|
|
return ent.QuerierFunc(func(ctx context.Context, query ent.Query) (ent.Value, error) {
|
|
calls++
|
|
ids, err := next.Query(ctx, query)
|
|
require.NoError(t, err)
|
|
// IDs() uses Select() under the hood, therefore,
|
|
// scanned values are returned as values and not pointers.
|
|
require.IsType(t, ([]int)(nil), ids)
|
|
require.Equal(t, []int{users[0].ID}, ids)
|
|
// Values can be changed and affect the return value.
|
|
return append(ids.([]int), 10, 20), nil
|
|
})
|
|
}),
|
|
intercept.TraverseFunc(func(ctx context.Context, query intercept.Query) error {
|
|
calls++
|
|
query.WhereP(user.ID(users[0].ID))
|
|
return nil
|
|
}),
|
|
)
|
|
require.Equal(t, []int{users[0].ID, 10, 20}, client.User.Query().IDsX(ctx))
|
|
require.Equal(t, 2, calls)
|
|
})
|
|
t.Run("GroupBy", func(t *testing.T) {
|
|
type n2c struct {
|
|
N string `sql:"name"`
|
|
C int `sql:"count"`
|
|
}
|
|
var (
|
|
calls int
|
|
client = enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
)
|
|
defer client.Close()
|
|
client.Intercept(
|
|
ent.InterceptFunc(func(next ent.Querier) ent.Querier {
|
|
return ent.QuerierFunc(func(ctx context.Context, query ent.Query) (ent.Value, error) {
|
|
calls++
|
|
vs, err := next.Query(ctx, query)
|
|
require.NoError(t, err)
|
|
return append(vs.([]n2c), n2c{N: "fake", C: 10}), nil
|
|
})
|
|
}),
|
|
intercept.TraverseFunc(func(ctx context.Context, query intercept.Query) error {
|
|
calls++
|
|
query.WhereP(card.NameNEQ("a8m"))
|
|
return nil
|
|
}),
|
|
)
|
|
client.Card.CreateBulk(
|
|
client.Card.Create().SetName("a8m").SetNumber("1234"),
|
|
client.Card.Create().SetName("a8m").SetNumber("5678"),
|
|
client.Card.Create().SetName("nati").SetNumber("9876"),
|
|
).ExecX(ctx)
|
|
var vs []n2c
|
|
require.NoError(t, client.Card.Query().GroupBy(user.FieldName).Aggregate(ent.Count()).Scan(ctx, &vs))
|
|
require.Equal(t, []n2c{{"nati", 1}, {"fake", 10}}, vs)
|
|
require.Equal(t, 2, calls)
|
|
})
|
|
}
|
|
|
|
func TestSoftDelete(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
|
|
a8m := client.User.Create().SetName("a8m").SaveX(ctx)
|
|
pets := client.Pet.CreateBulk(
|
|
client.Pet.Create().SetName("a").SetOwner(a8m),
|
|
client.Pet.Create().SetName("b").SetOwner(a8m),
|
|
// Set delete_time manually.
|
|
client.Pet.Create().SetName("c").SetOwner(a8m).SetDeleteTime(time.Now()),
|
|
).SaveX(ctx)
|
|
require.Equal(t, []int{pets[0].ID, pets[1].ID}, client.Pet.Query().Order(ent.Asc(pet.FieldID)).IDsX(ctx))
|
|
|
|
// DeleteOne using the API.
|
|
client.Pet.DeleteOne(pets[1]).ExecX(ctx)
|
|
require.Equal(t, pets[0].ID, client.Pet.Query().OnlyIDX(ctx))
|
|
|
|
// Delete using the API.
|
|
n := client.Pet.Delete().ExecX(ctx)
|
|
require.Equal(t, 1, n)
|
|
require.False(t, client.Pet.Query().ExistX(ctx))
|
|
|
|
// Query entities through the interceptor.
|
|
require.Zero(t, client.Pet.Query().CountX(ctx))
|
|
require.Empty(t, client.Pet.Query().AllX(ctx))
|
|
// Skip soft-deleted interceptors.
|
|
require.Equal(t, 3, client.Pet.Query().CountX(schema.SkipSoftDelete(ctx)))
|
|
require.Len(t, client.Pet.Query().AllX(schema.SkipSoftDelete(ctx)), 3)
|
|
|
|
client.Pet.CreateBulk(
|
|
client.Pet.Create().SetName("d"),
|
|
client.Pet.Create().SetName("d"),
|
|
client.Pet.Create().SetName("d"),
|
|
).ExecX(ctx)
|
|
|
|
// Select entities through the interceptor.
|
|
names := client.Pet.Query().Unique(true).Select(pet.FieldName).StringsX(ctx)
|
|
require.Equal(t, []string{"d"}, names)
|
|
names = client.Pet.Query().Unique(true).Order(ent.Asc(pet.FieldName)).Select(pet.FieldName).StringsX(schema.SkipSoftDelete(ctx))
|
|
require.Equal(t, []string{"a", "b", "c", "d"}, names)
|
|
|
|
// Group by.
|
|
var n2c []struct {
|
|
N string `sql:"name"`
|
|
C int `sql:"count"`
|
|
}
|
|
client.Pet.Query().GroupBy(pet.FieldName).Aggregate(ent.Count()).ScanX(ctx, &n2c)
|
|
require.Equal(t, "d", n2c[0].N)
|
|
require.Equal(t, 3, n2c[0].C)
|
|
n2c = nil
|
|
client.Pet.Query().GroupBy(pet.FieldName).Aggregate(ent.Count()).ScanX(schema.SkipSoftDelete(ctx), &n2c)
|
|
require.Len(t, n2c, 4)
|
|
|
|
// Edge traversals.
|
|
require.False(t, a8m.QueryPets().ExistX(ctx), "interceptors should be applied on edge traversals")
|
|
require.False(t, client.User.Query().QueryPets().ExistX(ctx))
|
|
require.Equal(t, 3, a8m.QueryPets().CountX(schema.SkipSoftDelete(ctx)))
|
|
|
|
// Eager-loading edges.
|
|
a8m = client.User.Query().WithPets().OnlyX(ctx)
|
|
require.Empty(t, a8m.Edges.Pets)
|
|
}
|
|
|
|
func TestTraverseUnique(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
|
|
a8m := client.User.Create().SetName("a8m").SaveX(ctx)
|
|
client.Pet.CreateBulk(
|
|
client.Pet.Create().SetName("a").SetOwner(a8m),
|
|
client.Pet.Create().SetName("b").SetOwner(a8m),
|
|
).ExecX(ctx)
|
|
require.Equal(t, 1, client.Pet.Query().QueryOwner().CountX(ctx))
|
|
|
|
// Disable unique traversal using interceptors.
|
|
client.User.Intercept(
|
|
intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
|
|
q.Unique(false)
|
|
return nil
|
|
}),
|
|
)
|
|
// The JOIN with pets will return the same owner twice, one for each pet.
|
|
require.Equal(t, 2, client.Pet.Query().QueryOwner().CountX(ctx))
|
|
}
|
|
|
|
// The following example demonstrates how to write interceptors that
|
|
// can be used by multiple packages/projects using generics.
|
|
func TestSharedInterceptor(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
client.User.Create().SetName("a8m").ExecX(ctx)
|
|
client.User.Create().SetName("nati").ExecX(ctx)
|
|
require.Len(t, client.User.Query().AllX(ctx), 2)
|
|
client.Intercept(SharedLimiter(intercept.NewQuery, 1))
|
|
require.Len(t, client.User.Query().AllX(ctx), 1)
|
|
}
|
|
|
|
// Project-level example. The usage of "entgo" package demonstrates how to
|
|
// write a generic interceptor that can be used by any ent-based project.
|
|
func SharedLimiter[Q interface{ Limit(int) }](f func(entgo.Query) (Q, error), limit int) entgo.Interceptor {
|
|
return entgo.InterceptFunc(func(next entgo.Querier) entgo.Querier {
|
|
return entgo.QuerierFunc(func(ctx context.Context, query entgo.Query) (ent.Value, error) {
|
|
l, err := f(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
l.Limit(limit)
|
|
return next.Query(ctx, query)
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestTypedTraverser(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
a8m, nat := client.User.Create().SetName("a8m").SaveX(ctx), client.User.Create().SetName("nati").SetActive(false).SaveX(ctx)
|
|
client.Pet.CreateBulk(
|
|
client.Pet.Create().SetName("a").SetOwner(a8m),
|
|
client.Pet.Create().SetName("b").SetOwner(a8m),
|
|
client.Pet.Create().SetName("c").SetOwner(nat),
|
|
).ExecX(ctx)
|
|
|
|
// Get all pets of all users.
|
|
if n := client.User.Query().QueryPets().CountX(ctx); n != 3 {
|
|
t.Errorf("got %d pets, want 3", n)
|
|
}
|
|
|
|
// Add an interceptor that filters out inactive users.
|
|
client.User.Intercept(
|
|
ent.TraverseFunc(func(ctx context.Context, query ent.Query) error {
|
|
if q, ok := query.(*ent.UserQuery); ok {
|
|
q.Where(user.Active(true))
|
|
}
|
|
return nil
|
|
}),
|
|
)
|
|
// Only pets of active users are returned.
|
|
if n := client.User.Query().QueryPets().CountX(ctx); n != 2 {
|
|
t.Errorf("got %d pets, want 2", n)
|
|
}
|
|
}
|
|
|
|
func TestLimitInterceptor(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
client.User.Create().SetName("a8m").SaveX(ctx)
|
|
client.User.Create().SetName("nati").SaveX(ctx)
|
|
require.Len(t, client.User.Query().AllX(ctx), 2)
|
|
client.Intercept(
|
|
intercept.Func(func(ctx context.Context, q intercept.Query) error {
|
|
q.Limit(1)
|
|
return nil
|
|
}),
|
|
)
|
|
require.Len(t, client.User.Query().AllX(ctx), 1)
|
|
}
|
|
|
|
func TestFilterTraverseFunc(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
|
|
defer client.Close()
|
|
a8m, nat := client.User.Create().SetName("a8m").SaveX(ctx), client.User.Create().SetName("nati").SetActive(false).SaveX(ctx)
|
|
client.Pet.CreateBulk(
|
|
client.Pet.Create().SetName("a").SetOwner(a8m),
|
|
client.Pet.Create().SetName("b").SetOwner(a8m),
|
|
client.Pet.Create().SetName("c").SetOwner(nat),
|
|
).ExecX(ctx)
|
|
// Get all pets of all users.
|
|
if n := client.User.Query().QueryPets().CountX(ctx); n != 3 {
|
|
t.Errorf("got %d pets, want 3", n)
|
|
}
|
|
|
|
// Add an interceptor that filters out inactive users.
|
|
client.User.Intercept(
|
|
intercept.TraverseFunc(func(ctx context.Context, query intercept.Query) error {
|
|
query.WhereP(func(s *sql.Selector) {
|
|
s.Where(sql.EQ("active", true))
|
|
})
|
|
return nil
|
|
}),
|
|
)
|
|
// Only pets of active users are returned.
|
|
if n := client.User.Query().QueryPets().CountX(ctx); n != 2 {
|
|
t.Errorf("got %d pets, want 2", n)
|
|
}
|
|
}
|