doc: explain how to use policies in migrations (#4141)

This commit is contained in:
Ariel Mashraki
2024-07-14 14:53:22 +03:00
committed by GitHub
parent 1073ce511e
commit 9f61938bcc
37 changed files with 5768 additions and 7 deletions

View File

@@ -8,7 +8,9 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"strconv"
"strings"
"entgo.io/ent/dialect"
@@ -81,6 +83,31 @@ type Tx struct {
driver.Tx
}
// ctyVarsKey is the key used for attaching and reading the context variables.
type ctxVarsKey struct{}
// sessionVars holds sessions/transactions variables to set before every statement.
type sessionVars struct {
vars []struct{ k, v string }
}
// WithVar returns a new context that holds the session variable to be executed before every query.
func WithVar(ctx context.Context, name, value string) context.Context {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
sv.vars = append(sv.vars, struct {
k, v string
}{
k: name,
v: value,
})
return context.WithValue(ctx, ctxVarsKey{}, sv)
}
// WithIntVar calls WithVar with the string representation of the value.
func WithIntVar(ctx context.Context, name string, value int) context.Context {
return WithVar(ctx, name, strconv.Itoa(value))
}
// ExecQuerier wraps the standard Exec and Query methods.
type ExecQuerier interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
@@ -93,18 +120,25 @@ type Conn struct {
}
// Exec implements the dialect.Exec method.
func (c Conn) Exec(ctx context.Context, query string, args, v any) error {
func (c Conn) Exec(ctx context.Context, query string, args, v any) (rerr error) {
argv, ok := args.([]any)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", v)
}
ex, cf, err := c.maySetVars(ctx)
if err != nil {
return err
}
if cf != nil {
defer func() { rerr = errors.Join(rerr, cf()) }()
}
switch v := v.(type) {
case nil:
if _, err := c.ExecContext(ctx, query, argv...); err != nil {
if _, err := ex.ExecContext(ctx, query, argv...); err != nil {
return err
}
case *sql.Result:
res, err := c.ExecContext(ctx, query, argv...)
res, err := ex.ExecContext(ctx, query, argv...)
if err != nil {
return err
}
@@ -125,14 +159,55 @@ func (c Conn) Query(ctx context.Context, query string, args, v any) error {
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", args)
}
rows, err := c.QueryContext(ctx, query, argv...)
ex, cf, err := c.maySetVars(ctx)
if err != nil {
return err
}
rows, err := ex.QueryContext(ctx, query, argv...)
if err != nil {
if cf != nil {
err = errors.Join(err, cf())
}
return err
}
*vr = Rows{rows}
if cf != nil {
vr.ColumnScanner = rowsWithCloser{rows, cf}
}
return nil
}
// maySetVars sets the session variables before executing a query.
func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error) {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
if len(sv.vars) == 0 {
return c, nil, nil
}
var (
ex ExecQuerier // Underlying ExecQuerier.
cf func() error // Close function.
)
switch e := c.ExecQuerier.(type) {
case *sql.Tx:
ex = e
case *sql.DB:
conn, err := e.Conn(ctx)
if err != nil {
return nil, nil, err
}
ex, cf = conn, conn.Close
}
for _, s := range sv.vars {
if _, err := ex.ExecContext(ctx, fmt.Sprintf("SET %s = '%s'", s.k, s.v)); err != nil {
if cf != nil {
err = errors.Join(err, cf())
}
return nil, nil, err
}
}
return ex, cf, nil
}
var _ dialect.Driver = (*Driver)(nil)
type (
@@ -154,9 +229,8 @@ type (
TxOptions = sql.TxOptions
)
// NullScanner represents an sql.Scanner that may be null.
// NullScanner implements the sql.Scanner interface so it can
// be used as a scan destination, similar to the types above.
// NullScanner implements the sql.Scanner interface such that it
// can be used as a scan destination, similar to the types above.
type NullScanner struct {
S sql.Scanner
Valid bool // Valid is true if the Scan value is not NULL.
@@ -182,3 +256,15 @@ type ColumnScanner interface {
NextResultSet() bool
Scan(dest ...any) error
}
// rowsWithCloser wraps the ColumnScanner interface with a custom Close hook.
type rowsWithCloser struct {
ColumnScanner
closer func() error
}
// Close closes the underlying ColumnScanner and calls the custom closer.
func (r rowsWithCloser) Close() error {
err := r.ColumnScanner.Close()
return errors.Join(err, r.closer())
}

View File

@@ -0,0 +1,87 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sql
import (
"context"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestWithVars(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
db.SetMaxOpenConns(1)
drv := OpenDB("sqlite3", db)
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
rows := &Rows{}
err = drv.Query(
WithVar(context.Background(), "foo", "bar"),
"SELECT 1",
[]any{},
rows,
)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
require.NoError(t, rows.Close(), "rows should be closed to release the connection")
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("SET foo = 'baz'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
err = drv.Query(
WithVar(WithVar(context.Background(), "foo", "bar"), "foo", "baz"),
"SELECT 1",
[]any{},
rows,
)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
require.NoError(t, rows.Close(), "rows should be closed to release the connection")
mock.ExpectBegin()
mock.ExpectExec("SET foo = 'bar'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectCommit()
tx, err := drv.Tx(context.Background())
require.NoError(t, err)
err = tx.Query(
WithVar(context.Background(), "foo", "bar"),
"SELECT 1",
[]any{},
rows,
)
require.NoError(t, err)
require.NoError(t, tx.Commit())
require.NoError(t, mock.ExpectationsWereMet())
// Rows should not be closed to release the session,
// as a transaction is always scoped to a single connection.
mock.ExpectExec("SET foo = 'qux'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO users DEFAULT VALUES").WillReturnResult(sqlmock.NewResult(0, 0))
err = drv.Exec(
WithVar(context.Background(), "foo", "qux"),
"INSERT INTO users DEFAULT VALUES",
[]any{},
nil,
)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
// No rows are returned, so no need to close them.
mock.ExpectExec("SET foo = 'foo'").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO users DEFAULT VALUES").WillReturnResult(sqlmock.NewResult(0, 0))
err = drv.Exec(
WithVar(context.Background(), "foo", "foo"),
"INSERT INTO users DEFAULT VALUES",
[]any{},
nil,
)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
// No rows are returned, so no need to close them.
}

223
doc/md/migration/rls.mdx Normal file
View File

@@ -0,0 +1,223 @@
---
title: Using Row-Level Security in Ent Schema
id: rls
slug: row-level-security
---
import InstallationInstructions from '../components/_installation_instructions.mdx';
Row-level security (RLS) in PostgreSQL enables tables to implement policies that limit access or modification of rows
according to the user's role, enhancing the basic SQL-standard privileges provided by `GRANT`.
Once activated, every standard access to the table has to adhere to these policies. If no policies are defined on the table,
it defaults to a deny-all, meaning no rows can be seen or mutated. These policies can be tailored to specific commands,
roles, or both, allowing for detailed management of who can access or change data.
This guide explains how to attach Row-Level Security (RLS) Policies to your Ent types (objects) and configure the schema
migration to manage both the RLS and the Ent schema as a single migration unit using Atlas.
:::info [Atlas Pro Feature](https://atlasgo.io/features#pro-plan)
Atlas support for [Row-Level Security Policies](https://atlasgo.io/atlas-schema/hcl#row-level-security-policy) used in
this guide is available exclusively to Pro users. To use this feature, run:
```
atlas login
```
:::
## Install Atlas
<InstallationInstructions />
## Login to Atlas
```shell
$ atlas login a8m
//highlight-next-line-info
You are now connected to "a8m" on Atlas Cloud.
```
## Composite Schema
An `ent/schema` package is mostly used for defining Ent types (objects), their fields, edges and logic. Table policies
or any other database native objects do not have representation in Ent models.
In order to extend our PostgreSQL schema to include both our Ent types and their policies, we configure Atlas to
read the state of the schema from a [Composite Schema](https://atlasgo.io/atlas-schema/projects#data-source-composite_schema)
data source. Follow the steps below to configure this for your project:
1\. Let's define a simple schema with two types (tables): `users` and `tenants`:
```go title="ent/schema/tenant.go"
// Tenant holds the schema definition for the Tenant entity.
type Tenant struct {
ent.Schema
}
// Fields of the Tenant.
func (Tenant) Fields() []ent.Field {
return []ent.Field{
field.String("name"),
}
}
// User holds the schema definition for the User entity.
type User struct {
ent.Schema
}
// Fields of the User.
func (User) Fields() []ent.Field {
return []ent.Field{
field.String("name"),
field.Int("tenant_id"),
}
}
```
2\. Now, suppose we want to limit access to the `users` table based on the `tenant_id` field. We can achieve this by defining
a Row-Level Security (RLS) policy on the `users` table. Below is the SQL code that defines the RLS policy:
```sql title="schema.sql"
--- Enable row-level security on the users table.
ALTER TABLE "users" ENABLE ROW LEVEL SECURITY;
-- Create a policy that restricts access to rows in the users table based on the current tenant.
CREATE POLICY tenant_isolation ON "users"
USING ("tenant_id" = current_setting('app.current_tenant')::integer);
```
3\. Lastly, we create a simple `atlas.hcl` config file with a `composite_schema` that includes both our Ent schema and
the custom security policies defined in `schema.sql`:
```hcl title="atlas.hcl"
data "composite_schema" "app" {
# Load the ent schema first with all tables.
schema "public" {
url = "ent://ent/schema"
}
# Then, load the RLS schema.
schema "public" {
url = "file://schema.sql"
}
}
env "local" {
src = data.composite_schema.app.url
dev = "docker://postgres/15/dev?search_path=public"
}
```
## Usage
After setting up our composite schema, we can get its representation using the `atlas schema inspect` command, generate
schema migrations for it, apply them to a database, and more. Below are a few commands to get you started with Atlas:
#### Inspect the Schema
The `atlas schema inspect` command is commonly used to inspect databases. However, we can also use it to inspect our
`composite_schema` and print the SQL representation of it:
```shell
atlas schema inspect \
--env local \
--url env://src \
--format '{{ sql . }}'
```
The command above prints the following SQL. Note, the `tenant_isolation` policy is defined in the schema after the `users`
table:
```sql
-- Create "users" table
CREATE TABLE "users" ("id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, "name" character varying NOT NULL, "tenant_id" bigint NOT NULL, PRIMARY KEY ("id"));
-- Enable row-level security for "users" table
ALTER TABLE "users" ENABLE ROW LEVEL SECURITY;
-- Create policy "tenant_isolation"
CREATE POLICY "tenant_isolation" ON "users" AS PERMISSIVE FOR ALL TO PUBLIC USING (tenant_id = (current_setting('app.current_tenant'::text))::integer);
-- Create "tenants" table
CREATE TABLE "tenants" ("id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, "name" character varying NOT NULL, PRIMARY KEY ("id"));
```
#### Generate Migrations For the Schema
To generate a migration for the schema, run the following command:
```shell
atlas migrate diff \
--env local
```
Note that a new migration file is created with the following content:
```sql title="migrations/20240712090543.sql"
-- Create "users" table
CREATE TABLE "users" ("id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, "name" character varying NOT NULL, "tenant_id" bigint NOT NULL, PRIMARY KEY ("id"));
-- Enable row-level security for "users" table
ALTER TABLE "users" ENABLE ROW LEVEL SECURITY;
-- Create policy "tenant_isolation"
CREATE POLICY "tenant_isolation" ON "users" AS PERMISSIVE FOR ALL TO PUBLIC USING (tenant_id = (current_setting('app.current_tenant'::text))::integer);
-- Create "tenants" table
CREATE TABLE "tenants" ("id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, "name" character varying NOT NULL, PRIMARY KEY ("id"));
```
#### Apply the Migrations
To apply the migration generated above to a database, run the following command:
```
atlas migrate apply \
--env local \
--url "postgres://postgres:pass@localhost:5432/database?search_path=public&sslmode=disable"
```
:::info Apply the Schema Directly on the Database
Sometimes, there is a need to apply the schema directly to the database without generating a migration file. For example,
when experimenting with schema changes, spinning up a database for testing, etc. In such cases, you can use the command
below to apply the schema directly to the database:
```shell
atlas schema apply \
--env local \
--url "postgres://postgres:pass@localhost:5432/database?search_path=public&sslmode=disable"
```
Or, using the [Atlas Go SDK](https://github.com/ariga/atlas-go-sdk):
```go
ac, err := atlasexec.NewClient(".", "atlas")
if err != nil {
log.Fatalf("failed to initialize client: %w", err)
}
// Automatically update the database with the desired schema.
// Another option, is to use 'migrate apply' or 'schema apply' manually.
if _, err := ac.SchemaApply(ctx, &atlasexec.SchemaApplyParams{
Env: "local",
URL: "postgres://postgres:pass@localhost:5432/database?search_path=public&sslmode=disable",
}); err != nil {
log.Fatalf("failed to apply schema changes: %w", err)
}
```
:::
## Code Example
After setting up our Ent schema and the RLS policies, we can open an Ent client and pass the different mutations and
queries the relevant tenant ID we work on. This ensures that the database upholds our RLS policy:
```go
ctx1, ctx2 := sql.WithIntVar(ctx, "app.current_tenant", a8m.ID), sql.WithIntVar(ctx, "app.current_tenant", r3m.ID)
users1 := client.User.Query().AllX(ctx1)
// Users1 can only see users from tenant a8m.
users2 := client.User.Query().AllX(ctx2)
// Users2 can only see users from tenant r3m.
```
:::info Real World Example
In real applications, users can utilize [hooks](/docs/hooks) and [interceptors](/docs/interceptors) to set the `app.current_tenant`
variable based on the user's context.
:::
The code for this guide can be found in [GitHub](https://github.com/ent/ent/tree/master/examples/rls).

View File

@@ -52,6 +52,7 @@ module.exports = {
{type: 'doc', id: 'migration/domain', label: 'Domain Types'},
{type: 'doc', id: 'migration/enum', label: 'Enum Types'},
{type: 'doc', id: 'migration/extension', label: 'Extensions'},
{type: 'doc', id: 'migration/rls', label: 'Row-Level Security'},
{type: 'doc', id: 'migration/trigger', label: 'Triggers'},
],
collapsed: false,

3
examples/rls/README.md Normal file
View File

@@ -0,0 +1,3 @@
## Using PostgreSQL Triggers in Ent Schema
Read the full guide in: https://entgo.io/docs/migration/rls

15
examples/rls/atlas.hcl Normal file
View File

@@ -0,0 +1,15 @@
data "composite_schema" "app" {
# Load the ent schema first with all tables.
schema "public" {
url = "ent://ent/schema"
}
# Then, load the RLS schema.
schema "public" {
url = "file://schema.sql"
}
}
env "local" {
src = data.composite_schema.app.url
dev = "docker://postgres/15/dev?search_path=public"
}

483
examples/rls/ent/client.go Normal file
View File

@@ -0,0 +1,483 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"log"
"reflect"
"entgo.io/ent"
"entgo.io/ent/examples/rls/ent/migrate"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/examples/rls/ent/tenant"
"entgo.io/ent/examples/rls/ent/user"
)
// Client is the client that holds all ent builders.
type Client struct {
config
// Schema is the client for creating, migrating and dropping schema.
Schema *migrate.Schema
// Tenant is the client for interacting with the Tenant builders.
Tenant *TenantClient
// User is the client for interacting with the User builders.
User *UserClient
}
// NewClient creates a new client configured with the given options.
func NewClient(opts ...Option) *Client {
client := &Client{config: newConfig(opts...)}
client.init()
return client
}
func (c *Client) init() {
c.Schema = migrate.NewSchema(c.driver)
c.Tenant = NewTenantClient(c.config)
c.User = NewUserClient(c.config)
}
type (
// config is the configuration for the client and its builder.
config struct {
// driver used for executing database requests.
driver dialect.Driver
// debug enable a debug logging.
debug bool
// log used for logging on debug mode.
log func(...any)
// hooks to execute on mutations.
hooks *hooks
// interceptors to execute on queries.
inters *inters
}
// Option function to configure the client.
Option func(*config)
)
// newConfig creates a new config for the client.
func newConfig(opts ...Option) config {
cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}}
cfg.options(opts...)
return cfg
}
// options applies the options on the config object.
func (c *config) options(opts ...Option) {
for _, opt := range opts {
opt(c)
}
if c.debug {
c.driver = dialect.Debug(c.driver, c.log)
}
}
// Debug enables debug logging on the ent.Driver.
func Debug() Option {
return func(c *config) {
c.debug = true
}
}
// Log sets the logging function for debug mode.
func Log(fn func(...any)) Option {
return func(c *config) {
c.log = fn
}
}
// Driver configures the client driver.
func Driver(driver dialect.Driver) Option {
return func(c *config) {
c.driver = driver
}
}
// Open opens a database/sql.DB specified by the driver name and
// the data source name, and returns a new client attached to it.
// Optional parameters can be added for configuring the client.
func Open(driverName, dataSourceName string, options ...Option) (*Client, error) {
switch driverName {
case dialect.MySQL, dialect.Postgres, dialect.SQLite:
drv, err := sql.Open(driverName, dataSourceName)
if err != nil {
return nil, err
}
return NewClient(append(options, Driver(drv))...), nil
default:
return nil, fmt.Errorf("unsupported driver: %q", driverName)
}
}
// ErrTxStarted is returned when trying to start a new transaction from a transactional client.
var ErrTxStarted = errors.New("ent: cannot start a transaction within a transaction")
// Tx returns a new transactional client. The provided context
// is used until the transaction is committed or rolled back.
func (c *Client) Tx(ctx context.Context) (*Tx, error) {
if _, ok := c.driver.(*txDriver); ok {
return nil, ErrTxStarted
}
tx, err := newTx(ctx, c.driver)
if err != nil {
return nil, fmt.Errorf("ent: starting a transaction: %w", err)
}
cfg := c.config
cfg.driver = tx
return &Tx{
ctx: ctx,
config: cfg,
Tenant: NewTenantClient(cfg),
User: NewUserClient(cfg),
}, nil
}
// BeginTx returns a transactional client with specified options.
func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
if _, ok := c.driver.(*txDriver); ok {
return nil, errors.New("ent: cannot start a transaction within a transaction")
}
tx, err := c.driver.(interface {
BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error)
}).BeginTx(ctx, opts)
if err != nil {
return nil, fmt.Errorf("ent: starting a transaction: %w", err)
}
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
ctx: ctx,
config: cfg,
Tenant: NewTenantClient(cfg),
User: NewUserClient(cfg),
}, nil
}
// Debug returns a new debug-client. It's used to get verbose logging on specific operations.
//
// client.Debug().
// Tenant.
// Query().
// Count(ctx)
func (c *Client) Debug() *Client {
if c.debug {
return c
}
cfg := c.config
cfg.driver = dialect.Debug(c.driver, c.log)
client := &Client{config: cfg}
client.init()
return client
}
// Close closes the database connection and prevents new queries from starting.
func (c *Client) Close() error {
return c.driver.Close()
}
// Use adds the mutation hooks to all the entity clients.
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
func (c *Client) Use(hooks ...Hook) {
c.Tenant.Use(hooks...)
c.User.Use(hooks...)
}
// Intercept adds the query interceptors to all the entity clients.
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
func (c *Client) Intercept(interceptors ...Interceptor) {
c.Tenant.Intercept(interceptors...)
c.User.Intercept(interceptors...)
}
// Mutate implements the ent.Mutator interface.
func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
switch m := m.(type) {
case *TenantMutation:
return c.Tenant.mutate(ctx, m)
case *UserMutation:
return c.User.mutate(ctx, m)
default:
return nil, fmt.Errorf("ent: unknown mutation type %T", m)
}
}
// TenantClient is a client for the Tenant schema.
type TenantClient struct {
config
}
// NewTenantClient returns a client for the Tenant from the given config.
func NewTenantClient(c config) *TenantClient {
return &TenantClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `tenant.Hooks(f(g(h())))`.
func (c *TenantClient) Use(hooks ...Hook) {
c.hooks.Tenant = append(c.hooks.Tenant, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `tenant.Intercept(f(g(h())))`.
func (c *TenantClient) Intercept(interceptors ...Interceptor) {
c.inters.Tenant = append(c.inters.Tenant, interceptors...)
}
// Create returns a builder for creating a Tenant entity.
func (c *TenantClient) Create() *TenantCreate {
mutation := newTenantMutation(c.config, OpCreate)
return &TenantCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of Tenant entities.
func (c *TenantClient) CreateBulk(builders ...*TenantCreate) *TenantCreateBulk {
return &TenantCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *TenantClient) MapCreateBulk(slice any, setFunc func(*TenantCreate, int)) *TenantCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &TenantCreateBulk{err: fmt.Errorf("calling to TenantClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*TenantCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &TenantCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Tenant.
func (c *TenantClient) Update() *TenantUpdate {
mutation := newTenantMutation(c.config, OpUpdate)
return &TenantUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *TenantClient) UpdateOne(t *Tenant) *TenantUpdateOne {
mutation := newTenantMutation(c.config, OpUpdateOne, withTenant(t))
return &TenantUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *TenantClient) UpdateOneID(id int) *TenantUpdateOne {
mutation := newTenantMutation(c.config, OpUpdateOne, withTenantID(id))
return &TenantUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for Tenant.
func (c *TenantClient) Delete() *TenantDelete {
mutation := newTenantMutation(c.config, OpDelete)
return &TenantDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *TenantClient) DeleteOne(t *Tenant) *TenantDeleteOne {
return c.DeleteOneID(t.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *TenantClient) DeleteOneID(id int) *TenantDeleteOne {
builder := c.Delete().Where(tenant.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &TenantDeleteOne{builder}
}
// Query returns a query builder for Tenant.
func (c *TenantClient) Query() *TenantQuery {
return &TenantQuery{
config: c.config,
ctx: &QueryContext{Type: TypeTenant},
inters: c.Interceptors(),
}
}
// Get returns a Tenant entity by its id.
func (c *TenantClient) Get(ctx context.Context, id int) (*Tenant, error) {
return c.Query().Where(tenant.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *TenantClient) GetX(ctx context.Context, id int) *Tenant {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *TenantClient) Hooks() []Hook {
return c.hooks.Tenant
}
// Interceptors returns the client interceptors.
func (c *TenantClient) Interceptors() []Interceptor {
return c.inters.Tenant
}
func (c *TenantClient) mutate(ctx context.Context, m *TenantMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&TenantCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&TenantUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&TenantUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&TenantDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Tenant mutation op: %q", m.Op())
}
}
// UserClient is a client for the User schema.
type UserClient struct {
config
}
// NewUserClient returns a client for the User from the given config.
func NewUserClient(c config) *UserClient {
return &UserClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `user.Hooks(f(g(h())))`.
func (c *UserClient) Use(hooks ...Hook) {
c.hooks.User = append(c.hooks.User, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `user.Intercept(f(g(h())))`.
func (c *UserClient) Intercept(interceptors ...Interceptor) {
c.inters.User = append(c.inters.User, interceptors...)
}
// Create returns a builder for creating a User entity.
func (c *UserClient) Create() *UserCreate {
mutation := newUserMutation(c.config, OpCreate)
return &UserCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of User entities.
func (c *UserClient) CreateBulk(builders ...*UserCreate) *UserCreateBulk {
return &UserCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UserClient) MapCreateBulk(slice any, setFunc func(*UserCreate, int)) *UserCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UserCreateBulk{err: fmt.Errorf("calling to UserClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UserCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UserCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for User.
func (c *UserClient) Update() *UserUpdate {
mutation := newUserMutation(c.config, OpUpdate)
return &UserUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UserClient) UpdateOne(u *User) *UserUpdateOne {
mutation := newUserMutation(c.config, OpUpdateOne, withUser(u))
return &UserUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UserClient) UpdateOneID(id int) *UserUpdateOne {
mutation := newUserMutation(c.config, OpUpdateOne, withUserID(id))
return &UserUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for User.
func (c *UserClient) Delete() *UserDelete {
mutation := newUserMutation(c.config, OpDelete)
return &UserDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UserClient) DeleteOne(u *User) *UserDeleteOne {
return c.DeleteOneID(u.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UserClient) DeleteOneID(id int) *UserDeleteOne {
builder := c.Delete().Where(user.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UserDeleteOne{builder}
}
// Query returns a query builder for User.
func (c *UserClient) Query() *UserQuery {
return &UserQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUser},
inters: c.Interceptors(),
}
}
// Get returns a User entity by its id.
func (c *UserClient) Get(ctx context.Context, id int) (*User, error) {
return c.Query().Where(user.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UserClient) GetX(ctx context.Context, id int) *User {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *UserClient) Hooks() []Hook {
return c.hooks.User
}
// Interceptors returns the client interceptors.
func (c *UserClient) Interceptors() []Interceptor {
return c.inters.User
}
func (c *UserClient) mutate(ctx context.Context, m *UserMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UserCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UserUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UserUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UserDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown User mutation op: %q", m.Op())
}
}
// hooks and interceptors per client, for fast access.
type (
hooks struct {
Tenant, User []ent.Hook
}
inters struct {
Tenant, User []ent.Interceptor
}
)

610
examples/rls/ent/ent.go Normal file
View File

@@ -0,0 +1,610 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/tenant"
"entgo.io/ent/examples/rls/ent/user"
)
// ent aliases to avoid import conflicts in user's code.
type (
Op = ent.Op
Hook = ent.Hook
Value = ent.Value
Query = ent.Query
QueryContext = ent.QueryContext
Querier = ent.Querier
QuerierFunc = ent.QuerierFunc
Interceptor = ent.Interceptor
InterceptFunc = ent.InterceptFunc
Traverser = ent.Traverser
TraverseFunc = ent.TraverseFunc
Policy = ent.Policy
Mutator = ent.Mutator
Mutation = ent.Mutation
MutateFunc = ent.MutateFunc
)
type clientCtxKey struct{}
// FromContext returns a Client stored inside a context, or nil if there isn't one.
func FromContext(ctx context.Context) *Client {
c, _ := ctx.Value(clientCtxKey{}).(*Client)
return c
}
// NewContext returns a new context with the given Client attached.
func NewContext(parent context.Context, c *Client) context.Context {
return context.WithValue(parent, clientCtxKey{}, c)
}
type txCtxKey struct{}
// TxFromContext returns a Tx stored inside a context, or nil if there isn't one.
func TxFromContext(ctx context.Context) *Tx {
tx, _ := ctx.Value(txCtxKey{}).(*Tx)
return tx
}
// NewTxContext returns a new context with the given Tx attached.
func NewTxContext(parent context.Context, tx *Tx) context.Context {
return context.WithValue(parent, txCtxKey{}, tx)
}
// OrderFunc applies an ordering on the sql selector.
// Deprecated: Use Asc/Desc functions or the package builders instead.
type OrderFunc func(*sql.Selector)
var (
initCheck sync.Once
columnCheck sql.ColumnCheck
)
// columnChecker checks if the column exists in the given table.
func checkColumn(table, column string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
tenant.Table: tenant.ValidColumn,
user.Table: user.ValidColumn,
})
})
return columnCheck(table, column)
}
// Asc applies the given fields in ASC order.
func Asc(fields ...string) func(*sql.Selector) {
return func(s *sql.Selector) {
for _, f := range fields {
if err := checkColumn(s.TableName(), f); err != nil {
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)})
}
s.OrderBy(sql.Asc(s.C(f)))
}
}
}
// Desc applies the given fields in DESC order.
func Desc(fields ...string) func(*sql.Selector) {
return func(s *sql.Selector) {
for _, f := range fields {
if err := checkColumn(s.TableName(), f); err != nil {
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)})
}
s.OrderBy(sql.Desc(s.C(f)))
}
}
}
// AggregateFunc applies an aggregation step on the group-by traversal/selector.
type AggregateFunc func(*sql.Selector) string
// As is a pseudo aggregation function for renaming another other functions with custom names. For example:
//
// GroupBy(field1, field2).
// Aggregate(ent.As(ent.Sum(field1), "sum_field1"), (ent.As(ent.Sum(field2), "sum_field2")).
// Scan(ctx, &v)
func As(fn AggregateFunc, end string) AggregateFunc {
return func(s *sql.Selector) string {
return sql.As(fn(s), end)
}
}
// Count applies the "count" aggregation function on each group.
func Count() AggregateFunc {
return func(s *sql.Selector) string {
return sql.Count("*")
}
}
// Max applies the "max" aggregation function on the given field of each group.
func Max(field string) AggregateFunc {
return func(s *sql.Selector) string {
if err := checkColumn(s.TableName(), field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
return ""
}
return sql.Max(s.C(field))
}
}
// Mean applies the "mean" aggregation function on the given field of each group.
func Mean(field string) AggregateFunc {
return func(s *sql.Selector) string {
if err := checkColumn(s.TableName(), field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
return ""
}
return sql.Avg(s.C(field))
}
}
// Min applies the "min" aggregation function on the given field of each group.
func Min(field string) AggregateFunc {
return func(s *sql.Selector) string {
if err := checkColumn(s.TableName(), field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
return ""
}
return sql.Min(s.C(field))
}
}
// Sum applies the "sum" aggregation function on the given field of each group.
func Sum(field string) AggregateFunc {
return func(s *sql.Selector) string {
if err := checkColumn(s.TableName(), field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
return ""
}
return sql.Sum(s.C(field))
}
}
// ValidationError returns when validating a field or edge fails.
type ValidationError struct {
Name string // Field or edge name.
err error
}
// Error implements the error interface.
func (e *ValidationError) Error() string {
return e.err.Error()
}
// Unwrap implements the errors.Wrapper interface.
func (e *ValidationError) Unwrap() error {
return e.err
}
// IsValidationError returns a boolean indicating whether the error is a validation error.
func IsValidationError(err error) bool {
if err == nil {
return false
}
var e *ValidationError
return errors.As(err, &e)
}
// NotFoundError returns when trying to fetch a specific entity and it was not found in the database.
type NotFoundError struct {
label string
}
// Error implements the error interface.
func (e *NotFoundError) Error() string {
return "ent: " + e.label + " not found"
}
// IsNotFound returns a boolean indicating whether the error is a not found error.
func IsNotFound(err error) bool {
if err == nil {
return false
}
var e *NotFoundError
return errors.As(err, &e)
}
// MaskNotFound masks not found error.
func MaskNotFound(err error) error {
if IsNotFound(err) {
return nil
}
return err
}
// NotSingularError returns when trying to fetch a singular entity and more then one was found in the database.
type NotSingularError struct {
label string
}
// Error implements the error interface.
func (e *NotSingularError) Error() string {
return "ent: " + e.label + " not singular"
}
// IsNotSingular returns a boolean indicating whether the error is a not singular error.
func IsNotSingular(err error) bool {
if err == nil {
return false
}
var e *NotSingularError
return errors.As(err, &e)
}
// NotLoadedError returns when trying to get a node that was not loaded by the query.
type NotLoadedError struct {
edge string
}
// Error implements the error interface.
func (e *NotLoadedError) Error() string {
return "ent: " + e.edge + " edge was not loaded"
}
// IsNotLoaded returns a boolean indicating whether the error is a not loaded error.
func IsNotLoaded(err error) bool {
if err == nil {
return false
}
var e *NotLoadedError
return errors.As(err, &e)
}
// ConstraintError returns when trying to create/update one or more entities and
// one or more of their constraints failed. For example, violation of edge or
// field uniqueness.
type ConstraintError struct {
msg string
wrap error
}
// Error implements the error interface.
func (e ConstraintError) Error() string {
return "ent: constraint failed: " + e.msg
}
// Unwrap implements the errors.Wrapper interface.
func (e *ConstraintError) Unwrap() error {
return e.wrap
}
// IsConstraintError returns a boolean indicating whether the error is a constraint failure.
func IsConstraintError(err error) bool {
if err == nil {
return false
}
var e *ConstraintError
return errors.As(err, &e)
}
// selector embedded by the different Select/GroupBy builders.
type selector struct {
label string
flds *[]string
fns []AggregateFunc
scan func(context.Context, any) error
}
// ScanX is like Scan, but panics if an error occurs.
func (s *selector) ScanX(ctx context.Context, v any) {
if err := s.scan(ctx, v); err != nil {
panic(err)
}
}
// Strings returns list of strings from a selector. It is only allowed when selecting one field.
func (s *selector) Strings(ctx context.Context) ([]string, error) {
if len(*s.flds) > 1 {
return nil, errors.New("ent: Strings is not achievable when selecting more than 1 field")
}
var v []string
if err := s.scan(ctx, &v); err != nil {
return nil, err
}
return v, nil
}
// StringsX is like Strings, but panics if an error occurs.
func (s *selector) StringsX(ctx context.Context) []string {
v, err := s.Strings(ctx)
if err != nil {
panic(err)
}
return v
}
// String returns a single string from a selector. It is only allowed when selecting one field.
func (s *selector) String(ctx context.Context) (_ string, err error) {
var v []string
if v, err = s.Strings(ctx); err != nil {
return
}
switch len(v) {
case 1:
return v[0], nil
case 0:
err = &NotFoundError{s.label}
default:
err = fmt.Errorf("ent: Strings returned %d results when one was expected", len(v))
}
return
}
// StringX is like String, but panics if an error occurs.
func (s *selector) StringX(ctx context.Context) string {
v, err := s.String(ctx)
if err != nil {
panic(err)
}
return v
}
// Ints returns list of ints from a selector. It is only allowed when selecting one field.
func (s *selector) Ints(ctx context.Context) ([]int, error) {
if len(*s.flds) > 1 {
return nil, errors.New("ent: Ints is not achievable when selecting more than 1 field")
}
var v []int
if err := s.scan(ctx, &v); err != nil {
return nil, err
}
return v, nil
}
// IntsX is like Ints, but panics if an error occurs.
func (s *selector) IntsX(ctx context.Context) []int {
v, err := s.Ints(ctx)
if err != nil {
panic(err)
}
return v
}
// Int returns a single int from a selector. It is only allowed when selecting one field.
func (s *selector) Int(ctx context.Context) (_ int, err error) {
var v []int
if v, err = s.Ints(ctx); err != nil {
return
}
switch len(v) {
case 1:
return v[0], nil
case 0:
err = &NotFoundError{s.label}
default:
err = fmt.Errorf("ent: Ints returned %d results when one was expected", len(v))
}
return
}
// IntX is like Int, but panics if an error occurs.
func (s *selector) IntX(ctx context.Context) int {
v, err := s.Int(ctx)
if err != nil {
panic(err)
}
return v
}
// Float64s returns list of float64s from a selector. It is only allowed when selecting one field.
func (s *selector) Float64s(ctx context.Context) ([]float64, error) {
if len(*s.flds) > 1 {
return nil, errors.New("ent: Float64s is not achievable when selecting more than 1 field")
}
var v []float64
if err := s.scan(ctx, &v); err != nil {
return nil, err
}
return v, nil
}
// Float64sX is like Float64s, but panics if an error occurs.
func (s *selector) Float64sX(ctx context.Context) []float64 {
v, err := s.Float64s(ctx)
if err != nil {
panic(err)
}
return v
}
// Float64 returns a single float64 from a selector. It is only allowed when selecting one field.
func (s *selector) Float64(ctx context.Context) (_ float64, err error) {
var v []float64
if v, err = s.Float64s(ctx); err != nil {
return
}
switch len(v) {
case 1:
return v[0], nil
case 0:
err = &NotFoundError{s.label}
default:
err = fmt.Errorf("ent: Float64s returned %d results when one was expected", len(v))
}
return
}
// Float64X is like Float64, but panics if an error occurs.
func (s *selector) Float64X(ctx context.Context) float64 {
v, err := s.Float64(ctx)
if err != nil {
panic(err)
}
return v
}
// Bools returns list of bools from a selector. It is only allowed when selecting one field.
func (s *selector) Bools(ctx context.Context) ([]bool, error) {
if len(*s.flds) > 1 {
return nil, errors.New("ent: Bools is not achievable when selecting more than 1 field")
}
var v []bool
if err := s.scan(ctx, &v); err != nil {
return nil, err
}
return v, nil
}
// BoolsX is like Bools, but panics if an error occurs.
func (s *selector) BoolsX(ctx context.Context) []bool {
v, err := s.Bools(ctx)
if err != nil {
panic(err)
}
return v
}
// Bool returns a single bool from a selector. It is only allowed when selecting one field.
func (s *selector) Bool(ctx context.Context) (_ bool, err error) {
var v []bool
if v, err = s.Bools(ctx); err != nil {
return
}
switch len(v) {
case 1:
return v[0], nil
case 0:
err = &NotFoundError{s.label}
default:
err = fmt.Errorf("ent: Bools returned %d results when one was expected", len(v))
}
return
}
// BoolX is like Bool, but panics if an error occurs.
func (s *selector) BoolX(ctx context.Context) bool {
v, err := s.Bool(ctx)
if err != nil {
panic(err)
}
return v
}
// withHooks invokes the builder operation with the given hooks, if any.
func withHooks[V Value, M any, PM interface {
*M
Mutation
}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) {
if len(hooks) == 0 {
return exec(ctx)
}
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutationT, ok := any(m).(PM)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
// Set the mutation to the builder.
*mutation = *mutationT
return exec(ctx)
})
for i := len(hooks) - 1; i >= 0; i-- {
if hooks[i] == nil {
return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = hooks[i](mut)
}
v, err := mut.Mutate(ctx, mutation)
if err != nil {
return value, err
}
nv, ok := v.(V)
if !ok {
return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation)
}
return nv, nil
}
// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist.
func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context {
if ent.QueryFromContext(ctx) == nil {
qc.Op = op
ctx = ent.NewQueryContext(ctx, qc)
}
return ctx
}
func querierAll[V Value, Q interface {
sqlAll(context.Context, ...queryHook) (V, error)
}]() Querier {
return QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
query, ok := q.(Q)
if !ok {
return nil, fmt.Errorf("unexpected query type %T", q)
}
return query.sqlAll(ctx)
})
}
func querierCount[Q interface {
sqlCount(context.Context) (int, error)
}]() Querier {
return QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
query, ok := q.(Q)
if !ok {
return nil, fmt.Errorf("unexpected query type %T", q)
}
return query.sqlCount(ctx)
})
}
func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) {
for i := len(inters) - 1; i >= 0; i-- {
qr = inters[i].Intercept(qr)
}
rv, err := qr.Query(ctx, q)
if err != nil {
return v, err
}
vt, ok := rv.(V)
if !ok {
return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v)
}
return vt, nil
}
func scanWithInterceptors[Q1 ent.Query, Q2 interface {
sqlScan(context.Context, Q1, any) error
}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error {
rv := reflect.ValueOf(v)
var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
query, ok := q.(Q1)
if !ok {
return nil, fmt.Errorf("unexpected query type %T", q)
}
if err := selectOrGroup.sqlScan(ctx, query, v); err != nil {
return nil, err
}
if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() {
return rv.Elem().Interface(), nil
}
return v, nil
})
for i := len(inters) - 1; i >= 0; i-- {
qr = inters[i].Intercept(qr)
}
vv, err := qr.Query(ctx, rootQuery)
if err != nil {
return err
}
switch rv2 := reflect.ValueOf(vv); {
case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer:
case rv.Type() == rv2.Type():
rv.Elem().Set(rv2.Elem())
case rv.Elem().Type() == rv2.Type():
rv.Elem().Set(rv2)
}
return nil
}
// queryHook describes an internal hook for the different sqlAll methods.
type queryHook func(context.Context, *sqlgraph.QuerySpec)

View File

@@ -0,0 +1,84 @@
// Code generated by ent, DO NOT EDIT.
package enttest
import (
"context"
"entgo.io/ent/examples/rls/ent"
// required by schema hooks.
_ "entgo.io/ent/examples/rls/ent/runtime"
"entgo.io/ent/dialect/sql/schema"
"entgo.io/ent/examples/rls/ent/migrate"
)
type (
// TestingT is the interface that is shared between
// testing.T and testing.B and used by enttest.
TestingT interface {
FailNow()
Error(...any)
}
// Option configures client creation.
Option func(*options)
options struct {
opts []ent.Option
migrateOpts []schema.MigrateOption
}
)
// WithOptions forwards options to client creation.
func WithOptions(opts ...ent.Option) Option {
return func(o *options) {
o.opts = append(o.opts, opts...)
}
}
// WithMigrateOptions forwards options to auto migration.
func WithMigrateOptions(opts ...schema.MigrateOption) Option {
return func(o *options) {
o.migrateOpts = append(o.migrateOpts, opts...)
}
}
func newOptions(opts []Option) *options {
o := &options{}
for _, opt := range opts {
opt(o)
}
return o
}
// Open calls ent.Open and auto-run migration.
func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client {
o := newOptions(opts)
c, err := ent.Open(driverName, dataSourceName, o.opts...)
if err != nil {
t.Error(err)
t.FailNow()
}
migrateSchema(t, c, o)
return c
}
// NewClient calls ent.NewClient and auto-run migration.
func NewClient(t TestingT, opts ...Option) *ent.Client {
o := newOptions(opts)
c := ent.NewClient(o.opts...)
migrateSchema(t, c, o)
return c
}
func migrateSchema(t TestingT, c *ent.Client, o *options) {
tables, err := schema.CopyTables(migrate.Tables)
if err != nil {
t.Error(err)
t.FailNow()
}
if err := migrate.Create(context.Background(), c.Schema, tables, o.migrateOpts...); err != nil {
t.Error(err)
t.FailNow()
}
}

View File

@@ -0,0 +1,3 @@
package ent
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema

View File

@@ -0,0 +1,211 @@
// Code generated by ent, DO NOT EDIT.
package hook
import (
"context"
"fmt"
"entgo.io/ent/examples/rls/ent"
)
// The TenantFunc type is an adapter to allow the use of ordinary
// function as Tenant mutator.
type TenantFunc func(context.Context, *ent.TenantMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f TenantFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.TenantMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TenantMutation", m)
}
// The UserFunc type is an adapter to allow the use of ordinary
// function as User mutator.
type UserFunc func(context.Context, *ent.UserMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UserFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UserMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserMutation", m)
}
// Condition is a hook condition function.
type Condition func(context.Context, ent.Mutation) bool
// And groups conditions with the AND operator.
func And(first, second Condition, rest ...Condition) Condition {
return func(ctx context.Context, m ent.Mutation) bool {
if !first(ctx, m) || !second(ctx, m) {
return false
}
for _, cond := range rest {
if !cond(ctx, m) {
return false
}
}
return true
}
}
// Or groups conditions with the OR operator.
func Or(first, second Condition, rest ...Condition) Condition {
return func(ctx context.Context, m ent.Mutation) bool {
if first(ctx, m) || second(ctx, m) {
return true
}
for _, cond := range rest {
if cond(ctx, m) {
return true
}
}
return false
}
}
// Not negates a given condition.
func Not(cond Condition) Condition {
return func(ctx context.Context, m ent.Mutation) bool {
return !cond(ctx, m)
}
}
// HasOp is a condition testing mutation operation.
func HasOp(op ent.Op) Condition {
return func(_ context.Context, m ent.Mutation) bool {
return m.Op().Is(op)
}
}
// HasAddedFields is a condition validating `.AddedField` on fields.
func HasAddedFields(field string, fields ...string) Condition {
return func(_ context.Context, m ent.Mutation) bool {
if _, exists := m.AddedField(field); !exists {
return false
}
for _, field := range fields {
if _, exists := m.AddedField(field); !exists {
return false
}
}
return true
}
}
// HasClearedFields is a condition validating `.FieldCleared` on fields.
func HasClearedFields(field string, fields ...string) Condition {
return func(_ context.Context, m ent.Mutation) bool {
if exists := m.FieldCleared(field); !exists {
return false
}
for _, field := range fields {
if exists := m.FieldCleared(field); !exists {
return false
}
}
return true
}
}
// HasFields is a condition validating `.Field` on fields.
func HasFields(field string, fields ...string) Condition {
return func(_ context.Context, m ent.Mutation) bool {
if _, exists := m.Field(field); !exists {
return false
}
for _, field := range fields {
if _, exists := m.Field(field); !exists {
return false
}
}
return true
}
}
// If executes the given hook under condition.
//
// hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...)))
func If(hk ent.Hook, cond Condition) ent.Hook {
return func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if cond(ctx, m) {
return hk(next).Mutate(ctx, m)
}
return next.Mutate(ctx, m)
})
}
}
// On executes the given hook only for the given operation.
//
// hook.On(Log, ent.Delete|ent.Create)
func On(hk ent.Hook, op ent.Op) ent.Hook {
return If(hk, HasOp(op))
}
// Unless skips the given hook only for the given operation.
//
// hook.Unless(Log, ent.Update|ent.UpdateOne)
func Unless(hk ent.Hook, op ent.Op) ent.Hook {
return If(hk, Not(HasOp(op)))
}
// FixedError is a hook returning a fixed error.
func FixedError(err error) ent.Hook {
return func(ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) {
return nil, err
})
}
}
// Reject returns a hook that rejects all operations that match op.
//
// func (T) Hooks() []ent.Hook {
// return []ent.Hook{
// Reject(ent.Delete|ent.Update),
// }
// }
func Reject(op ent.Op) ent.Hook {
hk := FixedError(fmt.Errorf("%s operation is not allowed", op))
return On(hk, op)
}
// Chain acts as a list of hooks and is effectively immutable.
// Once created, it will always hold the same set of hooks in the same order.
type Chain struct {
hooks []ent.Hook
}
// NewChain creates a new chain of hooks.
func NewChain(hooks ...ent.Hook) Chain {
return Chain{append([]ent.Hook(nil), hooks...)}
}
// Hook chains the list of hooks and returns the final hook.
func (c Chain) Hook() ent.Hook {
return func(mutator ent.Mutator) ent.Mutator {
for i := len(c.hooks) - 1; i >= 0; i-- {
mutator = c.hooks[i](mutator)
}
return mutator
}
}
// Append extends a chain, adding the specified hook
// as the last ones in the mutation flow.
func (c Chain) Append(hooks ...ent.Hook) Chain {
newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks))
newHooks = append(newHooks, c.hooks...)
newHooks = append(newHooks, hooks...)
return Chain{newHooks}
}
// Extend extends a chain, adding the specified chain
// as the last ones in the mutation flow.
func (c Chain) Extend(chain Chain) Chain {
return c.Append(chain.hooks...)
}

