Files
ent/entc/gen/template/dialect/sql/query.tmpl
2020-03-07 22:11:25 +02:00

334 lines
11 KiB
Cheetah

{{/*
Copyright 2019-present Facebook Inc. All rights reserved.
This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}
{{/* Additional fields for the builder. */}}
{{ define "dialect/sql/query/fields" }}
{{- with $.ForeignKeys }}
withFKs bool
{{- end }}
{{- end }}
{{ define "dialect/sql/query" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name }}, error) {
var (
nodes = []*{{ $.Name }}{}
{{- with $.ForeignKeys }}
withFKs = {{ $receiver }}.withFKs
{{- end }}
_spec = {{ $receiver }}.querySpec()
{{- with $.Edges }}
loadedTypes = [{{ len . }}]bool{
{{- range $e := . }}
{{ $receiver }}.with{{ pascal $e.Name }} != nil,
{{- end }}
}
{{- end }}
)
{{- with $.ForeignKeys }}
{{- with $.FKEdges }}
if {{ range $i, $e := . }}{{ if gt $i 0 }} || {{ end }}{{ $receiver }}.with{{ pascal $e.Name }} != nil{{ end }} {
withFKs = true
}
{{- end }}
if withFKs {
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.ForeignKeys...)
}
{{- end }}
_spec.ScanValues = func() []interface{} {
node := &{{ $.Name }}{config: {{ $receiver }}.config}
nodes = append(nodes, node)
values := node.scanValues()
{{- with $.ForeignKeys }}
if withFKs {
values = append(values, node.fkValues()...)
}
{{- end }}
return values
}
_spec.Assign = func(values ...interface{}) error {
if len(nodes) == 0 {
return fmt.Errorf("{{ $pkg }}: Assign called without calling ScanValues")
}
node := nodes[len(nodes)-1]
{{- with $.Edges }}
node.Edges.loadedTypes = loadedTypes
{{- end }}
return node.assignValues(values...)
}
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
{{- range $e := $.Edges }}
{{- with extend $ "Rec" $receiver "Edge" $e }}
{{ template "dialect/sql/query/eagerloading" . }}
{{- end }}
{{- end }}
return nodes, nil
}
func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
_spec := {{ $receiver }}.querySpec()
return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec)
}
func ({{ $receiver }} *{{ $builder }}) sqlExist(ctx context.Context) (bool, error) {
n, err := {{ $receiver }}.sqlCount(ctx)
if err != nil {
return false, fmt.Errorf("{{ $pkg }}: check existence: %v", err)
}
return n > 0, nil
}
func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
_spec := &sqlgraph.QuerySpec{
Node: &sqlgraph.NodeSpec{
Table: {{ $.Package }}.Table,
Columns: {{ $.Package }}.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.{{ $.ID.Type.ConstName }},
Column: {{ $.Package }}.{{ $.ID.Constant }},
},
},
From: {{ $receiver }}.sql,
Unique: true,
}
if ps := {{ $receiver }}.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := {{ $receiver }}.limit; limit != nil {
_spec.Limit = *limit
}
if offset := {{ $receiver }}.offset; offset != nil {
_spec.Offset = *offset
}
if ps := {{ $receiver }}.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func ({{ $receiver }} *{{ $builder }}) sqlQuery() *sql.Selector {
builder := sql.Dialect({{ $receiver }}.driver.Dialect())
t1 := builder.Table({{ $.Package }}.Table)
selector := builder.Select(t1.Columns({{ $.Package }}.Columns...)...).From(t1)
if {{ $receiver }}.sql != nil {
selector = {{ $receiver }}.sql
selector.Select(selector.Columns({{ $.Package }}.Columns...)...)
}
for _, p := range {{ $receiver }}.predicates {
p(selector)
}
for _, p := range {{ $receiver }}.order {
p(selector)
}
if offset := {{ $receiver }}.offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := {{ $receiver }}.limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
{{ end }}
{{/* query/path defines the query generation for path of a given edge. */}}
{{ define "dialect/sql/query/path" }}
{{- $n := $ }} {{/* the node we start the query from. */}}
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
{{- $receiver := $.Scope.Receiver }}
step := sqlgraph.NewStep(
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, {{ $receiver }}.sqlQuery()),
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}),
sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }},
{{- if $e.M2M -}}
{{ $n.Package }}.{{ $e.PKConstant }}...
{{- else -}}
{{ $n.Package }}.{{ $e.ColumnConstant }}
{{- end -}}
),
)
query.sql = sqlgraph.SetNeighbors({{ $receiver }}.driver.Dialect(), step)
{{ end }}
{{/* query/from defines the query generation for an edge query from a given node. */}}
{{ define "dialect/sql/query/from" }}
{{- $n := $ }} {{/* the node we start the query from. */}}
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
{{- $receiver := $.Scope.Receiver -}}
id := {{ $receiver }}.{{ if and $.ID.IsString (not $.ID.UserDefined) }}id(){{ else }}ID{{ end }}
step := sqlgraph.NewStep(
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}),
sqlgraph.Edge(sqlgraph.{{ $e.Rel.Type }}, {{ $e.IsInverse }}, {{ $n.Package }}.{{ $e.TableConstant }},
{{- if $e.M2M -}}
{{ $n.Package }}.{{ $e.PKConstant }}...
{{- else -}}
{{ $n.Package }}.{{ $e.ColumnConstant }}
{{- end -}}
),
)
query.sql = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
{{ end }}
{{ define "dialect/sql/query/eagerloading" }}
{{- $e := $.Scope.Edge }}
{{- $receiver := $.Scope.Rec }}
if query := {{ $receiver }}.with{{ pascal $e.Name }}; query != nil {
{{- if $e.M2M }}
fks := make([]driver.Value, 0, len(nodes))
ids := make(map[{{ $.ID.Type }}]*{{ $.Name }}, len(nodes))
for _, node := range nodes {
ids[node.ID] = node
fks = append(fks, node.ID)
}
var (
edgeids []{{ $e.Type.ID.Type }}
edges = make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
)
_spec := &sqlgraph.EdgeQuerySpec{
Edge: &sqlgraph.EdgeSpec{
Inverse: {{ $e.IsInverse }},
Table: {{ $.Package }}.{{ $e.TableConstant }},
Columns: {{ $.Package }}.{{ $e.PKConstant }},
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues({{ $.Package }}.{{ $e.PKConstant }}[{{ if $e.IsInverse }}1{{ else }}0{{ end }}], fks...))
},
{{ $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.NullType }}{{ end }}
{{ $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.NullType }}{{ end }}
ScanValues: func() [2]interface{}{
return [2]interface{}{&{{ $out }}{}, &{{ $in }}{}}
},
Assign: func(out, in interface{}) error {
eout, ok := out.(*{{ $out }})
if !ok || eout == nil {
return fmt.Errorf("unexpected id value for edge-out")
}
ein, ok := in.(*{{ $in }})
if !ok || ein == nil {
return fmt.Errorf("unexpected id value for edge-in")
}
outValue := {{ with extend $ "Arg" "eout" "Field" $.ID "NullType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
inValue := {{ with extend $ "Arg" "ein" "Field" $e.Type.ID "NullType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
node, ok := ids[outValue]
if !ok {
return fmt.Errorf("unexpected node id in edges: %v", outValue)
}
edgeids = append(edgeids, inValue)
edges[inValue] = append(edges[inValue], node)
return nil
},
}
if err := sqlgraph.QueryEdges(ctx, {{ $receiver }}.driver, _spec); err != nil {
return nil, fmt.Errorf(`query edges "{{ $e.Name }}": %v`, err)
}
query.Where({{ $e.Type.Package }}.IDIn(edgeids...))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := edges[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.{{ $e.StructField }} = append(nodes[i].Edges.{{ $e.StructField }}, n)
}
}
{{- else if $e.OwnFK }}
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
for i := range nodes {
if fk := nodes[i].{{ $e.StructFKField }}; fk != nil {
ids = append(ids, *fk)
nodeids[*fk] = append(nodeids[*fk], nodes[i])
}
}
query.Where({{ $e.Type.Package }}.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "{{ $e.StructFKField }}" returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.{{ $e.StructField }} = n
}
}
{{- else }}
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
for i := range nodes {
{{- /* Convert string-ids that are stored as int in the database */ -}}
{{- if and (not $.ID.UserDefined) $.ID.IsString }}
id, err := strconv.Atoi(nodes[i].ID)
if err != nil {
return nil, err
}
fks = append(fks, id)
{{- else }}
fks = append(fks, nodes[i].ID)
{{- end }}
nodeids[nodes[i].ID] = nodes[i]
}
query.withFKs = true
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
fk := n.{{ $e.StructFKField }}
if fk == nil {
return nil, fmt.Errorf(`foreign-key "{{ $e.StructFKField }}" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "{{ $e.StructFKField }}" returned %v for node %v`, *fk, n.ID)
}
node.Edges.{{ $e.StructField }} = {{ if $e.Unique }}n{{ else }}append(node.Edges.{{ $e.StructField }}, n){{ end }}
}
{{- end }}
}
{{ end }}
{{- /* Convert string-ids that are stored as int in the database */ -}}
{{ define "dialect/sql/query/eagerloading/m2massign" }}
{{- $arg := $.Scope.Arg }}
{{- $field := $.Scope.Field }}
{{- $nulltype := $.Scope.NullType }}
{{- if and (not $field.UserDefined) $field.IsString -}}
strconv.FormatInt({{ $arg }}.Int64, 10)
{{- else if hasPrefix $nulltype "sql" -}}
{{ $field.NullTypeField $arg -}}
{{- else -}}
{{ if not $field.Nillable }}*{{ end }}{{ $arg }}
{{- end }}
{{- end }}