entc/gen: add indexes, edges and hooks to mixin (#431)

This commit is contained in:
Ariel Mashraki
2020-04-20 13:40:56 +03:00
committed by GitHub
parent cec1dd1edf
commit 1c49159d18
22 changed files with 454 additions and 249 deletions

File diff suppressed because one or more lines are too long

View File

@@ -84,6 +84,9 @@ func init() {
{{- range $n := $.Nodes }}
{{- $pkg := $n.Package }}
{{- $schema := base $.Config.Schema }}
{{- with $n.RuntimeMixin }}
{{ $pkg }}Mixin := {{ $schema }}.{{ $n.Name }}{}.Mixin()
{{- end }}
{{- if $n.HasPolicy }}
{{ print $pkg ".Policy" }} = {{ $schema }}.{{ $n.Name }}{}.Policy()
{{ print $pkg ".Hooks" }}[0] = func(next ent.Mutator) ent.Mutator {
@@ -95,20 +98,34 @@ func init() {
})
}
{{- end }}
{{- if $n.NumHooks }}
{{ print $pkg "Hooks" }} := {{ $schema }}.{{ $n.Name }}{}.Hooks()
for i, h := range {{ print $pkg "Hooks" }} {
{{ print $pkg ".Hooks" }}[i{{ if $n.HasPolicy }}+1{{ end }}] = h
}
{{- with $hooks := $n.HookPositions }}
{{- /* Hooks defined in schema mixins. */}}
{{- with $idx := $n.MixedInHooks }}
{{- range $i := $idx }}
{{ print $pkg "MixinHooks" $i }} := {{ $pkg }}Mixin[{{ $i }}].(interface{ Hooks() []ent.Hook }).Hooks()
{{- end }}
{{- end }}
{{- /* If there are hooks defined in the schema. */}}
{{- $schemaHooks := false }}{{ range $p := $hooks }}{{ if not $p.MixedIn }}{{ $schemaHooks = true }}{{ end }}{{ end }}
{{- if $schemaHooks }}
{{ print $pkg "Hooks" }} := {{ $schema }}.{{ $n.Name }}{}.Hooks()
{{- end }}
{{- range $i, $p := $hooks }}
{{- if $n.HasPolicy }}
{{ $i = add $i 1 }}
{{- end }}
{{- if $p.MixedIn }}
{{ print $pkg ".Hooks" }}[{{ $i }}] = {{ print $pkg "MixinHooks" $p.MixinIndex }}[{{ $p.Index }}]
{{- else }}
{{ print $pkg ".Hooks" }}[{{ $i }}] = {{ print $pkg "Hooks" }}[{{ $p.Index }}]
{{- end }}
{{- end }}
{{- end }}
{{- if or $n.HasDefault $n.HasValidators }}
{{- with $n.MixedInWithDefaultOrValidator }}
{{ $pkg }}Mixin := {{ $schema }}.{{ $n.Name }}{}.Mixin()
{{ $pkg }}MixinFields := [...][]ent.Field{
{{- range $i, $_ := xrange $n.NumMixin }}
{{ $pkg }}Mixin[{{ $i }}].Fields(),
{{- end }}
}
{{- with $idx := $n.MixedInFields }}
{{- range $i := $idx }}
{{ print $pkg "MixinFields" $i }} := {{ $pkg }}Mixin[{{ $i }}].Fields()
{{- end }}
{{- end }}
{{- $fields := $n.Fields }}{{ if $n.ID.UserDefined }}{{ $fields = append $fields $n.ID }}{{ end }}
{{- with $fields }}
@@ -119,11 +136,10 @@ func init() {
{{- $desc := print $pkg "Desc" $f.StructField }}
{{- /* enum default values handled near their declarations (in type package). */}}
{{- if or (and $f.Default (not $f.IsEnum)) $f.UpdateDefault $f.Validators }}
// {{ $desc }} is the schema descriptor for {{ $f.Name }} field.
{{- if $f.Position.MixedIn }}
// {{ $desc }} is the schema descriptor for {{ $f.Name }} field.
{{ $desc }} := {{ $pkg }}MixinFields[{{ $f.Position.MixinIndex }}][{{ $f.Position.Index }}].Descriptor()
{{ $desc }} := {{ print $pkg "MixinFields" $f.Position.MixinIndex }}[{{ $f.Position.Index }}].Descriptor()
{{- else }}
// {{ $desc }} is the schema descriptor for {{ $f.Name }} field.
{{ $desc }} := {{ $pkg }}Fields[{{ $f.Position.Index }}].Descriptor()
{{- end }}
{{- end }}

View File

@@ -21,12 +21,13 @@ import (
"github.com/facebookincubator/ent/schema/field"
)
// The following types and their exported methods used by the codegen
// to generate the assets.
type (
// Type represents one node-type in the graph, its relations and
// the information it holds.
Type struct {
*Config
// schema definition.
schema *load.Schema
// Name holds the type/ent name.
Name string
@@ -46,7 +47,6 @@ type (
// Field holds the information of a type field used for the templates.
Field struct {
// field definition.
def *load.Field
// Name is the name of this field in the database schema.
Name string
@@ -288,14 +288,35 @@ func (t Type) FKEdges() (edges []*Edge) {
return
}
// MixedInWithDefaultOrValidator returns all mixed-in fields with default values for creation or update.
func (t Type) MixedInWithDefaultOrValidator() (fields []*Field) {
// RuntimeMixin returns schema mixin that needs to be loaded at
// runtime. For example, for default values, validators or hooks.
func (t Type) RuntimeMixin() bool {
return len(t.MixedInFields()) > 0 || len(t.MixedInHooks()) > 0
}
// MixedInFields returns the indices of mixin holds runtime code.
func (t Type) MixedInFields() []int {
idx := make(map[int]struct{})
for _, f := range t.Fields {
if f.Position != nil && f.Position.MixedIn && (f.Default || f.UpdateDefault || f.Validators > 0) {
fields = append(fields, f)
idx[f.Position.MixinIndex] = struct{}{}
}
}
return
return sortedKeys(idx)
}
// MixedInHooks returns the indices of mixin with hooks.
func (t Type) MixedInHooks() []int {
if t.schema == nil {
return nil
}
idx := make(map[int]struct{})
for _, h := range t.schema.Hooks {
if h.MixedIn {
idx[h.MixinIndex] = struct{}{}
}
}
return sortedKeys(idx)
}
// NumMixin returns the type's mixin count.
@@ -484,11 +505,19 @@ func (t Type) SiblingImports() []string {
// NumHooks returns the number of hooks declared in the type schema.
func (t Type) NumHooks() int {
if t.schema != nil {
return t.schema.Hooks
return len(t.schema.Hooks)
}
return 0
}
// HookPositions returns the position information of hooks declared in the type schema.
func (t Type) HookPositions() []*load.Position {
if t.schema != nil {
return t.schema.Hooks
}
return nil
}
// HasPolicy returns whether a privacy policy was declared in the type schema.
func (t Type) HasPolicy() bool {
if t.schema != nil {
@@ -920,3 +949,12 @@ func names(ids ...string) map[string]struct{} {
}
return m
}
func sortedKeys(m map[int]struct{}) []int {
s := make([]int, 0, len(m))
for k := range m {
s = append(s, k)
}
sort.Ints(s)
return s
}

View File

@@ -134,10 +134,8 @@ func TestType_Receiver(t *testing.T) {
}
}
func TestType_MixedInWithDefaultOrValidator(t *testing.T) {
position := &load.Position{
MixedIn: true,
}
func TestType_WithRuntimeMixin(t *testing.T) {
position := &load.Position{MixedIn: true}
typ := &Type{
Fields: []*Field{
{Default: true, Position: position},
@@ -145,8 +143,7 @@ func TestType_MixedInWithDefaultOrValidator(t *testing.T) {
{Validators: 1, Position: position},
},
}
fields := typ.MixedInWithDefaultOrValidator()
require.Equal(t, 3, len(fields))
require.True(t, typ.RuntimeMixin())
}
func TestType_TagTypes(t *testing.T) {

View File

@@ -16,8 +16,6 @@ import (
"github.com/facebookincubator/ent/entc/integration/ent/groupinfo"
"github.com/facebookincubator/ent/entc/integration/ent/schema"
"github.com/facebookincubator/ent/entc/integration/ent/user"
"github.com/facebookincubator/ent"
)
// The init function reads all schema descriptors with runtime
@@ -25,17 +23,15 @@ import (
// to their package variables.
func init() {
cardMixin := schema.Card{}.Mixin()
cardMixinFields := [...][]ent.Field{
cardMixin[0].Fields(),
}
cardMixinFields0 := cardMixin[0].Fields()
cardFields := schema.Card{}.Fields()
_ = cardFields
// cardDescCreateTime is the schema descriptor for create_time field.
cardDescCreateTime := cardMixinFields[0][0].Descriptor()
cardDescCreateTime := cardMixinFields0[0].Descriptor()
// card.DefaultCreateTime holds the default value on creation for the create_time field.
card.DefaultCreateTime = cardDescCreateTime.Default.(func() time.Time)
// cardDescUpdateTime is the schema descriptor for update_time field.
cardDescUpdateTime := cardMixinFields[0][1].Descriptor()
cardDescUpdateTime := cardMixinFields0[1].Descriptor()
// card.DefaultUpdateTime holds the default value on creation for the update_time field.
card.DefaultUpdateTime = cardDescUpdateTime.Default.(func() time.Time)
// card.UpdateDefaultUpdateTime holds the default value on update for the update_time field.
@@ -117,13 +113,11 @@ func init() {
// groupinfo.DefaultMaxUsers holds the default value on creation for the max_users field.
groupinfo.DefaultMaxUsers = groupinfoDescMaxUsers.Default.(int)
userMixin := schema.User{}.Mixin()
userMixinFields := [...][]ent.Field{
userMixin[0].Fields(),
}
userMixinFields0 := userMixin[0].Fields()
userFields := schema.User{}.Fields()
_ = userFields
// userDescOptionalInt is the schema descriptor for optional_int field.
userDescOptionalInt := userMixinFields[0][0].Descriptor()
userDescOptionalInt := userMixinFields0[0].Descriptor()
// user.OptionalIntValidator is a validator for the "optional_int" field. It is called by the builders before save.
user.OptionalIntValidator = userDescOptionalInt.Validators[0].(func(int) error)
// userDescLast is the schema descriptor for last field.

View File

@@ -8,7 +8,7 @@ import (
"github.com/facebookincubator/ent"
"github.com/facebookincubator/ent/schema/edge"
"github.com/facebookincubator/ent/schema/field"
"github.com/facebookincubator/ent/schema/schemautil"
"github.com/facebookincubator/ent/schema/mixin"
)
// Card holds the schema definition for the CreditCard entity.
@@ -18,7 +18,7 @@ type Card struct {
func (Card) Mixin() []ent.Mixin {
return []ent.Mixin{
schemautil.TimeMixin{},
mixin.Time{},
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/facebookincubator/ent"
"github.com/facebookincubator/ent/schema/edge"
"github.com/facebookincubator/ent/schema/field"
"github.com/facebookincubator/ent/schema/mixin"
)
// User holds the schema for the user entity.
@@ -63,7 +64,9 @@ func (User) Edges() []ent.Edge {
}
// UserMixin composes create/update time mixin.
type UserMixin struct{}
type UserMixin struct {
mixin.Schema
}
// Fields of the time mixin.
func (UserMixin) Fields() []ent.Field {

View File

@@ -16,8 +16,6 @@ import (
"github.com/facebookincubator/ent/entc/integration/gremlin/ent/group"
"github.com/facebookincubator/ent/entc/integration/gremlin/ent/groupinfo"
"github.com/facebookincubator/ent/entc/integration/gremlin/ent/user"
"github.com/facebookincubator/ent"
)
// The init function reads all schema descriptors with runtime
@@ -25,17 +23,15 @@ import (
// to their package variables.
func init() {
cardMixin := schema.Card{}.Mixin()
cardMixinFields := [...][]ent.Field{
cardMixin[0].Fields(),
}
cardMixinFields0 := cardMixin[0].Fields()
cardFields := schema.Card{}.Fields()
_ = cardFields
// cardDescCreateTime is the schema descriptor for create_time field.
cardDescCreateTime := cardMixinFields[0][0].Descriptor()
cardDescCreateTime := cardMixinFields0[0].Descriptor()
// card.DefaultCreateTime holds the default value on creation for the create_time field.
card.DefaultCreateTime = cardDescCreateTime.Default.(func() time.Time)
// cardDescUpdateTime is the schema descriptor for update_time field.
cardDescUpdateTime := cardMixinFields[0][1].Descriptor()
cardDescUpdateTime := cardMixinFields0[1].Descriptor()
// card.DefaultUpdateTime holds the default value on creation for the update_time field.
card.DefaultUpdateTime = cardDescUpdateTime.Default.(func() time.Time)
// card.UpdateDefaultUpdateTime holds the default value on update for the update_time field.
@@ -117,13 +113,11 @@ func init() {
// groupinfo.DefaultMaxUsers holds the default value on creation for the max_users field.
groupinfo.DefaultMaxUsers = groupinfoDescMaxUsers.Default.(int)
userMixin := schema.User{}.Mixin()
userMixinFields := [...][]ent.Field{
userMixin[0].Fields(),
}
userMixinFields0 := userMixin[0].Fields()
userFields := schema.User{}.Fields()
_ = userFields
// userDescOptionalInt is the schema descriptor for optional_int field.
userDescOptionalInt := userMixinFields[0][0].Descriptor()
userDescOptionalInt := userMixinFields0[0].Descriptor()
// user.OptionalIntValidator is a validator for the "optional_int" field. It is called by the builders before save.
user.OptionalIntValidator = userDescOptionalInt.Validators[0].(func(int) error)
// userDescLast is the schema descriptor for last field.

View File

@@ -55,7 +55,7 @@ var ForeignKeys = []string{
// import _ "github.com/facebookincubator/ent/entc/integration/hooks/ent/runtime"
//
var (
Hooks [2]ent.Hook
Hooks [3]ent.Hook
// DefaultNumber holds the default value on creation for the number field.
DefaultNumber string
// NumberValidator is a validator for the "number" field. It is called by the builders before save.

View File

@@ -11,16 +11,20 @@ import (
"github.com/facebookincubator/ent/entc/integration/hooks/ent/card"
"github.com/facebookincubator/ent/entc/integration/hooks/ent/schema"
"github.com/facebookincubator/ent"
)
// The init function reads all schema descriptors with runtime
// code (default values, validators or hooks) and stitches it
// to their package variables.
func init() {
cardMixin := schema.Card{}.Mixin()
cardMixinHooks0 := cardMixin[0].(interface{ Hooks() []ent.Hook }).Hooks()
cardHooks := schema.Card{}.Hooks()
for i, h := range cardHooks {
card.Hooks[i] = h
}
card.Hooks[0] = cardMixinHooks0[0]
card.Hooks[1] = cardHooks[0]
card.Hooks[2] = cardHooks[1]
cardFields := schema.Card{}.Fields()
_ = cardFields
// cardDescNumber is the schema descriptor for number field.

View File

@@ -9,21 +9,38 @@ import (
"fmt"
"time"
"github.com/facebookincubator/ent/entc/integration/hooks/ent/card"
"github.com/facebookincubator/ent"
gen "github.com/facebookincubator/ent/entc/integration/hooks/ent"
"github.com/facebookincubator/ent/entc/integration/hooks/ent/card"
"github.com/facebookincubator/ent/entc/integration/hooks/ent/hook"
"github.com/facebookincubator/ent/schema/edge"
"github.com/facebookincubator/ent/schema/field"
"github.com/facebookincubator/ent/schema/mixin"
)
// RejectMany rejects all update operations
// that are not on a specific entity.
type RejectUpdate struct {
mixin.Schema
}
func (RejectUpdate) Hooks() []ent.Hook {
return []ent.Hook{
hook.Reject(ent.OpUpdate),
}
}
// Card holds the schema definition for the CreditCard entity.
type Card struct {
ent.Schema
}
func (Card) Mixin() []ent.Mixin {
return []ent.Mixin{
RejectUpdate{},
}
}
func (Card) Hooks() []ent.Hook {
return []ent.Hook{
hook.On(

View File

@@ -35,6 +35,8 @@ func TestSchemaHooks(t *testing.T) {
})
})
client.Card.Create().SetNumber("1234").SaveX(ctx)
_, err = client.Card.Update().Save(ctx)
require.EqualError(t, err, "OpUpdate operation is not allowed")
}
func TestRuntimeHooks(t *testing.T) {

View File

@@ -1,69 +0,0 @@
//+build tests
package main
import (
"context"
"fmt"
"io"
"os"
"github.com/facebookincubator/ent/entc/integration/hooks/ent"
"github.com/facebookincubator/ent/entc/integration/hooks/ent/hook"
_ "github.com/facebookincubator/ent/entc/integration/hooks/ent/runtime"
_ "github.com/mattn/go-sqlite3"
)
func main() {
ctx := context.Background()
client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
if err != nil {
panic(err)
}
if err := client.Schema.Create(ctx); err != nil {
panic(err)
}
client.Use(func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
fmt.Println("start")
defer fmt.Println("end")
return next.Mutate(ctx, m)
})
})
client.Card.Use(func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
fmt.Printf("Before Hook\tOp: %s\tType: %s\tConcreteType: %T\n", m.Op(), m.Type(), m)
defer fmt.Println("Done!")
if ns, ok := m.(interface{ SetName(string) }); ok {
ns.SetName("hook name")
}
return next.Mutate(ctx, m)
})
})
client.Card.Use(func(next ent.Mutator) ent.Mutator {
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
fmt.Println("Concrete hook\t", m.Op())
return next.Mutate(ctx, m)
})
})
client.Use(hook.On(LogWithConfig(os.Stdout), ent.OpUpdate|ent.OpCreate))
u := client.Card.Create().SetNumber("A").SaveX(ctx)
u.Update().SetName("Boring2").SaveX(ctx)
client.Card.Update().SetName("foo").SaveX(ctx)
client.Card.DeleteOneID(u.ID).ExecX(ctx)
client.Card.Delete().ExecX(ctx)
}
func LogWithConfig(w io.Writer) ent.Hook {
if w == nil {
w = os.Stdout
}
return func(next ent.Mutator) ent.Mutator {
return hook.CardFunc(func(ctx context.Context, m *ent.CardMutation) (ent.Value, error) {
fmt.Fprintln(w, "Logging Hook:\t", m.Op())
return next.Mutate(ctx, m)
})
}
}

View File

@@ -45,9 +45,8 @@ func init() {
})
}
planetHooks := schema.Planet{}.Hooks()
for i, h := range planetHooks {
planet.Hooks[i+1] = h
}
planet.Hooks[1] = planetHooks[0]
planetFields := schema.Planet{}.Fields()
_ = planetFields
// planetDescName is the schema descriptor for name field.

File diff suppressed because one or more lines are too long

View File

@@ -12,17 +12,18 @@ import (
"github.com/facebookincubator/ent"
"github.com/facebookincubator/ent/schema/edge"
"github.com/facebookincubator/ent/schema/field"
"github.com/facebookincubator/ent/schema/index"
)
// Schema represents an ent.Schema that was loaded from a complied user package.
type Schema struct {
Name string `json:"name,omitempty"`
Config ent.Config `json:"config,omitempty"`
Edges []*Edge `json:"edges,omitempty"`
Fields []*Field `json:"fields,omitempty"`
Indexes []*Index `json:"indexes,omitempty"`
Hooks int `json:"hooks,omitempty"`
Policy bool `json:"policy,omitempty"`
Name string `json:"name,omitempty"`
Config ent.Config `json:"config,omitempty"`
Edges []*Edge `json:"edges,omitempty"`
Fields []*Field `json:"fields,omitempty"`
Indexes []*Index `json:"indexes,omitempty"`
Hooks []*Position `json:"hooks,omitempty"`
Policy bool `json:"policy,omitempty"`
}
// Position describes a field position in the schema.
@@ -89,7 +90,7 @@ func NewEdge(ed *edge.Descriptor) *Edge {
return ne
}
// NewField creates an loaded field from edge descriptor.
// NewField creates an loaded field from field descriptor.
func NewField(fd *field.Descriptor) (*Field, error) {
sf := &Field{
Name: fd.Name,
@@ -126,6 +127,16 @@ func NewField(fd *field.Descriptor) (*Field, error) {
return sf, nil
}
// NewIndex creates an loaded index from index descriptor.
func NewIndex(idx *index.Descriptor) *Index {
return &Index{
Edges: idx.Edges,
Fields: idx.Fields,
Unique: idx.Unique,
StorageKey: idx.StorageKey,
}
}
// MarshalSchema encode the ent.Schema interface into a JSON
// that can be decoded into the Schema object object.
func MarshalSchema(schema ent.Interface) (b []byte, err error) {
@@ -133,6 +144,9 @@ func MarshalSchema(schema ent.Interface) (b []byte, err error) {
Config: schema.Config(),
Name: indirect(reflect.TypeOf(schema)).Name(),
}
if err := s.loadMixin(schema); err != nil {
return nil, fmt.Errorf("schema %q: %v", s.Name, err)
}
if err := s.loadFields(schema); err != nil {
return nil, fmt.Errorf("schema %q: %v", s.Name, err)
}
@@ -148,13 +162,7 @@ func MarshalSchema(schema ent.Interface) (b []byte, err error) {
return nil, fmt.Errorf("schema %q: %v", s.Name, err)
}
for _, idx := range indexes {
idx := idx.Descriptor()
s.Indexes = append(s.Indexes, &Index{
Edges: idx.Edges,
Fields: idx.Fields,
Unique: idx.Unique,
StorageKey: idx.StorageKey,
})
s.Indexes = append(s.Indexes, NewIndex(idx.Descriptor()))
}
if err := s.loadHooks(schema); err != nil {
return nil, fmt.Errorf("schema %q: %v", s.Name, err)
@@ -179,8 +187,8 @@ func UnmarshalSchema(buf []byte) (*Schema, error) {
return s, nil
}
// loadFields loads field to schema from ent.Interface.
func (s *Schema) loadFields(schema ent.Interface) error {
// loadMixin loads mixin to schema from ent.Interface.
func (s *Schema) loadMixin(schema ent.Interface) error {
mixin, err := safeMixin(schema)
if err != nil {
return err
@@ -202,7 +210,37 @@ func (s *Schema) loadFields(schema ent.Interface) error {
}
s.Fields = append(s.Fields, sf)
}
edges, err := safeEdges(mx)
if err != nil {
return err
}
for _, e := range edges {
s.Edges = append(s.Edges, NewEdge(e.Descriptor()))
}
indexes, err := safeIndexes(mx)
if err != nil {
return err
}
for _, idx := range indexes {
s.Indexes = append(s.Indexes, NewIndex(idx.Descriptor()))
}
hooks, err := safeHooks(mx)
if err != nil {
return err
}
for j := range hooks {
s.Hooks = append(s.Hooks, &Position{
Index: j,
MixedIn: true,
MixinIndex: i,
})
}
}
return nil
}
// loadFields loads field to schema from ent.Interface.
func (s *Schema) loadFields(schema ent.Interface) error {
fields, err := safeFields(schema)
if err != nil {
return err
@@ -223,7 +261,12 @@ func (s *Schema) loadHooks(schema ent.Interface) error {
if err != nil {
return err
}
s.Hooks = len(hooks)
for i := range hooks {
s.Hooks = append(s.Hooks, &Position{
Index: i,
MixedIn: false,
})
}
return nil
}
@@ -265,7 +308,7 @@ func safeFields(fd interface{ Fields() []ent.Field }) (fields []ent.Field, err e
}
// safeEdges wraps the schema.Edges method with recover to ensure no panics in marshaling.
func safeEdges(schema ent.Interface) (edges []ent.Edge, err error) {
func safeEdges(schema interface{ Edges() []ent.Edge }) (edges []ent.Edge, err error) {
defer func() {
if v := recover(); v != nil {
err = fmt.Errorf("schema.Edges panics: %v", v)
@@ -276,7 +319,7 @@ func safeEdges(schema ent.Interface) (edges []ent.Edge, err error) {
}
// safeIndexes wraps the schema.Indexes method with recover to ensure no panics in marshaling.
func safeIndexes(schema ent.Interface) (indexes []ent.Index, err error) {
func safeIndexes(schema interface{ Indexes() []ent.Index }) (indexes []ent.Index, err error) {
defer func() {
if v := recover(); v != nil {
err = fmt.Errorf("schema.Indexes panics: %v", v)
@@ -298,7 +341,7 @@ func safeMixin(schema ent.Interface) (mixin []ent.Mixin, err error) {
}
// safeHooks wraps the schema.Hooks method with recover to ensure no panics in marshaling.
func safeHooks(schema ent.Interface) (hooks []ent.Hook, err error) {
func safeHooks(schema interface{ Hooks() []ent.Hook }) (hooks []ent.Hook, err error) {
defer func() {
if v := recover(); v != nil {
err = fmt.Errorf("schema.Hooks panics: %v", v)

View File

@@ -14,6 +14,7 @@ import (
"github.com/facebookincubator/ent/schema/edge"
"github.com/facebookincubator/ent/schema/field"
"github.com/facebookincubator/ent/schema/index"
"github.com/facebookincubator/ent/schema/mixin"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
@@ -215,7 +216,9 @@ func TestMarshalDefaults(t *testing.T) {
require.True(t, schema.Fields[4].UpdateDefault)
}
type TimeMixin struct{}
type TimeMixin struct {
mixin.Schema
}
func (TimeMixin) Fields() []ent.Field {
return []ent.Field{
@@ -228,14 +231,37 @@ func (TimeMixin) Fields() []ent.Field {
}
}
type Mixin struct{}
type HooksMixin struct {
mixin.Schema
}
func (Mixin) Fields() []ent.Field {
func (HooksMixin) Fields() []ent.Field {
return []ent.Field{
field.String("boring"),
}
}
func (HooksMixin) Edges() []ent.Edge {
return []ent.Edge{
edge.To("user", User.Type).
Unique(),
}
}
func (HooksMixin) Indexes() []ent.Index {
return []ent.Index{
index.Fields("boring").
Edges("user"),
}
}
func (HooksMixin) Hooks() []ent.Hook {
return []ent.Hook{
func(ent.Mutator) ent.Mutator { return nil },
func(ent.Mutator) ent.Mutator { return nil },
}
}
type WithMixin struct {
ent.Schema
}
@@ -243,7 +269,7 @@ type WithMixin struct {
func (WithMixin) Mixin() []ent.Mixin {
return []ent.Mixin{
TimeMixin{},
Mixin{},
HooksMixin{},
}
}
@@ -253,6 +279,26 @@ func (WithMixin) Fields() []ent.Field {
}
}
func (WithMixin) Edges() []ent.Edge {
return []ent.Edge{
edge.To("owner", User.Type),
}
}
func (WithMixin) Indexes() []ent.Index {
return []ent.Index{
index.Fields("field").
Edges("owner").
Unique(),
}
}
func (WithMixin) Hooks() []ent.Hook {
return []ent.Hook{
func(ent.Mutator) ent.Mutator { return nil },
}
}
func TestMarshalMixin(t *testing.T) {
d := WithMixin{}
buf, err := MarshalSchema(d)
@@ -262,29 +308,67 @@ func TestMarshalMixin(t *testing.T) {
err = json.Unmarshal(buf, schema)
require.NoError(t, err)
require.Equal(t, "WithMixin", schema.Name)
require.Equal(t, "created_at", schema.Fields[0].Name)
require.True(t, schema.Fields[0].Default)
require.True(t, schema.Fields[0].Position.MixedIn)
require.Equal(t, 0, schema.Fields[0].Position.MixinIndex)
require.Equal(t, 0, schema.Fields[0].Position.Index)
t.Run("Fields", func(t *testing.T) {
require.Equal(t, "WithMixin", schema.Name)
require.Equal(t, "created_at", schema.Fields[0].Name)
require.True(t, schema.Fields[0].Default)
require.True(t, schema.Fields[0].Position.MixedIn)
require.Equal(t, 0, schema.Fields[0].Position.MixinIndex)
require.Equal(t, 0, schema.Fields[0].Position.Index)
require.Equal(t, "updated_at", schema.Fields[1].Name)
require.True(t, schema.Fields[1].Default)
require.True(t, schema.Fields[1].UpdateDefault)
require.True(t, schema.Fields[1].Position.MixedIn)
require.Equal(t, 0, schema.Fields[1].Position.MixinIndex)
require.Equal(t, 1, schema.Fields[1].Position.Index)
require.Equal(t, "updated_at", schema.Fields[1].Name)
require.True(t, schema.Fields[1].Default)
require.True(t, schema.Fields[1].UpdateDefault)
require.True(t, schema.Fields[1].Position.MixedIn)
require.Equal(t, 0, schema.Fields[1].Position.MixinIndex)
require.Equal(t, 1, schema.Fields[1].Position.Index)
require.Equal(t, "boring", schema.Fields[2].Name)
require.False(t, schema.Fields[2].Default)
require.False(t, schema.Fields[2].UpdateDefault)
require.True(t, schema.Fields[2].Position.MixedIn)
require.Equal(t, 1, schema.Fields[2].Position.MixinIndex)
require.Equal(t, 0, schema.Fields[2].Position.Index)
require.Equal(t, "boring", schema.Fields[2].Name)
require.False(t, schema.Fields[2].Default)
require.False(t, schema.Fields[2].UpdateDefault)
require.True(t, schema.Fields[2].Position.MixedIn)
require.Equal(t, 1, schema.Fields[2].Position.MixinIndex)
require.Equal(t, 0, schema.Fields[2].Position.Index)
require.Equal(t, "field", schema.Fields[3].Name)
require.False(t, schema.Fields[3].Default)
require.False(t, schema.Fields[3].Position.MixedIn)
require.Equal(t, 0, schema.Fields[3].Position.Index)
require.Equal(t, "field", schema.Fields[3].Name)
require.False(t, schema.Fields[3].Default)
require.False(t, schema.Fields[3].Position.MixedIn)
require.Equal(t, 0, schema.Fields[3].Position.Index)
})
t.Run("Hooks", func(t *testing.T) {
require.True(t, schema.Hooks[0].MixedIn)
require.True(t, schema.Hooks[1].MixedIn)
require.Equal(t, 1, schema.Hooks[0].MixinIndex)
require.Equal(t, 1, schema.Hooks[1].MixinIndex)
require.Equal(t, 0, schema.Hooks[0].Index)
require.Equal(t, 1, schema.Hooks[1].Index)
require.False(t, schema.Hooks[2].MixedIn)
require.Equal(t, 0, schema.Hooks[2].Index)
require.Equal(t, 0, schema.Hooks[2].MixinIndex)
})
t.Run("Edges", func(t *testing.T) {
require.Len(t, schema.Edges, 2)
require.Equal(t, "user", schema.Edges[0].Name)
require.Equal(t, "User", schema.Edges[0].Type)
require.True(t, schema.Edges[0].Unique)
require.Equal(t, "owner", schema.Edges[1].Name)
require.Equal(t, "User", schema.Edges[1].Type)
require.False(t, schema.Edges[1].Unique)
})
t.Run("Indexes", func(t *testing.T) {
require.Len(t, schema.Indexes, 2)
require.Equal(t, []string{"boring"}, schema.Indexes[0].Fields)
require.Equal(t, []string{"user"}, schema.Indexes[0].Edges)
require.False(t, schema.Indexes[0].Unique)
require.Equal(t, []string{"field"}, schema.Indexes[1].Fields)
require.Equal(t, []string{"owner"}, schema.Indexes[1].Edges)
require.True(t, schema.Indexes[1].Unique)
})
}