diff --git a/schema/field/field.go b/schema/field/field.go index f54cb41d2..245332d5d 100644 --- a/schema/field/field.go +++ b/schema/field/field.go @@ -405,6 +405,16 @@ func (b *boolBuilder) StorageKey(key string) *boolBuilder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Bool("deleted"). +// GoType(&sql.NullBool{}) +// +func (b *boolBuilder) GoType(typ interface{}) *boolBuilder { + b.desc.goType(typ, reflect.Bool) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *boolBuilder) Descriptor() *Descriptor { return b.desc diff --git a/schema/field/field_test.go b/schema/field/field_test.go index 2977655a1..6fb9300d4 100644 --- a/schema/field/field_test.go +++ b/schema/field/field_test.go @@ -77,13 +77,36 @@ func TestFloat(t *testing.T) { } func TestBool(t *testing.T) { - f := field.Bool("active").Default(true).Immutable() - fd := f.Descriptor() + fd := field.Bool("active").Default(true).Immutable().Descriptor() assert.Equal(t, "active", fd.Name) assert.Equal(t, field.TypeBool, fd.Info.Type) assert.NotNil(t, fd.Default) assert.True(t, fd.Immutable) assert.Equal(t, true, fd.Default) + + type Status bool + fd = field.Bool("active").GoType(Status(false)).Descriptor() + assert.NoError(t, fd.Err()) + assert.Equal(t, "field_test.Status", fd.Info.Ident) + assert.Equal(t, "github.com/facebookincubator/ent/schema/field_test", fd.Info.PkgPath) + assert.Equal(t, "field_test.Status", fd.Info.String()) + assert.False(t, fd.Info.Nillable) + assert.False(t, fd.Info.ValueScanner()) + + fd = field.Bool("deleted").GoType(&sql.NullBool{}).Descriptor() + assert.NoError(t, fd.Err()) + assert.Equal(t, "sql.NullBool", fd.Info.Ident) + assert.Equal(t, "database/sql", fd.Info.PkgPath) + assert.Equal(t, "sql.NullBool", fd.Info.String()) + assert.True(t, fd.Info.Nillable) + assert.True(t, fd.Info.ValueScanner()) + + fd = field.Bool("active").GoType(1).Descriptor() + assert.Error(t, fd.Err()) + fd = field.Bool("active").GoType(struct{}{}).Descriptor() + assert.Error(t, fd.Err()) + fd = field.Bool("active").GoType(new(Status)).Descriptor() + assert.Error(t, fd.Err()) } func TestBytes(t *testing.T) {