mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
Most queries are not graph traversals but rather regular table scans, in which case the DISTINCT clause is not needed as duplicates cannot be returned (unless query was modified by the user).
443 lines
15 KiB
Cheetah
443 lines
15 KiB
Cheetah
{{/*
|
|
Copyright 2019-present Facebook Inc. All rights reserved.
|
|
This source code is licensed under the Apache 2.0 license found
|
|
in the LICENSE file in the root directory of this source tree.
|
|
*/}}
|
|
|
|
{{/* gotype: entgo.io/ent/entc/gen.typeScope */}}
|
|
|
|
{{/* Additional fields for the builder. */}}
|
|
{{ define "dialect/sql/query/fields" }}
|
|
{{- with $.UnexportedForeignKeys }}
|
|
withFKs bool
|
|
{{- end }}
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/fields/additional/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- xtemplate $tmpl $ }}
|
|
{{- end }}
|
|
{{- end }}
|
|
{{- end }}
|
|
|
|
{{ define "dialect/sql/query" }}
|
|
{{ $pkg := $.Scope.Package }}
|
|
{{ $builder := pascal $.Scope.Builder }}
|
|
{{ $receiver := receiver $builder }}
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
|
|
var (
|
|
nodes = []*{{ $.Name }}{}
|
|
{{- with $.UnexportedForeignKeys }}
|
|
withFKs = {{ $receiver }}.withFKs
|
|
{{- end }}
|
|
_spec = {{ $receiver }}.querySpec()
|
|
{{- with $.Edges }}
|
|
loadedTypes = [{{ len . }}]bool{
|
|
{{- range $e := . }}
|
|
{{ $receiver }}.{{ $e.EagerLoadField }} != nil,
|
|
{{- end }}
|
|
}
|
|
{{- end }}
|
|
)
|
|
{{- with $.UnexportedForeignKeys }}
|
|
{{- with $.FKEdges }}
|
|
if {{ range $i, $e := . }}{{ if gt $i 0 }} || {{ end }}{{ $receiver }}.{{ $e.EagerLoadField }} != nil{{ end }} {
|
|
withFKs = true
|
|
}
|
|
{{- end }}
|
|
if withFKs {
|
|
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.ForeignKeys...)
|
|
}
|
|
{{- end }}
|
|
_spec.ScanValues = func(columns []string) ([]any, error) {
|
|
return (*{{ $.Name }}).scanValues(nil, columns)
|
|
}
|
|
_spec.Assign = func(columns []string, values []any) error {
|
|
node := &{{ $.Name }}{config: {{ $receiver }}.config}
|
|
nodes = append(nodes, node)
|
|
{{- with $.Edges }}
|
|
node.Edges.loadedTypes = loadedTypes
|
|
{{- end }}
|
|
return node.assignValues(columns, values)
|
|
}
|
|
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates.*/}}
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- xtemplate $tmpl $ }}
|
|
{{- end }}
|
|
{{- end }}
|
|
for i := range hooks {
|
|
hooks[i](ctx, _spec)
|
|
}
|
|
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(nodes) == 0 {
|
|
return nodes, nil
|
|
}
|
|
{{- range $e := $.Edges }}
|
|
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
|
|
if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }}
|
|
func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }},
|
|
func(n *{{ $.Name }}, e *{{ $e.Type.Name }}){ n.Edges.{{ $e.StructField }} = {{ if $e.Unique }}e{{ else }}append(n.Edges.{{ $e.StructField }}, e){{ end }} }); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
{{- end }}
|
|
{{- /* Allow extensions to inject code using templates to process nodes before they are returned. */}}
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/all/nodes/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- xtemplate $tmpl $ }}
|
|
{{- end }}
|
|
{{- end }}
|
|
return nodes, nil
|
|
}
|
|
|
|
{{/* Generate a method to eager-load each edge. */}}
|
|
{{- range $e := $.Edges }}
|
|
func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
|
|
{{- if $e.M2M }}
|
|
edgeIDs := make([]driver.Value, len(nodes))
|
|
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
|
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
|
|
for i, node := range nodes {
|
|
edgeIDs[i] = node.ID
|
|
byID[node.ID] = node
|
|
if init != nil {
|
|
init(node)
|
|
}
|
|
}
|
|
query.Where(func(s *sql.Selector) {
|
|
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/join/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- with extend $ "Edge" $e }}
|
|
{{- xtemplate $tmpl . }}
|
|
{{- end }}
|
|
{{- end }}
|
|
{{- end }}
|
|
{{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }}
|
|
{{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }}
|
|
s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}]))
|
|
s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeIDs...))
|
|
columns := s.SelectedColumns()
|
|
s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]))
|
|
s.AppendSelect(columns...)
|
|
s.SetDistinct(false)
|
|
})
|
|
if err := query.prepareQuery(ctx); err != nil {
|
|
return err
|
|
}
|
|
qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
|
|
return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
|
assign := spec.Assign
|
|
values := spec.ScanValues
|
|
{{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }}
|
|
{{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }}
|
|
spec.ScanValues = func(columns []string) ([]any, error) {
|
|
values, err := values(columns[1:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return append([]any{new({{ $out }})}, values...), nil
|
|
}
|
|
spec.Assign = func(columns []string, values []any) error {
|
|
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
|
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
|
if nids[inValue] == nil {
|
|
nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: {}}
|
|
return assign(columns[1:], values[1:])
|
|
}
|
|
nids[inValue][byID[outValue]] = struct{}{}
|
|
return nil
|
|
}
|
|
})
|
|
})
|
|
neighbors, err := withInterceptors[[]*{{ $e.Type.Name }}](ctx, query, qr, query.inters)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, n := range neighbors {
|
|
nodes, ok := nids[n.ID]
|
|
if !ok {
|
|
return fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
|
|
}
|
|
for kn := range nodes {
|
|
assign(kn, n)
|
|
}
|
|
}
|
|
{{- else if $e.OwnFK }}
|
|
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
|
|
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
|
|
for i := range nodes {
|
|
{{- $fk := $e.ForeignKey }}
|
|
{{- if $fk.Field.Nillable }}
|
|
if nodes[i].{{ $fk.StructField }} == nil {
|
|
continue
|
|
}
|
|
{{- end }}
|
|
fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }}
|
|
if _, ok := nodeids[fk]; !ok {
|
|
ids = append(ids, fk)
|
|
}
|
|
nodeids[fk] = append(nodeids[fk], nodes[i])
|
|
}
|
|
if len(ids) == 0 {
|
|
return nil
|
|
}
|
|
query.Where({{ $e.Type.Package }}.IDIn(ids...))
|
|
neighbors, err := query.All(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, n := range neighbors {
|
|
nodes, ok := nodeids[n.ID]
|
|
if !ok {
|
|
return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID)
|
|
}
|
|
for i := range nodes {
|
|
assign(nodes[i], n)
|
|
}
|
|
}
|
|
{{- else }}
|
|
fks := make([]driver.Value, 0, len(nodes))
|
|
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
|
for i := range nodes {
|
|
fks = append(fks, nodes[i].ID)
|
|
nodeids[nodes[i].ID] = nodes[i]
|
|
{{- if $e.O2M }}
|
|
if init != nil {
|
|
init(nodes[i])
|
|
}
|
|
{{- end }}
|
|
}
|
|
{{- with $e.Type.UnexportedForeignKeys }}
|
|
query.withFKs = true
|
|
{{- end }}
|
|
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
|
|
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
|
|
}))
|
|
neighbors, err := query.All(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, n := range neighbors {
|
|
{{- $fk := $e.ForeignKey }}
|
|
fk := n.{{ $fk.StructField }}
|
|
{{- if $fk.Field.Nillable }}
|
|
if fk == nil {
|
|
return fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID)
|
|
}
|
|
{{- end }}
|
|
node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk]
|
|
if !ok {
|
|
return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }})
|
|
}
|
|
assign(node, n)
|
|
}
|
|
{{- end }}
|
|
return nil
|
|
}
|
|
{{- end }}
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
|
|
_spec := {{ $receiver }}.querySpec()
|
|
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}}
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- xtemplate $tmpl $ }}
|
|
{{- end }}
|
|
{{- end }}
|
|
{{- if $.HasCompositeID }}
|
|
{{- /* In case of an edge schema with composite primary-key, there is no need to SELECT DISTINCT. */}}
|
|
_spec.Unique = false
|
|
_spec.Node.Columns = nil
|
|
{{- else }}
|
|
_spec.Node.Columns = {{ $receiver }}.ctx.Fields
|
|
if len({{ $receiver }}.ctx.Fields) > 0 {
|
|
{{- /* In case of field selection, configure query to unique only if was explicitly set to true. */}}
|
|
_spec.Unique = {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique
|
|
}
|
|
{{- end }}
|
|
return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec)
|
|
}
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
|
|
_spec := sqlgraph.NewQuerySpec({{ $.Package }}.Table, {{ $.Package }}.Columns, {{ if $.HasOneFieldID }}sqlgraph.NewFieldSpec({{ $.Package }}.{{ $.ID.Constant }}, field.{{ $.ID.Type.ConstName }}){{ else }}nil{{ end }})
|
|
{{- /* Setup any intermediate queries if exist (traversal path). */}}
|
|
_spec.From = {{ $receiver }}.sql
|
|
if unique := {{ $receiver }}.ctx.Unique; unique != nil {
|
|
_spec.Unique = *unique
|
|
} else if {{ $receiver }}.path != nil {
|
|
_spec.Unique = true
|
|
}
|
|
if fields := {{ $receiver }}.ctx.Fields; len(fields) > 0 {
|
|
_spec.Node.Columns = make([]string, 0, len(fields))
|
|
{{- if $.HasOneFieldID }}
|
|
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }})
|
|
for i := range fields {
|
|
if fields[i] != {{ $.Package }}.{{ $.ID.Constant }} {
|
|
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
|
}
|
|
}
|
|
{{- else }}
|
|
for i := range fields {
|
|
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
|
}
|
|
{{- end }}
|
|
}
|
|
if ps := {{ $receiver }}.predicates; len(ps) > 0 {
|
|
_spec.Predicate = func(selector *sql.Selector) {
|
|
for i := range ps {
|
|
ps[i](selector)
|
|
}
|
|
}
|
|
}
|
|
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
|
|
_spec.Limit = *limit
|
|
}
|
|
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
|
|
_spec.Offset = *offset
|
|
}
|
|
if ps := {{ $receiver }}.order; len(ps) > 0 {
|
|
_spec.Order = func(selector *sql.Selector) {
|
|
for i := range ps {
|
|
ps[i](selector)
|
|
}
|
|
}
|
|
}
|
|
return _spec
|
|
}
|
|
|
|
{{ template "dialect/sql/query/selector" $ }}
|
|
|
|
|
|
{{- /* Allow adding methods to the query-builder by ent extensions or user templates.*/}}
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/additional/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- xtemplate $tmpl $ }}
|
|
{{- end }}
|
|
{{- end }}
|
|
|
|
{{ end }}
|
|
|
|
{{ define "dialect/sql/query/selector" }}
|
|
{{ $builder := pascal $.Scope.Builder }}
|
|
{{ $receiver := receiver $builder }}
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector {
|
|
builder := sql.Dialect({{ $receiver }}.driver.Dialect())
|
|
t1 := builder.Table({{ $.Package }}.Table)
|
|
columns := {{ $receiver }}.ctx.Fields
|
|
if len(columns) == 0 {
|
|
columns = {{ $.Package }}.Columns
|
|
}
|
|
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
|
if {{ $receiver }}.sql != nil {
|
|
selector = {{ $receiver }}.sql
|
|
selector.Select(selector.Columns(columns...)...)
|
|
}
|
|
if {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique {
|
|
selector.Distinct()
|
|
}
|
|
{{- /* Allow mutating the sql.Selector by ent extensions or user templates.*/}}
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/selector/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- xtemplate $tmpl $ }}
|
|
{{- end }}
|
|
{{- end }}
|
|
for _, p := range {{ $receiver }}.predicates {
|
|
p(selector)
|
|
}
|
|
for _, p := range {{ $receiver }}.order {
|
|
p(selector)
|
|
}
|
|
if offset := {{ $receiver }}.ctx.Offset; offset != nil {
|
|
// limit is mandatory for offset clause. We start
|
|
// with default value, and override it below if needed.
|
|
selector.Offset(*offset).Limit(math.MaxInt32)
|
|
}
|
|
if limit := {{ $receiver }}.ctx.Limit; limit != nil {
|
|
selector.Limit(*limit)
|
|
}
|
|
return selector
|
|
}
|
|
{{ end }}
|
|
|
|
{{/* query/path defines the query generation for path of a given edge. */}}
|
|
{{ define "dialect/sql/query/path" }}
|
|
{{- $n := $ }} {{/* the node we start the query from. */}}
|
|
{{- $e := $.Scope.Edge }} {{/* the edge we need to generate the path to. */}}
|
|
{{- $ident := $.Scope.Ident -}}
|
|
{{- $receiver := $.Scope.Receiver }}
|
|
selector := {{ $receiver }}.sqlQuery(ctx)
|
|
if err := selector.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
step := sqlgraph.NewStep(
|
|
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ if $n.HasCompositeID }}{{ $e.ColumnConstant }}{{ else }}{{ $n.ID.Constant }}{{ end }}, selector),
|
|
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
|
|
sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }},
|
|
{{- if $e.M2M -}}
|
|
{{ $n.Package }}.{{ $e.PKConstant }}...
|
|
{{- else -}}
|
|
{{ $n.Package }}.{{ $e.ColumnConstant }}
|
|
{{- end -}}
|
|
),
|
|
)
|
|
{{- /* Allow mutating the sqlgraph.Step by ent extensions or user templates.*/}}
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/path/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- xtemplate $tmpl $ }}
|
|
{{- end }}
|
|
{{- end }}
|
|
{{ $ident }} = sqlgraph.SetNeighbors({{ $receiver }}.driver.Dialect(), step)
|
|
{{ end }}
|
|
|
|
{{/* query/from defines the query generation for an edge query from a given node. */}}
|
|
{{ define "dialect/sql/query/from" }}
|
|
{{- $n := $ }} {{/* the node we start the query from. */}}
|
|
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
|
|
{{- $ident := $.Scope.Ident -}}
|
|
{{- $receiver := $.Scope.Receiver -}}
|
|
id := {{ $receiver }}.ID
|
|
step := sqlgraph.NewStep(
|
|
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
|
|
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
|
|
sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }},
|
|
{{- if $e.M2M -}}
|
|
{{ $n.Package }}.{{ $e.PKConstant }}...
|
|
{{- else -}}
|
|
{{ $n.Package }}.{{ $e.ColumnConstant }}
|
|
{{- end -}}
|
|
),
|
|
)
|
|
{{- /* Allow mutating the sqlgraph.Step by ent extensions or user templates.*/}}
|
|
{{- with $tmpls := matchTemplate "dialect/sql/query/from/*" }}
|
|
{{- range $tmpl := $tmpls }}
|
|
{{- xtemplate $tmpl $ }}
|
|
{{- end }}
|
|
{{- end }}
|
|
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
|
|
{{ end }}
|
|
|
|
{{ define "dialect/sql/query/eagerloading/m2massign" }}
|
|
{{- $arg := $.Scope.Arg }}
|
|
{{- $field := $.Scope.Field }}
|
|
{{- $scantype := $.Scope.ScanType }}
|
|
{{- if hasPrefix $scantype "sql" -}}
|
|
{{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}}
|
|
{{- else -}}
|
|
{{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }}
|
|
{{- end }}
|
|
{{- end }}
|
|
|
|
{{ define "dialect/sql/query/preparecheck" }}
|
|
{{- $pkg := $.Scope.Package }}
|
|
{{- $receiver := $.Scope.Receiver }}
|
|
for _, f := range {{ $receiver }}.ctx.Fields {
|
|
if !{{ $.Package }}.ValidColumn(f) {
|
|
return &ValidationError{Name: f, err: fmt.Errorf("{{ $pkg }}: invalid field %q for query", f)}
|
|
}
|
|
}
|
|
{{- end }}
|