schema/field: add GoType option for time fields (#520)

This commit is contained in:
Ariel Mashraki
2020-06-03 14:36:49 +03:00
committed by GitHub
parent 5c5dee7e6f
commit fff0b1a9ed
2 changed files with 45 additions and 6 deletions

View File

@@ -262,7 +262,7 @@ func (b *stringBuilder) SchemaType(types map[string]string) *stringBuilder {
// GoType(http.Dir("dir"))
//
func (b *stringBuilder) GoType(typ interface{}) *stringBuilder {
b.desc.goType(typ, reflect.String)
b.desc.goType(typ, stringType)
return b
}
@@ -337,6 +337,16 @@ func (b *timeBuilder) StorageKey(key string) *timeBuilder {
return b
}
// GoType overrides the default Go type with a custom one.
//
// field.Time("deleted_at").
// GoType(&sql.NullTime{})
//
func (b *timeBuilder) GoType(typ interface{}) *timeBuilder {
b.desc.goType(typ, timeType)
return b
}
// Descriptor implements the ent.Field interface by returning its descriptor.
func (b *timeBuilder) Descriptor() *Descriptor {
return b.desc
@@ -411,7 +421,7 @@ func (b *boolBuilder) StorageKey(key string) *boolBuilder {
// GoType(&sql.NullBool{})
//
func (b *boolBuilder) GoType(typ interface{}) *boolBuilder {
b.desc.goType(typ, reflect.Bool)
b.desc.goType(typ, boolType)
return b
}
@@ -719,7 +729,7 @@ func (d *Descriptor) Err() error {
return d.err
}
func (d *Descriptor) goType(typ interface{}, expectKind reflect.Kind) {
func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
t := reflect.TypeOf(typ)
tv := indirect(t)
info := &TypeInfo{
@@ -738,7 +748,7 @@ func (d *Descriptor) goType(typ interface{}, expectKind reflect.Kind) {
info.Nillable = true
}
switch {
case t.Kind() == expectKind:
case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
case t.Implements(valueScannerType):
n := t.NumMethod()
for i := 0; i < n; i++ {
@@ -756,12 +766,17 @@ func (d *Descriptor) goType(typ interface{}, expectKind reflect.Kind) {
info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
}
default:
d.err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectKind)
d.err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectType)
}
d.Info = info
}
var valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem()
var (
boolType = reflect.TypeOf(false)
timeType = reflect.TypeOf(time.Time{})
stringType = reflect.TypeOf("string")
valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem()
)
// ValueScanner is the interface that groups the Value and the Scan methods.
type ValueScanner interface {

View File

@@ -193,6 +193,30 @@ func TestTime(t *testing.T) {
Descriptor()
assert.Equal(t, "updated_at", fd.Name)
assert.Equal(t, now, fd.UpdateDefault.(func() time.Time)())
type Time time.Time
fd = field.Time("deleted_at").GoType(Time{}).Descriptor()
assert.NoError(t, fd.Err())
assert.Equal(t, "field_test.Time", fd.Info.Ident)
assert.Equal(t, "github.com/facebookincubator/ent/schema/field_test", fd.Info.PkgPath)
assert.Equal(t, "field_test.Time", fd.Info.String())
assert.False(t, fd.Info.Nillable)
assert.False(t, fd.Info.ValueScanner())
fd = field.Time("deleted_at").GoType(&sql.NullTime{}).Descriptor()
assert.NoError(t, fd.Err())
assert.Equal(t, "sql.NullTime", fd.Info.Ident)
assert.Equal(t, "database/sql", fd.Info.PkgPath)
assert.Equal(t, "sql.NullTime", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
fd = field.Time("active").GoType(1).Descriptor()
assert.Error(t, fd.Err())
fd = field.Time("active").GoType(struct{}{}).Descriptor()
assert.Error(t, fd.Err())
fd = field.Time("active").GoType(new(Time)).Descriptor()
assert.Error(t, fd.Err())
}
func TestJSON(t *testing.T) {