fix: checks for error after rows.Next call (#480)

This commit is contained in:
Ciaran Liedeman
2020-05-09 15:23:47 +02:00
committed by GitHub
parent 7f260c3ae0
commit f59abad425
5 changed files with 35 additions and 4 deletions

View File

@@ -17,6 +17,7 @@ type ColumnScanner interface {
Next() bool
Scan(...interface{}) error
Columns() ([]string, error)
Err() error
}
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
@@ -29,6 +30,9 @@ func ScanOne(rows ColumnScanner, v interface{}) error {
return fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
}
if !rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
return sql.ErrNoRows
}
if err := rows.Scan(v); err != nil {
@@ -37,7 +41,7 @@ func ScanOne(rows ColumnScanner, v interface{}) error {
if rows.Next() {
return fmt.Errorf("sql/scan: expect exactly one row in result set")
}
return nil
return rows.Err()
}
// ScanInt64 scans and returns an int64 from the rows columns.
@@ -92,7 +96,7 @@ func ScanSlice(rows ColumnScanner, v interface{}) error {
vv := reflect.Append(rv, scan.value(values...))
rv.Set(vv)
}
return nil
return rows.Err()
}
// rowScan is the configuration for scanning one sql.Row.

View File

@@ -31,6 +31,9 @@ func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error {
}
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
return fmt.Errorf("mysql: version variable was not found")
}
version := make([]string, 2)
@@ -75,6 +78,9 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
}
t.AddColumn(c)
}
if rows.Err() != nil {
return nil, rows.Err()
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("mysql: closing rows %v", err)
}
@@ -426,6 +432,9 @@ func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) {
}
idx.columns = append(idx.columns, column)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return i, nil
}

View File

@@ -29,6 +29,9 @@ func (d *Postgres) init(ctx context.Context, tx dialect.Tx) error {
}
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
return fmt.Errorf("server_version_num variable was not found")
}
var version string
@@ -93,6 +96,9 @@ func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Tabl
}
t.AddColumn(c)
}
if rows.Err() != nil {
return nil, rows.Err()
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("closing rows %v", err)
}
@@ -184,6 +190,9 @@ func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (In
}
idx.columns = append(idx.columns, column)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return idxs, nil
}

View File

@@ -168,6 +168,9 @@ func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
}
t.AddColumn(c)
}
if rows.Err() != nil {
return nil, rows.Err()
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("sqlite: closing rows %v", err)
}
@@ -215,6 +218,9 @@ func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Index
i.primary = origin.String == "pk"
idx = append(idx, i)
}
if rows.Err() != nil {
return nil, rows.Err()
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("closing rows %v", err)
}

View File

@@ -516,7 +516,7 @@ func QueryEdges(ctx context.Context, drv dialect.Driver, spec *EdgeQuerySpec) er
return err
}
}
return nil
return rows.Err()
}
type query struct {
@@ -540,7 +540,7 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
return err
}
}
return nil
return rows.Err()
}
func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
@@ -723,6 +723,9 @@ func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdge
func (u *updater) scan(rows *sql.Rows) error {
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value}
}
if err := rows.Scan(u.ScanValues...); err != nil {