entc/gen: simplify policy execution (#822)

This commit is contained in:
Ariel Mashraki
2020-10-06 12:16:31 +03:00
committed by GitHub
parent d5ae1b018e
commit 3f3debbe97
11 changed files with 81 additions and 57 deletions

File diff suppressed because one or more lines are too long

View File

@@ -352,10 +352,8 @@ func ({{ $receiver }} *{{ $builder }}) prepareQuery(ctx context.Context) error {
{{ $receiver }}.{{ $.Storage }} = prev
}
{{- if $.NumPolicy }}
for _, policy := range {{ $.Package }}.Policy {
if err := policy.EvalQuery(ctx, {{ $receiver }}); err != nil {
return err
}
if err := {{ $.Package }}.Policy.EvalQuery(ctx, {{ $receiver }}); err != nil {
return err
}
{{- end }}
return nil

View File

@@ -66,7 +66,7 @@ const (
Hooks [{{ $numHooks }}]ent.Hook
{{- end }}
{{- if $.NumPolicy }}
Policy [{{ $.NumPolicy }}]ent.Policy
Policy ent.Policy
{{- end }}
{{- $fields := $.Fields }}{{ if $.ID.UserDefined }}{{ $fields = append $fields $.ID }}{{ end }}
{{- range $f := $fields }}

View File

@@ -90,14 +90,12 @@ func init() {
{{ $pkg }}Mixin := {{ $schema }}.{{ $n.Name }}{}.Mixin()
{{- end }}
{{- with $policies := $n.PolicyPositions }}
{{- /* policies defined in schema and mixins. */}}
setPolicies({{ $pkg }}.Policy[:], {{ range $idx := $n.MixedInPolicies }}{{ $pkg }}Mixin[{{ $idx }}],{{ end }}{{ $schema }}.{{ $n.Name }}{})
{{- /* policies defined in schema and mixins. */}}
{{ $pkg }}.Policy = newPolicy({{ range $idx := $n.MixedInPolicies }}{{ $pkg }}Mixin[{{ $idx }}],{{ end }}{{ $schema }}.{{ $n.Name }}{})
{{ $pkg }}.Hooks[0] = func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
for _, policy := range {{ $pkg }}.Policy {
if err := policy.EvalMutation(ctx, m); err != nil {
return nil, err
}
if err := {{ $pkg }}.Policy.EvalMutation(ctx, m); err != nil {
return nil, err
}
return next.Mutate(ctx, m)
})
@@ -195,15 +193,35 @@ func init() {
{{ $hasPolicy := false }}{{ range $n := $.Nodes }}{{ if $n.NumPolicy }}{{ $hasPolicy = true }}{{ end }}{{ end }}
{{ if $hasPolicy }}
// setPolicies sets the policies from the given mixin and ent.Schema.
func setPolicies(policies []ent.Policy, ps ...interface{ Policy() ent.Policy }) {
var i int
type policies []ent.Policy
// newPolicy creates a policy from list of mixin and ent.Schema.
func newPolicy(ps ...interface{ Policy() ent.Policy }) ent.Policy {
pocs := make(policies, 0, len(ps))
for _, p := range ps {
if policy := p.Policy(); policy != nil {
policies[i] = policy
i++
pocs = append(pocs, policy)
}
}
return pocs
}
func (p policies) EvalMutation(ctx context.Context, m ent.Mutation) error {
for i := range p {
if err := p[i].EvalMutation(ctx, m); err != nil {
return err
}
}
return nil
}
func (p policies) EvalQuery(ctx context.Context, q ent.Query) error {
for i := range p {
if err := p[i].EvalQuery(ctx, q); err != nil {
return err
}
}
return nil
}
{{ end }}

View File

@@ -22,13 +22,11 @@ import (
// to their package variables.
func init() {
taskMixin := schema.Task{}.Mixin()
setPolicies(task.Policy[:], taskMixin[0], taskMixin[1], schema.Task{})
task.Policy = newPolicy(taskMixin[0], taskMixin[1], schema.Task{})
task.Hooks[0] = func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
for _, policy := range task.Policy {
if err := policy.EvalMutation(ctx, m); err != nil {
return nil, err
}
if err := task.Policy.EvalMutation(ctx, m); err != nil {
return nil, err
}
return next.Mutate(ctx, m)
})
@@ -43,13 +41,11 @@ func init() {
// task.TitleValidator is a validator for the "title" field. It is called by the builders before save.
task.TitleValidator = taskDescTitle.Validators[0].(func(string) error)
teamMixin := schema.Team{}.Mixin()
setPolicies(team.Policy[:], teamMixin[0], schema.Team{})
team.Policy = newPolicy(teamMixin[0], schema.Team{})
team.Hooks[0] = func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
for _, policy := range team.Policy {
if err := policy.EvalMutation(ctx, m); err != nil {
return nil, err
}
if err := team.Policy.EvalMutation(ctx, m); err != nil {
return nil, err
}
return next.Mutate(ctx, m)
})
@@ -61,13 +57,11 @@ func init() {
// team.NameValidator is a validator for the "name" field. It is called by the builders before save.
team.NameValidator = teamDescName.Validators[0].(func(string) error)
userMixin := schema.User{}.Mixin()
setPolicies(user.Policy[:], userMixin[0], userMixin[1], schema.User{})
user.Policy = newPolicy(userMixin[0], userMixin[1], schema.User{})
user.Hooks[0] = func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
for _, policy := range user.Policy {
if err := policy.EvalMutation(ctx, m); err != nil {
return nil, err
}
if err := user.Policy.EvalMutation(ctx, m); err != nil {
return nil, err
}
return next.Mutate(ctx, m)
})
@@ -80,15 +74,35 @@ func init() {
user.NameValidator = userDescName.Validators[0].(func(string) error)
}
// setPolicies sets the policies from the given mixin and ent.Schema.
func setPolicies(policies []ent.Policy, ps ...interface{ Policy() ent.Policy }) {
var i int
type policies []ent.Policy
// newPolicy creates a policy from list of mixin and ent.Schema.
func newPolicy(ps ...interface{ Policy() ent.Policy }) ent.Policy {
pocs := make(policies, 0, len(ps))
for _, p := range ps {
if policy := p.Policy(); policy != nil {
policies[i] = policy
i++
pocs = append(pocs, policy)
}
}
return pocs
}
func (p policies) EvalMutation(ctx context.Context, m ent.Mutation) error {
for i := range p {
if err := p[i].EvalMutation(ctx, m); err != nil {
return err
}
}
return nil
}
func (p policies) EvalQuery(ctx context.Context, q ent.Query) error {
for i := range p {
if err := p[i].EvalQuery(ctx, q); err != nil {
return err
}
}
return nil
}
const (

View File

@@ -87,7 +87,7 @@ func ValidColumn(column string) bool {
//
var (
Hooks [2]ent.Hook
Policy [3]ent.Policy
Policy ent.Policy
// TitleValidator is a validator for the "title" field. It is called by the builders before save.
TitleValidator func(string) error
)

View File

@@ -367,10 +367,8 @@ func (tq *TaskQuery) prepareQuery(ctx context.Context) error {
}
tq.sql = prev
}
for _, policy := range task.Policy {
if err := policy.EvalQuery(ctx, tq); err != nil {
return err
}
if err := task.Policy.EvalQuery(ctx, tq); err != nil {
return err
}
return nil
}

View File

@@ -70,7 +70,7 @@ func ValidColumn(column string) bool {
//
var (
Hooks [1]ent.Hook
Policy [2]ent.Policy
Policy ent.Policy
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
)

View File

@@ -366,10 +366,8 @@ func (tq *TeamQuery) prepareQuery(ctx context.Context) error {
}
tq.sql = prev
}
for _, policy := range team.Policy {
if err := policy.EvalQuery(ctx, tq); err != nil {
return err
}
if err := team.Policy.EvalQuery(ctx, tq); err != nil {
return err
}
return nil
}

View File

@@ -72,7 +72,7 @@ func ValidColumn(column string) bool {
//
var (
Hooks [1]ent.Hook
Policy [3]ent.Policy
Policy ent.Policy
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
)

View File

@@ -366,10 +366,8 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error {
}
uq.sql = prev
}
for _, policy := range user.Policy {
if err := policy.EvalQuery(ctx, uq); err != nil {
return err
}
if err := user.Policy.EvalQuery(ctx, uq); err != nil {
return err
}
return nil
}