dialect/sql/sqlscan: allow scanning optional values to non-pointer fields

This commit is contained in:
Ariel Mashraki
2021-04-20 09:13:35 +03:00
committed by Ariel Mashraki
parent b3041725d2
commit 89d1bcd80c
2 changed files with 40 additions and 2 deletions

View File

@@ -204,17 +204,29 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
if len(idx) > 1 {
rtype = rtype.Field(idx[1]).Type
}
if !nillable(rtype) {
// Create a pointer to the actual reflect
// types to accept optional struct fields.
rtype = reflect.PtrTo(rtype)
}
scan.columns = append(scan.columns, rtype)
}
scan.value = func(vs ...interface{}) reflect.Value {
st := reflect.New(typ).Elem()
for i, v := range vs {
rv := reflect.Indirect(reflect.ValueOf(v))
if rv.IsNil() {
continue
}
idx := idxs[i]
rvalue := st.Field(idx[0])
if len(idx) > 1 {
rvalue = rvalue.Field(idx[1])
}
rvalue.Set(reflect.Indirect(reflect.ValueOf(v)))
if !nillable(rvalue.Type()) {
rv = reflect.Indirect(rv)
}
rvalue.Set(rv)
}
return st
}
@@ -232,6 +244,15 @@ func columnName(f reflect.StructField) string {
return name
}
// nillable reports if the reflect-type can have nil value.
func nillable(t reflect.Type) bool {
switch t.Kind() {
case reflect.Interface, reflect.Slice, reflect.Map, reflect.Ptr, reflect.UnsafePointer:
return true
}
return false
}
// scanPtr wraps the underlying type with rowScan.
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
typ = typ.Elem()

View File

@@ -142,7 +142,8 @@ func TestScanSlice(t *testing.T) {
func TestScanNestedStruct(t *testing.T) {
mock := sqlmock.NewRows([]string{"name", "age"}).
AddRow("foo", 1).
AddRow("bar", 2)
AddRow("bar", 2).
AddRow("baz", nil)
type T struct{ Name string }
var v []struct {
T
@@ -153,6 +154,22 @@ func TestScanNestedStruct(t *testing.T) {
require.Equal(t, 1, v[0].Age)
require.Equal(t, "bar", v[1].Name)
require.Equal(t, 2, v[1].Age)
require.Equal(t, "baz", v[2].Name)
require.Equal(t, 0, v[2].Age)
mock = sqlmock.NewRows([]string{"name", "age"}).
AddRow("foo", 1).
AddRow("bar", nil)
type T1 struct{ Name **string }
var v1 []struct {
T1
Age *int
}
require.NoError(t, ScanSlice(toRows(mock), &v1))
require.Equal(t, "foo", **v1[0].Name)
require.Equal(t, "bar", **v1[1].Name)
require.Equal(t, 1, *v1[0].Age)
require.Nil(t, v1[1].Age)
}
func TestScanSlicePtr(t *testing.T) {