entc/gen: add indexes, edges and hooks to mixin (#431)

This commit is contained in:
Ariel Mashraki
2020-04-20 13:40:56 +03:00
committed by GitHub
parent cec1dd1edf
commit 1c49159d18
22 changed files with 454 additions and 249 deletions

File diff suppressed because one or more lines are too long

View File

@@ -12,17 +12,18 @@ import (
"github.com/facebookincubator/ent"
"github.com/facebookincubator/ent/schema/edge"
"github.com/facebookincubator/ent/schema/field"
"github.com/facebookincubator/ent/schema/index"
)
// Schema represents an ent.Schema that was loaded from a complied user package.
type Schema struct {
Name string `json:"name,omitempty"`
Config ent.Config `json:"config,omitempty"`
Edges []*Edge `json:"edges,omitempty"`
Fields []*Field `json:"fields,omitempty"`
Indexes []*Index `json:"indexes,omitempty"`
Hooks int `json:"hooks,omitempty"`
Policy bool `json:"policy,omitempty"`
Name string `json:"name,omitempty"`
Config ent.Config `json:"config,omitempty"`
Edges []*Edge `json:"edges,omitempty"`
Fields []*Field `json:"fields,omitempty"`
Indexes []*Index `json:"indexes,omitempty"`
Hooks []*Position `json:"hooks,omitempty"`
Policy bool `json:"policy,omitempty"`
}
// Position describes a field position in the schema.
@@ -89,7 +90,7 @@ func NewEdge(ed *edge.Descriptor) *Edge {
return ne
}
// NewField creates an loaded field from edge descriptor.
// NewField creates an loaded field from field descriptor.
func NewField(fd *field.Descriptor) (*Field, error) {
sf := &Field{
Name: fd.Name,
@@ -126,6 +127,16 @@ func NewField(fd *field.Descriptor) (*Field, error) {
return sf, nil
}
// NewIndex creates an loaded index from index descriptor.
func NewIndex(idx *index.Descriptor) *Index {
return &Index{
Edges: idx.Edges,
Fields: idx.Fields,
Unique: idx.Unique,
StorageKey: idx.StorageKey,
}
}
// MarshalSchema encode the ent.Schema interface into a JSON
// that can be decoded into the Schema object object.
func MarshalSchema(schema ent.Interface) (b []byte, err error) {
@@ -133,6 +144,9 @@ func MarshalSchema(schema ent.Interface) (b []byte, err error) {
Config: schema.Config(),
Name: indirect(reflect.TypeOf(schema)).Name(),
}
if err := s.loadMixin(schema); err != nil {
return nil, fmt.Errorf("schema %q: %v", s.Name, err)
}
if err := s.loadFields(schema); err != nil {
return nil, fmt.Errorf("schema %q: %v", s.Name, err)
}
@@ -148,13 +162,7 @@ func MarshalSchema(schema ent.Interface) (b []byte, err error) {
return nil, fmt.Errorf("schema %q: %v", s.Name, err)
}
for _, idx := range indexes {
idx := idx.Descriptor()
s.Indexes = append(s.Indexes, &Index{
Edges: idx.Edges,
Fields: idx.Fields,
Unique: idx.Unique,
StorageKey: idx.StorageKey,
})
s.Indexes = append(s.Indexes, NewIndex(idx.Descriptor()))
}
if err := s.loadHooks(schema); err != nil {
return nil, fmt.Errorf("schema %q: %v", s.Name, err)
@@ -179,8 +187,8 @@ func UnmarshalSchema(buf []byte) (*Schema, error) {
return s, nil
}
// loadFields loads field to schema from ent.Interface.
func (s *Schema) loadFields(schema ent.Interface) error {
// loadMixin loads mixin to schema from ent.Interface.
func (s *Schema) loadMixin(schema ent.Interface) error {
mixin, err := safeMixin(schema)
if err != nil {
return err
@@ -202,7 +210,37 @@ func (s *Schema) loadFields(schema ent.Interface) error {
}
s.Fields = append(s.Fields, sf)
}
edges, err := safeEdges(mx)
if err != nil {
return err
}
for _, e := range edges {
s.Edges = append(s.Edges, NewEdge(e.Descriptor()))
}
indexes, err := safeIndexes(mx)
if err != nil {
return err
}
for _, idx := range indexes {
s.Indexes = append(s.Indexes, NewIndex(idx.Descriptor()))
}
hooks, err := safeHooks(mx)
if err != nil {
return err
}
for j := range hooks {
s.Hooks = append(s.Hooks, &Position{
Index: j,
MixedIn: true,
MixinIndex: i,
})
}
}
return nil
}
// loadFields loads field to schema from ent.Interface.
func (s *Schema) loadFields(schema ent.Interface) error {
fields, err := safeFields(schema)
if err != nil {
return err
@@ -223,7 +261,12 @@ func (s *Schema) loadHooks(schema ent.Interface) error {
if err != nil {
return err
}
s.Hooks = len(hooks)
for i := range hooks {
s.Hooks = append(s.Hooks, &Position{
Index: i,
MixedIn: false,
})
}
return nil
}
@@ -265,7 +308,7 @@ func safeFields(fd interface{ Fields() []ent.Field }) (fields []ent.Field, err e
}
// safeEdges wraps the schema.Edges method with recover to ensure no panics in marshaling.
func safeEdges(schema ent.Interface) (edges []ent.Edge, err error) {
func safeEdges(schema interface{ Edges() []ent.Edge }) (edges []ent.Edge, err error) {
defer func() {
if v := recover(); v != nil {
err = fmt.Errorf("schema.Edges panics: %v", v)
@@ -276,7 +319,7 @@ func safeEdges(schema ent.Interface) (edges []ent.Edge, err error) {
}
// safeIndexes wraps the schema.Indexes method with recover to ensure no panics in marshaling.
func safeIndexes(schema ent.Interface) (indexes []ent.Index, err error) {
func safeIndexes(schema interface{ Indexes() []ent.Index }) (indexes []ent.Index, err error) {
defer func() {
if v := recover(); v != nil {
err = fmt.Errorf("schema.Indexes panics: %v", v)
@@ -298,7 +341,7 @@ func safeMixin(schema ent.Interface) (mixin []ent.Mixin, err error) {
}
// safeHooks wraps the schema.Hooks method with recover to ensure no panics in marshaling.
func safeHooks(schema ent.Interface) (hooks []ent.Hook, err error) {
func safeHooks(schema interface{ Hooks() []ent.Hook }) (hooks []ent.Hook, err error) {
defer func() {
if v := recover(); v != nil {
err = fmt.Errorf("schema.Hooks panics: %v", v)

View File

@@ -14,6 +14,7 @@ import (
"github.com/facebookincubator/ent/schema/edge"
"github.com/facebookincubator/ent/schema/field"
"github.com/facebookincubator/ent/schema/index"
"github.com/facebookincubator/ent/schema/mixin"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
@@ -215,7 +216,9 @@ func TestMarshalDefaults(t *testing.T) {
require.True(t, schema.Fields[4].UpdateDefault)
}
type TimeMixin struct{}
type TimeMixin struct {
mixin.Schema
}
func (TimeMixin) Fields() []ent.Field {
return []ent.Field{
@@ -228,14 +231,37 @@ func (TimeMixin) Fields() []ent.Field {
}
}
type Mixin struct{}
type HooksMixin struct {
mixin.Schema
}
func (Mixin) Fields() []ent.Field {
func (HooksMixin) Fields() []ent.Field {
return []ent.Field{
field.String("boring"),
}
}
func (HooksMixin) Edges() []ent.Edge {
return []ent.Edge{
edge.To("user", User.Type).
Unique(),
}
}
func (HooksMixin) Indexes() []ent.Index {
return []ent.Index{
index.Fields("boring").
Edges("user"),
}
}
func (HooksMixin) Hooks() []ent.Hook {
return []ent.Hook{
func(ent.Mutator) ent.Mutator { return nil },
func(ent.Mutator) ent.Mutator { return nil },
}
}
type WithMixin struct {
ent.Schema
}
@@ -243,7 +269,7 @@ type WithMixin struct {
func (WithMixin) Mixin() []ent.Mixin {
return []ent.Mixin{
TimeMixin{},
Mixin{},
HooksMixin{},
}
}
@@ -253,6 +279,26 @@ func (WithMixin) Fields() []ent.Field {
}
}
func (WithMixin) Edges() []ent.Edge {
return []ent.Edge{
edge.To("owner", User.Type),
}
}
func (WithMixin) Indexes() []ent.Index {
return []ent.Index{
index.Fields("field").
Edges("owner").
Unique(),
}
}
func (WithMixin) Hooks() []ent.Hook {
return []ent.Hook{
func(ent.Mutator) ent.Mutator { return nil },
}
}
func TestMarshalMixin(t *testing.T) {
d := WithMixin{}
buf, err := MarshalSchema(d)
@@ -262,29 +308,67 @@ func TestMarshalMixin(t *testing.T) {
err = json.Unmarshal(buf, schema)
require.NoError(t, err)
require.Equal(t, "WithMixin", schema.Name)
require.Equal(t, "created_at", schema.Fields[0].Name)
require.True(t, schema.Fields[0].Default)
require.True(t, schema.Fields[0].Position.MixedIn)
require.Equal(t, 0, schema.Fields[0].Position.MixinIndex)
require.Equal(t, 0, schema.Fields[0].Position.Index)
t.Run("Fields", func(t *testing.T) {
require.Equal(t, "WithMixin", schema.Name)
require.Equal(t, "created_at", schema.Fields[0].Name)
require.True(t, schema.Fields[0].Default)
require.True(t, schema.Fields[0].Position.MixedIn)
require.Equal(t, 0, schema.Fields[0].Position.MixinIndex)
require.Equal(t, 0, schema.Fields[0].Position.Index)
require.Equal(t, "updated_at", schema.Fields[1].Name)
require.True(t, schema.Fields[1].Default)
require.True(t, schema.Fields[1].UpdateDefault)
require.True(t, schema.Fields[1].Position.MixedIn)
require.Equal(t, 0, schema.Fields[1].Position.MixinIndex)
require.Equal(t, 1, schema.Fields[1].Position.Index)
require.Equal(t, "updated_at", schema.Fields[1].Name)
require.True(t, schema.Fields[1].Default)
require.True(t, schema.Fields[1].UpdateDefault)
require.True(t, schema.Fields[1].Position.MixedIn)
require.Equal(t, 0, schema.Fields[1].Position.MixinIndex)
require.Equal(t, 1, schema.Fields[1].Position.Index)
require.Equal(t, "boring", schema.Fields[2].Name)
require.False(t, schema.Fields[2].Default)
require.False(t, schema.Fields[2].UpdateDefault)
require.True(t, schema.Fields[2].Position.MixedIn)
require.Equal(t, 1, schema.Fields[2].Position.MixinIndex)
require.Equal(t, 0, schema.Fields[2].Position.Index)
require.Equal(t, "boring", schema.Fields[2].Name)
require.False(t, schema.Fields[2].Default)
require.False(t, schema.Fields[2].UpdateDefault)
require.True(t, schema.Fields[2].Position.MixedIn)
require.Equal(t, 1, schema.Fields[2].Position.MixinIndex)
require.Equal(t, 0, schema.Fields[2].Position.Index)
require.Equal(t, "field", schema.Fields[3].Name)
require.False(t, schema.Fields[3].Default)
require.False(t, schema.Fields[3].Position.MixedIn)
require.Equal(t, 0, schema.Fields[3].Position.Index)
require.Equal(t, "field", schema.Fields[3].Name)
require.False(t, schema.Fields[3].Default)
require.False(t, schema.Fields[3].Position.MixedIn)
require.Equal(t, 0, schema.Fields[3].Position.Index)
})
t.Run("Hooks", func(t *testing.T) {
require.True(t, schema.Hooks[0].MixedIn)
require.True(t, schema.Hooks[1].MixedIn)
require.Equal(t, 1, schema.Hooks[0].MixinIndex)
require.Equal(t, 1, schema.Hooks[1].MixinIndex)
require.Equal(t, 0, schema.Hooks[0].Index)
require.Equal(t, 1, schema.Hooks[1].Index)
require.False(t, schema.Hooks[2].MixedIn)
require.Equal(t, 0, schema.Hooks[2].Index)
require.Equal(t, 0, schema.Hooks[2].MixinIndex)
})
t.Run("Edges", func(t *testing.T) {
require.Len(t, schema.Edges, 2)
require.Equal(t, "user", schema.Edges[0].Name)
require.Equal(t, "User", schema.Edges[0].Type)
require.True(t, schema.Edges[0].Unique)
require.Equal(t, "owner", schema.Edges[1].Name)
require.Equal(t, "User", schema.Edges[1].Type)
require.False(t, schema.Edges[1].Unique)
})
t.Run("Indexes", func(t *testing.T) {
require.Len(t, schema.Indexes, 2)
require.Equal(t, []string{"boring"}, schema.Indexes[0].Fields)
require.Equal(t, []string{"user"}, schema.Indexes[0].Edges)
require.False(t, schema.Indexes[0].Unique)
require.Equal(t, []string{"field"}, schema.Indexes[1].Fields)
require.Equal(t, []string{"owner"}, schema.Indexes[1].Edges)
require.True(t, schema.Indexes[1].Unique)
})
}