dialect/sql: export conn type (#767)

This commit is contained in:
Alex Snast
2020-09-15 16:17:43 +03:00
committed by GitHub
parent 2c0d7e5a42
commit bd4d2a553c

View File

@@ -7,6 +7,7 @@ package sql
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"strings"
@@ -15,7 +16,7 @@ import (
// Driver is a dialect.Driver implementation for SQL based databases.
type Driver struct {
conn
Conn
dialect string
}
@@ -25,17 +26,17 @@ func Open(driver, source string) (*Driver, error) {
if err != nil {
return nil, err
}
return &Driver{conn{db}, driver}, nil
return &Driver{Conn{db}, driver}, nil
}
// OpenDB wraps the given database/sql.DB method with a Driver.
func OpenDB(driver string, db *sql.DB) *Driver {
return &Driver{conn{db}, driver}
return &Driver{Conn{db}, driver}
}
// DB returns the underlying *sql.DB instance.
func (d Driver) DB() *sql.DB {
return d.conn.ExecQuerier.(*sql.DB)
return d.ExecQuerier.(*sql.DB)
}
// Dialect implements the dialect.Dialect method.
@@ -51,45 +52,43 @@ func (d Driver) Dialect() string {
// Tx starts and returns a transaction.
func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) {
return d.BeginTx(ctx, &sql.TxOptions{})
return d.BeginTx(ctx, nil)
}
// BeginTx starts a transaction with options.
func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) {
tx, err := d.ExecQuerier.(*sql.DB).BeginTx(ctx, opts)
tx, err := d.DB().BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{conn{tx}}, nil
return &Tx{
ExecQuerier: Conn{tx},
Tx: tx,
}, nil
}
// Close closes the underlying connection.
func (d *Driver) Close() error { return d.ExecQuerier.(*sql.DB).Close() }
func (d *Driver) Close() error { return d.DB().Close() }
// Tx wraps the sql.Tx for implementing the dialect.Tx interface.
// Tx implements dialect.Tx interface.
type Tx struct {
conn
dialect.ExecQuerier
driver.Tx
}
// Commit commits the transaction.
func (t *Tx) Commit() error { return t.ExecQuerier.(*sql.Tx).Commit() }
// Rollback rollback the transaction.
func (t *Tx) Rollback() error { return t.ExecQuerier.(*sql.Tx).Rollback() }
// ExecQuerier wraps the standard Exec and Query methods.
type ExecQuerier interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// shared connection ExecQuerier between Driver and Tx.
type conn struct {
// Conn implements dialect.ExecQuerier given ExecQuerier.
type Conn struct {
ExecQuerier
}
// Exec implements the dialect.Exec method.
func (c *conn) Exec(ctx context.Context, query string, args, v interface{}) error {
func (c Conn) Exec(ctx context.Context, query string, args, v interface{}) error {
argv, ok := args.([]interface{})
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", v)
@@ -112,7 +111,7 @@ func (c *conn) Exec(ctx context.Context, query string, args, v interface{}) erro
}
// Query implements the dialect.Query method.
func (c *conn) Query(ctx context.Context, query string, args, v interface{}) error {
func (c Conn) Query(ctx context.Context, query string, args, v interface{}) error {
vr, ok := v.(*Rows)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v)