From 1bc4d48a51437aeda1779ec2af35107c6919f0b9 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Tue, 18 Oct 2022 16:18:50 +0300 Subject: [PATCH] dialect/sql: support scanning json fields (#3022) --- dialect/sql/scan.go | 100 ++++++++++++++++++----- dialect/sql/scan_test.go | 66 +++++++++++++++ entc/integration/json/ent/schema/user.go | 1 + entc/integration/json/ent/user.go | 2 +- entc/integration/json/json_test.go | 25 ++++++ 5 files changed, 173 insertions(+), 21 deletions(-) diff --git a/dialect/sql/scan.go b/dialect/sql/scan.go index b82e25c22..597824673 100644 --- a/dialect/sql/scan.go +++ b/dialect/sql/scan.go @@ -7,9 +7,11 @@ package sql import ( "database/sql" "database/sql/driver" + "encoding/json" "fmt" "reflect" "strings" + "time" ) // ScanOne scans one row to the given value. It fails if the rows holds more than 1 row. @@ -113,8 +115,11 @@ func ScanSlice(rows ColumnScanner, v any) error { if err := rows.Scan(values...); err != nil { return fmt.Errorf("sql/scan: failed scanning rows: %w", err) } - vv := reflect.Append(rv, scan.value(values...)) - rv.Set(vv) + vv, err := scan.value(values...) + if err != nil { + return err + } + rv.Set(reflect.Append(rv, vv)) } return rows.Err() } @@ -124,7 +129,7 @@ type rowScan struct { // column types of a row. columns []reflect.Type // value functions that converts the row columns (result) to a reflect.Value. - value func(v ...any) reflect.Value + value func(v ...any) (reflect.Value, error) } // values returns a []any from the configured column types. @@ -142,8 +147,8 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) { case assignable(typ): return &rowScan{ columns: []reflect.Type{typ}, - value: func(v ...any) reflect.Value { - return reflect.Indirect(reflect.ValueOf(v[0])) + value: func(v ...any) (reflect.Value, error) { + return reflect.Indirect(reflect.ValueOf(v[0])), nil }, }, nil case k == reflect.Ptr: @@ -155,7 +160,22 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) { } } -var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +var ( + scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + nullJSONType = reflect.TypeOf((*nullJSON)(nil)).Elem() +) + +// nullJSON represents a json.RawMessage that may be NULL. +type nullJSON json.RawMessage + +// Scan implements the sql.Scanner interface. +func (j *nullJSON) Scan(v interface{}) error { + if v == nil { + return nil + } + *j = v.([]byte) + return nil +} // assignable reports if the given type can be assigned directly by `Rows.Scan`. func assignable(typ reflect.Type) bool { @@ -170,7 +190,7 @@ func assignable(typ reflect.Type) bool { return true } -// scanStruct returns the a configuration for scanning an sql.Row into a struct. +// scanStruct returns the configuration for scanning a sql.Row into a struct. func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { var ( scan = &rowScan{} @@ -183,7 +203,7 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { if f.PkgPath != "" { continue } - // Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`. + // Support 1-level embedding to accept types as `type T struct {ent.T; V int}`. if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct { for j := 0; j < typ.NumField(); j++ { names[columnName(typ.Field(j))] = []int{i, j} @@ -204,14 +224,19 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { if len(idx) > 1 { rtype = rtype.Field(idx[1]).Type } - if !nillable(rtype) { - // Create a pointer to the actual reflect - // types to accept optional struct fields. + switch { + // If the field is not support by the standard + // convertAssign, assume it is a JSON field. + case !supportsScan(rtype): + rtype = nullJSONType + // Create a pointer to the actual reflect + // types to accept optional struct fields. + case !nillable(rtype): rtype = reflect.PtrTo(rtype) } scan.columns = append(scan.columns, rtype) } - scan.value = func(vs ...any) reflect.Value { + scan.value = func(vs ...any) (reflect.Value, error) { st := reflect.New(typ).Elem() for i, v := range vs { rv := reflect.Indirect(reflect.ValueOf(v)) @@ -219,16 +244,27 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { continue } idx := idxs[i] - rvalue := st.Field(idx[0]) + rvalue, ft := st.Field(idx[0]), st.Type().Field(idx[0]) if len(idx) > 1 { - rvalue = rvalue.Field(idx[1]) + // Embedded field. + rvalue, ft = rvalue.Field(idx[1]), ft.Type.Field(idx[1]) } - if !nillable(rvalue.Type()) { + switch { + case rv.Type() == nullJSONType: + if rv = reflect.Indirect(rv); rv.IsNil() { + continue + } + if err := json.Unmarshal(rv.Bytes(), rvalue.Addr().Interface()); err != nil { + return reflect.Value{}, fmt.Errorf("unmarshal field %q: %w", ft.Name, err) + } + case !nillable(rvalue.Type()): rv = reflect.Indirect(rv) + fallthrough + default: + rvalue.Set(rv) } - rvalue.Set(rv) } - return st + return st, nil } return scan, nil } @@ -261,12 +297,36 @@ func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) { return nil, err } wrap := scan.value - scan.value = func(vs ...any) reflect.Value { - v := wrap(vs...) + scan.value = func(vs ...any) (reflect.Value, error) { + v, err := wrap(vs...) + if err != nil { + return reflect.Value{}, err + } pt := reflect.PtrTo(v.Type()) pv := reflect.New(pt.Elem()) pv.Elem().Set(v) - return pv + return pv, nil } return scan, nil } + +func supportsScan(t reflect.Type) bool { + if t.Implements(scannerType) || reflect.PtrTo(t).Implements(scannerType) { + return true + } + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch t.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.Pointer, reflect.String: + return true + case reflect.Slice: + return t == reflect.TypeOf(sql.RawBytes(nil)) || t == reflect.TypeOf([]byte(nil)) + case reflect.Interface: + return t == reflect.TypeOf((*any)(nil)).Elem() + default: + return t == reflect.TypeOf(time.Time{}) || t.Implements(scannerType) + } +} diff --git a/dialect/sql/scan_test.go b/dialect/sql/scan_test.go index a107f9354..f66c22eb2 100644 --- a/dialect/sql/scan_test.go +++ b/dialect/sql/scan_test.go @@ -7,6 +7,7 @@ package sql import ( "database/sql" "database/sql/driver" + "encoding/json" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -139,6 +140,71 @@ func TestScanSlice(t *testing.T) { require.Empty(t, pp) } +func TestScanJSON(t *testing.T) { + mock := sqlmock.NewRows([]string{"v", "p"}). + AddRow([]byte(`{"i": 1, "s":"a8m"}`), []byte(`{"i": 1, "s":"a8m"}`)). + AddRow([]byte(`{"i": 2, "s":"tmr"}`), []byte(`{"i": 2, "s":"tmr"}`)). + AddRow([]byte(nil), []byte(`null`)). + AddRow(nil, nil) + var v1 []*struct { + V struct { + I int `json:"i"` + S string `json:"s"` + } `json:"v"` + P *struct { + I int `json:"i"` + S string `json:"s"` + } `json:"p"` + } + require.NoError(t, ScanSlice(toRows(mock), &v1)) + require.Equal(t, 1, v1[0].V.I) + require.Equal(t, "a8m", v1[0].V.S) + require.Equal(t, v1[0].V, *v1[0].P) + require.Equal(t, 2, v1[1].V.I) + require.Equal(t, "tmr", v1[1].V.S) + require.Equal(t, v1[1].V, *v1[1].P) + require.Equal(t, 0, v1[2].V.I) + require.Equal(t, "", v1[2].V.S) + require.Nil(t, v1[2].P) + require.Equal(t, 0, v1[3].V.I) + require.Equal(t, "", v1[3].V.S) + require.Nil(t, v1[3].P) + + mock = sqlmock.NewRows([]string{"v", "p"}). + AddRow([]byte(`[1]`), []byte(`[1]`)). + AddRow([]byte(`[2]`), []byte(`[2]`)) + var v2 []*struct { + V []int `json:"v"` + P *[]int `json:"p"` + } + require.NoError(t, ScanSlice(toRows(mock), &v2)) + require.Equal(t, []int{1}, v2[0].V) + require.Equal(t, v2[0].V, *v2[0].P) + require.Equal(t, []int{2}, v2[1].V) + require.Equal(t, v2[1].V, *v2[1].P) + + mock = sqlmock.NewRows([]string{"v", "p"}). + AddRow([]byte(`null`), []byte(`{}`)). + AddRow(nil, nil) + var v3 []*struct { + V json.RawMessage `json:"v"` + P *json.RawMessage `json:"p"` + } + require.NoError(t, ScanSlice(toRows(mock), &v3)) + require.Equal(t, json.RawMessage("null"), v3[0].V) + require.Equal(t, json.RawMessage("{}"), *v3[0].P) + require.Equal(t, json.RawMessage(nil), v3[1].V) + require.Nil(t, v3[1].P) + + // Unmarshal errors. + mock = sqlmock.NewRows([]string{"v", "p"}). + AddRow([]byte(`{invalid}`), []byte(`{}`)) + require.EqualError(t, ScanSlice(toRows(mock), &v1), `unmarshal field "V": invalid character 'i' looking for beginning of object key string`) + mock = sqlmock.NewRows([]string{"v", "p"}). + AddRow([]byte(``), []byte(``)) + require.EqualError(t, ScanSlice(toRows(mock), &v1), `unmarshal field "V": unexpected end of JSON input`) +} + func TestScanNestedStruct(t *testing.T) { mock := sqlmock.NewRows([]string{"name", "age"}). AddRow("foo", 1). diff --git a/entc/integration/json/ent/schema/user.go b/entc/integration/json/ent/schema/user.go index 0cba518c1..0c439f44b 100644 --- a/entc/integration/json/ent/schema/user.go +++ b/entc/integration/json/ent/schema/user.go @@ -30,6 +30,7 @@ func (User) Fields() []ent.Field { Optional(), field.JSON("URLs", []*url.URL{}). StorageKey("urls"). + StructTag(`json:"urls,omitempty"`). Optional(), field.JSON("raw", json.RawMessage{}). Optional(), diff --git a/entc/integration/json/ent/user.go b/entc/integration/json/ent/user.go index 6740f242f..3392bee23 100644 --- a/entc/integration/json/ent/user.go +++ b/entc/integration/json/ent/user.go @@ -28,7 +28,7 @@ type User struct { // URL holds the value of the "url" field. URL *url.URL `json:"url,omitempty"` // URLs holds the value of the "URLs" field. - URLs []*url.URL `json:"URLs,omitempty"` + URLs []*url.URL `json:"urls,omitempty"` // Raw holds the value of the "raw" field. Raw json.RawMessage `json:"raw,omitempty"` // Dirs holds the value of the "dirs" field. diff --git a/entc/integration/json/json_test.go b/entc/integration/json/json_test.go index 8f9ef12e2..28be10709 100644 --- a/entc/integration/json/json_test.go +++ b/entc/integration/json/json_test.go @@ -55,6 +55,7 @@ func TestMySQL(t *testing.T) { Strings(t, client) Predicates(t, client) } + Scan(t, client) }) } } @@ -86,6 +87,7 @@ func TestMaria(t *testing.T) { NetAddr(t, client) RawMessage(t, client) Predicates(t, client) + Scan(t, client) }) } } @@ -117,6 +119,7 @@ func TestPostgres(t *testing.T) { NetAddr(t, client) RawMessage(t, client) Predicates(t, client) + Scan(t, client) }) } } @@ -137,6 +140,7 @@ func TestSQLite(t *testing.T) { NetAddr(t, client) RawMessage(t, client) Predicates(t, client) + Scan(t, client) } func Ints(t *testing.T, client *ent.Client) { @@ -596,3 +600,24 @@ func Predicates(t *testing.T, client *ent.Client) { require.Equal(t, 4, n) }) } + +func Scan(t *testing.T, client *ent.Client) { + ctx := context.Background() + all := client.User.Query().Order(ent.Asc(user.FieldID)).AllX(ctx) + require.NotEmpty(t, all) + var scanned []*ent.User + // Select all non-sensitive fields. + client.User.Query().Order(ent.Asc(user.FieldID)).Select(user.Columns[:len(user.Columns)-2]...).ScanX(ctx, &scanned) + require.Equal(t, len(all), len(scanned)) + for i := range all { + require.Equal(t, all[i].ID, scanned[i].ID) + require.Equal(t, all[i].T, scanned[i].T) + require.Equal(t, all[i].URL, scanned[i].URL) + require.Equal(t, all[i].URLs, scanned[i].URLs) + require.Equal(t, all[i].Dirs, scanned[i].Dirs) + require.Equal(t, all[i].Raw, scanned[i].Raw) + require.Equal(t, all[i].Ints, scanned[i].Ints) + require.Equal(t, all[i].Floats, scanned[i].Floats) + require.Equal(t, all[i].Strings, scanned[i].Strings) + } +}