{{/* Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. */}} {{/* gotype: entgo.io/ent/entc/gen.typeScope */}} {{/* Additional fields for the builder. */}} {{ define "dialect/sql/query/fields" }} {{- with $.UnexportedForeignKeys }} withFKs bool {{- end }} {{- with $tmpls := matchTemplate "dialect/sql/query/fields/additional/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{- end }} {{ define "dialect/sql/query" }} {{ $pkg := $.Scope.Package }} {{ $builder := pascal $.Scope.Builder }} {{ $receiver := receiver $builder }} func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) { var ( nodes = []*{{ $.Name }}{} {{- with $.UnexportedForeignKeys }} withFKs = {{ $receiver }}.withFKs {{- end }} _spec = {{ $receiver }}.querySpec() {{- with $.Edges }} loadedTypes = [{{ len . }}]bool{ {{- range $e := . }} {{ $receiver }}.{{ $e.EagerLoadField }} != nil, {{- end }} } {{- end }} ) {{- with $.UnexportedForeignKeys }} {{- $edgesWithoutField := list }} {{- range $.FKEdges }}{{ if not .Field }}{{ $edgesWithoutField = append $edgesWithoutField . }}{{ end }}{{ end }} {{- if $edgesWithoutField }} if {{ range $i, $e := $edgesWithoutField }}{{ if $i }} || {{ end }}{{ $receiver }}.{{ $e.EagerLoadField }} != nil{{ end }} { withFKs = true } {{- end }} if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.ForeignKeys...) } {{- end }} _spec.ScanValues = func(columns []string) ([]any, error) { return (*{{ $.Name }}).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []any) error { node := &{{ $.Name }}{config: {{ $receiver }}.config} nodes = append(nodes, node) {{- with $.Edges }} node.Edges.loadedTypes = loadedTypes {{- end }} return node.assignValues(columns, values) } {{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} for i := range hooks { hooks[i](ctx, _spec) } if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil { return nil, err } if len(nodes) == 0 { return nodes, nil } {{- range $e := $.Edges }} if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil { if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if $e.Unique }}nil{{ else }} func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }}, func(n *{{ $.Name }}, e *{{ $e.Type.Name }}){ n.Edges.{{ $e.StructField }} = {{ if $e.Unique }}e{{ else }}append(n.Edges.{{ $e.StructField }}, e){{ end }} }); err != nil { return nil, err } } {{- end }} {{- /* Allow extensions to inject code using templates to process nodes before they are returned. */}} {{- with $tmpls := matchTemplate "dialect/sql/query/all/nodes/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} return nodes, nil } {{/* Generate a method to eager-load each edge. */}} {{- range $e := $.Edges }} func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error { {{- if $e.M2M }} edgeIDs := make([]driver.Value, len(nodes)) byID := make(map[{{ $.ID.Type }}]*{{ $.Name }}) nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{}) for i, node := range nodes { edgeIDs[i] = node.ID byID[node.ID] = node if init != nil { init(node) } } query.Where(func(s *sql.Selector) { joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }}) {{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/join/*" }} {{- range $tmpl := $tmpls }} {{- with extend $ "Edge" $e }} {{- xtemplate $tmpl . }} {{- end }} {{- end }} {{- end }} {{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }} {{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }} s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}])) s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeIDs...)) columns := s.SelectedColumns() s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}])) s.AppendSelect(columns...) s.SetDistinct(false) }) if err := query.prepareQuery(ctx); err != nil { return err } qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { assign := spec.Assign values := spec.ScanValues {{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }} {{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }} spec.ScanValues = func(columns []string) ([]any, error) { values, err := values(columns[1:]) if err != nil { return nil, err } return append([]any{new({{ $out }})}, values...), nil } spec.Assign = func(columns []string, values []any) error { outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} if nids[inValue] == nil { nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: {}} return assign(columns[1:], values[1:]) } nids[inValue][byID[outValue]] = struct{}{} return nil } }) }) neighbors, err := withInterceptors[[]*{{ $e.Type.Name }}](ctx, query, qr, query.inters) if err != nil { return err } for _, n := range neighbors { nodes, ok := nids[n.ID] if !ok { return fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID) } for kn := range nodes { assign(kn, n) } } {{- else if $e.OwnFK }} ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes)) nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }}) for i := range nodes { {{- $fk := $e.ForeignKey }} {{- if $fk.Field.Nillable }} if nodes[i].{{ $fk.StructField }} == nil { continue } {{- end }} fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }} if _, ok := nodeids[fk]; !ok { ids = append(ids, fk) } nodeids[fk] = append(nodeids[fk], nodes[i]) } if len(ids) == 0 { return nil } query.Where({{ $e.Type.Package }}.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { return err } for _, n := range neighbors { nodes, ok := nodeids[n.ID] if !ok { return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID) } for i := range nodes { assign(nodes[i], n) } } {{- else }} fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }}) for i := range nodes { fks = append(fks, nodes[i].ID) nodeids[nodes[i].ID] = nodes[i] {{- if $e.O2M }} if init != nil { init(nodes[i]) } {{- end }} } {{- with $e.Type.UnexportedForeignKeys }} query.withFKs = true {{- end }} query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) { s.Where(sql.InValues(s.C({{ $.Package }}.{{ $e.ColumnConstant }}), fks...)) })) neighbors, err := query.All(ctx) if err != nil { return err } for _, n := range neighbors { {{- $fk := $e.ForeignKey }} fk := n.{{ $fk.StructField }} {{- if $fk.Field.Nillable }} if fk == nil { return fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID) } {{- end }} node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk] if !ok { return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }}) } assign(node, n) } {{- end }} return nil } {{- end }} func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) { _spec := {{ $receiver }}.querySpec() {{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}} {{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{- if $.HasCompositeID }} {{- /* In case of an edge schema with composite primary-key, there is no need to SELECT DISTINCT. */}} _spec.Unique = false _spec.Node.Columns = nil {{- else }} _spec.Node.Columns = {{ $receiver }}.ctx.Fields if len({{ $receiver }}.ctx.Fields) > 0 { {{- /* In case of field selection, configure query to unique only if was explicitly set to true. */}} _spec.Unique = {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique } {{- end }} return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec) } func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec { _spec := sqlgraph.NewQuerySpec({{ $.Package }}.Table, {{ $.Package }}.Columns, {{ if $.HasOneFieldID }}sqlgraph.NewFieldSpec({{ $.Package }}.{{ $.ID.Constant }}, field.{{ $.ID.Type.ConstName }}){{ else }}nil{{ end }}) {{- /* Setup any intermediate queries if exist (traversal path). */}} _spec.From = {{ $receiver }}.sql if unique := {{ $receiver }}.ctx.Unique; unique != nil { _spec.Unique = *unique } else if {{ $receiver }}.path != nil { _spec.Unique = true } if fields := {{ $receiver }}.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) {{- if $.HasOneFieldID }} _spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }}) for i := range fields { if fields[i] != {{ $.Package }}.{{ $.ID.Constant }} { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } {{- else }} for i := range fields { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } {{- end }} {{- /* In case of custom field selection, ensure edge-fields are loaded in case their edges were loaded. */}} {{- range $.FKEdges }} {{- $f := .Field }} {{- if not $f }} {{- continue }} {{- end }} if {{ $receiver }}.{{ .EagerLoadField }} != nil { _spec.Node.AddColumnOnce({{ $.Package }}.{{ $f.Constant }}) } {{- end }} } if ps := {{ $receiver }}.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { ps[i](selector) } } } if limit := {{ $receiver }}.ctx.Limit; limit != nil { _spec.Limit = *limit } if offset := {{ $receiver }}.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := {{ $receiver }}.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { ps[i](selector) } } } return _spec } {{ template "dialect/sql/query/selector" $ }} {{- /* Allow adding methods to the query-builder by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/additional/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{ end }} {{ define "dialect/sql/query/selector" }} {{ $builder := pascal $.Scope.Builder }} {{ $receiver := receiver $builder }} func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect({{ $receiver }}.driver.Dialect()) t1 := builder.Table({{ $.Package }}.Table) columns := {{ $receiver }}.ctx.Fields if len(columns) == 0 { columns = {{ $.Package }}.Columns } selector := builder.Select(t1.Columns(columns...)...).From(t1) if {{ $receiver }}.sql != nil { selector = {{ $receiver }}.sql selector.Select(selector.Columns(columns...)...) } if {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique { selector.Distinct() } {{- /* Allow mutating the sql.Selector by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/selector/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} for _, p := range {{ $receiver }}.predicates { p(selector) } for _, p := range {{ $receiver }}.order { p(selector) } if offset := {{ $receiver }}.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } if limit := {{ $receiver }}.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector } {{ end }} {{/* query/path defines the query generation for path of a given edge. */}} {{ define "dialect/sql/query/path" }} {{- $n := $ }} {{/* the node we start the query from. */}} {{- $e := $.Scope.Edge }} {{/* the edge we need to generate the path to. */}} {{- $ident := $.Scope.Ident -}} {{- $receiver := $.Scope.Receiver }} selector := {{ $receiver }}.sqlQuery(ctx) if err := selector.Err(); err != nil { return nil, err } step := sqlgraph.NewStep( sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ if $n.HasCompositeID }}{{ $e.ColumnConstant }}{{ else }}{{ $n.ID.Constant }}{{ end }}, selector), sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}), sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }}, {{- if $e.M2M -}} {{ $n.Package }}.{{ $e.PKConstant }}... {{- else -}} {{ $n.Package }}.{{ $e.ColumnConstant }} {{- end -}} ), ) {{- /* Allow mutating the sqlgraph.Step by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/path/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{ $ident }} = sqlgraph.SetNeighbors({{ $receiver }}.driver.Dialect(), step) {{ end }} {{/* query/from defines the query generation for an edge query from a given node. */}} {{ define "dialect/sql/query/from" }} {{- $n := $ }} {{/* the node we start the query from. */}} {{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}} {{- $ident := $.Scope.Ident -}} {{- $receiver := $.Scope.Receiver -}} id := {{ $receiver }}.ID step := sqlgraph.NewStep( sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id), sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}), sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }}, {{- if $e.M2M -}} {{ $n.Package }}.{{ $e.PKConstant }}... {{- else -}} {{ $n.Package }}.{{ $e.ColumnConstant }} {{- end -}} ), ) {{- /* Allow mutating the sqlgraph.Step by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/from/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step) {{ end }} {{ define "dialect/sql/query/eagerloading/m2massign" }} {{- $arg := $.Scope.Arg }} {{- $field := $.Scope.Field }} {{- $scantype := $.Scope.ScanType }} {{- if hasPrefix $scantype "sql" -}} {{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}} {{- else -}} {{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }} {{- end }} {{- end }} {{ define "dialect/sql/query/preparecheck" }} {{- $pkg := $.Scope.Package }} {{- $receiver := $.Scope.Receiver }} for _, f := range {{ $receiver }}.ctx.Fields { if !{{ $.Package }}.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("{{ $pkg }}: invalid field %q for query", f)} } } {{- end }}