mirror of
https://github.com/ent/ent.git
synced 2026-05-05 09:00:57 +03:00
entc/gen/template: use a fixed receiver
This commit is contained in:
@@ -21,19 +21,18 @@ in the LICENSE file in the root directory of this source tree.
|
||||
{{ define "dialect/sql/query" }}
|
||||
{{ $pkg := $.Scope.Package }}
|
||||
{{ $builder := pascal $.Scope.Builder }}
|
||||
{{ $receiver := receiver $builder }}
|
||||
|
||||
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
|
||||
func (q *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
|
||||
var (
|
||||
nodes = []*{{ $.Name }}{}
|
||||
{{- with $.UnexportedForeignKeys }}
|
||||
withFKs = {{ $receiver }}.withFKs
|
||||
withFKs = q.withFKs
|
||||
{{- end }}
|
||||
_spec = {{ $receiver }}.querySpec()
|
||||
_spec = q.querySpec()
|
||||
{{- with $.Edges }}
|
||||
loadedTypes = [{{ len . }}]bool{
|
||||
{{- range $e := . }}
|
||||
{{ $receiver }}.{{ $e.EagerLoadField }} != nil,
|
||||
q.{{ $e.EagerLoadField }} != nil,
|
||||
{{- end }}
|
||||
}
|
||||
{{- end }}
|
||||
@@ -42,7 +41,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
{{- $edgesWithoutField := list }}
|
||||
{{- range $.FKEdges }}{{ if not .Field }}{{ $edgesWithoutField = append $edgesWithoutField . }}{{ end }}{{ end }}
|
||||
{{- if $edgesWithoutField }}
|
||||
if {{ range $i, $e := $edgesWithoutField }}{{ if $i }} || {{ end }}{{ $receiver }}.{{ $e.EagerLoadField }} != nil{{ end }} {
|
||||
if {{ range $i, $e := $edgesWithoutField }}{{ if $i }} || {{ end }}q.{{ $e.EagerLoadField }} != nil{{ end }} {
|
||||
withFKs = true
|
||||
}
|
||||
{{- end }}
|
||||
@@ -54,7 +53,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
return (*{{ $.Name }}).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []any) error {
|
||||
node := &{{ $.Name }}{config: {{ $receiver }}.config}
|
||||
node := &{{ $.Name }}{config: q.config}
|
||||
nodes = append(nodes, node)
|
||||
{{- with $.Edges }}
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
@@ -70,15 +69,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
|
||||
if err := sqlgraph.QueryNodes(ctx, q.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
{{- range $e := $.Edges }}
|
||||
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
|
||||
if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }}
|
||||
if query := q.{{ $e.EagerLoadField }}; query != nil {
|
||||
if err := q.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }}
|
||||
func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }},
|
||||
{{- $lhs := printf "n.Edges.%s" $e.StructField }}
|
||||
{{- $rhs := print "e" }}{{- if not $e.Unique }}{{ $rhs = printf "append(%s, e)" $lhs }}{{ end }}
|
||||
@@ -110,7 +109,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
|
||||
{{/* Generate a method to eager-load each edge. */}}
|
||||
{{- range $e := $.Edges }}
|
||||
func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
|
||||
func (q *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
|
||||
{{- if $e.M2M }}
|
||||
edgeIDs := make([]driver.Value, len(nodes))
|
||||
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
@@ -126,7 +125,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
|
||||
{{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/join/*" }}
|
||||
{{- range $tmpl := $tmpls }}
|
||||
{{- with extend $ "Edge" $e }}
|
||||
{{- with extend $ "Edge" $e "Receiver" "q" }}
|
||||
{{- xtemplate $tmpl . }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -261,8 +260,8 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := {{ $receiver }}.querySpec()
|
||||
func (q *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := q.querySpec()
|
||||
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}}
|
||||
{{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }}
|
||||
{{- range $tmpl := $tmpls }}
|
||||
@@ -274,25 +273,25 @@ func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error
|
||||
_spec.Unique = false
|
||||
_spec.Node.Columns = nil
|
||||
{{- else }}
|
||||
_spec.Node.Columns = {{ $receiver }}.ctx.Fields
|
||||
if len({{ $receiver }}.ctx.Fields) > 0 {
|
||||
_spec.Node.Columns = q.ctx.Fields
|
||||
if len(q.ctx.Fields) > 0 {
|
||||
{{- /* In case of field selection, configure query to unique only if was explicitly set to true. */}}
|
||||
_spec.Unique = {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique
|
||||
_spec.Unique = q.ctx.Unique != nil && *q.ctx.Unique
|
||||
}
|
||||
{{- end }}
|
||||
return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec)
|
||||
return sqlgraph.CountNodes(ctx, q.driver, _spec)
|
||||
}
|
||||
|
||||
func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
|
||||
func (q *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
|
||||
_spec := sqlgraph.NewQuerySpec({{ $.Package }}.Table, {{ $.Package }}.Columns, {{ if $.HasOneFieldID }}sqlgraph.NewFieldSpec({{ $.Package }}.{{ $.ID.Constant }}, field.{{ $.ID.Type.ConstName }}){{ else }}nil{{ end }})
|
||||
{{- /* Setup any intermediate queries if exist (traversal path). */}}
|
||||
_spec.From = {{ $receiver }}.sql
|
||||
if unique := {{ $receiver }}.ctx.Unique; unique != nil {
|
||||
_spec.From = q.sql
|
||||
if unique := q.ctx.Unique; unique != nil {
|
||||
_spec.Unique = *unique
|
||||
} else if {{ $receiver }}.path != nil {
|
||||
} else if q.path != nil {
|
||||
_spec.Unique = true
|
||||
}
|
||||
if fields := {{ $receiver }}.ctx.Fields; len(fields) > 0 {
|
||||
if fields := q.ctx.Fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
{{- if $.HasOneFieldID }}
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }})
|
||||
@@ -312,25 +311,25 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
|
||||
{{- if not $f }}
|
||||
{{- continue }}
|
||||
{{- end }}
|
||||
if {{ $receiver }}.{{ .EagerLoadField }} != nil {
|
||||
if q.{{ .EagerLoadField }} != nil {
|
||||
_spec.Node.AddColumnOnce({{ $.Package }}.{{ $f.Constant }})
|
||||
}
|
||||
{{- end }}
|
||||
}
|
||||
if ps := {{ $receiver }}.predicates; len(ps) > 0 {
|
||||
if ps := q.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
|
||||
if limit := q.ctx.Limit; limit != nil {
|
||||
_spec.Limit = *limit
|
||||
}
|
||||
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
|
||||
if offset := q.ctx.Offset; offset != nil {
|
||||
_spec.Offset = *offset
|
||||
}
|
||||
if ps := {{ $receiver }}.order; len(ps) > 0 {
|
||||
if ps := q.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
@@ -353,22 +352,21 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
|
||||
|
||||
{{ define "dialect/sql/query/selector" }}
|
||||
{{ $builder := pascal $.Scope.Builder }}
|
||||
{{ $receiver := receiver $builder }}
|
||||
|
||||
func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
func (q *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
{{- $builderV := "builder" }}{{ if eq $.Package $builderV }}{{ $builderV = "builderC" }}{{ end }}
|
||||
{{ $builderV }} := sql.Dialect({{ $receiver }}.driver.Dialect())
|
||||
{{ $builderV }} := sql.Dialect(q.driver.Dialect())
|
||||
t1 := {{ $builderV }}.Table({{ $.Package }}.Table)
|
||||
columns := {{ $receiver }}.ctx.Fields
|
||||
columns := q.ctx.Fields
|
||||
if len(columns) == 0 {
|
||||
columns = {{ $.Package }}.Columns
|
||||
}
|
||||
selector := {{ $builderV }}.Select(t1.Columns(columns...)...).From(t1)
|
||||
if {{ $receiver }}.sql != nil {
|
||||
selector = {{ $receiver }}.sql
|
||||
if q.sql != nil {
|
||||
selector = q.sql
|
||||
selector.Select(selector.Columns(columns...)...)
|
||||
}
|
||||
if {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique {
|
||||
if q.ctx.Unique != nil && *q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
{{- /* Allow mutating the sql.Selector by ent extensions or user templates.*/}}
|
||||
@@ -377,18 +375,18 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- xtemplate $tmpl $ }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
for _, p := range {{ $receiver }}.predicates {
|
||||
for _, p := range q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range {{ $receiver }}.order {
|
||||
for _, p := range q.order {
|
||||
p(selector)
|
||||
}
|
||||
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
|
||||
if offset := q.ctx.Offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||
}
|
||||
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
|
||||
if limit := q.ctx.Limit; limit != nil {
|
||||
selector.Limit(*limit)
|
||||
}
|
||||
return selector
|
||||
@@ -400,8 +398,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- $n := $ }} {{/* the node we start the query from. */}}
|
||||
{{- $e := $.Scope.Edge }} {{/* the edge we need to generate the path to. */}}
|
||||
{{- $ident := $.Scope.Ident -}}
|
||||
{{- $receiver := $.Scope.Receiver }}
|
||||
selector := {{ $receiver }}.sqlQuery(ctx)
|
||||
selector := q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -422,7 +419,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- xtemplate $tmpl $ }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{ $ident }} = sqlgraph.SetNeighbors({{ $receiver }}.driver.Dialect(), step)
|
||||
{{ $ident }} = sqlgraph.SetNeighbors(q.driver.Dialect(), step)
|
||||
{{ end }}
|
||||
|
||||
{{/* query/from defines the query generation for an edge query from a given node. */}}
|
||||
@@ -430,8 +427,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- $n := $ }} {{/* the node we start the query from. */}}
|
||||
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
|
||||
{{- $ident := $.Scope.Ident -}}
|
||||
{{- $receiver := $.Scope.Receiver -}}
|
||||
id := {{ $receiver }}.ID
|
||||
id := m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
|
||||
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
|
||||
@@ -449,7 +445,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- xtemplate $tmpl $ }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
|
||||
{{ $ident }} = sqlgraph.Neighbors(m.driver.Dialect(), step)
|
||||
{{ end }}
|
||||
|
||||
{{ define "dialect/sql/query/eagerloading/m2massign" }}
|
||||
@@ -465,8 +461,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
|
||||
{{ define "dialect/sql/query/preparecheck" }}
|
||||
{{- $pkg := $.Scope.Package }}
|
||||
{{- $receiver := $.Scope.Receiver }}
|
||||
for _, f := range {{ $receiver }}.ctx.Fields {
|
||||
for _, f := range q.ctx.Fields {
|
||||
if !{{ $.Package }}.ValidColumn(f) {
|
||||
return &ValidationError{Name: f, err: fmt.Errorf("{{ $pkg }}: invalid field %q for query", f)}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user