mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
338 lines
8.9 KiB
Go
338 lines
8.9 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package sql
|
|
|
|
import (
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
|
|
func ScanOne(rows ColumnScanner, v any) error {
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return fmt.Errorf("sql/scan: failed getting column names: %w", err)
|
|
}
|
|
if n := len(columns); n != 1 {
|
|
return fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
|
|
}
|
|
if !rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
return sql.ErrNoRows
|
|
}
|
|
if err := rows.Scan(v); err != nil {
|
|
return err
|
|
}
|
|
if rows.Next() {
|
|
return fmt.Errorf("sql/scan: expect exactly one row in result set")
|
|
}
|
|
return rows.Err()
|
|
}
|
|
|
|
// ScanInt64 scans and returns an int64 from the rows.
|
|
func ScanInt64(rows ColumnScanner) (int64, error) {
|
|
var n int64
|
|
if err := ScanOne(rows, &n); err != nil {
|
|
return 0, err
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// ScanInt scans and returns an int from the rows.
|
|
func ScanInt(rows ColumnScanner) (int, error) {
|
|
n, err := ScanInt64(rows)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int(n), nil
|
|
}
|
|
|
|
// ScanBool scans and returns a boolean from the rows.
|
|
func ScanBool(rows ColumnScanner) (bool, error) {
|
|
var b bool
|
|
if err := ScanOne(rows, &b); err != nil {
|
|
return false, err
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
// ScanString scans and returns a string from the rows.
|
|
func ScanString(rows ColumnScanner) (string, error) {
|
|
var s string
|
|
if err := ScanOne(rows, &s); err != nil {
|
|
return "", err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// ScanValue scans and returns a driver.Value from the rows.
|
|
func ScanValue(rows ColumnScanner) (driver.Value, error) {
|
|
var v driver.Value
|
|
if err := ScanOne(rows, &v); err != nil {
|
|
return "", err
|
|
}
|
|
return v, nil
|
|
}
|
|
|
|
// ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice.
|
|
func ScanSlice(rows ColumnScanner, v any) error {
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return fmt.Errorf("sql/scan: failed getting column names: %w", err)
|
|
}
|
|
rv := reflect.ValueOf(v)
|
|
switch {
|
|
case rv.Kind() != reflect.Ptr:
|
|
if t := reflect.TypeOf(v); t != nil {
|
|
return fmt.Errorf("sql/scan: ScanSlice(non-pointer %s)", t)
|
|
}
|
|
fallthrough
|
|
case rv.IsNil():
|
|
return fmt.Errorf("sql/scan: ScanSlice(nil)")
|
|
}
|
|
rv = reflect.Indirect(rv)
|
|
if k := rv.Kind(); k != reflect.Slice {
|
|
return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k)
|
|
}
|
|
scan, err := scanType(rv.Type().Elem(), columns)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n, m := len(columns), len(scan.columns); n > m {
|
|
return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m)
|
|
}
|
|
for rows.Next() {
|
|
values := scan.values()
|
|
if err := rows.Scan(values...); err != nil {
|
|
return fmt.Errorf("sql/scan: failed scanning rows: %w", err)
|
|
}
|
|
vv, err := scan.value(values...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rv.Set(reflect.Append(rv, vv))
|
|
}
|
|
return rows.Err()
|
|
}
|
|
|
|
// rowScan is the configuration for scanning one sql.Row.
|
|
type rowScan struct {
|
|
// column types of a row.
|
|
columns []reflect.Type
|
|
// value functions that converts the row columns (result) to a reflect.Value.
|
|
value func(v ...any) (reflect.Value, error)
|
|
}
|
|
|
|
// values returns a []any from the configured column types.
|
|
func (r *rowScan) values() []any {
|
|
values := make([]any, len(r.columns))
|
|
for i := range r.columns {
|
|
values[i] = reflect.New(r.columns[i]).Interface()
|
|
}
|
|
return values
|
|
}
|
|
|
|
// scanType returns rowScan for the given reflect.Type.
|
|
func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
|
|
switch k := typ.Kind(); {
|
|
case assignable(typ):
|
|
return &rowScan{
|
|
columns: []reflect.Type{typ},
|
|
value: func(v ...any) (reflect.Value, error) {
|
|
return reflect.Indirect(reflect.ValueOf(v[0])), nil
|
|
},
|
|
}, nil
|
|
case k == reflect.Ptr:
|
|
return scanPtr(typ, columns)
|
|
case k == reflect.Struct:
|
|
return scanStruct(typ, columns)
|
|
default:
|
|
return nil, fmt.Errorf("sql/scan: unsupported type ([]%s)", k)
|
|
}
|
|
}
|
|
|
|
var (
|
|
scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
|
nullJSONType = reflect.TypeOf((*nullJSON)(nil)).Elem()
|
|
)
|
|
|
|
// nullJSON represents a json.RawMessage that may be NULL.
|
|
type nullJSON json.RawMessage
|
|
|
|
// Scan implements the sql.Scanner interface.
|
|
func (j *nullJSON) Scan(v interface{}) error {
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
*j = v.([]byte)
|
|
return nil
|
|
}
|
|
|
|
// assignable reports if the given type can be assigned directly by `Rows.Scan`.
|
|
func assignable(typ reflect.Type) bool {
|
|
switch k := typ.Kind(); {
|
|
case typ.Implements(scannerType):
|
|
case k == reflect.Interface && typ.NumMethod() == 0:
|
|
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
|
|
case (k == reflect.Slice || k == reflect.Array) && typ.Elem().Kind() == reflect.Uint8:
|
|
default:
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// scanStruct returns the configuration for scanning a sql.Row into a struct.
|
|
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
|
var (
|
|
scan = &rowScan{}
|
|
idxs = make([][]int, 0, typ.NumField())
|
|
names = make(map[string][]int, typ.NumField())
|
|
)
|
|
for i := 0; i < typ.NumField(); i++ {
|
|
f := typ.Field(i)
|
|
// Skip unexported fields.
|
|
if f.PkgPath != "" {
|
|
continue
|
|
}
|
|
// Support 1-level embedding to accept types as `type T struct {ent.T; V int}`.
|
|
if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct {
|
|
for j := 0; j < typ.NumField(); j++ {
|
|
names[columnName(typ.Field(j))] = []int{i, j}
|
|
}
|
|
continue
|
|
}
|
|
names[columnName(f)] = []int{i}
|
|
}
|
|
for _, c := range columns {
|
|
var idx []int
|
|
// Normalize columns if necessary,
|
|
// for example: COUNT(*) => count.
|
|
switch name := strings.Split(c, "(")[0]; {
|
|
case names[name] != nil:
|
|
idx = names[name]
|
|
case names[strings.ToLower(name)] != nil:
|
|
idx = names[strings.ToLower(name)]
|
|
default:
|
|
return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name)
|
|
}
|
|
idxs = append(idxs, idx)
|
|
rtype := typ.Field(idx[0]).Type
|
|
if len(idx) > 1 {
|
|
rtype = rtype.Field(idx[1]).Type
|
|
}
|
|
switch {
|
|
// If the field is not support by the standard
|
|
// convertAssign, assume it is a JSON field.
|
|
case !supportsScan(rtype):
|
|
rtype = nullJSONType
|
|
// Create a pointer to the actual reflect
|
|
// types to accept optional struct fields.
|
|
case !nillable(rtype):
|
|
rtype = reflect.PtrTo(rtype)
|
|
}
|
|
scan.columns = append(scan.columns, rtype)
|
|
}
|
|
scan.value = func(vs ...any) (reflect.Value, error) {
|
|
st := reflect.New(typ).Elem()
|
|
for i, v := range vs {
|
|
rv := reflect.Indirect(reflect.ValueOf(v))
|
|
if rv.IsNil() {
|
|
continue
|
|
}
|
|
idx := idxs[i]
|
|
rvalue, ft := st.Field(idx[0]), st.Type().Field(idx[0])
|
|
if len(idx) > 1 {
|
|
// Embedded field.
|
|
rvalue, ft = rvalue.Field(idx[1]), ft.Type.Field(idx[1])
|
|
}
|
|
switch {
|
|
case rv.Type() == nullJSONType:
|
|
if rv = reflect.Indirect(rv); rv.IsNil() {
|
|
continue
|
|
}
|
|
if err := json.Unmarshal(rv.Bytes(), rvalue.Addr().Interface()); err != nil {
|
|
return reflect.Value{}, fmt.Errorf("unmarshal field %q: %w", ft.Name, err)
|
|
}
|
|
case !nillable(rvalue.Type()):
|
|
rv = reflect.Indirect(rv)
|
|
fallthrough
|
|
default:
|
|
rvalue.Set(rv)
|
|
}
|
|
}
|
|
return st, nil
|
|
}
|
|
return scan, nil
|
|
}
|
|
|
|
// columnName returns the column name of a struct-field.
|
|
func columnName(f reflect.StructField) string {
|
|
name := strings.ToLower(f.Name)
|
|
if tag, ok := f.Tag.Lookup("sql"); ok {
|
|
name = tag
|
|
} else if tag, ok := f.Tag.Lookup("json"); ok {
|
|
name = strings.Split(tag, ",")[0]
|
|
}
|
|
return name
|
|
}
|
|
|
|
// nillable reports if the reflect-type can have nil value.
|
|
func nillable(t reflect.Type) bool {
|
|
switch t.Kind() {
|
|
case reflect.Interface, reflect.Slice, reflect.Map, reflect.Ptr, reflect.UnsafePointer:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// scanPtr wraps the underlying type with rowScan.
|
|
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
|
|
typ = typ.Elem()
|
|
scan, err := scanType(typ, columns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
wrap := scan.value
|
|
scan.value = func(vs ...any) (reflect.Value, error) {
|
|
v, err := wrap(vs...)
|
|
if err != nil {
|
|
return reflect.Value{}, err
|
|
}
|
|
pt := reflect.PtrTo(v.Type())
|
|
pv := reflect.New(pt.Elem())
|
|
pv.Elem().Set(v)
|
|
return pv, nil
|
|
}
|
|
return scan, nil
|
|
}
|
|
|
|
func supportsScan(t reflect.Type) bool {
|
|
if t.Implements(scannerType) || reflect.PtrTo(t).Implements(scannerType) {
|
|
return true
|
|
}
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
switch t.Kind() {
|
|
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
|
reflect.Float32, reflect.Float64, reflect.Pointer, reflect.String:
|
|
return true
|
|
case reflect.Slice:
|
|
return t == reflect.TypeOf(sql.RawBytes(nil)) || t == reflect.TypeOf([]byte(nil))
|
|
case reflect.Interface:
|
|
return t == reflect.TypeOf((*any)(nil)).Elem()
|
|
default:
|
|
return t == reflect.TypeOf(time.Time{}) || t.Implements(scannerType)
|
|
}
|
|
}
|