mirror of
https://github.com/ent/ent.git
synced 2026-05-03 16:10:59 +03:00
150 lines
4.0 KiB
Go
150 lines
4.0 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package sql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/facebookincubator/ent/dialect"
|
|
)
|
|
|
|
// Driver is a dialect.Driver implementation for SQL based databases.
|
|
type Driver struct {
|
|
conn
|
|
dialect string
|
|
}
|
|
|
|
// Open wraps the database/sql.Open method and returns a dialect.Driver that implements the an ent/dialect.Driver interface.
|
|
func Open(driver, source string) (*Driver, error) {
|
|
db, err := sql.Open(driver, source)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Driver{conn{db}, driver}, nil
|
|
}
|
|
|
|
// OpenDB wraps the given database/sql.DB method with a Driver.
|
|
func OpenDB(driver string, db *sql.DB) *Driver {
|
|
return &Driver{conn{db}, driver}
|
|
}
|
|
|
|
// DB returns the underlying *sql.DB instance.
|
|
func (d Driver) DB() *sql.DB {
|
|
return d.conn.ExecQuerier.(*sql.DB)
|
|
}
|
|
|
|
// Dialect implements the dialect.Dialect method.
|
|
func (d Driver) Dialect() string {
|
|
// if the underlying driver is wrapped with opencensus driver.
|
|
for _, name := range []string{dialect.MySQL, dialect.SQLite} {
|
|
if strings.HasPrefix(d.dialect, name) {
|
|
return name
|
|
}
|
|
}
|
|
return d.dialect
|
|
}
|
|
|
|
// Tx starts and returns a transaction.
|
|
func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) {
|
|
return d.BeginTx(ctx, &sql.TxOptions{})
|
|
}
|
|
|
|
// BeginTx starts a transaction with options.
|
|
func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) {
|
|
tx, err := d.ExecQuerier.(*sql.DB).BeginTx(ctx, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Tx{conn{tx}}, nil
|
|
}
|
|
|
|
// Close closes the underlying connection.
|
|
func (d *Driver) Close() error { return d.ExecQuerier.(*sql.DB).Close() }
|
|
|
|
// Tx wraps the sql.Tx for implementing the dialect.Tx interface.
|
|
type Tx struct {
|
|
conn
|
|
}
|
|
|
|
// Commit commits the transaction.
|
|
func (t *Tx) Commit() error { return t.ExecQuerier.(*sql.Tx).Commit() }
|
|
|
|
// Rollback rollback the transaction.
|
|
func (t *Tx) Rollback() error { return t.ExecQuerier.(*sql.Tx).Rollback() }
|
|
|
|
// ExecQuerier wraps the standard Exec and Query methods.
|
|
type ExecQuerier interface {
|
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
|
}
|
|
|
|
// shared connection ExecQuerier between Driver and Tx.
|
|
type conn struct {
|
|
ExecQuerier
|
|
}
|
|
|
|
// Exec implements the dialect.Exec method.
|
|
func (c *conn) Exec(ctx context.Context, query string, args, v interface{}) error {
|
|
argv, ok := args.([]interface{})
|
|
if !ok {
|
|
return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", v)
|
|
}
|
|
switch v := v.(type) {
|
|
case nil:
|
|
if _, err := c.ExecContext(ctx, query, argv...); err != nil {
|
|
return err
|
|
}
|
|
case *sql.Result:
|
|
res, err := c.ExecContext(ctx, query, argv...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*v = res
|
|
default:
|
|
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Result", v)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Exec implements the dialect.Query method.
|
|
func (c *conn) Query(ctx context.Context, query string, args, v interface{}) error {
|
|
vr, ok := v.(*Rows)
|
|
if !ok {
|
|
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v)
|
|
}
|
|
argv, ok := args.([]interface{})
|
|
if !ok {
|
|
return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", args)
|
|
}
|
|
rows, err := c.QueryContext(ctx, query, argv...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*vr = Rows{rows}
|
|
return nil
|
|
}
|
|
|
|
var _ dialect.Driver = (*Driver)(nil)
|
|
|
|
type (
|
|
// Rows wraps the sql.Rows to avoid locks copy.
|
|
Rows struct{ *sql.Rows }
|
|
// Result is an alias to sql.Result.
|
|
Result = sql.Result
|
|
// NullBool is an alias to sql.NullBool.
|
|
NullBool = sql.NullBool
|
|
// NullInt64 is an alias to sql.NullInt64.
|
|
NullInt64 = sql.NullInt64
|
|
// NullString is an alias to sql.NullString.
|
|
NullString = sql.NullString
|
|
// NullFloat64 is an alias to sql.NullFloat64.
|
|
NullFloat64 = sql.NullFloat64
|
|
// TxOptions holds the transaction options to be used in DB.BeginTx.
|
|
TxOptions = sql.TxOptions
|
|
)
|