mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: export conn type (#767)
This commit is contained in:
@@ -7,6 +7,7 @@ package sql
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
|
||||
// Driver is a dialect.Driver implementation for SQL based databases.
|
||||
type Driver struct {
|
||||
conn
|
||||
Conn
|
||||
dialect string
|
||||
}
|
||||
|
||||
@@ -25,17 +26,17 @@ func Open(driver, source string) (*Driver, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Driver{conn{db}, driver}, nil
|
||||
return &Driver{Conn{db}, driver}, nil
|
||||
}
|
||||
|
||||
// OpenDB wraps the given database/sql.DB method with a Driver.
|
||||
func OpenDB(driver string, db *sql.DB) *Driver {
|
||||
return &Driver{conn{db}, driver}
|
||||
return &Driver{Conn{db}, driver}
|
||||
}
|
||||
|
||||
// DB returns the underlying *sql.DB instance.
|
||||
func (d Driver) DB() *sql.DB {
|
||||
return d.conn.ExecQuerier.(*sql.DB)
|
||||
return d.ExecQuerier.(*sql.DB)
|
||||
}
|
||||
|
||||
// Dialect implements the dialect.Dialect method.
|
||||
@@ -51,45 +52,43 @@ func (d Driver) Dialect() string {
|
||||
|
||||
// Tx starts and returns a transaction.
|
||||
func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) {
|
||||
return d.BeginTx(ctx, &sql.TxOptions{})
|
||||
return d.BeginTx(ctx, nil)
|
||||
}
|
||||
|
||||
// BeginTx starts a transaction with options.
|
||||
func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) {
|
||||
tx, err := d.ExecQuerier.(*sql.DB).BeginTx(ctx, opts)
|
||||
tx, err := d.DB().BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{conn{tx}}, nil
|
||||
return &Tx{
|
||||
ExecQuerier: Conn{tx},
|
||||
Tx: tx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying connection.
|
||||
func (d *Driver) Close() error { return d.ExecQuerier.(*sql.DB).Close() }
|
||||
func (d *Driver) Close() error { return d.DB().Close() }
|
||||
|
||||
// Tx wraps the sql.Tx for implementing the dialect.Tx interface.
|
||||
// Tx implements dialect.Tx interface.
|
||||
type Tx struct {
|
||||
conn
|
||||
dialect.ExecQuerier
|
||||
driver.Tx
|
||||
}
|
||||
|
||||
// Commit commits the transaction.
|
||||
func (t *Tx) Commit() error { return t.ExecQuerier.(*sql.Tx).Commit() }
|
||||
|
||||
// Rollback rollback the transaction.
|
||||
func (t *Tx) Rollback() error { return t.ExecQuerier.(*sql.Tx).Rollback() }
|
||||
|
||||
// ExecQuerier wraps the standard Exec and Query methods.
|
||||
type ExecQuerier interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// shared connection ExecQuerier between Driver and Tx.
|
||||
type conn struct {
|
||||
// Conn implements dialect.ExecQuerier given ExecQuerier.
|
||||
type Conn struct {
|
||||
ExecQuerier
|
||||
}
|
||||
|
||||
// Exec implements the dialect.Exec method.
|
||||
func (c *conn) Exec(ctx context.Context, query string, args, v interface{}) error {
|
||||
func (c Conn) Exec(ctx context.Context, query string, args, v interface{}) error {
|
||||
argv, ok := args.([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", v)
|
||||
@@ -112,7 +111,7 @@ func (c *conn) Exec(ctx context.Context, query string, args, v interface{}) erro
|
||||
}
|
||||
|
||||
// Query implements the dialect.Query method.
|
||||
func (c *conn) Query(ctx context.Context, query string, args, v interface{}) error {
|
||||
func (c Conn) Query(ctx context.Context, query string, args, v interface{}) error {
|
||||
vr, ok := v.(*Rows)
|
||||
if !ok {
|
||||
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v)
|
||||
|
||||
Reference in New Issue
Block a user