diff --git a/dialect/sql/scan.go b/dialect/sql/scan.go index aa9652c13..6cfa01f0d 100644 --- a/dialect/sql/scan.go +++ b/dialect/sql/scan.go @@ -204,17 +204,29 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { if len(idx) > 1 { rtype = rtype.Field(idx[1]).Type } + if !nillable(rtype) { + // Create a pointer to the actual reflect + // types to accept optional struct fields. + rtype = reflect.PtrTo(rtype) + } scan.columns = append(scan.columns, rtype) } scan.value = func(vs ...interface{}) reflect.Value { st := reflect.New(typ).Elem() for i, v := range vs { + rv := reflect.Indirect(reflect.ValueOf(v)) + if rv.IsNil() { + continue + } idx := idxs[i] rvalue := st.Field(idx[0]) if len(idx) > 1 { rvalue = rvalue.Field(idx[1]) } - rvalue.Set(reflect.Indirect(reflect.ValueOf(v))) + if !nillable(rvalue.Type()) { + rv = reflect.Indirect(rv) + } + rvalue.Set(rv) } return st } @@ -232,6 +244,15 @@ func columnName(f reflect.StructField) string { return name } +// nillable reports if the reflect-type can have nil value. +func nillable(t reflect.Type) bool { + switch t.Kind() { + case reflect.Interface, reflect.Slice, reflect.Map, reflect.Ptr, reflect.UnsafePointer: + return true + } + return false +} + // scanPtr wraps the underlying type with rowScan. func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) { typ = typ.Elem() diff --git a/dialect/sql/scan_test.go b/dialect/sql/scan_test.go index 5022ae13d..a107f9354 100644 --- a/dialect/sql/scan_test.go +++ b/dialect/sql/scan_test.go @@ -142,7 +142,8 @@ func TestScanSlice(t *testing.T) { func TestScanNestedStruct(t *testing.T) { mock := sqlmock.NewRows([]string{"name", "age"}). AddRow("foo", 1). - AddRow("bar", 2) + AddRow("bar", 2). + AddRow("baz", nil) type T struct{ Name string } var v []struct { T @@ -153,6 +154,22 @@ func TestScanNestedStruct(t *testing.T) { require.Equal(t, 1, v[0].Age) require.Equal(t, "bar", v[1].Name) require.Equal(t, 2, v[1].Age) + require.Equal(t, "baz", v[2].Name) + require.Equal(t, 0, v[2].Age) + + mock = sqlmock.NewRows([]string{"name", "age"}). + AddRow("foo", 1). + AddRow("bar", nil) + type T1 struct{ Name **string } + var v1 []struct { + T1 + Age *int + } + require.NoError(t, ScanSlice(toRows(mock), &v1)) + require.Equal(t, "foo", **v1[0].Name) + require.Equal(t, "bar", **v1[1].Name) + require.Equal(t, 1, *v1[0].Age) + require.Nil(t, v1[1].Age) } func TestScanSlicePtr(t *testing.T) {