entc/gen: add default-funcs and validators for userdefined id (#436)

Fixes #432
This commit is contained in:
Ariel Mashraki
2020-04-18 12:28:50 +03:00
committed by GitHub
parent 3342b85580
commit 3c6a04f884
12 changed files with 74 additions and 20 deletions

File diff suppressed because one or more lines are too long

View File

@@ -34,8 +34,9 @@ type {{ $builder }} struct {
// Save creates the {{ $.Name }} in the database.
func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) (*{{ $.Name }}, error) {
{{- $mutation := print $receiver ".mutation" }}
{{- range $_, $f := $.Fields }}
{{- if or $f.Default (not $f.Optional) }}
{{- $fields := $.Fields }}{{ if $.ID.UserDefined }}{{ $fields = append $fields $.ID }}{{ end }}
{{- range $f := $fields }}
{{- if or $f.Default (and (not $f.Optional) (ne $f $.ID)) }}
if _, ok := {{ $mutation }}.{{ $f.MutationGet }}(); !ok {
{{- if $f.Default }}
v := {{ $.Package }}.{{ $f.DefaultName }}{{ if or $f.IsTime $f.IsUUID }}(){{ end }}
@@ -53,7 +54,7 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) (*{{ $.Name }},
}
{{- end }}
{{- end }}
{{- range $_, $e := $.Edges }}
{{- range $e := $.Edges }}
{{- if not $e.Optional }}
{{- if $e.Unique }}
if _, ok := {{ $mutation }}.{{ $e.StructField }}ID(); !ok {

View File

@@ -57,7 +57,8 @@ const (
{{- if $.HasPolicy }}
Policy ent.Policy
{{- end }}
{{- range $f := $.Fields }}
{{- $fields := $.Fields }}{{ if $.ID.UserDefined }}{{ $fields = append $fields $.ID }}{{ end }}
{{- range $f := $fields }}
{{- if and $f.Default (not $f.IsEnum) }}
{{- $default := $f.DefaultName }}
// {{ $default }} holds the default value on creation for the {{ $f.Name }} field.
@@ -79,7 +80,7 @@ const (
{{ end }}
{{/* define custom type for enum fields */}}
{{ range $_, $f := $.Fields -}}
{{ range $f := $.Fields -}}
{{ if $f.IsEnum }}
{{/* omit the package name from the type. */}}
{{ $enum := trimPackage $f.Type.String $.Package }}

View File

@@ -110,11 +110,12 @@ func init() {
{{- end }}
}
{{- end }}
{{- with $n.Fields }}
{{- $fields := $n.Fields }}{{ if $n.ID.UserDefined }}{{ $fields = append $fields $n.ID }}{{ end }}
{{- with $fields }}
{{ $pkg }}Fields := {{ $schema }}.{{ $n.Name }}{}.Fields()
_ = {{ $pkg }}Fields
{{- end }}
{{- range $i, $f := $n.Fields }}
{{- range $i, $f := $fields }}
{{- $desc := print $pkg "Desc" $f.StructField }}
{{- /* enum default values handled near their declarations (in type package). */}}
{{- if or (and $f.Default (not $f.IsEnum)) $f.UpdateDefault $f.Validators }}

View File

@@ -222,7 +222,11 @@ func (t Type) HasAssoc(name string) (*Edge, bool) {
// HasValidators reports if any of the type's field has validators.
func (t Type) HasValidators() bool {
for _, f := range t.Fields {
fields := t.Fields
if t.ID.UserDefined {
fields = append(fields, t.ID)
}
for _, f := range fields {
if f.Validators > 0 {
return true
}
@@ -232,7 +236,11 @@ func (t Type) HasValidators() bool {
// HasDefault reports if any of this type's fields has default value on creation.
func (t Type) HasDefault() bool {
for _, f := range t.Fields {
fields := t.Fields
if t.ID.UserDefined {
fields = append(fields, t.ID)
}
for _, f := range fields {
if f.Default {
return true
}

View File

@@ -91,10 +91,11 @@ func CustomID(t *testing.T, client *ent.Client) {
require.Equal(t, 3, hub.ID)
require.Equal(t, []int{1, 5}, hub.QueryUsers().Order(ent.Asc(user.FieldID)).IDsX(ctx))
blb := client.Blob.Create().SetID(uuid.New()).SaveX(ctx)
require.NotEmpty(t, blb.ID)
require.NotEmpty(t, blb.UUID)
chd := client.Blob.Create().SetID(uuid.New()).SetParent(blb).SaveX(ctx)
blb := client.Blob.Create().SaveX(ctx)
require.NotEmpty(t, blb.ID, "use default value")
id := uuid.New()
chd := client.Blob.Create().SetID(id).SetParent(blb).SaveX(ctx)
require.Equal(t, id, chd.ID, "use provided id")
require.Equal(t, blb.ID, chd.QueryParent().OnlyX(ctx).ID)
lnk := client.Blob.Create().SetID(uuid.New()).AddLinks(chd, blb).SaveX(ctx)
require.Equal(t, 2, lnk.QueryLinks().CountX(ctx))

View File

@@ -52,4 +52,6 @@ var (
var (
// DefaultUUID holds the default value on creation for the uuid field.
DefaultUUID func() uuid.UUID
// DefaultID holds the default value on creation for the id field.
DefaultID func() uuid.UUID
)

View File

@@ -75,6 +75,10 @@ func (bc *BlobCreate) Save(ctx context.Context) (*Blob, error) {
v := blob.DefaultUUID()
bc.mutation.SetUUID(v)
}
if _, ok := bc.mutation.ID(); !ok {
v := blob.DefaultID()
bc.mutation.SetID(v)
}
var (
err error
node *Blob

View File

@@ -61,3 +61,8 @@ var (
// primary key for the friends relation (M2M).
FriendsPrimaryKey = []string{"pet_id", "friend_id"}
)
var (
// IDValidator is a validator for the "id" field. It is called by the builders before save.
IDValidator func(string) error
)

View File

@@ -100,6 +100,11 @@ func (pc *PetCreate) SetBestFriend(p *Pet) *PetCreate {
// Save creates the Pet in the database.
func (pc *PetCreate) Save(ctx context.Context) (*Pet, error) {
if v, ok := pc.mutation.ID(); ok {
if err := pet.IDValidator(v); err != nil {
return nil, fmt.Errorf("ent: validator failed for field \"id\": %v", err)
}
}
var (
err error
node *Pet

View File

@@ -8,6 +8,7 @@ package ent
import (
"github.com/facebookincubator/ent/entc/integration/customid/ent/blob"
"github.com/facebookincubator/ent/entc/integration/customid/ent/pet"
"github.com/facebookincubator/ent/entc/integration/customid/ent/schema"
"github.com/google/uuid"
)
@@ -22,4 +23,28 @@ func init() {
blobDescUUID := blobFields[1].Descriptor()
// blob.DefaultUUID holds the default value on creation for the uuid field.
blob.DefaultUUID = blobDescUUID.Default.(func() uuid.UUID)
// blobDescID is the schema descriptor for id field.
blobDescID := blobFields[0].Descriptor()
// blob.DefaultID holds the default value on creation for the id field.
blob.DefaultID = blobDescID.Default.(func() uuid.UUID)
petFields := schema.Pet{}.Fields()
_ = petFields
// petDescID is the schema descriptor for id field.
petDescID := petFields[0].Descriptor()
// pet.IDValidator is a validator for the "id" field. It is called by the builders before save.
pet.IDValidator = func() func(string) error {
validators := petDescID.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(id string) error {
for _, fn := range fns {
if err := fn(id); err != nil {
return err
}
}
return nil
}
}()
}

View File

@@ -20,7 +20,8 @@ type Blob struct {
// Fields of the Blob.
func (Blob) Fields() []ent.Field {
return []ent.Field{
field.UUID("id", uuid.UUID{}),
field.UUID("id", uuid.UUID{}).
Default(uuid.New),
field.UUID("uuid", uuid.UUID{}).
Default(uuid.New),
}