Files
ent/entc/gen/template/dialect/sql/query.tmpl
Ariel Mashraki 939c7cff1a entc/gen: reduce the usage of DISTINCT in queries (#3305)
Most queries are not graph traversals but rather regular table scans,
in which case the DISTINCT clause is not needed as duplicates cannot be
returned (unless query was modified by the user).
2023-02-06 22:40:50 +02:00

443 lines
15 KiB
Cheetah

{{/*
Copyright 2019-present Facebook Inc. All rights reserved.
This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}
{{/* gotype: entgo.io/ent/entc/gen.typeScope */}}
{{/* Additional fields for the builder. */}}
{{ define "dialect/sql/query/fields" }}
{{- with $.UnexportedForeignKeys }}
withFKs bool
{{- end }}
{{- with $tmpls := matchTemplate "dialect/sql/query/fields/additional/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{- end }}
{{ define "dialect/sql/query" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
var (
nodes = []*{{ $.Name }}{}
{{- with $.UnexportedForeignKeys }}
withFKs = {{ $receiver }}.withFKs
{{- end }}
_spec = {{ $receiver }}.querySpec()
{{- with $.Edges }}
loadedTypes = [{{ len . }}]bool{
{{- range $e := . }}
{{ $receiver }}.{{ $e.EagerLoadField }} != nil,
{{- end }}
}
{{- end }}
)
{{- with $.UnexportedForeignKeys }}
{{- with $.FKEdges }}
if {{ range $i, $e := . }}{{ if gt $i 0 }} || {{ end }}{{ $receiver }}.{{ $e.EagerLoadField }} != nil{{ end }} {
withFKs = true
}
{{- end }}
if withFKs {
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.ForeignKeys...)
}
{{- end }}
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*{{ $.Name }}).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &{{ $.Name }}{config: {{ $receiver }}.config}
nodes = append(nodes, node)
{{- with $.Edges }}
node.Edges.loadedTypes = loadedTypes
{{- end }}
return node.assignValues(columns, values)
}
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates.*/}}
{{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
{{- range $e := $.Edges }}
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }}
func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }},
func(n *{{ $.Name }}, e *{{ $e.Type.Name }}){ n.Edges.{{ $e.StructField }} = {{ if $e.Unique }}e{{ else }}append(n.Edges.{{ $e.StructField }}, e){{ end }} }); err != nil {
return nil, err
}
}
{{- end }}
{{- /* Allow extensions to inject code using templates to process nodes before they are returned. */}}
{{- with $tmpls := matchTemplate "dialect/sql/query/all/nodes/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
return nodes, nil
}
{{/* Generate a method to eager-load each edge. */}}
{{- range $e := $.Edges }}
func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
{{- if $e.M2M }}
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
for i, node := range nodes {
edgeIDs[i] = node.ID
byID[node.ID] = node
if init != nil {
init(node)
}
}
query.Where(func(s *sql.Selector) {
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
{{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/join/*" }}
{{- range $tmpl := $tmpls }}
{{- with extend $ "Edge" $e }}
{{- xtemplate $tmpl . }}
{{- end }}
{{- end }}
{{- end }}
{{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }}
{{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }}
s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}]))
s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeIDs...))
columns := s.SelectedColumns()
s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
if err := query.prepareQuery(ctx); err != nil {
return err
}
qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
{{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }}
{{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }}
spec.ScanValues = func(columns []string) ([]any, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
return append([]any{new({{ $out }})}, values...), nil
}
spec.Assign = func(columns []string, values []any) error {
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
if nids[inValue] == nil {
nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: {}}
return assign(columns[1:], values[1:])
}
nids[inValue][byID[outValue]] = struct{}{}
return nil
}
})
})
neighbors, err := withInterceptors[[]*{{ $e.Type.Name }}](ctx, query, qr, query.inters)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nids[n.ID]
if !ok {
return fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
}
for kn := range nodes {
assign(kn, n)
}
}
{{- else if $e.OwnFK }}
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
for i := range nodes {
{{- $fk := $e.ForeignKey }}
{{- if $fk.Field.Nillable }}
if nodes[i].{{ $fk.StructField }} == nil {
continue
}
{{- end }}
fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }}
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where({{ $e.Type.Package }}.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
{{- else }}
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
{{- if $e.O2M }}
if init != nil {
init(nodes[i])
}
{{- end }}
}
{{- with $e.Type.UnexportedForeignKeys }}
query.withFKs = true
{{- end }}
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
{{- $fk := $e.ForeignKey }}
fk := n.{{ $fk.StructField }}
{{- if $fk.Field.Nillable }}
if fk == nil {
return fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID)
}
{{- end }}
node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk]
if !ok {
return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }})
}
assign(node, n)
}
{{- end }}
return nil
}
{{- end }}
func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
_spec := {{ $receiver }}.querySpec()
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}}
{{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{- if $.HasCompositeID }}
{{- /* In case of an edge schema with composite primary-key, there is no need to SELECT DISTINCT. */}}
_spec.Unique = false
_spec.Node.Columns = nil
{{- else }}
_spec.Node.Columns = {{ $receiver }}.ctx.Fields
if len({{ $receiver }}.ctx.Fields) > 0 {
{{- /* In case of field selection, configure query to unique only if was explicitly set to true. */}}
_spec.Unique = {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique
}
{{- end }}
return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec)
}
func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec({{ $.Package }}.Table, {{ $.Package }}.Columns, {{ if $.HasOneFieldID }}sqlgraph.NewFieldSpec({{ $.Package }}.{{ $.ID.Constant }}, field.{{ $.ID.Type.ConstName }}){{ else }}nil{{ end }})
{{- /* Setup any intermediate queries if exist (traversal path). */}}
_spec.From = {{ $receiver }}.sql
if unique := {{ $receiver }}.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if {{ $receiver }}.path != nil {
_spec.Unique = true
}
if fields := {{ $receiver }}.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
{{- if $.HasOneFieldID }}
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }})
for i := range fields {
if fields[i] != {{ $.Package }}.{{ $.ID.Constant }} {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
{{- else }}
for i := range fields {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
{{- end }}
}
if ps := {{ $receiver }}.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := {{ $receiver }}.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
{{ template "dialect/sql/query/selector" $ }}
{{- /* Allow adding methods to the query-builder by ent extensions or user templates.*/}}
{{- with $tmpls := matchTemplate "dialect/sql/query/additional/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{ end }}
{{ define "dialect/sql/query/selector" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect({{ $receiver }}.driver.Dialect())
t1 := builder.Table({{ $.Package }}.Table)
columns := {{ $receiver }}.ctx.Fields
if len(columns) == 0 {
columns = {{ $.Package }}.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if {{ $receiver }}.sql != nil {
selector = {{ $receiver }}.sql
selector.Select(selector.Columns(columns...)...)
}
if {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique {
selector.Distinct()
}
{{- /* Allow mutating the sql.Selector by ent extensions or user templates.*/}}
{{- with $tmpls := matchTemplate "dialect/sql/query/selector/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
for _, p := range {{ $receiver }}.predicates {
p(selector)
}
for _, p := range {{ $receiver }}.order {
p(selector)
}
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
{{ end }}
{{/* query/path defines the query generation for path of a given edge. */}}
{{ define "dialect/sql/query/path" }}
{{- $n := $ }} {{/* the node we start the query from. */}}
{{- $e := $.Scope.Edge }} {{/* the edge we need to generate the path to. */}}
{{- $ident := $.Scope.Ident -}}
{{- $receiver := $.Scope.Receiver }}
selector := {{ $receiver }}.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ if $n.HasCompositeID }}{{ $e.ColumnConstant }}{{ else }}{{ $n.ID.Constant }}{{ end }}, selector),
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }},
{{- if $e.M2M -}}
{{ $n.Package }}.{{ $e.PKConstant }}...
{{- else -}}
{{ $n.Package }}.{{ $e.ColumnConstant }}
{{- end -}}
),
)
{{- /* Allow mutating the sqlgraph.Step by ent extensions or user templates.*/}}
{{- with $tmpls := matchTemplate "dialect/sql/query/path/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{ $ident }} = sqlgraph.SetNeighbors({{ $receiver }}.driver.Dialect(), step)
{{ end }}
{{/* query/from defines the query generation for an edge query from a given node. */}}
{{ define "dialect/sql/query/from" }}
{{- $n := $ }} {{/* the node we start the query from. */}}
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
{{- $ident := $.Scope.Ident -}}
{{- $receiver := $.Scope.Receiver -}}
id := {{ $receiver }}.ID
step := sqlgraph.NewStep(
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }},
{{- if $e.M2M -}}
{{ $n.Package }}.{{ $e.PKConstant }}...
{{- else -}}
{{ $n.Package }}.{{ $e.ColumnConstant }}
{{- end -}}
),
)
{{- /* Allow mutating the sqlgraph.Step by ent extensions or user templates.*/}}
{{- with $tmpls := matchTemplate "dialect/sql/query/from/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
{{ end }}
{{ define "dialect/sql/query/eagerloading/m2massign" }}
{{- $arg := $.Scope.Arg }}
{{- $field := $.Scope.Field }}
{{- $scantype := $.Scope.ScanType }}
{{- if hasPrefix $scantype "sql" -}}
{{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}}
{{- else -}}
{{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }}
{{- end }}
{{- end }}
{{ define "dialect/sql/query/preparecheck" }}
{{- $pkg := $.Scope.Package }}
{{- $receiver := $.Scope.Receiver }}
for _, f := range {{ $receiver }}.ctx.Fields {
if !{{ $.Package }}.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("{{ $pkg }}: invalid field %q for query", f)}
}
}
{{- end }}