dialect/sql: support scanning json fields (#3022)

This commit is contained in:
Ariel Mashraki
2022-10-18 16:18:50 +03:00
committed by GitHub
parent d17c28bee0
commit 1bc4d48a51
5 changed files with 173 additions and 21 deletions

View File

@@ -7,9 +7,11 @@ package sql
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
)
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
@@ -113,8 +115,11 @@ func ScanSlice(rows ColumnScanner, v any) error {
if err := rows.Scan(values...); err != nil {
return fmt.Errorf("sql/scan: failed scanning rows: %w", err)
}
vv := reflect.Append(rv, scan.value(values...))
rv.Set(vv)
vv, err := scan.value(values...)
if err != nil {
return err
}
rv.Set(reflect.Append(rv, vv))
}
return rows.Err()
}
@@ -124,7 +129,7 @@ type rowScan struct {
// column types of a row.
columns []reflect.Type
// value functions that converts the row columns (result) to a reflect.Value.
value func(v ...any) reflect.Value
value func(v ...any) (reflect.Value, error)
}
// values returns a []any from the configured column types.
@@ -142,8 +147,8 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
case assignable(typ):
return &rowScan{
columns: []reflect.Type{typ},
value: func(v ...any) reflect.Value {
return reflect.Indirect(reflect.ValueOf(v[0]))
value: func(v ...any) (reflect.Value, error) {
return reflect.Indirect(reflect.ValueOf(v[0])), nil
},
}, nil
case k == reflect.Ptr:
@@ -155,7 +160,22 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
}
}
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
var (
scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
nullJSONType = reflect.TypeOf((*nullJSON)(nil)).Elem()
)
// nullJSON represents a json.RawMessage that may be NULL.
type nullJSON json.RawMessage
// Scan implements the sql.Scanner interface.
func (j *nullJSON) Scan(v interface{}) error {
if v == nil {
return nil
}
*j = v.([]byte)
return nil
}
// assignable reports if the given type can be assigned directly by `Rows.Scan`.
func assignable(typ reflect.Type) bool {
@@ -170,7 +190,7 @@ func assignable(typ reflect.Type) bool {
return true
}
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
// scanStruct returns the configuration for scanning a sql.Row into a struct.
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
var (
scan = &rowScan{}
@@ -183,7 +203,7 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
if f.PkgPath != "" {
continue
}
// Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`.
// Support 1-level embedding to accept types as `type T struct {ent.T; V int}`.
if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct {
for j := 0; j < typ.NumField(); j++ {
names[columnName(typ.Field(j))] = []int{i, j}
@@ -204,14 +224,19 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
if len(idx) > 1 {
rtype = rtype.Field(idx[1]).Type
}
if !nillable(rtype) {
// Create a pointer to the actual reflect
// types to accept optional struct fields.
switch {
// If the field is not support by the standard
// convertAssign, assume it is a JSON field.
case !supportsScan(rtype):
rtype = nullJSONType
// Create a pointer to the actual reflect
// types to accept optional struct fields.
case !nillable(rtype):
rtype = reflect.PtrTo(rtype)
}
scan.columns = append(scan.columns, rtype)
}
scan.value = func(vs ...any) reflect.Value {
scan.value = func(vs ...any) (reflect.Value, error) {
st := reflect.New(typ).Elem()
for i, v := range vs {
rv := reflect.Indirect(reflect.ValueOf(v))
@@ -219,16 +244,27 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
continue
}
idx := idxs[i]
rvalue := st.Field(idx[0])
rvalue, ft := st.Field(idx[0]), st.Type().Field(idx[0])
if len(idx) > 1 {
rvalue = rvalue.Field(idx[1])
// Embedded field.
rvalue, ft = rvalue.Field(idx[1]), ft.Type.Field(idx[1])
}
if !nillable(rvalue.Type()) {
switch {
case rv.Type() == nullJSONType:
if rv = reflect.Indirect(rv); rv.IsNil() {
continue
}
if err := json.Unmarshal(rv.Bytes(), rvalue.Addr().Interface()); err != nil {
return reflect.Value{}, fmt.Errorf("unmarshal field %q: %w", ft.Name, err)
}
case !nillable(rvalue.Type()):
rv = reflect.Indirect(rv)
fallthrough
default:
rvalue.Set(rv)
}
rvalue.Set(rv)
}
return st
return st, nil
}
return scan, nil
}
@@ -261,12 +297,36 @@ func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
return nil, err
}
wrap := scan.value
scan.value = func(vs ...any) reflect.Value {
v := wrap(vs...)
scan.value = func(vs ...any) (reflect.Value, error) {
v, err := wrap(vs...)
if err != nil {
return reflect.Value{}, err
}
pt := reflect.PtrTo(v.Type())
pv := reflect.New(pt.Elem())
pv.Elem().Set(v)
return pv
return pv, nil
}
return scan, nil
}
func supportsScan(t reflect.Type) bool {
if t.Implements(scannerType) || reflect.PtrTo(t).Implements(scannerType) {
return true
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
switch t.Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.Pointer, reflect.String:
return true
case reflect.Slice:
return t == reflect.TypeOf(sql.RawBytes(nil)) || t == reflect.TypeOf([]byte(nil))
case reflect.Interface:
return t == reflect.TypeOf((*any)(nil)).Elem()
default:
return t == reflect.TypeOf(time.Time{}) || t.Implements(scannerType)
}
}

View File

@@ -7,6 +7,7 @@ package sql
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"testing"
"github.com/DATA-DOG/go-sqlmock"
@@ -139,6 +140,71 @@ func TestScanSlice(t *testing.T) {
require.Empty(t, pp)
}
func TestScanJSON(t *testing.T) {
mock := sqlmock.NewRows([]string{"v", "p"}).
AddRow([]byte(`{"i": 1, "s":"a8m"}`), []byte(`{"i": 1, "s":"a8m"}`)).
AddRow([]byte(`{"i": 2, "s":"tmr"}`), []byte(`{"i": 2, "s":"tmr"}`)).
AddRow([]byte(nil), []byte(`null`)).
AddRow(nil, nil)
var v1 []*struct {
V struct {
I int `json:"i"`
S string `json:"s"`
} `json:"v"`
P *struct {
I int `json:"i"`
S string `json:"s"`
} `json:"p"`
}
require.NoError(t, ScanSlice(toRows(mock), &v1))
require.Equal(t, 1, v1[0].V.I)
require.Equal(t, "a8m", v1[0].V.S)
require.Equal(t, v1[0].V, *v1[0].P)
require.Equal(t, 2, v1[1].V.I)
require.Equal(t, "tmr", v1[1].V.S)
require.Equal(t, v1[1].V, *v1[1].P)
require.Equal(t, 0, v1[2].V.I)
require.Equal(t, "", v1[2].V.S)
require.Nil(t, v1[2].P)
require.Equal(t, 0, v1[3].V.I)
require.Equal(t, "", v1[3].V.S)
require.Nil(t, v1[3].P)
mock = sqlmock.NewRows([]string{"v", "p"}).
AddRow([]byte(`[1]`), []byte(`[1]`)).
AddRow([]byte(`[2]`), []byte(`[2]`))
var v2 []*struct {
V []int `json:"v"`
P *[]int `json:"p"`
}
require.NoError(t, ScanSlice(toRows(mock), &v2))
require.Equal(t, []int{1}, v2[0].V)
require.Equal(t, v2[0].V, *v2[0].P)
require.Equal(t, []int{2}, v2[1].V)
require.Equal(t, v2[1].V, *v2[1].P)
mock = sqlmock.NewRows([]string{"v", "p"}).
AddRow([]byte(`null`), []byte(`{}`)).
AddRow(nil, nil)
var v3 []*struct {
V json.RawMessage `json:"v"`
P *json.RawMessage `json:"p"`
}
require.NoError(t, ScanSlice(toRows(mock), &v3))
require.Equal(t, json.RawMessage("null"), v3[0].V)
require.Equal(t, json.RawMessage("{}"), *v3[0].P)
require.Equal(t, json.RawMessage(nil), v3[1].V)
require.Nil(t, v3[1].P)
// Unmarshal errors.
mock = sqlmock.NewRows([]string{"v", "p"}).
AddRow([]byte(`{invalid}`), []byte(`{}`))
require.EqualError(t, ScanSlice(toRows(mock), &v1), `unmarshal field "V": invalid character 'i' looking for beginning of object key string`)
mock = sqlmock.NewRows([]string{"v", "p"}).
AddRow([]byte(``), []byte(``))
require.EqualError(t, ScanSlice(toRows(mock), &v1), `unmarshal field "V": unexpected end of JSON input`)
}
func TestScanNestedStruct(t *testing.T) {
mock := sqlmock.NewRows([]string{"name", "age"}).
AddRow("foo", 1).