{{/* Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. */}} {{/* gotype: entgo.io/ent/entc/gen.typeScope */}} {{/* Additional fields for the builder. */}} {{ define "dialect/sql/query/fields" }} {{- with $.UnexportedForeignKeys }} withFKs bool {{- end }} {{- with $tmpls := matchTemplate "dialect/sql/query/fields/additional/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{- end }} {{ define "dialect/sql/query" }} {{ $pkg := $.Scope.Package }} {{ $builder := pascal $.Scope.Builder }} {{ $receiver := receiver $builder }} func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name }}, error) { var ( nodes = []*{{ $.Name }}{} {{- with $.UnexportedForeignKeys }} withFKs = {{ $receiver }}.withFKs {{- end }} _spec = {{ $receiver }}.querySpec() {{- with $.Edges }} loadedTypes = [{{ len . }}]bool{ {{- range $e := . }} {{ $receiver }}.{{ $e.EagerLoadField }} != nil, {{- end }} } {{- end }} ) {{- with $.UnexportedForeignKeys }} {{- with $.FKEdges }} if {{ range $i, $e := . }}{{ if gt $i 0 }} || {{ end }}{{ $receiver }}.{{ $e.EagerLoadField }} != nil{{ end }} { withFKs = true } {{- end }} if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.ForeignKeys...) } {{- end }} _spec.ScanValues = func(columns []string) ([]interface{}, error) { node := &{{ $.Name }}{config: {{ $receiver }}.config} nodes = append(nodes, node) return node.scanValues(columns) } _spec.Assign = func(columns []string, values []interface{}) error { if len(nodes) == 0 { return fmt.Errorf("{{ $pkg }}: Assign called without calling ScanValues") } node := nodes[len(nodes)-1] {{- with $.Edges }} node.Edges.loadedTypes = loadedTypes {{- end }} return node.assignValues(columns, values) } {{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil { return nil, err } if len(nodes) == 0 { return nodes, nil } {{- range $e := $.Edges }} {{- with extend $ "Rec" $receiver "Edge" $e }} {{ template "dialect/sql/query/eagerloading" . }} {{- end }} {{- end }} return nodes, nil } func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) { _spec := {{ $receiver }}.querySpec() {{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/spec/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec) } func ({{ $receiver }} *{{ $builder }}) sqlExist(ctx context.Context) (bool, error) { n, err := {{ $receiver }}.sqlCount(ctx) if err != nil { return false, fmt.Errorf("{{ $pkg }}: check existence: %w", err) } return n > 0, nil } func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec { _spec := &sqlgraph.QuerySpec{ Node: &sqlgraph.NodeSpec{ Table: {{ $.Package }}.Table, Columns: {{ $.Package }}.Columns, ID: &sqlgraph.FieldSpec{ Type: field.{{ $.ID.Type.ConstName }}, Column: {{ $.Package }}.{{ $.ID.Constant }}, }, }, From: {{ $receiver }}.sql, Unique: true, } if unique := {{ $receiver }}.unique; unique != nil { _spec.Unique = *unique } if fields := {{ $receiver }}.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }}) for i := range fields { if fields[i] != {{ $.Package }}.{{ $.ID.Constant }} { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } } if ps := {{ $receiver }}.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { ps[i](selector) } } } if limit := {{ $receiver }}.limit; limit != nil { _spec.Limit = *limit } if offset := {{ $receiver }}.offset; offset != nil { _spec.Offset = *offset } if ps := {{ $receiver }}.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { ps[i](selector) } } } return _spec } {{ template "dialect/sql/query/selector" $ }} {{- /* Allow adding methods to the query-builder by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/additional/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{ end }} {{ define "dialect/sql/query/selector" }} {{ $builder := pascal $.Scope.Builder }} {{ $receiver := receiver $builder }} func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect({{ $receiver }}.driver.Dialect()) t1 := builder.Table({{ $.Package }}.Table) columns := {{ $receiver }}.fields if len(columns) == 0 { columns = {{ $.Package }}.Columns } selector := builder.Select(t1.Columns(columns...)...).From(t1) if {{ $receiver }}.sql != nil { selector = {{ $receiver }}.sql selector.Select(selector.Columns(columns...)...) } {{- /* Allow mutating the sql.Selector by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/selector/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} for _, p := range {{ $receiver }}.predicates { p(selector) } for _, p := range {{ $receiver }}.order { p(selector) } if offset := {{ $receiver }}.offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } if limit := {{ $receiver }}.limit; limit != nil { selector.Limit(*limit) } return selector } {{ end }} {{/* query/path defines the query generation for path of a given edge. */}} {{ define "dialect/sql/query/path" }} {{- $n := $ }} {{/* the node we start the query from. */}} {{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}} {{- $ident := $.Scope.Ident -}} {{- $receiver := $.Scope.Receiver }} selector := {{ $receiver }}.sqlQuery(ctx) if err := selector.Err(); err != nil { return nil, err } step := sqlgraph.NewStep( sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, selector), sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}), sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }}, {{- if $e.M2M -}} {{ $n.Package }}.{{ $e.PKConstant }}... {{- else -}} {{ $n.Package }}.{{ $e.ColumnConstant }} {{- end -}} ), ) {{- /* Allow mutating the sqlgraph.Step by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/path/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{ $ident }} = sqlgraph.SetNeighbors({{ $receiver }}.driver.Dialect(), step) {{ end }} {{/* query/from defines the query generation for an edge query from a given node. */}} {{ define "dialect/sql/query/from" }} {{- $n := $ }} {{/* the node we start the query from. */}} {{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}} {{- $ident := $.Scope.Ident -}} {{- $receiver := $.Scope.Receiver -}} id := {{ $receiver }}.ID step := sqlgraph.NewStep( sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id), sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}), sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }}, {{- if $e.M2M -}} {{ $n.Package }}.{{ $e.PKConstant }}... {{- else -}} {{ $n.Package }}.{{ $e.ColumnConstant }} {{- end -}} ), ) {{- /* Allow mutating the sqlgraph.Step by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/from/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} {{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step) {{ end }} {{ define "dialect/sql/query/eagerloading" }} {{- $e := $.Scope.Edge }} {{- $receiver := $.Scope.Rec }} if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil { {{- if $e.M2M }} fks := make([]driver.Value, 0, len(nodes)) ids := make(map[{{ $.ID.Type }}]*{{ $.Name }}, len(nodes)) for _, node := range nodes { ids[node.ID] = node fks = append(fks, node.ID) node.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} } var ( edgeids []{{ $e.Type.ID.Type }} edges = make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }}) ) _spec := &sqlgraph.EdgeQuerySpec{ Edge: &sqlgraph.EdgeSpec{ Inverse: {{ $e.IsInverse }}, Table: {{ $.Package }}.{{ $e.TableConstant }}, Columns: {{ $.Package }}.{{ $e.PKConstant }}, }, Predicate: func(s *sql.Selector) { s.Where(sql.InValues({{ $.Package }}.{{ $e.PKConstant }}[{{ if $e.IsInverse }}1{{ else }}0{{ end }}], fks...)) }, {{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }} {{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }} ScanValues: func() [2]interface{}{ return [2]interface{}{new({{ $out }}), new({{ $in }})} }, Assign: func(out, in interface{}) error { eout, ok := out.(*{{ $out }}) if !ok || eout == nil { return fmt.Errorf("unexpected id value for edge-out") } ein, ok := in.(*{{ $in }}) if !ok || ein == nil { return fmt.Errorf("unexpected id value for edge-in") } outValue := {{ with extend $ "Arg" "eout" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} inValue := {{ with extend $ "Arg" "ein" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} node, ok := ids[outValue] if !ok { return fmt.Errorf("unexpected node id in edges: %v", outValue) } if _, ok := edges[inValue]; !ok { edgeids = append(edgeids, inValue) } edges[inValue] = append(edges[inValue], node) return nil }, } {{- /* Allow mutating the sqlgraph.EdgeQuerySpec by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/spec/*" }} {{- range $tmpl := $tmpls }} {{- xtemplate $tmpl $ }} {{- end }} {{- end }} if err := sqlgraph.QueryEdges(ctx, {{ $receiver }}.driver, _spec); err != nil { return nil, fmt.Errorf(`query edges "{{ $e.Name }}": %w`, err) } query.Where({{ $e.Type.Package }}.IDIn(edgeids...)) neighbors, err := query.All(ctx) if err != nil { return nil, err } for _, n := range neighbors { nodes, ok := edges[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID) } for i := range nodes { nodes[i].Edges.{{ $e.StructField }} = append(nodes[i].Edges.{{ $e.StructField }}, n) } } {{- else if $e.OwnFK }} ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes)) nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }}) for i := range nodes { {{- $fk := $e.ForeignKey }} {{- if $fk.Field.Nillable }} if nodes[i].{{ $fk.StructField }} == nil { continue } {{- end }} fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }} if _, ok := nodeids[fk]; !ok { ids = append(ids, fk) } nodeids[fk] = append(nodeids[fk], nodes[i]) } query.Where({{ $e.Type.Package }}.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { return nil, err } for _, n := range neighbors { nodes, ok := nodeids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID) } for i := range nodes { nodes[i].Edges.{{ $e.StructField }} = n } } {{- else }} fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }}) for i := range nodes { fks = append(fks, nodes[i].ID) nodeids[nodes[i].ID] = nodes[i] {{- if $e.O2M }} nodes[i].Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} {{- end }} } {{- with $e.Type.UnexportedForeignKeys }} query.withFKs = true {{- end }} query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) { s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...)) })) neighbors, err := query.All(ctx) if err != nil { return nil, err } for _, n := range neighbors { {{- $fk := $e.ForeignKey }} fk := n.{{ $fk.StructField }} {{- if $fk.Field.Nillable }} if fk == nil { return nil, fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID) } {{- end }} node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk] if !ok { return nil, fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n.ID) } node.Edges.{{ $e.StructField }} = {{ if $e.Unique }}n{{ else }}append(node.Edges.{{ $e.StructField }}, n){{ end }} } {{- end }} } {{ end }} {{ define "dialect/sql/query/eagerloading/m2massign" }} {{- $arg := $.Scope.Arg }} {{- $field := $.Scope.Field }} {{- $scantype := $.Scope.ScanType }} {{- if hasPrefix $scantype "sql" -}} {{ $field.ScanTypeField $arg -}} {{- else -}} {{ if not $field.Nillable }}*{{ end }}{{ $arg }} {{- end }} {{- end }} {{ define "dialect/sql/query/preparecheck" }} {{- $pkg := $.Scope.Package }} {{- $receiver := $.Scope.Receiver }} for _, f := range {{ $receiver }}.fields { if !{{ $.Package }}.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("{{ $pkg }}: invalid field %q for query", f)} } } {{- end }}