diff --git a/doc/md/predicates.md b/doc/md/predicates.md index 74957b03e..049039e80 100755 --- a/doc/md/predicates.md +++ b/doc/md/predicates.md @@ -89,41 +89,6 @@ client.Pet. Custom predicates can be useful if you want to write your own dialect-specific logic or to control the executed queries. -For example, in order to use built-in SQL functions such as `DATE()`, use one of the following options: - -1. Pass a dialect-aware predicate function using the `sql.P` option: - -```go -users := client.User.Query(). - Select(user.FieldID). - Where(sql.P(func(b *sql.Builder) { - b.WriteString("DATE(").Ident("last_login_at").WriteByte(')').WriteOp(OpGTE).Arg(value) - })). - AllX(ctx) -``` - -The above code will produce the following SQL query: - -```sql -SELECT `id` FROM `users` WHERE DATE(`last_login_at`) >= ? -``` - -2. Inline a predicate expression using the `ExprP()` option: - -```go -users := client.User.Query(). - Select(user.FieldID). - Where(func(s *sql.Selector) { - s.Where(sql.ExprP("DATE(last_login_at >= ?", value)) - }). - AllX(ctx) -``` - -The above code will produce the same SQL query: - -```sql -SELECT `id` FROM `users` WHERE DATE(`last_login_at`) >= ? -``` #### Get all pets of users 1, 2 and 3 ```go @@ -240,6 +205,44 @@ The above code will produce the following SQL query: SELECT DISTINCT `pets`.`id`, `pets`.`owner_id`, `pets`.`name`, `pets`.`age`, `pets`.`species` FROM `pets` WHERE `name` LIKE '_B%' ``` +#### Custom SQL functions + +In order to use built-in SQL functions such as `DATE()`, use one of the following options: + +1\. Pass a dialect-aware predicate function using the `sql.P` option: + +```go +users := client.User.Query(). + Select(user.FieldID). + Where(sql.P(func(b *sql.Builder) { + b.WriteString("DATE(").Ident("last_login_at").WriteByte(')').WriteOp(OpGTE).Arg(value) + })). + AllX(ctx) +``` + +The above code will produce the following SQL query: + +```sql +SELECT `id` FROM `users` WHERE DATE(`last_login_at`) >= ? +``` + +2\. Inline a predicate expression using the `ExprP()` option: + +```go +users := client.User.Query(). + Select(user.FieldID). + Where(func(s *sql.Selector) { + s.Where(sql.ExprP("DATE(last_login_at >= ?", value)) + }). + AllX(ctx) +``` + +The above code will produce the same SQL query: + +```sql +SELECT `id` FROM `users` WHERE DATE(`last_login_at`) >= ? +``` + ## JSON predicates JSON predicates are not generated by default as part of the code generation. However, ent provides an official package diff --git a/doc/md/privacy.md b/doc/md/privacy.md index 39117a3fb..f2bbea00a 100644 --- a/doc/md/privacy.md +++ b/doc/md/privacy.md @@ -287,10 +287,16 @@ type BaseMixin struct { // Policy defines the privacy policy of the BaseMixin. func (BaseMixin) Policy() ent.Policy { return privacy.Policy{ - Mutation: privacy.MutationPolicy{ - rule.DenyIfNoViewer(), - }, Query: privacy.QueryPolicy{ + // Deny any query operation in case + // there is no "viewer context". + rule.DenyIfNoViewer(), + // Allow admins to query any information. + rule.AllowIfAdmin(), + }, + Mutation: privacy.MutationPolicy{ + // Deny any mutation operation in case + // there is no "viewer context". rule.DenyIfNoViewer(), }, } @@ -329,30 +335,39 @@ func (Tenant) Policy() ent.Policy { Then, we expect the following code to run successfully: ```go title="examples/privacytenant/example_test.go" -func Do(ctx context.Context, client *ent.Client) error { - // Expect operation to fail, because viewer-context - // is missing (first mutation rule check). + +func Example_CreateTenants(ctx context.Context, client *ent.Client) { + // Expect operation to fail in case viewer-context is missing. + // First mutation privacy policy rule defined in BaseMixin. if err := client.Tenant.Create().Exec(ctx); !errors.Is(err, privacy.Deny) { - return fmt.Errorf("expect operation to fail, but got %w", err) + log.Fatal("expect tenant creation to fail, but got:", err) } - // Deny tenant creation if the viewer is not admin. + + // Expect operation to fail in case the ent.User in the viewer-context + // is not an admin user. Privacy policy defined in the Tenant schema. viewCtx := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.View}) if err := client.Tenant.Create().Exec(viewCtx); !errors.Is(err, privacy.Deny) { - return fmt.Errorf("expect operation to fail, but got %w", err) + log.Fatal("expect tenant creation to fail, but got:", err) } - // Apply the same operation with "Admin" role, expect it to pass. + + // Operations should pass successfully as the user in the viewer-context + // is an admin user. First mutation privacy policy in Tenant schema. adminCtx := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.Admin}) hub, err := client.Tenant.Create().SetName("GitHub").Save(adminCtx) if err != nil { - return fmt.Errorf("expect operation to pass, but got %w", err) + log.Fatal("expect tenant creation to pass, but got:", err) } fmt.Println(hub) + lab, err := client.Tenant.Create().SetName("GitLab").Save(adminCtx) if err != nil { - return fmt.Errorf("expect operation to pass, but got %w", err) + log.Fatal("expect tenant creation to pass, but got:", err) } fmt.Println(lab) - return nil + + // Output: + // Tenant(id=1, name=GitHub) + // Tenant(id=2, name=GitLab) } ``` @@ -365,10 +380,18 @@ type TenantMixin struct { mixin.Schema } +// Fields for all schemas that embed TenantMixin. +func (TenantMixin) Fields() []ent.Field { + return []ent.Field{ + field.Int("tenant_id"), + } +} + // Edges for all schemas that embed TenantMixin. func (TenantMixin) Edges() []ent.Edge { return []ent.Edge{ edge.To("tenant", Tenant.Type). + Field("tenant_id"). Unique(). Required(), } @@ -382,7 +405,9 @@ For use cases like this, Ent has an additional type of privacy rule named `Filte We can use `Filter` rules to filter out entities based on the identity of the viewer. Unlike the rules we previously discussed, `Filter` rules can limit the scope of the queries a viewer can make, in addition to returning privacy decisions. -> Note, the privacy filtering option needs to be enabled using the [`entql`](features.md#entql-filtering) feature-flag (see instructions [above](#configuration)). +:::info Note +The privacy filtering option needs to be enabled using the [`entql`](features.md#entql-filtering) feature-flag (see instructions [above](#configuration)). +::: ```go title="examples/privacytenant/rule/rule.go" // FilterTenantRule is a query/mutation rule that filters out entities that are not in the tenant. @@ -390,19 +415,20 @@ func FilterTenantRule() privacy.QueryMutationRule { // TenantsFilter is an interface to wrap WhereHasTenantWith() // predicate that is used by both `Group` and `User` schemas. type TenantsFilter interface { - WhereHasTenantWith(...predicate.Tenant) + WhereTenantID(entql.IntP) } return privacy.FilterFunc(func(ctx context.Context, f privacy.Filter) error { view := viewer.FromContext(ctx) - if view.Tenant() == "" { + tid, ok := view.Tenant() + if !ok { return privacy.Denyf("missing tenant information in viewer") } tf, ok := f.(TenantsFilter) if !ok { return privacy.Denyf("unexpected filter type %T", f) } - // Make sure that a tenant reads only entities that has an edge to it. - tf.WhereHasTenantWith(tenant.Name(view.Tenant())) + // Make sure that a tenant reads only entities that have an edge to it. + tf.WhereTenantID(entql.IntEQ(tid)) // Skip to the next privacy rule (equivalent to return nil). return privacy.Skip }) @@ -415,48 +441,94 @@ that use this mixin, will also have this privacy rule. ```go title="examples/privacytenant/ent/schema/mixin.go" // Policy for all schemas that embed TenantMixin. func (TenantMixin) Policy() ent.Policy { - return privacy.Policy{ - Query: privacy.QueryPolicy{ - rule.AllowIfAdmin(), - // Filter out entities that are not connected to the tenant. - // If the viewer is admin, this policy rule is skipped above. - rule.FilterTenantRule(), - }, - } + return rule.FilterTenantRule() } ``` Then, after running the code-generation, we expect the privacy-rules to take effect on the client operations. ```go title="examples/privacytenant/example_test.go" -func Do(ctx context.Context, client *ent.Client) error { - // A continuation of the code-block above. - // Create 2 users connected to the 2 tenants we created above - hubUser := client.User.Create().SetName("a8m").SetTenant(hub).SaveX(adminCtx) - labUser := client.User.Create().SetName("nati").SetTenant(lab).SaveX(adminCtx) +func Example_TenantView(ctx context.Context, client *ent.Client) { + // Operations should pass successfully as the user in the viewer-context + // is an admin user. First mutation privacy policy in Tenant schema. + adminCtx := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.Admin}) + hub := client.Tenant.Create().SetName("GitHub").SaveX(adminCtx) + lab := client.Tenant.Create().SetName("GitLab").SaveX(adminCtx) + // Create 2 tenant-specific viewer contexts. hubView := viewer.NewContext(ctx, viewer.UserViewer{T: hub}) - out := client.User.Query().OnlyX(hubView) - // Expect that "GitHub" tenant to read only its users (i.e. a8m). - if out.ID != hubUser.ID { - return fmt.Errorf("expect result for user query, got %v", out) - } - fmt.Println(out) - labView := viewer.NewContext(ctx, viewer.UserViewer{T: lab}) - out = client.User.Query().OnlyX(labView) - // Expect that "GitLab" tenant to read only its users (i.e. nati). - if out.ID != labUser.ID { - return fmt.Errorf("expect result for user query, got %v", out) + + // Create 2 users in each tenant. + hubUsers := client.User.CreateBulk( + client.User.Create().SetName("a8m").SetTenant(hub), + client.User.Create().SetName("nati").SetTenant(hub), + ).SaveX(hubView) + fmt.Println(hubUsers) + + labUsers := client.User.CreateBulk( + client.User.Create().SetName("foo").SetTenant(lab), + client.User.Create().SetName("bar").SetTenant(lab), + ).SaveX(labView) + fmt.Println(labUsers) + + // Query users should fail in case viewer-context is missing. + if _, err := client.User.Query().Count(ctx); !errors.Is(err, privacy.Deny) { + log.Fatal("expect user query to fail, but got:", err) } - fmt.Println(out) - return nil + + // Ensure each tenant can see only its users. + // First and only rule in TenantMixin. + fmt.Println(client.User.Query().Select(user.FieldName).StringsX(hubView)) + fmt.Println(client.User.Query().CountX(hubView)) + fmt.Println(client.User.Query().Select(user.FieldName).StringsX(labView)) + fmt.Println(client.User.Query().CountX(labView)) + + // Expect admin users to see everything. First + // query privacy policy defined in BaseMixin. + fmt.Println(client.User.Query().CountX(adminCtx)) // 4 + + // Update operation with specific tenant-view should update + // only the tenant in the viewer-context. + client.User.Update().SetFoods([]string{"pizza"}).SaveX(hubView) + fmt.Println(client.User.Query().AllX(hubView)) + fmt.Println(client.User.Query().AllX(labView)) + + // Delete operation with specific tenant-view should delete + // only the tenant in the viewer-context. + client.User.Delete().ExecX(labView) + fmt.Println( + client.User.Query().CountX(hubView), // 2 + client.User.Query().CountX(labView), // 0 + ) + + // DeleteOne with wrong viewer-context is nop. + client.User.DeleteOne(hubUsers[0]).ExecX(labView) + fmt.Println(client.User.Query().CountX(hubView)) // 2 + + // Unlike queries, admin users are not allowed to mutate tenant specific data. + if err := client.User.DeleteOne(hubUsers[0]).Exec(adminCtx); !errors.Is(err, privacy.Deny) { + log.Fatal("expect user deletion to fail, but got:", err) + } + + // Output: + // [User(id=1, tenant_id=1, name=a8m, foods=[]) User(id=2, tenant_id=1, name=nati, foods=[])] + // [User(id=3, tenant_id=2, name=foo, foods=[]) User(id=4, tenant_id=2, name=bar, foods=[])] + // [a8m nati] + // 2 + // [foo bar] + // 2 + // 4 + // [User(id=1, tenant_id=1, name=a8m, foods=[pizza]) User(id=2, tenant_id=1, name=nati, foods=[pizza])] + // [User(id=3, tenant_id=2, name=foo, foods=[]) User(id=4, tenant_id=2, name=bar, foods=[])] + // 2 0 + // 2 } ``` We finish our example with another privacy-rule named `DenyMismatchedTenants` on the `Group` schema. -The `DenyMismatchedTenants` rule rejects group creation if the associated users don't belong to +The `DenyMismatchedTenants` rule rejects group creation if the associated users do not belong to the same tenant as the group. ```go title="examples/privacytenant/rule/rule.go" @@ -473,14 +545,19 @@ func DenyMismatchedTenants() privacy.MutationRule { if len(users) == 0 { return privacy.Skip } - // Query the tenant-id of all users. Expect to have exact 1 result, - // and it matches the tenant-id of the group above. - id, err := m.Client().User.Query().Where(user.IDIn(users...)).QueryTenant().OnlyID(ctx) + // Query the tenant-ids of all attached users. Expect all users to be connected to the same tenant + // as the group. Note, we use privacy.DecisionContext to skip the FilterTenantRule defined above. + ids, err := m.Client().User.Query().Where(user.IDIn(users...)).Select(user.FieldTenantID).Ints(privacy.DecisionContext(ctx, privacy.Allow)) if err != nil { - return privacy.Denyf("querying the tenant-id %v", err) + return privacy.Denyf("querying the tenant-ids %v", err) } - if id != tid { - return privacy.Denyf("mismatch tenant-ids for group/users %d != %d", tid, id) + if len(ids) != len(users) { + return privacy.Denyf("one the attached users is not connected to a tenant %v", err) + } + for _, id := range ids { + if id != tid { + return privacy.Denyf("mismatch tenant-ids for group/users %d != %d", tid, id) + } } // Skip to the next privacy rule (equivalent to return nil). return privacy.Skip @@ -509,72 +586,71 @@ func (Group) Policy() ent.Policy { Again, we expect the privacy-rules to take effect on the client operations. ```go title="examples/privacytenant/example_test.go" -func Do(ctx context.Context, client *ent.Client) error { - // A continuation of the code-block above. +func Example_DenyMismatchedTenants(ctx context.Context, client *ent.Client) { + // Operation should pass successfully as the user in the viewer-context + // is an admin user. First mutation privacy policy in Tenant schema. + adminCtx := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.Admin}) + hub := client.Tenant.Create().SetName("GitHub").SaveX(adminCtx) + lab := client.Tenant.Create().SetName("GitLab").SaveX(adminCtx) - // Expect operation to fail because the DenyMismatchedTenants rule - // makes sure the group and the users are connected to the same tenant. - err = client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(labUser).Exec(adminCtx) - if !errors.Is(err, privacy.Deny) { - return fmt.Errorf("expect operation to fail, since user (nati) is not connected to the same tenant") + // Create 2 tenant-specific viewer contexts. + hubView := viewer.NewContext(ctx, viewer.UserViewer{T: hub}) + labView := viewer.NewContext(ctx, viewer.UserViewer{T: lab}) + + // Create 2 users in each tenant. + hubUsers := client.User.CreateBulk( + client.User.Create().SetName("a8m").SetTenant(hub), + client.User.Create().SetName("nati").SetTenant(hub), + ).SaveX(hubView) + fmt.Println(hubUsers) + + labUsers := client.User.CreateBulk( + client.User.Create().SetName("foo").SetTenant(lab), + client.User.Create().SetName("bar").SetTenant(lab), + ).SaveX(labView) + fmt.Println(labUsers) + + // Expect operation to fail as the DenyMismatchedTenants rule makes + // sure the group and the users are connected to the same tenant. + if err := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(labUsers...).Exec(hubView); !errors.Is(err, privacy.Deny) { + log.Fatal("expect operation to fail, since labUsers are not connected to the same tenant") } - err = client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(labUser, hubUser).Exec(adminCtx) - if !errors.Is(err, privacy.Deny) { - return fmt.Errorf("expect operation to fail, since some users (nati) are not connected to the same tenant") - } - entgo, err := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(hubUser).Save(adminCtx) - if err != nil { - return fmt.Errorf("expect operation to pass, but got %w", err) + if err := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(hubUsers[0], labUsers[0]).Exec(hubView); !errors.Is(err, privacy.Deny) { + log.Fatal("expect operation to fail, since labUsers[0] is not connected to the same tenant") } + // Expect mutation to pass as all users belong to the same tenant as the group. + entgo := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(hubUsers...).SaveX(hubView) fmt.Println(entgo) - return nil + + // Output: + // [User(id=1, tenant_id=1, name=a8m, foods=[]) User(id=2, tenant_id=1, name=nati, foods=[])] + // [User(id=3, tenant_id=2, name=foo, foods=[]) User(id=4, tenant_id=2, name=bar, foods=[])] + // Group(id=1, tenant_id=1, name=entgo.io) } ``` -In some cases, we want to reject user operations on entities that don't belong to their tenant **without loading -these entities from the database** (unlike the `DenyMismatchedTenants` example above). -To achieve this, we can use the `FilterTenantRule` rule for mutations as well, but limit it to specific operations as follows: - -```go title="examples/privacytenant/ent/schema/group.go" -// Policy defines the privacy policy of the Group. -func (Group) Policy() ent.Policy { - return privacy.Policy{ - Mutation: privacy.MutationPolicy{ - // Limit DenyMismatchedTenants only for - // Create operations - privacy.OnMutationOperation( - rule.DenyMismatchedTenants(), - ent.OpCreate, - ), - // Limit the FilterTenantRule only for - // UpdateOne and DeleteOne operations. - privacy.OnMutationOperation( - rule.FilterTenantRule(), - ent.OpUpdateOne|ent.OpDeleteOne, - ), - }, - } -} -``` - -Then, we expect the privacy-rules to take effect on the client operations. +In some cases, we want to reject user operations on entities that do not belong to their tenant **without loading +these entities from the database** (unlike the `DenyMismatchedTenants` example above). +To achieve this, we rely on the `FilterTenantRule` rule to add its filtering on mutations as well, and expect +operations to fail with `NotFoundError` in case the `tenant_id` column does not match the one stored in the +viewer-context. ```go title="examples/privacytenant/example_test.go" -func Do(ctx context.Context, client *ent.Client) error { - // A continuation of the code-block above. - +func Example_DenyMismatchedView(ctx context.Context, client *ent.Client) { + // Continuation of the code above. + // Expect operation to fail, because the FilterTenantRule rule makes sure // that tenants can update and delete only their groups. - err = entgo.Update().SetName("fail.go").Exec(labView) - if !ent.IsNotFound(err) { - return fmt.Errorf("expect operation to fail, since the group (entgo) is managed by a different tenant (hub), but got %w", err) - } - entgo, err = entgo.Update().SetName("entgo").Save(hubView) - if err != nil { - return fmt.Errorf("expect operation to pass, but got %w", err) + if err := entgo.Update().SetName("fail.go").Exec(labView); !ent.IsNotFound(err) { + log.Fatal("expect operation to fail, since the group (entgo) is managed by a different tenant (hub), but got:", err) } + + // Operation should pass in case it was applied with the right viewer-context. + entgo = entgo.Update().SetName("entgo").SaveX(hubView) fmt.Println(entgo) - return nil + + // Output: + // Group(id=1, tenant_id=1, name=entgo) } ``` diff --git a/doc/md/tutorial-todo-gql-filter-input.md b/doc/md/tutorial-todo-gql-filter-input.md index 38d46d931..a1d3f0759 100755 --- a/doc/md/tutorial-todo-gql-filter-input.md +++ b/doc/md/tutorial-todo-gql-filter-input.md @@ -311,9 +311,9 @@ We can use this new filtering as any other predicate: } ``` -### Usage of filter inputs as predicates +### Usage as predicates -The `Filter` option lets use the generated `WhereInput`s as regular repdicates on any type of query: +The `Filter` option lets use the generated `WhereInput`s as regular predicates on any type of query: ```go query := ent.Todo.Query() diff --git a/entc/gen/template.go b/entc/gen/template.go index 0b3e3ce44..39a973fa3 100644 --- a/entc/gen/template.go +++ b/entc/gen/template.go @@ -199,6 +199,8 @@ var ( "tx/additional/*/*", "update/additional/*", "query/additional/*", + "privacy/additional/*", + "privacy/additional/*/*", } // importPkg are the import packages used for code generation. importPkg = make(map[string]string) diff --git a/entc/gen/template/privacy/privacy.tmpl b/entc/gen/template/privacy/privacy.tmpl index 4a3cff083..17e3f4dc7 100644 --- a/entc/gen/template/privacy/privacy.tmpl +++ b/entc/gen/template/privacy/privacy.tmpl @@ -22,11 +22,11 @@ import ( var ( // Allow may be returned by rules to indicate that the policy - // evaluation should terminate with an allow decision. + // evaluation should terminate with allow decision. Allow = privacy.Allow // Deny may be returned by rules to indicate that the policy - // evaluation should terminate with an deny decision. + // evaluation should terminate with deny decision. Deny = privacy.Deny // Skip may be returned by rules to indicate that the policy @@ -53,11 +53,20 @@ func DecisionFromContext(ctx context.Context) (error, bool) { } type ( + // Policy groups query and mutation policies. + Policy = privacy.Policy + // QueryRule defines the interface deciding whether a // query is allowed and optionally modify it. QueryRule = privacy.QueryRule // QueryPolicy combines multiple query rules into a single policy. QueryPolicy = privacy.QueryPolicy + + // MutationRule defines the interface which decides whether a + // mutation is allowed and optionally modifies it. + MutationRule = privacy.MutationRule + // MutationPolicy combines multiple mutation rules into a single policy. + MutationPolicy = privacy.MutationPolicy ) // QueryRuleFunc type is an adapter to allow the use of @@ -69,14 +78,6 @@ func (f QueryRuleFunc) EvalQuery(ctx context.Context, q {{ $pkg }}.Query) error return f(ctx, q) } -type ( - // MutationRule defines the interface which decides whether a - // mutation is allowed and optionally modifies it. - MutationRule = privacy.MutationRule - // MutationPolicy combines multiple mutation rules into a single policy. - MutationPolicy = privacy.MutationPolicy -) - // MutationRuleFunc type is an adapter which allows the use of // ordinary functions as mutation rules. type MutationRuleFunc func(context.Context, {{ $pkg }}.Mutation) error @@ -86,22 +87,6 @@ func (f MutationRuleFunc) EvalMutation(ctx context.Context, m {{ $pkg }}.Mutatio return f(ctx, m) } -// Policy groups query and mutation policies. -type Policy struct { - Query QueryPolicy - Mutation MutationPolicy -} - -// EvalQuery forwards evaluation to query a policy. -func (policy Policy) EvalQuery(ctx context.Context, q {{ $pkg }}.Query) error { - return policy.Query.EvalQuery(ctx, q) -} - -// EvalMutation forwards evaluation to mutate a policy. -func (policy Policy) EvalMutation(ctx context.Context, m {{ $pkg }}.Mutation) error { - return policy.Mutation.EvalMutation(ctx, m) -} - // QueryMutationRule is an interface which groups query and mutation rules. type QueryMutationRule interface { QueryRule @@ -199,4 +184,10 @@ func DenyMutationOperationRule(op {{ $pkg }}.Op) MutationRule { {{ template "privacy/filter" $ }} {{- end }} +{{- with $tmpls := matchTemplate "privacy/additional/*" "privacy/additional/*/*" }} + {{- range $tmpl := $tmpls }} + {{- xtemplate $tmpl $ }} + {{- end }} +{{- end }} + {{ end }} diff --git a/entc/integration/privacy/ent/privacy/privacy.go b/entc/integration/privacy/ent/privacy/privacy.go index 822a55626..2a223a5d4 100644 --- a/entc/integration/privacy/ent/privacy/privacy.go +++ b/entc/integration/privacy/ent/privacy/privacy.go @@ -18,11 +18,11 @@ import ( var ( // Allow may be returned by rules to indicate that the policy - // evaluation should terminate with an allow decision. + // evaluation should terminate with allow decision. Allow = privacy.Allow // Deny may be returned by rules to indicate that the policy - // evaluation should terminate with an deny decision. + // evaluation should terminate with deny decision. Deny = privacy.Deny // Skip may be returned by rules to indicate that the policy @@ -57,11 +57,20 @@ func DecisionFromContext(ctx context.Context) (error, bool) { } type ( + // Policy groups query and mutation policies. + Policy = privacy.Policy + // QueryRule defines the interface deciding whether a // query is allowed and optionally modify it. QueryRule = privacy.QueryRule // QueryPolicy combines multiple query rules into a single policy. QueryPolicy = privacy.QueryPolicy + + // MutationRule defines the interface which decides whether a + // mutation is allowed and optionally modifies it. + MutationRule = privacy.MutationRule + // MutationPolicy combines multiple mutation rules into a single policy. + MutationPolicy = privacy.MutationPolicy ) // QueryRuleFunc type is an adapter to allow the use of @@ -73,14 +82,6 @@ func (f QueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { return f(ctx, q) } -type ( - // MutationRule defines the interface which decides whether a - // mutation is allowed and optionally modifies it. - MutationRule = privacy.MutationRule - // MutationPolicy combines multiple mutation rules into a single policy. - MutationPolicy = privacy.MutationPolicy -) - // MutationRuleFunc type is an adapter which allows the use of // ordinary functions as mutation rules. type MutationRuleFunc func(context.Context, ent.Mutation) error @@ -90,22 +91,6 @@ func (f MutationRuleFunc) EvalMutation(ctx context.Context, m ent.Mutation) erro return f(ctx, m) } -// Policy groups query and mutation policies. -type Policy struct { - Query QueryPolicy - Mutation MutationPolicy -} - -// EvalQuery forwards evaluation to query a policy. -func (policy Policy) EvalQuery(ctx context.Context, q ent.Query) error { - return policy.Query.EvalQuery(ctx, q) -} - -// EvalMutation forwards evaluation to mutate a policy. -func (policy Policy) EvalMutation(ctx context.Context, m ent.Mutation) error { - return policy.Mutation.EvalMutation(ctx, m) -} - // QueryMutationRule is an interface which groups query and mutation rules. type QueryMutationRule interface { QueryRule diff --git a/examples/privacyadmin/ent/privacy/privacy.go b/examples/privacyadmin/ent/privacy/privacy.go index 5fd9909a7..edf065bca 100644 --- a/examples/privacyadmin/ent/privacy/privacy.go +++ b/examples/privacyadmin/ent/privacy/privacy.go @@ -17,11 +17,11 @@ import ( var ( // Allow may be returned by rules to indicate that the policy - // evaluation should terminate with an allow decision. + // evaluation should terminate with allow decision. Allow = privacy.Allow // Deny may be returned by rules to indicate that the policy - // evaluation should terminate with an deny decision. + // evaluation should terminate with deny decision. Deny = privacy.Deny // Skip may be returned by rules to indicate that the policy @@ -56,11 +56,20 @@ func DecisionFromContext(ctx context.Context) (error, bool) { } type ( + // Policy groups query and mutation policies. + Policy = privacy.Policy + // QueryRule defines the interface deciding whether a // query is allowed and optionally modify it. QueryRule = privacy.QueryRule // QueryPolicy combines multiple query rules into a single policy. QueryPolicy = privacy.QueryPolicy + + // MutationRule defines the interface which decides whether a + // mutation is allowed and optionally modifies it. + MutationRule = privacy.MutationRule + // MutationPolicy combines multiple mutation rules into a single policy. + MutationPolicy = privacy.MutationPolicy ) // QueryRuleFunc type is an adapter to allow the use of @@ -72,14 +81,6 @@ func (f QueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { return f(ctx, q) } -type ( - // MutationRule defines the interface which decides whether a - // mutation is allowed and optionally modifies it. - MutationRule = privacy.MutationRule - // MutationPolicy combines multiple mutation rules into a single policy. - MutationPolicy = privacy.MutationPolicy -) - // MutationRuleFunc type is an adapter which allows the use of // ordinary functions as mutation rules. type MutationRuleFunc func(context.Context, ent.Mutation) error @@ -89,22 +90,6 @@ func (f MutationRuleFunc) EvalMutation(ctx context.Context, m ent.Mutation) erro return f(ctx, m) } -// Policy groups query and mutation policies. -type Policy struct { - Query QueryPolicy - Mutation MutationPolicy -} - -// EvalQuery forwards evaluation to query a policy. -func (policy Policy) EvalQuery(ctx context.Context, q ent.Query) error { - return policy.Query.EvalQuery(ctx, q) -} - -// EvalMutation forwards evaluation to mutate a policy. -func (policy Policy) EvalMutation(ctx context.Context, m ent.Mutation) error { - return policy.Mutation.EvalMutation(ctx, m) -} - // QueryMutationRule is an interface which groups query and mutation rules. type QueryMutationRule interface { QueryRule diff --git a/examples/privacytenant/ent/entql.go b/examples/privacytenant/ent/entql.go index 5fe6bab2d..8ed39f76b 100644 --- a/examples/privacytenant/ent/entql.go +++ b/examples/privacytenant/ent/entql.go @@ -32,7 +32,8 @@ var schemaGraph = func() *sqlgraph.Schema { }, Type: "Group", Fields: map[string]*sqlgraph.FieldSpec{ - group.FieldName: {Type: field.TypeString, Column: group.FieldName}, + group.FieldTenantID: {Type: field.TypeInt, Column: group.FieldTenantID}, + group.FieldName: {Type: field.TypeString, Column: group.FieldName}, }, } graph.Nodes[1] = &sqlgraph.Node{ @@ -60,8 +61,9 @@ var schemaGraph = func() *sqlgraph.Schema { }, Type: "User", Fields: map[string]*sqlgraph.FieldSpec{ - user.FieldName: {Type: field.TypeString, Column: user.FieldName}, - user.FieldFoods: {Type: field.TypeJSON, Column: user.FieldFoods}, + user.FieldTenantID: {Type: field.TypeInt, Column: user.FieldTenantID}, + user.FieldName: {Type: field.TypeString, Column: user.FieldName}, + user.FieldFoods: {Type: field.TypeJSON, Column: user.FieldFoods}, }, } graph.MustAddE( @@ -161,6 +163,11 @@ func (f *GroupFilter) WhereID(p entql.IntP) { f.Where(p.Field(group.FieldID)) } +// WhereTenantID applies the entql int predicate on the tenant_id field. +func (f *GroupFilter) WhereTenantID(p entql.IntP) { + f.Where(p.Field(group.FieldTenantID)) +} + // WhereName applies the entql string predicate on the name field. func (f *GroupFilter) WhereName(p entql.StringP) { f.Where(p.Field(group.FieldName)) @@ -279,6 +286,11 @@ func (f *UserFilter) WhereID(p entql.IntP) { f.Where(p.Field(user.FieldID)) } +// WhereTenantID applies the entql int predicate on the tenant_id field. +func (f *UserFilter) WhereTenantID(p entql.IntP) { + f.Where(p.Field(user.FieldTenantID)) +} + // WhereName applies the entql string predicate on the name field. func (f *UserFilter) WhereName(p entql.StringP) { f.Where(p.Field(user.FieldName)) diff --git a/examples/privacytenant/ent/group.go b/examples/privacytenant/ent/group.go index 4753499c3..3115314d6 100644 --- a/examples/privacytenant/ent/group.go +++ b/examples/privacytenant/ent/group.go @@ -20,12 +20,13 @@ type Group struct { config `json:"-"` // ID of the ent. ID int `json:"id,omitempty"` + // TenantID holds the value of the "tenant_id" field. + TenantID int `json:"tenant_id,omitempty"` // Name holds the value of the "name" field. Name string `json:"name,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. - Edges GroupEdges `json:"edges"` - group_tenant *int + Edges GroupEdges `json:"edges"` } // GroupEdges holds the relations/edges for other nodes in the graph. @@ -67,12 +68,10 @@ func (*Group) scanValues(columns []string) ([]interface{}, error) { values := make([]interface{}, len(columns)) for i := range columns { switch columns[i] { - case group.FieldID: + case group.FieldID, group.FieldTenantID: values[i] = new(sql.NullInt64) case group.FieldName: values[i] = new(sql.NullString) - case group.ForeignKeys[0]: // group_tenant - values[i] = new(sql.NullInt64) default: return nil, fmt.Errorf("unexpected column %q for type Group", columns[i]) } @@ -94,19 +93,18 @@ func (gr *Group) assignValues(columns []string, values []interface{}) error { return fmt.Errorf("unexpected type %T for field id", value) } gr.ID = int(value.Int64) + case group.FieldTenantID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field tenant_id", values[i]) + } else if value.Valid { + gr.TenantID = int(value.Int64) + } case group.FieldName: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field name", values[i]) } else if value.Valid { gr.Name = value.String } - case group.ForeignKeys[0]: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for edge-field group_tenant", value) - } else if value.Valid { - gr.group_tenant = new(int) - *gr.group_tenant = int(value.Int64) - } } } return nil @@ -145,6 +143,9 @@ func (gr *Group) String() string { var builder strings.Builder builder.WriteString("Group(") builder.WriteString(fmt.Sprintf("id=%v, ", gr.ID)) + builder.WriteString("tenant_id=") + builder.WriteString(fmt.Sprintf("%v", gr.TenantID)) + builder.WriteString(", ") builder.WriteString("name=") builder.WriteString(gr.Name) builder.WriteByte(')') diff --git a/examples/privacytenant/ent/group/group.go b/examples/privacytenant/ent/group/group.go index 21b61333c..326cb6a4d 100644 --- a/examples/privacytenant/ent/group/group.go +++ b/examples/privacytenant/ent/group/group.go @@ -15,6 +15,8 @@ const ( Label = "group" // FieldID holds the string denoting the id field in the database. FieldID = "id" + // FieldTenantID holds the string denoting the tenant_id field in the database. + FieldTenantID = "tenant_id" // FieldName holds the string denoting the name field in the database. FieldName = "name" // EdgeTenant holds the string denoting the tenant edge name in mutations. @@ -29,7 +31,7 @@ const ( // It exists in this package in order to avoid circular dependency with the "tenant" package. TenantInverseTable = "tenants" // TenantColumn is the table column denoting the tenant relation/edge. - TenantColumn = "group_tenant" + TenantColumn = "tenant_id" // UsersTable is the table that holds the users relation/edge. The primary key declared below. UsersTable = "user_groups" // UsersInverseTable is the table name for the User entity. @@ -40,15 +42,10 @@ const ( // Columns holds all SQL columns for group fields. var Columns = []string{ FieldID, + FieldTenantID, FieldName, } -// ForeignKeys holds the SQL foreign-keys that are owned by the "groups" -// table and are not defined as standalone fields in the schema. -var ForeignKeys = []string{ - "group_tenant", -} - var ( // UsersPrimaryKey and UsersColumn2 are the table columns denoting the // primary key for the users relation (M2M). @@ -62,11 +59,6 @@ func ValidColumn(column string) bool { return true } } - for i := range ForeignKeys { - if column == ForeignKeys[i] { - return true - } - } return false } diff --git a/examples/privacytenant/ent/group/where.go b/examples/privacytenant/ent/group/where.go index bab0ecc5c..a2af9f33d 100644 --- a/examples/privacytenant/ent/group/where.go +++ b/examples/privacytenant/ent/group/where.go @@ -95,6 +95,13 @@ func IDLTE(id int) predicate.Group { }) } +// TenantID applies equality check predicate on the "tenant_id" field. It's identical to TenantIDEQ. +func TenantID(v int) predicate.Group { + return predicate.Group(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldTenantID), v)) + }) +} + // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.Group { return predicate.Group(func(s *sql.Selector) { @@ -102,6 +109,54 @@ func Name(v string) predicate.Group { }) } +// TenantIDEQ applies the EQ predicate on the "tenant_id" field. +func TenantIDEQ(v int) predicate.Group { + return predicate.Group(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldTenantID), v)) + }) +} + +// TenantIDNEQ applies the NEQ predicate on the "tenant_id" field. +func TenantIDNEQ(v int) predicate.Group { + return predicate.Group(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldTenantID), v)) + }) +} + +// TenantIDIn applies the In predicate on the "tenant_id" field. +func TenantIDIn(vs ...int) predicate.Group { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Group(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldTenantID), v...)) + }) +} + +// TenantIDNotIn applies the NotIn predicate on the "tenant_id" field. +func TenantIDNotIn(vs ...int) predicate.Group { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Group(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldTenantID), v...)) + }) +} + // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/examples/privacytenant/ent/group_create.go b/examples/privacytenant/ent/group_create.go index 893a41d1c..a9624ec00 100644 --- a/examples/privacytenant/ent/group_create.go +++ b/examples/privacytenant/ent/group_create.go @@ -25,6 +25,12 @@ type GroupCreate struct { hooks []Hook } +// SetTenantID sets the "tenant_id" field. +func (gc *GroupCreate) SetTenantID(i int) *GroupCreate { + gc.mutation.SetTenantID(i) + return gc +} + // SetName sets the "name" field. func (gc *GroupCreate) SetName(s string) *GroupCreate { gc.mutation.SetName(s) @@ -39,12 +45,6 @@ func (gc *GroupCreate) SetNillableName(s *string) *GroupCreate { return gc } -// SetTenantID sets the "tenant" edge to the Tenant entity by ID. -func (gc *GroupCreate) SetTenantID(id int) *GroupCreate { - gc.mutation.SetTenantID(id) - return gc -} - // SetTenant sets the "tenant" edge to the Tenant entity. func (gc *GroupCreate) SetTenant(t *Tenant) *GroupCreate { return gc.SetTenantID(t.ID) @@ -153,6 +153,9 @@ func (gc *GroupCreate) defaults() error { // check runs all checks and user-defined validators on the builder. func (gc *GroupCreate) check() error { + if _, ok := gc.mutation.TenantID(); !ok { + return &ValidationError{Name: "tenant_id", err: errors.New(`ent: missing required field "Group.tenant_id"`)} + } if _, ok := gc.mutation.Name(); !ok { return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Group.name"`)} } @@ -211,7 +214,7 @@ func (gc *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { for _, k := range nodes { edge.Target.Nodes = append(edge.Target.Nodes, k) } - _node.group_tenant = &nodes[0] + _node.TenantID = nodes[0] _spec.Edges = append(_spec.Edges, edge) } if nodes := gc.mutation.UsersIDs(); len(nodes) > 0 { diff --git a/examples/privacytenant/ent/group_query.go b/examples/privacytenant/ent/group_query.go index 563d9b491..916376993 100644 --- a/examples/privacytenant/ent/group_query.go +++ b/examples/privacytenant/ent/group_query.go @@ -34,7 +34,6 @@ type GroupQuery struct { // eager-loading edges. withTenant *TenantQuery withUsers *UserQuery - withFKs bool // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -333,12 +332,12 @@ func (gq *GroupQuery) WithUsers(opts ...func(*UserQuery)) *GroupQuery { // Example: // // var v []struct { -// Name string `json:"name,omitempty"` +// TenantID int `json:"tenant_id,omitempty"` // Count int `json:"count,omitempty"` // } // // client.Group.Query(). -// GroupBy(group.FieldName). +// GroupBy(group.FieldTenantID). // Aggregate(ent.Count()). // Scan(ctx, &v) // @@ -362,11 +361,11 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Example: // // var v []struct { -// Name string `json:"name,omitempty"` +// TenantID int `json:"tenant_id,omitempty"` // } // // client.Group.Query(). -// Select(group.FieldName). +// Select(group.FieldTenantID). // Scan(ctx, &v) // func (gq *GroupQuery) Select(fields ...string) *GroupSelect { @@ -402,19 +401,12 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} - withFKs = gq.withFKs _spec = gq.querySpec() loadedTypes = [2]bool{ gq.withTenant != nil, gq.withUsers != nil, } ) - if gq.withTenant != nil { - withFKs = true - } - if withFKs { - _spec.Node.Columns = append(_spec.Node.Columns, group.ForeignKeys...) - } _spec.ScanValues = func(columns []string) ([]interface{}, error) { return (*Group).scanValues(nil, columns) } @@ -438,10 +430,7 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, ids := make([]int, 0, len(nodes)) nodeids := make(map[int][]*Group) for i := range nodes { - if nodes[i].group_tenant == nil { - continue - } - fk := *nodes[i].group_tenant + fk := nodes[i].TenantID if _, ok := nodeids[fk]; !ok { ids = append(ids, fk) } @@ -455,7 +444,7 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, for _, n := range neighbors { nodes, ok := nodeids[n.ID] if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_tenant" returned %v`, n.ID) + return nil, fmt.Errorf(`unexpected foreign-key "tenant_id" returned %v`, n.ID) } for i := range nodes { nodes[i].Edges.Tenant = n diff --git a/examples/privacytenant/ent/group_update.go b/examples/privacytenant/ent/group_update.go index 3fc823c6e..0b16b523e 100644 --- a/examples/privacytenant/ent/group_update.go +++ b/examples/privacytenant/ent/group_update.go @@ -33,6 +33,12 @@ func (gu *GroupUpdate) Where(ps ...predicate.Group) *GroupUpdate { return gu } +// SetTenantID sets the "tenant_id" field. +func (gu *GroupUpdate) SetTenantID(i int) *GroupUpdate { + gu.mutation.SetTenantID(i) + return gu +} + // SetName sets the "name" field. func (gu *GroupUpdate) SetName(s string) *GroupUpdate { gu.mutation.SetName(s) @@ -47,12 +53,6 @@ func (gu *GroupUpdate) SetNillableName(s *string) *GroupUpdate { return gu } -// SetTenantID sets the "tenant" edge to the Tenant entity by ID. -func (gu *GroupUpdate) SetTenantID(id int) *GroupUpdate { - gu.mutation.SetTenantID(id) - return gu -} - // SetTenant sets the "tenant" edge to the Tenant entity. func (gu *GroupUpdate) SetTenant(t *Tenant) *GroupUpdate { return gu.SetTenantID(t.ID) @@ -306,6 +306,12 @@ type GroupUpdateOne struct { mutation *GroupMutation } +// SetTenantID sets the "tenant_id" field. +func (guo *GroupUpdateOne) SetTenantID(i int) *GroupUpdateOne { + guo.mutation.SetTenantID(i) + return guo +} + // SetName sets the "name" field. func (guo *GroupUpdateOne) SetName(s string) *GroupUpdateOne { guo.mutation.SetName(s) @@ -320,12 +326,6 @@ func (guo *GroupUpdateOne) SetNillableName(s *string) *GroupUpdateOne { return guo } -// SetTenantID sets the "tenant" edge to the Tenant entity by ID. -func (guo *GroupUpdateOne) SetTenantID(id int) *GroupUpdateOne { - guo.mutation.SetTenantID(id) - return guo -} - // SetTenant sets the "tenant" edge to the Tenant entity. func (guo *GroupUpdateOne) SetTenant(t *Tenant) *GroupUpdateOne { return guo.SetTenantID(t.ID) diff --git a/examples/privacytenant/ent/internal/schema.go b/examples/privacytenant/ent/internal/schema.go index f1ba62b29..abfe3d4a0 100644 --- a/examples/privacytenant/ent/internal/schema.go +++ b/examples/privacytenant/ent/internal/schema.go @@ -10,4 +10,4 @@ // Package internal holds a loadable version of the latest schema. package internal -const Schema = `{"Schema":"entgo.io/ent/examples/privacytenant/ent/schema","Package":"entgo.io/ent/examples/privacytenant/ent","Schemas":[{"name":"Group","config":{"Table":""},"edges":[{"name":"tenant","type":"Tenant","unique":true,"required":true},{"name":"users","type":"User","ref_name":"groups","inverse":true}],"fields":[{"name":"name","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"default":true,"default_value":"Unknown","default_kind":24,"position":{"Index":0,"MixedIn":false,"MixinIndex":0}}],"policy":[{"Index":0,"MixedIn":true,"MixinIndex":0},{"Index":0,"MixedIn":true,"MixinIndex":1},{"Index":0,"MixedIn":false,"MixinIndex":0}]},{"name":"Tenant","config":{"Table":""},"fields":[{"name":"name","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"validators":1,"position":{"Index":0,"MixedIn":false,"MixinIndex":0}}],"policy":[{"Index":0,"MixedIn":true,"MixinIndex":0},{"Index":0,"MixedIn":false,"MixinIndex":0}]},{"name":"User","config":{"Table":""},"edges":[{"name":"tenant","type":"Tenant","unique":true,"required":true},{"name":"groups","type":"Group"}],"fields":[{"name":"name","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"default":true,"default_value":"Unknown","default_kind":24,"position":{"Index":0,"MixedIn":false,"MixinIndex":0}},{"name":"foods","type":{"Type":3,"Ident":"[]string","PkgPath":"","PkgName":"","Nillable":true,"RType":{"Name":"","Ident":"[]string","Kind":23,"PkgPath":"","Methods":{}}},"optional":true,"position":{"Index":1,"MixedIn":false,"MixinIndex":0}}],"policy":[{"Index":0,"MixedIn":true,"MixinIndex":0},{"Index":0,"MixedIn":true,"MixinIndex":1}]}],"Features":["privacy","entql","schema/snapshot"]}` +const Schema = `{"Schema":"entgo.io/ent/examples/privacytenant/ent/schema","Package":"entgo.io/ent/examples/privacytenant/ent","Schemas":[{"name":"Group","config":{"Table":""},"edges":[{"name":"tenant","type":"Tenant","field":"tenant_id","unique":true,"required":true},{"name":"users","type":"User","ref_name":"groups","inverse":true}],"fields":[{"name":"tenant_id","type":{"Type":12,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":0,"MixedIn":true,"MixinIndex":1}},{"name":"name","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"default":true,"default_value":"Unknown","default_kind":24,"position":{"Index":0,"MixedIn":false,"MixinIndex":0}}],"policy":[{"Index":0,"MixedIn":true,"MixinIndex":0},{"Index":0,"MixedIn":true,"MixinIndex":1},{"Index":0,"MixedIn":false,"MixinIndex":0}]},{"name":"Tenant","config":{"Table":""},"fields":[{"name":"name","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"validators":1,"position":{"Index":0,"MixedIn":false,"MixinIndex":0}}],"policy":[{"Index":0,"MixedIn":true,"MixinIndex":0},{"Index":0,"MixedIn":false,"MixinIndex":0}]},{"name":"User","config":{"Table":""},"edges":[{"name":"tenant","type":"Tenant","field":"tenant_id","unique":true,"required":true},{"name":"groups","type":"Group"}],"fields":[{"name":"tenant_id","type":{"Type":12,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":0,"MixedIn":true,"MixinIndex":1}},{"name":"name","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"default":true,"default_value":"Unknown","default_kind":24,"position":{"Index":0,"MixedIn":false,"MixinIndex":0}},{"name":"foods","type":{"Type":3,"Ident":"[]string","PkgPath":"","PkgName":"","Nillable":true,"RType":{"Name":"","Ident":"[]string","Kind":23,"PkgPath":"","Methods":{}}},"optional":true,"position":{"Index":1,"MixedIn":false,"MixinIndex":0}}],"policy":[{"Index":0,"MixedIn":true,"MixinIndex":0},{"Index":0,"MixedIn":true,"MixinIndex":1}]}],"Features":["privacy","entql","schema/snapshot"]}` diff --git a/examples/privacytenant/ent/migrate/schema.go b/examples/privacytenant/ent/migrate/schema.go index 8b07d0611..6219ab2bf 100644 --- a/examples/privacytenant/ent/migrate/schema.go +++ b/examples/privacytenant/ent/migrate/schema.go @@ -16,7 +16,7 @@ var ( GroupsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Default: "Unknown"}, - {Name: "group_tenant", Type: field.TypeInt}, + {Name: "tenant_id", Type: field.TypeInt}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -48,7 +48,7 @@ var ( {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Default: "Unknown"}, {Name: "foods", Type: field.TypeJSON, Nullable: true}, - {Name: "user_tenant", Type: field.TypeInt}, + {Name: "tenant_id", Type: field.TypeInt}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/examples/privacytenant/ent/mutation.go b/examples/privacytenant/ent/mutation.go index adf1feec0..8ac5aedaa 100644 --- a/examples/privacytenant/ent/mutation.go +++ b/examples/privacytenant/ent/mutation.go @@ -150,6 +150,42 @@ func (m *GroupMutation) IDs(ctx context.Context) ([]int, error) { } } +// SetTenantID sets the "tenant_id" field. +func (m *GroupMutation) SetTenantID(i int) { + m.tenant = &i +} + +// TenantID returns the value of the "tenant_id" field in the mutation. +func (m *GroupMutation) TenantID() (r int, exists bool) { + v := m.tenant + if v == nil { + return + } + return *v, true +} + +// OldTenantID returns the old "tenant_id" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldTenantID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTenantID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTenantID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTenantID: %w", err) + } + return oldValue.TenantID, nil +} + +// ResetTenantID resets all changes to the "tenant_id" field. +func (m *GroupMutation) ResetTenantID() { + m.tenant = nil +} + // SetName sets the "name" field. func (m *GroupMutation) SetName(s string) { m.name = &s @@ -186,11 +222,6 @@ func (m *GroupMutation) ResetName() { m.name = nil } -// SetTenantID sets the "tenant" edge to the Tenant entity by id. -func (m *GroupMutation) SetTenantID(id int) { - m.tenant = &id -} - // ClearTenant clears the "tenant" edge to the Tenant entity. func (m *GroupMutation) ClearTenant() { m.clearedtenant = true @@ -201,14 +232,6 @@ func (m *GroupMutation) TenantCleared() bool { return m.clearedtenant } -// TenantID returns the "tenant" edge ID in the mutation. -func (m *GroupMutation) TenantID() (id int, exists bool) { - if m.tenant != nil { - return *m.tenant, true - } - return -} - // TenantIDs returns the "tenant" edge IDs in the mutation. // Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use // TenantID instead. It exists only for internal usage by the builders. @@ -298,7 +321,10 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 1) + fields := make([]string, 0, 2) + if m.tenant != nil { + fields = append(fields, group.FieldTenantID) + } if m.name != nil { fields = append(fields, group.FieldName) } @@ -310,6 +336,8 @@ func (m *GroupMutation) Fields() []string { // schema. func (m *GroupMutation) Field(name string) (ent.Value, bool) { switch name { + case group.FieldTenantID: + return m.TenantID() case group.FieldName: return m.Name() } @@ -321,6 +349,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { // database failed. func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { + case group.FieldTenantID: + return m.OldTenantID(ctx) case group.FieldName: return m.OldName(ctx) } @@ -332,6 +362,13 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e // type. func (m *GroupMutation) SetField(name string, value ent.Value) error { switch name { + case group.FieldTenantID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTenantID(v) + return nil case group.FieldName: v, ok := value.(string) if !ok { @@ -346,13 +383,16 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. func (m *GroupMutation) AddedFields() []string { - return nil + var fields []string + return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } return nil, false } @@ -388,6 +428,9 @@ func (m *GroupMutation) ClearField(name string) error { // It returns an error if the field is not defined in the schema. func (m *GroupMutation) ResetField(name string) error { switch name { + case group.FieldTenantID: + m.ResetTenantID() + return nil case group.FieldName: m.ResetName() return nil @@ -925,6 +968,42 @@ func (m *UserMutation) IDs(ctx context.Context) ([]int, error) { } } +// SetTenantID sets the "tenant_id" field. +func (m *UserMutation) SetTenantID(i int) { + m.tenant = &i +} + +// TenantID returns the value of the "tenant_id" field in the mutation. +func (m *UserMutation) TenantID() (r int, exists bool) { + v := m.tenant + if v == nil { + return + } + return *v, true +} + +// OldTenantID returns the old "tenant_id" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldTenantID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTenantID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTenantID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTenantID: %w", err) + } + return oldValue.TenantID, nil +} + +// ResetTenantID resets all changes to the "tenant_id" field. +func (m *UserMutation) ResetTenantID() { + m.tenant = nil +} + // SetName sets the "name" field. func (m *UserMutation) SetName(s string) { m.name = &s @@ -1010,11 +1089,6 @@ func (m *UserMutation) ResetFoods() { delete(m.clearedFields, user.FieldFoods) } -// SetTenantID sets the "tenant" edge to the Tenant entity by id. -func (m *UserMutation) SetTenantID(id int) { - m.tenant = &id -} - // ClearTenant clears the "tenant" edge to the Tenant entity. func (m *UserMutation) ClearTenant() { m.clearedtenant = true @@ -1025,14 +1099,6 @@ func (m *UserMutation) TenantCleared() bool { return m.clearedtenant } -// TenantID returns the "tenant" edge ID in the mutation. -func (m *UserMutation) TenantID() (id int, exists bool) { - if m.tenant != nil { - return *m.tenant, true - } - return -} - // TenantIDs returns the "tenant" edge IDs in the mutation. // Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use // TenantID instead. It exists only for internal usage by the builders. @@ -1122,7 +1188,10 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 2) + fields := make([]string, 0, 3) + if m.tenant != nil { + fields = append(fields, user.FieldTenantID) + } if m.name != nil { fields = append(fields, user.FieldName) } @@ -1137,6 +1206,8 @@ func (m *UserMutation) Fields() []string { // schema. func (m *UserMutation) Field(name string) (ent.Value, bool) { switch name { + case user.FieldTenantID: + return m.TenantID() case user.FieldName: return m.Name() case user.FieldFoods: @@ -1150,6 +1221,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { // database failed. func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { + case user.FieldTenantID: + return m.OldTenantID(ctx) case user.FieldName: return m.OldName(ctx) case user.FieldFoods: @@ -1163,6 +1236,13 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er // type. func (m *UserMutation) SetField(name string, value ent.Value) error { switch name { + case user.FieldTenantID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTenantID(v) + return nil case user.FieldName: v, ok := value.(string) if !ok { @@ -1184,13 +1264,16 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. func (m *UserMutation) AddedFields() []string { - return nil + var fields []string + return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. func (m *UserMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } return nil, false } @@ -1235,6 +1318,9 @@ func (m *UserMutation) ClearField(name string) error { // It returns an error if the field is not defined in the schema. func (m *UserMutation) ResetField(name string) error { switch name { + case user.FieldTenantID: + m.ResetTenantID() + return nil case user.FieldName: m.ResetName() return nil diff --git a/examples/privacytenant/ent/privacy/privacy.go b/examples/privacytenant/ent/privacy/privacy.go index 2efc12ae8..c43295aba 100644 --- a/examples/privacytenant/ent/privacy/privacy.go +++ b/examples/privacytenant/ent/privacy/privacy.go @@ -18,11 +18,11 @@ import ( var ( // Allow may be returned by rules to indicate that the policy - // evaluation should terminate with an allow decision. + // evaluation should terminate with allow decision. Allow = privacy.Allow // Deny may be returned by rules to indicate that the policy - // evaluation should terminate with an deny decision. + // evaluation should terminate with deny decision. Deny = privacy.Deny // Skip may be returned by rules to indicate that the policy @@ -57,11 +57,20 @@ func DecisionFromContext(ctx context.Context) (error, bool) { } type ( + // Policy groups query and mutation policies. + Policy = privacy.Policy + // QueryRule defines the interface deciding whether a // query is allowed and optionally modify it. QueryRule = privacy.QueryRule // QueryPolicy combines multiple query rules into a single policy. QueryPolicy = privacy.QueryPolicy + + // MutationRule defines the interface which decides whether a + // mutation is allowed and optionally modifies it. + MutationRule = privacy.MutationRule + // MutationPolicy combines multiple mutation rules into a single policy. + MutationPolicy = privacy.MutationPolicy ) // QueryRuleFunc type is an adapter to allow the use of @@ -73,14 +82,6 @@ func (f QueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { return f(ctx, q) } -type ( - // MutationRule defines the interface which decides whether a - // mutation is allowed and optionally modifies it. - MutationRule = privacy.MutationRule - // MutationPolicy combines multiple mutation rules into a single policy. - MutationPolicy = privacy.MutationPolicy -) - // MutationRuleFunc type is an adapter which allows the use of // ordinary functions as mutation rules. type MutationRuleFunc func(context.Context, ent.Mutation) error @@ -90,22 +91,6 @@ func (f MutationRuleFunc) EvalMutation(ctx context.Context, m ent.Mutation) erro return f(ctx, m) } -// Policy groups query and mutation policies. -type Policy struct { - Query QueryPolicy - Mutation MutationPolicy -} - -// EvalQuery forwards evaluation to query a policy. -func (policy Policy) EvalQuery(ctx context.Context, q ent.Query) error { - return policy.Query.EvalQuery(ctx, q) -} - -// EvalMutation forwards evaluation to mutate a policy. -func (policy Policy) EvalMutation(ctx context.Context, m ent.Mutation) error { - return policy.Mutation.EvalMutation(ctx, m) -} - // QueryMutationRule is an interface which groups query and mutation rules. type QueryMutationRule interface { QueryRule diff --git a/examples/privacytenant/ent/schema/group.go b/examples/privacytenant/ent/schema/group.go index 3e165b094..fbdf0dbc0 100644 --- a/examples/privacytenant/ent/schema/group.go +++ b/examples/privacytenant/ent/schema/group.go @@ -12,7 +12,7 @@ import ( "entgo.io/ent/schema/field" ) -// User holds the schema definition for the Group entity. +// Group holds the schema definition for the Group entity. type Group struct { ent.Schema } @@ -51,12 +51,6 @@ func (Group) Policy() ent.Policy { rule.DenyMismatchedTenants(), ent.OpCreate, ), - // Limit the FilterTenantRule only for - // UpdateOne and DeleteOne operations. - privacy.OnMutationOperation( - rule.FilterTenantRule(), - ent.OpUpdateOne|ent.OpDeleteOne, - ), }, } } diff --git a/examples/privacytenant/ent/schema/mixin.go b/examples/privacytenant/ent/schema/mixin.go index 287e25c08..fb9a7ba11 100644 --- a/examples/privacytenant/ent/schema/mixin.go +++ b/examples/privacytenant/ent/schema/mixin.go @@ -9,6 +9,7 @@ import ( "entgo.io/ent/examples/privacytenant/ent/privacy" "entgo.io/ent/examples/privacytenant/rule" "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" "entgo.io/ent/schema/mixin" ) @@ -20,10 +21,14 @@ type BaseMixin struct { // Policy defines the privacy policy of the BaseMixin. func (BaseMixin) Policy() ent.Policy { return privacy.Policy{ - Mutation: privacy.MutationPolicy{ - rule.DenyIfNoViewer(), - }, Query: privacy.QueryPolicy{ + // Deny any operation in case there is no "viewer context". + rule.DenyIfNoViewer(), + // Allow admins to query any information. + rule.AllowIfAdmin(), + }, + Mutation: privacy.MutationPolicy{ + // Deny any operation in case there is no "viewer context". rule.DenyIfNoViewer(), }, } @@ -34,10 +39,18 @@ type TenantMixin struct { mixin.Schema } +// Fields for all schemas that embed TenantMixin. +func (TenantMixin) Fields() []ent.Field { + return []ent.Field{ + field.Int("tenant_id"), + } +} + // Edges for all schemas that embed TenantMixin. func (TenantMixin) Edges() []ent.Edge { return []ent.Edge{ edge.To("tenant", Tenant.Type). + Field("tenant_id"). Unique(). Required(), } @@ -45,12 +58,5 @@ func (TenantMixin) Edges() []ent.Edge { // Policy for all schemas that embed TenantMixin. func (TenantMixin) Policy() ent.Policy { - return privacy.Policy{ - Query: privacy.QueryPolicy{ - rule.AllowIfAdmin(), - // Filter out entities that are not connected to the tenant. - // If the viewer is admin, this policy rule is skipped above. - rule.FilterTenantRule(), - }, - } + return rule.FilterTenantRule() } diff --git a/examples/privacytenant/ent/schema/user.go b/examples/privacytenant/ent/schema/user.go index 0a187d145..e8841118d 100644 --- a/examples/privacytenant/ent/schema/user.go +++ b/examples/privacytenant/ent/schema/user.go @@ -42,6 +42,6 @@ func (User) Edges() []ent.Edge { // Policy defines the privacy policy of the User. func (User) Policy() ent.Policy { - // Privacy policy defined in the TenantMixin. + // Privacy policy defined in the BaseMixin and TenantMixin. return nil } diff --git a/examples/privacytenant/ent/template/privacy.tmpl b/examples/privacytenant/ent/template/privacy.tmpl new file mode 100644 index 000000000..fae6f54b8 --- /dev/null +++ b/examples/privacytenant/ent/template/privacy.tmpl @@ -0,0 +1,29 @@ +{{/* The line below tells Intellij/GoLand to enable the autocompletion based on the *gen.Graph type. */}} +{{/* gotype: entgo.io/ent/entc/gen.Graph */}} + +{{/* An example privacy rule that is appeneded to the generated privacy package. For more info: https://entgo.io/docs/privacy#template-based */}} +{{ define "privacy/additional/privacy" }} +// FilterTenantRule is a query/mutation rule that filters out entities that are not in the tenant. +func FilterTenantRule() QueryMutationRule { + // TenantsFilter is an interface to wrap WhereTenantID() + // predicate that is used by both `Group` and `User` schemas. + type TenantsFilter interface { + WhereTenantID(entql.IntP) + } + return FilterFunc(func(ctx context.Context, f Filter) error { + view := viewer.FromContext(ctx) + tid, ok := view.Tenant() + if !ok { + return Denyf("missing tenant information in viewer") + } + tf, ok := f.(TenantsFilter) + if !ok { + return Denyf("unexpected filter type %T", f) + } + // Make sure that a tenant reads only entities that have an edge to it. + tf.WhereTenantID(entql.IntEQ(tid)) + // Skip to the next privacy rule (equivalent to return nil). + return Skip + }) +} +{{ end }} \ No newline at end of file diff --git a/examples/privacytenant/ent/user.go b/examples/privacytenant/ent/user.go index 0c0aca715..45b32e009 100644 --- a/examples/privacytenant/ent/user.go +++ b/examples/privacytenant/ent/user.go @@ -21,14 +21,15 @@ type User struct { config `json:"-"` // ID of the ent. ID int `json:"id,omitempty"` + // TenantID holds the value of the "tenant_id" field. + TenantID int `json:"tenant_id,omitempty"` // Name holds the value of the "name" field. Name string `json:"name,omitempty"` // Foods holds the value of the "foods" field. Foods []string `json:"foods,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. - Edges UserEdges `json:"edges"` - user_tenant *int + Edges UserEdges `json:"edges"` } // UserEdges holds the relations/edges for other nodes in the graph. @@ -72,12 +73,10 @@ func (*User) scanValues(columns []string) ([]interface{}, error) { switch columns[i] { case user.FieldFoods: values[i] = new([]byte) - case user.FieldID: + case user.FieldID, user.FieldTenantID: values[i] = new(sql.NullInt64) case user.FieldName: values[i] = new(sql.NullString) - case user.ForeignKeys[0]: // user_tenant - values[i] = new(sql.NullInt64) default: return nil, fmt.Errorf("unexpected column %q for type User", columns[i]) } @@ -99,6 +98,12 @@ func (u *User) assignValues(columns []string, values []interface{}) error { return fmt.Errorf("unexpected type %T for field id", value) } u.ID = int(value.Int64) + case user.FieldTenantID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field tenant_id", values[i]) + } else if value.Valid { + u.TenantID = int(value.Int64) + } case user.FieldName: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field name", values[i]) @@ -113,13 +118,6 @@ func (u *User) assignValues(columns []string, values []interface{}) error { return fmt.Errorf("unmarshal field foods: %w", err) } } - case user.ForeignKeys[0]: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for edge-field user_tenant", value) - } else if value.Valid { - u.user_tenant = new(int) - *u.user_tenant = int(value.Int64) - } } } return nil @@ -158,6 +156,9 @@ func (u *User) String() string { var builder strings.Builder builder.WriteString("User(") builder.WriteString(fmt.Sprintf("id=%v, ", u.ID)) + builder.WriteString("tenant_id=") + builder.WriteString(fmt.Sprintf("%v", u.TenantID)) + builder.WriteString(", ") builder.WriteString("name=") builder.WriteString(u.Name) builder.WriteString(", ") diff --git a/examples/privacytenant/ent/user/user.go b/examples/privacytenant/ent/user/user.go index 8f949290d..a18df5c37 100644 --- a/examples/privacytenant/ent/user/user.go +++ b/examples/privacytenant/ent/user/user.go @@ -15,6 +15,8 @@ const ( Label = "user" // FieldID holds the string denoting the id field in the database. FieldID = "id" + // FieldTenantID holds the string denoting the tenant_id field in the database. + FieldTenantID = "tenant_id" // FieldName holds the string denoting the name field in the database. FieldName = "name" // FieldFoods holds the string denoting the foods field in the database. @@ -31,7 +33,7 @@ const ( // It exists in this package in order to avoid circular dependency with the "tenant" package. TenantInverseTable = "tenants" // TenantColumn is the table column denoting the tenant relation/edge. - TenantColumn = "user_tenant" + TenantColumn = "tenant_id" // GroupsTable is the table that holds the groups relation/edge. The primary key declared below. GroupsTable = "user_groups" // GroupsInverseTable is the table name for the Group entity. @@ -42,16 +44,11 @@ const ( // Columns holds all SQL columns for user fields. var Columns = []string{ FieldID, + FieldTenantID, FieldName, FieldFoods, } -// ForeignKeys holds the SQL foreign-keys that are owned by the "users" -// table and are not defined as standalone fields in the schema. -var ForeignKeys = []string{ - "user_tenant", -} - var ( // GroupsPrimaryKey and GroupsColumn2 are the table columns denoting the // primary key for the groups relation (M2M). @@ -65,11 +62,6 @@ func ValidColumn(column string) bool { return true } } - for i := range ForeignKeys { - if column == ForeignKeys[i] { - return true - } - } return false } diff --git a/examples/privacytenant/ent/user/where.go b/examples/privacytenant/ent/user/where.go index 6a9f37f50..0cfa3dbd5 100644 --- a/examples/privacytenant/ent/user/where.go +++ b/examples/privacytenant/ent/user/where.go @@ -95,6 +95,13 @@ func IDLTE(id int) predicate.User { }) } +// TenantID applies equality check predicate on the "tenant_id" field. It's identical to TenantIDEQ. +func TenantID(v int) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldTenantID), v)) + }) +} + // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.User { return predicate.User(func(s *sql.Selector) { @@ -102,6 +109,54 @@ func Name(v string) predicate.User { }) } +// TenantIDEQ applies the EQ predicate on the "tenant_id" field. +func TenantIDEQ(v int) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldTenantID), v)) + }) +} + +// TenantIDNEQ applies the NEQ predicate on the "tenant_id" field. +func TenantIDNEQ(v int) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldTenantID), v)) + }) +} + +// TenantIDIn applies the In predicate on the "tenant_id" field. +func TenantIDIn(vs ...int) predicate.User { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldTenantID), v...)) + }) +} + +// TenantIDNotIn applies the NotIn predicate on the "tenant_id" field. +func TenantIDNotIn(vs ...int) predicate.User { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldTenantID), v...)) + }) +} + // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/examples/privacytenant/ent/user_create.go b/examples/privacytenant/ent/user_create.go index eda5afd4d..848127896 100644 --- a/examples/privacytenant/ent/user_create.go +++ b/examples/privacytenant/ent/user_create.go @@ -25,6 +25,12 @@ type UserCreate struct { hooks []Hook } +// SetTenantID sets the "tenant_id" field. +func (uc *UserCreate) SetTenantID(i int) *UserCreate { + uc.mutation.SetTenantID(i) + return uc +} + // SetName sets the "name" field. func (uc *UserCreate) SetName(s string) *UserCreate { uc.mutation.SetName(s) @@ -45,12 +51,6 @@ func (uc *UserCreate) SetFoods(s []string) *UserCreate { return uc } -// SetTenantID sets the "tenant" edge to the Tenant entity by ID. -func (uc *UserCreate) SetTenantID(id int) *UserCreate { - uc.mutation.SetTenantID(id) - return uc -} - // SetTenant sets the "tenant" edge to the Tenant entity. func (uc *UserCreate) SetTenant(t *Tenant) *UserCreate { return uc.SetTenantID(t.ID) @@ -159,6 +159,9 @@ func (uc *UserCreate) defaults() error { // check runs all checks and user-defined validators on the builder. func (uc *UserCreate) check() error { + if _, ok := uc.mutation.TenantID(); !ok { + return &ValidationError{Name: "tenant_id", err: errors.New(`ent: missing required field "User.tenant_id"`)} + } if _, ok := uc.mutation.Name(); !ok { return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "User.name"`)} } @@ -225,7 +228,7 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { for _, k := range nodes { edge.Target.Nodes = append(edge.Target.Nodes, k) } - _node.user_tenant = &nodes[0] + _node.TenantID = nodes[0] _spec.Edges = append(_spec.Edges, edge) } if nodes := uc.mutation.GroupsIDs(); len(nodes) > 0 { diff --git a/examples/privacytenant/ent/user_query.go b/examples/privacytenant/ent/user_query.go index e04f8e700..2a9b7502d 100644 --- a/examples/privacytenant/ent/user_query.go +++ b/examples/privacytenant/ent/user_query.go @@ -34,7 +34,6 @@ type UserQuery struct { // eager-loading edges. withTenant *TenantQuery withGroups *GroupQuery - withFKs bool // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -333,12 +332,12 @@ func (uq *UserQuery) WithGroups(opts ...func(*GroupQuery)) *UserQuery { // Example: // // var v []struct { -// Name string `json:"name,omitempty"` +// TenantID int `json:"tenant_id,omitempty"` // Count int `json:"count,omitempty"` // } // // client.User.Query(). -// GroupBy(user.FieldName). +// GroupBy(user.FieldTenantID). // Aggregate(ent.Count()). // Scan(ctx, &v) // @@ -362,11 +361,11 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Example: // // var v []struct { -// Name string `json:"name,omitempty"` +// TenantID int `json:"tenant_id,omitempty"` // } // // client.User.Query(). -// Select(user.FieldName). +// Select(user.FieldTenantID). // Scan(ctx, &v) // func (uq *UserQuery) Select(fields ...string) *UserSelect { @@ -402,19 +401,12 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} - withFKs = uq.withFKs _spec = uq.querySpec() loadedTypes = [2]bool{ uq.withTenant != nil, uq.withGroups != nil, } ) - if uq.withTenant != nil { - withFKs = true - } - if withFKs { - _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) - } _spec.ScanValues = func(columns []string) ([]interface{}, error) { return (*User).scanValues(nil, columns) } @@ -438,10 +430,7 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ids := make([]int, 0, len(nodes)) nodeids := make(map[int][]*User) for i := range nodes { - if nodes[i].user_tenant == nil { - continue - } - fk := *nodes[i].user_tenant + fk := nodes[i].TenantID if _, ok := nodeids[fk]; !ok { ids = append(ids, fk) } @@ -455,7 +444,7 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e for _, n := range neighbors { nodes, ok := nodeids[n.ID] if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_tenant" returned %v`, n.ID) + return nil, fmt.Errorf(`unexpected foreign-key "tenant_id" returned %v`, n.ID) } for i := range nodes { nodes[i].Edges.Tenant = n diff --git a/examples/privacytenant/ent/user_update.go b/examples/privacytenant/ent/user_update.go index 02043a40e..daeccb896 100644 --- a/examples/privacytenant/ent/user_update.go +++ b/examples/privacytenant/ent/user_update.go @@ -33,6 +33,12 @@ func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate { return uu } +// SetTenantID sets the "tenant_id" field. +func (uu *UserUpdate) SetTenantID(i int) *UserUpdate { + uu.mutation.SetTenantID(i) + return uu +} + // SetName sets the "name" field. func (uu *UserUpdate) SetName(s string) *UserUpdate { uu.mutation.SetName(s) @@ -59,12 +65,6 @@ func (uu *UserUpdate) ClearFoods() *UserUpdate { return uu } -// SetTenantID sets the "tenant" edge to the Tenant entity by ID. -func (uu *UserUpdate) SetTenantID(id int) *UserUpdate { - uu.mutation.SetTenantID(id) - return uu -} - // SetTenant sets the "tenant" edge to the Tenant entity. func (uu *UserUpdate) SetTenant(t *Tenant) *UserUpdate { return uu.SetTenantID(t.ID) @@ -331,6 +331,12 @@ type UserUpdateOne struct { mutation *UserMutation } +// SetTenantID sets the "tenant_id" field. +func (uuo *UserUpdateOne) SetTenantID(i int) *UserUpdateOne { + uuo.mutation.SetTenantID(i) + return uuo +} + // SetName sets the "name" field. func (uuo *UserUpdateOne) SetName(s string) *UserUpdateOne { uuo.mutation.SetName(s) @@ -357,12 +363,6 @@ func (uuo *UserUpdateOne) ClearFoods() *UserUpdateOne { return uuo } -// SetTenantID sets the "tenant" edge to the Tenant entity by ID. -func (uuo *UserUpdateOne) SetTenantID(id int) *UserUpdateOne { - uuo.mutation.SetTenantID(id) - return uuo -} - // SetTenant sets the "tenant" edge to the Tenant entity. func (uuo *UserUpdateOne) SetTenant(t *Tenant) *UserUpdateOne { return uuo.SetTenantID(t.ID) diff --git a/examples/privacytenant/example_test.go b/examples/privacytenant/example_test.go index 971d5019b..159911c27 100644 --- a/examples/privacytenant/example_test.go +++ b/examples/privacytenant/example_test.go @@ -13,105 +13,196 @@ import ( "entgo.io/ent/examples/privacytenant/ent" "entgo.io/ent/examples/privacytenant/ent/privacy" _ "entgo.io/ent/examples/privacytenant/ent/runtime" + "entgo.io/ent/examples/privacytenant/ent/user" "entgo.io/ent/examples/privacytenant/viewer" _ "github.com/mattn/go-sqlite3" ) -func Example_PrivacyTenant() { - client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - if err != nil { - log.Fatalf("failed opening connection to sqlite: %v", err) - } - defer client.Close() +func Example_CreateTenants() { ctx := context.Background() - // Run the auto migration tool. - if err := client.Schema.Create(ctx); err != nil { - log.Fatalf("failed creating schema resources: %v", err) - } - if err := Do(ctx, client); err != nil { - log.Fatal(err) - } - // Output: - // Tenant(id=1, name=GitHub) - // Tenant(id=2, name=GitLab) - // User(id=1, name=a8m, foods=[]) - // User(id=2, name=nati, foods=[]) - // Group(id=1, name=entgo.io) - // Group(id=1, name=entgo) -} + client := open(ctx) + defer client.Close() -func Do(ctx context.Context, client *ent.Client) error { - // Expect operation to fail, because viewer-context - // is missing (first mutation rule check). + // Expect operation to fail in case viewer-context is missing. + // First mutation privacy policy rule defined in BaseMixin. if err := client.Tenant.Create().Exec(ctx); !errors.Is(err, privacy.Deny) { - return fmt.Errorf("expect operation to fail, but got %w", err) + log.Fatal("expect tenant creation to fail, but got:", err) } - // Deny tenant creation if the viewer is not admin. + + // Expect operation to fail in case the ent.User in the viewer-context + // is not an admin user. Privacy policy defined in the Tenant schema. viewCtx := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.View}) if err := client.Tenant.Create().Exec(viewCtx); !errors.Is(err, privacy.Deny) { - return fmt.Errorf("expect operation to fail, but got %w", err) + log.Fatal("expect tenant creation to fail, but got:", err) } - // Apply the same operation with "Admin" role, expect it to pass. + + // Operation should pass successfully as the user in the viewer-context + // is an admin user. First mutation privacy policy in Tenant schema. adminCtx := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.Admin}) hub, err := client.Tenant.Create().SetName("GitHub").Save(adminCtx) if err != nil { - return fmt.Errorf("expect operation to pass, but got %w", err) + log.Fatal("expect tenant creation to pass, but got:", err) } fmt.Println(hub) + lab, err := client.Tenant.Create().SetName("GitLab").Save(adminCtx) if err != nil { - return fmt.Errorf("expect operation to pass, but got %w", err) + log.Fatal("expect tenant creation to pass, but got:", err) } fmt.Println(lab) - // Create 2 users connected to the 2 tenants we created above - hubUser := client.User.Create().SetName("a8m").SetTenant(hub).SaveX(adminCtx) - labUser := client.User.Create().SetName("nati").SetTenant(lab).SaveX(adminCtx) + // Output: + // Tenant(id=1, name=GitHub) + // Tenant(id=2, name=GitLab) +} +func Example_TenantView() { + ctx := context.Background() + client := open(ctx) + defer client.Close() + + // Operation should pass successfully as the user in the viewer-context + // is an admin user. First mutation privacy policy in Tenant schema. + adminCtx := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.Admin}) + hub := client.Tenant.Create().SetName("GitHub").SaveX(adminCtx) + lab := client.Tenant.Create().SetName("GitLab").SaveX(adminCtx) + + // Create 2 tenant-specific viewer contexts. hubView := viewer.NewContext(ctx, viewer.UserViewer{T: hub}) - out := client.User.Query().OnlyX(hubView) - // Expect that "GitHub" tenant to read only its users (i.e. a8m). - if out.ID != hubUser.ID { - return fmt.Errorf("expect result for user query, got %v", out) - } - fmt.Println(out) - labView := viewer.NewContext(ctx, viewer.UserViewer{T: lab}) - out = client.User.Query().OnlyX(labView) - // Expect that "GitLab" tenant to read only its users (i.e. nati). - if out.ID != labUser.ID { - return fmt.Errorf("expect result for user query, got %v", out) - } - fmt.Println(out) - // Expect operation to fail, because the DenyMismatchedTenants rule makes sure - // the group and the users are connected to the same tenant. - err = client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(labUser).Exec(adminCtx) - if !errors.Is(err, privacy.Deny) { - return fmt.Errorf("expect operation to fail, since user (nati) is not connected to the same tenant") + // Create 2 users in each tenant. + hubUsers := client.User.CreateBulk( + client.User.Create().SetName("a8m").SetTenant(hub), + client.User.Create().SetName("nati").SetTenant(hub), + ).SaveX(hubView) + fmt.Println(hubUsers) + + labUsers := client.User.CreateBulk( + client.User.Create().SetName("foo").SetTenant(lab), + client.User.Create().SetName("bar").SetTenant(lab), + ).SaveX(labView) + fmt.Println(labUsers) + + // Query users should fail in case viewer-context is missing. + if _, err := client.User.Query().Count(ctx); !errors.Is(err, privacy.Deny) { + log.Fatal("expect user query to fail, but got:", err) } - err = client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(labUser, hubUser).Exec(adminCtx) - if !errors.Is(err, privacy.Deny) { - return fmt.Errorf("expect operation to fail, since some users (nati) are not connected to the same tenant") + + // Ensure each tenant can see only its users. + // First and only rule in TenantMixin. + fmt.Println(client.User.Query().Select(user.FieldName).StringsX(hubView)) + fmt.Println(client.User.Query().CountX(hubView)) + fmt.Println(client.User.Query().Select(user.FieldName).StringsX(labView)) + fmt.Println(client.User.Query().CountX(labView)) + + // Expect admin users to see everything. First + // query privacy policy defined in BaseMixin. + fmt.Println(client.User.Query().CountX(adminCtx)) // 4 + + // Update operation with specific tenant-view should update + // only the tenant in the viewer-context. + client.User.Update().SetFoods([]string{"pizza"}).SaveX(hubView) + fmt.Println(client.User.Query().AllX(hubView)) + fmt.Println(client.User.Query().AllX(labView)) + + // Delete operation with specific tenant-view should delete + // only the tenant in the viewer-context. + client.User.Delete().ExecX(labView) + fmt.Println( + client.User.Query().CountX(hubView), // 2 + client.User.Query().CountX(labView), // 0 + ) + + // DeleteOne with wrong viewer-context is nop. + client.User.DeleteOne(hubUsers[0]).ExecX(labView) + fmt.Println(client.User.Query().CountX(hubView)) // 2 + + // Unlike queries, admin users are not allowed to mutate tenant specific data. + if err := client.User.DeleteOne(hubUsers[0]).Exec(adminCtx); !errors.Is(err, privacy.Deny) { + log.Fatal("expect user deletion to fail, but got:", err) } - entgo, err := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(hubUser).Save(adminCtx) - if err != nil { - return fmt.Errorf("expect operation to pass, but got %w", err) + + // Output: + // [User(id=1, tenant_id=1, name=a8m, foods=[]) User(id=2, tenant_id=1, name=nati, foods=[])] + // [User(id=3, tenant_id=2, name=foo, foods=[]) User(id=4, tenant_id=2, name=bar, foods=[])] + // [a8m nati] + // 2 + // [foo bar] + // 2 + // 4 + // [User(id=1, tenant_id=1, name=a8m, foods=[pizza]) User(id=2, tenant_id=1, name=nati, foods=[pizza])] + // [User(id=3, tenant_id=2, name=foo, foods=[]) User(id=4, tenant_id=2, name=bar, foods=[])] + // 2 0 + // 2 +} + +func Example_DenyMismatchedTenants() { + ctx := context.Background() + client := open(ctx) + defer client.Close() + + // Operation should pass successfully as the user in the viewer-context + // is an admin user. First mutation privacy policy in Tenant schema. + adminCtx := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.Admin}) + hub := client.Tenant.Create().SetName("GitHub").SaveX(adminCtx) + lab := client.Tenant.Create().SetName("GitLab").SaveX(adminCtx) + + // Create 2 tenant-specific viewer contexts. + hubView := viewer.NewContext(ctx, viewer.UserViewer{T: hub}) + labView := viewer.NewContext(ctx, viewer.UserViewer{T: lab}) + + // Create 2 users in each tenant. + hubUsers := client.User.CreateBulk( + client.User.Create().SetName("a8m").SetTenant(hub), + client.User.Create().SetName("nati").SetTenant(hub), + ).SaveX(hubView) + fmt.Println(hubUsers) + + labUsers := client.User.CreateBulk( + client.User.Create().SetName("foo").SetTenant(lab), + client.User.Create().SetName("bar").SetTenant(lab), + ).SaveX(labView) + fmt.Println(labUsers) + + // Expect operation to fail as the DenyMismatchedTenants rule makes + // sure the group and the users are connected to the same tenant. + if err := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(labUsers...).Exec(hubView); !errors.Is(err, privacy.Deny) { + log.Fatal("expect operation to fail, since labUsers are not connected to the same tenant") } + if err := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(hubUsers[0], labUsers[0]).Exec(hubView); !errors.Is(err, privacy.Deny) { + log.Fatal("expect operation to fail, since labUsers[0] is not connected to the same tenant") + } + // Expect mutation to pass as all users belong to the same tenant as the group. + entgo := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(hubUsers...).SaveX(hubView) fmt.Println(entgo) // Expect operation to fail, because the FilterTenantRule rule makes sure // that tenants can update and delete only their groups. - err = entgo.Update().SetName("fail.go").Exec(labView) - if !ent.IsNotFound(err) { - return fmt.Errorf("expect operation to fail, since the group (entgo) is managed by a different tenant (hub), but got %w", err) - } - entgo, err = entgo.Update().SetName("entgo").Save(hubView) - if err != nil { - return fmt.Errorf("expect operation to pass, but got %w", err) + if err := entgo.Update().SetName("fail.go").Exec(labView); !ent.IsNotFound(err) { + log.Fatal("expect operation to fail, since the group (entgo) is managed by a different tenant (hub), but got:", err) } + + // Operation should pass in case it was applied with the right viewer-context. + entgo = entgo.Update().SetName("entgo").SaveX(hubView) fmt.Println(entgo) - return nil + // Output: + // [User(id=1, tenant_id=1, name=a8m, foods=[]) User(id=2, tenant_id=1, name=nati, foods=[])] + // [User(id=3, tenant_id=2, name=foo, foods=[]) User(id=4, tenant_id=2, name=bar, foods=[])] + // Group(id=1, tenant_id=1, name=entgo.io) + // Group(id=1, tenant_id=1, name=entgo) +} + +func open(ctx context.Context) *ent.Client { + client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + if err != nil { + log.Fatalf("failed opening connection to sqlite: %v", err) + } + // Run the auto migration tool. + if err := client.Schema.Create(ctx); err != nil { + log.Fatalf("failed creating schema resources: %v", err) + } + return client } diff --git a/examples/privacytenant/rule/rule.go b/examples/privacytenant/rule/rule.go index e22c145df..c33c84b57 100644 --- a/examples/privacytenant/rule/rule.go +++ b/examples/privacytenant/rule/rule.go @@ -7,10 +7,9 @@ package rule import ( "context" + "entgo.io/ent/entql" "entgo.io/ent/examples/privacytenant/ent" - "entgo.io/ent/examples/privacytenant/ent/predicate" "entgo.io/ent/examples/privacytenant/ent/privacy" - "entgo.io/ent/examples/privacytenant/ent/tenant" "entgo.io/ent/examples/privacytenant/ent/user" "entgo.io/ent/examples/privacytenant/viewer" ) @@ -41,22 +40,23 @@ func AllowIfAdmin() privacy.QueryMutationRule { // FilterTenantRule is a query/mutation rule that filters out entities that are not in the tenant. func FilterTenantRule() privacy.QueryMutationRule { - // TenantsFilter is an interface to wrap WhereHasTenantWith() + // TenantsFilter is an interface to wrap WhereTenantID() // predicate that is used by both `Group` and `User` schemas. type TenantsFilter interface { - WhereHasTenantWith(...predicate.Tenant) + WhereTenantID(entql.IntP) } return privacy.FilterFunc(func(ctx context.Context, f privacy.Filter) error { view := viewer.FromContext(ctx) - if view.Tenant() == "" { + tid, ok := view.Tenant() + if !ok { return privacy.Denyf("missing tenant information in viewer") } tf, ok := f.(TenantsFilter) if !ok { return privacy.Denyf("unexpected filter type %T", f) } - // Make sure that a tenant reads only entities that has an edge to it. - tf.WhereHasTenantWith(tenant.Name(view.Tenant())) + // Make sure that a tenant reads only entities that have an edge to it. + tf.WhereTenantID(entql.IntEQ(tid)) // Skip to the next privacy rule (equivalent to return nil). return privacy.Skip }) @@ -75,14 +75,19 @@ func DenyMismatchedTenants() privacy.MutationRule { if len(users) == 0 { return privacy.Skip } - // Query the tenant-id of all users. Expect to have exact 1 result, - // and it matches the tenant-id of the group above. - id, err := m.Client().User.Query().Where(user.IDIn(users...)).QueryTenant().OnlyID(ctx) + // Query the tenant-ids of all attached users. Expect all users to be connected to the same tenant + // as the group. Note, we use privacy.DecisionContext to skip the FilterTenantRule defined above. + ids, err := m.Client().User.Query().Where(user.IDIn(users...)).Select(user.FieldTenantID).Ints(privacy.DecisionContext(ctx, privacy.Allow)) if err != nil { - return privacy.Denyf("querying the tenant-id %v", err) + return privacy.Denyf("querying the tenant-ids %v", err) } - if id != tid { - return privacy.Denyf("mismatch tenant-ids for group/users %d != %d", tid, id) + if len(ids) != len(users) { + return privacy.Denyf("one the attached users is not connected to a tenant %v", err) + } + for _, id := range ids { + if id != tid { + return privacy.Denyf("mismatch tenant-ids for group/users %d != %d", tid, id) + } } // Skip to the next privacy rule (equivalent to return nil). return privacy.Skip diff --git a/examples/privacytenant/viewer/viewer.go b/examples/privacytenant/viewer/viewer.go index 36137783e..7dd7e4030 100644 --- a/examples/privacytenant/viewer/viewer.go +++ b/examples/privacytenant/viewer/viewer.go @@ -22,8 +22,8 @@ const ( // Viewer describes the query/mutation viewer-context. type Viewer interface { - Admin() bool // If viewer is admin. - Tenant() string // Tenant name. + Admin() bool // If viewer is admin. + Tenant() (int, bool) // Tenant identifier. } // UserViewer describes a user-viewer. @@ -36,11 +36,11 @@ func (v UserViewer) Admin() bool { return v.Role&Admin != 0 } -func (v UserViewer) Tenant() string { +func (v UserViewer) Tenant() (int, bool) { if v.T != nil { - return v.T.Name + return v.T.ID, true } - return "" + return 0, false } type ctxKey struct{} diff --git a/privacy/privacy.go b/privacy/privacy.go index 87ac98b6d..6494a5d45 100644 --- a/privacy/privacy.go +++ b/privacy/privacy.go @@ -29,9 +29,6 @@ var ( ) type ( - // Policies combines multiple policies into a single policy. - Policies []ent.Policy - // QueryRule defines the interface deciding whether a // query is allowed and optionally modify it. QueryRule interface { @@ -49,10 +46,29 @@ type ( // MutationPolicy combines multiple mutation rules into a single policy. MutationPolicy []MutationRule + + // Policy groups query and mutation policies. + Policy struct { + Query QueryPolicy + Mutation MutationPolicy + } ) +// EvalQuery forwards evaluation to query a policy. +func (p Policy) EvalQuery(ctx context.Context, q ent.Query) error { + return p.Query.EvalQuery(ctx, q) +} + +// EvalMutation forwards evaluation to mutate a policy. +func (p Policy) EvalMutation(ctx context.Context, m ent.Mutation) error { + return p.Mutation.EvalMutation(ctx, m) +} + // NewPolicies creates an ent.Policy from list of mixin.Schema // and ent.Schema that implement the ent.Policy interface. +// +// Note that, this is a runtime function used by the ent generated +// code and should not be used in ent/schemas as a privacy rule. func NewPolicies(schemas ...interface{ Policy() ent.Policy }) ent.Policy { policies := make(Policies, 0, len(schemas)) for i := range schemas { @@ -63,6 +79,12 @@ func NewPolicies(schemas ...interface{ Policy() ent.Policy }) ent.Policy { return policies } +// Policies combines multiple policies into a single policy. +// +// Note that, this is a runtime type used by the ent generated +// code and should not be used in ent/schemas as a privacy rule. +type Policies []ent.Policy + // EvalQuery evaluates the query policies. If the Allow error is returned // from one of the policies, it stops the evaluation with a nil error. func (policies Policies) EvalQuery(ctx context.Context, q ent.Query) error {