Files
ent/dialect/sql/driver.go

306 lines
7.9 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sql
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"strconv"
"strings"
"entgo.io/ent/dialect"
)
// Driver is a dialect.Driver implementation for SQL based databases.
type Driver struct {
Conn
dialect string
}
// NewDriver creates a new Driver with the given Conn and dialect.
func NewDriver(dialect string, c Conn) *Driver {
return &Driver{dialect: dialect, Conn: c}
}
// Open wraps the database/sql.Open method and returns a dialect.Driver that implements the an ent/dialect.Driver interface.
func Open(dialect, source string) (*Driver, error) {
db, err := sql.Open(dialect, source)
if err != nil {
return nil, err
}
return NewDriver(dialect, Conn{db, dialect}), nil
}
// OpenDB wraps the given database/sql.DB method with a Driver.
func OpenDB(dialect string, db *sql.DB) *Driver {
return NewDriver(dialect, Conn{db, dialect})
}
// DB returns the underlying *sql.DB instance.
func (d Driver) DB() *sql.DB {
return d.ExecQuerier.(*sql.DB)
}
// Dialect implements the dialect.Dialect method.
func (d Driver) Dialect() string {
// If the underlying driver is wrapped with a telemetry driver.
for _, name := range []string{dialect.MySQL, dialect.SQLite, dialect.Postgres} {
if strings.HasPrefix(d.dialect, name) {
return name
}
}
return d.dialect
}
// Tx starts and returns a transaction.
func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) {
return d.BeginTx(ctx, nil)
}
// BeginTx starts a transaction with options.
func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) {
tx, err := d.DB().BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{
Conn: Conn{tx, d.dialect},
Tx: tx,
}, nil
}
// Close closes the underlying connection.
func (d *Driver) Close() error { return d.DB().Close() }
// Tx implements dialect.Tx interface.
type Tx struct {
Conn
driver.Tx
}
// ctyVarsKey is the key used for attaching and reading the context variables.
type ctxVarsKey struct{}
// sessionVars holds sessions/transactions variables to set before every statement.
type sessionVars struct {
vars []struct{ k, v string }
}
// WithVar returns a new context that holds the session variable to be executed before every query.
func WithVar(ctx context.Context, name, value string) context.Context {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
sv.vars = append(sv.vars, struct {
k, v string
}{
k: name,
v: value,
})
return context.WithValue(ctx, ctxVarsKey{}, sv)
}
// VarFromContext returns the session variable value from the context.
func VarFromContext(ctx context.Context, name string) (string, bool) {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
for _, s := range sv.vars {
if s.k == name {
return s.v, true
}
}
return "", false
}
// WithIntVar calls WithVar with the string representation of the value.
func WithIntVar(ctx context.Context, name string, value int) context.Context {
return WithVar(ctx, name, strconv.Itoa(value))
}
// ExecQuerier wraps the standard Exec and Query methods.
type ExecQuerier interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
// Conn implements dialect.ExecQuerier given ExecQuerier.
type Conn struct {
ExecQuerier
dialect string
}
// Exec implements the dialect.Exec method.
func (c Conn) Exec(ctx context.Context, query string, args, v any) (rerr error) {
argv, ok := args.([]any)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", v)
}
ex, cf, err := c.maySetVars(ctx)
if err != nil {
return err
}
if cf != nil {
defer func() { rerr = errors.Join(rerr, cf()) }()
}
switch v := v.(type) {
case nil:
if _, err := ex.ExecContext(ctx, query, argv...); err != nil {
return err
}
case *sql.Result:
res, err := ex.ExecContext(ctx, query, argv...)
if err != nil {
return err
}
*v = res
default:
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Result", v)
}
return nil
}
// Query implements the dialect.Query method.
func (c Conn) Query(ctx context.Context, query string, args, v any) error {
vr, ok := v.(*Rows)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v)
}
argv, ok := args.([]any)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", args)
}
ex, cf, err := c.maySetVars(ctx)
if err != nil {
return err
}
rows, err := ex.QueryContext(ctx, query, argv...)
if err != nil {
if cf != nil {
err = errors.Join(err, cf())
}
return err
}
*vr = Rows{rows}
if cf != nil {
vr.ColumnScanner = rowsWithCloser{rows, cf}
}
return nil
}
// maySetVars sets the session variables before executing a query.
func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error) {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
if len(sv.vars) == 0 {
return c, nil, nil
}
var (
ex ExecQuerier // Underlying ExecQuerier.
cf func() error // Close function.
reset []string // Reset variables.
seen = make(map[string]struct{}, len(sv.vars))
)
switch e := c.ExecQuerier.(type) {
case *sql.Tx:
ex = e
case *sql.DB:
conn, err := e.Conn(ctx)
if err != nil {
return nil, nil, err
}
ex, cf = conn, conn.Close
}
for _, s := range sv.vars {
if _, ok := seen[s.k]; !ok {
switch c.dialect {
case dialect.Postgres:
reset = append(reset, fmt.Sprintf("RESET %s", s.k))
case dialect.MySQL:
reset = append(reset, fmt.Sprintf("SET %s = NULL", s.k))
}
seen[s.k] = struct{}{}
}
if _, err := ex.ExecContext(ctx, fmt.Sprintf("SET %s = '%s'", s.k, s.v)); err != nil {
if cf != nil {
err = errors.Join(err, cf())
}
return nil, nil, err
}
}
// If there are variables to reset, and we need to return the
// connection to the pool, we need to clean up the variables.
if cls := cf; cf != nil && len(reset) > 0 {
cf = func() error {
for _, q := range reset {
if _, err := ex.ExecContext(ctx, q); err != nil {
return errors.Join(err, cls())
}
}
return cls()
}
}
return ex, cf, nil
}
var _ dialect.Driver = (*Driver)(nil)
type (
// Rows wraps the sql.Rows to avoid locks copy.
Rows struct{ ColumnScanner }
// Result is an alias to sql.Result.
Result = sql.Result
// NullBool is an alias to sql.NullBool.
NullBool = sql.NullBool
// NullInt64 is an alias to sql.NullInt64.
NullInt64 = sql.NullInt64
// NullString is an alias to sql.NullString.
NullString = sql.NullString
// NullFloat64 is an alias to sql.NullFloat64.
NullFloat64 = sql.NullFloat64
// NullTime represents a time.Time that may be null.
NullTime = sql.NullTime
// TxOptions holds the transaction options to be used in DB.BeginTx.
TxOptions = sql.TxOptions
)
// NullScanner implements the sql.Scanner interface such that it
// can be used as a scan destination, similar to the types above.
type NullScanner struct {
S sql.Scanner
Valid bool // Valid is true if the Scan value is not NULL.
}
// Scan implements the Scanner interface.
func (n *NullScanner) Scan(value any) error {
n.Valid = value != nil
if n.Valid {
return n.S.Scan(value)
}
return nil
}
// ColumnScanner is the interface that wraps the standard
// sql.Rows methods used for scanning database rows.
type ColumnScanner interface {
Close() error
ColumnTypes() ([]*sql.ColumnType, error)
Columns() ([]string, error)
Err() error
Next() bool
NextResultSet() bool
Scan(dest ...any) error
}
// rowsWithCloser wraps the ColumnScanner interface with a custom Close hook.
type rowsWithCloser struct {
ColumnScanner
closer func() error
}
// Close closes the underlying ColumnScanner and calls the custom closer.
func (r rowsWithCloser) Close() error {
err := r.ColumnScanner.Close()
return errors.Join(err, r.closer())
}