mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
doc: explain how to use policies in migrations (#4141)
This commit is contained in:
@@ -8,7 +8,9 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
@@ -81,6 +83,31 @@ type Tx struct {
|
||||
driver.Tx
|
||||
}
|
||||
|
||||
// ctyVarsKey is the key used for attaching and reading the context variables.
|
||||
type ctxVarsKey struct{}
|
||||
|
||||
// sessionVars holds sessions/transactions variables to set before every statement.
|
||||
type sessionVars struct {
|
||||
vars []struct{ k, v string }
|
||||
}
|
||||
|
||||
// WithVar returns a new context that holds the session variable to be executed before every query.
|
||||
func WithVar(ctx context.Context, name, value string) context.Context {
|
||||
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
|
||||
sv.vars = append(sv.vars, struct {
|
||||
k, v string
|
||||
}{
|
||||
k: name,
|
||||
v: value,
|
||||
})
|
||||
return context.WithValue(ctx, ctxVarsKey{}, sv)
|
||||
}
|
||||
|
||||
// WithIntVar calls WithVar with the string representation of the value.
|
||||
func WithIntVar(ctx context.Context, name string, value int) context.Context {
|
||||
return WithVar(ctx, name, strconv.Itoa(value))
|
||||
}
|
||||
|
||||
// ExecQuerier wraps the standard Exec and Query methods.
|
||||
type ExecQuerier interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
@@ -93,18 +120,25 @@ type Conn struct {
|
||||
}
|
||||
|
||||
// Exec implements the dialect.Exec method.
|
||||
func (c Conn) Exec(ctx context.Context, query string, args, v any) error {
|
||||
func (c Conn) Exec(ctx context.Context, query string, args, v any) (rerr error) {
|
||||
argv, ok := args.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", v)
|
||||
}
|
||||
ex, cf, err := c.maySetVars(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cf != nil {
|
||||
defer func() { rerr = errors.Join(rerr, cf()) }()
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case nil:
|
||||
if _, err := c.ExecContext(ctx, query, argv...); err != nil {
|
||||
if _, err := ex.ExecContext(ctx, query, argv...); err != nil {
|
||||
return err
|
||||
}
|
||||
case *sql.Result:
|
||||
res, err := c.ExecContext(ctx, query, argv...)
|
||||
res, err := ex.ExecContext(ctx, query, argv...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,14 +159,55 @@ func (c Conn) Query(ctx context.Context, query string, args, v any) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", args)
|
||||
}
|
||||
rows, err := c.QueryContext(ctx, query, argv...)
|
||||
ex, cf, err := c.maySetVars(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := ex.QueryContext(ctx, query, argv...)
|
||||
if err != nil {
|
||||
if cf != nil {
|
||||
err = errors.Join(err, cf())
|
||||
}
|
||||
return err
|
||||
}
|
||||
*vr = Rows{rows}
|
||||
if cf != nil {
|
||||
vr.ColumnScanner = rowsWithCloser{rows, cf}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// maySetVars sets the session variables before executing a query.
|
||||
func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error) {
|
||||
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
|
||||
if len(sv.vars) == 0 {
|
||||
return c, nil, nil
|
||||
}
|
||||
var (
|
||||
ex ExecQuerier // Underlying ExecQuerier.
|
||||
cf func() error // Close function.
|
||||
)
|
||||
switch e := c.ExecQuerier.(type) {
|
||||
case *sql.Tx:
|
||||
ex = e
|
||||
case *sql.DB:
|
||||
conn, err := e.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ex, cf = conn, conn.Close
|
||||
}
|
||||
for _, s := range sv.vars {
|
||||
if _, err := ex.ExecContext(ctx, fmt.Sprintf("SET %s = '%s'", s.k, s.v)); err != nil {
|
||||
if cf != nil {
|
||||
err = errors.Join(err, cf())
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
return ex, cf, nil
|
||||
}
|
||||
|
||||
var _ dialect.Driver = (*Driver)(nil)
|
||||
|
||||
type (
|
||||
@@ -154,9 +229,8 @@ type (
|
||||
TxOptions = sql.TxOptions
|
||||
)
|
||||
|
||||
// NullScanner represents an sql.Scanner that may be null.
|
||||
// NullScanner implements the sql.Scanner interface so it can
|
||||
// be used as a scan destination, similar to the types above.
|
||||
// NullScanner implements the sql.Scanner interface such that it
|
||||
// can be used as a scan destination, similar to the types above.
|
||||
type NullScanner struct {
|
||||
S sql.Scanner
|
||||
Valid bool // Valid is true if the Scan value is not NULL.
|
||||
@@ -182,3 +256,15 @@ type ColumnScanner interface {
|
||||
NextResultSet() bool
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
// rowsWithCloser wraps the ColumnScanner interface with a custom Close hook.
|
||||
type rowsWithCloser struct {
|
||||
ColumnScanner
|
||||
closer func() error
|
||||
}
|
||||
|
||||
// Close closes the underlying ColumnScanner and calls the custom closer.
|
||||
func (r rowsWithCloser) Close() error {
|
||||
err := r.ColumnScanner.Close()
|
||||
return errors.Join(err, r.closer())
|
||||
}
|
||||
|
||||
87
dialect/sql/driver_test.go
Normal file
87
dialect/sql/driver_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWithVars(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db.SetMaxOpenConns(1)
|
||||
drv := OpenDB("sqlite3", db)
|
||||
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
rows := &Rows{}
|
||||
err = drv.Query(
|
||||
WithVar(context.Background(), "foo", "bar"),
|
||||
"SELECT 1",
|
||||
[]any{},
|
||||
rows,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
require.NoError(t, rows.Close(), "rows should be closed to release the connection")
|
||||
|
||||
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("SET foo = 'baz'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
err = drv.Query(
|
||||
WithVar(WithVar(context.Background(), "foo", "bar"), "foo", "baz"),
|
||||
"SELECT 1",
|
||||
[]any{},
|
||||
rows,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
require.NoError(t, rows.Close(), "rows should be closed to release the connection")
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
mock.ExpectCommit()
|
||||
tx, err := drv.Tx(context.Background())
|
||||
require.NoError(t, err)
|
||||
err = tx.Query(
|
||||
WithVar(context.Background(), "foo", "bar"),
|
||||
"SELECT 1",
|
||||
[]any{},
|
||||
rows,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tx.Commit())
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
// Rows should not be closed to release the session,
|
||||
// as a transaction is always scoped to a single connection.
|
||||
|
||||
mock.ExpectExec("SET foo = 'qux'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO users DEFAULT VALUES").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
err = drv.Exec(
|
||||
WithVar(context.Background(), "foo", "qux"),
|
||||
"INSERT INTO users DEFAULT VALUES",
|
||||
[]any{},
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
// No rows are returned, so no need to close them.
|
||||
|
||||
mock.ExpectExec("SET foo = 'foo'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO users DEFAULT VALUES").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
err = drv.Exec(
|
||||
WithVar(context.Background(), "foo", "foo"),
|
||||
"INSERT INTO users DEFAULT VALUES",
|
||||
[]any{},
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
// No rows are returned, so no need to close them.
|
||||
}
|
||||
Reference in New Issue
Block a user