mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: support recursive pointer types
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/210 Reviewed By: alexsn Differential Revision: D18757642 fbshipit-source-id: 6dec318f68e2f14f1c7ffdf2dcd185be130b2f77
This commit is contained in:
committed by
Facebook Github Bot
parent
80dab06a57
commit
0f4fc12cc5
@@ -28,37 +28,9 @@ func ScanSlice(rows ColumnScanner, v interface{}) error {
|
||||
if k := rv.Kind(); k != reflect.Slice {
|
||||
return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k)
|
||||
}
|
||||
var (
|
||||
scan *rowScan
|
||||
typ = rv.Type().Elem()
|
||||
)
|
||||
switch k := typ.Kind(); {
|
||||
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
|
||||
scan = &rowScan{
|
||||
columns: []reflect.Type{typ},
|
||||
value: func(v ...interface{}) reflect.Value {
|
||||
return reflect.Indirect(reflect.ValueOf(v[0]))
|
||||
},
|
||||
}
|
||||
case k == reflect.Ptr:
|
||||
typ = typ.Elem()
|
||||
if scan, err = scanStruct(typ, columns); err != nil {
|
||||
return err
|
||||
}
|
||||
wrap := scan.value
|
||||
scan.value = func(vs ...interface{}) reflect.Value {
|
||||
v := wrap(vs...)
|
||||
pt := reflect.PtrTo(v.Type())
|
||||
pv := reflect.New(pt.Elem())
|
||||
pv.Elem().Set(v)
|
||||
return pv
|
||||
}
|
||||
case k == reflect.Struct:
|
||||
if scan, err = scanStruct(typ, columns); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("sql/scan: unsupported type ([]%s)", k)
|
||||
scan, err := scanType(rv.Type().Elem(), columns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n, m := len(columns), len(scan.columns); n > m {
|
||||
return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m)
|
||||
@@ -91,6 +63,25 @@ func (r *rowScan) values() []interface{} {
|
||||
return values
|
||||
}
|
||||
|
||||
// scanType returns rowScan for the given reflect.Type.
|
||||
func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
switch k := typ.Kind(); {
|
||||
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
|
||||
return &rowScan{
|
||||
columns: []reflect.Type{typ},
|
||||
value: func(v ...interface{}) reflect.Value {
|
||||
return reflect.Indirect(reflect.ValueOf(v[0]))
|
||||
},
|
||||
}, nil
|
||||
case k == reflect.Ptr:
|
||||
return scanPtr(typ, columns)
|
||||
case k == reflect.Struct:
|
||||
return scanStruct(typ, columns)
|
||||
default:
|
||||
return nil, fmt.Errorf("sql/scan: unsupported type ([]%s)", k)
|
||||
}
|
||||
}
|
||||
|
||||
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
|
||||
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
var (
|
||||
@@ -127,3 +118,21 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
}
|
||||
return scan, nil
|
||||
}
|
||||
|
||||
// scanPtr wraps the underlying type with rowScan.
|
||||
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
typ = typ.Elem()
|
||||
scan, err := scanType(typ, columns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wrap := scan.value
|
||||
scan.value = func(vs ...interface{}) reflect.Value {
|
||||
v := wrap(vs...)
|
||||
pt := reflect.PtrTo(v.Type())
|
||||
pv := reflect.New(pt.Elem())
|
||||
pv.Elem().Set(v)
|
||||
return pv
|
||||
}
|
||||
return scan, nil
|
||||
}
|
||||
|
||||
@@ -5,105 +5,115 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestScanSlice(t *testing.T) {
|
||||
rows := &mockRows{
|
||||
columns: []string{"name"},
|
||||
values: [][]interface{}{{"foo"}, {"bar"}},
|
||||
}
|
||||
rows := sqlmock.NewRows([]string{"name"}).
|
||||
AddRow("foo").
|
||||
AddRow("bar")
|
||||
var v0 []string
|
||||
require.NoError(t, ScanSlice(rows, &v0))
|
||||
require.NoError(t, scanSlice(rows, &v0))
|
||||
require.Equal(t, []string{"foo", "bar"}, v0)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"age"},
|
||||
values: [][]interface{}{{1}, {2}},
|
||||
}
|
||||
rows = sqlmock.NewRows([]string{"age"}).
|
||||
AddRow(1).
|
||||
AddRow(2)
|
||||
var v1 []int
|
||||
require.NoError(t, ScanSlice(rows, &v1))
|
||||
require.NoError(t, scanSlice(rows, &v1))
|
||||
require.Equal(t, []int{1, 2}, v1)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"name", "COUNT(*)"},
|
||||
values: [][]interface{}{{"foo", 1}, {"bar", 2}},
|
||||
}
|
||||
rows = sqlmock.NewRows([]string{"name", "COUNT(*)"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
var v2 []struct {
|
||||
Name string
|
||||
Count int
|
||||
}
|
||||
require.NoError(t, ScanSlice(rows, &v2))
|
||||
require.NoError(t, scanSlice(rows, &v2))
|
||||
require.Equal(t, "foo", v2[0].Name)
|
||||
require.Equal(t, "bar", v2[1].Name)
|
||||
require.Equal(t, 1, v2[0].Count)
|
||||
require.Equal(t, 2, v2[1].Count)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"nick_name", "COUNT(*)"},
|
||||
values: [][]interface{}{{"foo", 1}, {"bar", 2}},
|
||||
}
|
||||
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
var v3 []struct {
|
||||
Count int
|
||||
Name string `json:"nick_name"`
|
||||
}
|
||||
require.NoError(t, ScanSlice(rows, &v3))
|
||||
require.NoError(t, scanSlice(rows, &v3))
|
||||
require.Equal(t, "foo", v3[0].Name)
|
||||
require.Equal(t, "bar", v3[1].Name)
|
||||
require.Equal(t, 1, v3[0].Count)
|
||||
require.Equal(t, 2, v3[1].Count)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"nick_name", "COUNT(*)"},
|
||||
values: [][]interface{}{{"foo", 1}, {"bar", 2}},
|
||||
}
|
||||
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
var v4 []*struct {
|
||||
Count int
|
||||
Name string `json:"nick_name"`
|
||||
Ignored string `json:"string"`
|
||||
}
|
||||
require.NoError(t, ScanSlice(rows, &v4))
|
||||
require.NoError(t, scanSlice(rows, &v4))
|
||||
require.Equal(t, "foo", v4[0].Name)
|
||||
require.Equal(t, "bar", v4[1].Name)
|
||||
require.Equal(t, 1, v4[0].Count)
|
||||
require.Equal(t, 2, v4[1].Count)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"nick_name", "COUNT(*)"},
|
||||
values: [][]interface{}{{"foo", 1}, {"bar", 2}},
|
||||
}
|
||||
rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
var v5 []*struct {
|
||||
Count int
|
||||
Name string `json:"name" sql:"nick_name"`
|
||||
}
|
||||
require.NoError(t, ScanSlice(rows, &v5))
|
||||
require.NoError(t, scanSlice(rows, &v5))
|
||||
require.Equal(t, "foo", v5[0].Name)
|
||||
require.Equal(t, "bar", v5[1].Name)
|
||||
require.Equal(t, 1, v5[0].Count)
|
||||
require.Equal(t, 2, v5[1].Count)
|
||||
}
|
||||
|
||||
type mockRows struct {
|
||||
columns []string
|
||||
values [][]interface{}
|
||||
func TestScanSlicePtr(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"name"}).
|
||||
AddRow("foo").
|
||||
AddRow("bar")
|
||||
var v0 []*string
|
||||
require.NoError(t, scanSlice(rows, &v0))
|
||||
require.Equal(t, "foo", *v0[0])
|
||||
require.Equal(t, "bar", *v0[1])
|
||||
|
||||
rows = sqlmock.NewRows([]string{"age"}).
|
||||
AddRow(1).
|
||||
AddRow(2)
|
||||
var v1 []**int
|
||||
require.NoError(t, scanSlice(rows, &v1))
|
||||
require.Equal(t, 1, **v1[0])
|
||||
require.Equal(t, 2, **v1[1])
|
||||
|
||||
rows = sqlmock.NewRows([]string{"age", "name"}).
|
||||
AddRow(1, "a8m").
|
||||
AddRow(2, "nati")
|
||||
var v2 []*struct {
|
||||
Age *int
|
||||
Name **string
|
||||
}
|
||||
require.NoError(t, scanSlice(rows, &v2))
|
||||
require.Equal(t, 1, *v2[0].Age)
|
||||
require.Equal(t, "a8m", **v2[0].Name)
|
||||
require.Equal(t, 2, *v2[1].Age)
|
||||
require.Equal(t, "nati", **v2[1].Name)
|
||||
}
|
||||
|
||||
func (m mockRows) Columns() ([]string, error) { return m.columns, nil }
|
||||
|
||||
func (m mockRows) Next() bool { return len(m.values) > 0 }
|
||||
|
||||
func (m *mockRows) Scan(vs ...interface{}) error {
|
||||
if len(m.values) == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
row := m.values[0]
|
||||
m.values = m.values[1:]
|
||||
for i := range vs {
|
||||
reflect.Indirect(reflect.ValueOf(vs[i])).Set(reflect.ValueOf(row[i]))
|
||||
}
|
||||
return nil
|
||||
func scanSlice(mrows *sqlmock.Rows, v interface{}) error {
|
||||
db, mock, _ := sqlmock.New()
|
||||
mock.ExpectQuery("").WillReturnRows(mrows)
|
||||
rows, _ := db.Query("")
|
||||
return ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user