mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
518 lines
13 KiB
Go
518 lines
13 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package load
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"math"
|
|
"net/http"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"entgo.io/ent"
|
|
"entgo.io/ent/schema"
|
|
"entgo.io/ent/schema/edge"
|
|
"entgo.io/ent/schema/field"
|
|
"entgo.io/ent/schema/index"
|
|
"entgo.io/ent/schema/mixin"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type OrderConfig struct {
|
|
FieldName string
|
|
}
|
|
|
|
func (OrderConfig) Name() string {
|
|
return "order_config"
|
|
}
|
|
|
|
func (o OrderConfig) Merge(ant schema.Annotation) schema.Annotation {
|
|
o.FieldName = ant.(OrderConfig).FieldName
|
|
return o
|
|
}
|
|
|
|
type IDConfig struct {
|
|
TagName string
|
|
}
|
|
|
|
func (IDConfig) Name() string {
|
|
return "id_config"
|
|
}
|
|
|
|
type PartialIndex struct {
|
|
WhereClause string
|
|
}
|
|
|
|
func (PartialIndex) Name() string {
|
|
return "partial_index"
|
|
}
|
|
|
|
func (p PartialIndex) Merge(ant schema.Annotation) schema.Annotation {
|
|
p.WhereClause = ant.(PartialIndex).WhereClause
|
|
return p
|
|
}
|
|
|
|
type AnnotationMixin struct {
|
|
mixin.Schema
|
|
}
|
|
|
|
func (AnnotationMixin) Annotations() []schema.Annotation {
|
|
return []schema.Annotation{
|
|
IDConfig{TagName: "id tag"},
|
|
OrderConfig{FieldName: "mixin annotations"},
|
|
}
|
|
}
|
|
|
|
type User struct {
|
|
ent.Schema
|
|
}
|
|
|
|
func (User) Mixin() []ent.Mixin {
|
|
return []ent.Mixin{
|
|
AnnotationMixin{},
|
|
}
|
|
}
|
|
|
|
func (User) Annotations() []schema.Annotation {
|
|
return []schema.Annotation{
|
|
OrderConfig{FieldName: "type annotations"},
|
|
}
|
|
}
|
|
|
|
func (User) Fields() []ent.Field {
|
|
return []ent.Field{
|
|
field.Int("age").
|
|
Comment("some comment"),
|
|
field.String("name").
|
|
Default("unknown").
|
|
Annotations(&OrderConfig{FieldName: "name"}),
|
|
field.String("nillable").
|
|
Nillable(),
|
|
field.String("optional").
|
|
Optional(),
|
|
field.Enum("state").
|
|
Values("on", "off").
|
|
Optional(),
|
|
field.String("sensitive").
|
|
Sensitive(),
|
|
field.Time("creation_time").
|
|
Default(time.Now),
|
|
field.UUID("uuid", uuid.UUID{}).
|
|
Default(uuid.New),
|
|
field.Int("parent_id").
|
|
Optional(),
|
|
}
|
|
}
|
|
|
|
func (User) Edges() []ent.Edge {
|
|
return []ent.Edge{
|
|
edge.To("groups", Group.Type).
|
|
Annotations(&OrderConfig{FieldName: "name"}),
|
|
edge.To("parent", User.Type).
|
|
Unique().
|
|
Field("parent_id").
|
|
StorageKey(edge.Column("parent_id")).
|
|
From("children"),
|
|
edge.To("following", User.Type).
|
|
Annotations(&OrderConfig{FieldName: "following"}).
|
|
From("followers").
|
|
Annotations(&OrderConfig{FieldName: "followers"}),
|
|
}
|
|
}
|
|
|
|
func (User) Indexes() []ent.Index {
|
|
return []ent.Index{
|
|
index.Fields("name", "address").
|
|
Unique(),
|
|
index.Fields("name").
|
|
Edges("parent").
|
|
StorageKey("user_parent_name").
|
|
Annotations(&PartialIndex{
|
|
WhereClause: "age > 20",
|
|
}).
|
|
Unique(),
|
|
}
|
|
}
|
|
|
|
type Group struct{ ent.Schema }
|
|
|
|
func (Group) Fields() []ent.Field { return nil }
|
|
|
|
func (Group) Edges() []ent.Edge {
|
|
return []ent.Edge{
|
|
edge.From("users", User.Type),
|
|
}
|
|
}
|
|
|
|
func TestMarshalSchema(t *testing.T) {
|
|
for _, u := range []ent.Interface{User{}, &User{}} {
|
|
buf, err := MarshalSchema(u)
|
|
require.NoError(t, err)
|
|
|
|
schema, err := UnmarshalSchema(buf)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "User", schema.Name)
|
|
require.Len(t, schema.Annotations, 2)
|
|
ant := schema.Annotations["order_config"].(map[string]interface{})
|
|
require.Equal(t, ant["FieldName"], "type annotations")
|
|
|
|
require.Len(t, schema.Fields, 9)
|
|
require.Equal(t, "age", schema.Fields[0].Name)
|
|
require.Equal(t, field.TypeInt, schema.Fields[0].Info.Type)
|
|
|
|
require.Equal(t, "name", schema.Fields[1].Name)
|
|
require.Equal(t, field.TypeString, schema.Fields[1].Info.Type)
|
|
require.Equal(t, "unknown", schema.Fields[1].DefaultValue)
|
|
require.NotEmpty(t, schema.Fields[1].Annotations)
|
|
ant = schema.Fields[1].Annotations["order_config"].(map[string]interface{})
|
|
require.Equal(t, ant["FieldName"], "name")
|
|
|
|
require.Equal(t, "nillable", schema.Fields[2].Name)
|
|
require.Equal(t, field.TypeString, schema.Fields[2].Info.Type)
|
|
require.True(t, schema.Fields[2].Nillable)
|
|
require.False(t, schema.Fields[2].Optional)
|
|
require.False(t, schema.Fields[2].Sensitive)
|
|
|
|
require.Equal(t, "optional", schema.Fields[3].Name)
|
|
require.Equal(t, field.TypeString, schema.Fields[3].Info.Type)
|
|
require.False(t, schema.Fields[3].Nillable)
|
|
require.True(t, schema.Fields[3].Optional)
|
|
|
|
require.Equal(t, "state", schema.Fields[4].Name)
|
|
require.Equal(t, field.TypeEnum, schema.Fields[4].Info.Type)
|
|
require.Equal(t, "on", schema.Fields[4].Enums[0].V)
|
|
require.Equal(t, "off", schema.Fields[4].Enums[1].V)
|
|
|
|
require.Equal(t, "sensitive", schema.Fields[5].Name)
|
|
require.Equal(t, field.TypeString, schema.Fields[5].Info.Type)
|
|
require.True(t, schema.Fields[5].Sensitive)
|
|
require.Equal(t, reflect.Invalid, schema.Fields[5].DefaultKind)
|
|
|
|
require.Equal(t, "creation_time", schema.Fields[6].Name)
|
|
require.Equal(t, field.TypeTime, schema.Fields[6].Info.Type)
|
|
require.Nil(t, schema.Fields[6].DefaultValue)
|
|
require.Equal(t, reflect.Func, schema.Fields[6].DefaultKind)
|
|
|
|
require.Equal(t, "uuid", schema.Fields[7].Name)
|
|
require.Equal(t, field.TypeUUID, schema.Fields[7].Info.Type)
|
|
require.True(t, schema.Fields[7].Default)
|
|
require.Equal(t, "github.com/google/uuid", schema.Fields[7].Info.PkgPath)
|
|
|
|
require.Equal(t, "parent_id", schema.Fields[8].Name)
|
|
require.Equal(t, field.TypeInt, schema.Fields[8].Info.Type)
|
|
require.True(t, schema.Fields[8].Optional)
|
|
|
|
require.Len(t, schema.Edges, 3)
|
|
require.Equal(t, "groups", schema.Edges[0].Name)
|
|
require.Equal(t, "Group", schema.Edges[0].Type)
|
|
require.False(t, schema.Edges[0].Inverse)
|
|
require.NotEmpty(t, schema.Edges[0].Annotations)
|
|
ant = schema.Edges[0].Annotations["order_config"].(map[string]interface{})
|
|
require.Equal(t, ant["FieldName"], "name")
|
|
|
|
require.Equal(t, "children", schema.Edges[1].Name)
|
|
require.Equal(t, "parent_id", schema.Edges[1].StorageKey.Columns[0])
|
|
require.Equal(t, "User", schema.Edges[1].Type)
|
|
require.True(t, schema.Edges[1].Inverse)
|
|
require.Equal(t, "parent", schema.Edges[1].Ref.Name)
|
|
require.True(t, schema.Edges[1].Ref.Unique)
|
|
require.Equal(t, "parent_id", schema.Edges[1].Ref.StorageKey.Columns[0])
|
|
|
|
ant = schema.Edges[2].Annotations["order_config"].(map[string]interface{})
|
|
require.Equal(t, ant["FieldName"], "followers")
|
|
ant = schema.Edges[2].Ref.Annotations["order_config"].(map[string]interface{})
|
|
require.Equal(t, ant["FieldName"], "following")
|
|
|
|
require.Equal(t, []string{"name", "address"}, schema.Indexes[0].Fields)
|
|
require.True(t, schema.Indexes[0].Unique)
|
|
require.Equal(t, []string{"name"}, schema.Indexes[1].Fields)
|
|
require.Equal(t, []string{"parent"}, schema.Indexes[1].Edges)
|
|
require.Equal(t, "user_parent_name", schema.Indexes[1].StorageKey)
|
|
require.True(t, schema.Indexes[1].Unique)
|
|
ant = schema.Indexes[1].Annotations["partial_index"].(map[string]interface{})
|
|
require.Equal(t, "age > 20", ant["WhereClause"])
|
|
|
|
require.Equal(t, "some comment", schema.Fields[0].Comment)
|
|
require.Empty(t, schema.Fields[1].Comment)
|
|
}
|
|
}
|
|
|
|
type InvalidEdge struct {
|
|
ent.Schema
|
|
}
|
|
|
|
// Edge panics because the edge declaration is invalid.
|
|
func (InvalidEdge) Edges() []ent.Edge {
|
|
return []ent.Edge{
|
|
edge.From("invalid", InvalidEdge{}.Type),
|
|
}
|
|
}
|
|
|
|
type InvalidUUID struct {
|
|
ent.Schema
|
|
}
|
|
|
|
func (InvalidUUID) Fields() []ent.Field {
|
|
return []ent.Field{
|
|
field.UUID("invalid", uuid.New()).
|
|
Default(time.Now),
|
|
}
|
|
}
|
|
|
|
func TestMarshalFails(t *testing.T) {
|
|
i1 := InvalidEdge{}
|
|
buf, err := MarshalSchema(i1)
|
|
require.Error(t, err)
|
|
require.Nil(t, buf)
|
|
|
|
i2 := InvalidUUID{}
|
|
buf, err = MarshalSchema(i2)
|
|
require.Nil(t, buf)
|
|
require.EqualError(t, err, `schema "InvalidUUID": field "invalid": expect type (func() uuid.UUID) for uuid default value`)
|
|
}
|
|
|
|
type WithDefaults struct {
|
|
ent.Schema
|
|
}
|
|
|
|
func (WithDefaults) Fields() []ent.Field {
|
|
return []ent.Field{
|
|
field.Int("int").
|
|
Default(1),
|
|
field.Float("float").
|
|
Default(math.Pi),
|
|
field.String("string").
|
|
Default("foo"),
|
|
field.Bool("string").
|
|
Default(true),
|
|
field.Time("updated_at").
|
|
UpdateDefault(time.Now),
|
|
// see issue #1146
|
|
field.Int("int_default_func").
|
|
DefaultFunc(func() int {
|
|
return 1e9
|
|
}),
|
|
field.Float("balance").
|
|
Default(0),
|
|
field.JSON("dirs", []http.Dir{}).
|
|
Default([]http.Dir{"/tmp"}),
|
|
}
|
|
}
|
|
|
|
func (WithDefaults) Edges() []ent.Edge {
|
|
return nil
|
|
}
|
|
|
|
func (WithDefaults) Indexes() []ent.Index {
|
|
return nil
|
|
}
|
|
|
|
func TestMarshalDefaults(t *testing.T) {
|
|
d := WithDefaults{}
|
|
buf, err := MarshalSchema(d)
|
|
require.NoError(t, err)
|
|
|
|
schema := &Schema{}
|
|
err = json.Unmarshal(buf, schema)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, "WithDefaults", schema.Name)
|
|
require.True(t, schema.Fields[0].Default)
|
|
require.True(t, schema.Fields[1].Default)
|
|
require.True(t, schema.Fields[2].Default)
|
|
require.True(t, schema.Fields[3].Default)
|
|
require.False(t, schema.Fields[4].Default)
|
|
require.True(t, schema.Fields[4].UpdateDefault)
|
|
require.True(t, schema.Fields[5].Default)
|
|
require.Equal(t, schema.Fields[5].DefaultKind, reflect.Func)
|
|
require.True(t, schema.Fields[6].Default)
|
|
require.True(t, schema.Fields[7].Default)
|
|
}
|
|
|
|
type TimeMixin struct {
|
|
mixin.Schema
|
|
}
|
|
|
|
func (TimeMixin) Fields() []ent.Field {
|
|
return []ent.Field{
|
|
field.Time("created_at").
|
|
Immutable().
|
|
Default(time.Now),
|
|
field.Time("updated_at").
|
|
Default(time.Now).
|
|
UpdateDefault(time.Now),
|
|
}
|
|
}
|
|
|
|
type HooksMixin struct {
|
|
mixin.Schema
|
|
}
|
|
|
|
func (HooksMixin) Fields() []ent.Field {
|
|
return []ent.Field{
|
|
field.String("boring"),
|
|
}
|
|
}
|
|
|
|
func (HooksMixin) Edges() []ent.Edge {
|
|
return []ent.Edge{
|
|
edge.To("user", User.Type).
|
|
Unique(),
|
|
}
|
|
}
|
|
|
|
func (HooksMixin) Indexes() []ent.Index {
|
|
return []ent.Index{
|
|
index.Fields("boring").
|
|
Edges("user"),
|
|
}
|
|
}
|
|
|
|
func (HooksMixin) Hooks() []ent.Hook {
|
|
return []ent.Hook{
|
|
func(ent.Mutator) ent.Mutator { return nil },
|
|
func(ent.Mutator) ent.Mutator { return nil },
|
|
}
|
|
}
|
|
|
|
type BoringPolicy struct{}
|
|
|
|
func (BoringPolicy) EvalMutation(context.Context, ent.Mutation) error { return nil }
|
|
func (BoringPolicy) EvalQuery(context.Context, ent.Query) error { return nil }
|
|
|
|
type PrivacyMixin struct {
|
|
mixin.Schema
|
|
}
|
|
|
|
func (PrivacyMixin) Policy() ent.Policy {
|
|
return BoringPolicy{}
|
|
}
|
|
|
|
type WithMixin struct {
|
|
ent.Schema
|
|
}
|
|
|
|
func (WithMixin) Mixin() []ent.Mixin {
|
|
return []ent.Mixin{
|
|
TimeMixin{},
|
|
HooksMixin{},
|
|
PrivacyMixin{},
|
|
}
|
|
}
|
|
|
|
func (WithMixin) Fields() []ent.Field {
|
|
return []ent.Field{
|
|
field.Int("field"),
|
|
}
|
|
}
|
|
|
|
func (WithMixin) Edges() []ent.Edge {
|
|
return []ent.Edge{
|
|
edge.To("owner", User.Type),
|
|
}
|
|
}
|
|
|
|
func (WithMixin) Indexes() []ent.Index {
|
|
return []ent.Index{
|
|
index.Fields("field").
|
|
Edges("owner").
|
|
Unique(),
|
|
}
|
|
}
|
|
|
|
func (WithMixin) Hooks() []ent.Hook {
|
|
return []ent.Hook{
|
|
func(ent.Mutator) ent.Mutator { return nil },
|
|
}
|
|
}
|
|
|
|
func (WithMixin) Policy() ent.Policy {
|
|
return BoringPolicy{}
|
|
}
|
|
|
|
func TestMarshalMixin(t *testing.T) {
|
|
d := WithMixin{}
|
|
buf, err := MarshalSchema(d)
|
|
require.NoError(t, err)
|
|
|
|
schema := &Schema{}
|
|
err = json.Unmarshal(buf, schema)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("Fields", func(t *testing.T) {
|
|
require.Equal(t, "WithMixin", schema.Name)
|
|
require.Equal(t, "created_at", schema.Fields[0].Name)
|
|
require.True(t, schema.Fields[0].Default)
|
|
require.True(t, schema.Fields[0].Position.MixedIn)
|
|
require.Equal(t, 0, schema.Fields[0].Position.MixinIndex)
|
|
require.Equal(t, 0, schema.Fields[0].Position.Index)
|
|
|
|
require.Equal(t, "updated_at", schema.Fields[1].Name)
|
|
require.True(t, schema.Fields[1].Default)
|
|
require.True(t, schema.Fields[1].UpdateDefault)
|
|
require.True(t, schema.Fields[1].Position.MixedIn)
|
|
require.Equal(t, 0, schema.Fields[1].Position.MixinIndex)
|
|
require.Equal(t, 1, schema.Fields[1].Position.Index)
|
|
|
|
require.Equal(t, "boring", schema.Fields[2].Name)
|
|
require.False(t, schema.Fields[2].Default)
|
|
require.False(t, schema.Fields[2].UpdateDefault)
|
|
require.True(t, schema.Fields[2].Position.MixedIn)
|
|
require.Equal(t, 1, schema.Fields[2].Position.MixinIndex)
|
|
require.Equal(t, 0, schema.Fields[2].Position.Index)
|
|
|
|
require.Equal(t, "field", schema.Fields[3].Name)
|
|
require.False(t, schema.Fields[3].Default)
|
|
require.False(t, schema.Fields[3].Position.MixedIn)
|
|
require.Equal(t, 0, schema.Fields[3].Position.Index)
|
|
})
|
|
|
|
t.Run("Hooks", func(t *testing.T) {
|
|
require.True(t, schema.Hooks[0].MixedIn)
|
|
require.True(t, schema.Hooks[1].MixedIn)
|
|
|
|
require.Equal(t, 1, schema.Hooks[0].MixinIndex)
|
|
require.Equal(t, 1, schema.Hooks[1].MixinIndex)
|
|
require.Equal(t, 0, schema.Hooks[0].Index)
|
|
require.Equal(t, 1, schema.Hooks[1].Index)
|
|
|
|
require.False(t, schema.Hooks[2].MixedIn)
|
|
require.Equal(t, 0, schema.Hooks[2].Index)
|
|
require.Equal(t, 0, schema.Hooks[2].MixinIndex)
|
|
})
|
|
|
|
t.Run("Edges", func(t *testing.T) {
|
|
require.Len(t, schema.Edges, 2)
|
|
require.Equal(t, "user", schema.Edges[0].Name)
|
|
require.Equal(t, "User", schema.Edges[0].Type)
|
|
require.True(t, schema.Edges[0].Unique)
|
|
|
|
require.Equal(t, "owner", schema.Edges[1].Name)
|
|
require.Equal(t, "User", schema.Edges[1].Type)
|
|
require.False(t, schema.Edges[1].Unique)
|
|
})
|
|
|
|
t.Run("Indexes", func(t *testing.T) {
|
|
require.Len(t, schema.Indexes, 2)
|
|
require.Equal(t, []string{"boring"}, schema.Indexes[0].Fields)
|
|
require.Equal(t, []string{"user"}, schema.Indexes[0].Edges)
|
|
require.False(t, schema.Indexes[0].Unique)
|
|
|
|
require.Equal(t, []string{"field"}, schema.Indexes[1].Fields)
|
|
require.Equal(t, []string{"owner"}, schema.Indexes[1].Edges)
|
|
require.True(t, schema.Indexes[1].Unique)
|
|
})
|
|
|
|
t.Run("Policy", func(t *testing.T) {
|
|
require.Len(t, schema.Policy, 2)
|
|
require.True(t, schema.Policy[0].MixedIn)
|
|
require.False(t, schema.Policy[1].MixedIn)
|
|
})
|
|
}
|