Files
ent/entc/integration/privacy/rule/rule.go
2022-04-14 20:00:52 +03:00

212 lines
6.6 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package rule
import (
"context"
"fmt"
"sync"
"entgo.io/ent/entc/integration/privacy/ent"
"entgo.io/ent/entc/integration/privacy/ent/hook"
"entgo.io/ent/entc/integration/privacy/ent/predicate"
"entgo.io/ent/entc/integration/privacy/ent/privacy"
"entgo.io/ent/entc/integration/privacy/ent/task"
"entgo.io/ent/entc/integration/privacy/ent/team"
"entgo.io/ent/entc/integration/privacy/ent/user"
"entgo.io/ent/entc/integration/privacy/viewer"
)
// DenyUpdateRule is a mutation rule that denies the update-many operation.
func DenyUpdateRule() privacy.MutationRule {
return privacy.DenyMutationOperationRule(ent.OpUpdate)
}
// DenyIfNoViewer is a rule that returns deny decision if the viewer is missing in the context.
func DenyIfNoViewer() privacy.QueryMutationRule {
return privacy.ContextQueryMutationRule(func(ctx context.Context) error {
view := viewer.FromContext(ctx)
if view == nil {
return privacy.Denyf("viewer-context is missing")
}
return privacy.Skip
})
}
// DenyIfNotAdmin is a rule that returns deny decision if the viewer not admin.
func DenyIfNotAdmin() privacy.QueryMutationRule {
return privacy.ContextQueryMutationRule(func(ctx context.Context) error {
view := viewer.FromContext(ctx)
if !view.Admin() {
return privacy.Denyf("viewer-context is not admin")
}
return privacy.Skip
})
}
// AllowIfAdmin is a rule that returns allow decision if the viewer is admin.
func AllowIfAdmin() privacy.QueryMutationRule {
return privacy.ContextQueryMutationRule(func(ctx context.Context) error {
view := viewer.FromContext(ctx)
if view.Admin() {
return privacy.Allow
}
return privacy.Skip
})
}
// AllowUserCreateIfAdmin is a rule that allows user creation only if the viewer is admin.
func AllowUserCreateIfAdmin() privacy.MutationRule {
rule := privacy.UserMutationRuleFunc(func(ctx context.Context, _ *ent.UserMutation) error {
view := viewer.FromContext(ctx)
if view.Admin() {
return privacy.Allow
}
// Skip to the next privacy rule, that may accept or reject this operation.
return privacy.Skip
})
return privacy.OnMutationOperation(rule, ent.OpCreate)
}
// AllowTaskCreateIfOwner is a rule that allows creating task only if the creator is also the user.
func AllowTaskCreateIfOwner() privacy.MutationRule {
rule := privacy.TaskMutationRuleFunc(func(ctx context.Context, m *ent.TaskMutation) error {
view, ok := viewer.FromContext(ctx).(*viewer.UserViewer)
if !ok {
return privacy.Skip
}
id, exists := m.OwnerID()
if exists && view.User.ID == id {
return privacy.Allow
}
// Skip to the next privacy rule, that may accept or reject this operation.
return privacy.Skip
})
return privacy.OnMutationOperation(rule, ent.OpCreate)
}
// FilterTeamRule is a query rule that filters out tasks and users that are not in the team.
func FilterTeamRule() privacy.QueryRule {
type TeamsFilter interface {
WhereHasTeamsWith(...predicate.Team)
}
return privacy.FilterFunc(func(ctx context.Context, f privacy.Filter) error {
view := viewer.FromContext(ctx)
teams, err := view.Teams(ctx)
if err != nil {
return privacy.Denyf("getting team names: %w", err)
}
tf, ok := f.(TeamsFilter)
if !ok {
return privacy.Denyf("unexpected filter type %T", f)
}
tf.WhereHasTeamsWith(team.NameIn(teams...))
return privacy.Skip
})
}
// FilterUsesDep is a filter query rule that uses its injected dependency using type-assertion.
func FilterUsesDep() privacy.QueryRule {
return privacy.FilterFunc(func(ctx context.Context, f privacy.Filter) error {
u, ok := f.(*ent.UserFilter)
if !ok {
return privacy.Denyf("unexpected filter type %T", f)
}
// Access the dependency after the type is resolved.
_ = u.HTTPClient
return privacy.Skip
})
}
// DenyIfStatusChangedByOther is a mutation rule that returns a deny decision if the
// task status was changed by someone that is not the owner of the task, or an admin.
func DenyIfStatusChangedByOther() privacy.MutationRule {
policy := privacy.TaskMutationRuleFunc(func(ctx context.Context, m *ent.TaskMutation) error {
// Skip if the mutation does not change the task status.
if _, exists := m.Status(); !exists {
return privacy.Skip
}
view, ok := viewer.FromContext(ctx).(*viewer.UserViewer)
// Skip if the viewer is an admin (or an app).
if !ok || view.Admin() {
return privacy.Skip
}
id, ok := m.ID()
if !ok {
return fmt.Errorf("missing task id")
}
owner, err := m.Client().User.Query().Where(user.HasTasksWith(task.ID(id))).Only(ctx)
if err != nil {
return err
}
// Deny the mutation, if the viewer is not the owner.
if owner.ID != view.User.ID {
return privacy.Denyf("viewer %d is not allowed to change the task status", view.User.ID)
}
return privacy.Skip
})
return privacy.OnMutationOperation(policy, ent.OpUpdateOne)
}
// AllowIfViewerInTheSameTeam returns allow decision if viewer on the same team as the task.
func AllowIfViewerInTheSameTeam() privacy.MutationRule {
policy := privacy.TaskMutationRuleFunc(func(ctx context.Context, m *ent.TaskMutation) error {
view, ok := viewer.FromContext(ctx).(*viewer.UserViewer)
// Skip if the viewer is an admin (or an app).
if !ok || view.Admin() {
return privacy.Skip
}
teams, err := view.Teams(ctx)
if err != nil {
return privacy.Denyf("getting team names: %w", err)
}
id, ok := m.ID()
if !ok {
return fmt.Errorf("missing task id")
}
// Query should return an error if the viewer
// does not belong to the task namespace/team.
if _, err = m.Client().Task.Query().
Where(
task.ID(id),
task.HasTeamsWith(team.NameIn(teams...)),
).
Only(ctx); err != nil {
return err
}
return privacy.Allow
})
return privacy.OnMutationOperation(policy, ent.OpUpdateOne)
}
var logger = struct {
logf func(string, ...interface{})
sync.RWMutex
}{
logf: func(string, ...interface{}) {},
}
// SetMutationLogFunc overrides the logging function used by LogPlanetMutationHook.
func SetMutationLogFunc(f func(string, ...interface{})) func(string, ...interface{}) {
logger.Lock()
defer logger.Unlock()
logf := logger.logf
logger.logf = f
return logf
}
// LogTaskMutationHook returns a hook logging planet mutations.
func LogTaskMutationHook() ent.Hook {
return func(next ent.Mutator) ent.Mutator {
return hook.TaskFunc(func(ctx context.Context, m *ent.TaskMutation) (ent.Value, error) {
value, err := next.Mutate(ctx, m)
logger.RLock()
defer logger.RUnlock()
logger.logf("task mutation: type %s, value %v, err %v", m.Op(), value, err)
return value, err
})
}
}