Files
ent/dialect/sql/scan.go
facebook-github-bot 267e3c15bd Initial commit
fbshipit-source-id: c79a38536e3c128dce1b2948615b72ec9779ed22
2019-06-16 04:37:51 -07:00

124 lines
3.3 KiB
Go

package sql
import (
"fmt"
"reflect"
"strings"
)
// ColumnScanner is the interface that wraps the
// three sql.Rows methods used for scanning.
type ColumnScanner interface {
Next() bool
Scan(...interface{}) error
Columns() ([]string, error)
}
// ScanSlice scans the given ColumnScanner (basically, sql.Rows or sql.Rows) into the given slice.
func ScanSlice(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("sql/scan: failed getting column names: %v", err)
}
rv := reflect.Indirect(reflect.ValueOf(v))
if k := rv.Kind(); k != reflect.Slice {
return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k)
}
var (
scan *rowScan
typ = rv.Type().Elem()
)
switch k := typ.Kind(); {
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
scan = &rowScan{
columns: []reflect.Type{typ},
value: func(v ...interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(v[0]))
},
}
case k == reflect.Ptr:
typ = typ.Elem()
if scan, err = scanStruct(typ, columns); err != nil {
return err
}
wrap := scan.value
scan.value = func(vs ...interface{}) reflect.Value {
v := wrap(vs...)
pt := reflect.PtrTo(v.Type())
pv := reflect.New(pt.Elem())
pv.Elem().Set(v)
return pv
}
case k == reflect.Struct:
if scan, err = scanStruct(typ, columns); err != nil {
return err
}
default:
return fmt.Errorf("sql/scan: unsupported type ([]%s)", k)
}
if n, m := len(columns), len(scan.columns); n > m {
return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m)
}
for rows.Next() {
values := scan.values()
if err := rows.Scan(values...); err != nil {
return fmt.Errorf("sql/scan: failed scanning rows: %v", err)
}
vv := reflect.Append(rv, scan.value(values...))
rv.Set(vv)
}
return nil
}
// rowScan is the configuration for scanning one sql.Row.
type rowScan struct {
// column types of a row.
columns []reflect.Type
// value functions that converts the row columns (result) to a reflect.Value.
value func(v ...interface{}) reflect.Value
}
// values returns a []interface{} from the configured column types.
func (r *rowScan) values() []interface{} {
values := make([]interface{}, len(r.columns))
for i := range r.columns {
values[i] = reflect.New(r.columns[i]).Interface()
}
return values
}
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
var (
scan = &rowScan{}
names = make(map[string]int)
idx = make([]int, 0, typ.NumField())
)
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
name := strings.ToLower(f.Name)
if tag, ok := f.Tag.Lookup("json"); ok {
name = strings.Split(tag, ",")[0]
}
names[name] = i
}
for _, c := range columns {
// normalize columns if necessary, for example: COUNT(*) => count.
name := strings.ToLower(strings.Split(c, "(")[0])
i, ok := names[name]
if !ok {
return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name)
}
idx = append(idx, i)
scan.columns = append(scan.columns, typ.Field(i).Type)
}
scan.value = func(vs ...interface{}) reflect.Value {
st := reflect.New(typ).Elem()
for i, v := range vs {
st.Field(idx[i]).Set(reflect.Indirect(reflect.ValueOf(v)))
}
return st
}
return scan, nil
}