View File

@@ -0,0 +1,64 @@
// Code generated by ent, DO NOT EDIT.
package migrate
import (
"context"
"fmt"
"io"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql/schema"
)
var (
// WithGlobalUniqueID sets the universal ids options to the migration.
// If this option is enabled, ent migration will allocate a 1<<32 range
// for the ids of each entity (table).
// Note that this option cannot be applied on tables that already exist.
WithGlobalUniqueID = schema.WithGlobalUniqueID
// WithDropColumn sets the drop column option to the migration.
// If this option is enabled, ent migration will drop old columns
// that were used for both fields and edges. This defaults to false.
WithDropColumn = schema.WithDropColumn
// WithDropIndex sets the drop index option to the migration.
// If this option is enabled, ent migration will drop old indexes
// that were defined in the schema. This defaults to false.
// Note that unique constraints are defined using `UNIQUE INDEX`,
// and therefore, it's recommended to enable this option to get more
// flexibility in the schema changes.
WithDropIndex = schema.WithDropIndex
// WithForeignKeys enables creating foreign-key in schema DDL. This defaults to true.
WithForeignKeys = schema.WithForeignKeys
)
// Schema is the API for creating, migrating and dropping a schema.
type Schema struct {
drv dialect.Driver
}
// NewSchema creates a new schema client.
func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} }
// Create creates all schema resources.
func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error {
return Create(ctx, s, Tables, opts...)
}
// Create creates all table resources using the given schema driver.
func Create(ctx context.Context, s *Schema, tables []*schema.Table, opts ...schema.MigrateOption) error {
migrate, err := schema.NewMigrate(s.drv, opts...)
if err != nil {
return fmt.Errorf("ent/migrate: %w", err)
}
return migrate.Create(ctx, tables...)
}
// WriteTo writes the schema changes to w instead of running them against the database.
//
// if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil {
// log.Fatal(err)
// }
func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error {
return Create(ctx, &Schema{drv: &schema.WriteDriver{Writer: w, Driver: s.drv}}, Tables, opts...)
}

