diff --git a/dialect/sql/scan.go b/dialect/sql/scan.go index b880fc962..8482c7fa2 100644 --- a/dialect/sql/scan.go +++ b/dialect/sql/scan.go @@ -28,37 +28,9 @@ func ScanSlice(rows ColumnScanner, v interface{}) error { if k := rv.Kind(); k != reflect.Slice { return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k) } - var ( - scan *rowScan - typ = rv.Type().Elem() - ) - switch k := typ.Kind(); { - case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64: - scan = &rowScan{ - columns: []reflect.Type{typ}, - value: func(v ...interface{}) reflect.Value { - return reflect.Indirect(reflect.ValueOf(v[0])) - }, - } - case k == reflect.Ptr: - typ = typ.Elem() - if scan, err = scanStruct(typ, columns); err != nil { - return err - } - wrap := scan.value - scan.value = func(vs ...interface{}) reflect.Value { - v := wrap(vs...) - pt := reflect.PtrTo(v.Type()) - pv := reflect.New(pt.Elem()) - pv.Elem().Set(v) - return pv - } - case k == reflect.Struct: - if scan, err = scanStruct(typ, columns); err != nil { - return err - } - default: - return fmt.Errorf("sql/scan: unsupported type ([]%s)", k) + scan, err := scanType(rv.Type().Elem(), columns) + if err != nil { + return err } if n, m := len(columns), len(scan.columns); n > m { return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m) @@ -91,6 +63,25 @@ func (r *rowScan) values() []interface{} { return values } +// scanType returns rowScan for the given reflect.Type. +func scanType(typ reflect.Type, columns []string) (*rowScan, error) { + switch k := typ.Kind(); { + case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64: + return &rowScan{ + columns: []reflect.Type{typ}, + value: func(v ...interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(v[0])) + }, + }, nil + case k == reflect.Ptr: + return scanPtr(typ, columns) + case k == reflect.Struct: + return scanStruct(typ, columns) + default: + return nil, fmt.Errorf("sql/scan: unsupported type ([]%s)", k) + } +} + // scanStruct returns the a configuration for scanning an sql.Row into a struct. func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { var ( @@ -127,3 +118,21 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { } return scan, nil } + +// scanPtr wraps the underlying type with rowScan. +func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) { + typ = typ.Elem() + scan, err := scanType(typ, columns) + if err != nil { + return nil, err + } + wrap := scan.value + scan.value = func(vs ...interface{}) reflect.Value { + v := wrap(vs...) + pt := reflect.PtrTo(v.Type()) + pv := reflect.New(pt.Elem()) + pv.Elem().Set(v) + return pv + } + return scan, nil +} diff --git a/dialect/sql/scan_test.go b/dialect/sql/scan_test.go index 884149772..6c3845069 100644 --- a/dialect/sql/scan_test.go +++ b/dialect/sql/scan_test.go @@ -5,105 +5,115 @@ package sql import ( - "database/sql" - "reflect" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestScanSlice(t *testing.T) { - rows := &mockRows{ - columns: []string{"name"}, - values: [][]interface{}{{"foo"}, {"bar"}}, - } + rows := sqlmock.NewRows([]string{"name"}). + AddRow("foo"). + AddRow("bar") var v0 []string - require.NoError(t, ScanSlice(rows, &v0)) + require.NoError(t, scanSlice(rows, &v0)) require.Equal(t, []string{"foo", "bar"}, v0) - rows = &mockRows{ - columns: []string{"age"}, - values: [][]interface{}{{1}, {2}}, - } + rows = sqlmock.NewRows([]string{"age"}). + AddRow(1). + AddRow(2) var v1 []int - require.NoError(t, ScanSlice(rows, &v1)) + require.NoError(t, scanSlice(rows, &v1)) require.Equal(t, []int{1, 2}, v1) - rows = &mockRows{ - columns: []string{"name", "COUNT(*)"}, - values: [][]interface{}{{"foo", 1}, {"bar", 2}}, - } + rows = sqlmock.NewRows([]string{"name", "COUNT(*)"}). + AddRow("foo", 1). + AddRow("bar", 2) var v2 []struct { Name string Count int } - require.NoError(t, ScanSlice(rows, &v2)) + require.NoError(t, scanSlice(rows, &v2)) require.Equal(t, "foo", v2[0].Name) require.Equal(t, "bar", v2[1].Name) require.Equal(t, 1, v2[0].Count) require.Equal(t, 2, v2[1].Count) - rows = &mockRows{ - columns: []string{"nick_name", "COUNT(*)"}, - values: [][]interface{}{{"foo", 1}, {"bar", 2}}, - } + rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). + AddRow("foo", 1). + AddRow("bar", 2) var v3 []struct { Count int Name string `json:"nick_name"` } - require.NoError(t, ScanSlice(rows, &v3)) + require.NoError(t, scanSlice(rows, &v3)) require.Equal(t, "foo", v3[0].Name) require.Equal(t, "bar", v3[1].Name) require.Equal(t, 1, v3[0].Count) require.Equal(t, 2, v3[1].Count) - rows = &mockRows{ - columns: []string{"nick_name", "COUNT(*)"}, - values: [][]interface{}{{"foo", 1}, {"bar", 2}}, - } + rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). + AddRow("foo", 1). + AddRow("bar", 2) var v4 []*struct { Count int Name string `json:"nick_name"` Ignored string `json:"string"` } - require.NoError(t, ScanSlice(rows, &v4)) + require.NoError(t, scanSlice(rows, &v4)) require.Equal(t, "foo", v4[0].Name) require.Equal(t, "bar", v4[1].Name) require.Equal(t, 1, v4[0].Count) require.Equal(t, 2, v4[1].Count) - rows = &mockRows{ - columns: []string{"nick_name", "COUNT(*)"}, - values: [][]interface{}{{"foo", 1}, {"bar", 2}}, - } + rows = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). + AddRow("foo", 1). + AddRow("bar", 2) var v5 []*struct { Count int Name string `json:"name" sql:"nick_name"` } - require.NoError(t, ScanSlice(rows, &v5)) + require.NoError(t, scanSlice(rows, &v5)) require.Equal(t, "foo", v5[0].Name) require.Equal(t, "bar", v5[1].Name) require.Equal(t, 1, v5[0].Count) require.Equal(t, 2, v5[1].Count) } -type mockRows struct { - columns []string - values [][]interface{} +func TestScanSlicePtr(t *testing.T) { + rows := sqlmock.NewRows([]string{"name"}). + AddRow("foo"). + AddRow("bar") + var v0 []*string + require.NoError(t, scanSlice(rows, &v0)) + require.Equal(t, "foo", *v0[0]) + require.Equal(t, "bar", *v0[1]) + + rows = sqlmock.NewRows([]string{"age"}). + AddRow(1). + AddRow(2) + var v1 []**int + require.NoError(t, scanSlice(rows, &v1)) + require.Equal(t, 1, **v1[0]) + require.Equal(t, 2, **v1[1]) + + rows = sqlmock.NewRows([]string{"age", "name"}). + AddRow(1, "a8m"). + AddRow(2, "nati") + var v2 []*struct { + Age *int + Name **string + } + require.NoError(t, scanSlice(rows, &v2)) + require.Equal(t, 1, *v2[0].Age) + require.Equal(t, "a8m", **v2[0].Name) + require.Equal(t, 2, *v2[1].Age) + require.Equal(t, "nati", **v2[1].Name) } -func (m mockRows) Columns() ([]string, error) { return m.columns, nil } - -func (m mockRows) Next() bool { return len(m.values) > 0 } - -func (m *mockRows) Scan(vs ...interface{}) error { - if len(m.values) == 0 { - return sql.ErrNoRows - } - row := m.values[0] - m.values = m.values[1:] - for i := range vs { - reflect.Indirect(reflect.ValueOf(vs[i])).Set(reflect.ValueOf(row[i])) - } - return nil +func scanSlice(mrows *sqlmock.Rows, v interface{}) error { + db, mock, _ := sqlmock.New() + mock.ExpectQuery("").WillReturnRows(mrows) + rows, _ := db.Query("") + return ScanSlice(rows, v) }