mirror of
https://github.com/ent/ent.git
synced 2026-05-28 09:49:08 +03:00
dialect/sql/sqlscan: allow scanning optional values to non-pointer fields
This commit is contained in:
committed by
Ariel Mashraki
parent
b3041725d2
commit
89d1bcd80c
@@ -204,17 +204,29 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
|||||||
if len(idx) > 1 {
|
if len(idx) > 1 {
|
||||||
rtype = rtype.Field(idx[1]).Type
|
rtype = rtype.Field(idx[1]).Type
|
||||||
}
|
}
|
||||||
|
if !nillable(rtype) {
|
||||||
|
// Create a pointer to the actual reflect
|
||||||
|
// types to accept optional struct fields.
|
||||||
|
rtype = reflect.PtrTo(rtype)
|
||||||
|
}
|
||||||
scan.columns = append(scan.columns, rtype)
|
scan.columns = append(scan.columns, rtype)
|
||||||
}
|
}
|
||||||
scan.value = func(vs ...interface{}) reflect.Value {
|
scan.value = func(vs ...interface{}) reflect.Value {
|
||||||
st := reflect.New(typ).Elem()
|
st := reflect.New(typ).Elem()
|
||||||
for i, v := range vs {
|
for i, v := range vs {
|
||||||
|
rv := reflect.Indirect(reflect.ValueOf(v))
|
||||||
|
if rv.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
idx := idxs[i]
|
idx := idxs[i]
|
||||||
rvalue := st.Field(idx[0])
|
rvalue := st.Field(idx[0])
|
||||||
if len(idx) > 1 {
|
if len(idx) > 1 {
|
||||||
rvalue = rvalue.Field(idx[1])
|
rvalue = rvalue.Field(idx[1])
|
||||||
}
|
}
|
||||||
rvalue.Set(reflect.Indirect(reflect.ValueOf(v)))
|
if !nillable(rvalue.Type()) {
|
||||||
|
rv = reflect.Indirect(rv)
|
||||||
|
}
|
||||||
|
rvalue.Set(rv)
|
||||||
}
|
}
|
||||||
return st
|
return st
|
||||||
}
|
}
|
||||||
@@ -232,6 +244,15 @@ func columnName(f reflect.StructField) string {
|
|||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nillable reports if the reflect-type can have nil value.
|
||||||
|
func nillable(t reflect.Type) bool {
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Interface, reflect.Slice, reflect.Map, reflect.Ptr, reflect.UnsafePointer:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// scanPtr wraps the underlying type with rowScan.
|
// scanPtr wraps the underlying type with rowScan.
|
||||||
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
|
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||||
typ = typ.Elem()
|
typ = typ.Elem()
|
||||||
|
|||||||
@@ -142,7 +142,8 @@ func TestScanSlice(t *testing.T) {
|
|||||||
func TestScanNestedStruct(t *testing.T) {
|
func TestScanNestedStruct(t *testing.T) {
|
||||||
mock := sqlmock.NewRows([]string{"name", "age"}).
|
mock := sqlmock.NewRows([]string{"name", "age"}).
|
||||||
AddRow("foo", 1).
|
AddRow("foo", 1).
|
||||||
AddRow("bar", 2)
|
AddRow("bar", 2).
|
||||||
|
AddRow("baz", nil)
|
||||||
type T struct{ Name string }
|
type T struct{ Name string }
|
||||||
var v []struct {
|
var v []struct {
|
||||||
T
|
T
|
||||||
@@ -153,6 +154,22 @@ func TestScanNestedStruct(t *testing.T) {
|
|||||||
require.Equal(t, 1, v[0].Age)
|
require.Equal(t, 1, v[0].Age)
|
||||||
require.Equal(t, "bar", v[1].Name)
|
require.Equal(t, "bar", v[1].Name)
|
||||||
require.Equal(t, 2, v[1].Age)
|
require.Equal(t, 2, v[1].Age)
|
||||||
|
require.Equal(t, "baz", v[2].Name)
|
||||||
|
require.Equal(t, 0, v[2].Age)
|
||||||
|
|
||||||
|
mock = sqlmock.NewRows([]string{"name", "age"}).
|
||||||
|
AddRow("foo", 1).
|
||||||
|
AddRow("bar", nil)
|
||||||
|
type T1 struct{ Name **string }
|
||||||
|
var v1 []struct {
|
||||||
|
T1
|
||||||
|
Age *int
|
||||||
|
}
|
||||||
|
require.NoError(t, ScanSlice(toRows(mock), &v1))
|
||||||
|
require.Equal(t, "foo", **v1[0].Name)
|
||||||
|
require.Equal(t, "bar", **v1[1].Name)
|
||||||
|
require.Equal(t, 1, *v1[0].Age)
|
||||||
|
require.Nil(t, v1[1].Age)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanSlicePtr(t *testing.T) {
|
func TestScanSlicePtr(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user