mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sqlscan: add ScanInt64 to be used by sqlgraph
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/213 Reviewed By: alexsn Differential Revision: D18763694 fbshipit-source-id: 890b35fcc2a28914b276ce65477788b4ddaeebf9
This commit is contained in:
committed by
Facebook Github Bot
parent
a4fac2db3b
commit
a5e4a9cf54
@@ -351,25 +351,14 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBui
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, fmt.Errorf("no rows found for query: %v", query)
|
||||
}
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
return ScanInt64(rows)
|
||||
}
|
||||
// MySQL, SQLite, etc.
|
||||
var res sql.Result
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -18,7 +19,38 @@ type ColumnScanner interface {
|
||||
Columns() ([]string, error)
|
||||
}
|
||||
|
||||
// ScanSlice scans the given ColumnScanner (basically, sql.Rows or sql.Rows) into the given slice.
|
||||
// ScanInt64 scans and returns an int64 from the rows columns.
|
||||
func ScanInt64(rows ColumnScanner) (int64, error) {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("sql/scan: failed getting column names: %v", err)
|
||||
}
|
||||
if n := len(columns); n != 1 {
|
||||
return 0, fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
|
||||
}
|
||||
if !rows.Next() {
|
||||
return 0, sql.ErrNoRows
|
||||
}
|
||||
var n int64
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if rows.Next() {
|
||||
return 0, fmt.Errorf("sql/scan: expect exactly one row in result set")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// ScanInt scans and returns an int from the rows columns.
|
||||
func ScanInt(rows ColumnScanner) (int, error) {
|
||||
n, err := ScanInt64(rows)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
// ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice.
|
||||
func ScanSlice(rows ColumnScanner, v interface{}) error {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
@@ -12,47 +13,47 @@ import (
|
||||
)
|
||||
|
||||
func TestScanSlice(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"name"}).
|
||||
mock := sqlmock.NewRows([]string{"name"}).
|
||||
AddRow("foo").
|
||||
AddRow("bar")
|
||||
var v0 []string
|
||||
require.NoError(t, scanSlice(rows, &v0))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v0))
|
||||
require.Equal(t, []string{"foo", "bar"}, v0)
|
||||
|
||||
rows = sqlmock.NewRows([]string{"age"}).
|
||||
mock = sqlmock.NewRows([]string{"age"}).
|
||||
AddRow(1).
|
||||
AddRow(2)
|
||||
var v1 []int
|
||||
require.NoError(t, scanSlice(rows, &v1))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v1))
|
||||
require.Equal(t, []int{1, 2}, v1)
|
||||
|
||||
rows = sqlmock.NewRows([]string{"name", "COUNT(*)"}).
|
||||
mock = sqlmock.NewRows([]string{"name", "COUNT(*)"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
var v2 []struct {
|
||||
Name string
|
||||
Count int
|
||||
}
|
||||
require.NoError(t, scanSlice(rows, &v2))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v2))
|
||||
require.Equal(t, "foo", v2[0].Name)
|
||||
require.Equal(t, "bar", v2[1].Name)
|
||||
require.Equal(t, 1, v2[0].Count)
|
||||
require.Equal(t, 2, v2[1].Count)
|
||||
|
||||
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
var v3 []struct {
|
||||
Count int
|
||||
Name string `json:"nick_name"`
|
||||
}
|
||||
require.NoError(t, scanSlice(rows, &v3))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v3))
|
||||
require.Equal(t, "foo", v3[0].Name)
|
||||
require.Equal(t, "bar", v3[1].Name)
|
||||
require.Equal(t, 1, v3[0].Count)
|
||||
require.Equal(t, 2, v3[1].Count)
|
||||
|
||||
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
var v4 []*struct {
|
||||
@@ -60,20 +61,20 @@ func TestScanSlice(t *testing.T) {
|
||||
Name string `json:"nick_name"`
|
||||
Ignored string `json:"string"`
|
||||
}
|
||||
require.NoError(t, scanSlice(rows, &v4))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v4))
|
||||
require.Equal(t, "foo", v4[0].Name)
|
||||
require.Equal(t, "bar", v4[1].Name)
|
||||
require.Equal(t, 1, v4[0].Count)
|
||||
require.Equal(t, 2, v4[1].Count)
|
||||
|
||||
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
var v5 []*struct {
|
||||
Count int
|
||||
Name string `json:"name" sql:"nick_name"`
|
||||
}
|
||||
require.NoError(t, scanSlice(rows, &v5))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v5))
|
||||
require.Equal(t, "foo", v5[0].Name)
|
||||
require.Equal(t, "bar", v5[1].Name)
|
||||
require.Equal(t, 1, v5[0].Count)
|
||||
@@ -81,39 +82,66 @@ func TestScanSlice(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestScanSlicePtr(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"name"}).
|
||||
mock := sqlmock.NewRows([]string{"name"}).
|
||||
AddRow("foo").
|
||||
AddRow("bar")
|
||||
var v0 []*string
|
||||
require.NoError(t, scanSlice(rows, &v0))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v0))
|
||||
require.Equal(t, "foo", *v0[0])
|
||||
require.Equal(t, "bar", *v0[1])
|
||||
|
||||
rows = sqlmock.NewRows([]string{"age"}).
|
||||
mock = sqlmock.NewRows([]string{"age"}).
|
||||
AddRow(1).
|
||||
AddRow(2)
|
||||
var v1 []**int
|
||||
require.NoError(t, scanSlice(rows, &v1))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v1))
|
||||
require.Equal(t, 1, **v1[0])
|
||||
require.Equal(t, 2, **v1[1])
|
||||
|
||||
rows = sqlmock.NewRows([]string{"age", "name"}).
|
||||
mock = sqlmock.NewRows([]string{"age", "name"}).
|
||||
AddRow(1, "a8m").
|
||||
AddRow(2, "nati")
|
||||
var v2 []*struct {
|
||||
Age *int
|
||||
Name **string
|
||||
}
|
||||
require.NoError(t, scanSlice(rows, &v2))
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v2))
|
||||
require.Equal(t, 1, *v2[0].Age)
|
||||
require.Equal(t, "a8m", **v2[0].Name)
|
||||
require.Equal(t, 2, *v2[1].Age)
|
||||
require.Equal(t, "nati", **v2[1].Name)
|
||||
}
|
||||
|
||||
func scanSlice(mrows *sqlmock.Rows, v interface{}) error {
|
||||
func TestScanInt64(t *testing.T) {
|
||||
mock := sqlmock.NewRows([]string{"age"}).
|
||||
AddRow("10").
|
||||
AddRow("20")
|
||||
n, err := ScanInt64(toRows(mock))
|
||||
require.Error(t, err)
|
||||
require.Zero(t, n)
|
||||
|
||||
mock = sqlmock.NewRows([]string{"age", "count"}).
|
||||
AddRow("10", "1")
|
||||
n, err = ScanInt64(toRows(mock))
|
||||
require.Error(t, err)
|
||||
require.Zero(t, n)
|
||||
|
||||
mock = sqlmock.NewRows([]string{"count"}).
|
||||
AddRow(10)
|
||||
n, err = ScanInt64(toRows(mock))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 10, n)
|
||||
|
||||
mock = sqlmock.NewRows([]string{"count"}).
|
||||
AddRow("10")
|
||||
n, err = ScanInt64(toRows(mock))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 10, n)
|
||||
}
|
||||
|
||||
func toRows(mrows *sqlmock.Rows) *sql.Rows {
|
||||
db, mock, _ := sqlmock.New()
|
||||
mock.ExpectQuery("").WillReturnRows(mrows)
|
||||
rows, _ := db.Query("")
|
||||
return ScanSlice(rows, v)
|
||||
return rows
|
||||
}
|
||||
|
||||
@@ -423,12 +423,9 @@ func exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}
|
||||
return false, fmt.Errorf("reading schema information %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, fmt.Errorf("no rows returned")
|
||||
}
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return false, fmt.Errorf("scanning count")
|
||||
n, err := sql.ScanInt(rows)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user