mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: move eager-loading to method (#2790)
This is a preparation work for 'WithNamed<E>' API
This commit is contained in:
@@ -75,9 +75,13 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
return nodes, nil
|
||||
}
|
||||
{{- range $e := $.Edges }}
|
||||
{{- with extend $ "Rec" $receiver "Edge" $e }}
|
||||
{{ template "dialect/sql/query/eagerloading" . }}
|
||||
{{- end }}
|
||||
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
|
||||
if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if and (not $e.M2M) (not $e.O2M) }}nil{{ else }}
|
||||
func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }},
|
||||
func(n *{{ $.Name }}, e *{{ $e.Type.Name }}){ n.Edges.{{ $e.StructField }} = {{ if or $e.OwnFK $e.Unique }}e{{ else }}append(n.Edges.{{ $e.StructField }}, e){{ end }} }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
{{- end }}
|
||||
{{- /* Allow extensions to inject code using templates to process nodes before they are returned. */}}
|
||||
{{- with $tmpls := matchTemplate "dialect/sql/query/all/nodes/*" }}
|
||||
@@ -88,6 +92,137 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
{{/* Generate a method to eager-load each edge. */}}
|
||||
{{- range $e := $.Edges }}
|
||||
func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
|
||||
{{- if $e.M2M }}
|
||||
edgeIDs := make([]driver.Value, len(nodes))
|
||||
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeIDs[i] = node.ID
|
||||
byID[node.ID] = node
|
||||
if init != nil {
|
||||
init(node)
|
||||
}
|
||||
}
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
|
||||
{{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }}
|
||||
{{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }}
|
||||
s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}]))
|
||||
s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeIDs...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
{{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }}
|
||||
{{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }}
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]interface{}{new({{ $out }})}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
nids[inValue][byID[outValue]] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
|
||||
}
|
||||
for kn := range nodes {
|
||||
assign(kn, n)
|
||||
}
|
||||
}
|
||||
{{- else if $e.OwnFK }}
|
||||
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
|
||||
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
|
||||
for i := range nodes {
|
||||
{{- $fk := $e.ForeignKey }}
|
||||
{{- if $fk.Field.Nillable }}
|
||||
if nodes[i].{{ $fk.StructField }} == nil {
|
||||
continue
|
||||
}
|
||||
{{- end }}
|
||||
fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }}
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
query.Where({{ $e.Type.Package }}.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
assign(nodes[i], n)
|
||||
}
|
||||
}
|
||||
{{- else }}
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
{{- if $e.O2M }}
|
||||
if init != nil {
|
||||
init(nodes[i])
|
||||
}
|
||||
{{- end }}
|
||||
}
|
||||
{{- with $e.Type.UnexportedForeignKeys }}
|
||||
query.withFKs = true
|
||||
{{- end }}
|
||||
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
{{- $fk := $e.ForeignKey }}
|
||||
fk := n.{{ $fk.StructField }}
|
||||
{{- if $fk.Field.Nillable }}
|
||||
if fk == nil {
|
||||
return fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID)
|
||||
}
|
||||
{{- end }}
|
||||
node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }})
|
||||
}
|
||||
assign(node, n)
|
||||
}
|
||||
{{- end }}
|
||||
return nil
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := {{ $receiver }}.querySpec()
|
||||
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}}
|
||||
@@ -286,133 +421,6 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
|
||||
{{ end }}
|
||||
|
||||
{{ define "dialect/sql/query/eagerloading" }}
|
||||
{{- $e := $.Scope.Edge }}
|
||||
{{- $receiver := $.Scope.Rec }}
|
||||
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
|
||||
{{- if $e.M2M }}
|
||||
edgeids := make([]driver.Value, len(nodes))
|
||||
byid := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeids[i] = node.ID
|
||||
byid[node.ID] = node
|
||||
node.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{}
|
||||
}
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
|
||||
{{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }}
|
||||
{{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }}
|
||||
s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}]))
|
||||
s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeids...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
{{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }}
|
||||
{{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }}
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]interface{}{new({{ $out }})}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*{{ $.Name }}]struct{}{byid[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
nids[inValue][byid[outValue]] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
|
||||
}
|
||||
for kn := range nodes {
|
||||
kn.Edges.{{ $e.StructField }} = append(kn.Edges.{{ $e.StructField }}, n)
|
||||
}
|
||||
}
|
||||
{{- else if $e.OwnFK }}
|
||||
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
|
||||
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
|
||||
for i := range nodes {
|
||||
{{- $fk := $e.ForeignKey }}
|
||||
{{- if $fk.Field.Nillable }}
|
||||
if nodes[i].{{ $fk.StructField }} == nil {
|
||||
continue
|
||||
}
|
||||
{{- end }}
|
||||
fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }}
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
query.Where({{ $e.Type.Package }}.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
nodes[i].Edges.{{ $e.StructField }} = n
|
||||
}
|
||||
}
|
||||
{{- else }}
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
{{- if $e.O2M }}
|
||||
nodes[i].Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{}
|
||||
{{- end }}
|
||||
}
|
||||
{{- with $e.Type.UnexportedForeignKeys }}
|
||||
query.withFKs = true
|
||||
{{- end }}
|
||||
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
{{- $fk := $e.ForeignKey }}
|
||||
fk := n.{{ $fk.StructField }}
|
||||
{{- if $fk.Field.Nillable }}
|
||||
if fk == nil {
|
||||
return nil, fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID)
|
||||
}
|
||||
{{- end }}
|
||||
node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }})
|
||||
}
|
||||
node.Edges.{{ $e.StructField }} = {{ if $e.Unique }}n{{ else }}append(node.Edges.{{ $e.StructField }}, n){{ end }}
|
||||
}
|
||||
{{- end }}
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
{{ define "dialect/sql/query/eagerloading/m2massign" }}
|
||||
{{- $arg := $.Scope.Arg }}
|
||||
{{- $field := $.Scope.Field }}
|
||||
|
||||
Reference in New Issue
Block a user