diff --git a/schema/field/field_test.go b/schema/field/field_test.go index fb74a7045..a5632cea1 100644 --- a/schema/field/field_test.go +++ b/schema/field/field_test.go @@ -62,6 +62,30 @@ func TestInt(t *testing.T) { assert.Equal(t, field.TypeUint16, field.Uint16("age").Descriptor().Info.Type) assert.Equal(t, field.TypeUint32, field.Uint32("age").Descriptor().Info.Type) assert.Equal(t, field.TypeUint64, field.Uint64("age").Descriptor().Info.Type) + + type Count int + fd = field.Int("active").GoType(Count(0)).Descriptor() + assert.NoError(t, fd.Err()) + assert.Equal(t, "field_test.Count", fd.Info.Ident) + assert.Equal(t, "github.com/facebookincubator/ent/schema/field_test", fd.Info.PkgPath) + assert.Equal(t, "field_test.Count", fd.Info.String()) + assert.False(t, fd.Info.Nillable) + assert.False(t, fd.Info.ValueScanner()) + + fd = field.Int("count").GoType(&sql.NullInt64{}).Descriptor() + assert.NoError(t, fd.Err()) + assert.Equal(t, "sql.NullInt64", fd.Info.Ident) + assert.Equal(t, "database/sql", fd.Info.PkgPath) + assert.Equal(t, "sql.NullInt64", fd.Info.String()) + assert.True(t, fd.Info.Nillable) + assert.True(t, fd.Info.ValueScanner()) + + fd = field.Int("count").GoType(false).Descriptor() + assert.Error(t, fd.Err()) + fd = field.Int("count").GoType(struct{}{}).Descriptor() + assert.Error(t, fd.Err()) + fd = field.Int("count").GoType(new(Count)).Descriptor() + assert.Error(t, fd.Err()) } func TestFloat(t *testing.T) { diff --git a/schema/field/gen/numeric.tmpl b/schema/field/gen/numeric.tmpl index 1439d5c00..63be9bdb5 100644 --- a/schema/field/gen/numeric.tmpl +++ b/schema/field/gen/numeric.tmpl @@ -5,7 +5,10 @@ package field -import "errors" +import ( + "errors" + "reflect" +) //go:generate go run gen/gen.go @@ -33,7 +36,7 @@ func Float32(name string) *float32Builder { return &float32Builder{&Descriptor{ }} } -{{ range $_, $t := $.Ints }} +{{ range $t := $.Ints }} {{ $builder := printf "%sBuilder" $t }} // {{ $builder }} is the builder for {{ $t }} field. @@ -160,6 +163,17 @@ func (b *{{ $builder }}) SchemaType(types map[string]string) *{{ $builder }} { return b } +{{ $tt := title $t.String }} +// GoType overrides the default Go type with a custom one. +// +// field.{{ $tt }}("{{ $t }}"). +// GoType(pkg.{{ $tt }}(0)) +// +func (b *{{ $builder }}) GoType(typ interface{}) *{{ $builder }} { + b.desc.goType(typ, {{ $t }}Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *{{ $builder }}) Descriptor() *Descriptor { return b.desc @@ -167,6 +181,12 @@ func (b *{{ $builder }}) Descriptor() *Descriptor { {{ end }} +var ( + {{- range $t := $.Ints }} + {{ $t }}Type = reflect.TypeOf({{ $t }}(0)) + {{- end }} +) + {{ range $t := $.Floats }} {{ $builder := printf "%sBuilder" $t }} diff --git a/schema/field/numeric.go b/schema/field/numeric.go index 20937d285..7e7fd4f6a 100644 --- a/schema/field/numeric.go +++ b/schema/field/numeric.go @@ -4,7 +4,10 @@ package field -import "errors" +import ( + "errors" + "reflect" +) //go:generate go run gen/gen.go @@ -226,6 +229,16 @@ func (b *intBuilder) SchemaType(types map[string]string) *intBuilder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Int("int"). +// GoType(pkg.Int(0)) +// +func (b *intBuilder) GoType(typ interface{}) *intBuilder { + b.desc.goType(typ, intType) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *intBuilder) Descriptor() *Descriptor { return b.desc @@ -343,6 +356,16 @@ func (b *uintBuilder) SchemaType(types map[string]string) *uintBuilder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Uint("uint"). +// GoType(pkg.Uint(0)) +// +func (b *uintBuilder) GoType(typ interface{}) *uintBuilder { + b.desc.goType(typ, uintType) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *uintBuilder) Descriptor() *Descriptor { return b.desc @@ -470,6 +493,16 @@ func (b *int8Builder) SchemaType(types map[string]string) *int8Builder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Int8("int8"). +// GoType(pkg.Int8(0)) +// +func (b *int8Builder) GoType(typ interface{}) *int8Builder { + b.desc.goType(typ, int8Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *int8Builder) Descriptor() *Descriptor { return b.desc @@ -597,6 +630,16 @@ func (b *int16Builder) SchemaType(types map[string]string) *int16Builder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Int16("int16"). +// GoType(pkg.Int16(0)) +// +func (b *int16Builder) GoType(typ interface{}) *int16Builder { + b.desc.goType(typ, int16Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *int16Builder) Descriptor() *Descriptor { return b.desc @@ -724,6 +767,16 @@ func (b *int32Builder) SchemaType(types map[string]string) *int32Builder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Int32("int32"). +// GoType(pkg.Int32(0)) +// +func (b *int32Builder) GoType(typ interface{}) *int32Builder { + b.desc.goType(typ, int32Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *int32Builder) Descriptor() *Descriptor { return b.desc @@ -851,6 +904,16 @@ func (b *int64Builder) SchemaType(types map[string]string) *int64Builder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Int64("int64"). +// GoType(pkg.Int64(0)) +// +func (b *int64Builder) GoType(typ interface{}) *int64Builder { + b.desc.goType(typ, int64Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *int64Builder) Descriptor() *Descriptor { return b.desc @@ -968,6 +1031,16 @@ func (b *uint8Builder) SchemaType(types map[string]string) *uint8Builder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Uint8("uint8"). +// GoType(pkg.Uint8(0)) +// +func (b *uint8Builder) GoType(typ interface{}) *uint8Builder { + b.desc.goType(typ, uint8Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *uint8Builder) Descriptor() *Descriptor { return b.desc @@ -1085,6 +1158,16 @@ func (b *uint16Builder) SchemaType(types map[string]string) *uint16Builder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Uint16("uint16"). +// GoType(pkg.Uint16(0)) +// +func (b *uint16Builder) GoType(typ interface{}) *uint16Builder { + b.desc.goType(typ, uint16Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *uint16Builder) Descriptor() *Descriptor { return b.desc @@ -1202,6 +1285,16 @@ func (b *uint32Builder) SchemaType(types map[string]string) *uint32Builder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Uint32("uint32"). +// GoType(pkg.Uint32(0)) +// +func (b *uint32Builder) GoType(typ interface{}) *uint32Builder { + b.desc.goType(typ, uint32Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *uint32Builder) Descriptor() *Descriptor { return b.desc @@ -1319,11 +1412,34 @@ func (b *uint64Builder) SchemaType(types map[string]string) *uint64Builder { return b } +// GoType overrides the default Go type with a custom one. +// +// field.Uint64("uint64"). +// GoType(pkg.Uint64(0)) +// +func (b *uint64Builder) GoType(typ interface{}) *uint64Builder { + b.desc.goType(typ, uint64Type) + return b +} + // Descriptor implements the ent.Field interface by returning its descriptor. func (b *uint64Builder) Descriptor() *Descriptor { return b.desc } +var ( + intType = reflect.TypeOf(int(0)) + uintType = reflect.TypeOf(uint(0)) + int8Type = reflect.TypeOf(int8(0)) + int16Type = reflect.TypeOf(int16(0)) + int32Type = reflect.TypeOf(int32(0)) + int64Type = reflect.TypeOf(int64(0)) + uint8Type = reflect.TypeOf(uint8(0)) + uint16Type = reflect.TypeOf(uint16(0)) + uint32Type = reflect.TypeOf(uint32(0)) + uint64Type = reflect.TypeOf(uint64(0)) +) + // float64Builder is the builder for float fields. type float64Builder struct { desc *Descriptor