mirror of
https://github.com/ent/ent.git
synced 2026-05-03 16:10:59 +03:00
185 lines
4.7 KiB
Go
185 lines
4.7 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package sql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"entgo.io/ent/dialect"
|
|
)
|
|
|
|
// Driver is a dialect.Driver implementation for SQL based databases.
|
|
type Driver struct {
|
|
Conn
|
|
dialect string
|
|
}
|
|
|
|
// NewDriver creates a new Driver with the given Conn and dialect.
|
|
func NewDriver(dialect string, c Conn) *Driver {
|
|
return &Driver{dialect: dialect, Conn: c}
|
|
}
|
|
|
|
// Open wraps the database/sql.Open method and returns a dialect.Driver that implements the an ent/dialect.Driver interface.
|
|
func Open(dialect, source string) (*Driver, error) {
|
|
db, err := sql.Open(dialect, source)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewDriver(dialect, Conn{db}), nil
|
|
}
|
|
|
|
// OpenDB wraps the given database/sql.DB method with a Driver.
|
|
func OpenDB(dialect string, db *sql.DB) *Driver {
|
|
return NewDriver(dialect, Conn{db})
|
|
}
|
|
|
|
// DB returns the underlying *sql.DB instance.
|
|
func (d Driver) DB() *sql.DB {
|
|
return d.ExecQuerier.(*sql.DB)
|
|
}
|
|
|
|
// Dialect implements the dialect.Dialect method.
|
|
func (d Driver) Dialect() string {
|
|
// If the underlying driver is wrapped with a telemetry driver.
|
|
for _, name := range []string{dialect.MySQL, dialect.SQLite, dialect.Postgres} {
|
|
if strings.HasPrefix(d.dialect, name) {
|
|
return name
|
|
}
|
|
}
|
|
return d.dialect
|
|
}
|
|
|
|
// Tx starts and returns a transaction.
|
|
func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) {
|
|
return d.BeginTx(ctx, nil)
|
|
}
|
|
|
|
// BeginTx starts a transaction with options.
|
|
func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) {
|
|
tx, err := d.DB().BeginTx(ctx, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Tx{
|
|
Conn: Conn{tx},
|
|
Tx: tx,
|
|
}, nil
|
|
}
|
|
|
|
// Close closes the underlying connection.
|
|
func (d *Driver) Close() error { return d.DB().Close() }
|
|
|
|
// Tx implements dialect.Tx interface.
|
|
type Tx struct {
|
|
Conn
|
|
driver.Tx
|
|
}
|
|
|
|
// ExecQuerier wraps the standard Exec and Query methods.
|
|
type ExecQuerier interface {
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
}
|
|
|
|
// Conn implements dialect.ExecQuerier given ExecQuerier.
|
|
type Conn struct {
|
|
ExecQuerier
|
|
}
|
|
|
|
// Exec implements the dialect.Exec method.
|
|
func (c Conn) Exec(ctx context.Context, query string, args, v any) error {
|
|
argv, ok := args.([]any)
|
|
if !ok {
|
|
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", v)
|
|
}
|
|
switch v := v.(type) {
|
|
case nil:
|
|
if _, err := c.ExecContext(ctx, query, argv...); err != nil {
|
|
return err
|
|
}
|
|
case *sql.Result:
|
|
res, err := c.ExecContext(ctx, query, argv...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*v = res
|
|
default:
|
|
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Result", v)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Query implements the dialect.Query method.
|
|
func (c Conn) Query(ctx context.Context, query string, args, v any) error {
|
|
vr, ok := v.(*Rows)
|
|
if !ok {
|
|
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v)
|
|
}
|
|
argv, ok := args.([]any)
|
|
if !ok {
|
|
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", args)
|
|
}
|
|
rows, err := c.QueryContext(ctx, query, argv...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*vr = Rows{rows}
|
|
return nil
|
|
}
|
|
|
|
var _ dialect.Driver = (*Driver)(nil)
|
|
|
|
type (
|
|
// Rows wraps the sql.Rows to avoid locks copy.
|
|
Rows struct{ ColumnScanner }
|
|
// Result is an alias to sql.Result.
|
|
Result = sql.Result
|
|
// NullBool is an alias to sql.NullBool.
|
|
NullBool = sql.NullBool
|
|
// NullInt64 is an alias to sql.NullInt64.
|
|
NullInt64 = sql.NullInt64
|
|
// NullString is an alias to sql.NullString.
|
|
NullString = sql.NullString
|
|
// NullFloat64 is an alias to sql.NullFloat64.
|
|
NullFloat64 = sql.NullFloat64
|
|
// NullTime represents a time.Time that may be null.
|
|
NullTime = sql.NullTime
|
|
// TxOptions holds the transaction options to be used in DB.BeginTx.
|
|
TxOptions = sql.TxOptions
|
|
)
|
|
|
|
// NullScanner represents an sql.Scanner that may be null.
|
|
// NullScanner implements the sql.Scanner interface so it can
|
|
// be used as a scan destination, similar to the types above.
|
|
type NullScanner struct {
|
|
S sql.Scanner
|
|
Valid bool // Valid is true if the Scan value is not NULL.
|
|
}
|
|
|
|
// Scan implements the Scanner interface.
|
|
func (n *NullScanner) Scan(value any) error {
|
|
n.Valid = value != nil
|
|
if n.Valid {
|
|
return n.S.Scan(value)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ColumnScanner is the interface that wraps the standard
|
|
// sql.Rows methods used for scanning database rows.
|
|
type ColumnScanner interface {
|
|
Close() error
|
|
ColumnTypes() ([]*sql.ColumnType, error)
|
|
Columns() ([]string, error)
|
|
Err() error
|
|
Next() bool
|
|
NextResultSet() bool
|
|
Scan(dest ...any) error
|
|
}
|