mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql: support scanning json fields (#3022)
This commit is contained in:
@@ -7,9 +7,11 @@ package sql
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
|
||||
@@ -113,8 +115,11 @@ func ScanSlice(rows ColumnScanner, v any) error {
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
return fmt.Errorf("sql/scan: failed scanning rows: %w", err)
|
||||
}
|
||||
vv := reflect.Append(rv, scan.value(values...))
|
||||
rv.Set(vv)
|
||||
vv, err := scan.value(values...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rv.Set(reflect.Append(rv, vv))
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
@@ -124,7 +129,7 @@ type rowScan struct {
|
||||
// column types of a row.
|
||||
columns []reflect.Type
|
||||
// value functions that converts the row columns (result) to a reflect.Value.
|
||||
value func(v ...any) reflect.Value
|
||||
value func(v ...any) (reflect.Value, error)
|
||||
}
|
||||
|
||||
// values returns a []any from the configured column types.
|
||||
@@ -142,8 +147,8 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
case assignable(typ):
|
||||
return &rowScan{
|
||||
columns: []reflect.Type{typ},
|
||||
value: func(v ...any) reflect.Value {
|
||||
return reflect.Indirect(reflect.ValueOf(v[0]))
|
||||
value: func(v ...any) (reflect.Value, error) {
|
||||
return reflect.Indirect(reflect.ValueOf(v[0])), nil
|
||||
},
|
||||
}, nil
|
||||
case k == reflect.Ptr:
|
||||
@@ -155,7 +160,22 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
||||
var (
|
||||
scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
||||
nullJSONType = reflect.TypeOf((*nullJSON)(nil)).Elem()
|
||||
)
|
||||
|
||||
// nullJSON represents a json.RawMessage that may be NULL.
|
||||
type nullJSON json.RawMessage
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (j *nullJSON) Scan(v interface{}) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
*j = v.([]byte)
|
||||
return nil
|
||||
}
|
||||
|
||||
// assignable reports if the given type can be assigned directly by `Rows.Scan`.
|
||||
func assignable(typ reflect.Type) bool {
|
||||
@@ -170,7 +190,7 @@ func assignable(typ reflect.Type) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
|
||||
// scanStruct returns the configuration for scanning a sql.Row into a struct.
|
||||
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
var (
|
||||
scan = &rowScan{}
|
||||
@@ -183,7 +203,7 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
if f.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
// Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`.
|
||||
// Support 1-level embedding to accept types as `type T struct {ent.T; V int}`.
|
||||
if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct {
|
||||
for j := 0; j < typ.NumField(); j++ {
|
||||
names[columnName(typ.Field(j))] = []int{i, j}
|
||||
@@ -204,14 +224,19 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
if len(idx) > 1 {
|
||||
rtype = rtype.Field(idx[1]).Type
|
||||
}
|
||||
if !nillable(rtype) {
|
||||
// Create a pointer to the actual reflect
|
||||
// types to accept optional struct fields.
|
||||
switch {
|
||||
// If the field is not support by the standard
|
||||
// convertAssign, assume it is a JSON field.
|
||||
case !supportsScan(rtype):
|
||||
rtype = nullJSONType
|
||||
// Create a pointer to the actual reflect
|
||||
// types to accept optional struct fields.
|
||||
case !nillable(rtype):
|
||||
rtype = reflect.PtrTo(rtype)
|
||||
}
|
||||
scan.columns = append(scan.columns, rtype)
|
||||
}
|
||||
scan.value = func(vs ...any) reflect.Value {
|
||||
scan.value = func(vs ...any) (reflect.Value, error) {
|
||||
st := reflect.New(typ).Elem()
|
||||
for i, v := range vs {
|
||||
rv := reflect.Indirect(reflect.ValueOf(v))
|
||||
@@ -219,16 +244,27 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
continue
|
||||
}
|
||||
idx := idxs[i]
|
||||
rvalue := st.Field(idx[0])
|
||||
rvalue, ft := st.Field(idx[0]), st.Type().Field(idx[0])
|
||||
if len(idx) > 1 {
|
||||
rvalue = rvalue.Field(idx[1])
|
||||
// Embedded field.
|
||||
rvalue, ft = rvalue.Field(idx[1]), ft.Type.Field(idx[1])
|
||||
}
|
||||
if !nillable(rvalue.Type()) {
|
||||
switch {
|
||||
case rv.Type() == nullJSONType:
|
||||
if rv = reflect.Indirect(rv); rv.IsNil() {
|
||||
continue
|
||||
}
|
||||
if err := json.Unmarshal(rv.Bytes(), rvalue.Addr().Interface()); err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("unmarshal field %q: %w", ft.Name, err)
|
||||
}
|
||||
case !nillable(rvalue.Type()):
|
||||
rv = reflect.Indirect(rv)
|
||||
fallthrough
|
||||
default:
|
||||
rvalue.Set(rv)
|
||||
}
|
||||
rvalue.Set(rv)
|
||||
}
|
||||
return st
|
||||
return st, nil
|
||||
}
|
||||
return scan, nil
|
||||
}
|
||||
@@ -261,12 +297,36 @@ func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
return nil, err
|
||||
}
|
||||
wrap := scan.value
|
||||
scan.value = func(vs ...any) reflect.Value {
|
||||
v := wrap(vs...)
|
||||
scan.value = func(vs ...any) (reflect.Value, error) {
|
||||
v, err := wrap(vs...)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
pt := reflect.PtrTo(v.Type())
|
||||
pv := reflect.New(pt.Elem())
|
||||
pv.Elem().Set(v)
|
||||
return pv
|
||||
return pv, nil
|
||||
}
|
||||
return scan, nil
|
||||
}
|
||||
|
||||
func supportsScan(t reflect.Type) bool {
|
||||
if t.Implements(scannerType) || reflect.PtrTo(t).Implements(scannerType) {
|
||||
return true
|
||||
}
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
switch t.Kind() {
|
||||
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64, reflect.Pointer, reflect.String:
|
||||
return true
|
||||
case reflect.Slice:
|
||||
return t == reflect.TypeOf(sql.RawBytes(nil)) || t == reflect.TypeOf([]byte(nil))
|
||||
case reflect.Interface:
|
||||
return t == reflect.TypeOf((*any)(nil)).Elem()
|
||||
default:
|
||||
return t == reflect.TypeOf(time.Time{}) || t.Implements(scannerType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ package sql
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
@@ -139,6 +140,71 @@ func TestScanSlice(t *testing.T) {
|
||||
require.Empty(t, pp)
|
||||
}
|
||||
|
||||
func TestScanJSON(t *testing.T) {
|
||||
mock := sqlmock.NewRows([]string{"v", "p"}).
|
||||
AddRow([]byte(`{"i": 1, "s":"a8m"}`), []byte(`{"i": 1, "s":"a8m"}`)).
|
||||
AddRow([]byte(`{"i": 2, "s":"tmr"}`), []byte(`{"i": 2, "s":"tmr"}`)).
|
||||
AddRow([]byte(nil), []byte(`null`)).
|
||||
AddRow(nil, nil)
|
||||
var v1 []*struct {
|
||||
V struct {
|
||||
I int `json:"i"`
|
||||
S string `json:"s"`
|
||||
} `json:"v"`
|
||||
P *struct {
|
||||
I int `json:"i"`
|
||||
S string `json:"s"`
|
||||
} `json:"p"`
|
||||
}
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v1))
|
||||
require.Equal(t, 1, v1[0].V.I)
|
||||
require.Equal(t, "a8m", v1[0].V.S)
|
||||
require.Equal(t, v1[0].V, *v1[0].P)
|
||||
require.Equal(t, 2, v1[1].V.I)
|
||||
require.Equal(t, "tmr", v1[1].V.S)
|
||||
require.Equal(t, v1[1].V, *v1[1].P)
|
||||
require.Equal(t, 0, v1[2].V.I)
|
||||
require.Equal(t, "", v1[2].V.S)
|
||||
require.Nil(t, v1[2].P)
|
||||
require.Equal(t, 0, v1[3].V.I)
|
||||
require.Equal(t, "", v1[3].V.S)
|
||||
require.Nil(t, v1[3].P)
|
||||
|
||||
mock = sqlmock.NewRows([]string{"v", "p"}).
|
||||
AddRow([]byte(`[1]`), []byte(`[1]`)).
|
||||
AddRow([]byte(`[2]`), []byte(`[2]`))
|
||||
var v2 []*struct {
|
||||
V []int `json:"v"`
|
||||
P *[]int `json:"p"`
|
||||
}
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v2))
|
||||
require.Equal(t, []int{1}, v2[0].V)
|
||||
require.Equal(t, v2[0].V, *v2[0].P)
|
||||
require.Equal(t, []int{2}, v2[1].V)
|
||||
require.Equal(t, v2[1].V, *v2[1].P)
|
||||
|
||||
mock = sqlmock.NewRows([]string{"v", "p"}).
|
||||
AddRow([]byte(`null`), []byte(`{}`)).
|
||||
AddRow(nil, nil)
|
||||
var v3 []*struct {
|
||||
V json.RawMessage `json:"v"`
|
||||
P *json.RawMessage `json:"p"`
|
||||
}
|
||||
require.NoError(t, ScanSlice(toRows(mock), &v3))
|
||||
require.Equal(t, json.RawMessage("null"), v3[0].V)
|
||||
require.Equal(t, json.RawMessage("{}"), *v3[0].P)
|
||||
require.Equal(t, json.RawMessage(nil), v3[1].V)
|
||||
require.Nil(t, v3[1].P)
|
||||
|
||||
// Unmarshal errors.
|
||||
mock = sqlmock.NewRows([]string{"v", "p"}).
|
||||
AddRow([]byte(`{invalid}`), []byte(`{}`))
|
||||
require.EqualError(t, ScanSlice(toRows(mock), &v1), `unmarshal field "V": invalid character 'i' looking for beginning of object key string`)
|
||||
mock = sqlmock.NewRows([]string{"v", "p"}).
|
||||
AddRow([]byte(``), []byte(``))
|
||||
require.EqualError(t, ScanSlice(toRows(mock), &v1), `unmarshal field "V": unexpected end of JSON input`)
|
||||
}
|
||||
|
||||
func TestScanNestedStruct(t *testing.T) {
|
||||
mock := sqlmock.NewRows([]string{"name", "age"}).
|
||||
AddRow("foo", 1).
|
||||
|
||||
Reference in New Issue
Block a user