entc/gen/template: use a fixed receiver

This commit is contained in:
Giau. Tran Minh
2025-03-17 11:55:47 +07:00
parent 6813cdd337
commit 31bbe41c5d
24 changed files with 495 additions and 575 deletions

View File

@@ -21,19 +21,18 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/query" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
func (q *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
var (
nodes = []*{{ $.Name }}{}
{{- with $.UnexportedForeignKeys }}
withFKs = {{ $receiver }}.withFKs
withFKs = q.withFKs
{{- end }}
_spec = {{ $receiver }}.querySpec()
_spec = q.querySpec()
{{- with $.Edges }}
loadedTypes = [{{ len . }}]bool{
{{- range $e := . }}
{{ $receiver }}.{{ $e.EagerLoadField }} != nil,
q.{{ $e.EagerLoadField }} != nil,
{{- end }}
}
{{- end }}
@@ -42,7 +41,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
{{- $edgesWithoutField := list }}
{{- range $.FKEdges }}{{ if not .Field }}{{ $edgesWithoutField = append $edgesWithoutField . }}{{ end }}{{ end }}
{{- if $edgesWithoutField }}
if {{ range $i, $e := $edgesWithoutField }}{{ if $i }} || {{ end }}{{ $receiver }}.{{ $e.EagerLoadField }} != nil{{ end }} {
if {{ range $i, $e := $edgesWithoutField }}{{ if $i }} || {{ end }}q.{{ $e.EagerLoadField }} != nil{{ end }} {
withFKs = true
}
{{- end }}
@@ -54,7 +53,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
return (*{{ $.Name }}).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &{{ $.Name }}{config: {{ $receiver }}.config}
node := &{{ $.Name }}{config: q.config}
nodes = append(nodes, node)
{{- with $.Edges }}
node.Edges.loadedTypes = loadedTypes
@@ -70,15 +69,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
if err := sqlgraph.QueryNodes(ctx, q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
{{- range $e := $.Edges }}
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }}
if query := q.{{ $e.EagerLoadField }}; query != nil {
if err := q.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }}
func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }},
{{- $lhs := printf "n.Edges.%s" $e.StructField }}
{{- $rhs := print "e" }}{{- if not $e.Unique }}{{ $rhs = printf "append(%s, e)" $lhs }}{{ end }}
@@ -110,7 +109,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
{{/* Generate a method to eager-load each edge. */}}
{{- range $e := $.Edges }}
func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
func (q *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
{{- if $e.M2M }}
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
@@ -126,7 +125,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
{{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/join/*" }}
{{- range $tmpl := $tmpls }}
{{- with extend $ "Edge" $e }}
{{- with extend $ "Edge" $e "Receiver" "q" }}
{{- xtemplate $tmpl . }}
{{- end }}
{{- end }}
@@ -261,8 +260,8 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
}
{{- end }}
func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
_spec := {{ $receiver }}.querySpec()
func (q *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
_spec := q.querySpec()
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}}
{{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }}
{{- range $tmpl := $tmpls }}
@@ -274,25 +273,25 @@ func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error
_spec.Unique = false
_spec.Node.Columns = nil
{{- else }}
_spec.Node.Columns = {{ $receiver }}.ctx.Fields
if len({{ $receiver }}.ctx.Fields) > 0 {
_spec.Node.Columns = q.ctx.Fields
if len(q.ctx.Fields) > 0 {
{{- /* In case of field selection, configure query to unique only if was explicitly set to true. */}}
_spec.Unique = {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique
_spec.Unique = q.ctx.Unique != nil && *q.ctx.Unique
}
{{- end }}
return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec)
return sqlgraph.CountNodes(ctx, q.driver, _spec)
}
func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
func (q *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec({{ $.Package }}.Table, {{ $.Package }}.Columns, {{ if $.HasOneFieldID }}sqlgraph.NewFieldSpec({{ $.Package }}.{{ $.ID.Constant }}, field.{{ $.ID.Type.ConstName }}){{ else }}nil{{ end }})
{{- /* Setup any intermediate queries if exist (traversal path). */}}
_spec.From = {{ $receiver }}.sql
if unique := {{ $receiver }}.ctx.Unique; unique != nil {
_spec.From = q.sql
if unique := q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if {{ $receiver }}.path != nil {
} else if q.path != nil {
_spec.Unique = true
}
if fields := {{ $receiver }}.ctx.Fields; len(fields) > 0 {
if fields := q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
{{- if $.HasOneFieldID }}
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }})
@@ -312,25 +311,25 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
{{- if not $f }}
{{- continue }}
{{- end }}
if {{ $receiver }}.{{ .EagerLoadField }} != nil {
if q.{{ .EagerLoadField }} != nil {
_spec.Node.AddColumnOnce({{ $.Package }}.{{ $f.Constant }})
}
{{- end }}
}
if ps := {{ $receiver }}.predicates; len(ps) > 0 {
if ps := q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
if limit := q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
if offset := q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := {{ $receiver }}.order; len(ps) > 0 {
if ps := q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
@@ -353,22 +352,21 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
{{ define "dialect/sql/query/selector" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector {
func (q *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector {
{{- $builderV := "builder" }}{{ if eq $.Package $builderV }}{{ $builderV = "builderC" }}{{ end }}
{{ $builderV }} := sql.Dialect({{ $receiver }}.driver.Dialect())
{{ $builderV }} := sql.Dialect(q.driver.Dialect())
t1 := {{ $builderV }}.Table({{ $.Package }}.Table)
columns := {{ $receiver }}.ctx.Fields
columns := q.ctx.Fields
if len(columns) == 0 {
columns = {{ $.Package }}.Columns
}
selector := {{ $builderV }}.Select(t1.Columns(columns...)...).From(t1)
if {{ $receiver }}.sql != nil {
selector = {{ $receiver }}.sql
if q.sql != nil {
selector = q.sql
selector.Select(selector.Columns(columns...)...)
}
if {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique {
if q.ctx.Unique != nil && *q.ctx.Unique {
selector.Distinct()
}
{{- /* Allow mutating the sql.Selector by ent extensions or user templates.*/}}
@@ -377,18 +375,18 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
for _, p := range {{ $receiver }}.predicates {
for _, p := range q.predicates {
p(selector)
}
for _, p := range {{ $receiver }}.order {
for _, p := range q.order {
p(selector)
}
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
if offset := q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
if limit := q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
@@ -400,8 +398,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- $n := $ }} {{/* the node we start the query from. */}}
{{- $e := $.Scope.Edge }} {{/* the edge we need to generate the path to. */}}
{{- $ident := $.Scope.Ident -}}
{{- $receiver := $.Scope.Receiver }}
selector := {{ $receiver }}.sqlQuery(ctx)
selector := q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
@@ -422,7 +419,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{ $ident }} = sqlgraph.SetNeighbors({{ $receiver }}.driver.Dialect(), step)
{{ $ident }} = sqlgraph.SetNeighbors(q.driver.Dialect(), step)
{{ end }}
{{/* query/from defines the query generation for an edge query from a given node. */}}
@@ -430,8 +427,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- $n := $ }} {{/* the node we start the query from. */}}
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
{{- $ident := $.Scope.Ident -}}
{{- $receiver := $.Scope.Receiver -}}
id := {{ $receiver }}.ID
id := m.ID
step := sqlgraph.NewStep(
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
@@ -449,7 +445,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
{{ $ident }} = sqlgraph.Neighbors(m.driver.Dialect(), step)
{{ end }}
{{ define "dialect/sql/query/eagerloading/m2massign" }}
@@ -465,8 +461,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{ define "dialect/sql/query/preparecheck" }}
{{- $pkg := $.Scope.Package }}
{{- $receiver := $.Scope.Receiver }}
for _, f := range {{ $receiver }}.ctx.Fields {
for _, f := range q.ctx.Fields {
if !{{ $.Package }}.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("{{ $pkg }}: invalid field %q for query", f)}
}