ent/privacy: initial privacy package (#836)

This commit is contained in:
Ariel Mashraki
2020-10-11 14:27:29 +03:00
committed by GitHub
parent 13b379d07c
commit dfc4dee8a5
8 changed files with 260 additions and 199 deletions

View File

@@ -8,24 +8,25 @@ package privacy
import (
"context"
"errors"
"fmt"
"github.com/facebook/ent/entc/integration/privacy/ent"
"github.com/facebook/ent/privacy"
)
var (
// Allow may be returned by rules to indicate that the policy
// evaluation should terminate with an allow decision.
Allow = errors.New("ent/privacy: allow rule")
Allow = privacy.Allow
// Deny may be returned by rules to indicate that the policy
// evaluation should terminate with an deny decision.
Deny = errors.New("ent/privacy: deny rule")
Deny = privacy.Deny
// Skip may be returned by rules to indicate that the policy
// evaluation should continue to the next rule.
Skip = errors.New("ent/privacy: skip rule")
Skip = privacy.Skip
)
// Allowf returns an formatted wrapped Allow decision.
@@ -43,52 +44,25 @@ func Skipf(format string, a ...interface{}) error {
return fmt.Errorf(format+": %w", append(a, Skip)...)
}
type decisionCtxKey struct{}
// DecisionContext creates a decision context.
// DecisionContext creates a new context from the given parent context with
// a policy decision attach to it.
func DecisionContext(parent context.Context, decision error) context.Context {
if decision == nil || errors.Is(decision, Skip) {
return parent
}
return context.WithValue(parent, decisionCtxKey{}, decision)
return privacy.DecisionContext(parent, decision)
}
func decisionFromContext(ctx context.Context) (error, bool) {
decision, ok := ctx.Value(decisionCtxKey{}).(error)
if ok && errors.Is(decision, Allow) {
decision = nil
}
return decision, ok
// DecisionFromContext retrieves the policy decision from the context.
func DecisionFromContext(ctx context.Context) (error, bool) {
return privacy.DecisionFromContext(ctx)
}
type (
// QueryPolicy combines multiple query rules into a single policy.
QueryPolicy []QueryRule
// QueryRule defines the interface deciding whether a
// query is allowed and optionally modify it.
QueryRule interface {
EvalQuery(context.Context, ent.Query) error
}
QueryRule = privacy.QueryRule
// QueryPolicy combines multiple query rules into a single policy.
QueryPolicy = privacy.QueryPolicy
)
// EvalQuery evaluates a query against a query policy.
func (policy QueryPolicy) EvalQuery(ctx context.Context, q ent.Query) error {
if decision, ok := decisionFromContext(ctx); ok {
return decision
}
for _, rule := range policy {
switch decision := rule.EvalQuery(ctx, q); {
case decision == nil || errors.Is(decision, Skip):
case errors.Is(decision, Allow):
return nil
default:
return decision
}
}
return nil
}
// QueryRuleFunc type is an adapter to allow the use of
// ordinary functions as query rules.
type QueryRuleFunc func(context.Context, ent.Query) error
@@ -99,33 +73,13 @@ func (f QueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error {
}
type (
// MutationPolicy combines multiple mutation rules into a single policy.
MutationPolicy []MutationRule
// MutationRule defines the interface deciding whether a
// mutation is allowed and optionally modify it.
MutationRule interface {
EvalMutation(context.Context, ent.Mutation) error
}
MutationRule = privacy.MutationRule
// MutationPolicy combines multiple mutation rules into a single policy.
MutationPolicy = privacy.MutationPolicy
)
// EvalMutation evaluates a mutation against a mutation policy.
func (policy MutationPolicy) EvalMutation(ctx context.Context, m ent.Mutation) error {
if decision, ok := decisionFromContext(ctx); ok {
return decision
}
for _, rule := range policy {
switch decision := rule.EvalMutation(ctx, m); {
case decision == nil || errors.Is(decision, Skip):
case errors.Is(decision, Allow):
return nil
default:
return decision
}
}
return nil
}
// MutationRuleFunc type is an adapter to allow the use of
// ordinary functions as mutation rules.
type MutationRuleFunc func(context.Context, ent.Mutation) error

View File

@@ -15,6 +15,7 @@ import (
"github.com/facebook/ent/entc/integration/privacy/ent/user"
"github.com/facebook/ent"
"github.com/facebook/ent/privacy"
)
// The init function reads all schema descriptors with runtime code
@@ -22,7 +23,7 @@ import (
// to their package variables.
func init() {
taskMixin := schema.Task{}.Mixin()
task.Policy = newPolicy(taskMixin[0], taskMixin[1], schema.Task{})
task.Policy = privacy.NewPolicies(taskMixin[0], taskMixin[1], schema.Task{})
task.Hooks[0] = func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if err := task.Policy.EvalMutation(ctx, m); err != nil {
@@ -41,7 +42,7 @@ func init() {
// task.TitleValidator is a validator for the "title" field. It is called by the builders before save.
task.TitleValidator = taskDescTitle.Validators[0].(func(string) error)
teamMixin := schema.Team{}.Mixin()
team.Policy = newPolicy(teamMixin[0], schema.Team{})
team.Policy = privacy.NewPolicies(teamMixin[0], schema.Team{})
team.Hooks[0] = func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if err := team.Policy.EvalMutation(ctx, m); err != nil {
@@ -57,7 +58,7 @@ func init() {
// team.NameValidator is a validator for the "name" field. It is called by the builders before save.
team.NameValidator = teamDescName.Validators[0].(func(string) error)
userMixin := schema.User{}.Mixin()
user.Policy = newPolicy(userMixin[0], userMixin[1], schema.User{})
user.Policy = privacy.NewPolicies(userMixin[0], userMixin[1], schema.User{})
user.Hooks[0] = func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if err := user.Policy.EvalMutation(ctx, m); err != nil {
@@ -74,37 +75,6 @@ func init() {
user.NameValidator = userDescName.Validators[0].(func(string) error)
}
type policies []ent.Policy
// newPolicy creates a policy from list of mixin and ent.Schema.
func newPolicy(ps ...interface{ Policy() ent.Policy }) ent.Policy {
pocs := make(policies, 0, len(ps))
for _, p := range ps {
if policy := p.Policy(); policy != nil {
pocs = append(pocs, policy)
}
}
return pocs
}
func (p policies) EvalMutation(ctx context.Context, m ent.Mutation) error {
for i := range p {
if err := p[i].EvalMutation(ctx, m); err != nil {
return err
}
}
return nil
}
func (p policies) EvalQuery(ctx context.Context, q ent.Query) error {
for i := range p {
if err := p[i].EvalQuery(ctx, q); err != nil {
return err
}
}
return nil
}
const (
Version = "(devel)" // Version of ent codegen.
)