View File

@@ -0,0 +1,42 @@
// Code generated by ent, DO NOT EDIT.
package migrate
import (
"entgo.io/ent/dialect/sql/schema"
"entgo.io/ent/schema/field"
)
var (
// TenantsColumns holds the columns for the "tenants" table.
TenantsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString},
}
// TenantsTable holds the schema information for the "tenants" table.
TenantsTable = &schema.Table{
Name: "tenants",
Columns: TenantsColumns,
PrimaryKey: []*schema.Column{TenantsColumns[0]},
}
// UsersColumns holds the columns for the "users" table.
UsersColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString},
{Name: "tenant_id", Type: field.TypeInt},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
Name: "users",
Columns: UsersColumns,
PrimaryKey: []*schema.Column{UsersColumns[0]},
}
// Tables holds all the tables in the schema.
Tables = []*schema.Table{
TenantsTable,
UsersTable,
}
)
func init() {
}

View File

@@ -0,0 +1,771 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"sync"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/examples/rls/ent/predicate"
"entgo.io/ent/examples/rls/ent/tenant"
"entgo.io/ent/examples/rls/ent/user"
)
const (
// Operation types.
OpCreate = ent.OpCreate
OpDelete = ent.OpDelete
OpDeleteOne = ent.OpDeleteOne
OpUpdate = ent.OpUpdate
OpUpdateOne = ent.OpUpdateOne
// Node types.
TypeTenant = "Tenant"
TypeUser = "User"
)
// TenantMutation represents an operation that mutates the Tenant nodes in the graph.
type TenantMutation struct {
config
op Op
typ string
id *int
name *string
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*Tenant, error)
predicates []predicate.Tenant
}
var _ ent.Mutation = (*TenantMutation)(nil)
// tenantOption allows management of the mutation configuration using functional options.
type tenantOption func(*TenantMutation)
// newTenantMutation creates new mutation for the Tenant entity.
func newTenantMutation(c config, op Op, opts ...tenantOption) *TenantMutation {
m := &TenantMutation{
config: c,
op: op,
typ: TypeTenant,
clearedFields: make(map[string]struct{}),
}
for _, opt := range opts {
opt(m)
}
return m
}
// withTenantID sets the ID field of the mutation.
func withTenantID(id int) tenantOption {
return func(m *TenantMutation) {
var (
err error
once sync.Once
value *Tenant
)
m.oldValue = func(ctx context.Context) (*Tenant, error) {
once.Do(func() {
if m.done {
err = errors.New("querying old values post mutation is not allowed")
} else {
value, err = m.Client().Tenant.Get(ctx, id)
}
})
return value, err
}
m.id = &id
}
}
// withTenant sets the old Tenant of the mutation.
func withTenant(node *Tenant) tenantOption {
return func(m *TenantMutation) {
m.oldValue = func(context.Context) (*Tenant, error) {
return node, nil
}
m.id = &node.ID
}
}
// Client returns a new `ent.Client` from the mutation. If the mutation was
// executed in a transaction (ent.Tx), a transactional client is returned.
func (m TenantMutation) Client() *Client {
client := &Client{config: m.config}
client.init()
return client
}
// Tx returns an `ent.Tx` for mutations that were executed in transactions;
// it returns an error otherwise.
func (m TenantMutation) Tx() (*Tx, error) {
if _, ok := m.driver.(*txDriver); !ok {
return nil, errors.New("ent: mutation is not running in a transaction")
}
tx := &Tx{config: m.config}
tx.init()
return tx, nil
}
// ID returns the ID value in the mutation. Note that the ID is only available
// if it was provided to the builder or after it was returned from the database.
func (m *TenantMutation) ID() (id int, exists bool) {
if m.id == nil {
return
}
return *m.id, true
}
// IDs queries the database and returns the entity ids that match the mutation's predicate.
// That means, if the mutation is applied within a transaction with an isolation level such
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
// or updated by the mutation.
func (m *TenantMutation) IDs(ctx context.Context) ([]int, error) {
switch {
case m.op.Is(OpUpdateOne | OpDeleteOne):
id, exists := m.ID()
if exists {
return []int{id}, nil
}
fallthrough
case m.op.Is(OpUpdate | OpDelete):
return m.Client().Tenant.Query().Where(m.predicates...).IDs(ctx)
default:
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
}
}
// SetName sets the "name" field.
func (m *TenantMutation) SetName(s string) {
m.name = &s
}
// Name returns the value of the "name" field in the mutation.
func (m *TenantMutation) Name() (r string, exists bool) {
v := m.name
if v == nil {
return
}
return *v, true
}
// OldName returns the old "name" field's value of the Tenant entity.
// If the Tenant object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *TenantMutation) OldName(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldName is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldName requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldName: %w", err)
}
return oldValue.Name, nil
}
// ResetName resets all changes to the "name" field.
func (m *TenantMutation) ResetName() {
m.name = nil
}
// Where appends a list predicates to the TenantMutation builder.
func (m *TenantMutation) Where(ps ...predicate.Tenant) {
m.predicates = append(m.predicates, ps...)
}
// WhereP appends storage-level predicates to the TenantMutation builder. Using this method,
// users can use type-assertion to append predicates that do not depend on any generated package.
func (m *TenantMutation) WhereP(ps ...func(*sql.Selector)) {
p := make([]predicate.Tenant, len(ps))
for i := range ps {
p[i] = ps[i]
}
m.Where(p...)
}
// Op returns the operation name.
func (m *TenantMutation) Op() Op {
return m.op
}
// SetOp allows setting the mutation operation.
func (m *TenantMutation) SetOp(op Op) {
m.op = op
}
// Type returns the node type of this mutation (Tenant).
func (m *TenantMutation) Type() string {
return m.typ
}
// Fields returns all fields that were changed during this mutation. Note that in
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *TenantMutation) Fields() []string {
fields := make([]string, 0, 1)
if m.name != nil {
fields = append(fields, tenant.FieldName)
}
return fields
}
// Field returns the value of a field with the given name. The second boolean
// return value indicates that this field was not set, or was not defined in the
// schema.
func (m *TenantMutation) Field(name string) (ent.Value, bool) {
switch name {
case tenant.FieldName:
return m.Name()
}
return nil, false
}
// OldField returns the old value of the field from the database. An error is
// returned if the mutation operation is not UpdateOne, or the query to the
// database failed.
func (m *TenantMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
case tenant.FieldName:
return m.OldName(ctx)
}
return nil, fmt.Errorf("unknown Tenant field %s", name)
}
// SetField sets the value of a field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
func (m *TenantMutation) SetField(name string, value ent.Value) error {
switch name {
case tenant.FieldName:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetName(v)
return nil
}
return fmt.Errorf("unknown Tenant field %s", name)
}
// AddedFields returns all numeric fields that were incremented/decremented during
// this mutation.
func (m *TenantMutation) AddedFields() []string {
return nil
}
// AddedField returns the numeric value that was incremented/decremented on a field
// with the given name. The second boolean return value indicates that this field
// was not set, or was not defined in the schema.
func (m *TenantMutation) AddedField(name string) (ent.Value, bool) {
return nil, false
}
// AddField adds the value to the field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
func (m *TenantMutation) AddField(name string, value ent.Value) error {
switch name {
}
return fmt.Errorf("unknown Tenant numeric field %s", name)
}
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
func (m *TenantMutation) ClearedFields() []string {
return nil
}
// FieldCleared returns a boolean indicating if a field with the given name was
// cleared in this mutation.
func (m *TenantMutation) FieldCleared(name string) bool {
_, ok := m.clearedFields[name]
return ok
}
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
func (m *TenantMutation) ClearField(name string) error {
return fmt.Errorf("unknown Tenant nullable field %s", name)
}
// ResetField resets all changes in the mutation for the field with the given name.
// It returns an error if the field is not defined in the schema.
func (m *TenantMutation) ResetField(name string) error {
switch name {
case tenant.FieldName:
m.ResetName()
return nil
}
return fmt.Errorf("unknown Tenant field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *TenantMutation) AddedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// AddedIDs returns all IDs (to other nodes) that were added for the given edge
// name in this mutation.
func (m *TenantMutation) AddedIDs(name string) []ent.Value {
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *TenantMutation) RemovedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
// the given name in this mutation.
func (m *TenantMutation) RemovedIDs(name string) []ent.Value {
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *TenantMutation) ClearedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// EdgeCleared returns a boolean which indicates if the edge with the given name
// was cleared in this mutation.
func (m *TenantMutation) EdgeCleared(name string) bool {
return false
}
// ClearEdge clears the value of the edge with the given name. It returns an error
// if that edge is not defined in the schema.
func (m *TenantMutation) ClearEdge(name string) error {
return fmt.Errorf("unknown Tenant unique edge %s", name)
}
// ResetEdge resets all changes to the edge with the given name in this mutation.
// It returns an error if the edge is not defined in the schema.
func (m *TenantMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown Tenant edge %s", name)
}
// UserMutation represents an operation that mutates the User nodes in the graph.
type UserMutation struct {
config
op Op
typ string
id *int
name *string
tenant_id *int
addtenant_id *int
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*User, error)
predicates []predicate.User
}
var _ ent.Mutation = (*UserMutation)(nil)
// userOption allows management of the mutation configuration using functional options.
type userOption func(*UserMutation)
// newUserMutation creates new mutation for the User entity.
func newUserMutation(c config, op Op, opts ...userOption) *UserMutation {
m := &UserMutation{
config: c,
op: op,
typ: TypeUser,
clearedFields: make(map[string]struct{}),
}
for _, opt := range opts {
opt(m)
}
return m
}
// withUserID sets the ID field of the mutation.
func withUserID(id int) userOption {
return func(m *UserMutation) {
var (
err error
once sync.Once
value *User
)
m.oldValue = func(ctx context.Context) (*User, error) {
once.Do(func() {
if m.done {
err = errors.New("querying old values post mutation is not allowed")
} else {
value, err = m.Client().User.Get(ctx, id)
}
})
return value, err
}
m.id = &id
}
}
// withUser sets the old User of the mutation.
func withUser(node *User) userOption {
return func(m *UserMutation) {
m.oldValue = func(context.Context) (*User, error) {
return node, nil
}
m.id = &node.ID
}
}
// Client returns a new `ent.Client` from the mutation. If the mutation was
// executed in a transaction (ent.Tx), a transactional client is returned.
func (m UserMutation) Client() *Client {
client := &Client{config: m.config}
client.init()
return client
}
// Tx returns an `ent.Tx` for mutations that were executed in transactions;
// it returns an error otherwise.
func (m UserMutation) Tx() (*Tx, error) {
if _, ok := m.driver.(*txDriver); !ok {
return nil, errors.New("ent: mutation is not running in a transaction")
}
tx := &Tx{config: m.config}
tx.init()
return tx, nil
}
// ID returns the ID value in the mutation. Note that the ID is only available
// if it was provided to the builder or after it was returned from the database.
func (m *UserMutation) ID() (id int, exists bool) {
if m.id == nil {
return
}
return *m.id, true
}
// IDs queries the database and returns the entity ids that match the mutation's predicate.
// That means, if the mutation is applied within a transaction with an isolation level such
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
// or updated by the mutation.
func (m *UserMutation) IDs(ctx context.Context) ([]int, error) {
switch {
case m.op.Is(OpUpdateOne | OpDeleteOne):
id, exists := m.ID()
if exists {
return []int{id}, nil
}
fallthrough
case m.op.Is(OpUpdate | OpDelete):
return m.Client().User.Query().Where(m.predicates...).IDs(ctx)
default:
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
}
}
// SetName sets the "name" field.
func (m *UserMutation) SetName(s string) {
m.name = &s
}
// Name returns the value of the "name" field in the mutation.
func (m *UserMutation) Name() (r string, exists bool) {
v := m.name
if v == nil {
return
}
return *v, true
}
// OldName returns the old "name" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldName(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldName is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldName requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldName: %w", err)
}
return oldValue.Name, nil
}
// ResetName resets all changes to the "name" field.
func (m *UserMutation) ResetName() {
m.name = nil
}
// SetTenantID sets the "tenant_id" field.
func (m *UserMutation) SetTenantID(i int) {
m.tenant_id = &i
m.addtenant_id = nil
}
// TenantID returns the value of the "tenant_id" field in the mutation.
func (m *UserMutation) TenantID() (r int, exists bool) {
v := m.tenant_id
if v == nil {
return
}
return *v, true
}
// OldTenantID returns the old "tenant_id" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTenantID(ctx context.Context) (v int, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTenantID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTenantID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTenantID: %w", err)
}
return oldValue.TenantID, nil
}
// AddTenantID adds i to the "tenant_id" field.
func (m *UserMutation) AddTenantID(i int) {
if m.addtenant_id != nil {
*m.addtenant_id += i
} else {
m.addtenant_id = &i
}
}
// AddedTenantID returns the value that was added to the "tenant_id" field in this mutation.
func (m *UserMutation) AddedTenantID() (r int, exists bool) {
v := m.addtenant_id
if v == nil {
return
}
return *v, true
}
// ResetTenantID resets all changes to the "tenant_id" field.
func (m *UserMutation) ResetTenantID() {
m.tenant_id = nil
m.addtenant_id = nil
}
// Where appends a list predicates to the UserMutation builder.
func (m *UserMutation) Where(ps ...predicate.User) {
m.predicates = append(m.predicates, ps...)
}
// WhereP appends storage-level predicates to the UserMutation builder. Using this method,
// users can use type-assertion to append predicates that do not depend on any generated package.
func (m *UserMutation) WhereP(ps ...func(*sql.Selector)) {
p := make([]predicate.User, len(ps))
for i := range ps {
p[i] = ps[i]
}
m.Where(p...)
}
// Op returns the operation name.
func (m *UserMutation) Op() Op {
return m.op
}
// SetOp allows setting the mutation operation.
func (m *UserMutation) SetOp(op Op) {
m.op = op
}
// Type returns the node type of this mutation (User).
func (m *UserMutation) Type() string {
return m.typ
}
// Fields returns all fields that were changed during this mutation. Note that in
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 2)
if m.name != nil {
fields = append(fields, user.FieldName)
}
if m.tenant_id != nil {
fields = append(fields, user.FieldTenantID)
}
return fields
}
// Field returns the value of a field with the given name. The second boolean
// return value indicates that this field was not set, or was not defined in the
// schema.
func (m *UserMutation) Field(name string) (ent.Value, bool) {
switch name {
case user.FieldName:
return m.Name()
case user.FieldTenantID:
return m.TenantID()
}
return nil, false
}
// OldField returns the old value of the field from the database. An error is
// returned if the mutation operation is not UpdateOne, or the query to the
// database failed.
func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
case user.FieldName:
return m.OldName(ctx)
case user.FieldTenantID:
return m.OldTenantID(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
// SetField sets the value of a field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
func (m *UserMutation) SetField(name string, value ent.Value) error {
switch name {
case user.FieldName:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetName(v)
return nil
case user.FieldTenantID:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTenantID(v)
return nil
}
return fmt.Errorf("unknown User field %s", name)
}
// AddedFields returns all numeric fields that were incremented/decremented during
// this mutation.
func (m *UserMutation) AddedFields() []string {
var fields []string
if m.addtenant_id != nil {
fields = append(fields, user.FieldTenantID)
}
return fields
}
// AddedField returns the numeric value that was incremented/decremented on a field
// with the given name. The second boolean return value indicates that this field
// was not set, or was not defined in the schema.
func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
switch name {
case user.FieldTenantID:
return m.AddedTenantID()
}
return nil, false
}
// AddField adds the value to the field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
func (m *UserMutation) AddField(name string, value ent.Value) error {
switch name {
case user.FieldTenantID:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddTenantID(v)
return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
func (m *UserMutation) ClearedFields() []string {
return nil
}
// FieldCleared returns a boolean indicating if a field with the given name was
// cleared in this mutation.
func (m *UserMutation) FieldCleared(name string) bool {
_, ok := m.clearedFields[name]
return ok
}
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
func (m *UserMutation) ClearField(name string) error {
return fmt.Errorf("unknown User nullable field %s", name)
}
// ResetField resets all changes in the mutation for the field with the given name.
// It returns an error if the field is not defined in the schema.
func (m *UserMutation) ResetField(name string) error {
switch name {
case user.FieldName:
m.ResetName()
return nil
case user.FieldTenantID:
m.ResetTenantID()
return nil
}
return fmt.Errorf("unknown User field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *UserMutation) AddedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// AddedIDs returns all IDs (to other nodes) that were added for the given edge
// name in this mutation.
func (m *UserMutation) AddedIDs(name string) []ent.Value {
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *UserMutation) RemovedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
// the given name in this mutation.
func (m *UserMutation) RemovedIDs(name string) []ent.Value {
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *UserMutation) ClearedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// EdgeCleared returns a boolean which indicates if the edge with the given name
// was cleared in this mutation.
func (m *UserMutation) EdgeCleared(name string) bool {
return false
}
// ClearEdge clears the value of the edge with the given name. It returns an error
// if that edge is not defined in the schema.
func (m *UserMutation) ClearEdge(name string) error {
return fmt.Errorf("unknown User unique edge %s", name)
}
// ResetEdge resets all changes to the edge with the given name in this mutation.
// It returns an error if the edge is not defined in the schema.
func (m *UserMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown User edge %s", name)
}

View File

@@ -0,0 +1,13 @@
// Code generated by ent, DO NOT EDIT.
package predicate
import (
"entgo.io/ent/dialect/sql"
)
// Tenant is the predicate function for tenant builders.
type Tenant func(*sql.Selector)
// User is the predicate function for user builders.
type User func(*sql.Selector)

View File

@@ -0,0 +1,9 @@
// Code generated by ent, DO NOT EDIT.
package ent
// The init function reads all schema descriptors with runtime code
// (default values, validators, hooks and policies) and stitches it
// to their package variables.
func init() {
}

View File

@@ -0,0 +1,9 @@
// Code generated by ent, DO NOT EDIT.
package runtime
// The schema-stitching logic is generated in entgo.io/ent/examples/rls/ent/runtime.go
const (
Version = "v0.0.0-00010101000000-000000000000" // Version of ent codegen.
)

View File

@@ -0,0 +1,35 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"entgo.io/ent"
"entgo.io/ent/schema/field"
)
// Tenant holds the schema definition for the Tenant entity.
type Tenant struct {
ent.Schema
}
// Fields of the Tenant.
func (Tenant) Fields() []ent.Field {
return []ent.Field{
field.String("name"),
}
}
// User holds the schema definition for the User entity.
type User struct {
ent.Schema
}
// Fields of the User.
func (User) Fields() []ent.Field {
return []ent.Field{
field.String("name"),
field.Int("tenant_id"),
}
}

103
examples/rls/ent/tenant.go Normal file
View File

@@ -0,0 +1,103 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/examples/rls/ent/tenant"
)
// Tenant is the model entity for the Tenant schema.
type Tenant struct {
config `json:"-"`
// ID of the ent.
ID int `json:"id,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*Tenant) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case tenant.FieldID:
values[i] = new(sql.NullInt64)
case tenant.FieldName:
values[i] = new(sql.NullString)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the Tenant fields.
func (t *Tenant) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case tenant.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
t.ID = int(value.Int64)
case tenant.FieldName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field name", values[i])
} else if value.Valid {
t.Name = value.String
}
default:
t.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the Tenant.
// This includes values selected through modifiers, order, etc.
func (t *Tenant) Value(name string) (ent.Value, error) {
return t.selectValues.Get(name)
}
// Update returns a builder for updating this Tenant.
// Note that you need to call Tenant.Unwrap() before calling this method if this Tenant
// was returned from a transaction, and the transaction was committed or rolled back.
func (t *Tenant) Update() *TenantUpdateOne {
return NewTenantClient(t.config).UpdateOne(t)
}
// Unwrap unwraps the Tenant entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (t *Tenant) Unwrap() *Tenant {
_tx, ok := t.config.driver.(*txDriver)
if !ok {
panic("ent: Tenant is not a transactional entity")
}
t.config.driver = _tx.drv
return t
}
// String implements the fmt.Stringer.
func (t *Tenant) String() string {
var builder strings.Builder
builder.WriteString("Tenant(")
builder.WriteString(fmt.Sprintf("id=%v, ", t.ID))
builder.WriteString("name=")
builder.WriteString(t.Name)
builder.WriteByte(')')
return builder.String()
}
// Tenants is a parsable slice of Tenant.
type Tenants []*Tenant

View File

@@ -0,0 +1,47 @@
// Code generated by ent, DO NOT EDIT.
package tenant
import (
"entgo.io/ent/dialect/sql"
)
const (
// Label holds the string label denoting the tenant type in the database.
Label = "tenant"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldName holds the string denoting the name field in the database.
FieldName = "name"
// Table holds the table name of the tenant in the database.
Table = "tenants"
)
// Columns holds all SQL columns for tenant fields.
var Columns = []string{
FieldID,
FieldName,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
// OrderOption defines the ordering options for the Tenant queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByName orders the results by the name field.
func ByName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldName, opts...).ToFunc()
}

View File

@@ -0,0 +1,138 @@
// Code generated by ent, DO NOT EDIT.
package tenant
import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/examples/rls/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int) predicate.Tenant {
return predicate.Tenant(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int) predicate.Tenant {
return predicate.Tenant(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int) predicate.Tenant {
return predicate.Tenant(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int) predicate.Tenant {
return predicate.Tenant(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int) predicate.Tenant {
return predicate.Tenant(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int) predicate.Tenant {
return predicate.Tenant(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int) predicate.Tenant {
return predicate.Tenant(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int) predicate.Tenant {
return predicate.Tenant(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int) predicate.Tenant {
return predicate.Tenant(sql.FieldLTE(FieldID, id))
}
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
func Name(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldEQ(FieldName, v))
}
// NameEQ applies the EQ predicate on the "name" field.
func NameEQ(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldEQ(FieldName, v))
}
// NameNEQ applies the NEQ predicate on the "name" field.
func NameNEQ(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldNEQ(FieldName, v))
}
// NameIn applies the In predicate on the "name" field.
func NameIn(vs ...string) predicate.Tenant {
return predicate.Tenant(sql.FieldIn(FieldName, vs...))
}
// NameNotIn applies the NotIn predicate on the "name" field.
func NameNotIn(vs ...string) predicate.Tenant {
return predicate.Tenant(sql.FieldNotIn(FieldName, vs...))
}
// NameGT applies the GT predicate on the "name" field.
func NameGT(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldGT(FieldName, v))
}
// NameGTE applies the GTE predicate on the "name" field.
func NameGTE(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldGTE(FieldName, v))
}
// NameLT applies the LT predicate on the "name" field.
func NameLT(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldLT(FieldName, v))
}
// NameLTE applies the LTE predicate on the "name" field.
func NameLTE(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldLTE(FieldName, v))
}
// NameContains applies the Contains predicate on the "name" field.
func NameContains(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldContains(FieldName, v))
}
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
func NameHasPrefix(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldHasPrefix(FieldName, v))
}
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
func NameHasSuffix(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldHasSuffix(FieldName, v))
}
// NameEqualFold applies the EqualFold predicate on the "name" field.
func NameEqualFold(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldEqualFold(FieldName, v))
}
// NameContainsFold applies the ContainsFold predicate on the "name" field.
func NameContainsFold(v string) predicate.Tenant {
return predicate.Tenant(sql.FieldContainsFold(FieldName, v))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.Tenant) predicate.Tenant {
return predicate.Tenant(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.Tenant) predicate.Tenant {
return predicate.Tenant(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.Tenant) predicate.Tenant {
return predicate.Tenant(sql.NotPredicates(p))
}

View File

@@ -0,0 +1,183 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/tenant"
"entgo.io/ent/schema/field"
)
// TenantCreate is the builder for creating a Tenant entity.
type TenantCreate struct {
config
mutation *TenantMutation
hooks []Hook
}
// SetName sets the "name" field.
func (tc *TenantCreate) SetName(s string) *TenantCreate {
tc.mutation.SetName(s)
return tc
}
// Mutation returns the TenantMutation object of the builder.
func (tc *TenantCreate) Mutation() *TenantMutation {
return tc.mutation
}
// Save creates the Tenant in the database.
func (tc *TenantCreate) Save(ctx context.Context) (*Tenant, error) {
return withHooks(ctx, tc.sqlSave, tc.mutation, tc.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (tc *TenantCreate) SaveX(ctx context.Context) *Tenant {
v, err := tc.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (tc *TenantCreate) Exec(ctx context.Context) error {
_, err := tc.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (tc *TenantCreate) ExecX(ctx context.Context) {
if err := tc.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (tc *TenantCreate) check() error {
if _, ok := tc.mutation.Name(); !ok {
return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Tenant.name"`)}
}
return nil
}
func (tc *TenantCreate) sqlSave(ctx context.Context) (*Tenant, error) {
if err := tc.check(); err != nil {
return nil, err
}
_node, _spec := tc.createSpec()
if err := sqlgraph.CreateNode(ctx, tc.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int(id)
tc.mutation.id = &_node.ID
tc.mutation.done = true
return _node, nil
}
func (tc *TenantCreate) createSpec() (*Tenant, *sqlgraph.CreateSpec) {
var (
_node = &Tenant{config: tc.config}
_spec = sqlgraph.NewCreateSpec(tenant.Table, sqlgraph.NewFieldSpec(tenant.FieldID, field.TypeInt))
)
if value, ok := tc.mutation.Name(); ok {
_spec.SetField(tenant.FieldName, field.TypeString, value)
_node.Name = value
}
return _node, _spec
}
// TenantCreateBulk is the builder for creating many Tenant entities in bulk.
type TenantCreateBulk struct {
config
err error
builders []*TenantCreate
}
// Save creates the Tenant entities in the database.
func (tcb *TenantCreateBulk) Save(ctx context.Context) ([]*Tenant, error) {
if tcb.err != nil {
return nil, tcb.err
}
specs := make([]*sqlgraph.CreateSpec, len(tcb.builders))
nodes := make([]*Tenant, len(tcb.builders))
mutators := make([]Mutator, len(tcb.builders))
for i := range tcb.builders {
func(i int, root context.Context) {
builder := tcb.builders[i]
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*TenantMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, tcb.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, tcb.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, tcb.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (tcb *TenantCreateBulk) SaveX(ctx context.Context) []*Tenant {
v, err := tcb.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (tcb *TenantCreateBulk) Exec(ctx context.Context) error {
_, err := tcb.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (tcb *TenantCreateBulk) ExecX(ctx context.Context) {
if err := tcb.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/predicate"
"entgo.io/ent/examples/rls/ent/tenant"
"entgo.io/ent/schema/field"
)
// TenantDelete is the builder for deleting a Tenant entity.
type TenantDelete struct {
config
hooks []Hook
mutation *TenantMutation
}
// Where appends a list predicates to the TenantDelete builder.
func (td *TenantDelete) Where(ps ...predicate.Tenant) *TenantDelete {
td.mutation.Where(ps...)
return td
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (td *TenantDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, td.sqlExec, td.mutation, td.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (td *TenantDelete) ExecX(ctx context.Context) int {
n, err := td.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (td *TenantDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(tenant.Table, sqlgraph.NewFieldSpec(tenant.FieldID, field.TypeInt))
if ps := td.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, td.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
td.mutation.done = true
return affected, err
}
// TenantDeleteOne is the builder for deleting a single Tenant entity.
type TenantDeleteOne struct {
td *TenantDelete
}
// Where appends a list predicates to the TenantDelete builder.
func (tdo *TenantDeleteOne) Where(ps ...predicate.Tenant) *TenantDeleteOne {
tdo.td.mutation.Where(ps...)
return tdo
}
// Exec executes the deletion query.
func (tdo *TenantDeleteOne) Exec(ctx context.Context) error {
n, err := tdo.td.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{tenant.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (tdo *TenantDeleteOne) ExecX(ctx context.Context) {
if err := tdo.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,527 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/predicate"
"entgo.io/ent/examples/rls/ent/tenant"
"entgo.io/ent/schema/field"
)
// TenantQuery is the builder for querying Tenant entities.
type TenantQuery struct {
config
ctx *QueryContext
order []tenant.OrderOption
inters []Interceptor
predicates []predicate.Tenant
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the TenantQuery builder.
func (tq *TenantQuery) Where(ps ...predicate.Tenant) *TenantQuery {
tq.predicates = append(tq.predicates, ps...)
return tq
}
// Limit the number of records to be returned by this query.
func (tq *TenantQuery) Limit(limit int) *TenantQuery {
tq.ctx.Limit = &limit
return tq
}
// Offset to start from.
func (tq *TenantQuery) Offset(offset int) *TenantQuery {
tq.ctx.Offset = &offset
return tq
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (tq *TenantQuery) Unique(unique bool) *TenantQuery {
tq.ctx.Unique = &unique
return tq
}
// Order specifies how the records should be ordered.
func (tq *TenantQuery) Order(o ...tenant.OrderOption) *TenantQuery {
tq.order = append(tq.order, o...)
return tq
}
// First returns the first Tenant entity from the query.
// Returns a *NotFoundError when no Tenant was found.
func (tq *TenantQuery) First(ctx context.Context) (*Tenant, error) {
nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{tenant.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (tq *TenantQuery) FirstX(ctx context.Context) *Tenant {
node, err := tq.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first Tenant ID from the query.
// Returns a *NotFoundError when no Tenant ID was found.
func (tq *TenantQuery) FirstID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{tenant.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (tq *TenantQuery) FirstIDX(ctx context.Context) int {
id, err := tq.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single Tenant entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one Tenant entity is found.
// Returns a *NotFoundError when no Tenant entities are found.
func (tq *TenantQuery) Only(ctx context.Context) (*Tenant, error) {
nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{tenant.Label}
default:
return nil, &NotSingularError{tenant.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (tq *TenantQuery) OnlyX(ctx context.Context) *Tenant {
node, err := tq.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only Tenant ID in the query.
// Returns a *NotSingularError when more than one Tenant ID is found.
// Returns a *NotFoundError when no entities are found.
func (tq *TenantQuery) OnlyID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{tenant.Label}
default:
err = &NotSingularError{tenant.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (tq *TenantQuery) OnlyIDX(ctx context.Context) int {
id, err := tq.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of Tenants.
func (tq *TenantQuery) All(ctx context.Context) ([]*Tenant, error) {
ctx = setContextOp(ctx, tq.ctx, ent.OpQueryAll)
if err := tq.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*Tenant, *TenantQuery]()
return withInterceptors[[]*Tenant](ctx, tq, qr, tq.inters)
}
// AllX is like All, but panics if an error occurs.
func (tq *TenantQuery) AllX(ctx context.Context) []*Tenant {
nodes, err := tq.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of Tenant IDs.
func (tq *TenantQuery) IDs(ctx context.Context) (ids []int, err error) {
if tq.ctx.Unique == nil && tq.path != nil {
tq.Unique(true)
}
ctx = setContextOp(ctx, tq.ctx, ent.OpQueryIDs)
if err = tq.Select(tenant.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (tq *TenantQuery) IDsX(ctx context.Context) []int {
ids, err := tq.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (tq *TenantQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, tq.ctx, ent.OpQueryCount)
if err := tq.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, tq, querierCount[*TenantQuery](), tq.inters)
}
// CountX is like Count, but panics if an error occurs.
func (tq *TenantQuery) CountX(ctx context.Context) int {
count, err := tq.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (tq *TenantQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, tq.ctx, ent.OpQueryExist)
switch _, err := tq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (tq *TenantQuery) ExistX(ctx context.Context) bool {
exist, err := tq.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the TenantQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (tq *TenantQuery) Clone() *TenantQuery {
if tq == nil {
return nil
}
return &TenantQuery{
config: tq.config,
ctx: tq.ctx.Clone(),
order: append([]tenant.OrderOption{}, tq.order...),
inters: append([]Interceptor{}, tq.inters...),
predicates: append([]predicate.Tenant{}, tq.predicates...),
// clone intermediate query.
sql: tq.sql.Clone(),
path: tq.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// Name string `json:"name,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.Tenant.Query().
// GroupBy(tenant.FieldName).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (tq *TenantQuery) GroupBy(field string, fields ...string) *TenantGroupBy {
tq.ctx.Fields = append([]string{field}, fields...)
grbuild := &TenantGroupBy{build: tq}
grbuild.flds = &tq.ctx.Fields
grbuild.label = tenant.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// Name string `json:"name,omitempty"`
// }
//
// client.Tenant.Query().
// Select(tenant.FieldName).
// Scan(ctx, &v)
func (tq *TenantQuery) Select(fields ...string) *TenantSelect {
tq.ctx.Fields = append(tq.ctx.Fields, fields...)
sbuild := &TenantSelect{TenantQuery: tq}
sbuild.label = tenant.Label
sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a TenantSelect configured with the given aggregations.
func (tq *TenantQuery) Aggregate(fns ...AggregateFunc) *TenantSelect {
return tq.Select().Aggregate(fns...)
}
func (tq *TenantQuery) prepareQuery(ctx context.Context) error {
for _, inter := range tq.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, tq); err != nil {
return err
}
}
}
for _, f := range tq.ctx.Fields {
if !tenant.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if tq.path != nil {
prev, err := tq.path(ctx)
if err != nil {
return err
}
tq.sql = prev
}
return nil
}
func (tq *TenantQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Tenant, error) {
var (
nodes = []*Tenant{}
_spec = tq.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*Tenant).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &Tenant{config: tq.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, tq.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (tq *TenantQuery) sqlCount(ctx context.Context) (int, error) {
_spec := tq.querySpec()
_spec.Node.Columns = tq.ctx.Fields
if len(tq.ctx.Fields) > 0 {
_spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique
}
return sqlgraph.CountNodes(ctx, tq.driver, _spec)
}
func (tq *TenantQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(tenant.Table, tenant.Columns, sqlgraph.NewFieldSpec(tenant.FieldID, field.TypeInt))
_spec.From = tq.sql
if unique := tq.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if tq.path != nil {
_spec.Unique = true
}
if fields := tq.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, tenant.FieldID)
for i := range fields {
if fields[i] != tenant.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := tq.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := tq.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := tq.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := tq.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (tq *TenantQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(tq.driver.Dialect())
t1 := builder.Table(tenant.Table)
columns := tq.ctx.Fields
if len(columns) == 0 {
columns = tenant.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if tq.sql != nil {
selector = tq.sql
selector.Select(selector.Columns(columns...)...)
}
if tq.ctx.Unique != nil && *tq.ctx.Unique {
selector.Distinct()
}
for _, p := range tq.predicates {
p(selector)
}
for _, p := range tq.order {
p(selector)
}
if offset := tq.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := tq.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// TenantGroupBy is the group-by builder for Tenant entities.
type TenantGroupBy struct {
selector
build *TenantQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (tgb *TenantGroupBy) Aggregate(fns ...AggregateFunc) *TenantGroupBy {
tgb.fns = append(tgb.fns, fns...)
return tgb
}
// Scan applies the selector query and scans the result into the given value.
func (tgb *TenantGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, tgb.build.ctx, ent.OpQueryGroupBy)
if err := tgb.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*TenantQuery, *TenantGroupBy](ctx, tgb.build, tgb, tgb.build.inters, v)
}
func (tgb *TenantGroupBy) sqlScan(ctx context.Context, root *TenantQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(tgb.fns))
for _, fn := range tgb.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*tgb.flds)+len(tgb.fns))
for _, f := range *tgb.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*tgb.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := tgb.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// TenantSelect is the builder for selecting fields of Tenant entities.
type TenantSelect struct {
*TenantQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (ts *TenantSelect) Aggregate(fns ...AggregateFunc) *TenantSelect {
ts.fns = append(ts.fns, fns...)
return ts
}
// Scan applies the selector query and scans the result into the given value.
func (ts *TenantSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, ts.ctx, ent.OpQuerySelect)
if err := ts.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*TenantQuery, *TenantSelect](ctx, ts.TenantQuery, ts, ts.inters, v)
}
func (ts *TenantSelect) sqlScan(ctx context.Context, root *TenantQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(ts.fns))
for _, fn := range ts.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*ts.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := ts.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,209 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/predicate"
"entgo.io/ent/examples/rls/ent/tenant"
"entgo.io/ent/schema/field"
)
// TenantUpdate is the builder for updating Tenant entities.
type TenantUpdate struct {
config
hooks []Hook
mutation *TenantMutation
}
// Where appends a list predicates to the TenantUpdate builder.
func (tu *TenantUpdate) Where(ps ...predicate.Tenant) *TenantUpdate {
tu.mutation.Where(ps...)
return tu
}
// SetName sets the "name" field.
func (tu *TenantUpdate) SetName(s string) *TenantUpdate {
tu.mutation.SetName(s)
return tu
}
// SetNillableName sets the "name" field if the given value is not nil.
func (tu *TenantUpdate) SetNillableName(s *string) *TenantUpdate {
if s != nil {
tu.SetName(*s)
}
return tu
}
// Mutation returns the TenantMutation object of the builder.
func (tu *TenantUpdate) Mutation() *TenantMutation {
return tu.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (tu *TenantUpdate) Save(ctx context.Context) (int, error) {
return withHooks(ctx, tu.sqlSave, tu.mutation, tu.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (tu *TenantUpdate) SaveX(ctx context.Context) int {
affected, err := tu.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (tu *TenantUpdate) Exec(ctx context.Context) error {
_, err := tu.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (tu *TenantUpdate) ExecX(ctx context.Context) {
if err := tu.Exec(ctx); err != nil {
panic(err)
}
}
func (tu *TenantUpdate) sqlSave(ctx context.Context) (n int, err error) {
_spec := sqlgraph.NewUpdateSpec(tenant.Table, tenant.Columns, sqlgraph.NewFieldSpec(tenant.FieldID, field.TypeInt))
if ps := tu.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := tu.mutation.Name(); ok {
_spec.SetField(tenant.FieldName, field.TypeString, value)
}
if n, err = sqlgraph.UpdateNodes(ctx, tu.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{tenant.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
tu.mutation.done = true
return n, nil
}
// TenantUpdateOne is the builder for updating a single Tenant entity.
type TenantUpdateOne struct {
config
fields []string
hooks []Hook
mutation *TenantMutation
}
// SetName sets the "name" field.
func (tuo *TenantUpdateOne) SetName(s string) *TenantUpdateOne {
tuo.mutation.SetName(s)
return tuo
}
// SetNillableName sets the "name" field if the given value is not nil.
func (tuo *TenantUpdateOne) SetNillableName(s *string) *TenantUpdateOne {
if s != nil {
tuo.SetName(*s)
}
return tuo
}
// Mutation returns the TenantMutation object of the builder.
func (tuo *TenantUpdateOne) Mutation() *TenantMutation {
return tuo.mutation
}
// Where appends a list predicates to the TenantUpdate builder.
func (tuo *TenantUpdateOne) Where(ps ...predicate.Tenant) *TenantUpdateOne {
tuo.mutation.Where(ps...)
return tuo
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (tuo *TenantUpdateOne) Select(field string, fields ...string) *TenantUpdateOne {
tuo.fields = append([]string{field}, fields...)
return tuo
}
// Save executes the query and returns the updated Tenant entity.
func (tuo *TenantUpdateOne) Save(ctx context.Context) (*Tenant, error) {
return withHooks(ctx, tuo.sqlSave, tuo.mutation, tuo.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (tuo *TenantUpdateOne) SaveX(ctx context.Context) *Tenant {
node, err := tuo.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (tuo *TenantUpdateOne) Exec(ctx context.Context) error {
_, err := tuo.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (tuo *TenantUpdateOne) ExecX(ctx context.Context) {
if err := tuo.Exec(ctx); err != nil {
panic(err)
}
}
func (tuo *TenantUpdateOne) sqlSave(ctx context.Context) (_node *Tenant, err error) {
_spec := sqlgraph.NewUpdateSpec(tenant.Table, tenant.Columns, sqlgraph.NewFieldSpec(tenant.FieldID, field.TypeInt))
id, ok := tuo.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Tenant.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := tuo.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, tenant.FieldID)
for _, f := range fields {
if !tenant.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != tenant.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := tuo.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := tuo.mutation.Name(); ok {
_spec.SetField(tenant.FieldName, field.TypeString, value)
}
_node = &Tenant{config: tuo.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, tuo.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{tenant.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
tuo.mutation.done = true
return _node, nil
}

213
examples/rls/ent/tx.go Normal file
View File

@@ -0,0 +1,213 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"sync"
"entgo.io/ent/dialect"
)
// Tx is a transactional client that is created by calling Client.Tx().
type Tx struct {
config
// Tenant is the client for interacting with the Tenant builders.
Tenant *TenantClient
// User is the client for interacting with the User builders.
User *UserClient
// lazily loaded.
client *Client
clientOnce sync.Once
// ctx lives for the life of the transaction. It is
// the same context used by the underlying connection.
ctx context.Context
}
type (
// Committer is the interface that wraps the Commit method.
Committer interface {
Commit(context.Context, *Tx) error
}
// The CommitFunc type is an adapter to allow the use of ordinary
// function as a Committer. If f is a function with the appropriate
// signature, CommitFunc(f) is a Committer that calls f.
CommitFunc func(context.Context, *Tx) error
// CommitHook defines the "commit middleware". A function that gets a Committer
// and returns a Committer. For example:
//
// hook := func(next ent.Committer) ent.Committer {
// return ent.CommitFunc(func(ctx context.Context, tx *ent.Tx) error {
// // Do some stuff before.
// if err := next.Commit(ctx, tx); err != nil {
// return err
// }
// // Do some stuff after.
// return nil
// })
// }
//
CommitHook func(Committer) Committer
)
// Commit calls f(ctx, m).
func (f CommitFunc) Commit(ctx context.Context, tx *Tx) error {
return f(ctx, tx)
}
// Commit commits the transaction.
func (tx *Tx) Commit() error {
txDriver := tx.config.driver.(*txDriver)
var fn Committer = CommitFunc(func(context.Context, *Tx) error {
return txDriver.tx.Commit()
})
txDriver.mu.Lock()
hooks := append([]CommitHook(nil), txDriver.onCommit...)
txDriver.mu.Unlock()
for i := len(hooks) - 1; i >= 0; i-- {
fn = hooks[i](fn)
}
return fn.Commit(tx.ctx, tx)
}
// OnCommit adds a hook to call on commit.
func (tx *Tx) OnCommit(f CommitHook) {
txDriver := tx.config.driver.(*txDriver)
txDriver.mu.Lock()
txDriver.onCommit = append(txDriver.onCommit, f)
txDriver.mu.Unlock()
}
type (
// Rollbacker is the interface that wraps the Rollback method.
Rollbacker interface {
Rollback(context.Context, *Tx) error
}
// The RollbackFunc type is an adapter to allow the use of ordinary
// function as a Rollbacker. If f is a function with the appropriate
// signature, RollbackFunc(f) is a Rollbacker that calls f.
RollbackFunc func(context.Context, *Tx) error
// RollbackHook defines the "rollback middleware". A function that gets a Rollbacker
// and returns a Rollbacker. For example:
//
// hook := func(next ent.Rollbacker) ent.Rollbacker {
// return ent.RollbackFunc(func(ctx context.Context, tx *ent.Tx) error {
// // Do some stuff before.
// if err := next.Rollback(ctx, tx); err != nil {
// return err
// }
// // Do some stuff after.
// return nil
// })
// }
//
RollbackHook func(Rollbacker) Rollbacker
)
// Rollback calls f(ctx, m).
func (f RollbackFunc) Rollback(ctx context.Context, tx *Tx) error {
return f(ctx, tx)
}
// Rollback rollbacks the transaction.
func (tx *Tx) Rollback() error {
txDriver := tx.config.driver.(*txDriver)
var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error {
return txDriver.tx.Rollback()
})
txDriver.mu.Lock()
hooks := append([]RollbackHook(nil), txDriver.onRollback...)
txDriver.mu.Unlock()
for i := len(hooks) - 1; i >= 0; i-- {
fn = hooks[i](fn)
}
return fn.Rollback(tx.ctx, tx)
}
// OnRollback adds a hook to call on rollback.
func (tx *Tx) OnRollback(f RollbackHook) {
txDriver := tx.config.driver.(*txDriver)
txDriver.mu.Lock()
txDriver.onRollback = append(txDriver.onRollback, f)
txDriver.mu.Unlock()
}
// Client returns a Client that binds to current transaction.
func (tx *Tx) Client() *Client {
tx.clientOnce.Do(func() {
tx.client = &Client{config: tx.config}
tx.client.init()
})
return tx.client
}
func (tx *Tx) init() {
tx.Tenant = NewTenantClient(tx.config)
tx.User = NewUserClient(tx.config)
}
// txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation.
// The idea is to support transactions without adding any extra code to the builders.
// When a builder calls to driver.Tx(), it gets the same dialect.Tx instance.
// Commit and Rollback are nop for the internal builders and the user must call one
// of them in order to commit or rollback the transaction.
//
// If a closed transaction is embedded in one of the generated entities, and the entity
// applies a query, for example: Tenant.QueryXXX(), the query will be executed
// through the driver which created this transaction.
//
// Note that txDriver is not goroutine safe.
type txDriver struct {
// the driver we started the transaction from.
drv dialect.Driver
// tx is the underlying transaction.
tx dialect.Tx
// completion hooks.
mu sync.Mutex
onCommit []CommitHook
onRollback []RollbackHook
}
// newTx creates a new transactional driver.
func newTx(ctx context.Context, drv dialect.Driver) (*txDriver, error) {
tx, err := drv.Tx(ctx)
if err != nil {
return nil, err
}
return &txDriver{tx: tx, drv: drv}, nil
}
// Tx returns the transaction wrapper (txDriver) to avoid Commit or Rollback calls
// from the internal builders. Should be called only by the internal builders.
func (tx *txDriver) Tx(context.Context) (dialect.Tx, error) { return tx, nil }
// Dialect returns the dialect of the driver we started the transaction from.
func (tx *txDriver) Dialect() string { return tx.drv.Dialect() }
// Close is a nop close.
func (*txDriver) Close() error { return nil }
// Commit is a nop commit for the internal builders.
// User must call `Tx.Commit` in order to commit the transaction.
func (*txDriver) Commit() error { return nil }
// Rollback is a nop rollback for the internal builders.
// User must call `Tx.Rollback` in order to rollback the transaction.
func (*txDriver) Rollback() error { return nil }
// Exec calls tx.Exec.
func (tx *txDriver) Exec(ctx context.Context, query string, args, v any) error {
return tx.tx.Exec(ctx, query, args, v)
}
// Query calls tx.Query.
func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error {
return tx.tx.Query(ctx, query, args, v)
}
var _ dialect.Driver = (*txDriver)(nil)

114
examples/rls/ent/user.go Normal file
View File

@@ -0,0 +1,114 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/examples/rls/ent/user"
)
// User is the model entity for the User schema.
type User struct {
config `json:"-"`
// ID of the ent.
ID int `json:"id,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// TenantID holds the value of the "tenant_id" field.
TenantID int `json:"tenant_id,omitempty"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case user.FieldID, user.FieldTenantID:
values[i] = new(sql.NullInt64)
case user.FieldName:
values[i] = new(sql.NullString)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the User fields.
func (u *User) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case user.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
u.ID = int(value.Int64)
case user.FieldName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field name", values[i])
} else if value.Valid {
u.Name = value.String
}
case user.FieldTenantID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field tenant_id", values[i])
} else if value.Valid {
u.TenantID = int(value.Int64)
}
default:
u.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the User.
// This includes values selected through modifiers, order, etc.
func (u *User) Value(name string) (ent.Value, error) {
return u.selectValues.Get(name)
}
// Update returns a builder for updating this User.
// Note that you need to call User.Unwrap() before calling this method if this User
// was returned from a transaction, and the transaction was committed or rolled back.
func (u *User) Update() *UserUpdateOne {
return NewUserClient(u.config).UpdateOne(u)
}
// Unwrap unwraps the User entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (u *User) Unwrap() *User {
_tx, ok := u.config.driver.(*txDriver)
if !ok {
panic("ent: User is not a transactional entity")
}
u.config.driver = _tx.drv
return u
}
// String implements the fmt.Stringer.
func (u *User) String() string {
var builder strings.Builder
builder.WriteString("User(")
builder.WriteString(fmt.Sprintf("id=%v, ", u.ID))
builder.WriteString("name=")
builder.WriteString(u.Name)
builder.WriteString(", ")
builder.WriteString("tenant_id=")
builder.WriteString(fmt.Sprintf("%v", u.TenantID))
builder.WriteByte(')')
return builder.String()
}
// Users is a parsable slice of User.
type Users []*User

View File

@@ -0,0 +1,55 @@
// Code generated by ent, DO NOT EDIT.
package user
import (
"entgo.io/ent/dialect/sql"
)
const (
// Label holds the string label denoting the user type in the database.
Label = "user"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldName holds the string denoting the name field in the database.
FieldName = "name"
// FieldTenantID holds the string denoting the tenant_id field in the database.
FieldTenantID = "tenant_id"
// Table holds the table name of the user in the database.
Table = "users"
)
// Columns holds all SQL columns for user fields.
var Columns = []string{
FieldID,
FieldName,
FieldTenantID,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
// OrderOption defines the ordering options for the User queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByName orders the results by the name field.
func ByName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldName, opts...).ToFunc()
}
// ByTenantID orders the results by the tenant_id field.
func ByTenantID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTenantID, opts...).ToFunc()
}

View File

@@ -0,0 +1,183 @@
// Code generated by ent, DO NOT EDIT.
package user
import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/examples/rls/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int) predicate.User {
return predicate.User(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int) predicate.User {
return predicate.User(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int) predicate.User {
return predicate.User(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int) predicate.User {
return predicate.User(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int) predicate.User {
return predicate.User(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int) predicate.User {
return predicate.User(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int) predicate.User {
return predicate.User(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int) predicate.User {
return predicate.User(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int) predicate.User {
return predicate.User(sql.FieldLTE(FieldID, id))
}
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
func Name(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldName, v))
}
// TenantID applies equality check predicate on the "tenant_id" field. It's identical to TenantIDEQ.
func TenantID(v int) predicate.User {
return predicate.User(sql.FieldEQ(FieldTenantID, v))
}
// NameEQ applies the EQ predicate on the "name" field.
func NameEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldName, v))
}
// NameNEQ applies the NEQ predicate on the "name" field.
func NameNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldName, v))
}
// NameIn applies the In predicate on the "name" field.
func NameIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldName, vs...))
}
// NameNotIn applies the NotIn predicate on the "name" field.
func NameNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldName, vs...))
}
// NameGT applies the GT predicate on the "name" field.
func NameGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldName, v))
}
// NameGTE applies the GTE predicate on the "name" field.
func NameGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldName, v))
}
// NameLT applies the LT predicate on the "name" field.
func NameLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldName, v))
}
// NameLTE applies the LTE predicate on the "name" field.
func NameLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldName, v))
}
// NameContains applies the Contains predicate on the "name" field.
func NameContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldName, v))
}
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
func NameHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldName, v))
}
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
func NameHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldName, v))
}
// NameEqualFold applies the EqualFold predicate on the "name" field.
func NameEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldName, v))
}
// NameContainsFold applies the ContainsFold predicate on the "name" field.
func NameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldName, v))
}
// TenantIDEQ applies the EQ predicate on the "tenant_id" field.
func TenantIDEQ(v int) predicate.User {
return predicate.User(sql.FieldEQ(FieldTenantID, v))
}
// TenantIDNEQ applies the NEQ predicate on the "tenant_id" field.
func TenantIDNEQ(v int) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTenantID, v))
}
// TenantIDIn applies the In predicate on the "tenant_id" field.
func TenantIDIn(vs ...int) predicate.User {
return predicate.User(sql.FieldIn(FieldTenantID, vs...))
}
// TenantIDNotIn applies the NotIn predicate on the "tenant_id" field.
func TenantIDNotIn(vs ...int) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTenantID, vs...))
}
// TenantIDGT applies the GT predicate on the "tenant_id" field.
func TenantIDGT(v int) predicate.User {
return predicate.User(sql.FieldGT(FieldTenantID, v))
}
// TenantIDGTE applies the GTE predicate on the "tenant_id" field.
func TenantIDGTE(v int) predicate.User {
return predicate.User(sql.FieldGTE(FieldTenantID, v))
}
// TenantIDLT applies the LT predicate on the "tenant_id" field.
func TenantIDLT(v int) predicate.User {
return predicate.User(sql.FieldLT(FieldTenantID, v))
}
// TenantIDLTE applies the LTE predicate on the "tenant_id" field.
func TenantIDLTE(v int) predicate.User {
return predicate.User(sql.FieldLTE(FieldTenantID, v))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.User) predicate.User {
return predicate.User(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.User) predicate.User {
return predicate.User(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.User) predicate.User {
return predicate.User(sql.NotPredicates(p))
}

View File

@@ -0,0 +1,196 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/user"
"entgo.io/ent/schema/field"
)
// UserCreate is the builder for creating a User entity.
type UserCreate struct {
config
mutation *UserMutation
hooks []Hook
}
// SetName sets the "name" field.
func (uc *UserCreate) SetName(s string) *UserCreate {
uc.mutation.SetName(s)
return uc
}
// SetTenantID sets the "tenant_id" field.
func (uc *UserCreate) SetTenantID(i int) *UserCreate {
uc.mutation.SetTenantID(i)
return uc
}
// Mutation returns the UserMutation object of the builder.
func (uc *UserCreate) Mutation() *UserMutation {
return uc.mutation
}
// Save creates the User in the database.
func (uc *UserCreate) Save(ctx context.Context) (*User, error) {
return withHooks(ctx, uc.sqlSave, uc.mutation, uc.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (uc *UserCreate) SaveX(ctx context.Context) *User {
v, err := uc.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (uc *UserCreate) Exec(ctx context.Context) error {
_, err := uc.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (uc *UserCreate) ExecX(ctx context.Context) {
if err := uc.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (uc *UserCreate) check() error {
if _, ok := uc.mutation.Name(); !ok {
return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "User.name"`)}
}
if _, ok := uc.mutation.TenantID(); !ok {
return &ValidationError{Name: "tenant_id", err: errors.New(`ent: missing required field "User.tenant_id"`)}
}
return nil
}
func (uc *UserCreate) sqlSave(ctx context.Context) (*User, error) {
if err := uc.check(); err != nil {
return nil, err
}
_node, _spec := uc.createSpec()
if err := sqlgraph.CreateNode(ctx, uc.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int(id)
uc.mutation.id = &_node.ID
uc.mutation.done = true
return _node, nil
}
func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
var (
_node = &User{config: uc.config}
_spec = sqlgraph.NewCreateSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt))
)
if value, ok := uc.mutation.Name(); ok {
_spec.SetField(user.FieldName, field.TypeString, value)
_node.Name = value
}
if value, ok := uc.mutation.TenantID(); ok {
_spec.SetField(user.FieldTenantID, field.TypeInt, value)
_node.TenantID = value
}
return _node, _spec
}
// UserCreateBulk is the builder for creating many User entities in bulk.
type UserCreateBulk struct {
config
err error
builders []*UserCreate
}
// Save creates the User entities in the database.
func (ucb *UserCreateBulk) Save(ctx context.Context) ([]*User, error) {
if ucb.err != nil {
return nil, ucb.err
}
specs := make([]*sqlgraph.CreateSpec, len(ucb.builders))
nodes := make([]*User, len(ucb.builders))
mutators := make([]Mutator, len(ucb.builders))
for i := range ucb.builders {
func(i int, root context.Context) {
builder := ucb.builders[i]
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*UserMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, ucb.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, ucb.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, ucb.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (ucb *UserCreateBulk) SaveX(ctx context.Context) []*User {
v, err := ucb.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (ucb *UserCreateBulk) Exec(ctx context.Context) error {
_, err := ucb.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (ucb *UserCreateBulk) ExecX(ctx context.Context) {
if err := ucb.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/predicate"
"entgo.io/ent/examples/rls/ent/user"
"entgo.io/ent/schema/field"
)
// UserDelete is the builder for deleting a User entity.
type UserDelete struct {
config
hooks []Hook
mutation *UserMutation
}
// Where appends a list predicates to the UserDelete builder.
func (ud *UserDelete) Where(ps ...predicate.User) *UserDelete {
ud.mutation.Where(ps...)
return ud
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (ud *UserDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, ud.sqlExec, ud.mutation, ud.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (ud *UserDelete) ExecX(ctx context.Context) int {
n, err := ud.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (ud *UserDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt))
if ps := ud.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, ud.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
ud.mutation.done = true
return affected, err
}
// UserDeleteOne is the builder for deleting a single User entity.
type UserDeleteOne struct {
ud *UserDelete
}
// Where appends a list predicates to the UserDelete builder.
func (udo *UserDeleteOne) Where(ps ...predicate.User) *UserDeleteOne {
udo.ud.mutation.Where(ps...)
return udo
}
// Exec executes the deletion query.
func (udo *UserDeleteOne) Exec(ctx context.Context) error {
n, err := udo.ud.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{user.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (udo *UserDeleteOne) ExecX(ctx context.Context) {
if err := udo.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,527 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/predicate"
"entgo.io/ent/examples/rls/ent/user"
"entgo.io/ent/schema/field"
)
// UserQuery is the builder for querying User entities.
type UserQuery struct {
config
ctx *QueryContext
order []user.OrderOption
inters []Interceptor
predicates []predicate.User
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserQuery builder.
func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery {
uq.predicates = append(uq.predicates, ps...)
return uq
}
// Limit the number of records to be returned by this query.
func (uq *UserQuery) Limit(limit int) *UserQuery {
uq.ctx.Limit = &limit
return uq
}
// Offset to start from.
func (uq *UserQuery) Offset(offset int) *UserQuery {
uq.ctx.Offset = &offset
return uq
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (uq *UserQuery) Unique(unique bool) *UserQuery {
uq.ctx.Unique = &unique
return uq
}
// Order specifies how the records should be ordered.
func (uq *UserQuery) Order(o ...user.OrderOption) *UserQuery {
uq.order = append(uq.order, o...)
return uq
}
// First returns the first User entity from the query.
// Returns a *NotFoundError when no User was found.
func (uq *UserQuery) First(ctx context.Context) (*User, error) {
nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{user.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (uq *UserQuery) FirstX(ctx context.Context) *User {
node, err := uq.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first User ID from the query.
// Returns a *NotFoundError when no User ID was found.
func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{user.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (uq *UserQuery) FirstIDX(ctx context.Context) int {
id, err := uq.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single User entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one User entity is found.
// Returns a *NotFoundError when no User entities are found.
func (uq *UserQuery) Only(ctx context.Context) (*User, error) {
nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{user.Label}
default:
return nil, &NotSingularError{user.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (uq *UserQuery) OnlyX(ctx context.Context) *User {
node, err := uq.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only User ID in the query.
// Returns a *NotSingularError when more than one User ID is found.
// Returns a *NotFoundError when no entities are found.
func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{user.Label}
default:
err = &NotSingularError{user.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (uq *UserQuery) OnlyIDX(ctx context.Context) int {
id, err := uq.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of Users.
func (uq *UserQuery) All(ctx context.Context) ([]*User, error) {
ctx = setContextOp(ctx, uq.ctx, ent.OpQueryAll)
if err := uq.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*User, *UserQuery]()
return withInterceptors[[]*User](ctx, uq, qr, uq.inters)
}
// AllX is like All, but panics if an error occurs.
func (uq *UserQuery) AllX(ctx context.Context) []*User {
nodes, err := uq.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of User IDs.
func (uq *UserQuery) IDs(ctx context.Context) (ids []int, err error) {
if uq.ctx.Unique == nil && uq.path != nil {
uq.Unique(true)
}
ctx = setContextOp(ctx, uq.ctx, ent.OpQueryIDs)
if err = uq.Select(user.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (uq *UserQuery) IDsX(ctx context.Context) []int {
ids, err := uq.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (uq *UserQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, uq.ctx, ent.OpQueryCount)
if err := uq.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, uq, querierCount[*UserQuery](), uq.inters)
}
// CountX is like Count, but panics if an error occurs.
func (uq *UserQuery) CountX(ctx context.Context) int {
count, err := uq.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (uq *UserQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, uq.ctx, ent.OpQueryExist)
switch _, err := uq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (uq *UserQuery) ExistX(ctx context.Context) bool {
exist, err := uq.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (uq *UserQuery) Clone() *UserQuery {
if uq == nil {
return nil
}
return &UserQuery{
config: uq.config,
ctx: uq.ctx.Clone(),
order: append([]user.OrderOption{}, uq.order...),
inters: append([]Interceptor{}, uq.inters...),
predicates: append([]predicate.User{}, uq.predicates...),
// clone intermediate query.
sql: uq.sql.Clone(),
path: uq.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// Name string `json:"name,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.User.Query().
// GroupBy(user.FieldName).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy {
uq.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserGroupBy{build: uq}
grbuild.flds = &uq.ctx.Fields
grbuild.label = user.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// Name string `json:"name,omitempty"`
// }
//
// client.User.Query().
// Select(user.FieldName).
// Scan(ctx, &v)
func (uq *UserQuery) Select(fields ...string) *UserSelect {
uq.ctx.Fields = append(uq.ctx.Fields, fields...)
sbuild := &UserSelect{UserQuery: uq}
sbuild.label = user.Label
sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserSelect configured with the given aggregations.
func (uq *UserQuery) Aggregate(fns ...AggregateFunc) *UserSelect {
return uq.Select().Aggregate(fns...)
}
func (uq *UserQuery) prepareQuery(ctx context.Context) error {
for _, inter := range uq.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, uq); err != nil {
return err
}
}
}
for _, f := range uq.ctx.Fields {
if !user.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if uq.path != nil {
prev, err := uq.path(ctx)
if err != nil {
return err
}
uq.sql = prev
}
return nil
}
func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) {
var (
nodes = []*User{}
_spec = uq.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*User).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &User{config: uq.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) {
_spec := uq.querySpec()
_spec.Node.Columns = uq.ctx.Fields
if len(uq.ctx.Fields) > 0 {
_spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique
}
return sqlgraph.CountNodes(ctx, uq.driver, _spec)
}
func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(user.Table, user.Columns, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt))
_spec.From = uq.sql
if unique := uq.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if uq.path != nil {
_spec.Unique = true
}
if fields := uq.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, user.FieldID)
for i := range fields {
if fields[i] != user.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := uq.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := uq.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := uq.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := uq.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(uq.driver.Dialect())
t1 := builder.Table(user.Table)
columns := uq.ctx.Fields
if len(columns) == 0 {
columns = user.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if uq.sql != nil {
selector = uq.sql
selector.Select(selector.Columns(columns...)...)
}
if uq.ctx.Unique != nil && *uq.ctx.Unique {
selector.Distinct()
}
for _, p := range uq.predicates {
p(selector)
}
for _, p := range uq.order {
p(selector)
}
if offset := uq.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := uq.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// UserGroupBy is the group-by builder for User entities.
type UserGroupBy struct {
selector
build *UserQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy {
ugb.fns = append(ugb.fns, fns...)
return ugb
}
// Scan applies the selector query and scans the result into the given value.
func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, ugb.build.ctx, ent.OpQueryGroupBy)
if err := ugb.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserQuery, *UserGroupBy](ctx, ugb.build, ugb, ugb.build.inters, v)
}
func (ugb *UserGroupBy) sqlScan(ctx context.Context, root *UserQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(ugb.fns))
for _, fn := range ugb.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*ugb.flds)+len(ugb.fns))
for _, f := range *ugb.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*ugb.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := ugb.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserSelect is the builder for selecting fields of User entities.
type UserSelect struct {
*UserQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect {
us.fns = append(us.fns, fns...)
return us
}
// Scan applies the selector query and scans the result into the given value.
func (us *UserSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, us.ctx, ent.OpQuerySelect)
if err := us.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserQuery, *UserSelect](ctx, us.UserQuery, us, us.inters, v)
}
func (us *UserSelect) sqlScan(ctx context.Context, root *UserQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(us.fns))
for _, fn := range us.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*us.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := us.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,263 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/rls/ent/predicate"
"entgo.io/ent/examples/rls/ent/user"
"entgo.io/ent/schema/field"
)
// UserUpdate is the builder for updating User entities.
type UserUpdate struct {
config
hooks []Hook
mutation *UserMutation
}
// Where appends a list predicates to the UserUpdate builder.
func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate {
uu.mutation.Where(ps...)
return uu
}
// SetName sets the "name" field.
func (uu *UserUpdate) SetName(s string) *UserUpdate {
uu.mutation.SetName(s)
return uu
}
// SetNillableName sets the "name" field if the given value is not nil.
func (uu *UserUpdate) SetNillableName(s *string) *UserUpdate {
if s != nil {
uu.SetName(*s)
}
return uu
}
// SetTenantID sets the "tenant_id" field.
func (uu *UserUpdate) SetTenantID(i int) *UserUpdate {
uu.mutation.ResetTenantID()
uu.mutation.SetTenantID(i)
return uu
}
// SetNillableTenantID sets the "tenant_id" field if the given value is not nil.
func (uu *UserUpdate) SetNillableTenantID(i *int) *UserUpdate {
if i != nil {
uu.SetTenantID(*i)
}
return uu
}
// AddTenantID adds i to the "tenant_id" field.
func (uu *UserUpdate) AddTenantID(i int) *UserUpdate {
uu.mutation.AddTenantID(i)
return uu
}
// Mutation returns the UserMutation object of the builder.
func (uu *UserUpdate) Mutation() *UserMutation {
return uu.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (uu *UserUpdate) Save(ctx context.Context) (int, error) {
return withHooks(ctx, uu.sqlSave, uu.mutation, uu.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (uu *UserUpdate) SaveX(ctx context.Context) int {
affected, err := uu.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (uu *UserUpdate) Exec(ctx context.Context) error {
_, err := uu.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (uu *UserUpdate) ExecX(ctx context.Context) {
if err := uu.Exec(ctx); err != nil {
panic(err)
}
}
func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
_spec := sqlgraph.NewUpdateSpec(user.Table, user.Columns, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt))
if ps := uu.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := uu.mutation.Name(); ok {
_spec.SetField(user.FieldName, field.TypeString, value)
}
if value, ok := uu.mutation.TenantID(); ok {
_spec.SetField(user.FieldTenantID, field.TypeInt, value)
}
if value, ok := uu.mutation.AddedTenantID(); ok {
_spec.AddField(user.FieldTenantID, field.TypeInt, value)
}
if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
uu.mutation.done = true
return n, nil
}
// UserUpdateOne is the builder for updating a single User entity.
type UserUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserMutation
}
// SetName sets the "name" field.
func (uuo *UserUpdateOne) SetName(s string) *UserUpdateOne {
uuo.mutation.SetName(s)
return uuo
}
// SetNillableName sets the "name" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableName(s *string) *UserUpdateOne {
if s != nil {
uuo.SetName(*s)
}
return uuo
}
// SetTenantID sets the "tenant_id" field.
func (uuo *UserUpdateOne) SetTenantID(i int) *UserUpdateOne {
uuo.mutation.ResetTenantID()
uuo.mutation.SetTenantID(i)
return uuo
}
// SetNillableTenantID sets the "tenant_id" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableTenantID(i *int) *UserUpdateOne {
if i != nil {
uuo.SetTenantID(*i)
}
return uuo
}
// AddTenantID adds i to the "tenant_id" field.
func (uuo *UserUpdateOne) AddTenantID(i int) *UserUpdateOne {
uuo.mutation.AddTenantID(i)
return uuo
}
// Mutation returns the UserMutation object of the builder.
func (uuo *UserUpdateOne) Mutation() *UserMutation {
return uuo.mutation
}
// Where appends a list predicates to the UserUpdate builder.
func (uuo *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
uuo.mutation.Where(ps...)
return uuo
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (uuo *UserUpdateOne) Select(field string, fields ...string) *UserUpdateOne {
uuo.fields = append([]string{field}, fields...)
return uuo
}
// Save executes the query and returns the updated User entity.
func (uuo *UserUpdateOne) Save(ctx context.Context) (*User, error) {
return withHooks(ctx, uuo.sqlSave, uuo.mutation, uuo.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (uuo *UserUpdateOne) SaveX(ctx context.Context) *User {
node, err := uuo.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (uuo *UserUpdateOne) Exec(ctx context.Context) error {
_, err := uuo.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (uuo *UserUpdateOne) ExecX(ctx context.Context) {
if err := uuo.Exec(ctx); err != nil {
panic(err)
}
}
func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
_spec := sqlgraph.NewUpdateSpec(user.Table, user.Columns, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt))
id, ok := uuo.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "User.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := uuo.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, user.FieldID)
for _, f := range fields {
if !user.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != user.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := uuo.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := uuo.mutation.Name(); ok {
_spec.SetField(user.FieldName, field.TypeString, value)
}
if value, ok := uuo.mutation.TenantID(); ok {
_spec.SetField(user.FieldTenantID, field.TypeInt, value)
}
if value, ok := uuo.mutation.AddedTenantID(); ok {
_spec.AddField(user.FieldTenantID, field.TypeInt, value)
}
_node = &User{config: uuo.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, uuo.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
uuo.mutation.done = true
return _node, nil
}

View File

@@ -0,0 +1,62 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package main
import (
"context"
"log"
"os"
"testing"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect"
"entgo.io/ent/examples/rls/ent"
"ariga.io/atlas-go-sdk/atlasexec"
_ "github.com/lib/pq"
"github.com/stretchr/testify/require"
)
func TestRowLevelSecurity(t *testing.T) {
if os.Getenv("CI") != "" {
t.Skip()
}
ctx := context.Background()
// Note that APP_URL is used for the ent client, and ATLAS_URL is used for the atlas client.
// Two different roles. The app role has access to the specific tables, and the atlas role in
// this example is the default superuser role.
client, err := ent.Open(dialect.Postgres, os.Getenv("APP_URL"))
if err != nil {
log.Fatalln(err)
}
ac, err := atlasexec.NewClient(".", "atlas")
if err != nil {
log.Fatalf("failed to initialize client: %v", err)
}
// Automatically update the database with the desired schema.
// Another option, is to use 'migrate apply' or 'schema apply' manually.
_, err = ac.SchemaApply(ctx, &atlasexec.SchemaApplyParams{
// URL to your database. For example:
// postgres://postgres:pass@localhost:5432/database?search_path=public&sslmode=disable
URL: os.Getenv("ATLAS_URL"),
Env: "local",
})
require.NoError(t, err)
t.Cleanup(func() {
client.User.Delete().ExecX(ctx)
client.Tenant.Delete().ExecX(ctx)
})
a8m, r3m := client.Tenant.Create().SetName("a8m").SaveX(ctx), client.Tenant.Create().SetName("r3m").SaveX(ctx)
ctx1, ctx2 := sql.WithIntVar(ctx, "app.current_tenant", a8m.ID), sql.WithIntVar(ctx, "app.current_tenant", r3m.ID)
u1 := client.User.Create().SetName("User: a8m").SetTenantID(a8m.ID).SaveX(ctx1)
u2 := client.User.Create().SetName("User: r3m").SetTenantID(r3m.ID).SaveX(ctx2)
users1 := client.User.Query().AllX(ctx1)
require.Len(t, users1, 1)
require.Equal(t, u1.ID, users1[0].ID)
users2 := client.User.Query().AllX(ctx2)
require.Len(t, users2, 1)
require.Equal(t, u2.ID, users2[0].ID)
}

View File

@@ -0,0 +1,8 @@
-- Create "users" table
CREATE TABLE "users" ("id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, "name" character varying NOT NULL, "tenant_id" bigint NOT NULL, PRIMARY KEY ("id"));
-- Enable row-level security for "users" table
ALTER TABLE "users" ENABLE ROW LEVEL SECURITY;
-- Create policy "tenant_isolation"
CREATE POLICY "tenant_isolation" ON "users" AS PERMISSIVE FOR ALL TO PUBLIC USING (tenant_id = (current_setting('app.current_tenant'::text))::integer);
-- Create "tenants" table
CREATE TABLE "tenants" ("id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, "name" character varying NOT NULL, PRIMARY KEY ("id"));

View File

@@ -0,0 +1,2 @@
h1:UgEhzp8eL2yI01uMPYx3rblLHLfUMm8/ZMhIjom3yp0=
20240714101101.sql h1:hKQHX9Nr7MaGr39gRl6UmyIx4nfe2GGg2csbk+3VRuo=

6
examples/rls/schema.sql Normal file
View File

@@ -0,0 +1,6 @@
-- Enable row-level security on the users table.
ALTER TABLE "users" ENABLE ROW LEVEL SECURITY;
-- Create a policy that restricts access to rows in the users table based on the current tenant.
CREATE POLICY tenant_isolation ON "users"
USING ("tenant_id" = current_setting('app.current_tenant')::integer);