mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen/mutation: add IDs method for mutations
This commit is contained in:
committed by
Ariel Mashraki
parent
e8e4401d27
commit
caa673826a
@@ -8,6 +8,7 @@ package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -80,7 +81,7 @@ func withCardID(id int) cardOption {
|
||||
m.oldValue = func(ctx context.Context) (*Card, error) {
|
||||
once.Do(func() {
|
||||
if m.done {
|
||||
err = fmt.Errorf("querying old values post mutation is not allowed")
|
||||
err = errors.New("querying old values post mutation is not allowed")
|
||||
} else {
|
||||
value, err = m.Client().Card.Get(ctx, id)
|
||||
}
|
||||
@@ -113,7 +114,7 @@ func (m CardMutation) Client() *Client {
|
||||
// it returns an error otherwise.
|
||||
func (m CardMutation) Tx() (*Tx, error) {
|
||||
if _, ok := m.driver.(*txDriver); !ok {
|
||||
return nil, fmt.Errorf("ent: mutation is not running in a transaction")
|
||||
return nil, errors.New("ent: mutation is not running in a transaction")
|
||||
}
|
||||
tx := &Tx{config: m.config}
|
||||
tx.init()
|
||||
@@ -129,6 +130,25 @@ func (m *CardMutation) ID() (id int, exists bool) {
|
||||
return *m.id, true
|
||||
}
|
||||
|
||||
// IDs queries the database and returns the entity ids that match the mutation's predicate.
|
||||
// That means, if the mutation is applied within a transaction with an isolation level such
|
||||
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
|
||||
// or updated by the mutation.
|
||||
func (m *CardMutation) IDs(ctx context.Context) ([]int, error) {
|
||||
switch {
|
||||
case m.op.Is(OpUpdateOne | OpDeleteOne):
|
||||
id, exists := m.ID()
|
||||
if exists {
|
||||
return []int{id}, nil
|
||||
}
|
||||
fallthrough
|
||||
case m.op.Is(OpUpdate | OpDelete):
|
||||
return m.Client().Card.Query().Where(m.predicates...).IDs(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
|
||||
}
|
||||
}
|
||||
|
||||
// SetNumber sets the "number" field.
|
||||
func (m *CardMutation) SetNumber(s string) {
|
||||
m.number = &s
|
||||
@@ -148,10 +168,10 @@ func (m *CardMutation) Number() (r string, exists bool) {
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *CardMutation) OldNumber(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, fmt.Errorf("OldNumber is only allowed on UpdateOne operations")
|
||||
return v, errors.New("OldNumber is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, fmt.Errorf("OldNumber requires an ID field in the mutation")
|
||||
return v, errors.New("OldNumber requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
@@ -184,10 +204,10 @@ func (m *CardMutation) Name() (r string, exists bool) {
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *CardMutation) OldName(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, fmt.Errorf("OldName is only allowed on UpdateOne operations")
|
||||
return v, errors.New("OldName is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, fmt.Errorf("OldName requires an ID field in the mutation")
|
||||
return v, errors.New("OldName requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
@@ -233,10 +253,10 @@ func (m *CardMutation) CreatedAt() (r time.Time, exists bool) {
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *CardMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, fmt.Errorf("OldCreatedAt is only allowed on UpdateOne operations")
|
||||
return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, fmt.Errorf("OldCreatedAt requires an ID field in the mutation")
|
||||
return v, errors.New("OldCreatedAt requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
@@ -269,10 +289,10 @@ func (m *CardMutation) InHook() (r string, exists bool) {
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *CardMutation) OldInHook(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, fmt.Errorf("OldInHook is only allowed on UpdateOne operations")
|
||||
return v, errors.New("OldInHook is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, fmt.Errorf("OldInHook requires an ID field in the mutation")
|
||||
return v, errors.New("OldInHook requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
@@ -632,7 +652,7 @@ func withUserID(id int) userOption {
|
||||
m.oldValue = func(ctx context.Context) (*User, error) {
|
||||
once.Do(func() {
|
||||
if m.done {
|
||||
err = fmt.Errorf("querying old values post mutation is not allowed")
|
||||
err = errors.New("querying old values post mutation is not allowed")
|
||||
} else {
|
||||
value, err = m.Client().User.Get(ctx, id)
|
||||
}
|
||||
@@ -665,7 +685,7 @@ func (m UserMutation) Client() *Client {
|
||||
// it returns an error otherwise.
|
||||
func (m UserMutation) Tx() (*Tx, error) {
|
||||
if _, ok := m.driver.(*txDriver); !ok {
|
||||
return nil, fmt.Errorf("ent: mutation is not running in a transaction")
|
||||
return nil, errors.New("ent: mutation is not running in a transaction")
|
||||
}
|
||||
tx := &Tx{config: m.config}
|
||||
tx.init()
|
||||
@@ -681,6 +701,25 @@ func (m *UserMutation) ID() (id int, exists bool) {
|
||||
return *m.id, true
|
||||
}
|
||||
|
||||
// IDs queries the database and returns the entity ids that match the mutation's predicate.
|
||||
// That means, if the mutation is applied within a transaction with an isolation level such
|
||||
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
|
||||
// or updated by the mutation.
|
||||
func (m *UserMutation) IDs(ctx context.Context) ([]int, error) {
|
||||
switch {
|
||||
case m.op.Is(OpUpdateOne | OpDeleteOne):
|
||||
id, exists := m.ID()
|
||||
if exists {
|
||||
return []int{id}, nil
|
||||
}
|
||||
fallthrough
|
||||
case m.op.Is(OpUpdate | OpDelete):
|
||||
return m.Client().User.Query().Where(m.predicates...).IDs(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
|
||||
}
|
||||
}
|
||||
|
||||
// SetVersion sets the "version" field.
|
||||
func (m *UserMutation) SetVersion(i int) {
|
||||
m.version = &i
|
||||
@@ -701,10 +740,10 @@ func (m *UserMutation) Version() (r int, exists bool) {
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UserMutation) OldVersion(ctx context.Context) (v int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, fmt.Errorf("OldVersion is only allowed on UpdateOne operations")
|
||||
return v, errors.New("OldVersion is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, fmt.Errorf("OldVersion requires an ID field in the mutation")
|
||||
return v, errors.New("OldVersion requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
@@ -756,10 +795,10 @@ func (m *UserMutation) Name() (r string, exists bool) {
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UserMutation) OldName(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, fmt.Errorf("OldName is only allowed on UpdateOne operations")
|
||||
return v, errors.New("OldName is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, fmt.Errorf("OldName requires an ID field in the mutation")
|
||||
return v, errors.New("OldName requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
@@ -793,10 +832,10 @@ func (m *UserMutation) Worth() (r uint, exists bool) {
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UserMutation) OldWorth(ctx context.Context) (v uint, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, fmt.Errorf("OldWorth is only allowed on UpdateOne operations")
|
||||
return v, errors.New("OldWorth is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, fmt.Errorf("OldWorth requires an ID field in the mutation")
|
||||
return v, errors.New("OldWorth requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -152,6 +152,37 @@ func TestDeletion(t *testing.T) {
|
||||
require.Zero(t, client.Card.Query().CountX(ctx))
|
||||
}
|
||||
|
||||
func TestMutationIDs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
count := make(map[ent.Op]int)
|
||||
client.User.Use(
|
||||
hook.Unless(
|
||||
func(next ent.Mutator) ent.Mutator {
|
||||
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
|
||||
count[m.Op()]++
|
||||
ids, err := m.IDs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ids, 1)
|
||||
require.Equal(t, count[m.Op()], ids[0])
|
||||
return next.Mutate(ctx, m)
|
||||
})
|
||||
},
|
||||
ent.OpCreate,
|
||||
),
|
||||
)
|
||||
for i := 0; i < 5; i++ {
|
||||
owner := client.User.Create().SetName(fmt.Sprintf("owner-%d", i)).SaveX(ctx)
|
||||
client.Card.Create().SetNumber(fmt.Sprintf("card-%d", i)).SetOwner(owner).ExecX(ctx)
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
p := user.And(user.Name(fmt.Sprintf("owner-%d", i)), user.HasCardsWith(card.Number(fmt.Sprintf("card-%d", i))))
|
||||
client.User.Update().AddVersion(1).Where(p).ExecX(ctx)
|
||||
client.User.Delete().Where(p).ExecX(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostCreation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
|
||||
Reference in New Issue
Block a user