schema/field: add GoType option for string fields (#500)

This commit is contained in:
Ariel Mashraki
2020-05-25 20:05:17 +03:00
committed by GitHub
parent 31690c7e60
commit 100d300094
58 changed files with 845 additions and 273 deletions

View File

@@ -5,6 +5,7 @@
package field
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
@@ -14,31 +15,6 @@ import (
"time"
)
// A Descriptor for field configuration.
type Descriptor struct {
Tag string // struct tag.
Size int // varchar size.
Name string // field name.
Info *TypeInfo // field type info.
Unique bool // unique index of field.
Nillable bool // nillable struct field.
Optional bool // nullable field in database.
Immutable bool // create-only field.
Default interface{} // default value on create.
UpdateDefault interface{} // default value on update.
Validators []interface{} // validator functions.
StorageKey string // sql column or gremlin property.
Enums []string // enum values.
Sensitive bool // sensitive info string field.
SchemaType map[string]string // override the schema type.
err error
}
// Err returns the error, if any, that was added by the field builder.
func (d *Descriptor) Err() error {
return d.err
}
// String returns a new Field with type string.
func String(name string) *stringBuilder {
return &stringBuilder{&Descriptor{
@@ -280,6 +256,16 @@ func (b *stringBuilder) SchemaType(types map[string]string) *stringBuilder {
return b
}
// GoType overrides the default Go type with a custom one.
//
// field.String("dir").
// GoType(http.Dir("dir"))
//
func (b *stringBuilder) GoType(typ interface{}) *stringBuilder {
b.desc.goType(typ, reflect.String)
return b
}
// Descriptor implements the ent.Field interface by returning its descriptor.
func (b *stringBuilder) Descriptor() *Descriptor {
return b.desc
@@ -697,3 +683,86 @@ func (b *uuidBuilder) SchemaType(types map[string]string) *uuidBuilder {
func (b *uuidBuilder) Descriptor() *Descriptor {
return b.desc
}
// A Descriptor for field configuration.
type Descriptor struct {
Tag string // struct tag.
Size int // varchar size.
Name string // field name.
Info *TypeInfo // field type info.
Unique bool // unique index of field.
Nillable bool // nillable struct field.
Optional bool // nullable field in database.
Immutable bool // create-only field.
Default interface{} // default value on create.
UpdateDefault interface{} // default value on update.
Validators []interface{} // validator functions.
StorageKey string // sql column or gremlin property.
Enums []string // enum values.
Sensitive bool // sensitive info string field.
SchemaType map[string]string // override the schema type.
err error
}
// Err returns the error, if any, that was added by the field builder.
func (d *Descriptor) Err() error {
return d.err
}
func (d *Descriptor) goType(typ interface{}, expectKind reflect.Kind) {
t := reflect.TypeOf(typ)
tv := indirect(t)
info := &TypeInfo{
Type: TypeString,
Ident: t.String(),
PkgPath: tv.PkgPath(),
RType: &RType{
Name: tv.Name(),
Kind: tv.Kind(),
PkgPath: tv.PkgPath(),
Methods: make(map[string]struct{ In, Out []*RType }, t.NumMethod()),
},
}
switch t.Kind() {
case reflect.Slice, reflect.Array, reflect.Ptr, reflect.Map:
info.Nillable = true
}
switch {
case tv.Kind() == expectKind:
case t.Implements(valueScannerType):
n := t.NumMethod()
for i := 0; i < n; i++ {
m := t.Method(i)
in := make([]*RType, m.Type.NumIn()-1)
for j := range in {
arg := m.Type.In(j + 1)
in[j] = &RType{Name: arg.Name(), Kind: arg.Kind(), PkgPath: arg.PkgPath()}
}
out := make([]*RType, m.Type.NumOut())
for j := range out {
ret := m.Type.Out(j)
out[j] = &RType{Name: ret.Name(), Kind: ret.Kind(), PkgPath: ret.PkgPath()}
}
info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
}
default:
d.err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectKind)
}
d.Info = info
}
var valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem()
// ValueScanner is the interface that groups the Value and the Scan methods.
type ValueScanner interface {
driver.Valuer
sql.Scanner
}
// indirect returns the type at the end of indirection.
func indirect(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}

View File

@@ -5,7 +5,10 @@
package field_test
import (
"database/sql"
"net/http"
"net/url"
"reflect"
"regexp"
"testing"
"time"
@@ -101,6 +104,48 @@ func TestString(t *testing.T) {
assert.True(t, fd.Unique)
assert.Len(t, fd.Validators, 2)
assert.True(t, fd.Sensitive)
fd = field.String("name").GoType(http.Dir("dir")).Descriptor()
assert.NoError(t, fd.Err())
assert.Equal(t, "http.Dir", fd.Info.Ident)
assert.Equal(t, "net/http", fd.Info.PkgPath)
assert.Equal(t, "http.Dir", fd.Info.String())
assert.False(t, fd.Info.Nillable)
assert.False(t, fd.Info.ValueScanner())
fd = field.String("name").GoType(http.MethodOptions).Descriptor()
assert.NoError(t, fd.Err())
assert.Equal(t, "string", fd.Info.Ident)
assert.Equal(t, "", fd.Info.PkgPath)
assert.Equal(t, "string", fd.Info.String())
assert.False(t, fd.Info.Nillable)
fd = field.String("nullable_name").GoType(&sql.NullString{}).Descriptor()
assert.NoError(t, fd.Err())
assert.Equal(t, "*sql.NullString", fd.Info.Ident)
assert.Equal(t, "database/sql", fd.Info.PkgPath)
assert.Equal(t, "*sql.NullString", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
assert.False(t, fd.Info.Stringer())
assert.True(t, fd.Info.RType.TypeEqual(reflect.TypeOf(sql.NullString{})))
assert.True(t, fd.Info.RType.TypeEqual(reflect.TypeOf(&sql.NullString{})))
type tURL struct {
field.ValueScanner
*url.URL
}
fd = field.String("nullable_url").GoType(&tURL{}).Descriptor()
assert.Equal(t, "*field_test.tURL", fd.Info.Ident)
assert.Equal(t, "github.com/facebookincubator/ent/schema/field_test", fd.Info.PkgPath)
assert.Equal(t, "*field_test.tURL", fd.Info.String())
assert.True(t, fd.Info.ValueScanner())
assert.True(t, fd.Info.Stringer())
fd = field.String("name").GoType(1).Descriptor()
assert.Error(t, fd.Err())
fd = field.String("name").GoType(struct{}{}).Descriptor()
assert.Error(t, fd.Err())
}
func TestTime(t *testing.T) {

View File

@@ -4,7 +4,11 @@
package field
import "strings"
import (
"fmt"
"reflect"
"strings"
)
// A Type represents a field type.
type Type uint8
@@ -72,6 +76,7 @@ type TypeInfo struct {
Ident string
PkgPath string
Nillable bool // slices or pointers.
RType *RType
}
// String returns the string representation of a type.
@@ -101,6 +106,18 @@ func (t TypeInfo) ConstName() string {
return t.Type.ConstName()
}
// ValueScanner indicates if this type implements the ValueScanner interface.
func (t TypeInfo) ValueScanner() bool {
return t.RType.implements(valueScannerType)
}
var stringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
// Stringer indicates if this type implements the Stringer interface.
func (t TypeInfo) Stringer() bool {
return t.RType.implements(stringerType)
}
var (
typeNames = [...]string{
TypeInvalid: "invalid",
@@ -132,3 +149,45 @@ var (
TypeBytes: "TypeBytes",
}
)
// RType holds a serializable reflect.Type information of
// Go object. Used by the entc package.
type RType struct {
Name string
Kind reflect.Kind
PkgPath string
Methods map[string]struct{ In, Out []*RType }
}
// TypeEqual tests if the RType is equal to given reflect.Type.
func (r *RType) TypeEqual(t reflect.Type) bool {
t = indirect(t)
return r.Name == t.Name() && r.Kind == t.Kind() && r.PkgPath == t.PkgPath()
}
func (r *RType) implements(typ reflect.Type) bool {
if r == nil {
return false
}
n := typ.NumMethod()
for i := 0; i < n; i++ {
m0 := typ.Method(i)
m1, ok := r.Methods[m0.Name]
if !ok || len(m1.In) != m0.Type.NumIn() || len(m1.Out) != m0.Type.NumOut() {
return false
}
in := m0.Type.NumIn()
for j := 0; j < in; j++ {
if !m1.In[j].TypeEqual(m0.Type.In(j)) {
return false
}
}
out := m0.Type.NumOut()
for j := 0; j < out; j++ {
if !m1.Out[j].TypeEqual(m0.Type.Out(j)) {
return false
}
}
}
return true
}