doc: explain how to use policies in migrations (#4141)

This commit is contained in:
Ariel Mashraki
2024-07-14 14:53:22 +03:00
committed by GitHub
parent 1073ce511e
commit 9f61938bcc
37 changed files with 5768 additions and 7 deletions

View File

@@ -8,7 +8,9 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"strconv"
"strings"
"entgo.io/ent/dialect"
@@ -81,6 +83,31 @@ type Tx struct {
driver.Tx
}
// ctyVarsKey is the key used for attaching and reading the context variables.
type ctxVarsKey struct{}
// sessionVars holds sessions/transactions variables to set before every statement.
type sessionVars struct {
vars []struct{ k, v string }
}
// WithVar returns a new context that holds the session variable to be executed before every query.
func WithVar(ctx context.Context, name, value string) context.Context {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
sv.vars = append(sv.vars, struct {
k, v string
}{
k: name,
v: value,
})
return context.WithValue(ctx, ctxVarsKey{}, sv)
}
// WithIntVar calls WithVar with the string representation of the value.
func WithIntVar(ctx context.Context, name string, value int) context.Context {
return WithVar(ctx, name, strconv.Itoa(value))
}
// ExecQuerier wraps the standard Exec and Query methods.
type ExecQuerier interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
@@ -93,18 +120,25 @@ type Conn struct {
}
// Exec implements the dialect.Exec method.
func (c Conn) Exec(ctx context.Context, query string, args, v any) error {
func (c Conn) Exec(ctx context.Context, query string, args, v any) (rerr error) {
argv, ok := args.([]any)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", v)
}
ex, cf, err := c.maySetVars(ctx)
if err != nil {
return err
}
if cf != nil {
defer func() { rerr = errors.Join(rerr, cf()) }()
}
switch v := v.(type) {
case nil:
if _, err := c.ExecContext(ctx, query, argv...); err != nil {
if _, err := ex.ExecContext(ctx, query, argv...); err != nil {
return err
}
case *sql.Result:
res, err := c.ExecContext(ctx, query, argv...)
res, err := ex.ExecContext(ctx, query, argv...)
if err != nil {
return err
}
@@ -125,14 +159,55 @@ func (c Conn) Query(ctx context.Context, query string, args, v any) error {
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", args)
}
rows, err := c.QueryContext(ctx, query, argv...)
ex, cf, err := c.maySetVars(ctx)
if err != nil {
return err
}
rows, err := ex.QueryContext(ctx, query, argv...)
if err != nil {
if cf != nil {
err = errors.Join(err, cf())
}
return err
}
*vr = Rows{rows}
if cf != nil {
vr.ColumnScanner = rowsWithCloser{rows, cf}
}
return nil
}
// maySetVars sets the session variables before executing a query.
func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error) {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
if len(sv.vars) == 0 {
return c, nil, nil
}
var (
ex ExecQuerier // Underlying ExecQuerier.
cf func() error // Close function.
)
switch e := c.ExecQuerier.(type) {
case *sql.Tx:
ex = e
case *sql.DB:
conn, err := e.Conn(ctx)
if err != nil {
return nil, nil, err
}
ex, cf = conn, conn.Close
}
for _, s := range sv.vars {
if _, err := ex.ExecContext(ctx, fmt.Sprintf("SET %s = '%s'", s.k, s.v)); err != nil {
if cf != nil {
err = errors.Join(err, cf())
}
return nil, nil, err
}
}
return ex, cf, nil
}
var _ dialect.Driver = (*Driver)(nil)
type (
@@ -154,9 +229,8 @@ type (
TxOptions = sql.TxOptions
)
// NullScanner represents an sql.Scanner that may be null.
// NullScanner implements the sql.Scanner interface so it can
// be used as a scan destination, similar to the types above.
// NullScanner implements the sql.Scanner interface such that it
// can be used as a scan destination, similar to the types above.
type NullScanner struct {
S sql.Scanner
Valid bool // Valid is true if the Scan value is not NULL.
@@ -182,3 +256,15 @@ type ColumnScanner interface {
NextResultSet() bool
Scan(dest ...any) error
}
// rowsWithCloser wraps the ColumnScanner interface with a custom Close hook.
type rowsWithCloser struct {
ColumnScanner
closer func() error
}
// Close closes the underlying ColumnScanner and calls the custom closer.
func (r rowsWithCloser) Close() error {
err := r.ColumnScanner.Close()
return errors.Join(err, r.closer())
}