mirror of
https://github.com/ent/ent.git
synced 2026-03-05 19:35:23 +03:00
schema/field: use actual go type in generated interfaces (#1428)
This commit is contained in:
@@ -1056,12 +1056,12 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
|
||||
tv := indirect(t)
|
||||
info := &TypeInfo{
|
||||
Type: d.Info.Type,
|
||||
Ident: tv.String(),
|
||||
Ident: t.String(),
|
||||
PkgPath: tv.PkgPath(),
|
||||
RType: &RType{
|
||||
rtype: tv,
|
||||
rtype: t,
|
||||
Kind: t.Kind(),
|
||||
Name: tv.Name(),
|
||||
Kind: tv.Kind(),
|
||||
PkgPath: tv.PkgPath(),
|
||||
Methods: make(map[string]struct{ In, Out []*RType }, t.NumMethod()),
|
||||
},
|
||||
@@ -1070,8 +1070,10 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
|
||||
case reflect.Slice, reflect.Array, reflect.Ptr, reflect.Map:
|
||||
info.Nillable = true
|
||||
}
|
||||
switch {
|
||||
case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
|
||||
switch pt := reflect.PtrTo(t); {
|
||||
case pt.Implements(valueScannerType):
|
||||
t = pt
|
||||
fallthrough
|
||||
case t.Implements(valueScannerType):
|
||||
n := t.NumMethod()
|
||||
for i := 0; i < n; i++ {
|
||||
@@ -1088,11 +1090,9 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
|
||||
}
|
||||
info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
|
||||
}
|
||||
case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
|
||||
default:
|
||||
d.Err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectType)
|
||||
if pt := reflect.PtrTo(t); pt.Implements(valueScannerType) {
|
||||
d.Err = fmt.Errorf("%s. Use %s instead", d.Err, pt)
|
||||
}
|
||||
}
|
||||
d.Info = info
|
||||
}
|
||||
@@ -1120,6 +1120,7 @@ var (
|
||||
bytesType = reflect.TypeOf([]byte(nil))
|
||||
timeType = reflect.TypeOf(time.Time{})
|
||||
stringType = reflect.TypeOf("")
|
||||
valuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem()
|
||||
)
|
||||
|
||||
|
||||
@@ -77,14 +77,12 @@ func TestInt(t *testing.T) {
|
||||
|
||||
fd = field.Int("count").GoType(&sql.NullInt64{}).Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "sql.NullInt64", fd.Info.Ident)
|
||||
assert.Equal(t, "*sql.NullInt64", fd.Info.Ident)
|
||||
assert.Equal(t, "database/sql", fd.Info.PkgPath)
|
||||
assert.Equal(t, "sql.NullInt64", fd.Info.String())
|
||||
assert.Equal(t, "*sql.NullInt64", fd.Info.String())
|
||||
assert.True(t, fd.Info.Nillable)
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
|
||||
fd = field.Int("count").GoType(sql.NullInt64{}).Descriptor()
|
||||
assert.EqualError(t, fd.Err, `GoType must be a "int" type or ValueScanner. Use *sql.NullInt64 instead`)
|
||||
fd = field.Int("count").GoType(false).Descriptor()
|
||||
assert.EqualError(t, fd.Err, `GoType must be a "int" type or ValueScanner`)
|
||||
fd = field.Int("count").GoType(struct{}{}).Descriptor()
|
||||
@@ -132,9 +130,9 @@ func TestFloat(t *testing.T) {
|
||||
|
||||
fd = field.Float("count").GoType(&sql.NullFloat64{}).Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "sql.NullFloat64", fd.Info.Ident)
|
||||
assert.Equal(t, "*sql.NullFloat64", fd.Info.Ident)
|
||||
assert.Equal(t, "database/sql", fd.Info.PkgPath)
|
||||
assert.Equal(t, "sql.NullFloat64", fd.Info.String())
|
||||
assert.Equal(t, "*sql.NullFloat64", fd.Info.String())
|
||||
assert.True(t, fd.Info.Nillable)
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
|
||||
@@ -166,9 +164,9 @@ func TestBool(t *testing.T) {
|
||||
|
||||
fd = field.Bool("deleted").GoType(&sql.NullBool{}).Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "sql.NullBool", fd.Info.Ident)
|
||||
assert.Equal(t, "*sql.NullBool", fd.Info.Ident)
|
||||
assert.Equal(t, "database/sql", fd.Info.PkgPath)
|
||||
assert.Equal(t, "sql.NullBool", fd.Info.String())
|
||||
assert.Equal(t, "*sql.NullBool", fd.Info.String())
|
||||
assert.True(t, fd.Info.Nillable)
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
|
||||
@@ -180,6 +178,13 @@ func TestBool(t *testing.T) {
|
||||
assert.Error(t, fd.Err)
|
||||
}
|
||||
|
||||
type Pair struct {
|
||||
K, V []byte
|
||||
}
|
||||
|
||||
func (*Pair) Scan(interface{}) error { return nil }
|
||||
func (Pair) Value() (driver.Value, error) { return nil, nil }
|
||||
|
||||
func TestBytes(t *testing.T) {
|
||||
fd := field.Bytes("active").Default([]byte("{}")).Comment("comment").Descriptor()
|
||||
assert.Equal(t, "active", fd.Name)
|
||||
@@ -196,18 +201,15 @@ func TestBytes(t *testing.T) {
|
||||
assert.True(t, fd.Info.Nillable)
|
||||
assert.False(t, fd.Info.ValueScanner())
|
||||
|
||||
fd = field.Bytes("blob").GoType(&sql.NullString{}).Descriptor()
|
||||
fd = field.Bytes("blob").GoType(sql.NullString{}).Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "sql.NullString", fd.Info.Ident)
|
||||
assert.Equal(t, "database/sql", fd.Info.PkgPath)
|
||||
assert.Equal(t, "sql.NullString", fd.Info.String())
|
||||
assert.True(t, fd.Info.Nillable)
|
||||
assert.False(t, fd.Info.Nillable)
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
|
||||
fd = field.Bytes("uuid").
|
||||
GoType(&uuid.UUID{}).
|
||||
DefaultFunc(uuid.New).
|
||||
Descriptor()
|
||||
fd = field.Bytes("uuid").GoType(uuid.UUID{}).DefaultFunc(uuid.New).Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "uuid.UUID", fd.Info.Ident)
|
||||
assert.Equal(t, "github.com/google/uuid", fd.Info.PkgPath)
|
||||
@@ -216,6 +218,18 @@ func TestBytes(t *testing.T) {
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
assert.NotEmpty(t, fd.Default.(func() uuid.UUID)())
|
||||
|
||||
fd = field.Bytes("uuid").
|
||||
GoType(uuid.UUID{}).
|
||||
DefaultFunc(uuid.New).
|
||||
Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "uuid.UUID", fd.Info.String())
|
||||
fd = field.Bytes("pair").
|
||||
GoType(&Pair{}).
|
||||
Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "*field_test.Pair", fd.Info.String())
|
||||
|
||||
fd = field.Bytes("blob").GoType(1).Descriptor()
|
||||
assert.Error(t, fd.Err)
|
||||
fd = field.Bytes("blob").GoType(struct{}{}).Descriptor()
|
||||
@@ -263,7 +277,7 @@ func TestString_DefaultFunc(t *testing.T) {
|
||||
assert.Error(t, fd.Err, "`var _ http.Dir = f2()` should fail")
|
||||
|
||||
f3 := func() sql.NullString { return sql.NullString{} }
|
||||
fd = field.String("str").GoType(&sql.NullString{}).DefaultFunc(f3).Descriptor()
|
||||
fd = field.String("str").GoType(sql.NullString{}).DefaultFunc(f3).Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
|
||||
type S string
|
||||
@@ -272,6 +286,16 @@ func TestString_DefaultFunc(t *testing.T) {
|
||||
assert.Error(t, fd.Err, "`var _ http.Dir = f4()` should fail")
|
||||
}
|
||||
|
||||
type VString string
|
||||
|
||||
func (s *VString) Scan(interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s VString) Value() (driver.Value, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
fd := field.String("name").
|
||||
DefaultFunc(func() string {
|
||||
@@ -311,23 +335,27 @@ func TestString(t *testing.T) {
|
||||
|
||||
fd = field.String("nullable_name").GoType(&sql.NullString{}).Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "sql.NullString", fd.Info.Ident)
|
||||
assert.Equal(t, "*sql.NullString", fd.Info.Ident)
|
||||
assert.Equal(t, "database/sql", fd.Info.PkgPath)
|
||||
assert.Equal(t, "sql.NullString", fd.Info.String())
|
||||
assert.Equal(t, "*sql.NullString", fd.Info.String())
|
||||
assert.True(t, fd.Info.Nillable)
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
assert.False(t, fd.Info.Stringer())
|
||||
assert.True(t, fd.Info.RType.TypeEqual(reflect.TypeOf(sql.NullString{})))
|
||||
assert.True(t, fd.Info.RType.TypeEqual(reflect.TypeOf(&sql.NullString{})))
|
||||
|
||||
fd = field.String("nullable_name").GoType(VString("")).Descriptor()
|
||||
assert.True(t, fd.Info.Valuer())
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
assert.False(t, fd.Info.Stringer())
|
||||
|
||||
type tURL struct {
|
||||
field.ValueScanner
|
||||
*url.URL
|
||||
}
|
||||
fd = field.String("nullable_url").GoType(&tURL{}).Descriptor()
|
||||
assert.Equal(t, "field_test.tURL", fd.Info.Ident)
|
||||
assert.Equal(t, "*field_test.tURL", fd.Info.Ident)
|
||||
assert.Equal(t, "entgo.io/ent/schema/field_test", fd.Info.PkgPath)
|
||||
assert.Equal(t, "field_test.tURL", fd.Info.String())
|
||||
assert.Equal(t, "*field_test.tURL", fd.Info.String())
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
assert.True(t, fd.Info.Stringer())
|
||||
|
||||
@@ -373,9 +401,9 @@ func TestTime(t *testing.T) {
|
||||
|
||||
fd = field.Time("deleted_at").GoType(&sql.NullTime{}).Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "sql.NullTime", fd.Info.Ident)
|
||||
assert.Equal(t, "*sql.NullTime", fd.Info.Ident)
|
||||
assert.Equal(t, "database/sql", fd.Info.PkgPath)
|
||||
assert.Equal(t, "sql.NullTime", fd.Info.String())
|
||||
assert.Equal(t, "*sql.NullTime", fd.Info.String())
|
||||
assert.True(t, fd.Info.Nillable)
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
|
||||
@@ -513,13 +541,13 @@ func (c custom) Value() (driver.Value, error) {
|
||||
func TestField_Other(t *testing.T) {
|
||||
fd := field.Other("other", &custom{}).
|
||||
Unique().
|
||||
Default(custom{}).
|
||||
Default(&custom{}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "varchar"}).
|
||||
Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
assert.Equal(t, "other", fd.Name)
|
||||
assert.True(t, fd.Unique)
|
||||
assert.Equal(t, "field_test.custom", fd.Info.String())
|
||||
assert.Equal(t, "*field_test.custom", fd.Info.String())
|
||||
assert.Equal(t, "entgo.io/ent/schema/field_test", fd.Info.PkgPath)
|
||||
assert.NotNil(t, fd.Default)
|
||||
|
||||
@@ -529,13 +557,13 @@ func TestField_Other(t *testing.T) {
|
||||
|
||||
fd = field.Other("other", &custom{}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "varchar"}).
|
||||
Default(func() custom { return custom{} }).
|
||||
Default(func() *custom { return &custom{} }).
|
||||
Descriptor()
|
||||
assert.NoError(t, fd.Err)
|
||||
|
||||
fd = field.Other("other", &custom{}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "varchar"}).
|
||||
Default(func() *custom { return &custom{} }).
|
||||
Default(func() custom { return custom{} }).
|
||||
Descriptor()
|
||||
assert.Error(t, fd.Err, "invalid default value")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package field
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
@@ -122,6 +123,11 @@ func (t TypeInfo) ValueScanner() bool {
|
||||
return t.RType.implements(valueScannerType)
|
||||
}
|
||||
|
||||
// ValueScanner indicates if this type implements the driver.Valuer interface.
|
||||
func (t TypeInfo) Valuer() bool {
|
||||
return t.RType.implements(valuerType)
|
||||
}
|
||||
|
||||
// Comparable reports whether values of this type are comparable.
|
||||
func (t TypeInfo) Comparable() bool {
|
||||
switch t.Type {
|
||||
@@ -189,10 +195,23 @@ type RType struct {
|
||||
rtype reflect.Type
|
||||
}
|
||||
|
||||
// TypeEqual tests if the RType is equal to given reflect.Type.
|
||||
// TypeEqual reports if the underlying type is equal to the RType (after pointer indirections).
|
||||
func (r *RType) TypeEqual(t reflect.Type) bool {
|
||||
t = indirect(t)
|
||||
return r.Name == t.Name() && r.Kind == t.Kind() && r.PkgPath == t.PkgPath()
|
||||
tv := indirect(t)
|
||||
return r.Name == tv.Name() && r.Kind == t.Kind() && r.PkgPath == tv.PkgPath()
|
||||
}
|
||||
|
||||
// RType returns the string value of the indirect reflect.Type.
|
||||
func (r *RType) String() string {
|
||||
if r.rtype != nil {
|
||||
return r.rtype.String()
|
||||
}
|
||||
return path.Base(r.PkgPath) + "." + r.Name
|
||||
}
|
||||
|
||||
// IsPtr reports if the reflect-type is a pointer type.
|
||||
func (r *RType) IsPtr() bool {
|
||||
return r != nil && r.Kind == reflect.Ptr
|
||||
}
|
||||
|
||||
func (r *RType) implements(typ reflect.Type) bool {
|
||||
|
||||
Reference in New Issue
Block a user