dialect/sqlscan: add ScanInt64 to be used by sqlgraph

Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/213

Reviewed By: alexsn

Differential Revision: D18763694

fbshipit-source-id: 890b35fcc2a28914b276ce65477788b4ddaeebf9
This commit is contained in:
Ariel Mashraki
2019-12-03 01:43:34 -08:00
committed by Facebook Github Bot
parent a4fac2db3b
commit a5e4a9cf54
4 changed files with 86 additions and 40 deletions

View File

@@ -351,25 +351,14 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBui
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, fmt.Errorf("no rows found for query: %v", query)
}
var id int64
if err := rows.Scan(&id); err != nil {
return 0, err
}
return id, nil
return ScanInt64(rows)
}
// MySQL, SQLite, etc.
var res sql.Result
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, err
}
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
return res.LastInsertId()
}
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.

View File

@@ -5,6 +5,7 @@
package sql
import (
"database/sql"
"fmt"
"reflect"
"strings"
@@ -18,7 +19,38 @@ type ColumnScanner interface {
Columns() ([]string, error)
}
// ScanSlice scans the given ColumnScanner (basically, sql.Rows or sql.Rows) into the given slice.
// ScanInt64 scans and returns an int64 from the rows columns.
func ScanInt64(rows ColumnScanner) (int64, error) {
columns, err := rows.Columns()
if err != nil {
return 0, fmt.Errorf("sql/scan: failed getting column names: %v", err)
}
if n := len(columns); n != 1 {
return 0, fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
}
if !rows.Next() {
return 0, sql.ErrNoRows
}
var n int64
if err := rows.Scan(&n); err != nil {
return 0, err
}
if rows.Next() {
return 0, fmt.Errorf("sql/scan: expect exactly one row in result set")
}
return n, nil
}
// ScanInt scans and returns an int from the rows columns.
func ScanInt(rows ColumnScanner) (int, error) {
n, err := ScanInt64(rows)
if err != nil {
return 0, err
}
return int(n), nil
}
// ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice.
func ScanSlice(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()
if err != nil {

View File

@@ -5,6 +5,7 @@
package sql
import (
"database/sql"
"testing"
"github.com/DATA-DOG/go-sqlmock"
@@ -12,47 +13,47 @@ import (
)
func TestScanSlice(t *testing.T) {
rows := sqlmock.NewRows([]string{"name"}).
mock := sqlmock.NewRows([]string{"name"}).
AddRow("foo").
AddRow("bar")
var v0 []string
require.NoError(t, scanSlice(rows, &v0))
require.NoError(t, ScanSlice(toRows(mock), &v0))
require.Equal(t, []string{"foo", "bar"}, v0)
rows = sqlmock.NewRows([]string{"age"}).
mock = sqlmock.NewRows([]string{"age"}).
AddRow(1).
AddRow(2)
var v1 []int
require.NoError(t, scanSlice(rows, &v1))
require.NoError(t, ScanSlice(toRows(mock), &v1))
require.Equal(t, []int{1, 2}, v1)
rows = sqlmock.NewRows([]string{"name", "COUNT(*)"}).
mock = sqlmock.NewRows([]string{"name", "COUNT(*)"}).
AddRow("foo", 1).
AddRow("bar", 2)
var v2 []struct {
Name string
Count int
}
require.NoError(t, scanSlice(rows, &v2))
require.NoError(t, ScanSlice(toRows(mock), &v2))
require.Equal(t, "foo", v2[0].Name)
require.Equal(t, "bar", v2[1].Name)
require.Equal(t, 1, v2[0].Count)
require.Equal(t, 2, v2[1].Count)
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
AddRow("foo", 1).
AddRow("bar", 2)
var v3 []struct {
Count int
Name string `json:"nick_name"`
}
require.NoError(t, scanSlice(rows, &v3))
require.NoError(t, ScanSlice(toRows(mock), &v3))
require.Equal(t, "foo", v3[0].Name)
require.Equal(t, "bar", v3[1].Name)
require.Equal(t, 1, v3[0].Count)
require.Equal(t, 2, v3[1].Count)
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
AddRow("foo", 1).
AddRow("bar", 2)
var v4 []*struct {
@@ -60,20 +61,20 @@ func TestScanSlice(t *testing.T) {
Name string `json:"nick_name"`
Ignored string `json:"string"`
}
require.NoError(t, scanSlice(rows, &v4))
require.NoError(t, ScanSlice(toRows(mock), &v4))
require.Equal(t, "foo", v4[0].Name)
require.Equal(t, "bar", v4[1].Name)
require.Equal(t, 1, v4[0].Count)
require.Equal(t, 2, v4[1].Count)
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
AddRow("foo", 1).
AddRow("bar", 2)
var v5 []*struct {
Count int
Name string `json:"name" sql:"nick_name"`
}
require.NoError(t, scanSlice(rows, &v5))
require.NoError(t, ScanSlice(toRows(mock), &v5))
require.Equal(t, "foo", v5[0].Name)
require.Equal(t, "bar", v5[1].Name)
require.Equal(t, 1, v5[0].Count)
@@ -81,39 +82,66 @@ func TestScanSlice(t *testing.T) {
}
func TestScanSlicePtr(t *testing.T) {
rows := sqlmock.NewRows([]string{"name"}).
mock := sqlmock.NewRows([]string{"name"}).
AddRow("foo").
AddRow("bar")
var v0 []*string
require.NoError(t, scanSlice(rows, &v0))
require.NoError(t, ScanSlice(toRows(mock), &v0))
require.Equal(t, "foo", *v0[0])
require.Equal(t, "bar", *v0[1])
rows = sqlmock.NewRows([]string{"age"}).
mock = sqlmock.NewRows([]string{"age"}).
AddRow(1).
AddRow(2)
var v1 []**int
require.NoError(t, scanSlice(rows, &v1))
require.NoError(t, ScanSlice(toRows(mock), &v1))
require.Equal(t, 1, **v1[0])
require.Equal(t, 2, **v1[1])
rows = sqlmock.NewRows([]string{"age", "name"}).
mock = sqlmock.NewRows([]string{"age", "name"}).
AddRow(1, "a8m").
AddRow(2, "nati")
var v2 []*struct {
Age *int
Name **string
}
require.NoError(t, scanSlice(rows, &v2))
require.NoError(t, ScanSlice(toRows(mock), &v2))
require.Equal(t, 1, *v2[0].Age)
require.Equal(t, "a8m", **v2[0].Name)
require.Equal(t, 2, *v2[1].Age)
require.Equal(t, "nati", **v2[1].Name)
}
func scanSlice(mrows *sqlmock.Rows, v interface{}) error {
func TestScanInt64(t *testing.T) {
mock := sqlmock.NewRows([]string{"age"}).
AddRow("10").
AddRow("20")
n, err := ScanInt64(toRows(mock))
require.Error(t, err)
require.Zero(t, n)
mock = sqlmock.NewRows([]string{"age", "count"}).
AddRow("10", "1")
n, err = ScanInt64(toRows(mock))
require.Error(t, err)
require.Zero(t, n)
mock = sqlmock.NewRows([]string{"count"}).
AddRow(10)
n, err = ScanInt64(toRows(mock))
require.NoError(t, err)
require.EqualValues(t, 10, n)
mock = sqlmock.NewRows([]string{"count"}).
AddRow("10")
n, err = ScanInt64(toRows(mock))
require.NoError(t, err)
require.EqualValues(t, 10, n)
}
func toRows(mrows *sqlmock.Rows) *sql.Rows {
db, mock, _ := sqlmock.New()
mock.ExpectQuery("").WillReturnRows(mrows)
rows, _ := db.Query("")
return ScanSlice(rows, v)
return rows
}

View File

@@ -423,12 +423,9 @@ func exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}
return false, fmt.Errorf("reading schema information %v", err)
}
defer rows.Close()
if !rows.Next() {
return false, fmt.Errorf("no rows returned")
}
var n int
if err := rows.Scan(&n); err != nil {
return false, fmt.Errorf("scanning count")
n, err := sql.ScanInt(rows)
if err != nil {
return false, err
}
return n > 0, nil
}