entc/gen/template: use a fixed receiver

This commit is contained in:
Giau. Tran Minh
2025-03-17 11:55:47 +07:00
parent 6813cdd337
commit 31bbe41c5d
24 changed files with 495 additions and 575 deletions

View File

@@ -17,8 +17,6 @@ in the LICENSE file in the root directory of this source tree.
{{ end }}
{{ $builder := $.CreateName }}
{{ $receiver := $.CreateReceiver }}
{{ $mutation := print $receiver ".mutation" }}
// {{ $builder }} is the builder for creating a {{ $.Name }} entity.
type {{ $builder }} struct {
@@ -32,27 +30,27 @@ type {{ $builder }} struct {
{{- end }}
}
{{ with extend $ "Receiver" $receiver "Builder" $builder }}
{{ with extend $ "Receiver" "c" "Builder" $builder }}
{{ template "setter" . }}
{{ end }}
// Save creates the {{ $.Name }} in the database.
func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) (*{{ $.Name }}, error) {
func (c *{{ $builder }}) Save(ctx context.Context) (*{{ $.Name }}, error) {
{{- if $.HasDefault }}
{{- if $runtimeRequired }}
if err := {{ $receiver }}.defaults(); err != nil {
if err := c.defaults(); err != nil {
return nil, err
}
{{- else }}
{{ $receiver }}.defaults()
c.defaults()
{{- end }}
{{- end }}
return withHooks(ctx, {{ $receiver }}.{{ $.Storage }}Save, {{ $mutation }}, {{ $receiver }}.hooks)
return withHooks(ctx, c.{{ $.Storage }}Save, c.mutation, c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func ({{ $receiver }} *{{ $builder }}) SaveX(ctx context.Context) *{{ $.Name }} {
v, err := {{ $receiver }}.Save(ctx)
func (c *{{ $builder }}) SaveX(ctx context.Context) *{{ $.Name }} {
v, err := c.Save(ctx)
if err != nil {
panic(err)
}
@@ -60,14 +58,14 @@ func ({{ $receiver }} *{{ $builder }}) SaveX(ctx context.Context) *{{ $.Name }}
}
// Exec executes the query.
func ({{ $receiver }} *{{ $builder }}) Exec(ctx context.Context) error {
_, err := {{ $receiver }}.Save(ctx)
func (c *{{ $builder }}) Exec(ctx context.Context) error {
_, err := c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) {
if err := {{ $receiver }}.Exec(ctx); err != nil {
func (c *{{ $builder }}) ExecX(ctx context.Context) {
if err := c.Exec(ctx); err != nil {
panic(err)
}
}
@@ -75,17 +73,17 @@ func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) {
{{- $fields := $.Fields }}{{ if $.HasOneFieldID }}{{ if $.ID.UserDefined }}{{ $fields = append $fields $.ID }}{{ end }}{{ end }}
{{ if $.HasDefault }}
// defaults sets the default values of the builder before save.
func ({{ $receiver }} *{{ $builder }}) defaults() {{ if $runtimeRequired }}error{{ end }}{
func (c *{{ $builder }}) defaults() {{ if $runtimeRequired }}error{{ end }}{
{{- range $f := $fields }}
{{- if $f.Default }}
if _, ok := {{ $mutation }}.{{ $f.MutationGet }}(); !ok {
if _, ok := c.mutation.{{ $f.MutationGet }}(); !ok {
{{- if and $runtimeRequired $f.DefaultFunc }}
if {{ $.Package }}.{{ $f.DefaultName }} == nil {
return fmt.Errorf("{{ $pkg }}: uninitialized {{ $.Package }}.{{ $f.DefaultName }} (forgotten import {{ $pkg }}/runtime?)")
}
{{- end }}
v := {{ $.Package }}.{{ $f.DefaultName }}{{ if $f.DefaultFunc }}(){{ end }}
{{ $mutation }}.{{ $f.MutationSet }}(v)
c.mutation.{{ $f.MutationSet }}(v)
}
{{- end }}
{{- end }}
@@ -96,7 +94,7 @@ func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) {
{{ end }}
// check runs all checks and user-defined validators on the builder.
func ({{ $receiver }} *{{ $builder }}) check() error {
func (c *{{ $builder }}) check() error {
{{- range $f := $fields }}
{{- $skip := false }}{{ if $.HasOneFieldID }}{{ if eq $f.Name $.ID.Name }}{{ $skip = true }}{{ end }}{{ end }}
{{- if and (not $f.Optional) (not $skip) }}
@@ -105,10 +103,10 @@ func ({{ $receiver }} *{{ $builder }}) check() error {
{{- if $n }}
{{- $partially := ne $n (len $.Config.Storage.Dialects) }}
{{- if $partially }}
switch {{ $receiver }}.driver.Dialect() {
switch c.driver.Dialect() {
case {{ join $dialects ", " }}:
{{- end }}
if _, ok := {{ $mutation }}.{{ $f.MutationGet }}(); !ok {
if _, ok := c.mutation.{{ $f.MutationGet }}(); !ok {
return &ValidationError{Name: "{{ $f.Name }}", err: errors.New(`{{ $pkg }}: missing required field "{{ $.Name }}.{{ $f.Name }}"`)}
}
{{- if $partially }}
@@ -118,7 +116,7 @@ func ({{ $receiver }} *{{ $builder }}) check() error {
{{- end }}
{{- $isValidator := and ($f.HasGoType) ($f.Type.Validator) }}
{{- with or $f.Validators $f.IsEnum $isValidator }}
if v, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok {
if v, ok := c.mutation.{{ $f.MutationGet }}(); ok {
if err := {{ if or $f.Validators $f.IsEnum }}{{ $.Package }}.{{ $f.Validator }}({{ $f.BasicType "v" }}){{ else }}v.Validate(){{ end }}; err != nil {
return &ValidationError{Name: "{{ $f.Name }}", err: fmt.Errorf(`{{ $pkg }}: validator failed for field "{{ $.Name }}.{{ $f.Name }}": %w`, err)}
}
@@ -127,7 +125,7 @@ func ({{ $receiver }} *{{ $builder }}) check() error {
{{- end }}
{{- range $e := $.EdgesWithID }}
{{- if not $e.Optional }}
if len({{ $mutation }}.{{ $e.StructField }}IDs()) == 0 {
if len(c.mutation.{{ $e.StructField }}IDs()) == 0 {
return &ValidationError{Name: "{{ $e.Name }}", err: errors.New(`{{ $pkg }}: missing required edge "{{ $.Name }}.{{ $e.Name }}"`)}
}
{{- end }}
@@ -135,7 +133,7 @@ func ({{ $receiver }} *{{ $builder }}) check() error {
return nil
}
{{ with extend $ "Receiver" $receiver "Builder" $builder }}
{{ with extend $ "Receiver" "c" "Builder" $builder }}
{{ $tmpl := printf "dialect/%s/create" $.Storage }}
{{ xtemplate $tmpl . }}
{{ end }}
@@ -148,7 +146,6 @@ func ({{ $receiver }} *{{ $builder }}) check() error {
{{- end }}
{{ $bulk := printf "%sCreateBulk" (pascal $.Name) }}
{{ $receiver = receiver $bulk }}
// {{ $bulk }} is the builder for creating many {{ $.Name }} entities in bulk.
type {{ $bulk }} struct {
@@ -165,7 +162,7 @@ type {{ $bulk }} struct {
{{/* If the storage driver supports bulk creation */}}
{{ $tmpl = printf "dialect/%s/create_bulk" $.Storage }}
{{ if hasTemplate $tmpl }}
{{ with extend $ "Builder" $bulk "Receiver" $receiver }}
{{ with extend $ "Builder" $bulk "Receiver" "c" }}
{{ xtemplate $tmpl . }}
{{ end }}
{{ end }}

View File

@@ -18,8 +18,6 @@ import (
)
{{ $builder := $.DeleteName }}
{{ $receiver := $.DeleteReceiver }}
{{ $mutation := print $receiver ".mutation" }}
// {{ $builder }} is the builder for deleting a {{ $.Name }} entity.
type {{ $builder }} struct {
@@ -29,27 +27,26 @@ type {{ $builder }} struct {
}
// Where appends a list predicates to the {{ $builder }} builder.
func ({{ $receiver }} *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $builder }} {
{{ $mutation }}.Where(ps...)
return {{ $receiver }}
func (d *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $builder }} {
d.mutation.Where(ps...)
return d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func ({{ $receiver }} *{{ $builder }}) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, {{ $receiver }}.{{ $.Storage }}Exec, {{ $mutation }}, {{ $receiver }}.hooks)
func (d *{{ $builder }}) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, d.{{ $.Storage }}Exec, d.mutation, d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) int {
{{- $n := "n" }}{{ if eq $receiver $n }}{{ $n = "_n" }}{{ end }}
{{ $n }}, err := {{ $receiver }}.Exec(ctx)
func (d *{{ $builder }}) ExecX(ctx context.Context) int {
n, err := d.Exec(ctx)
if err != nil {
panic(err)
}
return {{ $n }}
return n
}
{{ with extend $ "Receiver" $receiver "Builder" $builder }}
{{ with extend $ "Receiver" "d" "Builder" $builder }}
{{ $tmpl := printf "dialect/%s/delete" $.Storage }}
{{ xtemplate $tmpl . }}
{{ end }}
@@ -62,22 +59,21 @@ func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) int {
{{- end }}
{{ $onebuilder := $.DeleteOneName }}
{{ $oneReceiver := $.DeleteOneReceiver }}
// {{ $onebuilder }} is the builder for deleting a single {{ $.Name }} entity.
type {{ $onebuilder }} struct {
{{ $receiver }} *{{ $builder }}
d *{{ $builder }}
}
// Where appends a list predicates to the {{ $builder }} builder.
func ({{ $oneReceiver }} *{{ $onebuilder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $onebuilder }} {
{{ $oneReceiver }}.{{ $mutation }}.Where(ps...)
return {{ $oneReceiver }}
func (d *{{ $onebuilder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $onebuilder }} {
d.d.mutation.Where(ps...)
return d
}
// Exec executes the deletion query.
func ({{ $oneReceiver }} *{{ $onebuilder }}) Exec(ctx context.Context) error {
n, err := {{ $oneReceiver }}.{{ $receiver }}.Exec(ctx)
func (d *{{ $onebuilder }}) Exec(ctx context.Context) error {
n, err := d.d.Exec(ctx)
switch {
case err != nil:
return err
@@ -89,8 +85,8 @@ func ({{ $oneReceiver }} *{{ $onebuilder }}) Exec(ctx context.Context) error {
}
// ExecX is like Exec, but panics if an error occurs.
func ({{ $oneReceiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) {
if err := {{ $oneReceiver }}.Exec(ctx); err != nil {
func (d *{{ $onebuilder }}) ExecX(ctx context.Context) {
if err := d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -16,7 +16,6 @@ in the LICENSE file in the root directory of this source tree.
{{ end }}
{{ $builder := $.QueryName }}
{{ $receiver := receiver $builder }}
// {{ $builder }} is the builder for querying {{ $.Name }} entities.
type {{ $builder }} struct {
@@ -40,47 +39,47 @@ type {{ $builder }} struct {
}
// Where adds a new predicate for the {{ $builder }} builder.
func ({{ $receiver }} *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $builder }} {
{{ $receiver}}.predicates = append({{ $receiver }}.predicates, ps...)
return {{ $receiver }}
func (q *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $builder }} {
q.predicates = append(q.predicates, ps...)
return q
}
// Limit the number of records to be returned by this query.
func ({{ $receiver }} *{{ $builder }}) Limit(limit int) *{{ $builder }} {
{{ $receiver }}.ctx.Limit = &limit
return {{ $receiver }}
func (q *{{ $builder }}) Limit(limit int) *{{ $builder }} {
q.ctx.Limit = &limit
return q
}
// Offset to start from.
func ({{ $receiver }} *{{ $builder }}) Offset(offset int) *{{ $builder }} {
{{ $receiver }}.ctx.Offset = &offset
return {{ $receiver }}
func (q *{{ $builder }}) Offset(offset int) *{{ $builder }} {
q.ctx.Offset = &offset
return q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func ({{ $receiver }} *{{ $builder }}) Unique(unique bool) *{{ $builder }} {
{{ $receiver }}.ctx.Unique = &unique
return {{ $receiver }}
func (q *{{ $builder }}) Unique(unique bool) *{{ $builder }} {
q.ctx.Unique = &unique
return q
}
// Order specifies how the records should be ordered.
func ({{ $receiver }} *{{ $builder }}) Order(o ...{{ $.Package }}.OrderOption) *{{ $builder }} {
{{ $receiver }}.order = append({{ $receiver }}.order, o...)
return {{ $receiver }}
func (q *{{ $builder }}) Order(o ...{{ $.Package }}.OrderOption) *{{ $builder }} {
q.order = append(q.order, o...)
return q
}
{{/* this code has similarity with edge queries in client.tmpl */}}
{{ range $e := $.Edges }}
{{ $edge_builder := print $e.Type.QueryName }}
// Query{{ pascal $e.Name }} chains the current query on the "{{ $e.Name }}" edge.
func ({{ $receiver }} *{{ $builder }}) Query{{ pascal $e.Name }}() *{{ $edge_builder }} {
query := (&{{ $e.Type.ClientName }}{config: {{ $receiver }}.config}).Query()
func (q *{{ $builder }}) Query{{ pascal $e.Name }}() *{{ $edge_builder }} {
query := (&{{ $e.Type.ClientName }}{config: q.config}).Query()
query.path = func(ctx context.Context) (fromU {{ $.Storage.Builder }}, err error) {
if err := {{ $receiver }}.prepareQuery(ctx); err != nil {
if err := q.prepareQuery(ctx); err != nil {
return nil, err
}
{{- with extend $ "Receiver" $receiver "Edge" $e "Ident" "fromU" -}}
{{ with extend $ "Receiver" "q" "Edge" $e "Ident" "fromU" -}}
{{ $tmpl := printf "dialect/%s/query/path" $.Storage }}
{{- xtemplate $tmpl . }}
{{- end -}}
@@ -92,8 +91,8 @@ func ({{ $receiver }} *{{ $builder }}) Order(o ...{{ $.Package }}.OrderOption) *
// First returns the first {{ $.Name }} entity from the query.
// Returns a *NotFoundError when no {{ $.Name }} was found.
func ({{ $receiver }} *{{ $builder }}) First(ctx context.Context) (*{{ $.Name }}, error) {
nodes, err := {{ $receiver }}.Limit(1).All(setContextOp(ctx, {{ $receiver }}.ctx, ent.OpQueryFirst))
func (q *{{ $builder }}) First(ctx context.Context) (*{{ $.Name }}, error) {
nodes, err := q.Limit(1).All(setContextOp(ctx, q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
@@ -104,8 +103,8 @@ func ({{ $receiver }} *{{ $builder }}) First(ctx context.Context) (*{{ $.Name }}
}
// FirstX is like First, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) FirstX(ctx context.Context) *{{ $.Name }} {
node, err := {{ $receiver }}.First(ctx)
func (q *{{ $builder }}) FirstX(ctx context.Context) *{{ $.Name }} {
node, err := q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
@@ -115,9 +114,9 @@ func ({{ $receiver }} *{{ $builder }}) FirstX(ctx context.Context) *{{ $.Name }}
{{ if $.HasOneFieldID }}
// FirstID returns the first {{ $.Name }} ID from the query.
// Returns a *NotFoundError when no {{ $.Name }} ID was found.
func ({{ $receiver }} *{{ $builder }}) FirstID(ctx context.Context) (id {{ $.ID.Type }}, err error) {
func (q *{{ $builder }}) FirstID(ctx context.Context) (id {{ $.ID.Type }}, err error) {
var ids []{{ $.ID.Type }}
if ids, err = {{ $receiver }}.Limit(1).IDs(setContextOp(ctx, {{ $receiver }}.ctx, ent.OpQueryFirstID)); err != nil {
if ids, err = q.Limit(1).IDs(setContextOp(ctx, q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
@@ -128,8 +127,8 @@ func ({{ $receiver }} *{{ $builder }}) FirstX(ctx context.Context) *{{ $.Name }}
}
// FirstIDX is like FirstID, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) FirstIDX(ctx context.Context) {{ $.ID.Type }} {
id, err := {{ $receiver }}.FirstID(ctx)
func (q *{{ $builder }}) FirstIDX(ctx context.Context) {{ $.ID.Type }} {
id, err := q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
@@ -140,8 +139,8 @@ func ({{ $receiver }} *{{ $builder }}) FirstX(ctx context.Context) *{{ $.Name }}
// Only returns a single {{ $.Name }} entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one {{ $.Name }} entity is found.
// Returns a *NotFoundError when no {{ $.Name }} entities are found.
func ({{ $receiver }} *{{ $builder }}) Only(ctx context.Context) (*{{ $.Name }}, error) {
nodes, err := {{ $receiver }}.Limit(2).All(setContextOp(ctx, {{ $receiver }}.ctx, ent.OpQueryOnly))
func (q *{{ $builder }}) Only(ctx context.Context) (*{{ $.Name }}, error) {
nodes, err := q.Limit(2).All(setContextOp(ctx, q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
@@ -156,8 +155,8 @@ func ({{ $receiver }} *{{ $builder }}) Only(ctx context.Context) (*{{ $.Name }},
}
// OnlyX is like Only, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) OnlyX(ctx context.Context) *{{ $.Name }} {
node, err := {{ $receiver }}.Only(ctx)
func (q *{{ $builder }}) OnlyX(ctx context.Context) *{{ $.Name }} {
node, err := q.Only(ctx)
if err != nil {
panic(err)
}
@@ -168,9 +167,9 @@ func ({{ $receiver }} *{{ $builder }}) OnlyX(ctx context.Context) *{{ $.Name }}
// OnlyID is like Only, but returns the only {{ $.Name }} ID in the query.
// Returns a *NotSingularError when more than one {{ $.Name }} ID is found.
// Returns a *NotFoundError when no entities are found.
func ({{ $receiver }} *{{ $builder }}) OnlyID(ctx context.Context) (id {{ $.ID.Type }}, err error) {
func (q *{{ $builder }}) OnlyID(ctx context.Context) (id {{ $.ID.Type }}, err error) {
var ids []{{ $.ID.Type }}
if ids, err = {{ $receiver }}.Limit(2).IDs(setContextOp(ctx, {{ $receiver }}.ctx, ent.OpQueryOnlyID)); err != nil {
if ids, err = q.Limit(2).IDs(setContextOp(ctx, q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
@@ -185,8 +184,8 @@ func ({{ $receiver }} *{{ $builder }}) OnlyX(ctx context.Context) *{{ $.Name }}
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) OnlyIDX(ctx context.Context) {{ $.ID.Type }} {
id, err := {{ $receiver }}.OnlyID(ctx)
func (q *{{ $builder }}) OnlyIDX(ctx context.Context) {{ $.ID.Type }} {
id, err := q.OnlyID(ctx)
if err != nil {
panic(err)
}
@@ -195,18 +194,18 @@ func ({{ $receiver }} *{{ $builder }}) OnlyX(ctx context.Context) *{{ $.Name }}
{{ end }}
// All executes the query and returns a list of {{ plural $.Name }}.
func ({{ $receiver }} *{{ $builder }}) All(ctx context.Context) ([]*{{ $.Name }}, error) {
ctx = setContextOp(ctx, {{ $receiver }}.ctx, ent.OpQueryAll)
if err := {{ $receiver }}.prepareQuery(ctx); err != nil {
func (q *{{ $builder }}) All(ctx context.Context) ([]*{{ $.Name }}, error) {
ctx = setContextOp(ctx, q.ctx, ent.OpQueryAll)
if err := q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*{{ $.Name }}, *{{ $builder }}]()
return withInterceptors[[]*{{ $.Name }}](ctx, {{ $receiver }}, qr, {{ $receiver }}.inters)
return withInterceptors[[]*{{ $.Name }}](ctx, q, qr, q.inters)
}
// AllX is like All, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) AllX(ctx context.Context) []*{{ $.Name }} {
nodes, err := {{ $receiver }}.All(ctx)
func (q *{{ $builder }}) AllX(ctx context.Context) []*{{ $.Name }} {
nodes, err := q.All(ctx)
if err != nil {
panic(err)
}
@@ -215,21 +214,21 @@ func ({{ $receiver }} *{{ $builder }}) AllX(ctx context.Context) []*{{ $.Name }}
{{ if $.HasOneFieldID }}
// IDs executes the query and returns a list of {{ $.Name }} IDs.
func ({{ $receiver }} *{{ $builder }}) IDs(ctx context.Context) (ids []{{ $.ID.Type }}, err error) {
func (q *{{ $builder }}) IDs(ctx context.Context) (ids []{{ $.ID.Type }}, err error) {
{{- /* Since a graph traversal such as JOINs can return duplicate IDs, set the Unique modifier unless specified otherwise. */}}
if {{ $receiver }}.ctx.Unique == nil && {{ $receiver }}.path != nil {
{{ $receiver }}.Unique(true)
if q.ctx.Unique == nil && q.path != nil {
q.Unique(true)
}
ctx = setContextOp(ctx, {{ $receiver }}.ctx, ent.OpQueryIDs)
if err = {{ $receiver }}.Select({{ $.Package }}.FieldID).Scan(ctx, &ids); err != nil {
ctx = setContextOp(ctx, q.ctx, ent.OpQueryIDs)
if err = q.Select({{ $.Package }}.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) IDsX(ctx context.Context) []{{ $.ID.Type }} {
ids, err := {{ $receiver }}.IDs(ctx)
func (q *{{ $builder }}) IDsX(ctx context.Context) []{{ $.ID.Type }} {
ids, err := q.IDs(ctx)
if err != nil {
panic(err)
}
@@ -238,17 +237,17 @@ func ({{ $receiver }} *{{ $builder }}) AllX(ctx context.Context) []*{{ $.Name }}
{{ end }}
// Count returns the count of the given query.
func ({{ $receiver }} *{{ $builder }}) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, {{ $receiver }}.ctx, ent.OpQueryCount)
if err := {{ $receiver }}.prepareQuery(ctx); err != nil {
func (q *{{ $builder }}) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, q.ctx, ent.OpQueryCount)
if err := q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, {{ $receiver }}, querierCount[*{{ $builder }}](), {{ $receiver }}.inters)
return withInterceptors[int](ctx, q, querierCount[*{{ $builder }}](), q.inters)
}
// CountX is like Count, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) CountX(ctx context.Context) int {
count, err := {{ $receiver }}.Count(ctx)
func (q *{{ $builder }}) CountX(ctx context.Context) int {
count, err := q.Count(ctx)
if err != nil {
panic(err)
}
@@ -256,9 +255,9 @@ func ({{ $receiver }} *{{ $builder }}) CountX(ctx context.Context) int {
}
// Exist returns true if the query has elements in the graph.
func ({{ $receiver }} *{{ $builder }}) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, {{ $receiver }}.ctx, ent.OpQueryExist)
switch _, err := {{ $receiver }}.First{{ if $.HasOneFieldID }}ID{{ end }}(ctx);{
func (q *{{ $builder }}) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, q.ctx, ent.OpQueryExist)
switch _, err := q.First{{ if $.HasOneFieldID }}ID{{ end }}(ctx);{
case IsNotFound(err):
return false, nil
case err != nil:
@@ -269,8 +268,8 @@ func ({{ $receiver }} *{{ $builder }}) Exist(ctx context.Context) (bool, error)
}
// ExistX is like Exist, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) ExistX(ctx context.Context) bool {
exist, err := {{ $receiver }}.Exist(ctx)
func (q *{{ $builder }}) ExistX(ctx context.Context) bool {
exist, err := q.Exist(ctx)
if err != nil {
panic(err)
}
@@ -279,24 +278,24 @@ func ({{ $receiver }} *{{ $builder }}) ExistX(ctx context.Context) bool {
// Clone returns a duplicate of the {{ $builder }} builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func ({{ $receiver }} *{{ $builder }}) Clone() *{{ $builder }} {
if {{ $receiver }} == nil {
func (q *{{ $builder }}) Clone() *{{ $builder }} {
if q == nil {
return nil
}
return &{{ $builder }}{
config: {{ $receiver }}.config,
ctx: {{ $receiver }}.ctx.Clone(),
order: append([]{{ $.Package }}.OrderOption{}, {{ $receiver }}.order...),
inters: append([]Interceptor{}, {{ $receiver }}.inters...),
predicates: append([]predicate.{{ $.Name }}{}, {{ $receiver }}.predicates...),
config: q.config,
ctx: q.ctx.Clone(),
order: append([]{{ $.Package }}.OrderOption{}, q.order...),
inters: append([]Interceptor{}, q.inters...),
predicates: append([]predicate.{{ $.Name }}{}, q.predicates...),
{{- range $e := $.Edges }}
{{ $e.EagerLoadField }}: {{ $receiver }}.{{ $e.EagerLoadField }}.Clone(),
{{ $e.EagerLoadField }}: q.{{ $e.EagerLoadField }}.Clone(),
{{- end }}
// clone intermediate query.
{{ $.Storage }}: {{ $receiver }}.{{ $.Storage }}.Clone(),
path: {{ $receiver }}.path,
{{ $.Storage }}: q.{{ $.Storage }}.Clone(),
path: q.path,
{{- if $.FeatureEnabled "sql/modifier" }}
modifiers: append([]func(*sql.Selector){}, {{ $receiver }}.modifiers...),
modifiers: append([]func(*sql.Selector){}, q.modifiers...),
{{- end }}
}
}
@@ -306,13 +305,13 @@ func ({{ $receiver }} *{{ $builder }}) Clone() *{{ $builder }} {
{{ $func := print "With" $e.StructField }}
// {{ $func }} tells the query-builder to eager-load the nodes that are connected to
// the "{{ $e.Name }}" edge. The optional arguments are used to configure the query builder of the edge.
func ({{ $receiver }} *{{ $builder }}) {{ $func }}(opts ...func(*{{ $ebuilder }})) *{{ $builder }} {
query := (&{{ $e.Type.ClientName }}{config: {{ $receiver }}.config}).Query()
func (q *{{ $builder }}) {{ $func }}(opts ...func(*{{ $ebuilder }})) *{{ $builder }} {
query := (&{{ $e.Type.ClientName }}{config: q.config}).Query()
for _, opt := range opts {
opt(query)
}
{{ $receiver }}.{{ $e.EagerLoadField }} = query
return {{ $receiver }}
q.{{ $e.EagerLoadField }} = query
return q
}
{{- end }}
@@ -336,10 +335,10 @@ func ({{ $receiver }} *{{ $builder }}) Clone() *{{ $builder }} {
// Scan(ctx, &v)
//
{{- end }}
func ({{ $receiver }} *{{ $builder }}) GroupBy(field string, fields ...string) *{{ $groupBuilder }} {
{{ $receiver }}.ctx.Fields = append([]string{field}, fields...)
grbuild := &{{ $groupBuilder }}{build: {{ $receiver }}}
grbuild.flds = &{{ $receiver }}.ctx.Fields
func (q *{{ $builder }}) GroupBy(field string, fields ...string) *{{ $groupBuilder }} {
q.ctx.Fields = append([]string{field}, fields...)
grbuild := &{{ $groupBuilder }}{build: q}
grbuild.flds = &q.ctx.Fields
grbuild.label = {{ $.Package }}.Label
grbuild.scan = grbuild.Scan
return grbuild
@@ -363,26 +362,26 @@ func ({{ $receiver }} *{{ $builder }}) GroupBy(field string, fields ...string) *
// Scan(ctx, &v)
//
{{- end }}
func ({{ $receiver }} *{{ $builder }}) Select(fields ...string) *{{ $selectBuilder }} {
{{ $receiver }}.ctx.Fields = append({{ $receiver }}.ctx.Fields, fields...)
sbuild := &{{ $selectBuilder }}{ {{ $builder }}: {{ $receiver }} }
func (q *{{ $builder }}) Select(fields ...string) *{{ $selectBuilder }} {
q.ctx.Fields = append(q.ctx.Fields, fields...)
sbuild := &{{ $selectBuilder }}{ {{ $builder }}: q }
sbuild.label = {{ $.Package }}.Label
sbuild.flds, sbuild.scan = &{{ $receiver }}.ctx.Fields, sbuild.Scan
sbuild.flds, sbuild.scan = &q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a {{ $selectBuilder }} configured with the given aggregations.
func ({{ $receiver }} *{{ $builder }}) Aggregate(fns ...AggregateFunc) *{{ $selectBuilder }} {
return {{ $receiver }}.Select().Aggregate(fns...)
func (q *{{ $builder }}) Aggregate(fns ...AggregateFunc) *{{ $selectBuilder }} {
return q.Select().Aggregate(fns...)
}
func ({{ $receiver }} *{{ $builder }}) prepareQuery(ctx context.Context) error {
for _, inter := range {{ $receiver }}.inters {
func (q *{{ $builder }}) prepareQuery(ctx context.Context) error {
for _, inter := range q.inters {
if inter == nil {
return fmt.Errorf("{{ $pkg }}: uninitialized interceptor (forgotten import {{ $pkg }}/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, {{ $receiver }}); err != nil {
if err := trv.Traverse(ctx, q); err != nil {
return err
}
}
@@ -390,22 +389,22 @@ func ({{ $receiver }} *{{ $builder }}) prepareQuery(ctx context.Context) error {
{{- /* Optional prepare checks per dialect. */}}
{{- $tmpl = printf "dialect/%s/query/preparecheck" $.Storage }}
{{- if hasTemplate $tmpl }}
{{- with extend $ "Receiver" $receiver "Package" $pkg }}
{{- with extend $ "Receiver" "q" "Package" $pkg }}
{{- xtemplate $tmpl . }}
{{- end }}
{{- end }}
if {{ $receiver }}.path != nil {
prev, err := {{ $receiver }}.path(ctx)
if q.path != nil {
prev, err := q.path(ctx)
if err != nil {
return err
}
{{ $receiver }}.{{ $.Storage }} = prev
q.{{ $.Storage }} = prev
}
{{- if $.NumPolicy }}
if {{ $.Package }}.Policy == nil {
return errors.New("{{ $pkg }}: uninitialized {{ $.Package }}.Policy (forgotten import {{ $pkg }}/runtime?)")
}
if err := {{ $.Package }}.Policy.EvalQuery(ctx, {{ $receiver }}); err != nil {
if err := {{ $.Package }}.Policy.EvalQuery(ctx, q); err != nil {
return err
}
{{- end }}

View File

@@ -8,7 +8,6 @@ in the LICENSE file in the root directory of this source tree.
{{ define "setter" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $fields := $.Fields }}
{{ $updater := false }}
{{ $creator := true }}
@@ -22,16 +21,15 @@ in the LICENSE file in the root directory of this source tree.
{{- end }}
{{ range $f := $fields }}
{{ $p := receiver $f.Type.String }}{{ if eq $p $receiver }} {{ $p = "value" }} {{ end }}
{{ $func := print "Set" $f.StructField }}
// {{ $func }} sets the "{{ $f.Name }}" field.
func ({{ $receiver }} *{{ $builder }}) {{ $func }}({{ $p }} {{ $f.Type }}) *{{ $builder }} {
func (m *{{ $builder }}) {{ $func }}(v {{ $f.Type }}) *{{ $builder }} {
{{- /* setting numeric type override previous calls to Add. */}}
{{- if and $updater $f.SupportsMutationAdd }}
{{ $receiver }}.mutation.{{ $f.MutationReset }}()
m.mutation.{{ $f.MutationReset }}()
{{- end }}
{{ $receiver }}.mutation.{{ $f.MutationSet }}({{ $p }})
return {{ $receiver }}
m.mutation.{{ $f.MutationSet }}(v)
return m
}
{{/* Avoid generating nillable setters for nillable types. */}}
@@ -41,36 +39,36 @@ in the LICENSE file in the root directory of this source tree.
{{ $skipNillable := false }}{{ range $fields }}{{ if and (ne .Name $f.Name) (eq .MutationSet $nillableFunc) }}{{ $skipNillable = true }}{{ break }}{{ end }}{{ end }}
{{ if and (not $f.Type.Nillable) (not $skipNillable) (or $nillableC $nillableU) }}
// {{ $nillableFunc }} sets the "{{ $f.Name }}" field if the given value is not nil.
func ({{ $receiver }} *{{ $builder }}) {{ $nillableFunc }}({{ $p }} *{{ $f.Type }}) *{{ $builder }} {
if {{ $p }} != nil {
{{ $receiver }}.{{ $func }}(*{{ $p }})
func (m *{{ $builder }}) {{ $nillableFunc }}(v *{{ $f.Type }}) *{{ $builder }} {
if v != nil {
m.{{ $func }}(*v)
}
return {{ $receiver }}
return m
}
{{ end }}
{{ if and $updater $f.SupportsMutationAdd }}
// {{ $f.MutationAdd }} adds {{ $p }} to the "{{ $f.Name }}" field.
func ({{ $receiver }} *{{ $builder }}) {{ $f.MutationAdd }}({{ $p }} {{ $f.SignedType }}) *{{ $builder }} {
{{ $receiver }}.mutation.{{ $f.MutationAdd }}({{ $p }})
return {{ $receiver }}
// {{ $f.MutationAdd }} adds value to the "{{ $f.Name }}" field.
func (m *{{ $builder }}) {{ $f.MutationAdd }}(v {{ $f.SignedType }}) *{{ $builder }} {
m.mutation.{{ $f.MutationAdd }}(v)
return m
}
{{ end }}
{{ if and $updater $f.SupportsMutationAppend }}
// {{ $f.MutationAppend }} appends {{ $p }} to the "{{ $f.Name }}" field.
func ({{ $receiver }} *{{ $builder }}) {{ $f.MutationAppend }}({{ $p }} {{ $f.Type }}) *{{ $builder }} {
{{ $receiver }}.mutation.{{ $f.MutationAppend }}({{ $p }})
return {{ $receiver }}
// {{ $f.MutationAppend }} appends value to the "{{ $f.Name }}" field.
func (m *{{ $builder }}) {{ $f.MutationAppend }}(v {{ $f.Type }}) *{{ $builder }} {
m.mutation.{{ $f.MutationAppend }}(v)
return m
}
{{ end }}
{{ if and $f.Optional $updater }}
{{ $func := print "Clear" $f.StructField }}
// {{ $func }} clears the value of the "{{ $f.Name }}" field.
func ({{ $receiver }} *{{ $builder }}) {{ $func }}() *{{ $builder }} {
{{ $receiver }}.mutation.{{ $func }}()
return {{ $receiver }}
func (m *{{ $builder }}) {{ $func }}() *{{ $builder }} {
m.mutation.{{ $func }}()
return m
}
{{ end }}
{{ end }}
@@ -85,42 +83,39 @@ in the LICENSE file in the root directory of this source tree.
{{ $withSetter := not $e.HasFieldSetter }}
{{ if $withSetter }}
// {{ $idsFunc }} {{ $op }}s the "{{ $e.Name }}" edge to the {{ $e.Type.Name }} entity by ID{{ if not $e.Unique }}s{{ end }}.
func ({{ $receiver }} *{{ $builder }}) {{ $idsFunc }}({{ if $e.Unique }}id{{ else }}ids ...{{ end }} {{ $e.Type.ID.Type }}) *{{ $builder }} {
{{ $receiver }}.mutation.{{ $idsFunc }}({{ if $e.Unique }}id{{ else }}ids ...{{ end }})
return {{ $receiver }}
func (m *{{ $builder }}) {{ $idsFunc }}({{ if $e.Unique }}id{{ else }}ids ...{{ end }} {{ $e.Type.ID.Type }}) *{{ $builder }} {
m.mutation.{{ $idsFunc }}({{ if $e.Unique }}id{{ else }}ids ...{{ end }})
return m
}
{{ end }}
{{ if and $e.Unique $e.Optional $withSetter }}
{{ $nillableIDsFunc := print "SetNillable" $e.StructField "ID" }}
// {{ $nillableIDsFunc }} sets the "{{ $e.Name }}" edge to the {{ $e.Type.Name }} entity by ID if the given value is not nil.
func ({{ $receiver }} *{{ $builder }}) {{ $nillableIDsFunc }}(id *{{ $e.Type.ID.Type }}) *{{ $builder }} {
func (m *{{ $builder }}) {{ $nillableIDsFunc }}(id *{{ $e.Type.ID.Type }}) *{{ $builder }} {
if id != nil {
{{ $receiver}} = {{ $receiver }}.{{ $idsFunc }}(*id)
m = m.{{ $idsFunc }}(*id)
}
return {{ $receiver }}
return m
}
{{ end }}
{{ $p := lower (printf "%.1s" $e.Type.Name) }}
{{ if eq $p $receiver }} {{ $p = "v" }} {{ end }}
{{ $func := print (pascal $op) $e.StructField }}
// {{ $func }} {{ $op }}s the "{{ $e.Name }}" edge{{if not $e.Unique}}s{{ end }} to the {{ $e.Type.Name }} entity.
func ({{ $receiver }} *{{ $builder }}) {{ $func }}({{ $p }} {{ if not $e.Unique }}...{{ end }}*{{ $e.Type.Name}}) *{{ $builder }} {
func (m *{{ $builder }}) {{ $func }}(v {{ if not $e.Unique }}...{{ end }}*{{ $e.Type.Name}}) *{{ $builder }} {
{{ if $e.Unique -}}
return {{ $receiver }}.{{ $idsFunc }}({{ $p }}.ID)
return m.{{ $idsFunc }}(v.ID)
{{- else -}}
ids := make([]{{ $e.Type.ID.Type }}, len({{ $p }}))
{{ $i := "i" }}{{ if eq $i $p }}{{ $i = "j" }}{{ end -}}
for {{ $i }} := range {{ $p }} {
ids[{{ $i }}] = {{ $p }}[{{ $i }}].ID
ids := make([]{{ $e.Type.ID.Type }}, len(v))
for i := range v {
ids[i] = v[i].ID
}
return {{ $receiver }}.{{ $idsFunc }}(ids...)
return m.{{ $idsFunc }}(ids...)
{{- end }}
}
{{ end }}
// Mutation returns the {{ $.MutationName }} object of the builder.
func ({{ $receiver }} *{{ $builder }}) Mutation() *{{ $.MutationName }} {
return {{ $receiver }}.mutation
func (m *{{ $builder }}) Mutation() *{{ $.MutationName }} {
return m.mutation
}
{{ end }}

View File

@@ -16,8 +16,6 @@ in the LICENSE file in the root directory of this source tree.
{{ end }}
{{ $builder := $.UpdateName }}
{{ $receiver := $.UpdateReceiver }}
{{ $mutation := print $receiver ".mutation" }}
{{ $runtimeRequired := or $.NumHooks $.NumPolicy }}
// {{ $builder }} is the builder for updating {{ $.Name }} entities.
@@ -27,36 +25,36 @@ type {{ $builder }} struct {
}
// Where appends a list predicates to the {{ $builder }} builder.
func ({{ $receiver}} *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $builder }} {
{{ $mutation }}.Where(ps...)
return {{ $receiver }}
func (u *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $builder }} {
u.mutation.Where(ps...)
return u
}
{{ with extend $ "Receiver" $receiver "Builder" $builder }}
{{ with extend $ "Receiver" "u" "Builder" $builder }}
{{ template "setter" . }}
{{ end }}
{{ with extend $ "Receiver" $receiver "Builder" $builder }}
{{ with extend $ "Receiver" "u" "Builder" $builder }}
{{ template "update/edges" . }}
{{ end }}
// Save executes the query and returns the number of nodes affected by the update operation.
func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) (int, error) {
func (u *{{ $builder }}) Save(ctx context.Context) (int, error) {
{{- if $.HasUpdateDefault }}
{{- if $runtimeRequired }}
if err := {{ $receiver }}.defaults(); err != nil {
if err := u.defaults(); err != nil {
return 0, err
}
{{- else }}
{{ $receiver }}.defaults()
u.defaults()
{{- end }}
{{- end }}
return withHooks(ctx, {{ $receiver }}.{{ $.Storage }}Save, {{ $mutation }}, {{ $receiver }}.hooks)
return withHooks(ctx, u.{{ $.Storage }}Save, u.mutation, u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) SaveX(ctx context.Context) int {
affected, err := {{ $receiver }}.Save(ctx)
func (u *{{ $builder }}) SaveX(ctx context.Context) int {
affected, err := u.Save(ctx)
if err != nil {
panic(err)
}
@@ -64,30 +62,28 @@ func ({{ $receiver }} *{{ $builder }}) SaveX(ctx context.Context) int {
}
// Exec executes the query.
func ({{ $receiver }} *{{ $builder }}) Exec(ctx context.Context) error {
_, err := {{ $receiver }}.Save(ctx)
func (u *{{ $builder }}) Exec(ctx context.Context) error {
_, err := u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) {
if err := {{ $receiver }}.Exec(ctx); err != nil {
func (u *{{ $builder }}) ExecX(ctx context.Context) {
if err := u.Exec(ctx); err != nil {
panic(err)
}
}
{{ with extend $ "Receiver" $receiver "Package" $pkg "Builder" $builder }}
{{ with extend $ "Receiver" "u" "Package" $pkg "Builder" $builder }}
{{ template "update/checks" . }}
{{ end }}
{{ with extend $ "Receiver" $receiver "Builder" $builder "Package" $pkg }}
{{ with extend $ "Receiver" "u" "Builder" $builder "Package" $pkg }}
{{ $tmpl := printf "dialect/%s/update" $.Storage }}
{{ xtemplate $tmpl . }}
{{ end }}
{{ $onebuilder := $.UpdateOneName }}
{{ $receiver = receiver $onebuilder }}
{{ $mutation = print $receiver ".mutation" }}
// {{ $onebuilder }} is the builder for updating a single {{ $.Name }} entity.
type {{ $onebuilder }} struct {
@@ -96,45 +92,45 @@ type {{ $onebuilder }} struct {
{{- template "update/fields" $ }}
}
{{ with extend $ "Receiver" $receiver "Builder" $onebuilder }}
{{ with extend $ "Receiver" "u" "Builder" $onebuilder }}
{{ template "setter" . }}
{{ end }}
{{ with extend $ "Receiver" $receiver "Builder" $onebuilder }}
{{ with extend $ "Receiver" "u" "Builder" $onebuilder }}
{{ template "update/edges" . }}
{{ end }}
// Where appends a list predicates to the {{ $builder }} builder.
func ({{ $receiver }} *{{ $onebuilder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $onebuilder }} {
{{ $mutation }}.Where(ps...)
return {{ $receiver }}
func (u *{{ $onebuilder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $onebuilder }} {
u.mutation.Where(ps...)
return u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func ({{ $receiver }} *{{ $onebuilder }}) Select(field string, fields ...string) *{{ $onebuilder }} {
{{ $receiver }}.fields = append([]string{field}, fields...)
return {{ $receiver }}
func (u *{{ $onebuilder }}) Select(field string, fields ...string) *{{ $onebuilder }} {
u.fields = append([]string{field}, fields...)
return u
}
// Save executes the query and returns the updated {{ $.Name }} entity.
func ({{ $receiver }} *{{ $onebuilder }} ) Save(ctx context.Context) (*{{ $.Name }}, error) {
func (u *{{ $onebuilder }} ) Save(ctx context.Context) (*{{ $.Name }}, error) {
{{- if $.HasUpdateDefault }}
{{- if $runtimeRequired }}
if err := {{ $receiver }}.defaults(); err != nil {
if err := u.defaults(); err != nil {
return nil, err
}
{{- else }}
{{ $receiver }}.defaults()
u.defaults()
{{- end }}
{{- end }}
return withHooks(ctx, {{ $receiver }}.{{ $.Storage }}Save, {{ $mutation }}, {{ $receiver }}.hooks)
return withHooks(ctx, u.{{ $.Storage }}Save, u.mutation, u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func ({{ $receiver }} *{{ $onebuilder }}) SaveX(ctx context.Context) *{{ $.Name }} {
node, err := {{ $receiver }}.Save(ctx)
func (u *{{ $onebuilder }}) SaveX(ctx context.Context) *{{ $.Name }} {
node, err := u.Save(ctx)
if err != nil {
panic(err)
}
@@ -142,23 +138,23 @@ func ({{ $receiver }} *{{ $onebuilder }}) SaveX(ctx context.Context) *{{ $.Name
}
// Exec executes the query on the entity.
func ({{ $receiver }} *{{ $onebuilder }}) Exec(ctx context.Context) error {
_, err := {{ $receiver }}.Save(ctx)
func (u *{{ $onebuilder }}) Exec(ctx context.Context) error {
_, err := u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func ({{ $receiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) {
if err := {{ $receiver }}.Exec(ctx); err != nil {
func (u *{{ $onebuilder }}) ExecX(ctx context.Context) {
if err := u.Exec(ctx); err != nil {
panic(err)
}
}
{{ with extend $ "Receiver" $receiver "Package" $pkg "Builder" $onebuilder }}
{{ with extend $ "Receiver" "u" "Package" $pkg "Builder" $onebuilder }}
{{ template "update/checks" . }}
{{ end }}
{{ with extend $ "Receiver" $receiver "Builder" $onebuilder "Package" $pkg }}
{{ with extend $ "Receiver" "u" "Builder" $onebuilder "Package" $pkg }}
{{ $tmpl := printf "dialect/%s/update" $.Storage }}
{{ xtemplate $tmpl . }}
{{ end }}
@@ -186,8 +182,6 @@ func ({{ $receiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) {
{{/* shared edges removal between the two updaters */}}
{{ define "update/edges" }}
{{ $builder := pascal .Scope.Builder }}
{{ $receiver := .Scope.Receiver }}
{{ $mutation := print $receiver ".mutation" }}
{{ range $e := $.EdgesWithID }}
{{ if $e.Immutable }}
@@ -196,29 +190,25 @@ func ({{ $receiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) {
{{ end }}
{{ $func := $e.MutationClear }}
// {{ $func }} clears {{ if $e.Unique }}the "{{ $e.Name }}" edge{{ else }}all "{{ $e.Name }}" edges{{ end }} to the {{ $e.Type.Name }} entity.
func ({{ $receiver }} *{{ $builder }}) {{ $func }}() *{{ $builder }} {
{{ $mutation }}.{{ $func }}()
return {{ $receiver }}
func (u *{{ $builder }}) {{ $func }}() *{{ $builder }} {
u.mutation.{{ $func }}()
return u
}
{{ if not $e.Unique }}
{{ $p := lower (printf "%.1s" $e.Type.Name) }}
{{/* if the name of the parameter conflicts with the receiver name */}}
{{ if eq $p $receiver }} {{ $p = "v" }} {{ end }}
{{ $idsFunc := print "Remove" (singular $e.Name | pascal) "IDs" }}
// {{ $idsFunc }} removes the "{{ $e.Name }}" edge to {{ $e.Type.Name }} entities by IDs.
func ({{ $receiver }} *{{ $builder }}) {{ $idsFunc }}(ids ...{{ $e.Type.ID.Type }}) *{{ $builder }} {
{{ $mutation }}.{{ $idsFunc }}(ids...)
return {{ $receiver }}
func (u *{{ $builder }}) {{ $idsFunc }}(ids ...{{ $e.Type.ID.Type }}) *{{ $builder }} {
u.mutation.{{ $idsFunc }}(ids...)
return u
}
{{ $func := print "Remove" $e.StructField }}
// {{ $func }} removes "{{ $e.Name }}" edges to {{ $e.Type.Name }} entities.
func ({{ $receiver }} *{{ $builder }}) {{ $func }}({{ $p }} ...*{{ $e.Type.Name }}) *{{ $builder }} {
ids := make([]{{ $e.Type.ID.Type }}, len({{ $p }}))
{{ $i := "i" }}{{ if eq $i $p }}{{ $i = "j" }}{{ end -}}
for {{ $i }} := range {{ $p }} {
ids[{{ $i }}] = {{ $p }}[{{ $i }}].ID
func (u *{{ $builder }}) {{ $func }}(v ...*{{ $e.Type.Name }}) *{{ $builder }} {
ids := make([]{{ $e.Type.ID.Type }}, len(v))
for i := range v {
ids[i] = v[i].ID
}
return {{ $receiver }}.{{ $idsFunc }}(ids...)
return u.{{ $idsFunc }}(ids...)
}
{{ end }}
{{ end }}
@@ -227,24 +217,22 @@ func ({{ $receiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) {
{{/* shared template for the 2 update builders */}}
{{ define "update/checks" }}
{{ $pkg := .Scope.Package }}
{{ $receiver := .Scope.Receiver }}
{{ $builder := pascal .Scope.Builder }}
{{ $mutation := print $receiver ".mutation" }}
{{ $runtimeRequired := or $.NumHooks $.NumPolicy }}
{{ if $.HasUpdateDefault }}
// defaults sets the default values of the builder before save.
func ({{ $receiver }} *{{ $builder }}) defaults() {{ if $runtimeRequired }}error{{ end }}{
func (u *{{ $builder }}) defaults() {{ if $runtimeRequired }}error{{ end }}{
{{- range $f := $.Fields }}
{{- if $f.UpdateDefault }}
if _, ok := {{ $mutation }}.{{ $f.MutationGet }}(); !ok {{ if $f.Optional }} && !{{ $mutation }}.{{ $f.StructField }}Cleared() {{ end }} {
if _, ok := u.mutation.{{ $f.MutationGet }}(); !ok {{ if $f.Optional }} && !u.mutation.{{ $f.StructField }}Cleared() {{ end }} {
{{- if $runtimeRequired }}
if {{ $.Package }}.{{ $f.UpdateDefaultName }} == nil {
return fmt.Errorf("{{ $pkg }}: uninitialized {{ $.Package }}.{{ $f.UpdateDefaultName }} (forgotten import {{ $pkg }}/runtime?)")
}
{{- end }}
v := {{ $.Package }}.{{ $f.UpdateDefaultName }}()
{{ $mutation }}.{{ $f.MutationSet }}(v)
u.mutation.{{ $f.MutationSet }}(v)
}
{{- end }}
{{- end }}
@@ -256,11 +244,11 @@ func ({{ $receiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) {
{{ if $.HasUpdateCheckers }}
// check runs all checks and user-defined validators on the builder.
func ({{ $receiver }} *{{ $builder }}) check() error {
func (u *{{ $builder }}) check() error {
{{- range $f := $.Fields }}
{{- $isValidator := and ($f.HasGoType) ($f.Type.Validator) }}
{{- with and (not $f.Immutable) (or $f.Validators $f.IsEnum $isValidator) }}
if v, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok {
if v, ok := u.mutation.{{ $f.MutationGet }}(); ok {
if err := {{ if or $f.Validators $f.IsEnum }}{{ $.Package }}.{{ $f.Validator }}({{ $f.BasicType "v" }}){{ else }}v.Validate(){{ end }}; err != nil {
return &ValidationError{Name: "{{ $f.Name }}", err: fmt.Errorf(`{{ $pkg }}: validator failed for field "{{ $.Name }}.{{ $f.Name }}": %w`, err)}
}
@@ -269,7 +257,7 @@ func ({{ $receiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) {
{{- end }}
{{- range $e := $.Edges }}
{{- if and $e.Unique (not $e.Optional) }}
if {{ $mutation }}.{{ $e.StructField }}Cleared() && len({{ $mutation }}.{{ $e.StructField }}IDs()) > 0 {
if u.mutation.{{ $e.StructField }}Cleared() && len(u.mutation.{{ $e.StructField }}IDs()) > 0 {
return errors.New(`{{ $pkg }}: clearing a required unique edge "{{ $.Name }}.{{ $e.Name }}"`)
}
{{- end }}

View File

@@ -284,8 +284,6 @@ type {{ $client }} struct {
config
}
{{ $rec := $n.Receiver }}{{ if eq $rec "c" }}{{ $rec = printf "%.2s" $n.Name | lower }}{{ end }}
// New{{ $client }} returns a client for the {{ $n.Name }} from the given config.
func New{{ $client }}(c config) *{{ $client }} {
return &{{ $client }}{config: c}
@@ -340,13 +338,13 @@ func (c *{{ $client }}) Update() *{{ $n.UpdateName }} {
}
// UpdateOne returns an update builder for the given entity.
func (c *{{ $client }}) UpdateOne({{ $rec }} *{{ $n.Name }}) *{{ $n.UpdateOneName }} {
func (c *{{ $client }}) UpdateOne(m *{{ $n.Name }}) *{{ $n.UpdateOneName }} {
{{- if $n.HasOneFieldID }}
mutation := new{{ $n.MutationName }}(c.config, OpUpdateOne, {{ print "with" $n.Name }}({{ $rec }}))
mutation := new{{ $n.MutationName }}(c.config, OpUpdateOne, {{ print "with" $n.Name }}(m))
{{- else }}
mutation := new{{ $n.MutationName }}(c.config, OpUpdateOne)
{{- range $id := $n.EdgeSchema.ID }}
mutation.{{ $id.BuilderField }} = &{{ $rec }}.{{ $id.StructField }}
mutation.{{ $id.BuilderField }} = &m.{{ $id.StructField }}
{{- end }}
{{- end }}
return &{{ $n.UpdateOneName }}{config: c.config, hooks: c.Hooks(), mutation: mutation}
@@ -368,8 +366,8 @@ func (c *{{ $client }}) Delete() *{{ $n.DeleteName }} {
{{ with $n.HasOneFieldID }}
// DeleteOne returns a builder for deleting the given entity.
func (c *{{ $client }}) DeleteOne({{ $rec }} *{{ $n.Name }}) *{{ $n.DeleteOneName }} {
return c.DeleteOneID({{ $rec }}.ID)
func (c *{{ $client }}) DeleteOne(m *{{ $n.Name }}) *{{ $n.DeleteOneName }} {
return c.DeleteOneID(m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
@@ -415,14 +413,13 @@ func (c *{{ $client }}) Query() *{{ $n.QueryName }} {
{{ range $e := $n.Edges }}
{{ $builder := $e.Type.QueryName }}
{{ $arg := $rec }}{{ if eq $arg "id" }}{{ $arg = "node" }}{{ end }}
{{ $func := print "Query" (pascal $e.Name) }}
// Query{{ pascal $e.Name }} queries the {{ $e.Name }} edge of a {{ $n.Name }}.
func (c *{{ $client }}) {{ $func }}({{ $arg }} *{{ $n.Name }}) *{{ $builder }} {
func (c *{{ $client }}) {{ $func }}(m *{{ $n.Name }}) *{{ $builder }} {
{{- if $n.HasOneFieldID }}
query := (&{{ $e.Type.ClientName }}{config: c.config}).Query()
query.path = func(context.Context) (fromV {{ $.Storage.Builder }}, _ error) {
{{- with extend $n "Receiver" $arg "Edge" $e "Ident" "fromV" }}
{{- with extend $n "Receiver" "m" "Edge" $e "Ident" "fromV" }}
{{ $tmpl := printf "dialect/%s/query/from" $.Storage }}
{{- xtemplate $tmpl . -}}
{{- end -}}
@@ -432,7 +429,7 @@ func (c *{{ $client }}) {{ $func }}({{ $arg }} *{{ $n.Name }}) *{{ $builder }} {
{{- else }}
{{- /* For edge schema, we use the predicate-based approach. */}}
return c.Query().
Where({{ range $id := $n.EdgeSchema.ID }}{{ $n.Package }}.{{ $id.StructField }}({{ $arg }}.{{ $id.StructField }}),{{ end }}).
Where({{ range $id := $n.EdgeSchema.ID }}{{ $n.Package }}.{{ $id.StructField }}(m.{{ $id.StructField }}),{{ end }}).
{{ $func }}()
{{- end }}
}

View File

@@ -8,33 +8,31 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/gremlin/create" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $mutation := print $receiver ".mutation" }}
func ({{ $receiver }} *{{ $builder }}) gremlinSave(ctx context.Context) (*{{ $.Name }}, error) {
if err := {{ $receiver }}.check(); err != nil {
func (c *{{ $builder }}) gremlinSave(ctx context.Context) (*{{ $.Name }}, error) {
if err := c.check(); err != nil {
return nil, err
}
res := &gremlin.Response{}
query, bindings := {{ $receiver }}.gremlin().Query()
if err := {{ $receiver }}.driver.Exec(ctx, query, bindings, res); err != nil {
query, bindings := c.gremlin().Query()
if err := c.driver.Exec(ctx, query, bindings, res); err != nil {
return nil, err
}
if err, ok := isConstantError(res); ok {
return nil, err
}
rnode := &{{ $.Name }}{config: {{ $receiver }}.config}
rnode := &{{ $.Name }}{config: c.config}
if err := rnode.FromResponse(res); err != nil {
return nil, err
}
{{- if $.HasOneFieldID }}
{{ $mutation }}.{{ $.ID.BuilderField }} = &rnode.{{ $.ID.StructField }}
{{ $mutation }}.done = true
c.mutation.{{ $.ID.BuilderField }} = &rnode.{{ $.ID.StructField }}
c.mutation.done = true
{{- end }}
return rnode, nil
}
func ({{ $receiver }} *{{ $builder }}) gremlin() *dsl.Traversal {
func (c *{{ $builder }}) gremlin() *dsl.Traversal {
{{- with .NumConstraint }}
type constraint struct {
pred *dsl.Traversal // constraint predicate.
@@ -44,12 +42,12 @@ func ({{ $receiver }} *{{ $builder }}) gremlin() *dsl.Traversal {
{{- end }}
v := g.AddV({{ $.Package }}.Label)
{{- if $.ID.UserDefined }}
if id, ok := {{ $mutation }}.{{ $.ID.MutationGet }}(); ok {
if id, ok := c.mutation.{{ $.ID.MutationGet }}(); ok {
v.Property(dsl.ID, id)
}
{{- end }}
{{- range $f := $.MutationFields }}
if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok {
if value, ok := c.mutation.{{ $f.MutationGet }}(); ok {
{{- if $f.Unique }}
constraints = append(constraints, &constraint{
pred: g.V().Has({{ $.Package }}.Label, {{ $.Package }}.{{ $f.Constant }}, value).Count(),
@@ -62,7 +60,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin() *dsl.Traversal {
{{- range $e := $.Edges }}
{{- $direction := "In" }}
{{- $name := printf "%s.%s" $.Package $e.LabelConstant }}
for _, id := range {{ $mutation }}.{{ $e.StructField }}IDs() {
for _, id := range c.mutation.{{ $e.StructField }}IDs() {
{{- if $e.IsInverse }}
{{- $direction = "Out" }}
{{- $name = printf "%s.%s" $e.Type.Package $e.LabelConstant }}

View File

@@ -8,23 +8,20 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/gremlin/delete" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $mutation := print $receiver ".mutation" }}
func ({{ $receiver}} *{{ $builder }}) gremlinExec(ctx context.Context) (int, error) {
func (d *{{ $builder }}) gremlinExec(ctx context.Context) (int, error) {
res := &gremlin.Response{}
query, bindings := {{ $receiver }}.gremlin().Query()
if err := {{ $receiver }}.driver.Exec(ctx, query, bindings, res); err != nil {
query, bindings := d.gremlin().Query()
if err := d.driver.Exec(ctx, query, bindings, res); err != nil {
return 0, err
}
{{ $mutation }}.done = true
d.mutation.done = true
return res.ReadInt()
}
func ({{ $receiver }} *{{ $builder }}) gremlin() *dsl.Traversal {
func (d *{{ $builder }}) gremlin() *dsl.Traversal {
t := g.V().HasLabel({{ $.Package }}.Label)
for _, p := range {{ $mutation }}.predicates {
for _, p := range d.mutation.predicates {
p(t)
}
return t.SideEffect(__.Drop()).Count()

View File

@@ -8,14 +8,13 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/gremlin/query" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
func (q *{{ $builder }}) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
res := &gremlin.Response{}
traversal := {{ $receiver }}.gremlinQuery(ctx)
if len({{ $receiver }}.ctx.Fields) > 0 {
fields := make([]any, len({{ $receiver }}.ctx.Fields))
for i, f := range {{ $receiver }}.ctx.Fields {
traversal := q.gremlinQuery(ctx)
if len(q.ctx.Fields) > 0 {
fields := make([]any, len(q.ctx.Fields))
for i, f := range q.ctx.Fields {
fields[i] = f
}
traversal.ValueMap(fields...)
@@ -23,43 +22,43 @@ func ({{ $receiver }} *{{ $builder }}) gremlinAll(ctx context.Context, hooks ...
traversal.ValueMap(true)
}
query, bindings := traversal.Query()
if err := {{ $receiver }}.driver.Exec(ctx, query, bindings, res); err != nil {
if err := q.driver.Exec(ctx, query, bindings, res); err != nil {
return nil, err
}
var {{ plural $.Receiver }} {{ plural $.Name }}
if err := {{ plural $.Receiver }}.FromResponse(res); err != nil {
var results {{ plural $.Name }}
if err := results.FromResponse(res); err != nil {
return nil, err
}
for i := range {{ plural $.Receiver }} {
{{ plural $.Receiver }}[i].config = {{ $receiver }}.config
for i := range results {
results[i].config = q.config
}
return {{ plural $.Receiver }}, nil
return results, nil
}
func ({{ $receiver }} *{{ $builder }}) gremlinCount(ctx context.Context) (int, error) {
func (q *{{ $builder }}) gremlinCount(ctx context.Context) (int, error) {
res := &gremlin.Response{}
query, bindings := {{ $receiver }}.gremlinQuery(ctx).Count().Query()
if err := {{ $receiver }}.driver.Exec(ctx, query, bindings, res); err != nil {
query, bindings := q.gremlinQuery(ctx).Count().Query()
if err := q.driver.Exec(ctx, query, bindings, res); err != nil {
return 0, err
}
return res.ReadInt()
}
func ({{ $receiver }} *{{ $builder }}) gremlinQuery(context.Context) *dsl.Traversal {
func (q *{{ $builder }}) gremlinQuery(context.Context) *dsl.Traversal {
v := g.V().HasLabel({{ $.Package }}.Label)
if {{ $receiver }}.gremlin != nil {
v = {{ $receiver }}.gremlin.Clone()
if q.gremlin != nil {
v = q.gremlin.Clone()
}
for _, p := range {{ $receiver }}.predicates {
for _, p := range q.predicates {
p(v)
}
if len({{ $receiver }}.order) > 0 {
if len(q.order) > 0 {
v.Order()
for _, p := range {{ $receiver }}.order {
for _, p := range q.order {
p(v)
}
}
switch limit, offset := {{ $receiver }}.ctx.Limit, {{ $receiver }}.ctx.Offset; {
switch limit, offset := q.ctx.Limit, q.ctx.Offset; {
case limit != nil && offset != nil:
v.Range(*offset, *offset + *limit)
case offset != nil:
@@ -67,7 +66,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlinQuery(context.Context) *dsl.Traver
case limit != nil:
v.Limit(*limit)
}
if unique := {{ $receiver }}.ctx.Unique; unique == nil || *unique {
if unique := q.ctx.Unique; unique == nil || *unique {
v.Dedup()
}
return v

View File

@@ -9,47 +9,44 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/gremlin/update" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $mutation := print $receiver ".mutation" }}
{{ $one := hasSuffix $builder "One" }}
{{ $zero := 0 }}{{ if $one }}{{ $zero = "nil" }}{{ end }}
func ({{ $receiver }} *{{ $builder }}) gremlinSave(ctx context.Context) ({{- if $one }}*{{ $.Name }}{{ else }}int{{ end }}, error) {
func (u *{{ $builder }}) gremlinSave(ctx context.Context) ({{- if $one }}*{{ $.Name }}{{ else }}int{{ end }}, error) {
{{- if $.HasUpdateCheckers }}
if err := {{ $receiver }}.check(); err != nil {
if err := u.check(); err != nil {
return {{ $zero }}, err
}
{{- end }}
res := &gremlin.Response{}
{{- if $one }}
id, ok := {{ $mutation }}.{{ $.ID.MutationGet }}()
id, ok := u.mutation.{{ $.ID.MutationGet }}()
if !ok {
return {{ $zero }}, &ValidationError{Name: "{{ $.ID.Name }}", err: errors.New(`{{ $pkg }}: missing "{{ $.Name }}.{{ $.ID.Name }}" for update`)}
}
query, bindings := {{ $receiver }}.gremlin(id).Query()
query, bindings := u.gremlin(id).Query()
{{- else }}
query, bindings := {{ $receiver }}.gremlin().Query()
query, bindings := u.gremlin().Query()
{{- end }}
if err := {{ $receiver }}.driver.Exec(ctx, query, bindings, res); err != nil {
if err := u.driver.Exec(ctx, query, bindings, res); err != nil {
return {{ $zero }}, err
}
if err, ok := isConstantError(res); ok {
return {{ $zero }}, err
}
{{ $mutation }}.done = true
u.mutation.done = true
{{- if $one }}
{{- $r := $.Receiver }}
{{ $r }} := &{{ $.Name }}{config: {{ $receiver }}.config}
if err := {{ $r }}.FromResponse(res); err != nil {
m := &{{ $.Name }}{config: u.config}
if err := m.FromResponse(res); err != nil {
return nil, err
}
return {{ $r }}, nil
return m, nil
{{- else }}
return res.ReadInt()
{{- end }}
}
func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{ end }}) *dsl.Traversal {
func (u *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{ end }}) *dsl.Traversal {
{{- with .NumConstraint }}
type constraint struct {
pred *dsl.Traversal // constraint predicate.
@@ -63,7 +60,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{
{{- /* general update for N vertices */}}
{{- else }}
v := g.V().HasLabel({{ $.Package }}.Label)
for _, p := range {{ $mutation }}.predicates {
for _, p := range u.mutation.predicates {
p(v)
}
{{- end }}
@@ -76,7 +73,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{
)
{{- range $f := $.MutationFields }}
{{- if or (not $f.Immutable) $f.UpdateDefault }}
if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok {
if value, ok := u.mutation.{{ $f.MutationGet }}(); ok {
{{- if $f.Unique }}
constraints = append(constraints, &constraint{
pred: g.V().Has({{ $.Package }}.Label, {{ $.Package }}.{{ $f.Constant }}, value).Count(),
@@ -86,7 +83,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{
v.Property(dsl.Single, {{ $.Package }}.{{ $f.Constant }}, value)
}
{{- if $f.SupportsMutationAdd }}
if value, ok := {{ $mutation }}.Added{{ $f.StructField }}(); ok {
if value, ok := u.mutation.Added{{ $f.StructField }}(); ok {
{{- if $f.Unique }}
addValue := rv.Clone().Union(__.Values({{ $.Package }}.{{ $f.Constant }}), __.Constant(value)).Sum().Next()
constraints = append(constraints, &constraint{
@@ -104,7 +101,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{
var properties []any
{{- range $f := $.MutationFields }}
{{- if $f.Optional }}
if {{ $mutation }}.{{ $f.StructField }}Cleared() {
if u.mutation.{{ $f.StructField }}Cleared() {
properties = append(properties, {{ $.Package }}.{{ $f.Constant }})
}
{{- end }}
@@ -122,9 +119,9 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{
{{- end }}
{{- /* remove edges */}}
{{- if $e.Unique }}
if {{ $mutation }}.{{ $e.StructField }}Cleared() {
if u.mutation.{{ $e.StructField }}Cleared() {
{{- else }}
for _, id := range {{ $mutation }}.Removed{{ $e.StructField }}IDs() {
for _, id := range u.mutation.Removed{{ $e.StructField }}IDs() {
{{- end }}
{{- if $e.Bidi }}
tr := rv.Clone().BothE({{ $name }}){{ if not $e.Unique }}.Where(__.Or(__.InV().HasID(id), __.OutV().HasID(id))){{ end }}.Drop().Iterate()
@@ -136,7 +133,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{
trs = append(trs, tr)
}
{{- /* update edges */}}
for _, id := range {{ $mutation }}.{{ $e.StructField }}IDs() {
for _, id := range u.mutation.{{ $e.StructField }}IDs() {
{{- if $e.IsInverse }}
v.AddE({{ $name }}).From(g.V(id)).InV()
{{- else }}
@@ -162,10 +159,10 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{
}
{{- end }}
{{- if $one }}
if len({{ $receiver }}.fields) > 0 {
fields := make([]any, 0, len({{ $receiver }}.fields)+1)
if len(u.fields) > 0 {
fields := make([]any, 0, len(u.fields)+1)
fields = append(fields, true)
for _, f := range {{ $receiver }}.fields {
for _, f := range u.fields {
fields = append(fields, f)
}
v.ValueMap(fields...)

View File

@@ -8,20 +8,18 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/create" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $mutation := print $receiver ".mutation" }}
func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name }}, error) {
if err := {{ $receiver }}.check(); err != nil {
func (c *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name }}, error) {
if err := c.check(); err != nil {
return nil, err
}
_node, _spec {{ if $.HasValueScanner }}, err {{ end }} := {{ $receiver }}.createSpec()
_node, _spec {{ if $.HasValueScanner }}, err {{ end }} := c.createSpec()
{{- if $.HasValueScanner }}
if err != nil {
return nil, err
}
{{- end }}
if err := sqlgraph.CreateNode(ctx, {{ $receiver }}.driver, _spec); err != nil {
if err := sqlgraph.CreateNode(ctx, c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
@@ -58,15 +56,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name
{{- end }}
{{- end }}
{{- if $.HasOneFieldID }}
{{ $mutation }}.{{ $.ID.BuilderField }} = &_node.{{ $.ID.StructField }}
{{ $mutation }}.done = true
c.mutation.{{ $.ID.BuilderField }} = &_node.{{ $.ID.StructField }}
c.mutation.done = true
{{- end }}
return _node, nil
}
func ({{ $receiver }} *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.CreateSpec{{ if $.HasValueScanner }}, error{{ end }}) {
func (c *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.CreateSpec{{ if $.HasValueScanner }}, error{{ end }}) {
var (
_node = &{{ $.Name }}{config: {{ $receiver }}.config}
_node = &{{ $.Name }}{config: c.config}
_spec = sqlgraph.NewCreateSpec({{ $.Package }}.Table, {{ if $.HasOneFieldID }}sqlgraph.NewFieldSpec({{ $.Package }}.{{ $.ID.Constant }}, field.{{ $.ID.Type.ConstName }}){{ else }}nil{{ end }})
)
{{- /* Allow mutating the sqlgraph.CreateSpec by ent extensions or user templates.*/}}
@@ -76,13 +74,13 @@ func ({{ $receiver }} *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.Cr
{{- end }}
{{- end }}
{{- if and (not $.HasCompositeID) $.ID.UserDefined }}
if id, ok := {{ $mutation }}.{{ $.ID.MutationGet }}(); ok {
if id, ok := c.mutation.{{ $.ID.MutationGet }}(); ok {
_node.ID = id
_spec.ID.Value = {{ if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}&{{ end }}id
}
{{- end }}
{{- range $f := $.MutationFields }}
if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok {
if value, ok := c.mutation.{{ $f.MutationGet }}(); ok {
{{- if $f.HasValueScanner }}
vv, err := {{ $f.ValueFunc }}(value)
if err != nil {
@@ -96,7 +94,7 @@ func ({{ $receiver }} *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.Cr
}
{{- end }}
{{- range $e := $.EdgesWithID }}
if nodes := {{ $mutation }}.{{ $e.StructField }}IDs(); len(nodes) > 0 {
if nodes := c.mutation.{{ $e.StructField }}IDs(); len(nodes) > 0 {
{{- with extend $ "Edge" $e "Nodes" true "Zero" "nil" }}
{{ template "dialect/sql/defedge" . }}{{/* defined in sql/update.tmpl */}}
{{- end }}
@@ -139,20 +137,19 @@ func ({{ $receiver }} *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.Cr
{{ define "dialect/sql/create_bulk" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
// Save creates the {{ $.Name }} entities in the database.
func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }}, error) {
func (c *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }}, error) {
{{- /* Initialization error was set by MapCreateBulk. */}}
if {{ $receiver }}.err != nil {
return nil, {{ $receiver }}.err
if c.err != nil {
return nil, c.err
}
specs := make([]*sqlgraph.CreateSpec, len({{ $receiver }}.builders))
nodes := make([]*{{ $.Name }}, len({{ $receiver }}.builders))
mutators := make([]Mutator, len({{ $receiver }}.builders))
for i := range {{ $receiver }}.builders {
specs := make([]*sqlgraph.CreateSpec, len(c.builders))
nodes := make([]*{{ $.Name }}, len(c.builders))
mutators := make([]Mutator, len(c.builders))
for i := range c.builders {
func(i int, root context.Context) {
builder := {{ $receiver }}.builders[i]
builder := c.builders[i]
{{- if $.HasDefault }}
builder.defaults()
{{- end }}
@@ -173,7 +170,7 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }
}
{{- end }}
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, {{ $receiver }}.builders[i+1].mutation)
_, err = mutators[i+1].Mutate(root, c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
{{- /* Allow mutating the sqlgraph.BatchCreateSpec by ent extensions or user templates.*/}}
@@ -183,7 +180,7 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }
{{- end }}
{{- end }}
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, {{ $receiver }}.driver, spec); err != nil {
if err = sqlgraph.BatchCreate(ctx, c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
@@ -219,7 +216,7 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, {{ $receiver }}.builders[0].mutation); err != nil {
if _, err := mutators[0].Mutate(ctx, c.builders[0].mutation); err != nil {
return nil, err
}
}
@@ -227,8 +224,8 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }
}
// SaveX is like Save, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) SaveX(ctx context.Context) []*{{ $.Name }} {
v, err := {{ $receiver }}.Save(ctx)
func (c *{{ $builder }}) SaveX(ctx context.Context) []*{{ $.Name }} {
v, err := c.Save(ctx)
if err != nil {
panic(err)
}
@@ -236,14 +233,14 @@ func ({{ $receiver }} *{{ $builder }}) SaveX(ctx context.Context) []*{{ $.Name }
}
// Exec executes the query.
func ({{ $receiver }} *{{ $builder }}) Exec(ctx context.Context) error {
_, err := {{ $receiver }}.Save(ctx)
func (c *{{ $builder }}) Exec(ctx context.Context) error {
_, err := c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) {
if err := {{ $receiver }}.Exec(ctx); err != nil {
func (c *{{ $builder }}) ExecX(ctx context.Context) {
if err := c.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -7,7 +7,6 @@ in the LICENSE file in the root directory of this source tree.
{{/* gotype: entgo.io/ent/entc/gen.typeScope */}}
{{ define "dialect/sql/decode/one" }}
{{ $receiver := $.Receiver }}
{{ $ctypes := dict }}
{{ if $.HasOneFieldID }}
@@ -56,30 +55,29 @@ func (*{{ $.Name }}) scanValues(columns []string) ([]any, error) {
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the {{ $.Name }} fields.
func ({{ $receiver }} *{{ $.Name }}) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
func (m *{{ $.Name }}) assignValues(columns []string, values []any) error {
if v, c := len(values), len(columns); v < c {
return fmt.Errorf("mismatch number of scan values: %d != %d", v, c)
}
{{- $idx := "i" }}{{ if eq $idx $receiver }}{{ $idx = "j" }}{{ end }}
for {{ $idx }} := range columns {
switch columns[{{ $idx }}] {
for i := range columns {
switch columns[i] {
{{- if $.HasOneFieldID }}
case {{ $.Package }}.{{ $.ID.Constant }}:
{{- if or $.ID.IsString $.ID.IsBytes $.ID.HasGoType }}
{{- with extend $ "Idx" $idx "Field" $.ID "Rec" $receiver }}
{{- with extend $ "Idx" "i" "Field" $.ID "Rec" "m" }}
{{ template "dialect/sql/decode/field" . }}
{{- end }}
{{- else }}
value, ok := values[{{ $idx }}].(*sql.NullInt64)
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
{{ $receiver }}.ID = {{ $.ID.Type }}(value.Int64)
m.ID = {{ $.ID.Type }}(value.Int64)
{{- end }}
{{- end }}
{{- range $f := $.Fields }}
case {{ $.Package }}.{{ $f.Constant }}:
{{- with extend $ "Idx" $idx "Field" $f "Rec" $receiver }}
{{- with extend $ "Idx" "i" "Field" $f "Rec" "m" }}
{{ template "dialect/sql/decode/field" . }}
{{- end }}
{{- end }}
@@ -87,21 +85,21 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(columns []string, values []any
{{- $f := $fk.Field }}
case {{ if $fk.UserDefined }}{{ $.Package }}.{{ $.ID.Constant }}{{ else }}{{ $.Package }}.ForeignKeys[{{ $i }}]{{ end }}:
{{- if or $fk.UserDefined (and $f.UserDefined (or $f.IsString $f.IsBytes $f.HasGoType)) }}
{{- with extend $ "Idx" $idx "Field" $f "Rec" $receiver "StructField" $fk.StructField }}
{{- with extend $ "Idx" "i" "Field" $f "Rec" "m" "StructField" $fk.StructField }}
{{ template "dialect/sql/decode/field" . }}
{{- end }}
{{- else }}
if value, ok := values[{{ $idx }}].(*sql.NullInt64); !ok {
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for edge-field {{ $f.Name}}", value)
} else if value.Valid {
{{ $receiver }}.{{ $fk.StructField }} = new({{ $f.Type }})
{{ if and $f.Nillable (not $f.Type.Nillable) }}*{{ end }}{{ $receiver }}.{{ $fk.StructField }} = {{ $f.Type }}(value.Int64)
m.{{ $fk.StructField }} = new({{ $f.Type }})
{{ if and $f.Nillable (not $f.Type.Nillable) }}*{{ end }}m.{{ $fk.StructField }} = {{ $f.Type }}(value.Int64)
}
{{- end }}
{{- end }}
default:
{{- /* In case of no match, allow getting this value by its name. */}}
{{ $receiver }}.selectValues.Set(columns[{{ $idx }}], values[{{ $idx }}])
m.selectValues.Set(columns[i], values[i])
}
}
return nil
@@ -109,8 +107,8 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(columns []string, values []any
// {{ $.ValueName }} returns the ent.Value that was dynamically selected and assigned to the {{ $.Name }}.
// This includes values selected through modifiers, order, etc.
func ({{ $receiver }} *{{ $.Name }}) {{ $.ValueName }}(name string) (ent.Value, error) {
return {{ $receiver }}.selectValues.Get(name)
func (m *{{ $.Name }}) {{ $.ValueName }}(name string) (ent.Value, error) {
return m.selectValues.Get(name)
}
{{ end }}

View File

@@ -8,10 +8,8 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/delete" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $mutation := print $receiver ".mutation" }}
func ({{ $receiver}} *{{ $builder }}) sqlExec(ctx context.Context) (int, error) {
func (d *{{ $builder }}) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec({{ $.Package }}.Table, {{ if $.HasOneFieldID }}sqlgraph.NewFieldSpec({{ $.Package }}.{{ $.ID.Constant }}, field.{{ $.ID.Type.ConstName }}){{ else }}nil{{ end }})
{{- /* Allow mutating the sqlgraph.DeleteSpec by ent extensions or user templates.*/}}
{{- with $tmpls := matchTemplate "dialect/sql/delete/spec/*" }}
@@ -19,18 +17,18 @@ func ({{ $receiver}} *{{ $builder }}) sqlExec(ctx context.Context) (int, error)
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
if ps := {{ $mutation }}.predicates; len(ps) > 0 {
if ps := d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, {{ $receiver}}.driver, _spec)
affected, err := sqlgraph.DeleteNodes(ctx, d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
{{ $mutation }}.done = true
d.mutation.done = true
return affected, err
}

View File

@@ -85,30 +85,29 @@ type predicateAdder interface {
{{ range $i, $n := $.Nodes }}
{{ $builder := $n.QueryName }}
{{ $receiver := receiver $builder }}
{{ $mutation := $n.MutationName }}
{{ $filter := print $n.FilterName }}
// addPredicate implements the predicateAdder interface.
func ({{ $receiver }} *{{ $builder }}) addPredicate(pred func(s *sql.Selector)) {
{{ $receiver }}.predicates = append({{ $receiver }}.predicates, pred)
func (q *{{ $builder }}) addPredicate(pred func(s *sql.Selector)) {
q.predicates = append(q.predicates, pred)
}
// Filter returns a Filter implementation to apply filters on the {{ $builder }} builder.
func ({{ $receiver }} *{{ $builder }}) Filter() *{{ $filter }} {
return &{{ $filter }}{config: {{ $receiver }}.config, predicateAdder: {{ $receiver}} }
func (q *{{ $builder }}) Filter() *{{ $filter }} {
return &{{ $filter }}{config: q.config, predicateAdder: q }
}
{{- if not $n.IsView }}
// addPredicate implements the predicateAdder interface.
func (m *{{ $mutation }}) addPredicate(pred func(s *sql.Selector)) {
m.predicates = append(m.predicates, pred)
}
// addPredicate implements the predicateAdder interface.
func (m *{{ $mutation }}) addPredicate(pred func(s *sql.Selector)) {
m.predicates = append(m.predicates, pred)
}
// Filter returns an entql.Where implementation to apply filters on the {{ $mutation }} builder.
func (m *{{ $mutation }}) Filter() *{{ $filter }} {
return &{{ $filter }}{config: m.config, predicateAdder: m}
}
// Filter returns an entql.Where implementation to apply filters on the {{ $mutation }} builder.
func (m *{{ $mutation }}) Filter() *{{ $filter }} {
return &{{ $filter }}{config: m.config, predicateAdder: m}
}
{{- end }}
// {{ $filter }} provides a generic filtering capability at runtime for {{ $builder }}.

View File

@@ -18,8 +18,7 @@ in the LICENSE file in the root directory of this source tree.
{{/* Template for adding the "executing" the list of modifiers on the sql.Selector. */}}
{{ define "dialect/sql/query/selector/modify" }}
{{- if or ($.FeatureEnabled "sql/lock") ($.FeatureEnabled "sql/modifier") }}
{{- $receiver := pascal $.Scope.Builder | receiver }}
for _, m := range {{ $receiver }}.modifiers {
for _, m := range q.modifiers {
m(selector)
}
{{- end }}
@@ -28,9 +27,8 @@ in the LICENSE file in the root directory of this source tree.
{{/* Template for passing the modifiers to the sqlgraph.QuerySpec. */}}
{{ define "dialect/sql/query/spec/modify" }}
{{- if or ($.FeatureEnabled "sql/lock") ($.FeatureEnabled "sql/modifier") }}
{{- $receiver := pascal $.Scope.Builder | receiver }}
if len({{ $receiver }}.modifiers) > 0 {
_spec.Modifiers = {{ $receiver }}.modifiers
if len(q.modifiers) > 0 {
_spec.Modifiers = q.modifiers
}
{{- end }}
{{- end -}}
@@ -39,12 +37,11 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/query/additional/modify" }}
{{ if $.FeatureEnabled "sql/modifier" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
{{ $selectBuilder := pascal $.Name | printf "%sSelect" }}
// Modify adds a query modifier for attaching custom logic to queries.
func ({{ $receiver }} *{{ $builder }}) Modify(modifiers ...func(s *sql.Selector)) *{{ $selectBuilder }} {
{{ $receiver }}.modifiers = append({{ $receiver }}.modifiers, modifiers...)
return {{ $receiver }}.Select()
func (q *{{ $builder }}) Modify(modifiers ...func(s *sql.Selector)) *{{ $selectBuilder }} {
q.modifiers = append(q.modifiers, modifiers...)
return q.Select()
}
{{ end }}
{{ end }}
@@ -53,11 +50,10 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/select/additional/modify" }}
{{ if $.FeatureEnabled "sql/modifier" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
// Modify adds a query modifier for attaching custom logic to queries.
func ({{ $receiver }} *{{ $builder }}) Modify(modifiers ...func(s *sql.Selector)) *{{ $builder }} {
{{ $receiver }}.modifiers = append({{ $receiver }}.modifiers, modifiers...)
return {{ $receiver }}
func (q *{{ $builder }}) Modify(modifiers ...func(s *sql.Selector)) *{{ $builder }} {
q.modifiers = append(q.modifiers, modifiers...)
return q
}
{{ end }}
{{ end }}
@@ -74,11 +70,10 @@ in the LICENSE file in the root directory of this source tree.
{{ if $.FeatureEnabled "sql/modifier" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
// Modify adds a statement modifier for attaching custom logic to the UPDATE statement.
func ({{ $receiver }} *{{ $builder }}) Modify(modifiers ...func(u *sql.UpdateBuilder)) *{{ $builder }} {
{{ $receiver }}.modifiers = append({{ $receiver }}.modifiers, modifiers...)
return {{ $receiver }}
func (u *{{ $builder }}) Modify(modifiers ...func(*sql.UpdateBuilder)) *{{ $builder }} {
u.modifiers = append(u.modifiers, modifiers...)
return u
}
{{ end }}
{{ end }}
@@ -86,7 +81,6 @@ in the LICENSE file in the root directory of this source tree.
{{/* Template for passing the modifiers to the sqlgraph.UpdateSpec. */}}
{{ define "dialect/sql/update/spec/modify" }}
{{- if $.FeatureEnabled "sql/modifier" }}
{{- $receiver := pascal $.Scope.Builder | receiver }}
_spec.AddModifiers({{ $receiver }}.modifiers...)
_spec.AddModifiers(u.modifiers...)
{{- end }}
{{- end -}}

View File

@@ -20,32 +20,31 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/model/additional/namedges" }}
{{- if $.FeatureEnabled "namedges" }}
{{ $receiver := $.Receiver }}
{{- range $e := $.Edges }}
{{- if not $e.Unique }}
{{ $func := print "Named" $e.StructField }}
// {{ $func }} returns the {{ $e.StructField }} named value or an error if the edge was not
// loaded in eager-loading with this name.
func ({{ $receiver }} *{{ $.Name }}) Named{{ $e.StructField }}(name string) ([]*{{ $e.Type.Name }}, error) {
if {{ $receiver }}.Edges.named{{ $e.StructField }} == nil {
func (m *{{ $.Name }}) Named{{ $e.StructField }}(name string) ([]*{{ $e.Type.Name }}, error) {
if m.Edges.named{{ $e.StructField }} == nil {
return nil, &NotLoadedError{edge: name}
}
nodes, ok := {{ $receiver }}.Edges.named{{ $e.StructField }}[name]
nodes, ok := m.Edges.named{{ $e.StructField }}[name]
if !ok {
return nil, &NotLoadedError{edge: name}
}
return nodes, nil
}
func ({{ $receiver }} *{{ $.Name }}) appendNamed{{ $e.StructField }}(name string, edges ...*{{ $e.Type.Name }}) {
if {{ $receiver }}.Edges.named{{ $e.StructField }} == nil {
{{ $receiver }}.Edges.named{{ $e.StructField }} = make(map[string][]*{{ $e.Type.Name }})
func (m *{{ $.Name }}) appendNamed{{ $e.StructField }}(name string, edges ...*{{ $e.Type.Name }}) {
if m.Edges.named{{ $e.StructField }} == nil {
m.Edges.named{{ $e.StructField }} = make(map[string][]*{{ $e.Type.Name }})
}
if len(edges) == 0 {
{{- /* Prefer empty array over nil to stay consistent with the standard eager-loading API. */}}
{{ $receiver }}.Edges.named{{ $e.StructField }}[name] = []*{{ $e.Type.Name }}{}
m.Edges.named{{ $e.StructField }}[name] = []*{{ $e.Type.Name }}{}
} else {
{{ $receiver }}.Edges.named{{ $e.StructField }}[name] = append({{ $receiver }}.Edges.named{{ $e.StructField }}[name], edges...)
m.Edges.named{{ $e.StructField }}[name] = append(m.Edges.named{{ $e.StructField }}[name], edges...)
}
}
{{- end }}
@@ -73,16 +72,16 @@ in the LICENSE file in the root directory of this source tree.
{{ $func := print "WithNamed" $e.StructField }}
// {{ $func }} tells the query-builder to eager-load the nodes that are connected to the "{{ $e.Name }}"
// edge with the given name. The optional arguments are used to configure the query builder of the edge.
func ({{ $receiver }} *{{ $builder }}) {{ $func }}(name string, opts ...func(*{{ $ebuilder }})) *{{ $builder }} {
query := (&{{ $e.Type.ClientName }}{config: {{ $receiver }}.config}).Query()
func (m *{{ $builder }}) {{ $func }}(name string, opts ...func(*{{ $ebuilder }})) *{{ $builder }} {
query := (&{{ $e.Type.ClientName }}{config: m.config}).Query()
for _, opt := range opts {
opt(query)
}
if {{ $receiver }}.{{ $e.EagerLoadNamedField }} == nil {
{{ $receiver }}.{{ $e.EagerLoadNamedField }} = make(map[string]*{{ $e.Type.QueryName }})
if m.{{ $e.EagerLoadNamedField }} == nil {
m.{{ $e.EagerLoadNamedField }} = make(map[string]*{{ $e.Type.QueryName }})
}
{{ $receiver }}.{{ $e.EagerLoadNamedField }}[name] = query
return {{ $receiver }}
m.{{ $e.EagerLoadNamedField }}[name] = query
return m
}
{{- end }}
{{- end }}
@@ -93,11 +92,10 @@ in the LICENSE file in the root directory of this source tree.
{{- define "dialect/sql/query/all/nodes/namedges" }}
{{- if $.FeatureEnabled "namedges" }}
{{- $builder := pascal $.Scope.Builder }}
{{- $receiver := receiver $builder }}
{{- range $e := $.Edges }}
{{- if not $e.Unique }}
for name, query := range {{ $receiver }}.{{ $e.EagerLoadNamedField }} {
if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes,
for name, query := range q.{{ $e.EagerLoadNamedField }} {
if err := q.load{{ $e.StructField }}(ctx, query, nodes,
func(n *{{ $.Name }}) { n.appendNamed{{ $e.StructField }}(name) },
{{- if and ($.FeatureEnabled "bidiedges") $e.Ref $e.Ref.Unique }}
func(n *{{ $.Name }}, e *{{ $e.Type.Name }}){

View File

@@ -88,7 +88,9 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context
{{- end }}
{{- define "dialect/sql/query/spec/ctxschemaconfig" }}
{{- template "dialect/sql/spec/ctxschemaconfig" . }}
{{- with extend $ "Receiver" "q" }}
{{- template "dialect/sql/spec/ctxschemaconfig" . }}
{{- end }}
{{- end }}
{{- define "dialect/sql/query/eagerloading/join/schemaconfig" }}
@@ -99,8 +101,7 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context
{{- define "dialect/sql/defedge/spec/schemaconfig" }}
{{- $e := $.Scope.Edge }}
{{- $builder := pascal $.Scope.Builder }}
{{- $receiver := receiver $builder }}
{{- $receiver := $.Scope.Receiver }}
{{- $ident := "edge" }}{{ with $.Scope.Ident }}{{ $ident = . }}{{ end }}
{{- if $.FeatureEnabled "sql/schemaconfig" }}
{{- $schema := $e.Type.Name }}
@@ -125,8 +126,7 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context
{{/* A template for injecting the SchemaConfig to the context. Should be executed before other templates. */}}
{{- define "dialect/sql/spec/ctxschemaconfig" -}}
{{- $builder := pascal $.Scope.Builder }}
{{- $receiver := receiver $builder }}
{{- $receiver := $.Scope.Receiver }}
{{- $ident := "_spec.Node" }}{{ with $.Scope.Ident }}{{ $ident = . }}{{ end }}
{{- if $.FeatureEnabled "sql/schemaconfig" }}
{{ $ident }}.Schema = {{ $receiver }}.schemaConfig.{{ $.Name }}
@@ -137,25 +137,25 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context
{{- end -}}
{{- define "dialect/sql/query/selector/ctxschemaconfig" -}}
{{- $builder := pascal $.Scope.Builder }}
{{- $receiver := receiver $builder }}
{{- if $.FeatureEnabled "sql/schemaconfig" }}
t1.Schema({{ $receiver }}.schemaConfig.{{ $.Name }})
ctx = internal.NewSchemaConfigContext(ctx, {{ $receiver }}.schemaConfig)
t1.Schema(q.schemaConfig.{{ $.Name }})
ctx = internal.NewSchemaConfigContext(ctx, q.schemaConfig)
selector.WithContext(ctx)
{{- end }}
{{- end -}}
{{- define "dialect/sql/query/path/ctxschemaconfig" }}
{{- if $.FeatureEnabled "sql/schemaconfig" }}
schemaConfig := {{ $.Scope.Receiver }}.schemaConfig
{{- $receiver := $.Scope.Receiver }}
schemaConfig := {{ $receiver }}.schemaConfig
{{- template "dialect/sql/query/step/ctxschemaconfig" . }}
{{- end -}}
{{- end -}}
{{- define "dialect/sql/query/from/ctxschemaconfig" }}
{{- if $.FeatureEnabled "sql/schemaconfig" }}
schemaConfig := {{ $.Scope.Receiver }}.schemaConfig
{{- $receiver := $.Scope.Receiver }}
schemaConfig := {{ $receiver }}.schemaConfig
{{- template "dialect/sql/query/step/ctxschemaconfig" . }}
{{- end -}}
{{- end -}}
@@ -208,9 +208,9 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context
{{- range $n := $.Nodes }}
{{ $n.Name }}: tableSchemas[{{ indexOf $all $n.TableSchema }}],
{{- range $e := $n.Edges }}
{{- if and $e.M2M (not $e.Inverse) (not $e.Through) }}
{{ $n.Name }}{{ $e.StructField }}: tableSchemas[{{ indexOf $all $e.TableSchema }}],
{{- end }}
{{- if and $e.M2M (not $e.Inverse) (not $e.Through) }}
{{ $n.Name }}{{ $e.StructField }}: tableSchemas[{{ indexOf $all $e.TableSchema }}],
{{- end }}
{{- end }}
{{- end }}
}

View File

@@ -26,14 +26,14 @@ in the LICENSE file in the root directory of this source tree.
{{/* Template for passing the "OnConflict" options to the sqlgraph.CreateSpec. */}}
{{- define "dialect/sql/create/spec/upsert" }}
{{- if $.FeatureEnabled "sql/upsert" }}
_spec.OnConflict = {{ $.Scope.Receiver }}.conflict
_spec.OnConflict = c.conflict
{{- end }}
{{- end }}
{{/* Template for passing the "OnConflict" options to the sqlgraph.BatchCreateSpec. */}}
{{- define "dialect/sql/create_bulk/spec/upsert" }}
{{- if $.FeatureEnabled "sql/upsert" }}
spec.OnConflict = {{ $.Scope.Receiver }}.conflict
spec.OnConflict = c.conflict
{{- end }}
{{- end }}
@@ -47,7 +47,6 @@ in the LICENSE file in the root directory of this source tree.
{{ define "helper/upsertone" }}
{{ $pkg := base $.Config.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $upsertOne := print $.Name "UpsertOne" }}
{{ $upsertSet := print $.Name "Upsert" }}
@@ -74,11 +73,9 @@ in the LICENSE file in the root directory of this source tree.
{{- end }}
// Exec(ctx)
//
func ({{ $receiver }} *{{ $builder }}) OnConflict(opts ...sql.ConflictOption) *{{ $upsertOne }} {
{{ $receiver }}.conflict = opts
return &{{ $upsertOne }}{
create: {{ $receiver }},
}
func (c *{{ $builder }}) OnConflict(opts ...sql.ConflictOption) *{{ $upsertOne }} {
c.conflict = opts
return &{{ $upsertOne }}{create: c}
}
// OnConflictColumns calls `OnConflict` and configures the columns
@@ -88,11 +85,9 @@ func ({{ $receiver }} *{{ $builder }}) OnConflict(opts ...sql.ConflictOption) *{
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
//
func ({{ $receiver }} *{{ $builder }}) OnConflictColumns(columns ...string) *{{ $upsertOne }} {
{{ $receiver }}.conflict = append({{ $receiver }}.conflict, sql.ConflictColumns(columns...))
return &{{ $upsertOne }}{
create: {{ $receiver }},
}
func (c *{{ $builder }}) OnConflictColumns(columns ...string) *{{ $upsertOne }} {
c.conflict = append(c.conflict, sql.ConflictColumns(columns...))
return &{{ $upsertOne }}{create: c}
}
type (
@@ -264,7 +259,6 @@ func (u *{{ $upsertOne }}) ExecX(ctx context.Context) {
{{ define "helper/upsertbulk" }}
{{ $pkg := base $.Config.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $upsertBulk := print $.Name "UpsertBulk" }}
{{ $upsertSet := print $.Name "Upsert" }}
{{ $udfID := false }}{{ if $.HasOneFieldID }}{{ $udfID = $.ID.UserDefined }}{{ end }}
@@ -288,11 +282,9 @@ func (u *{{ $upsertOne }}) ExecX(ctx context.Context) {
{{- end }}
// Exec(ctx)
//
func ({{ $receiver }} *{{ $builder }}) OnConflict(opts ...sql.ConflictOption) *{{ $upsertBulk }} {
{{ $receiver }}.conflict = opts
return &{{ $upsertBulk }}{
create: {{ $receiver }},
}
func (c *{{ $builder }}) OnConflict(opts ...sql.ConflictOption) *{{ $upsertBulk }} {
c.conflict = opts
return &{{ $upsertBulk }}{create: c}
}
// OnConflictColumns calls `OnConflict` and configures the columns
@@ -302,11 +294,9 @@ func ({{ $receiver }} *{{ $builder }}) OnConflict(opts ...sql.ConflictOption) *{
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
//
func ({{ $receiver }} *{{ $builder }}) OnConflictColumns(columns ...string) *{{ $upsertBulk }} {
{{ $receiver }}.conflict = append({{ $receiver }}.conflict, sql.ConflictColumns(columns...))
return &{{ $upsertBulk }}{
create: {{ $receiver }},
}
func (c *{{ $builder }}) OnConflictColumns(columns ...string) *{{ $upsertBulk }} {
c.conflict = append(c.conflict, sql.ConflictColumns(columns...))
return &{{ $upsertBulk }}{create: c}
}
// {{ $upsertBulk }} is the builder for "upsert"-ing

View File

@@ -8,30 +8,29 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/group" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlScan(ctx context.Context, root *{{ $.QueryName }}, v any) error {
func (q *{{ $builder }}) sqlScan(ctx context.Context, root *{{ $.QueryName }}, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len({{ $receiver}}.fns))
for _, fn := range {{ $receiver }}.fns {
aggregation := make([]string, 0, len(q.fns))
for _, fn := range q.fns {
aggregation = append(aggregation, fn(selector))
}
{{- /* If no columns were selected, the default selection is the fields used for "group-by", and the aggregation functions.*/}}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*{{ $receiver }}.flds) + len({{ $receiver}}.fns))
for _, f := range *{{ $receiver }}.flds {
columns := make([]string, 0, len(*q.flds) + len(q.fns))
for _, f := range *q.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*{{ $receiver }}.flds...)...)
selector.GroupBy(selector.Columns(*q.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := {{ $receiver }}.build.driver.Query(ctx, query, args, rows); err != nil {
if err := q.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()

View File

@@ -21,19 +21,18 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/query" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
func (q *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
var (
nodes = []*{{ $.Name }}{}
{{- with $.UnexportedForeignKeys }}
withFKs = {{ $receiver }}.withFKs
withFKs = q.withFKs
{{- end }}
_spec = {{ $receiver }}.querySpec()
_spec = q.querySpec()
{{- with $.Edges }}
loadedTypes = [{{ len . }}]bool{
{{- range $e := . }}
{{ $receiver }}.{{ $e.EagerLoadField }} != nil,
q.{{ $e.EagerLoadField }} != nil,
{{- end }}
}
{{- end }}
@@ -42,7 +41,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
{{- $edgesWithoutField := list }}
{{- range $.FKEdges }}{{ if not .Field }}{{ $edgesWithoutField = append $edgesWithoutField . }}{{ end }}{{ end }}
{{- if $edgesWithoutField }}
if {{ range $i, $e := $edgesWithoutField }}{{ if $i }} || {{ end }}{{ $receiver }}.{{ $e.EagerLoadField }} != nil{{ end }} {
if {{ range $i, $e := $edgesWithoutField }}{{ if $i }} || {{ end }}q.{{ $e.EagerLoadField }} != nil{{ end }} {
withFKs = true
}
{{- end }}
@@ -54,7 +53,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
return (*{{ $.Name }}).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &{{ $.Name }}{config: {{ $receiver }}.config}
node := &{{ $.Name }}{config: q.config}
nodes = append(nodes, node)
{{- with $.Edges }}
node.Edges.loadedTypes = loadedTypes
@@ -70,15 +69,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
if err := sqlgraph.QueryNodes(ctx, q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
{{- range $e := $.Edges }}
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }}
if query := q.{{ $e.EagerLoadField }}; query != nil {
if err := q.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }}
func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }},
{{- $lhs := printf "n.Edges.%s" $e.StructField }}
{{- $rhs := print "e" }}{{- if not $e.Unique }}{{ $rhs = printf "append(%s, e)" $lhs }}{{ end }}
@@ -110,7 +109,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
{{/* Generate a method to eager-load each edge. */}}
{{- range $e := $.Edges }}
func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
func (q *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
{{- if $e.M2M }}
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
@@ -126,7 +125,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
{{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/join/*" }}
{{- range $tmpl := $tmpls }}
{{- with extend $ "Edge" $e }}
{{- with extend $ "Edge" $e "Receiver" "q" }}
{{- xtemplate $tmpl . }}
{{- end }}
{{- end }}
@@ -261,8 +260,8 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
}
{{- end }}
func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
_spec := {{ $receiver }}.querySpec()
func (q *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
_spec := q.querySpec()
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}}
{{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }}
{{- range $tmpl := $tmpls }}
@@ -274,25 +273,25 @@ func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error
_spec.Unique = false
_spec.Node.Columns = nil
{{- else }}
_spec.Node.Columns = {{ $receiver }}.ctx.Fields
if len({{ $receiver }}.ctx.Fields) > 0 {
_spec.Node.Columns = q.ctx.Fields
if len(q.ctx.Fields) > 0 {
{{- /* In case of field selection, configure query to unique only if was explicitly set to true. */}}
_spec.Unique = {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique
_spec.Unique = q.ctx.Unique != nil && *q.ctx.Unique
}
{{- end }}
return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec)
return sqlgraph.CountNodes(ctx, q.driver, _spec)
}
func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
func (q *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec({{ $.Package }}.Table, {{ $.Package }}.Columns, {{ if $.HasOneFieldID }}sqlgraph.NewFieldSpec({{ $.Package }}.{{ $.ID.Constant }}, field.{{ $.ID.Type.ConstName }}){{ else }}nil{{ end }})
{{- /* Setup any intermediate queries if exist (traversal path). */}}
_spec.From = {{ $receiver }}.sql
if unique := {{ $receiver }}.ctx.Unique; unique != nil {
_spec.From = q.sql
if unique := q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if {{ $receiver }}.path != nil {
} else if q.path != nil {
_spec.Unique = true
}
if fields := {{ $receiver }}.ctx.Fields; len(fields) > 0 {
if fields := q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
{{- if $.HasOneFieldID }}
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }})
@@ -312,25 +311,25 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
{{- if not $f }}
{{- continue }}
{{- end }}
if {{ $receiver }}.{{ .EagerLoadField }} != nil {
if q.{{ .EagerLoadField }} != nil {
_spec.Node.AddColumnOnce({{ $.Package }}.{{ $f.Constant }})
}
{{- end }}
}
if ps := {{ $receiver }}.predicates; len(ps) > 0 {
if ps := q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
if limit := q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
if offset := q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := {{ $receiver }}.order; len(ps) > 0 {
if ps := q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
@@ -353,22 +352,21 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
{{ define "dialect/sql/query/selector" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector {
func (q *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector {
{{- $builderV := "builder" }}{{ if eq $.Package $builderV }}{{ $builderV = "builderC" }}{{ end }}
{{ $builderV }} := sql.Dialect({{ $receiver }}.driver.Dialect())
{{ $builderV }} := sql.Dialect(q.driver.Dialect())
t1 := {{ $builderV }}.Table({{ $.Package }}.Table)
columns := {{ $receiver }}.ctx.Fields
columns := q.ctx.Fields
if len(columns) == 0 {
columns = {{ $.Package }}.Columns
}
selector := {{ $builderV }}.Select(t1.Columns(columns...)...).From(t1)
if {{ $receiver }}.sql != nil {
selector = {{ $receiver }}.sql
if q.sql != nil {
selector = q.sql
selector.Select(selector.Columns(columns...)...)
}
if {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique {
if q.ctx.Unique != nil && *q.ctx.Unique {
selector.Distinct()
}
{{- /* Allow mutating the sql.Selector by ent extensions or user templates.*/}}
@@ -377,18 +375,18 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
for _, p := range {{ $receiver }}.predicates {
for _, p := range q.predicates {
p(selector)
}
for _, p := range {{ $receiver }}.order {
for _, p := range q.order {
p(selector)
}
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
if offset := q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
if limit := q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
@@ -400,8 +398,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- $n := $ }} {{/* the node we start the query from. */}}
{{- $e := $.Scope.Edge }} {{/* the edge we need to generate the path to. */}}
{{- $ident := $.Scope.Ident -}}
{{- $receiver := $.Scope.Receiver }}
selector := {{ $receiver }}.sqlQuery(ctx)
selector := q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
@@ -422,7 +419,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{ $ident }} = sqlgraph.SetNeighbors({{ $receiver }}.driver.Dialect(), step)
{{ $ident }} = sqlgraph.SetNeighbors(q.driver.Dialect(), step)
{{ end }}
{{/* query/from defines the query generation for an edge query from a given node. */}}
@@ -430,8 +427,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- $n := $ }} {{/* the node we start the query from. */}}
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
{{- $ident := $.Scope.Ident -}}
{{- $receiver := $.Scope.Receiver -}}
id := {{ $receiver }}.ID
id := m.ID
step := sqlgraph.NewStep(
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
@@ -449,7 +445,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
{{ $ident }} = sqlgraph.Neighbors(m.driver.Dialect(), step)
{{ end }}
{{ define "dialect/sql/query/eagerloading/m2massign" }}
@@ -465,8 +461,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{ define "dialect/sql/query/preparecheck" }}
{{- $pkg := $.Scope.Package }}
{{- $receiver := $.Scope.Receiver }}
for _, f := range {{ $receiver }}.ctx.Fields {
for _, f := range q.ctx.Fields {
if !{{ $.Package }}.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("{{ $pkg }}: invalid field %q for query", f)}
}

View File

@@ -8,15 +8,14 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/select" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlScan(ctx context.Context, root *{{ $.QueryName }}, v any) error {
func (q *{{ $builder }}) sqlScan(ctx context.Context, root *{{ $.QueryName }}, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len({{ $receiver}}.fns))
for _, fn := range {{ $receiver }}.fns {
aggregation := make([]string, 0, len(q.fns))
for _, fn := range q.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*{{ $receiver }}.selector.flds); {
switch n := len(*q.selector.flds); {
{{- /* If no columns were selected, the default selection is the aggregation.*/}}
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
@@ -25,7 +24,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlScan(ctx context.Context, root *{{ $.Q
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := {{ $receiver }}.driver.Query(ctx, query, args, rows); err != nil {
if err := q.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()

View File

@@ -18,11 +18,8 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/update" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := $.Scope.Receiver }}
{{ $mutation := print $receiver ".mutation" }}
{{ $one := hasSuffix $builder "One" }}
{{- $zero := 0 }}{{ if $one }}{{ $zero = "nil" }}{{ end }}
{{- $ret := "n" }}{{ if eq $ret $receiver }}{{ $ret = "_n" }}{{ end }}{{ if $one }}{{ $ret = "_node" }}{{ end }}
{{- /* Allow adding methods to the update-builder by ent extensions or user templates.*/}}
{{- with $tmpls := matchTemplate "dialect/sql/update/additional/*" }}
@@ -31,10 +28,10 @@ in the LICENSE file in the root directory of this source tree.
{{- end }}
{{- end }}
func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }} {{ if $one }}*{{ $.Name }}{{ else }}int{{ end }}, err error) {
func (u *{{ $builder }}) sqlSave(ctx context.Context) (_n {{ if $one }}*{{ $.Name }}{{ else }}int{{ end }}, err error) {
{{- if $.HasUpdateCheckers }}
if err := {{ $receiver }}.check(); err != nil {
return {{ $ret }}, err
if err := u.check(); err != nil {
return _n, err
}
{{- end }}
_spec := sqlgraph.NewUpdateSpec({{ $.Package }}.Table, {{ $.Package }}.Columns,
@@ -47,12 +44,12 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
{{- end }})
{{- if $one }}
{{- if $.HasOneFieldID }}
id, ok := {{ $mutation }}.{{ $.ID.MutationGet }}()
id, ok := u.mutation.{{ $.ID.MutationGet }}()
if !ok {
return {{ $zero }}, &ValidationError{Name: "{{ $.ID.Name }}", err: errors.New(`{{ $pkg }}: missing "{{ $.Name }}.{{ $.ID.Name }}" for update`)}
}
_spec.Node.ID.Value = id
if fields := {{ $receiver }}.fields; len(fields) > 0 {
if fields := u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }})
for _, f := range fields {
@@ -66,13 +63,13 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
}
{{- else }}{{/* Composite ID. */}}
{{- range $i, $id := $.EdgeSchema.ID }}
if id, ok := {{ $mutation }}.{{ $id.MutationGet }}(); !ok {
if id, ok := u.mutation.{{ $id.MutationGet }}(); !ok {
return {{ $zero }}, &ValidationError{Name: "{{ $id.Name }}", err: errors.New(`{{ $pkg }}: missing "{{ $.Name }}.{{ $id.Name }}" for update`)}
} else {
_spec.Node.CompositeID[{{ $i }}].Value = id
}
{{- end }}
if fields := {{ $receiver }}.fields; len(fields) > 0 {
if fields := u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, len(fields))
for i, f := range fields {
if !{{ $.Package }}.ValidColumn(f) {
@@ -83,7 +80,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
}
{{- end }}
{{- end }}
if ps := {{ $mutation }}.predicates; len(ps) > 0 {
if ps := u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
@@ -92,7 +89,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
}
{{- range $f := $.MutationFields }}
{{- if or (not $f.Immutable) $f.UpdateDefault }}
if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok {
if value, ok := u.mutation.{{ $f.MutationGet }}(); ok {
{{- if $f.HasValueScanner }}
vv, err := {{ $f.ValueFunc }}(value)
if err != nil {
@@ -104,7 +101,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
{{- end }}
}
{{- if $f.SupportsMutationAdd }}
if value, ok := {{ $mutation }}.{{ $f.MutationAdded }}(); ok {
if value, ok := u.mutation.{{ $f.MutationAdded }}(); ok {
{{- if $f.HasValueScanner }}
vv, err := {{ $f.ValueFunc }}(value)
if err != nil {
@@ -117,7 +114,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
}
{{- end }}
{{- if $f.SupportsMutationAppend }}
if value, ok := {{ $mutation }}.{{ $f.MutationAppended }}(); ok {
if value, ok := u.mutation.{{ $f.MutationAppended }}(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, {{ $.Package }}.{{ $f.Constant }}, value)
})
@@ -125,7 +122,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
{{- end }}
{{- end }}
{{- if $f.Optional }}
if {{ $mutation }}.{{ $f.StructField }}Cleared() {
if u.mutation.{{ $f.StructField }}Cleared() {
_spec.ClearField({{ $.Package }}.{{ $f.Constant }}, field.{{ $f.Type.ConstName }})
}
{{- end }}
@@ -135,21 +132,21 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
{{- /* Skip to the next one as immutable edges cannot be updated. */}}
{{- continue }}
{{- end }}
if {{ $mutation }}.{{ $e.MutationCleared }}() {
if u.mutation.{{ $e.MutationCleared }}() {
{{- with extend $ "Edge" $e }}
{{ template "dialect/sql/defedge" . }}
{{- end }}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
{{- if not $e.Unique }}
if nodes := {{ $mutation }}.Removed{{ $e.StructField }}IDs(); len(nodes) > 0 && !{{ $mutation }}.{{ $e.MutationCleared }}() {
if nodes := u.mutation.Removed{{ $e.StructField }}IDs(); len(nodes) > 0 && !u.mutation.{{ $e.MutationCleared }}() {
{{- with extend $ "Edge" $e "Nodes" true "Zero" $zero }}
{{ template "dialect/sql/defedge" . }}
{{- end }}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
{{- end }}
if nodes := {{ $mutation }}.{{ $e.StructField }}IDs(); len(nodes) > 0 {
if nodes := u.mutation.{{ $e.StructField }}IDs(); len(nodes) > 0 {
{{- with extend $ "Edge" $e "Nodes" true "Zero" $zero }}
{{ template "dialect/sql/defedge" . }}
{{- end }}
@@ -163,14 +160,14 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
{{- end }}
{{- end }}
{{- if $one }}
{{ $ret }} = &{{ $.Name }}{config: {{ $receiver }}.config}
_spec.Assign = {{ $ret }}.assignValues
_spec.ScanValues = {{ $ret }}.scanValues
_n = &{{ $.Name }}{config: u.config}
_spec.Assign = _n.assignValues
_spec.ScanValues = _n.scanValues
{{- end }}
{{- if $one }}
if err = sqlgraph.UpdateNode(ctx, {{ $receiver }}.driver, _spec); err != nil {
if err = sqlgraph.UpdateNode(ctx, u.driver, _spec); err != nil {
{{- else }}
if {{ $ret }}, err = sqlgraph.UpdateNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
if _n, err = sqlgraph.UpdateNodes(ctx, u.driver, _spec); err != nil {
{{- end }}
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{ {{ $.Package }}.Label}
@@ -179,8 +176,8 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }}
}
return {{ $zero }}, err
}
{{ $mutation }}.done = true
return {{ $ret }}, nil
u.mutation.done = true
return _n, nil
}
{{ end }}

View File

@@ -90,13 +90,11 @@ type {{ $edgesType }} struct {
{{ $tmpl = printf "dialect/%s/decode/one" $.Storage }}
{{ xtemplate $tmpl $ }}
{{ $receiver := $.Receiver }}
{{ range $e := $.Edges }}
{{ $func := print "Query" $e.StructField }}
// {{ $func }} queries the "{{ $e.Name }}" edge of the {{ $.Name }} entity.
func ({{ $receiver }} *{{ $.Name }}) {{ $func }}() *{{ $e.Type.QueryName }} {
return New{{ $.ClientName }}({{ $receiver }}.config).{{ $func }}({{ $receiver }})
func (m *{{ $.Name }}) {{ $func }}() *{{ $e.Type.QueryName }} {
return New{{ $.ClientName }}(m.config).{{ $func }}(m)
}
{{ end }}
@@ -104,20 +102,20 @@ type {{ $edgesType }} struct {
// Update returns a builder for updating this {{ $.Name }}.
// Note that you need to call {{ $.Name }}.Unwrap() before calling this method if this {{ $.Name }}
// was returned from a transaction, and the transaction was committed or rolled back.
func ({{ $receiver }} *{{ $.Name }}) Update() *{{ $.UpdateOneName }} {
return New{{ $.ClientName }}({{ $receiver }}.config).UpdateOne({{ $receiver }})
func (m *{{ $.Name }}) Update() *{{ $.UpdateOneName }} {
return New{{ $.ClientName }}(m.config).UpdateOne(m)
}
{{- end }}
// Unwrap unwraps the {{ $.Name }} entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func ({{ $receiver }} *{{ $.Name }}) Unwrap() *{{ $.Name }} {
_tx, ok := {{ $receiver }}.config.driver.(*txDriver)
func (m *{{ $.Name }}) Unwrap() *{{ $.Name }} {
_tx, ok := m.config.driver.(*txDriver)
if !ok {
panic("{{ $pkg }}: {{ $.Name }} is not a transactional entity")
}
{{ $receiver }}.config.driver = _tx.drv
return {{ $receiver }}
m.config.driver = _tx.drv
return m
}
{{ template "model/stringer" $ }}
@@ -136,14 +134,12 @@ type {{ $slice }} []*{{ $.Name }}
{{/* A template to generate a fmt.Stringer implementation. */}}
{{ define "model/stringer" }}
{{ $receiver := $.Receiver }}
// String implements the fmt.Stringer.
func ({{ $receiver }} *{{ $.Name }}) String() string {
func (m *{{ $.Name }}) String() string {
var builder strings.Builder
builder.WriteString("{{ $.Name }}(")
{{- if $.HasOneFieldID }}
builder.WriteString(fmt.Sprintf("id=%v{{ if $.Fields }}, {{ end }}", {{ $receiver }}.ID))
builder.WriteString(fmt.Sprintf("id=%v{{ if $.Fields }}, {{ end }}", m.ID))
{{- end }}
{{- range $i, $f := $.Fields }}
{{- if ne $i 0 }}
@@ -152,7 +148,7 @@ type {{ $slice }} []*{{ $.Name }}
{{- if $f.Sensitive }}
builder.WriteString("{{ $f.Name }}={{ print "<sensitive>" }}")
{{- else }}
{{- $sf := printf "%s.%s" $receiver $f.StructField }}
{{- $sf := printf "m.%s" $f.StructField }}
{{- if $f.Nillable }}
if v := {{ $sf }}; v != nil {
builder.WriteString("{{ $f.Name }}=")

View File

@@ -4,10 +4,9 @@
{{/* A template for adding the Modify method to the query builder. */}}
{{ define "dialect/sql/query/additional/modify" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) Modify(modifier func(s *sql.Selector)) *{{ $builder }} {
{{ $receiver }}.modifiers = append({{ $receiver }}.modifiers, modifier)
return {{ $receiver }}
func (q *{{ $builder }}) Modify(modifier func(*sql.Selector)) *{{ $builder }} {
q.modifiers = append(q.modifiers, modifier)
return q
}
{{ end }}
@@ -16,15 +15,13 @@
{{- end }}
{{ define "dialect/sql/query/selector/modify" }}
{{- $receiver := pascal $.Scope.Builder | receiver }}
for _, m := range {{ $receiver }}.modifiers {
for _, m := range q.modifiers {
m(selector)
}
{{- end -}}
{{ define "dialect/sql/query/spec/modify" }}
{{- $receiver := pascal $.Scope.Builder | receiver }}
if len({{ $receiver }}.modifiers) > 0 {
_spec.Modifiers = {{ $receiver }}.modifiers
if len(q.modifiers) > 0 {
_spec.Modifiers = q.modifiers
}
{{- end -}}