dialect/sqlscan: support scanning []uint8 type (#797)

Fixed #796
This commit is contained in:
Ariel Mashraki
2020-09-24 11:13:23 +03:00
committed by GitHub
parent 47fef27bc6
commit 142773b73d
3 changed files with 35 additions and 3 deletions

View File

@@ -119,9 +119,7 @@ func (r *rowScan) values() []interface{} {
// scanType returns rowScan for the given reflect.Type.
func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
switch k := typ.Kind(); {
case k == reflect.Interface && typ.NumMethod() == 0:
fallthrough // interface{}
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
case assignable(typ):
return &rowScan{
columns: []reflect.Type{typ},
value: func(v ...interface{}) reflect.Value {
@@ -137,6 +135,21 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
}
}
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
// assignable reports if the given type can be assigned directly by `Rows.Scan`.
func assignable(typ reflect.Type) bool {
switch k := typ.Kind(); {
case typ.Implements(scannerType):
case k == reflect.Interface && typ.NumMethod() == 0:
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
case (k == reflect.Slice || k == reflect.Array) && typ.Elem().Kind() == reflect.Uint8:
default:
return false
}
return true
}
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
var (

View File

@@ -10,6 +10,7 @@ import (
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
@@ -93,6 +94,23 @@ func TestScanSlice(t *testing.T) {
require.False(t, v6[0].Name.Valid)
require.False(t, v6[1].Age.Valid)
require.Equal(t, "a8m", v6[1].Name.String)
u1, u2 := uuid.New().String(), uuid.New().String()
mock = sqlmock.NewRows([]string{"ids"}).
AddRow([]byte(u1)).
AddRow([]byte(u2))
var ids []uuid.UUID
require.NoError(t, ScanSlice(toRows(mock), &ids))
require.Equal(t, u1, ids[0].String())
require.Equal(t, u2, ids[1].String())
mock = sqlmock.NewRows([]string{"pids"}).
AddRow([]byte(u1)).
AddRow([]byte(u2))
var pids []*uuid.UUID
require.NoError(t, ScanSlice(toRows(mock), &pids))
require.Equal(t, u1, pids[0].String())
require.Equal(t, u2, pids[1].String())
}
func TestScanSlicePtr(t *testing.T) {