dialect/sql: support scanning json fields (#3022)

This commit is contained in:
Ariel Mashraki
2022-10-18 16:18:50 +03:00
committed by GitHub
parent d17c28bee0
commit 1bc4d48a51
5 changed files with 173 additions and 21 deletions

View File

@@ -7,9 +7,11 @@ package sql
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
)
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
@@ -113,8 +115,11 @@ func ScanSlice(rows ColumnScanner, v any) error {
if err := rows.Scan(values...); err != nil {
return fmt.Errorf("sql/scan: failed scanning rows: %w", err)
}
vv := reflect.Append(rv, scan.value(values...))
rv.Set(vv)
vv, err := scan.value(values...)
if err != nil {
return err
}
rv.Set(reflect.Append(rv, vv))
}
return rows.Err()
}
@@ -124,7 +129,7 @@ type rowScan struct {
// column types of a row.
columns []reflect.Type
// value functions that converts the row columns (result) to a reflect.Value.
value func(v ...any) reflect.Value
value func(v ...any) (reflect.Value, error)
}
// values returns a []any from the configured column types.
@@ -142,8 +147,8 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
case assignable(typ):
return &rowScan{
columns: []reflect.Type{typ},
value: func(v ...any) reflect.Value {
return reflect.Indirect(reflect.ValueOf(v[0]))
value: func(v ...any) (reflect.Value, error) {
return reflect.Indirect(reflect.ValueOf(v[0])), nil
},
}, nil
case k == reflect.Ptr:
@@ -155,7 +160,22 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
}
}
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
var (
scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
nullJSONType = reflect.TypeOf((*nullJSON)(nil)).Elem()
)
// nullJSON represents a json.RawMessage that may be NULL.
type nullJSON json.RawMessage
// Scan implements the sql.Scanner interface.
func (j *nullJSON) Scan(v interface{}) error {
if v == nil {
return nil
}
*j = v.([]byte)
return nil
}
// assignable reports if the given type can be assigned directly by `Rows.Scan`.
func assignable(typ reflect.Type) bool {
@@ -170,7 +190,7 @@ func assignable(typ reflect.Type) bool {
return true
}
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
// scanStruct returns the configuration for scanning a sql.Row into a struct.
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
var (
scan = &rowScan{}
@@ -183,7 +203,7 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
if f.PkgPath != "" {
continue
}
// Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`.
// Support 1-level embedding to accept types as `type T struct {ent.T; V int}`.
if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct {
for j := 0; j < typ.NumField(); j++ {
names[columnName(typ.Field(j))] = []int{i, j}
@@ -204,14 +224,19 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
if len(idx) > 1 {
rtype = rtype.Field(idx[1]).Type
}
if !nillable(rtype) {
// Create a pointer to the actual reflect
// types to accept optional struct fields.
switch {
// If the field is not support by the standard
// convertAssign, assume it is a JSON field.
case !supportsScan(rtype):
rtype = nullJSONType
// Create a pointer to the actual reflect
// types to accept optional struct fields.
case !nillable(rtype):
rtype = reflect.PtrTo(rtype)
}
scan.columns = append(scan.columns, rtype)
}
scan.value = func(vs ...any) reflect.Value {
scan.value = func(vs ...any) (reflect.Value, error) {
st := reflect.New(typ).Elem()
for i, v := range vs {
rv := reflect.Indirect(reflect.ValueOf(v))
@@ -219,16 +244,27 @@ func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
continue
}
idx := idxs[i]
rvalue := st.Field(idx[0])
rvalue, ft := st.Field(idx[0]), st.Type().Field(idx[0])
if len(idx) > 1 {
rvalue = rvalue.Field(idx[1])
// Embedded field.
rvalue, ft = rvalue.Field(idx[1]), ft.Type.Field(idx[1])
}
if !nillable(rvalue.Type()) {
switch {
case rv.Type() == nullJSONType:
if rv = reflect.Indirect(rv); rv.IsNil() {
continue
}
if err := json.Unmarshal(rv.Bytes(), rvalue.Addr().Interface()); err != nil {
return reflect.Value{}, fmt.Errorf("unmarshal field %q: %w", ft.Name, err)
}
case !nillable(rvalue.Type()):
rv = reflect.Indirect(rv)
fallthrough
default:
rvalue.Set(rv)
}
rvalue.Set(rv)
}
return st
return st, nil
}
return scan, nil
}
@@ -261,12 +297,36 @@ func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
return nil, err
}
wrap := scan.value
scan.value = func(vs ...any) reflect.Value {
v := wrap(vs...)
scan.value = func(vs ...any) (reflect.Value, error) {
v, err := wrap(vs...)
if err != nil {
return reflect.Value{}, err
}
pt := reflect.PtrTo(v.Type())
pv := reflect.New(pt.Elem())
pv.Elem().Set(v)
return pv
return pv, nil
}
return scan, nil
}
func supportsScan(t reflect.Type) bool {
if t.Implements(scannerType) || reflect.PtrTo(t).Implements(scannerType) {
return true
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
switch t.Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.Pointer, reflect.String:
return true
case reflect.Slice:
return t == reflect.TypeOf(sql.RawBytes(nil)) || t == reflect.TypeOf([]byte(nil))
case reflect.Interface:
return t == reflect.TypeOf((*any)(nil)).Elem()
default:
return t == reflect.TypeOf(time.Time{}) || t.Implements(scannerType)
}
}