mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/sqlscan: allow scanning values to embedded struct fields
This commit is contained in:
committed by
Ariel Mashraki
parent
497fca4c96
commit
b3041725d2
@@ -174,8 +174,8 @@ func assignable(typ reflect.Type) bool {
|
||||
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
var (
|
||||
scan = &rowScan{}
|
||||
idx = make([]int, 0, typ.NumField())
|
||||
names = make(map[string]int, typ.NumField())
|
||||
idxs = make([][]int, 0, typ.NumField())
|
||||
names = make(map[string][]int, typ.NumField())
|
||||
)
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
@@ -183,34 +183,55 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
if f.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(f.Name)
|
||||
if tag, ok := f.Tag.Lookup("sql"); ok {
|
||||
name = tag
|
||||
} else if tag, ok := f.Tag.Lookup("json"); ok {
|
||||
name = strings.Split(tag, ",")[0]
|
||||
// Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`.
|
||||
if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct {
|
||||
for j := 0; j < typ.NumField(); j++ {
|
||||
names[columnName(typ.Field(j))] = []int{i, j}
|
||||
}
|
||||
continue
|
||||
}
|
||||
names[name] = i
|
||||
names[columnName(f)] = []int{i}
|
||||
}
|
||||
for _, c := range columns {
|
||||
// Normalize columns if necessary, for example: COUNT(*) => count.
|
||||
name := strings.ToLower(strings.Split(c, "(")[0])
|
||||
i, ok := names[name]
|
||||
idx, ok := names[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name)
|
||||
}
|
||||
idx = append(idx, i)
|
||||
scan.columns = append(scan.columns, typ.Field(i).Type)
|
||||
idxs = append(idxs, idx)
|
||||
rtype := typ.Field(idx[0]).Type
|
||||
if len(idx) > 1 {
|
||||
rtype = rtype.Field(idx[1]).Type
|
||||
}
|
||||
scan.columns = append(scan.columns, rtype)
|
||||
}
|
||||
scan.value = func(vs ...interface{}) reflect.Value {
|
||||
st := reflect.New(typ).Elem()
|
||||
for i, v := range vs {
|
||||
st.Field(idx[i]).Set(reflect.Indirect(reflect.ValueOf(v)))
|
||||
idx := idxs[i]
|
||||
rvalue := st.Field(idx[0])
|
||||
if len(idx) > 1 {
|
||||
rvalue = rvalue.Field(idx[1])
|
||||
}
|
||||
rvalue.Set(reflect.Indirect(reflect.ValueOf(v)))
|
||||
}
|
||||
return st
|
||||
}
|
||||
return scan, nil
|
||||
}
|
||||
|
||||
// columnName returns the column name of a struct-field.
|
||||
func columnName(f reflect.StructField) string {
|
||||
name := strings.ToLower(f.Name)
|
||||
if tag, ok := f.Tag.Lookup("sql"); ok {
|
||||
name = tag
|
||||
} else if tag, ok := f.Tag.Lookup("json"); ok {
|
||||
name = strings.Split(tag, ",")[0]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// scanPtr wraps the underlying type with rowScan.
|
||||
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
typ = typ.Elem()
|
||||
|
||||
@@ -139,6 +139,22 @@ func TestScanSlice(t *testing.T) {
|
||||
require.Empty(t, pp)
|
||||
}
|
||||
|
||||
func TestScanNestedStruct(t *testing.T) {
|
||||
mock := sqlmock.NewRows([]string{"name", "age"}).
|
||||
AddRow("foo", 1).
|
||||
AddRow("bar", 2)
|
||||
type T struct{ Name string }
|
||||
var v []struct {
|
||||
T
|
||||
Age int
|
||||
}
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v))
|
||||
require.Equal(t, "foo", v[0].Name)
|
||||
require.Equal(t, 1, v[0].Age)
|
||||
require.Equal(t, "bar", v[1].Name)
|
||||
require.Equal(t, 2, v[1].Age)
|
||||
}
|
||||
|
||||
func TestScanSlicePtr(t *testing.T) {
|
||||
mock := sqlmock.NewRows([]string{"name"}).
|
||||
AddRow("foo").
|
||||
|
||||
Reference in New Issue
Block a user