mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: attach tx hooks to underlying driver (#2980)
Allow attaching hooks to new instances of ent.Tx. For example, ent.Mutation.Tx().OnCommit.
This commit is contained in:
@@ -24,12 +24,6 @@ type Tx struct {
|
||||
// lazily loaded.
|
||||
client *Client
|
||||
clientOnce sync.Once
|
||||
|
||||
// completion callbacks.
|
||||
mu sync.Mutex
|
||||
onCommit []CommitHook
|
||||
onRollback []RollbackHook
|
||||
|
||||
// ctx lives for the life of the transaction. It is
|
||||
// the same context used by the underlying connection.
|
||||
ctx context.Context
|
||||
@@ -74,9 +68,9 @@ func (tx *Tx) Commit() error {
|
||||
var fn Committer = CommitFunc(func(context.Context, *Tx) error {
|
||||
return txDriver.tx.Commit()
|
||||
})
|
||||
tx.mu.Lock()
|
||||
hooks := append([]CommitHook(nil), tx.onCommit...)
|
||||
tx.mu.Unlock()
|
||||
txDriver.mu.Lock()
|
||||
hooks := append([]CommitHook(nil), txDriver.onCommit...)
|
||||
txDriver.mu.Unlock()
|
||||
for i := len(hooks) - 1; i >= 0; i-- {
|
||||
fn = hooks[i](fn)
|
||||
}
|
||||
@@ -85,9 +79,10 @@ func (tx *Tx) Commit() error {
|
||||
|
||||
// OnCommit adds a hook to call on commit.
|
||||
func (tx *Tx) OnCommit(f CommitHook) {
|
||||
tx.mu.Lock()
|
||||
defer tx.mu.Unlock()
|
||||
tx.onCommit = append(tx.onCommit, f)
|
||||
txDriver := tx.config.driver.(*txDriver)
|
||||
txDriver.mu.Lock()
|
||||
txDriver.onCommit = append(txDriver.onCommit, f)
|
||||
txDriver.mu.Unlock()
|
||||
}
|
||||
|
||||
type (
|
||||
@@ -129,9 +124,9 @@ func (tx *Tx) Rollback() error {
|
||||
var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error {
|
||||
return txDriver.tx.Rollback()
|
||||
})
|
||||
tx.mu.Lock()
|
||||
hooks := append([]RollbackHook(nil), tx.onRollback...)
|
||||
tx.mu.Unlock()
|
||||
txDriver.mu.Lock()
|
||||
hooks := append([]RollbackHook(nil), txDriver.onRollback...)
|
||||
txDriver.mu.Unlock()
|
||||
for i := len(hooks) - 1; i >= 0; i-- {
|
||||
fn = hooks[i](fn)
|
||||
}
|
||||
@@ -140,9 +135,10 @@ func (tx *Tx) Rollback() error {
|
||||
|
||||
// OnRollback adds a hook to call on rollback.
|
||||
func (tx *Tx) OnRollback(f RollbackHook) {
|
||||
tx.mu.Lock()
|
||||
defer tx.mu.Unlock()
|
||||
tx.onRollback = append(tx.onRollback, f)
|
||||
txDriver := tx.config.driver.(*txDriver)
|
||||
txDriver.mu.Lock()
|
||||
txDriver.onRollback = append(txDriver.onRollback, f)
|
||||
txDriver.mu.Unlock()
|
||||
}
|
||||
|
||||
// Client returns a Client that binds to current transaction.
|
||||
@@ -175,6 +171,10 @@ type txDriver struct {
|
||||
drv dialect.Driver
|
||||
// tx is the underlying transaction.
|
||||
tx dialect.Tx
|
||||
// completion hooks.
|
||||
mu sync.Mutex
|
||||
onCommit []CommitHook
|
||||
onRollback []RollbackHook
|
||||
}
|
||||
|
||||
// newTx creates a new transactional driver.
|
||||
|
||||
@@ -394,3 +394,32 @@ func TestConditions(t *testing.T) {
|
||||
client.User.Update().Where(user.ID(alexsn.ID)).AddWorth(100).SaveX(ctx)
|
||||
client.User.DeleteOne(alexsn).ExecX(ctx)
|
||||
}
|
||||
|
||||
func TestRuntimeTx(t *testing.T) {
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1", enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)))
|
||||
defer client.Close()
|
||||
client.Card.Use(func(next ent.Mutator) ent.Mutator {
|
||||
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
|
||||
v, err := next.Mutate(ctx, m)
|
||||
require.NoError(t, err)
|
||||
tx, err := m.Tx()
|
||||
require.NoError(t, err)
|
||||
tx.OnCommit(func(next ent.Committer) ent.Committer {
|
||||
return ent.CommitFunc(func(ctx context.Context, tx *ent.Tx) error {
|
||||
// Ensure the transaction can see the created card.
|
||||
tx.Card.GetX(ctx, v.(*ent.Card).ID)
|
||||
// Cause the transaction to fail.
|
||||
require.NoError(t, tx.Rollback())
|
||||
return fmt.Errorf("fail")
|
||||
})
|
||||
})
|
||||
return v, nil
|
||||
})
|
||||
})
|
||||
ctx := context.Background()
|
||||
tx, err := client.Tx(ctx)
|
||||
require.NoError(t, err)
|
||||
tx.Card.Create().SetNumber("9876").ExecX(ctx)
|
||||
require.EqualError(t, tx.Commit(), "fail")
|
||||
require.Zero(t, client.Card.Query().CountX(ctx), "database is empty")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user