diff --git a/dialect/sql/driver.go b/dialect/sql/driver.go index 64492197a..643e15063 100644 --- a/dialect/sql/driver.go +++ b/dialect/sql/driver.go @@ -7,6 +7,7 @@ package sql import ( "context" "database/sql" + "database/sql/driver" "fmt" "strings" @@ -15,7 +16,7 @@ import ( // Driver is a dialect.Driver implementation for SQL based databases. type Driver struct { - conn + Conn dialect string } @@ -25,17 +26,17 @@ func Open(driver, source string) (*Driver, error) { if err != nil { return nil, err } - return &Driver{conn{db}, driver}, nil + return &Driver{Conn{db}, driver}, nil } // OpenDB wraps the given database/sql.DB method with a Driver. func OpenDB(driver string, db *sql.DB) *Driver { - return &Driver{conn{db}, driver} + return &Driver{Conn{db}, driver} } // DB returns the underlying *sql.DB instance. func (d Driver) DB() *sql.DB { - return d.conn.ExecQuerier.(*sql.DB) + return d.ExecQuerier.(*sql.DB) } // Dialect implements the dialect.Dialect method. @@ -51,45 +52,43 @@ func (d Driver) Dialect() string { // Tx starts and returns a transaction. func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) { - return d.BeginTx(ctx, &sql.TxOptions{}) + return d.BeginTx(ctx, nil) } // BeginTx starts a transaction with options. func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) { - tx, err := d.ExecQuerier.(*sql.DB).BeginTx(ctx, opts) + tx, err := d.DB().BeginTx(ctx, opts) if err != nil { return nil, err } - return &Tx{conn{tx}}, nil + return &Tx{ + ExecQuerier: Conn{tx}, + Tx: tx, + }, nil } // Close closes the underlying connection. -func (d *Driver) Close() error { return d.ExecQuerier.(*sql.DB).Close() } +func (d *Driver) Close() error { return d.DB().Close() } -// Tx wraps the sql.Tx for implementing the dialect.Tx interface. +// Tx implements dialect.Tx interface. type Tx struct { - conn + dialect.ExecQuerier + driver.Tx } -// Commit commits the transaction. -func (t *Tx) Commit() error { return t.ExecQuerier.(*sql.Tx).Commit() } - -// Rollback rollback the transaction. -func (t *Tx) Rollback() error { return t.ExecQuerier.(*sql.Tx).Rollback() } - // ExecQuerier wraps the standard Exec and Query methods. type ExecQuerier interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } -// shared connection ExecQuerier between Driver and Tx. -type conn struct { +// Conn implements dialect.ExecQuerier given ExecQuerier. +type Conn struct { ExecQuerier } // Exec implements the dialect.Exec method. -func (c *conn) Exec(ctx context.Context, query string, args, v interface{}) error { +func (c Conn) Exec(ctx context.Context, query string, args, v interface{}) error { argv, ok := args.([]interface{}) if !ok { return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", v) @@ -112,7 +111,7 @@ func (c *conn) Exec(ctx context.Context, query string, args, v interface{}) erro } // Query implements the dialect.Query method. -func (c *conn) Query(ctx context.Context, query string, args, v interface{}) error { +func (c Conn) Query(ctx context.Context, query string, args, v interface{}) error { vr, ok := v.(*Rows) if !ok { return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v)