Files
ent/dialect/sql/scan.go

264 lines
7.0 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sql
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strings"
)
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
func ScanOne(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("sql/scan: failed getting column names: %w", err)
}
if n := len(columns); n != 1 {
return fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
}
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err := rows.Scan(v); err != nil {
return err
}
if rows.Next() {
return fmt.Errorf("sql/scan: expect exactly one row in result set")
}
return rows.Err()
}
// ScanInt64 scans and returns an int64 from the rows columns.
func ScanInt64(rows ColumnScanner) (int64, error) {
var n int64
if err := ScanOne(rows, &n); err != nil {
return 0, err
}
return n, nil
}
// ScanInt scans and returns an int from the rows columns.
func ScanInt(rows ColumnScanner) (int, error) {
n, err := ScanInt64(rows)
if err != nil {
return 0, err
}
return int(n), nil
}
// ScanString scans and returns a string from the rows columns.
func ScanString(rows ColumnScanner) (string, error) {
var s string
if err := ScanOne(rows, &s); err != nil {
return "", err
}
return s, nil
}
// ScanValue scans and returns a driver.Value from the rows columns.
func ScanValue(rows ColumnScanner) (driver.Value, error) {
var v driver.Value
if err := ScanOne(rows, &v); err != nil {
return "", err
}
return v, nil
}
// ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice.
func ScanSlice(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("sql/scan: failed getting column names: %w", err)
}
rv := reflect.ValueOf(v)
switch {
case rv.Kind() != reflect.Ptr:
if t := reflect.TypeOf(v); t != nil {
return fmt.Errorf("sql/scan: ScanSlice(non-pointer %s)", t)
}
fallthrough
case rv.IsNil():
return fmt.Errorf("sql/scan: ScanSlice(nil)")
}
rv = reflect.Indirect(rv)
if k := rv.Kind(); k != reflect.Slice {
return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k)
}
scan, err := scanType(rv.Type().Elem(), columns)
if err != nil {
return err
}
if n, m := len(columns), len(scan.columns); n > m {
return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m)
}
for rows.Next() {
values := scan.values()
if err := rows.Scan(values...); err != nil {
return fmt.Errorf("sql/scan: failed scanning rows: %w", err)
}
vv := reflect.Append(rv, scan.value(values...))
rv.Set(vv)
}
return rows.Err()
}
// rowScan is the configuration for scanning one sql.Row.
type rowScan struct {
// column types of a row.
columns []reflect.Type
// value functions that converts the row columns (result) to a reflect.Value.
value func(v ...interface{}) reflect.Value
}
// values returns a []interface{} from the configured column types.
func (r *rowScan) values() []interface{} {
values := make([]interface{}, len(r.columns))
for i := range r.columns {
values[i] = reflect.New(r.columns[i]).Interface()
}
return values
}
// scanType returns rowScan for the given reflect.Type.
func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
switch k := typ.Kind(); {
case assignable(typ):
return &rowScan{
columns: []reflect.Type{typ},
value: func(v ...interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(v[0]))
},
}, nil
case k == reflect.Ptr:
return scanPtr(typ, columns)
case k == reflect.Struct:
return scanStruct(typ, columns)
default:
return nil, fmt.Errorf("sql/scan: unsupported type ([]%s)", k)
}
}
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
// assignable reports if the given type can be assigned directly by `Rows.Scan`.
func assignable(typ reflect.Type) bool {
switch k := typ.Kind(); {
case typ.Implements(scannerType):
case k == reflect.Interface && typ.NumMethod() == 0:
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
case (k == reflect.Slice || k == reflect.Array) && typ.Elem().Kind() == reflect.Uint8:
default:
return false
}
return true
}
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
var (
scan = &rowScan{}
idxs = make([][]int, 0, typ.NumField())
names = make(map[string][]int, typ.NumField())
)
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
// Skip unexported fields.
if f.PkgPath != "" {
continue
}
// Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`.
if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct {
for j := 0; j < typ.NumField(); j++ {
names[columnName(typ.Field(j))] = []int{i, j}
}
continue
}
names[columnName(f)] = []int{i}
}
for _, c := range columns {
// Normalize columns if necessary, for example: COUNT(*) => count.
name := strings.ToLower(strings.Split(c, "(")[0])
idx, ok := names[name]
if !ok {
return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name)
}
idxs = append(idxs, idx)
rtype := typ.Field(idx[0]).Type
if len(idx) > 1 {
rtype = rtype.Field(idx[1]).Type
}
if !nillable(rtype) {
// Create a pointer to the actual reflect
// types to accept optional struct fields.
rtype = reflect.PtrTo(rtype)
}
scan.columns = append(scan.columns, rtype)
}
scan.value = func(vs ...interface{}) reflect.Value {
st := reflect.New(typ).Elem()
for i, v := range vs {
rv := reflect.Indirect(reflect.ValueOf(v))
if rv.IsNil() {
continue
}
idx := idxs[i]
rvalue := st.Field(idx[0])
if len(idx) > 1 {
rvalue = rvalue.Field(idx[1])
}
if !nillable(rvalue.Type()) {
rv = reflect.Indirect(rv)
}
rvalue.Set(rv)
}
return st
}
return scan, nil
}
// columnName returns the column name of a struct-field.
func columnName(f reflect.StructField) string {
name := strings.ToLower(f.Name)
if tag, ok := f.Tag.Lookup("sql"); ok {
name = tag
} else if tag, ok := f.Tag.Lookup("json"); ok {
name = strings.Split(tag, ",")[0]
}
return name
}
// nillable reports if the reflect-type can have nil value.
func nillable(t reflect.Type) bool {
switch t.Kind() {
case reflect.Interface, reflect.Slice, reflect.Map, reflect.Ptr, reflect.UnsafePointer:
return true
}
return false
}
// scanPtr wraps the underlying type with rowScan.
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
typ = typ.Elem()
scan, err := scanType(typ, columns)
if err != nil {
return nil, err
}
wrap := scan.value
scan.value = func(vs ...interface{}) reflect.Value {
v := wrap(vs...)
pt := reflect.PtrTo(v.Type())
pv := reflect.New(pt.Elem())
pv.Elem().Set(v)
return pv
}
return scan, nil
}