schema/field: use actual go type in generated interfaces (#1428)

This commit is contained in:
Ariel Mashraki
2021-04-07 09:53:44 +03:00
committed by GitHub
parent 3fe9d1081e
commit 2cc1c628dc
107 changed files with 4932 additions and 2455 deletions

View File

@@ -1056,12 +1056,12 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
tv := indirect(t)
info := &TypeInfo{
Type: d.Info.Type,
Ident: tv.String(),
Ident: t.String(),
PkgPath: tv.PkgPath(),
RType: &RType{
rtype: tv,
rtype: t,
Kind: t.Kind(),
Name: tv.Name(),
Kind: tv.Kind(),
PkgPath: tv.PkgPath(),
Methods: make(map[string]struct{ In, Out []*RType }, t.NumMethod()),
},
@@ -1070,8 +1070,10 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
case reflect.Slice, reflect.Array, reflect.Ptr, reflect.Map:
info.Nillable = true
}
switch {
case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
switch pt := reflect.PtrTo(t); {
case pt.Implements(valueScannerType):
t = pt
fallthrough
case t.Implements(valueScannerType):
n := t.NumMethod()
for i := 0; i < n; i++ {
@@ -1088,11 +1090,9 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
}
info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
}
case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
default:
d.Err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectType)
if pt := reflect.PtrTo(t); pt.Implements(valueScannerType) {
d.Err = fmt.Errorf("%s. Use %s instead", d.Err, pt)
}
}
d.Info = info
}
@@ -1120,6 +1120,7 @@ var (
bytesType = reflect.TypeOf([]byte(nil))
timeType = reflect.TypeOf(time.Time{})
stringType = reflect.TypeOf("")
valuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem()
)

View File

