From b3041725d2c7821673ffbe2dc22b605df0d36ca2 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 19 Apr 2021 18:28:15 +0300 Subject: [PATCH] dialect/sql/sqlscan: allow scanning values to embedded struct fields --- dialect/sql/scan.go | 45 +++++++++++++++++++++++++++++----------- dialect/sql/scan_test.go | 16 ++++++++++++++ 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/dialect/sql/scan.go b/dialect/sql/scan.go index 00ff991b1..aa9652c13 100644 --- a/dialect/sql/scan.go +++ b/dialect/sql/scan.go @@ -174,8 +174,8 @@ func assignable(typ reflect.Type) bool { func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { var ( scan = &rowScan{} - idx = make([]int, 0, typ.NumField()) - names = make(map[string]int, typ.NumField()) + idxs = make([][]int, 0, typ.NumField()) + names = make(map[string][]int, typ.NumField()) ) for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) @@ -183,34 +183,55 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { if f.PkgPath != "" { continue } - name := strings.ToLower(f.Name) - if tag, ok := f.Tag.Lookup("sql"); ok { - name = tag - } else if tag, ok := f.Tag.Lookup("json"); ok { - name = strings.Split(tag, ",")[0] + // Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`. + if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct { + for j := 0; j < typ.NumField(); j++ { + names[columnName(typ.Field(j))] = []int{i, j} + } + continue } - names[name] = i + names[columnName(f)] = []int{i} } for _, c := range columns { // Normalize columns if necessary, for example: COUNT(*) => count. name := strings.ToLower(strings.Split(c, "(")[0]) - i, ok := names[name] + idx, ok := names[name] if !ok { return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name) } - idx = append(idx, i) - scan.columns = append(scan.columns, typ.Field(i).Type) + idxs = append(idxs, idx) + rtype := typ.Field(idx[0]).Type + if len(idx) > 1 { + rtype = rtype.Field(idx[1]).Type + } + scan.columns = append(scan.columns, rtype) } scan.value = func(vs ...interface{}) reflect.Value { st := reflect.New(typ).Elem() for i, v := range vs { - st.Field(idx[i]).Set(reflect.Indirect(reflect.ValueOf(v))) + idx := idxs[i] + rvalue := st.Field(idx[0]) + if len(idx) > 1 { + rvalue = rvalue.Field(idx[1]) + } + rvalue.Set(reflect.Indirect(reflect.ValueOf(v))) } return st } return scan, nil } +// columnName returns the column name of a struct-field. +func columnName(f reflect.StructField) string { + name := strings.ToLower(f.Name) + if tag, ok := f.Tag.Lookup("sql"); ok { + name = tag + } else if tag, ok := f.Tag.Lookup("json"); ok { + name = strings.Split(tag, ",")[0] + } + return name +} + // scanPtr wraps the underlying type with rowScan. func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) { typ = typ.Elem() diff --git a/dialect/sql/scan_test.go b/dialect/sql/scan_test.go index 08036f18f..5022ae13d 100644 --- a/dialect/sql/scan_test.go +++ b/dialect/sql/scan_test.go @@ -139,6 +139,22 @@ func TestScanSlice(t *testing.T) { require.Empty(t, pp) } +func TestScanNestedStruct(t *testing.T) { + mock := sqlmock.NewRows([]string{"name", "age"}). + AddRow("foo", 1). + AddRow("bar", 2) + type T struct{ Name string } + var v []struct { + T + Age int + } + require.NoError(t, ScanSlice(toRows(mock), &v)) + require.Equal(t, "foo", v[0].Name) + require.Equal(t, 1, v[0].Age) + require.Equal(t, "bar", v[1].Name) + require.Equal(t, 2, v[1].Age) +} + func TestScanSlicePtr(t *testing.T) { mock := sqlmock.NewRows([]string{"name"}). AddRow("foo").