dialect/sql/sqlscan: allow scanning values to embedded struct fields

This commit is contained in:
Ariel Mashraki
2021-04-19 18:28:15 +03:00
committed by Ariel Mashraki
parent 497fca4c96
commit b3041725d2
2 changed files with 49 additions and 12 deletions

View File

@@ -174,8 +174,8 @@ func assignable(typ reflect.Type) bool {
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
var (
scan = &rowScan{}
idx = make([]int, 0, typ.NumField())
names = make(map[string]int, typ.NumField())
idxs = make([][]int, 0, typ.NumField())
names = make(map[string][]int, typ.NumField())
)
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
@@ -183,34 +183,55 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
if f.PkgPath != "" {
continue
}
name := strings.ToLower(f.Name)
if tag, ok := f.Tag.Lookup("sql"); ok {
name = tag
} else if tag, ok := f.Tag.Lookup("json"); ok {
name = strings.Split(tag, ",")[0]
// Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`.
if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct {
for j := 0; j < typ.NumField(); j++ {
names[columnName(typ.Field(j))] = []int{i, j}
}
continue
}
names[name] = i
names[columnName(f)] = []int{i}
}
for _, c := range columns {
// Normalize columns if necessary, for example: COUNT(*) => count.
name := strings.ToLower(strings.Split(c, "(")[0])
i, ok := names[name]
idx, ok := names[name]
if !ok {
return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name)
}
idx = append(idx, i)
scan.columns = append(scan.columns, typ.Field(i).Type)
idxs = append(idxs, idx)
rtype := typ.Field(idx[0]).Type
if len(idx) > 1 {
rtype = rtype.Field(idx[1]).Type
}
scan.columns = append(scan.columns, rtype)
}
scan.value = func(vs ...interface{}) reflect.Value {
st := reflect.New(typ).Elem()
for i, v := range vs {
st.Field(idx[i]).Set(reflect.Indirect(reflect.ValueOf(v)))
idx := idxs[i]
rvalue := st.Field(idx[0])
if len(idx) > 1 {
rvalue = rvalue.Field(idx[1])
}
rvalue.Set(reflect.Indirect(reflect.ValueOf(v)))
}
return st
}
return scan, nil
}
// columnName returns the column name of a struct-field.
func columnName(f reflect.StructField) string {
name := strings.ToLower(f.Name)
if tag, ok := f.Tag.Lookup("sql"); ok {
name = tag
} else if tag, ok := f.Tag.Lookup("json"); ok {
name = strings.Split(tag, ",")[0]
}
return name
}
// scanPtr wraps the underlying type with rowScan.
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
typ = typ.Elem()

View File

@@ -139,6 +139,22 @@ func TestScanSlice(t *testing.T) {
require.Empty(t, pp)
}
func TestScanNestedStruct(t *testing.T) {
mock := sqlmock.NewRows([]string{"name", "age"}).
AddRow("foo", 1).
AddRow("bar", 2)
type T struct{ Name string }
var v []struct {
T
Age int
}
require.NoError(t, ScanSlice(toRows(mock), &v))
require.Equal(t, "foo", v[0].Name)
require.Equal(t, 1, v[0].Age)
require.Equal(t, "bar", v[1].Name)
require.Equal(t, 2, v[1].Age)
}
func TestScanSlicePtr(t *testing.T) {
mock := sqlmock.NewRows([]string{"name"}).
AddRow("foo").