From 142773b73d344287a24bf4319d949e6a2182ecf9 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Thu, 24 Sep 2020 11:13:23 +0300 Subject: [PATCH] dialect/sqlscan: support scanning []uint8 type (#797) Fixed #796 --- dialect/sql/scan.go | 19 ++++++++++++++++--- dialect/sql/scan_test.go | 18 ++++++++++++++++++ entc/integration/customid/customid_test.go | 1 + 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/dialect/sql/scan.go b/dialect/sql/scan.go index e23752fc2..73f041514 100644 --- a/dialect/sql/scan.go +++ b/dialect/sql/scan.go @@ -119,9 +119,7 @@ func (r *rowScan) values() []interface{} { // scanType returns rowScan for the given reflect.Type. func scanType(typ reflect.Type, columns []string) (*rowScan, error) { switch k := typ.Kind(); { - case k == reflect.Interface && typ.NumMethod() == 0: - fallthrough // interface{} - case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64: + case assignable(typ): return &rowScan{ columns: []reflect.Type{typ}, value: func(v ...interface{}) reflect.Value { @@ -137,6 +135,21 @@ func scanType(typ reflect.Type, columns []string) (*rowScan, error) { } } +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +// assignable reports if the given type can be assigned directly by `Rows.Scan`. +func assignable(typ reflect.Type) bool { + switch k := typ.Kind(); { + case typ.Implements(scannerType): + case k == reflect.Interface && typ.NumMethod() == 0: + case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64: + case (k == reflect.Slice || k == reflect.Array) && typ.Elem().Kind() == reflect.Uint8: + default: + return false + } + return true +} + // scanStruct returns the a configuration for scanning an sql.Row into a struct. func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { var ( diff --git a/dialect/sql/scan_test.go b/dialect/sql/scan_test.go index b1291dcca..c0f0671e6 100644 --- a/dialect/sql/scan_test.go +++ b/dialect/sql/scan_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/DATA-DOG/go-sqlmock" + "github.com/google/uuid" "github.com/stretchr/testify/require" ) @@ -93,6 +94,23 @@ func TestScanSlice(t *testing.T) { require.False(t, v6[0].Name.Valid) require.False(t, v6[1].Age.Valid) require.Equal(t, "a8m", v6[1].Name.String) + + u1, u2 := uuid.New().String(), uuid.New().String() + mock = sqlmock.NewRows([]string{"ids"}). + AddRow([]byte(u1)). + AddRow([]byte(u2)) + var ids []uuid.UUID + require.NoError(t, ScanSlice(toRows(mock), &ids)) + require.Equal(t, u1, ids[0].String()) + require.Equal(t, u2, ids[1].String()) + + mock = sqlmock.NewRows([]string{"pids"}). + AddRow([]byte(u1)). + AddRow([]byte(u2)) + var pids []*uuid.UUID + require.NoError(t, ScanSlice(toRows(mock), &pids)) + require.Equal(t, u1, pids[0].String()) + require.Equal(t, u2, pids[1].String()) } func TestScanSlicePtr(t *testing.T) { diff --git a/entc/integration/customid/customid_test.go b/entc/integration/customid/customid_test.go index c776cea3f..10bf18602 100644 --- a/entc/integration/customid/customid_test.go +++ b/entc/integration/customid/customid_test.go @@ -101,6 +101,7 @@ func CustomID(t *testing.T, client *ent.Client) { require.Equal(t, 2, lnk.QueryLinks().CountX(ctx)) require.Equal(t, lnk.ID, chd.QueryLinks().OnlyX(ctx).ID) require.Equal(t, lnk.ID, blb.QueryLinks().OnlyX(ctx).ID) + require.Len(t, client.Blob.Query().IDsX(ctx), 3) pedro := client.Pet.Create().SetID("pedro").SetOwner(a8m).SaveX(ctx) require.Equal(t, a8m.ID, pedro.QueryOwner().OnlyIDX(ctx))