{{/* Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. */}} {{/* gotype: entgo.io/ent/entc/gen.Graph */}} {{ define "tx" }} {{ template "header" $ }} import ( "context" "sync" "entgo.io/ent/dialect" ) // Tx is a transactional client that is created by calling Client.Tx(). type Tx struct { config {{- range $n := $.Nodes }} // {{ $n.Name }} is the client for interacting with the {{ $n.Name }} builders. {{ $n.Name }} *{{ $n.Name }}Client {{- end }} // lazily loaded. client *Client clientOnce sync.Once // ctx lives for the life of the transaction. It is // the same context used by the underlying connection. ctx context.Context } {{ $funcs := dict "Commit" "Committer" "Rollback" "Rollbacker" }} {{ range $func := keys $funcs }} {{ $iface := get $funcs $func }} type ( // {{ $iface }} is the interface that wraps the {{ $func }} method. {{ $iface }} interface { {{ $func }}(context.Context, *Tx) error } // The {{ $func }}Func type is an adapter to allow the use of ordinary // function as a {{ $iface }}. If f is a function with the appropriate // signature, {{ $func }}Func(f) is a {{ $iface }} that calls f. {{ $func }}Func func(context.Context, *Tx) error // {{ $func }}Hook defines the "{{ lower $func }} middleware". A function that gets a {{ $iface }} // and returns a {{ $iface }}. For example: // // hook := func(next ent.{{ $iface }}) ent.{{ $iface }} { // return ent.{{ $func }}Func(func(ctx context.Context, tx *ent.Tx) error { // // Do some stuff before. // if err := next.{{ $func }}(ctx, tx); err != nil { // return err // } // // Do some stuff after. // return nil // }) // } // {{ $func }}Hook func({{ $iface }} ) {{ $iface }} ) // {{ $func }} calls f(ctx, m). func (f {{ $func }}Func) {{ $func }}(ctx context.Context, tx *Tx) error { return f(ctx, tx) } {{- $onFuncs := print "on" $func }} // {{ $func }} {{ lower $func }}s the transaction. func (tx *Tx) {{ $func }}() error { txDriver := tx.config.driver.(*txDriver) var fn {{ $iface }} = {{ $func }}Func(func(context.Context, *Tx) error { return txDriver.tx.{{ $func }}() }) txDriver.mu.Lock() hooks := append([]{{ $func }}Hook(nil), txDriver.{{ $onFuncs }}...) txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } return fn.{{ $func }}(tx.ctx, tx) } // On{{ $func }} adds a hook to call on {{ lower $func }}. func (tx *Tx) On{{ $func }}(f {{ $func }}Hook) { txDriver := tx.config.driver.(*txDriver) txDriver.mu.Lock() txDriver.{{ $onFuncs }} = append(txDriver.{{ $onFuncs }}, f) txDriver.mu.Unlock() } {{- end }} // Client returns a Client that binds to current transaction. func (tx *Tx) Client() *Client { tx.clientOnce.Do(func() { tx.client = &Client{config: tx.config} tx.client.init() }) return tx.client } func (tx *Tx) init() { {{- range $n := $.Nodes }} tx.{{ $n.Name }} = New{{ $n.Name }}Client(tx.config) {{- end }} } {{/* first node for doc example */}} {{- $first := index $.Nodes 0 }} // txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. // The idea is to support transactions without adding any extra code to the builders. // When a builder calls to driver.Tx(), it gets the same dialect.Tx instance. // Commit and Rollback are nop for the internal builders and the user must call one // of them in order to commit or rollback the transaction. // // If a closed transaction is embedded in one of the generated entities, and the entity // applies a query, for example: {{ $first.Name }}.QueryXXX(), the query will be executed // through the driver which created this transaction. // // Note that txDriver is not goroutine safe. type txDriver struct { // the driver we started the transaction from. drv dialect.Driver // tx is the underlying transaction. tx dialect.Tx // completion hooks. mu sync.Mutex onCommit []CommitHook onRollback []RollbackHook } // newTx creates a new transactional driver. func newTx(ctx context.Context, drv dialect.Driver) (*txDriver, error) { tx, err := drv.Tx(ctx) if err != nil { return nil, err } return &txDriver{tx: tx, drv: drv}, nil } // Tx returns the transaction wrapper (txDriver) to avoid Commit or Rollback calls // from the internal builders. Should be called only by the internal builders. func (tx *txDriver) Tx(context.Context) (dialect.Tx, error) { return tx, nil } // Dialect returns the dialect of the driver we started the transaction from. func (tx *txDriver) Dialect() string { return tx.drv.Dialect() } // Close is a nop close. func (*txDriver) Close() error { return nil } // Commit is a nop commit for the internal builders. // User must call `Tx.Commit` in order to commit the transaction. func (*txDriver) Commit() error { return nil } // Rollback is a nop rollback for the internal builders. // User must call `Tx.Rollback` in order to rollback the transaction. func (*txDriver) Rollback() error { return nil } // Exec calls tx.Exec. func (tx *txDriver) Exec(ctx context.Context, query string, args, v any) error { return tx.tx.Exec(ctx, query, args, v) } // Query calls tx.Query. func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error { return tx.tx.Query(ctx, query, args, v) } var _ dialect.Driver = (*txDriver)(nil) {{- with $tmpls := matchTemplate "tx/additional/*" "tx/additional/*/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{ end }}