mirror of
https://github.com/ent/ent.git
synced 2026-03-05 19:35:23 +03:00
dialect/sql: reset session variables when query is done (#4364)
This commit is contained in:
@@ -33,12 +33,12 @@ func Open(dialect, source string) (*Driver, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewDriver(dialect, Conn{db}), nil
|
||||
return NewDriver(dialect, Conn{db, dialect}), nil
|
||||
}
|
||||
|
||||
// OpenDB wraps the given database/sql.DB method with a Driver.
|
||||
func OpenDB(dialect string, db *sql.DB) *Driver {
|
||||
return NewDriver(dialect, Conn{db})
|
||||
return NewDriver(dialect, Conn{db, dialect})
|
||||
}
|
||||
|
||||
// DB returns the underlying *sql.DB instance.
|
||||
@@ -69,7 +69,7 @@ func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, erro
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{
|
||||
Conn: Conn{tx},
|
||||
Conn: Conn{tx, d.dialect},
|
||||
Tx: tx,
|
||||
}, nil
|
||||
}
|
||||
@@ -103,6 +103,17 @@ func WithVar(ctx context.Context, name, value string) context.Context {
|
||||
return context.WithValue(ctx, ctxVarsKey{}, sv)
|
||||
}
|
||||
|
||||
// VarFromContext returns the session variable value from the context.
|
||||
func VarFromContext(ctx context.Context, name string) (string, bool) {
|
||||
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
|
||||
for _, s := range sv.vars {
|
||||
if s.k == name {
|
||||
return s.v, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// WithIntVar calls WithVar with the string representation of the value.
|
||||
func WithIntVar(ctx context.Context, name string, value int) context.Context {
|
||||
return WithVar(ctx, name, strconv.Itoa(value))
|
||||
@@ -117,6 +128,7 @@ type ExecQuerier interface {
|
||||
// Conn implements dialect.ExecQuerier given ExecQuerier.
|
||||
type Conn struct {
|
||||
ExecQuerier
|
||||
dialect string
|
||||
}
|
||||
|
||||
// Exec implements the dialect.Exec method.
|
||||
@@ -184,8 +196,10 @@ func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error)
|
||||
return c, nil, nil
|
||||
}
|
||||
var (
|
||||
ex ExecQuerier // Underlying ExecQuerier.
|
||||
cf func() error // Close function.
|
||||
ex ExecQuerier // Underlying ExecQuerier.
|
||||
cf func() error // Close function.
|
||||
reset []string // Reset variables.
|
||||
seen = make(map[string]struct{}, len(sv.vars))
|
||||
)
|
||||
switch e := c.ExecQuerier.(type) {
|
||||
case *sql.Tx:
|
||||
@@ -198,6 +212,15 @@ func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error)
|
||||
ex, cf = conn, conn.Close
|
||||
}
|
||||
for _, s := range sv.vars {
|
||||
if _, ok := seen[s.k]; !ok {
|
||||
switch c.dialect {
|
||||
case dialect.Postgres:
|
||||
reset = append(reset, fmt.Sprintf("RESET %s", s.k))
|
||||
case dialect.MySQL:
|
||||
reset = append(reset, fmt.Sprintf("SET %s = NULL", s.k))
|
||||
}
|
||||
seen[s.k] = struct{}{}
|
||||
}
|
||||
if _, err := ex.ExecContext(ctx, fmt.Sprintf("SET %s = '%s'", s.k, s.v)); err != nil {
|
||||
if cf != nil {
|
||||
err = errors.Join(err, cf())
|
||||
@@ -205,6 +228,18 @@ func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error)
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
// If there are variables to reset, and we need to return the
|
||||
// connection to the pool, we need to clean up the variables.
|
||||
if cls := cf; cf != nil && len(reset) > 0 {
|
||||
cf = func() error {
|
||||
for _, q := range reset {
|
||||
if _, err := ex.ExecContext(ctx, q); err != nil {
|
||||
return errors.Join(err, cls())
|
||||
}
|
||||
}
|
||||
return cls()
|
||||
}
|
||||
}
|
||||
return ex, cf, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -16,9 +18,10 @@ func TestWithVars(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db.SetMaxOpenConns(1)
|
||||
drv := OpenDB("sqlite3", db)
|
||||
drv := OpenDB(dialect.Postgres, db)
|
||||
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
mock.ExpectExec("RESET foo").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
rows := &Rows{}
|
||||
err = drv.Query(
|
||||
WithVar(context.Background(), "foo", "bar"),
|
||||
@@ -27,12 +30,13 @@ func TestWithVars(t *testing.T) {
|
||||
rows,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
require.NoError(t, rows.Close(), "rows should be closed to release the connection")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
|
||||
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("SET foo = 'baz'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
mock.ExpectExec("RESET foo").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
err = drv.Query(
|
||||
WithVar(WithVar(context.Background(), "foo", "bar"), "foo", "baz"),
|
||||
"SELECT 1",
|
||||
@@ -40,8 +44,8 @@ func TestWithVars(t *testing.T) {
|
||||
rows,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
require.NoError(t, rows.Close(), "rows should be closed to release the connection")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
@@ -63,6 +67,7 @@ func TestWithVars(t *testing.T) {
|
||||
|
||||
mock.ExpectExec("SET foo = 'qux'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO users DEFAULT VALUES").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("RESET foo").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
err = drv.Exec(
|
||||
WithVar(context.Background(), "foo", "qux"),
|
||||
"INSERT INTO users DEFAULT VALUES",
|
||||
@@ -75,6 +80,7 @@ func TestWithVars(t *testing.T) {
|
||||
|
||||
mock.ExpectExec("SET foo = 'foo'").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO users DEFAULT VALUES").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("RESET foo").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
err = drv.Exec(
|
||||
WithVar(context.Background(), "foo", "foo"),
|
||||
"INSERT INTO users DEFAULT VALUES",
|
||||
|
||||
Reference in New Issue
Block a user