@@ -77,14 +77,12 @@ func TestInt(t *testing.T) {
fd = field.Int("count").GoType(&sql.NullInt64{}).Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "sql.NullInt64", fd.Info.Ident)
assert.Equal(t, "*sql.NullInt64", fd.Info.Ident)
assert.Equal(t, "database/sql", fd.Info.PkgPath)
assert.Equal(t, "sql.NullInt64", fd.Info.String())
assert.Equal(t, "*sql.NullInt64", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
fd = field.Int("count").GoType(sql.NullInt64{}).Descriptor()
assert.EqualError(t, fd.Err, `GoType must be a "int" type or ValueScanner. Use *sql.NullInt64 instead`)
fd = field.Int("count").GoType(false).Descriptor()
assert.EqualError(t, fd.Err, `GoType must be a "int" type or ValueScanner`)
fd = field.Int("count").GoType(struct{}{}).Descriptor()
@@ -132,9 +130,9 @@ func TestFloat(t *testing.T) {
fd = field.Float("count").GoType(&sql.NullFloat64{}).Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "sql.NullFloat64", fd.Info.Ident)
assert.Equal(t, "*sql.NullFloat64", fd.Info.Ident)
assert.Equal(t, "database/sql", fd.Info.PkgPath)
assert.Equal(t, "sql.NullFloat64", fd.Info.String())
assert.Equal(t, "*sql.NullFloat64", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
@@ -166,9 +164,9 @@ func TestBool(t *testing.T) {
fd = field.Bool("deleted").GoType(&sql.NullBool{}).Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "sql.NullBool", fd.Info.Ident)
assert.Equal(t, "*sql.NullBool", fd.Info.Ident)
assert.Equal(t, "database/sql", fd.Info.PkgPath)
assert.Equal(t, "sql.NullBool", fd.Info.String())
assert.Equal(t, "*sql.NullBool", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
@@ -180,6 +178,13 @@ func TestBool(t *testing.T) {
assert.Error(t, fd.Err)
}
type Pair struct {
K, V []byte
}
func (*Pair) Scan(interface{}) error { return nil }
func (Pair) Value() (driver.Value, error) { return nil, nil }
func TestBytes(t *testing.T) {
fd := field.Bytes("active").Default([]byte("{}")).Comment("comment").Descriptor()
assert.Equal(t, "active", fd.Name)
@@ -196,18 +201,15 @@ func TestBytes(t *testing.T) {
assert.True(t, fd.Info.Nillable)
assert.False(t, fd.Info.ValueScanner())
fd = field.Bytes("blob").GoType(&sql.NullString{}).Descriptor()
fd = field.Bytes("blob").GoType(sql.NullString{}).Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "sql.NullString", fd.Info.Ident)
assert.Equal(t, "database/sql", fd.Info.PkgPath)
assert.Equal(t, "sql.NullString", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.False(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
fd = field.Bytes("uuid").
GoType(&uuid.UUID{}).
DefaultFunc(uuid.New).
Descriptor()
fd = field.Bytes("uuid").GoType(uuid.UUID{}).DefaultFunc(uuid.New).Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "uuid.UUID", fd.Info.Ident)
assert.Equal(t, "github.com/google/uuid", fd.Info.PkgPath)
@@ -216,6 +218,18 @@ func TestBytes(t *testing.T) {
assert.True(t, fd.Info.ValueScanner())
assert.NotEmpty(t, fd.Default.(func() uuid.UUID)())
fd = field.Bytes("uuid").
GoType(uuid.UUID{}).
DefaultFunc(uuid.New).
Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "uuid.UUID", fd.Info.String())
fd = field.Bytes("pair").
GoType(&Pair{}).
Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "*field_test.Pair", fd.Info.String())
fd = field.Bytes("blob").GoType(1).Descriptor()
assert.Error(t, fd.Err)
fd = field.Bytes("blob").GoType(struct{}{}).Descriptor()
@@ -263,7 +277,7 @@ func TestString_DefaultFunc(t *testing.T) {
assert.Error(t, fd.Err, "`var _ http.Dir = f2()` should fail")
f3 := func() sql.NullString { return sql.NullString{} }
fd = field.String("str").GoType(&sql.NullString{}).DefaultFunc(f3).Descriptor()
fd = field.String("str").GoType(sql.NullString{}).DefaultFunc(f3).Descriptor()
assert.NoError(t, fd.Err)
type S string
@@ -272,6 +286,16 @@ func TestString_DefaultFunc(t *testing.T) {
assert.Error(t, fd.Err, "`var _ http.Dir = f4()` should fail")
}
type VString string
func (s *VString) Scan(interface{}) error {
return nil
}
func (s VString) Value() (driver.Value, error) {
return "", nil
}
func TestString(t *testing.T) {
fd := field.String("name").
DefaultFunc(func() string {
@@ -311,23 +335,27 @@ func TestString(t *testing.T) {
fd = field.String("nullable_name").GoType(&sql.NullString{}).Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "sql.NullString", fd.Info.Ident)
assert.Equal(t, "*sql.NullString", fd.Info.Ident)
assert.Equal(t, "database/sql", fd.Info.PkgPath)
assert.Equal(t, "sql.NullString", fd.Info.String())
assert.Equal(t, "*sql.NullString", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
assert.False(t, fd.Info.Stringer())
assert.True(t, fd.Info.RType.TypeEqual(reflect.TypeOf(sql.NullString{})))
assert.True(t, fd.Info.RType.TypeEqual(reflect.TypeOf(&sql.NullString{})))
fd = field.String("nullable_name").GoType(VString("")).Descriptor()
assert.True(t, fd.Info.Valuer())
assert.True(t, fd.Info.ValueScanner())
assert.False(t, fd.Info.Stringer())
type tURL struct {
field.ValueScanner
*url.URL
}
fd = field.String("nullable_url").GoType(&tURL{}).Descriptor()
assert.Equal(t, "field_test.tURL", fd.Info.Ident)
assert.Equal(t, "*field_test.tURL", fd.Info.Ident)
assert.Equal(t, "entgo.io/ent/schema/field_test", fd.Info.PkgPath)
assert.Equal(t, "field_test.tURL", fd.Info.String())
assert.Equal(t, "*field_test.tURL", fd.Info.String())
assert.True(t, fd.Info.ValueScanner())
assert.True(t, fd.Info.Stringer())
@@ -373,9 +401,9 @@ func TestTime(t *testing.T) {
fd = field.Time("deleted_at").GoType(&sql.NullTime{}).Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "sql.NullTime", fd.Info.Ident)
assert.Equal(t, "*sql.NullTime", fd.Info.Ident)
assert.Equal(t, "database/sql", fd.Info.PkgPath)
assert.Equal(t, "sql.NullTime", fd.Info.String())
assert.Equal(t, "*sql.NullTime", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
@@ -513,13 +541,13 @@ func (c custom) Value() (driver.Value, error) {
func TestField_Other(t *testing.T) {
fd := field.Other("other", &custom{}).
Unique().
Default(custom{}).
Default(&custom{}).
SchemaType(map[string]string{dialect.Postgres: "varchar"}).
Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "other", fd.Name)
assert.True(t, fd.Unique)
assert.Equal(t, "field_test.custom", fd.Info.String())
assert.Equal(t, "*field_test.custom", fd.Info.String())
assert.Equal(t, "entgo.io/ent/schema/field_test", fd.Info.PkgPath)
assert.NotNil(t, fd.Default)
@@ -529,13 +557,13 @@ func TestField_Other(t *testing.T) {
fd = field.Other("other", &custom{}).
SchemaType(map[string]string{dialect.Postgres: "varchar"}).
Default(func() custom { return custom{} }).
Default(func() *custom { return &custom{} }).
Descriptor()
assert.NoError(t, fd.Err)
fd = field.Other("other", &custom{}).
SchemaType(map[string]string{dialect.Postgres: "varchar"}).
Default(func() *custom { return &custom{} }).
Default(func() custom { return custom{} }).
Descriptor()
assert.Error(t, fd.Err, "invalid default value")
}

View File

@@ -6,6 +6,7 @@ package field
import (
"fmt"
"path"
"reflect"
"strings"
)
@@ -122,6 +123,11 @@ func (t TypeInfo) ValueScanner() bool {
return t.RType.implements(valueScannerType)
}
// ValueScanner indicates if this type implements the driver.Valuer interface.
func (t TypeInfo) Valuer() bool {
return t.RType.implements(valuerType)
}
// Comparable reports whether values of this type are comparable.
func (t TypeInfo) Comparable() bool {
switch t.Type {
@@ -189,10 +195,23 @@ type RType struct {
rtype reflect.Type
}
// TypeEqual tests if the RType is equal to given reflect.Type.
// TypeEqual reports if the underlying type is equal to the RType (after pointer indirections).
func (r *RType) TypeEqual(t reflect.Type) bool {
t = indirect(t)
return r.Name == t.Name() && r.Kind == t.Kind() && r.PkgPath == t.PkgPath()
tv := indirect(t)
return r.Name == tv.Name() && r.Kind == t.Kind() && r.PkgPath == tv.PkgPath()
}
// RType returns the string value of the indirect reflect.Type.
func (r *RType) String() string {
if r.rtype != nil {
return r.rtype.String()
}
return path.Base(r.PkgPath) + "." + r.Name
}
// IsPtr reports if the reflect-type is a pointer type.
func (r *RType) IsPtr() bool {
return r != nil && r.Kind == reflect.Ptr
}
func (r *RType) implements(typ reflect.Type) bool {