mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: add support for tx hooks (#575)
This commit is contained in:
@@ -67,7 +67,8 @@ func Open(driverName, dataSourceName string, options ...Option) (*Client, error)
|
||||
}
|
||||
}
|
||||
|
||||
// Tx returns a new transactional client.
|
||||
// Tx returns a new transactional client. The provided context
|
||||
// is used until the transaction is committed or rolled back.
|
||||
func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
||||
if _, ok := c.driver.(*txDriver); ok {
|
||||
return nil, fmt.Errorf("ent: cannot start a transaction within a transaction")
|
||||
@@ -78,6 +79,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
||||
}
|
||||
cfg := config{driver: tx, log: c.log, debug: c.debug, hooks: c.hooks}
|
||||
return &Tx{
|
||||
ctx: ctx,
|
||||
config: cfg,
|
||||
Car: NewCarClient(cfg),
|
||||
Group: NewGroupClient(cfg),
|
||||
|
||||
@@ -29,41 +29,119 @@ type Tx struct {
|
||||
|
||||
// completion callbacks.
|
||||
mu sync.Mutex
|
||||
onCommit []func(error)
|
||||
onRollback []func(error)
|
||||
onCommit []CommitHook
|
||||
onRollback []RollbackHook
|
||||
|
||||
// ctx lives for the life of the transaction. It is
|
||||
// the same context used by the underlying connection.
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type (
|
||||
// Committer is the interface that wraps the Committer method.
|
||||
Committer interface {
|
||||
Commit(context.Context, *Tx) error
|
||||
}
|
||||
|
||||
// The CommitFunc type is an adapter to allow the use of ordinary
|
||||
// function as a Committer. If f is a function with the appropriate
|
||||
// signature, CommitFunc(f) is a Committer that calls f.
|
||||
CommitFunc func(context.Context, *Tx) error
|
||||
|
||||
// CommitHook defines the "commit middleware". A function that gets a Committer
|
||||
// and returns a Committer. For example:
|
||||
//
|
||||
// hook := func(next ent.Committer) ent.Committer {
|
||||
// return ent.CommitFunc(func(context.Context, tx *ent.Tx) error {
|
||||
// // Do some stuff before.
|
||||
// if err := next.Commit(ctx, tx); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// // Do some stuff after.
|
||||
// return nil
|
||||
// })
|
||||
// }
|
||||
//
|
||||
CommitHook func(Committer) Committer
|
||||
)
|
||||
|
||||
// Commit calls f(ctx, m).
|
||||
func (f CommitFunc) Commit(ctx context.Context, tx *Tx) error {
|
||||
return f(ctx, tx)
|
||||
}
|
||||
|
||||
// Commit commits the transaction.
|
||||
func (tx *Tx) Commit() error {
|
||||
err := tx.config.driver.(*txDriver).tx.Commit()
|
||||
txDriver := tx.config.driver.(*txDriver)
|
||||
var fn Committer = CommitFunc(func(context.Context, *Tx) error {
|
||||
return txDriver.tx.Commit()
|
||||
})
|
||||
tx.mu.Lock()
|
||||
defer tx.mu.Unlock()
|
||||
for _, f := range tx.onCommit {
|
||||
f(err)
|
||||
hooks := append([]CommitHook(nil), tx.onCommit...)
|
||||
tx.mu.Unlock()
|
||||
for i := len(hooks) - 1; i >= 0; i-- {
|
||||
fn = hooks[i](fn)
|
||||
}
|
||||
return err
|
||||
return fn.Commit(tx.ctx, tx)
|
||||
}
|
||||
|
||||
// OnCommit adds a function to call on commit.
|
||||
func (tx *Tx) OnCommit(f func(error)) {
|
||||
// OnCommit adds a hook to call on commit.
|
||||
func (tx *Tx) OnCommit(f CommitHook) {
|
||||
tx.mu.Lock()
|
||||
defer tx.mu.Unlock()
|
||||
tx.onCommit = append(tx.onCommit, f)
|
||||
}
|
||||
|
||||
// Rollback rollbacks the transaction.
|
||||
func (tx *Tx) Rollback() error {
|
||||
err := tx.config.driver.(*txDriver).tx.Rollback()
|
||||
tx.mu.Lock()
|
||||
defer tx.mu.Unlock()
|
||||
for _, f := range tx.onRollback {
|
||||
f(err)
|
||||
type (
|
||||
// Rollbacker is the interface that wraps the Rollbacker method.
|
||||
Rollbacker interface {
|
||||
Rollback(context.Context, *Tx) error
|
||||
}
|
||||
return err
|
||||
|
||||
// The RollbackFunc type is an adapter to allow the use of ordinary
|
||||
// function as a Rollbacker. If f is a function with the appropriate
|
||||
// signature, RollbackFunc(f) is a Rollbacker that calls f.
|
||||
RollbackFunc func(context.Context, *Tx) error
|
||||
|
||||
// RollbackHook defines the "rollback middleware". A function that gets a Rollbacker
|
||||
// and returns a Rollbacker. For example:
|
||||
//
|
||||
// hook := func(next ent.Rollbacker) ent.Rollbacker {
|
||||
// return ent.RollbackFunc(func(context.Context, tx *ent.Tx) error {
|
||||
// // Do some stuff before.
|
||||
// if err := next.Rollback(ctx, tx); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// // Do some stuff after.
|
||||
// return nil
|
||||
// })
|
||||
// }
|
||||
//
|
||||
RollbackHook func(Rollbacker) Rollbacker
|
||||
)
|
||||
|
||||
// Rollback calls f(ctx, m).
|
||||
func (f RollbackFunc) Rollback(ctx context.Context, tx *Tx) error {
|
||||
return f(ctx, tx)
|
||||
}
|
||||
|
||||
// OnRollback adds a function to call on rollback.
|
||||
func (tx *Tx) OnRollback(f func(error)) {
|
||||
// Rollback rollbacks the transaction.
|
||||
func (tx *Tx) Rollback() error {
|
||||
txDriver := tx.config.driver.(*txDriver)
|
||||
var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error {
|
||||
return txDriver.tx.Rollback()
|
||||
})
|
||||
tx.mu.Lock()
|
||||
hooks := append([]RollbackHook(nil), tx.onRollback...)
|
||||
tx.mu.Unlock()
|
||||
for i := len(hooks) - 1; i >= 0; i-- {
|
||||
fn = hooks[i](fn)
|
||||
}
|
||||
return fn.Rollback(tx.ctx, tx)
|
||||
}
|
||||
|
||||
// OnRollback adds a hook to call on rollback.
|
||||
func (tx *Tx) OnRollback(f RollbackHook) {
|
||||
tx.mu.Lock()
|
||||
defer tx.mu.Unlock()
|
||||
tx.onRollback = append(tx.onRollback, f)
|
||||
|
||||
Reference in New Issue
Block a user