mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
schema/field: add GoType option for string fields (#500)
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
package field
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -14,31 +15,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Descriptor for field configuration.
|
||||
type Descriptor struct {
|
||||
Tag string // struct tag.
|
||||
Size int // varchar size.
|
||||
Name string // field name.
|
||||
Info *TypeInfo // field type info.
|
||||
Unique bool // unique index of field.
|
||||
Nillable bool // nillable struct field.
|
||||
Optional bool // nullable field in database.
|
||||
Immutable bool // create-only field.
|
||||
Default interface{} // default value on create.
|
||||
UpdateDefault interface{} // default value on update.
|
||||
Validators []interface{} // validator functions.
|
||||
StorageKey string // sql column or gremlin property.
|
||||
Enums []string // enum values.
|
||||
Sensitive bool // sensitive info string field.
|
||||
SchemaType map[string]string // override the schema type.
|
||||
err error
|
||||
}
|
||||
|
||||
// Err returns the error, if any, that was added by the field builder.
|
||||
func (d *Descriptor) Err() error {
|
||||
return d.err
|
||||
}
|
||||
|
||||
// String returns a new Field with type string.
|
||||
func String(name string) *stringBuilder {
|
||||
return &stringBuilder{&Descriptor{
|
||||
@@ -280,6 +256,16 @@ func (b *stringBuilder) SchemaType(types map[string]string) *stringBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// GoType overrides the default Go type with a custom one.
|
||||
//
|
||||
// field.String("dir").
|
||||
// GoType(http.Dir("dir"))
|
||||
//
|
||||
func (b *stringBuilder) GoType(typ interface{}) *stringBuilder {
|
||||
b.desc.goType(typ, reflect.String)
|
||||
return b
|
||||
}
|
||||
|
||||
// Descriptor implements the ent.Field interface by returning its descriptor.
|
||||
func (b *stringBuilder) Descriptor() *Descriptor {
|
||||
return b.desc
|
||||
@@ -697,3 +683,86 @@ func (b *uuidBuilder) SchemaType(types map[string]string) *uuidBuilder {
|
||||
func (b *uuidBuilder) Descriptor() *Descriptor {
|
||||
return b.desc
|
||||
}
|
||||
|
||||
// A Descriptor for field configuration.
|
||||
type Descriptor struct {
|
||||
Tag string // struct tag.
|
||||
Size int // varchar size.
|
||||
Name string // field name.
|
||||
Info *TypeInfo // field type info.
|
||||
Unique bool // unique index of field.
|
||||
Nillable bool // nillable struct field.
|
||||
Optional bool // nullable field in database.
|
||||
Immutable bool // create-only field.
|
||||
Default interface{} // default value on create.
|
||||
UpdateDefault interface{} // default value on update.
|
||||
Validators []interface{} // validator functions.
|
||||
StorageKey string // sql column or gremlin property.
|
||||
Enums []string // enum values.
|
||||
Sensitive bool // sensitive info string field.
|
||||
SchemaType map[string]string // override the schema type.
|
||||
err error
|
||||
}
|
||||
|
||||
// Err returns the error, if any, that was added by the field builder.
|
||||
func (d *Descriptor) Err() error {
|
||||
return d.err
|
||||
}
|
||||
|
||||
func (d *Descriptor) goType(typ interface{}, expectKind reflect.Kind) {
|
||||
t := reflect.TypeOf(typ)
|
||||
tv := indirect(t)
|
||||
info := &TypeInfo{
|
||||
Type: TypeString,
|
||||
Ident: t.String(),
|
||||
PkgPath: tv.PkgPath(),
|
||||
RType: &RType{
|
||||
Name: tv.Name(),
|
||||
Kind: tv.Kind(),
|
||||
PkgPath: tv.PkgPath(),
|
||||
Methods: make(map[string]struct{ In, Out []*RType }, t.NumMethod()),
|
||||
},
|
||||
}
|
||||
switch t.Kind() {
|
||||
case reflect.Slice, reflect.Array, reflect.Ptr, reflect.Map:
|
||||
info.Nillable = true
|
||||
}
|
||||
switch {
|
||||
case tv.Kind() == expectKind:
|
||||
case t.Implements(valueScannerType):
|
||||
n := t.NumMethod()
|
||||
for i := 0; i < n; i++ {
|
||||
m := t.Method(i)
|
||||
in := make([]*RType, m.Type.NumIn()-1)
|
||||
for j := range in {
|
||||
arg := m.Type.In(j + 1)
|
||||
in[j] = &RType{Name: arg.Name(), Kind: arg.Kind(), PkgPath: arg.PkgPath()}
|
||||
}
|
||||
out := make([]*RType, m.Type.NumOut())
|
||||
for j := range out {
|
||||
ret := m.Type.Out(j)
|
||||
out[j] = &RType{Name: ret.Name(), Kind: ret.Kind(), PkgPath: ret.PkgPath()}
|
||||
}
|
||||
info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
|
||||
}
|
||||
default:
|
||||
d.err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectKind)
|
||||
}
|
||||
d.Info = info
|
||||
}
|
||||
|
||||
var valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem()
|
||||
|
||||
// ValueScanner is the interface that groups the Value and the Scan methods.
|
||||
type ValueScanner interface {
|
||||
driver.Valuer
|
||||
sql.Scanner
|
||||
}
|
||||
|
||||
// indirect returns the type at the end of indirection.
|
||||
func indirect(t reflect.Type) reflect.Type {
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -5,7 +5,10 @@
|
||||
package field_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -101,6 +104,48 @@ func TestString(t *testing.T) {
|
||||
assert.True(t, fd.Unique)
|
||||
assert.Len(t, fd.Validators, 2)
|
||||
assert.True(t, fd.Sensitive)
|
||||
|
||||
fd = field.String("name").GoType(http.Dir("dir")).Descriptor()
|
||||
assert.NoError(t, fd.Err())
|
||||
assert.Equal(t, "http.Dir", fd.Info.Ident)
|
||||
assert.Equal(t, "net/http", fd.Info.PkgPath)
|
||||
assert.Equal(t, "http.Dir", fd.Info.String())
|
||||
assert.False(t, fd.Info.Nillable)
|
||||
assert.False(t, fd.Info.ValueScanner())
|
||||
|
||||
fd = field.String("name").GoType(http.MethodOptions).Descriptor()
|
||||
assert.NoError(t, fd.Err())
|
||||
assert.Equal(t, "string", fd.Info.Ident)
|
||||
assert.Equal(t, "", fd.Info.PkgPath)
|
||||
assert.Equal(t, "string", fd.Info.String())
|
||||
assert.False(t, fd.Info.Nillable)
|
||||
|
||||
fd = field.String("nullable_name").GoType(&sql.NullString{}).Descriptor()
|
||||
assert.NoError(t, fd.Err())
|
||||
assert.Equal(t, "*sql.NullString", fd.Info.Ident)
|
||||
assert.Equal(t, "database/sql", fd.Info.PkgPath)
|
||||
assert.Equal(t, "*sql.NullString", fd.Info.String())
|
||||
assert.True(t, fd.Info.Nillable)
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
assert.False(t, fd.Info.Stringer())
|
||||
assert.True(t, fd.Info.RType.TypeEqual(reflect.TypeOf(sql.NullString{})))
|
||||
assert.True(t, fd.Info.RType.TypeEqual(reflect.TypeOf(&sql.NullString{})))
|
||||
|
||||
type tURL struct {
|
||||
field.ValueScanner
|
||||
*url.URL
|
||||
}
|
||||
fd = field.String("nullable_url").GoType(&tURL{}).Descriptor()
|
||||
assert.Equal(t, "*field_test.tURL", fd.Info.Ident)
|
||||
assert.Equal(t, "github.com/facebookincubator/ent/schema/field_test", fd.Info.PkgPath)
|
||||
assert.Equal(t, "*field_test.tURL", fd.Info.String())
|
||||
assert.True(t, fd.Info.ValueScanner())
|
||||
assert.True(t, fd.Info.Stringer())
|
||||
|
||||
fd = field.String("name").GoType(1).Descriptor()
|
||||
assert.Error(t, fd.Err())
|
||||
fd = field.String("name").GoType(struct{}{}).Descriptor()
|
||||
assert.Error(t, fd.Err())
|
||||
}
|
||||
|
||||
func TestTime(t *testing.T) {
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
|
||||
package field
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A Type represents a field type.
|
||||
type Type uint8
|
||||
@@ -72,6 +76,7 @@ type TypeInfo struct {
|
||||
Ident string
|
||||
PkgPath string
|
||||
Nillable bool // slices or pointers.
|
||||
RType *RType
|
||||
}
|
||||
|
||||
// String returns the string representation of a type.
|
||||
@@ -101,6 +106,18 @@ func (t TypeInfo) ConstName() string {
|
||||
return t.Type.ConstName()
|
||||
}
|
||||
|
||||
// ValueScanner indicates if this type implements the ValueScanner interface.
|
||||
func (t TypeInfo) ValueScanner() bool {
|
||||
return t.RType.implements(valueScannerType)
|
||||
}
|
||||
|
||||
var stringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
|
||||
|
||||
// Stringer indicates if this type implements the Stringer interface.
|
||||
func (t TypeInfo) Stringer() bool {
|
||||
return t.RType.implements(stringerType)
|
||||
}
|
||||
|
||||
var (
|
||||
typeNames = [...]string{
|
||||
TypeInvalid: "invalid",
|
||||
@@ -132,3 +149,45 @@ var (
|
||||
TypeBytes: "TypeBytes",
|
||||
}
|
||||
)
|
||||
|
||||
// RType holds a serializable reflect.Type information of
|
||||
// Go object. Used by the entc package.
|
||||
type RType struct {
|
||||
Name string
|
||||
Kind reflect.Kind
|
||||
PkgPath string
|
||||
Methods map[string]struct{ In, Out []*RType }
|
||||
}
|
||||
|
||||
// TypeEqual tests if the RType is equal to given reflect.Type.
|
||||
func (r *RType) TypeEqual(t reflect.Type) bool {
|
||||
t = indirect(t)
|
||||
return r.Name == t.Name() && r.Kind == t.Kind() && r.PkgPath == t.PkgPath()
|
||||
}
|
||||
|
||||
func (r *RType) implements(typ reflect.Type) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
n := typ.NumMethod()
|
||||
for i := 0; i < n; i++ {
|
||||
m0 := typ.Method(i)
|
||||
m1, ok := r.Methods[m0.Name]
|
||||
if !ok || len(m1.In) != m0.Type.NumIn() || len(m1.Out) != m0.Type.NumOut() {
|
||||
return false
|
||||
}
|
||||
in := m0.Type.NumIn()
|
||||
for j := 0; j < in; j++ {
|
||||
if !m1.In[j].TypeEqual(m0.Type.In(j)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
out := m0.Type.NumOut()
|
||||
for j := 0; j < out; j++ {
|
||||
if !m1.Out[j].TypeEqual(m0.Type.Out(j)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user