Merge pull request #1125 from facebook/gotypedefault

Gotypedefault
This commit is contained in:
Ariel Mashraki
2021-01-04 12:01:41 +02:00
committed by GitHub
12 changed files with 164 additions and 14 deletions

View File

@@ -7,7 +7,7 @@ linters-settings:
dupl:
threshold: 100
funlen:
lines: 120
lines: 130
statements: 80
goheader:
template: |-

View File

@@ -8,7 +8,9 @@ package fieldtype
import (
"fmt"
"net"
"github.com/facebook/ent/dialect/sql"
"github.com/facebook/ent/entc/integration/ent/role"
)
@@ -193,8 +195,14 @@ var (
ValidateOptionalInt32Validator func(int32) error
// NdirValidator is a validator for the "ndir" field. It is called by the builders before save.
NdirValidator func(string) error
// DefaultStr holds the default value on creation for the str field.
DefaultStr func() sql.NullString
// DefaultNullStr holds the default value on creation for the null_str field.
DefaultNullStr func() sql.NullString
// LinkValidator is a validator for the "link" field. It is called by the builders before save.
LinkValidator func(string) error
// DefaultIP holds the default value on creation for the ip field.
DefaultIP func() net.IP
// MACValidator is a validator for the "mac" field. It is called by the builders before save.
MACValidator func(string) error
)

View File

