Files
ent/entc/gen/template/client.tmpl

521 lines
16 KiB
Cheetah

{{/*
Copyright 2019-present Facebook Inc. All rights reserved.
This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}
{{/* gotype: entgo.io/ent/entc/gen.Graph */}}
{{ define "client/init" }}
// Client is the client that holds all ent builders.
type Client struct {
config
{{- if $.SupportMigrate }}
// Schema is the client for creating, migrating and dropping schema.
Schema *migrate.Schema
{{- end }}
{{- range $n := $.Nodes }}
// {{ $n.Name }} is the client for interacting with the {{ $n.Name }} builders.
{{ $n.Name }} *{{ $n.ClientName }}
{{- end }}
{{- template "client/fields/additional" $ }}
{{- with $tmpls := matchTemplate "client/fields/additional/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
}
// NewClient creates a new client configured with the given options.
func NewClient(opts ...Option) *Client {
client := &Client{config: newConfig(opts...)}
client.init()
return client
}
func (c *Client) init() {
{{- if $.SupportMigrate }}
c.Schema = migrate.NewSchema(c.driver)
{{- end }}
{{- range $n := $.Nodes }}
c.{{ $n.Name }} = New{{ $n.ClientName }}(c.config)
{{- end }}
}
{{ end }}
{{ define "client" }}
{{ $pkg := base $.Config.Package }}
{{ template "header" $ }}
{{/* Additional dependencies injected to config. */}}
{{ $deps := list }}{{ with $.Config.Annotations }}{{ $deps = $.Config.Annotations.Dependencies }}{{ end }}
import (
"context"
"log"
"{{ $.Config.Package }}/migrate"
{{ range $n := $.Nodes }}
{{ $n.PackageAlias }} "{{ $n.Config.Package }}/{{ $n.PackageDir }}"
{{- end }}
{{- range $dep := $deps }}
{{ $dep.Type.PkgName }} "{{ $dep.Type.PkgPath }}"
{{- end }}
"entgo.io/ent/dialect"
{{ range $import := $.Storage.Imports -}}
"{{ $import }}"
{{ end -}}
{{- template "import/additional" $ }}
)
{{ template "client/init" $ }}
type (
// config is the configuration for the client and its builder.
config struct {
// driver used for executing database requests.
driver dialect.Driver
// debug enable a debug logging.
debug bool
// log used for logging on debug mode.
log func(...any)
// hooks to execute on mutations.
hooks *hooks
// interceptors to execute on queries.
inters *inters
{{- /* Additional dependency fields. */}}
{{- range $dep := $deps }}
{{ $dep.Field }} {{ $dep.Type }}
{{- end }}
{{- /* Support adding config fields from both global or dialect-specific templates. */}}
{{- range $prefix := list "" (printf "dialect/%s/" $.Storage) }}
{{- with $tmpls := matchTemplate (print $prefix "config/fields/*") }}
{{- range $tmpl := $tmpls }}
{{ xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{- end }}
}
// Option function to configure the client.
Option func(*config)
)
// newConfig creates a new config for the client.
func newConfig(opts ...Option) config {
cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}}
{{- with $tmpls := matchTemplate "config/init/fields/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
cfg.options(opts...)
return cfg
}
// options applies the options on the config object.
func (c *config) options(opts ...Option) {
for _, opt := range opts {
opt(c)
}
if c.debug {
c.driver = dialect.Debug(c.driver, c.log)
}
}
// Debug enables debug logging on the ent.Driver.
func Debug() Option {
return func(c *config) {
c.debug = true
}
}
// Log sets the logging function for debug mode.
func Log(fn func(...any)) Option {
return func(c *config) {
c.log = fn
}
}
// Driver configures the client driver.
func Driver(driver dialect.Driver) Option {
return func(c *config) {
c.driver = driver
}
}
{{- /* Additional dependency options reside on the config object. */}}
{{- range $dep := $deps }}
// {{ $dep.Option }} configures the {{ $dep.Field }}.
func {{ $dep.Option }}(v {{ $dep.Type }}) Option {
return func(c *config) {
c.{{ $dep.Field }} = v
}
}
{{- end }}
// Open opens a database/sql.DB specified by the driver name and
// the data source name, and returns a new client attached to it.
// Optional parameters can be added for configuring the client.
func Open(driverName, dataSourceName string, options ...Option) (*Client, error) {
switch driverName {
case {{ join $.Storage.Dialects ", " }}:
{{- $tmpl := printf "dialect/%s/client/open" $.Storage -}}
{{- xtemplate $tmpl . -}}
default:
return nil, fmt.Errorf("unsupported driver: %q", driverName)
}
}
// ErrTxStarted is returned when trying to start a new transaction from a transactional client.
var ErrTxStarted = errors.New("{{ $pkg }}: cannot start a transaction within a transaction")
// Tx returns a new transactional client. The provided context
// is used until the transaction is committed or rolled back.
func (c *Client) Tx(ctx context.Context) (*Tx, error) {
if _, ok := c.driver.(*txDriver); ok {
return nil, ErrTxStarted
}
tx, err := newTx(ctx, c.driver)
if err != nil {
return nil, fmt.Errorf("{{ $pkg }}: starting a transaction: %w", err)
}
cfg := c.config
cfg.driver = tx
return &Tx{
ctx: ctx,
config: cfg,
{{- range $n := $.Nodes }}
{{ $n.Name }}: New{{ $n.ClientName }}(cfg),
{{- end }}
}, nil
}
{{- /* If the storage driver supports TxOptions (like SQL) */}}
{{- $tmpl = printf "dialect/%s/txoptions" $.Storage }}
{{- if hasTemplate $tmpl }}
{{- xtemplate $tmpl . }}
{{- end }}
// Debug returns a new debug-client. It's used to get verbose logging on specific operations.
//
// client.Debug().
// {{ (index $.Nodes 0).Name }}.
// Query().
// Count(ctx)
//
func (c *Client) Debug() *Client {
if c.debug {
return c
}
cfg := c.config
cfg.driver = dialect.Debug(c.driver, c.log)
client := &Client{config: cfg}
client.init()
return client
}
// Close closes the database connection and prevents new queries from starting.
func (c *Client) Close() error {
return c.driver.Close()
}
// Use adds the mutation hooks to all the entity clients.
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
func (c *Client) Use(hooks ...Hook) {
{{- if lt (len $.Nodes) 6 }}
{{- range $n := $.MutableNodes }}
c.{{ $n.Name }}.Use(hooks...)
{{- end }}
{{- /* Use a for-loop for setting hooks in case there are too many nodes. */}}
{{- else }}
{{- $hooks := slist }}
{{- range $n := $.MutableNodes }}{{ $hooks = append $hooks (printf "c.%s," $n.Name) }}{{ end }}
for _, n := range []interface{ Use(...Hook) }{
{{ joinWords $hooks 80 }}
}{
n.Use(hooks...)
}
{{- end }}
}
// Intercept adds the query interceptors to all the entity clients.
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
func (c *Client) Intercept(interceptors ...Interceptor) {
{{- if lt (len $.Nodes) 6 }}
{{- range $n := $.Nodes }}
c.{{ $n.Name }}.Intercept(interceptors...)
{{- end }}
{{- /* Use a for-loop for setting interceptors in case there are too many nodes. */}}
{{- else }}
{{- $inters := slist }}
{{- range $n := $.Nodes }}{{ $inters = append $inters (printf "c.%s," $n.Name) }}{{ end }}
for _, n := range []interface{ Intercept(...Interceptor) }{
{{ joinWords $inters 80 }}
}{
n.Intercept(interceptors...)
}
{{- end }}
}
{{- with $tmpls := matchTemplate "client/additional/*" "client/additional/*/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
// Mutate implements the ent.Mutator interface.
func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
switch m := m.(type) {
{{- range $n := $.MutableNodes }}
case *{{ $n.MutationName }}:
return c.{{ $n.Name }}.mutate(ctx, m)
{{- end }}
default:
return nil, fmt.Errorf("{{ $pkg }}: unknown mutation type %T", m)
}
}
{{ range $n := $.Nodes }}
{{ $client := $n.ClientName }}
// {{ $client }} is a client for the {{ $n.Name }} schema.
type {{ $client }} struct {
config
}
{{ $rec := $n.Receiver }}{{ if eq $rec "c" }}{{ $rec = printf "%.2s" $n.Name | lower }}{{ end }}
// New{{ $client }} returns a client for the {{ $n.Name }} from the given config.
func New{{ $client }}(c config) *{{ $client }} {
return &{{ $client }}{config: c}
}
{{- if not $n.IsView }}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `{{ $n.Package }}.Hooks(f(g(h())))`.
func (c *{{ $client }}) Use(hooks ...Hook) {
c.hooks.{{ $n.Name }} = append(c.hooks.{{ $n.Name }}, hooks...)
}
{{- end }}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `{{ $n.Package }}.Intercept(f(g(h())))`.
func (c *{{ $client }}) Intercept(interceptors ...Interceptor) {
c.inters.{{ $n.Name }} = append(c.inters.{{ $n.Name }}, interceptors...)
}
{{- if not $n.IsView }}
// Create returns a builder for creating a {{ $n.Name }} entity.
func (c *{{ $client }}) Create() *{{ $n.CreateName }} {
mutation := new{{ $n.MutationName }}(c.config, OpCreate)
return &{{ $n.CreateName }}{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of {{ $n.Name }} entities.
func (c *{{ $client }}) CreateBulk(builders ...*{{ $n.CreateName }}) *{{ $n.CreateBulkName }} {
return &{{ $n.CreateBulkName }}{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *{{ $client }}) MapCreateBulk(slice any, setFunc func(*{{ $n.CreateName }}, int)) *{{ $n.CreateBulkName }} {
{{- /* This check should be replaced with type-parameter method, if it will be supported by Go. */}}
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &{{ $n.CreateBulkName }}{err: fmt.Errorf("calling to {{ $client }}.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*{{ $n.CreateName }}, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &{{ $n.CreateBulkName }}{config: c.config, builders: builders}
}
// Update returns an update builder for {{ $n.Name }}.
func (c *{{ $client }}) Update() *{{ $n.UpdateName }} {
mutation := new{{ $n.MutationName }}(c.config, OpUpdate)
return &{{ $n.UpdateName }}{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *{{ $client }}) UpdateOne({{ $rec }} *{{ $n.Name }}) *{{ $n.UpdateOneName }} {
{{- if $n.HasOneFieldID }}
mutation := new{{ $n.MutationName }}(c.config, OpUpdateOne, {{ print "with" $n.Name }}({{ $rec }}))
{{- else }}
mutation := new{{ $n.MutationName }}(c.config, OpUpdateOne)
{{- range $id := $n.EdgeSchema.ID }}
mutation.{{ $id.MutationSet }}({{ $rec }}.{{ $id.StructField }})
{{- end }}
{{- end }}
return &{{ $n.UpdateOneName }}{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
{{ with $n.HasOneFieldID }}
// UpdateOneID returns an update builder for the given id.
func (c *{{ $client }}) UpdateOneID(id {{ $n.ID.Type }}) *{{ $n.UpdateOneName }} {
mutation := new{{ $n.MutationName }}(c.config, OpUpdateOne, {{ print "with" $n.Name "ID" }}(id))
return &{{ $n.UpdateOneName }}{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
{{ end }}
// Delete returns a delete builder for {{ $n.Name }}.
func (c *{{ $client }}) Delete() *{{ $n.DeleteName }} {
mutation := new{{ $n.MutationName }}(c.config, OpDelete)
return &{{ $n.DeleteName }}{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
{{ with $n.HasOneFieldID }}
// DeleteOne returns a builder for deleting the given entity.
func (c *{{ $client }}) DeleteOne({{ $rec }} *{{ $n.Name }}) *{{ $n.DeleteOneName }} {
return c.DeleteOneID({{ $rec }}.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *{{ $client }}) DeleteOneID(id {{ $n.ID.Type }}) *{{ $n.DeleteOneName }} {
{{- $builder := "builder" }}{{ if eq $n.Package $builder }}{{ $builder = "builderC" }}{{ end }}
{{ $builder }} := c.Delete().Where({{ $n.Package }}.ID(id))
{{ $builder }}.mutation.{{ $n.ID.BuilderField }} = &id
{{ $builder }}.mutation.SetOp(OpDeleteOne)
return &{{ $n.DeleteOneName }}{ {{ $builder }} }
}
{{ end }}
{{- end }} {{/* End of if not IsView. */}}
// Query returns a query builder for {{ $n.Name }}.
func (c *{{ $client }}) Query() *{{ $n.QueryName }} {
return &{{ $n.QueryName }}{
config: c.config,
ctx: &QueryContext{Type: {{ $n.TypeName }} },
inters: c.Interceptors(),
{{- with $tmpls := matchTemplate (printf "dialect/%s/query/fields/init/*" $.Storage) }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $n }}
{{- end }}
{{- end }}
}
}
{{ with $n.HasOneFieldID }}
// Get returns a {{ $n.Name }} entity by its id.
func (c *{{ $client }}) Get(ctx context.Context, id {{ $n.ID.Type }}) (*{{ $n.Name }}, error) {
return c.Query().Where({{ $n.Package }}.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *{{ $client }}) GetX(ctx context.Context, id {{ $n.ID.Type }}) *{{ $n.Name }} {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
{{ end }}
{{ range $e := $n.Edges }}
{{ $builder := $e.Type.QueryName }}
{{ $arg := $rec }}{{ if eq $arg "id" }}{{ $arg = "node" }}{{ end }}
{{ $func := print "Query" (pascal $e.Name) }}
// Query{{ pascal $e.Name }} queries the {{ $e.Name }} edge of a {{ $n.Name }}.
func (c *{{ $client }}) {{ $func }}({{ $arg }} *{{ $n.Name }}) *{{ $builder }} {
{{- if $n.HasOneFieldID }}
query := (&{{ $e.Type.ClientName }}{config: c.config}).Query()
query.path = func(context.Context) (fromV {{ $.Storage.Builder }}, _ error) {
{{- with extend $n "Receiver" $arg "Edge" $e "Ident" "fromV" }}
{{ $tmpl := printf "dialect/%s/query/from" $.Storage }}
{{- xtemplate $tmpl . -}}
{{- end -}}
return fromV, nil
}
return query
{{- else }}
{{- /* For edge schema, we use the predicate-based approach. */}}
return c.Query().
Where({{ range $id := $n.EdgeSchema.ID }}{{ $n.Package }}.{{ $id.StructField }}({{ $arg }}.{{ $id.StructField }}),{{ end }}).
{{ $func }}()
{{- end }}
}
{{ end }}
{{- if not $n.IsView }}
// Hooks returns the client hooks.
func (c *{{ $client }}) Hooks() []Hook {
{{- if or $n.NumHooks $n.NumPolicy }}
hooks := c.hooks.{{ $n.Name }}
return append(hooks[:len(hooks):len(hooks)], {{ $n.Package }}.Hooks[:]...)
{{- else }}
return c.hooks.{{ $n.Name }}
{{- end }}
}
{{- end }}
// Interceptors returns the client interceptors.
func (c *{{ $client }}) Interceptors() []Interceptor {
{{- if $n.NumInterceptors }}
inters := c.inters.{{ $n.Name }}
return append(inters[:len(inters):len(inters)], {{ $n.Package }}.Interceptors[:]...)
{{- else }}
return c.inters.{{ $n.Name }}
{{- end }}
}
{{ if not $n.IsView }}
func (c *{{ $client }}) mutate(ctx context.Context, m *{{ $n.MutationName }}) (Value, error) {
switch m.Op() {
case OpCreate:
return (&{{ $n.CreateName }}{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&{{ $n.UpdateName }}{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&{{ $n.UpdateOneName }}{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&{{ $n.DeleteName }}{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("{{ $pkg }}: unknown {{ $n.Name }} mutation op: %q", m.Op())
}
}
{{ end }}
{{- end }}
// hooks and interceptors per client, for fast access.
type (
{{- $leni := len $.Nodes | add -1 }}
{{- $hooks := slist }}
{{- $inters := slist }}
{{- range $i, $n := $.Nodes }}
{{- if eq $i $leni }}
{{ if not $n.IsView }}{{- $hooks = append $hooks (printf "%s []ent.Hook" $n.Name) }}{{ end }}
{{- $inters = append $inters (printf "%s []ent.Interceptor" $n.Name) }}
{{- else }}
{{ if not $n.IsView }}{{- $hooks = append $hooks (print $n.Name ",") }}{{ end }}
{{- $inters = append $inters (print $n.Name ",") }}
{{- end }}
{{- end }}
hooks struct {
{{ joinWords $hooks 80 }}
}
inters struct {
{{ joinWords $inters 80 }}
}
)
{{- /* Support adding config options from both global or dialect-specific templates. */}}
{{- range $prefix := list "" (printf "dialect/%s/" $.Storage) }}
{{- with $tmpls := matchTemplate (print $prefix "config/options/*") }}
{{- range $tmpl := $tmpls }}
{{ xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{- end }}
{{- with $tmpls := matchTemplate "config/additional/*" "config/additional/*/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{ end }}
{{/* A template that can be overridden in order to add additional fields to the client.*/}}
{{ define "client/fields/additional" }}{{ end }}