dialect/sql: reset session variables when query is done (#4364)

This commit is contained in:
Ariel Mashraki
2025-03-28 06:21:31 +03:00
committed by GitHub
parent 9db6f4df43
commit 95a3e5f970
2 changed files with 49 additions and 8 deletions

View File

@@ -33,12 +33,12 @@ func Open(dialect, source string) (*Driver, error) {
if err != nil {
return nil, err
}
return NewDriver(dialect, Conn{db}), nil
return NewDriver(dialect, Conn{db, dialect}), nil
}
// OpenDB wraps the given database/sql.DB method with a Driver.
func OpenDB(dialect string, db *sql.DB) *Driver {
return NewDriver(dialect, Conn{db})
return NewDriver(dialect, Conn{db, dialect})
}
// DB returns the underlying *sql.DB instance.
@@ -69,7 +69,7 @@ func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, erro
return nil, err
}
return &Tx{
Conn: Conn{tx},
Conn: Conn{tx, d.dialect},
Tx: tx,
}, nil
}
@@ -103,6 +103,17 @@ func WithVar(ctx context.Context, name, value string) context.Context {
return context.WithValue(ctx, ctxVarsKey{}, sv)
}
// VarFromContext returns the session variable value from the context.
func VarFromContext(ctx context.Context, name string) (string, bool) {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
for _, s := range sv.vars {
if s.k == name {
return s.v, true
}
}
return "", false
}
// WithIntVar calls WithVar with the string representation of the value.
func WithIntVar(ctx context.Context, name string, value int) context.Context {
return WithVar(ctx, name, strconv.Itoa(value))
@@ -117,6 +128,7 @@ type ExecQuerier interface {
// Conn implements dialect.ExecQuerier given ExecQuerier.
type Conn struct {
ExecQuerier
dialect string
}
// Exec implements the dialect.Exec method.
@@ -186,6 +198,8 @@ func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error)
var (
ex ExecQuerier // Underlying ExecQuerier.
cf func() error // Close function.
reset []string // Reset variables.
seen = make(map[string]struct{}, len(sv.vars))
)
switch e := c.ExecQuerier.(type) {
case *sql.Tx:
@@ -198,6 +212,15 @@ func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error)
ex, cf = conn, conn.Close
}
for _, s := range sv.vars {
if _, ok := seen[s.k]; !ok {
switch c.dialect {
case dialect.Postgres:
reset = append(reset, fmt.Sprintf("RESET %s", s.k))
case dialect.MySQL:
reset = append(reset, fmt.Sprintf("SET %s = NULL", s.k))
}
seen[s.k] = struct{}{}
}
if _, err := ex.ExecContext(ctx, fmt.Sprintf("SET %s = '%s'", s.k, s.v)); err != nil {
if cf != nil {
err = errors.Join(err, cf())
@@ -205,6 +228,18 @@ func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error)
return nil, nil, err
}
}
// If there are variables to reset, and we need to return the
// connection to the pool, we need to clean up the variables.
if cls := cf; cf != nil && len(reset) > 0 {
cf = func() error {
for _, q := range reset {
if _, err := ex.ExecContext(ctx, q); err != nil {
return errors.Join(err, cls())
}
}
return cls()
}
}
return ex, cf, nil
}

View File

@@ -8,6 +8,8 @@ import (
"context"
"testing"
"entgo.io/ent/dialect"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
@@ -16,9 +18,10 @@ func TestWithVars(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
db.SetMaxOpenConns(1)
drv := OpenDB("sqlite3", db)
drv := OpenDB(dialect.Postgres, db)
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectExec("RESET foo").WillReturnResult(sqlmock.NewResult(0, 0))
rows := &Rows{}
err = drv.Query(
WithVar(context.Background(), "foo", "bar"),
@@ -27,12 +30,13 @@ func TestWithVars(t *testing.T) {
rows,
)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
require.NoError(t, rows.Close(), "rows should be closed to release the connection")
require.NoError(t, mock.ExpectationsWereMet())
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("SET foo = 'baz'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectExec("RESET foo").WillReturnResult(sqlmock.NewResult(0, 0))
err = drv.Query(
WithVar(WithVar(context.Background(), "foo", "bar"), "foo", "baz"),
"SELECT 1",
@@ -40,8 +44,8 @@ func TestWithVars(t *testing.T) {
rows,
)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
require.NoError(t, rows.Close(), "rows should be closed to release the connection")
require.NoError(t, mock.ExpectationsWereMet())
mock.ExpectBegin()
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
@@ -63,6 +67,7 @@ func TestWithVars(t *testing.T) {
mock.ExpectExec("SET foo = 'qux'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO users DEFAULT VALUES").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("RESET foo").WillReturnResult(sqlmock.NewResult(0, 0))
err = drv.Exec(
WithVar(context.Background(), "foo", "qux"),
"INSERT INTO users DEFAULT VALUES",
@@ -75,6 +80,7 @@ func TestWithVars(t *testing.T) {
mock.ExpectExec("SET foo = 'foo'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO users DEFAULT VALUES").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("RESET foo").WillReturnResult(sqlmock.NewResult(0, 0))
err = drv.Exec(
WithVar(context.Background(), "foo", "foo"),
"INSERT INTO users DEFAULT VALUES",