@@ -612,6 +612,18 @@ func (ftc *FieldTypeCreate) SaveX(ctx context.Context) *FieldType {
// defaults sets the default values of the builder before save.
func (ftc *FieldTypeCreate) defaults() {
if _, ok := ftc.mutation.Str(); !ok {
v := fieldtype.DefaultStr()
ftc.mutation.SetStr(v)
}
if _, ok := ftc.mutation.NullStr(); !ok {
v := fieldtype.DefaultNullStr()
ftc.mutation.SetNullStr(v)
}
if _, ok := ftc.mutation.IP(); !ok {
v := fieldtype.DefaultIP()
ftc.mutation.SetIP(v)
}
if _, ok := ftc.mutation.Role(); !ok {
v := fieldtype.DefaultRole
ftc.mutation.SetRole(v)

View File

@@ -7,8 +7,10 @@
package ent
import (
"net"
"time"
"github.com/facebook/ent/dialect/sql"
"github.com/facebook/ent/entc/integration/ent/card"
"github.com/facebook/ent/entc/integration/ent/fieldtype"
"github.com/facebook/ent/entc/integration/ent/file"
@@ -56,10 +58,22 @@ func init() {
fieldtypeDescNdir := fieldtypeFields[27].Descriptor()
// fieldtype.NdirValidator is a validator for the "ndir" field. It is called by the builders before save.
fieldtype.NdirValidator = fieldtypeDescNdir.Validators[0].(func(string) error)
// fieldtypeDescStr is the schema descriptor for str field.
fieldtypeDescStr := fieldtypeFields[28].Descriptor()
// fieldtype.DefaultStr holds the default value on creation for the str field.
fieldtype.DefaultStr = fieldtypeDescStr.Default.(func() sql.NullString)
// fieldtypeDescNullStr is the schema descriptor for null_str field.
fieldtypeDescNullStr := fieldtypeFields[29].Descriptor()
// fieldtype.DefaultNullStr holds the default value on creation for the null_str field.
fieldtype.DefaultNullStr = fieldtypeDescNullStr.Default.(func() sql.NullString)
// fieldtypeDescLink is the schema descriptor for link field.
fieldtypeDescLink := fieldtypeFields[30].Descriptor()
// fieldtype.LinkValidator is a validator for the "link" field. It is called by the builders before save.
fieldtype.LinkValidator = fieldtypeDescLink.Validators[0].(func(string) error)
// fieldtypeDescIP is the schema descriptor for ip field.
fieldtypeDescIP := fieldtypeFields[36].Descriptor()
// fieldtype.DefaultIP holds the default value on creation for the ip field.
fieldtype.DefaultIP = fieldtypeDescIP.Default.(func() net.IP)
// fieldtypeDescMAC is the schema descriptor for mac field.
fieldtypeDescMAC := fieldtypeFields[45].Descriptor()
// fieldtype.MACValidator is a validator for the "mac" field. It is called by the builders before save.

View File

@@ -79,11 +79,17 @@ func (FieldType) Fields() []ent.Field {
GoType(http.Dir("ndir")),
field.String("str").
Optional().
GoType(&sql.NullString{}),
GoType(&sql.NullString{}).
DefaultFunc(func() sql.NullString {
return sql.NullString{String: "default", Valid: true}
}),
field.String("null_str").
Optional().
Nillable().
GoType(&sql.NullString{}),
GoType(&sql.NullString{}).
DefaultFunc(func() sql.NullString {
return sql.NullString{String: "default", Valid: true}
}),
field.String("link").
Optional().
NotEmpty().
@@ -107,7 +113,10 @@ func (FieldType) Fields() []ent.Field {
GoType(&sql.NullTime{}),
field.Bytes("ip").
Optional().
GoType(net.IP("127.0.0.1")),
GoType(net.IP("127.0.0.1")).
DefaultFunc(func() net.IP {
return net.IP("127.0.0.1")
}),
field.Int("null_int64").
Optional().
GoType(&sql.NullInt64{}),

View File

@@ -7,7 +7,9 @@
package fieldtype
import (
"database/sql"
"fmt"
"net"
"github.com/facebook/ent/entc/integration/ent/role"
)
@@ -118,8 +120,14 @@ var (
ValidateOptionalInt32Validator func(int32) error
// NdirValidator is a validator for the "ndir" field. It is called by the builders before save.
NdirValidator func(string) error
// DefaultStr holds the default value on creation for the str field.
DefaultStr func() sql.NullString
// DefaultNullStr holds the default value on creation for the null_str field.
DefaultNullStr func() sql.NullString
// LinkValidator is a validator for the "link" field. It is called by the builders before save.
LinkValidator func(string) error
// DefaultIP holds the default value on creation for the ip field.
DefaultIP func() net.IP
// MACValidator is a validator for the "mac" field. It is called by the builders before save.
MACValidator func(string) error
)

View File

@@ -613,6 +613,18 @@ func (ftc *FieldTypeCreate) SaveX(ctx context.Context) *FieldType {
// defaults sets the default values of the builder before save.
func (ftc *FieldTypeCreate) defaults() {
if _, ok := ftc.mutation.Str(); !ok {
v := fieldtype.DefaultStr()
ftc.mutation.SetStr(v)
}
if _, ok := ftc.mutation.NullStr(); !ok {
v := fieldtype.DefaultNullStr()
ftc.mutation.SetNullStr(v)
}
if _, ok := ftc.mutation.IP(); !ok {
v := fieldtype.DefaultIP()
ftc.mutation.SetIP(v)
}
if _, ok := ftc.mutation.Role(); !ok {
v := fieldtype.DefaultRole
ftc.mutation.SetRole(v)

View File

@@ -7,6 +7,8 @@
package ent
import (
"database/sql"
"net"
"time"
"github.com/facebook/ent/entc/integration/ent/schema"
@@ -56,10 +58,22 @@ func init() {
fieldtypeDescNdir := fieldtypeFields[27].Descriptor()
// fieldtype.NdirValidator is a validator for the "ndir" field. It is called by the builders before save.
fieldtype.NdirValidator = fieldtypeDescNdir.Validators[0].(func(string) error)
// fieldtypeDescStr is the schema descriptor for str field.
fieldtypeDescStr := fieldtypeFields[28].Descriptor()
// fieldtype.DefaultStr holds the default value on creation for the str field.
fieldtype.DefaultStr = fieldtypeDescStr.Default.(func() sql.NullString)
// fieldtypeDescNullStr is the schema descriptor for null_str field.
fieldtypeDescNullStr := fieldtypeFields[29].Descriptor()
// fieldtype.DefaultNullStr holds the default value on creation for the null_str field.
fieldtype.DefaultNullStr = fieldtypeDescNullStr.Default.(func() sql.NullString)
// fieldtypeDescLink is the schema descriptor for link field.
fieldtypeDescLink := fieldtypeFields[30].Descriptor()
// fieldtype.LinkValidator is a validator for the "link" field. It is called by the builders before save.
fieldtype.LinkValidator = fieldtypeDescLink.Validators[0].(func(string) error)
// fieldtypeDescIP is the schema descriptor for ip field.
fieldtypeDescIP := fieldtypeFields[36].Descriptor()
// fieldtype.DefaultIP holds the default value on creation for the ip field.
fieldtype.DefaultIP = fieldtypeDescIP.Default.(func() net.IP)
// fieldtypeDescMAC is the schema descriptor for mac field.
fieldtypeDescMAC := fieldtypeFields[45].Descriptor()
// fieldtype.MACValidator is a validator for the "mac" field. It is called by the builders before save.

View File

@@ -61,8 +61,7 @@ func Types(t *testing.T, client *ent.Client) {
SetNillableInt64(math.MinInt64).
SetDir("dir").
SetNdir("ndir").
SetStr(sql.NullString{String: "str", Valid: true}).
SetNullStr(sql.NullString{String: "str", Valid: true}).
SetNullStr(sql.NullString{String: "not-default", Valid: true}).
SetLink(schema.Link{URL: link}).
SetNullLink(schema.Link{URL: link}).
SetRole(role.Admin).
@@ -79,10 +78,11 @@ func Types(t *testing.T, client *ent.Client) {
require.Equal(http.Dir("dir"), ft.Dir)
require.NotNil(*ft.Ndir)
require.Equal(http.Dir("ndir"), *ft.Ndir)
require.Equal("str", ft.Str.String)
require.Equal("str", ft.NullStr.String)
require.Equal("default", ft.Str.String)
require.Equal("not-default", ft.NullStr.String)
require.Equal("localhost", ft.Link.String())
require.Equal("localhost", ft.NullLink.String())
require.Equal(net.IP("127.0.0.1").String(), ft.IP.String())
mac, err := net.ParseMAC("3b:b3:6b:3c:10:79")
require.NoError(err)
dt, err := time.Parse(time.RFC3339, "1906-01-02T00:00:00+00:00")

View File

@@ -213,7 +213,7 @@ func (b *stringBuilder) Default(s string) *stringBuilder {
// field.String("cuid").
// DefaultFunc(cuid.New)
//
func (b *stringBuilder) DefaultFunc(fn func() string) *stringBuilder {
func (b *stringBuilder) DefaultFunc(fn interface{}) *stringBuilder {
b.desc.Default = fn
return b
}
@@ -295,6 +295,9 @@ func (b *stringBuilder) Annotations(annotations ...schema.Annotation) *stringBui
// Descriptor implements the ent.Field interface by returning its descriptor.
func (b *stringBuilder) Descriptor() *Descriptor {
if b.desc.Default != nil {
b.desc.checkDefaultFunc(stringType)
}
return b.desc
}
@@ -500,7 +503,7 @@ func (b *bytesBuilder) Default(v []byte) *bytesBuilder {
// field.Bytes("cuid").
// DefaultFunc(cuid.New)
//
func (b *bytesBuilder) DefaultFunc(fn func() []byte) *bytesBuilder {
func (b *bytesBuilder) DefaultFunc(fn interface{}) *bytesBuilder {
b.desc.Default = fn
return b
}
@@ -590,6 +593,9 @@ func (b *bytesBuilder) SchemaType(types map[string]string) *bytesBuilder {
// Descriptor implements the ent.Field interface by returning its descriptor.
func (b *bytesBuilder) Descriptor() *Descriptor {
if b.desc.Default != nil {
b.desc.checkDefaultFunc(bytesType)
}
return b.desc
}
@@ -911,6 +917,7 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
Ident: tv.String(),
PkgPath: tv.PkgPath(),
RType: &RType{
rtype: tv,
Name: tv.Name(),
Kind: tv.Kind(),
PkgPath: tv.PkgPath(),
@@ -948,6 +955,24 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
d.Info = info
}
func (d *Descriptor) checkDefaultFunc(expectType reflect.Type) {
typ := reflect.TypeOf(d.Default)
if typ.Kind() != reflect.Func || d.Err != nil {
return
}
err := fmt.Errorf("expect type (func() %s) for default value", d.Info)
if typ.NumIn() != 0 || typ.NumOut() != 1 {
d.Err = err
}
rtype := expectType
if d.Info.RType != nil {
rtype = d.Info.RType.rtype
}
if !typ.Out(0).AssignableTo(rtype) {
d.Err = err
}
}
var (
boolType = reflect.TypeOf(false)
bytesType = reflect.TypeOf([]byte(nil))

View File

@@ -185,9 +185,7 @@ func TestBytes(t *testing.T) {
fd = field.Bytes("uuid").
GoType(&uuid.UUID{}).
DefaultFunc(func() []byte {
return []byte("{}")
}).
DefaultFunc(uuid.New).
Descriptor()
assert.NoError(t, fd.Err)
assert.Equal(t, "uuid.UUID", fd.Info.Ident)
@@ -195,7 +193,7 @@ func TestBytes(t *testing.T) {
assert.Equal(t, "uuid.UUID", fd.Info.String())
assert.True(t, fd.Info.Nillable)
assert.True(t, fd.Info.ValueScanner())
assert.Equal(t, []byte("{}"), fd.Default.(func() []byte)())
assert.NotEmpty(t, fd.Default.(func() uuid.UUID)())
fd = field.Bytes("blob").GoType(1).Descriptor()
assert.Error(t, fd.Err)
@@ -205,6 +203,54 @@ func TestBytes(t *testing.T) {
assert.Error(t, fd.Err)
}
func TestBytes_DefaultFunc(t *testing.T) {
f1 := func() net.IP { return net.IP("0.0.0.0") }
fd := field.Bytes("ip").GoType(net.IP("127.0.0.1")).DefaultFunc(f1).Descriptor()
assert.NoError(t, fd.Err)
var _ []byte = f1()
fd = field.Bytes("ip").DefaultFunc(f1).Descriptor()
assert.NoError(t, fd.Err)
f2 := func() []byte { return []byte("0.0.0.0") }
var _ net.IP = f2()
fd = field.Bytes("ip").GoType(net.IP("127.0.0.1")).DefaultFunc(f2).Descriptor()
assert.NoError(t, fd.Err)
f3 := func() []uint8 { return []uint8("0.0.0.0") }
var _ net.IP = f3()
fd = field.Bytes("ip").GoType(net.IP("127.0.0.1")).DefaultFunc(f3).Descriptor()
assert.NoError(t, fd.Err)
fd = field.Bytes("ip").DefaultFunc(f3).Descriptor()
assert.NoError(t, fd.Err)
f4 := func() net.IPMask { return net.IPMask("ffff:ff80::") }
fd = field.Bytes("ip").GoType(net.IP("127.0.0.1")).DefaultFunc(f4).Descriptor()
assert.Error(t, fd.Err, "`var _ net.IP = f4()` should fail")
}
func TestString_DefaultFunc(t *testing.T) {
f1 := func() http.Dir { return "/tmp" }
fd := field.String("dir").GoType(http.Dir("/tmp")).DefaultFunc(f1).Descriptor()
assert.NoError(t, fd.Err)
fd = field.String("dir").DefaultFunc(f1).Descriptor()
assert.Error(t, fd.Err, "`var _ string = f1()` should fail")
f2 := func() string { return "/tmp" }
fd = field.String("dir").GoType(http.Dir("/tmp")).DefaultFunc(f2).Descriptor()
assert.Error(t, fd.Err, "`var _ http.Dir = f2()` should fail")
f3 := func() sql.NullString { return sql.NullString{} }
fd = field.String("str").GoType(&sql.NullString{}).DefaultFunc(f3).Descriptor()
assert.NoError(t, fd.Err)
type S string
f4 := func() S { return "" }
fd = field.String("str").GoType(http.Dir("/tmp")).DefaultFunc(f4).Descriptor()
assert.Error(t, fd.Err, "`var _ http.Dir = f4()` should fail")
}
func TestString(t *testing.T) {
fd := field.String("name").
DefaultFunc(func() string {

View File

@@ -179,6 +179,8 @@ type RType struct {
Kind reflect.Kind
PkgPath string
Methods map[string]struct{ In, Out []*RType }
// Used only for in-package checks.
rtype reflect.Type
}
// TypeEqual tests if the RType is equal to given reflect.Type.