mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
fix: checks for error after rows.Next call (#480)
This commit is contained in:
@@ -17,6 +17,7 @@ type ColumnScanner interface {
|
||||
Next() bool
|
||||
Scan(...interface{}) error
|
||||
Columns() ([]string, error)
|
||||
Err() error
|
||||
}
|
||||
|
||||
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
|
||||
@@ -29,6 +30,9 @@ func ScanOne(rows ColumnScanner, v interface{}) error {
|
||||
return fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
|
||||
}
|
||||
if !rows.Next() {
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
if err := rows.Scan(v); err != nil {
|
||||
@@ -37,7 +41,7 @@ func ScanOne(rows ColumnScanner, v interface{}) error {
|
||||
if rows.Next() {
|
||||
return fmt.Errorf("sql/scan: expect exactly one row in result set")
|
||||
}
|
||||
return nil
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// ScanInt64 scans and returns an int64 from the rows columns.
|
||||
@@ -92,7 +96,7 @@ func ScanSlice(rows ColumnScanner, v interface{}) error {
|
||||
vv := reflect.Append(rv, scan.value(values...))
|
||||
rv.Set(vv)
|
||||
}
|
||||
return nil
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// rowScan is the configuration for scanning one sql.Row.
|
||||
|
||||
@@ -31,6 +31,9 @@ func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error {
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
return fmt.Errorf("mysql: version variable was not found")
|
||||
}
|
||||
version := make([]string, 2)
|
||||
@@ -75,6 +78,9 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
}
|
||||
t.AddColumn(c)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("mysql: closing rows %v", err)
|
||||
}
|
||||
@@ -426,6 +432,9 @@ func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) {
|
||||
}
|
||||
idx.columns = append(idx.columns, column)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ func (d *Postgres) init(ctx context.Context, tx dialect.Tx) error {
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
return fmt.Errorf("server_version_num variable was not found")
|
||||
}
|
||||
var version string
|
||||
@@ -93,6 +96,9 @@ func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Tabl
|
||||
}
|
||||
t.AddColumn(c)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("closing rows %v", err)
|
||||
}
|
||||
@@ -184,6 +190,9 @@ func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (In
|
||||
}
|
||||
idx.columns = append(idx.columns, column)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
return idxs, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -168,6 +168,9 @@ func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
}
|
||||
t.AddColumn(c)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: closing rows %v", err)
|
||||
}
|
||||
@@ -215,6 +218,9 @@ func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Index
|
||||
i.primary = origin.String == "pk"
|
||||
idx = append(idx, i)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("closing rows %v", err)
|
||||
}
|
||||
|
||||
@@ -516,7 +516,7 @@ func QueryEdges(ctx context.Context, drv dialect.Driver, spec *EdgeQuerySpec) er
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
type query struct {
|
||||
@@ -540,7 +540,7 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
|
||||
@@ -723,6 +723,9 @@ func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdge
|
||||
func (u *updater) scan(rows *sql.Rows) error {
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value}
|
||||
}
|
||||
if err := rows.Scan(u.ScanValues...); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user