mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: add eager-loading support (#263)
* entc/gen: add OwnFK indicator for type edges * entc/gen: add Edges field for generated types * entc/gen: add With<T> method to query-builder template * entc/gen: scan and assign foreign-keys on eager-loading * entc/gen: load fk-relations (wip) * entc/integration: add o2m/m2o tests for eager-loading * entc/gen: add m2m support for eager-loading * entc/gen: add integration tests for m2m and subgraphs * entc/gen/integration: add tests for o2o eager-loading * all: generate all assets
This commit is contained in:
@@ -4,6 +4,13 @@ This source code is licensed under the Apache 2.0 license found
|
||||
in the LICENSE file in the root directory of this source tree.
|
||||
*/}}
|
||||
|
||||
{{/* Additional fields for the builder. */}}
|
||||
{{ define "dialect/sql/query/fields" }}
|
||||
{{- with $.ForeignKeys }}
|
||||
withFKs bool
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/query" }}
|
||||
{{ $pkg := $.Scope.Package }}
|
||||
{{ $builder := pascal $.Scope.Builder }}
|
||||
@@ -12,12 +19,31 @@ in the LICENSE file in the root directory of this source tree.
|
||||
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name }}, error) {
|
||||
var (
|
||||
nodes []*{{ $.Name }}
|
||||
{{- with $.ForeignKeys }}
|
||||
withFKs = {{ $receiver }}.withFKs
|
||||
{{- end }}
|
||||
spec = {{ $receiver }}.querySpec()
|
||||
)
|
||||
{{- with $.ForeignKeys }}
|
||||
{{- with $.FKEdges }}
|
||||
if {{ range $i, $e := . }}{{ if gt $i 0 }} || {{ end }}{{ $receiver }}.with{{ pascal $e.Name }} != nil{{ end }} {
|
||||
withFKs = true
|
||||
}
|
||||
{{- end }}
|
||||
if withFKs {
|
||||
spec.Node.Columns = append(spec.Node.Columns, {{ $.Package }}.ForeignKeys...)
|
||||
}
|
||||
{{- end }}
|
||||
spec.ScanValues = func() []interface{} {
|
||||
node := &{{ $.Name }}{config: {{ $receiver }}.config}
|
||||
nodes = append(nodes, node)
|
||||
return node.scanValues()
|
||||
values := node.scanValues()
|
||||
{{- with $.ForeignKeys }}
|
||||
if withFKs {
|
||||
values = append(values, node.fkValues()...)
|
||||
}
|
||||
{{- end }}
|
||||
return values
|
||||
}
|
||||
spec.Assign = func(values ...interface{}) error {
|
||||
if len(nodes) == 0 {
|
||||
@@ -29,6 +55,11 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name
|
||||
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
{{- range $e := $.Edges }}
|
||||
{{- with extend $ "Rec" $receiver "Edge" $e }}
|
||||
{{ template "dialect/sql/query/eagerloading" . }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
@@ -145,3 +176,145 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery() *sql.Selector {
|
||||
)
|
||||
query.sql = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
|
||||
{{ end }}
|
||||
|
||||
{{ define "dialect/sql/query/eagerloading" }}
|
||||
{{- $e := $.Scope.Edge }}
|
||||
{{- $receiver := $.Scope.Rec }}
|
||||
if query := {{ $receiver }}.with{{ pascal $e.Name }}; query != nil {
|
||||
{{- if $e.M2M }}
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
ids := make(map[{{ $.ID.Type }}]*{{ $.Name }}, len(nodes))
|
||||
for _, node := range nodes {
|
||||
ids[node.ID] = node
|
||||
fks = append(fks, node.ID)
|
||||
}
|
||||
var (
|
||||
edgeids []{{ $e.Type.ID.Type }}
|
||||
edges = make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
|
||||
)
|
||||
spec := &sqlgraph.EdgeQuerySpec{
|
||||
Edge: &sqlgraph.EdgeSpec{
|
||||
Inverse: {{ $e.IsInverse }},
|
||||
Table: {{ $.Package }}.{{ $e.TableConstant }},
|
||||
Columns: {{ $.Package }}.{{ $e.PKConstant }},
|
||||
},
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.InValues({{ $.Package }}.{{ $e.PKConstant }}[{{ if $e.IsInverse }}1{{ else }}0{{ end }}], fks...))
|
||||
},
|
||||
{{ $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.NullType }}{{ end }}
|
||||
{{ $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.NullType }}{{ end }}
|
||||
ScanValues: func() [2]interface{}{
|
||||
return [2]interface{}{&{{ $out }}{}, &{{ $in }}{}}
|
||||
},
|
||||
Assign: func(out, in interface{}) error {
|
||||
eout, ok := out.(*{{ $out }})
|
||||
if !ok || eout == nil {
|
||||
return fmt.Errorf("unexpected id value for edge-out")
|
||||
}
|
||||
ein, ok := in.(*{{ $in }})
|
||||
if !ok || ein == nil {
|
||||
return fmt.Errorf("unexpected id value for edge-in")
|
||||
}
|
||||
outValue := {{ with extend $ "Arg" "eout" "Field" $.ID "NullType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
inValue := {{ with extend $ "Arg" "ein" "Field" $e.Type.ID "NullType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
node, ok := ids[outValue]
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected node id in edges: %v", outValue)
|
||||
}
|
||||
edgeids = append(edgeids, inValue)
|
||||
edges[inValue] = append(edges[inValue], node)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := sqlgraph.QueryEdges(ctx, {{ $receiver }}.driver, spec); err != nil {
|
||||
return nil, fmt.Errorf(`query edges "{{ $e.Name }}": %v`, err)
|
||||
}
|
||||
query.Where({{ $e.Type.Package }}.IDIn(edgeids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := edges[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
nodes[i].Edges.{{ $e.StructField }} = append(nodes[i].Edges.{{ $e.StructField }}, n)
|
||||
}
|
||||
}
|
||||
{{- else if $e.OwnFK }}
|
||||
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
|
||||
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
|
||||
for i := range nodes {
|
||||
if fk := nodes[i].{{ $e.StructFKField }}; fk != nil {
|
||||
ids = append(ids, *fk)
|
||||
nodeids[*fk] = append(nodeids[*fk], nodes[i])
|
||||
}
|
||||
}
|
||||
query.Where({{ $e.Type.Package }}.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected foreign-key "{{ $e.StructFKField }}" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
nodes[i].Edges.{{ $e.StructField }} = n
|
||||
}
|
||||
}
|
||||
{{- else }}
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
for i := range nodes {
|
||||
{{- /* Convert string-ids that are stored as int in the database */ -}}
|
||||
{{- if and (not $.ID.UserDefined) $.ID.IsString }}
|
||||
id, err := strconv.Atoi(nodes[i].ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fks = append(fks, id)
|
||||
{{- else }}
|
||||
fks = append(fks, nodes[i].ID)
|
||||
{{- end }}
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
}
|
||||
query.withFKs = true
|
||||
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
fk := n.{{ $e.StructFKField }}
|
||||
if fk == nil {
|
||||
return nil, fmt.Errorf(`foreign-key "{{ $e.StructFKField }}" is nil for node %v`, n.ID)
|
||||
}
|
||||
node, ok := nodeids[*fk]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected foreign-key "{{ $e.StructFKField }}" returned %v for node %v`, *fk, n.ID)
|
||||
}
|
||||
node.Edges.{{ $e.StructField }} = {{ if $e.Unique }}n{{ else }}append(node.Edges.{{ $e.StructField }}, n){{ end }}
|
||||
}
|
||||
{{- end }}
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
{{- /* Convert string-ids that are stored as int in the database */ -}}
|
||||
{{ define "dialect/sql/query/eagerloading/m2massign" }}
|
||||
{{- $arg := $.Scope.Arg }}
|
||||
{{- $field := $.Scope.Field }}
|
||||
{{- $nulltype := $.Scope.NullType }}
|
||||
{{- if and (not $field.UserDefined) $field.IsString -}}
|
||||
strconv.FormatInt({{ $arg }}.Int64, 10)
|
||||
{{- else if hasPrefix $nulltype "sql" -}}
|
||||
{{ $field.NullTypeField "eout" -}}
|
||||
{{- else -}}
|
||||
{{ $arg }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
Reference in New Issue
Block a user