mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/sqlscan: allow scanning optional values to non-pointer fields
This commit is contained in:
committed by
Ariel Mashraki
parent
b3041725d2
commit
89d1bcd80c
@@ -204,17 +204,29 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
if len(idx) > 1 {
|
||||
rtype = rtype.Field(idx[1]).Type
|
||||
}
|
||||
if !nillable(rtype) {
|
||||
// Create a pointer to the actual reflect
|
||||
// types to accept optional struct fields.
|
||||
rtype = reflect.PtrTo(rtype)
|
||||
}
|
||||
scan.columns = append(scan.columns, rtype)
|
||||
}
|
||||
scan.value = func(vs ...interface{}) reflect.Value {
|
||||
st := reflect.New(typ).Elem()
|
||||
for i, v := range vs {
|
||||
rv := reflect.Indirect(reflect.ValueOf(v))
|
||||
if rv.IsNil() {
|
||||
continue
|
||||
}
|
||||
idx := idxs[i]
|
||||
rvalue := st.Field(idx[0])
|
||||
if len(idx) > 1 {
|
||||
rvalue = rvalue.Field(idx[1])
|
||||
}
|
||||
rvalue.Set(reflect.Indirect(reflect.ValueOf(v)))
|
||||
if !nillable(rvalue.Type()) {
|
||||
rv = reflect.Indirect(rv)
|
||||
}
|
||||
rvalue.Set(rv)
|
||||
}
|
||||
return st
|
||||
}
|
||||
@@ -232,6 +244,15 @@ func columnName(f reflect.StructField) string {
|
||||
return name
|
||||
}
|
||||
|
||||
// nillable reports if the reflect-type can have nil value.
|
||||
func nillable(t reflect.Type) bool {
|
||||
switch t.Kind() {
|
||||
case reflect.Interface, reflect.Slice, reflect.Map, reflect.Ptr, reflect.UnsafePointer:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// scanPtr wraps the underlying type with rowScan.
|
||||
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
typ = typ.Elem()
|
||||
|
||||
@@ -142,7 +142,8 @@ func TestScanSlice(t *testing.T) {
|
||||
func TestScanNestedStruct(t *testing.T) {
|
||||
mock := sqlmock.NewRows([]string{"name", "age"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
AddRow("bar", 2).
|
||||
AddRow("baz", nil)
|
||||
type T struct{ Name string }
|
||||
var v []struct {
|
||||
T
|
||||
@@ -153,6 +154,22 @@ func TestScanNestedStruct(t *testing.T) {
|
||||
require.Equal(t, 1, v[0].Age)
|
||||
require.Equal(t, "bar", v[1].Name)
|
||||
require.Equal(t, 2, v[1].Age)
|
||||
require.Equal(t, "baz", v[2].Name)
|
||||
require.Equal(t, 0, v[2].Age)
|
||||
|
||||
mock = sqlmock.NewRows([]string{"name", "age"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", nil)
|
||||
type T1 struct{ Name **string }
|
||||
var v1 []struct {
|
||||
T1
|
||||
Age *int
|
||||
}
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v1))
|
||||
require.Equal(t, "foo", **v1[0].Name)
|
||||
require.Equal(t, "bar", **v1[1].Name)
|
||||
require.Equal(t, 1, *v1[0].Age)
|
||||
require.Nil(t, v1[1].Age)
|
||||
}
|
||||
|
||||
func TestScanSlicePtr(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user