mirror of
https://github.com/ent/ent.git
synced 2026-04-28 05:30:56 +03:00
entc/gen: use join for loading m2m relationship (#2417)
* entc/gen: use join for m2m relationship * entc/gen: add test for eager-load inverse-m2m
This commit is contained in:
@@ -23,7 +23,7 @@ in the LICENSE file in the root directory of this source tree.
|
||||
{{ $builder := pascal $.Scope.Builder }}
|
||||
{{ $receiver := receiver $builder }}
|
||||
|
||||
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name }}, error) {
|
||||
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
|
||||
var (
|
||||
nodes = []*{{ $.Name }}{}
|
||||
{{- with $.UnexportedForeignKeys }}
|
||||
@@ -49,15 +49,11 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name
|
||||
}
|
||||
{{- end }}
|
||||
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
node := &{{ $.Name }}{config: {{ $receiver }}.config}
|
||||
nodes = append(nodes, node)
|
||||
return node.scanValues(columns)
|
||||
return (*{{ $.Name }}).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []interface{}) error {
|
||||
if len(nodes) == 0 {
|
||||
return fmt.Errorf("{{ $pkg }}: Assign called without calling ScanValues")
|
||||
}
|
||||
node := nodes[len(nodes)-1]
|
||||
node := &{{ $.Name }}{config: {{ $receiver }}.config}
|
||||
nodes = append(nodes, node)
|
||||
{{- with $.Edges }}
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
{{- end }}
|
||||
@@ -69,6 +65,9 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name
|
||||
{{- xtemplate $tmpl $ }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -272,74 +271,58 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- $receiver := $.Scope.Rec }}
|
||||
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
|
||||
{{- if $e.M2M }}
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
ids := make(map[{{ $.ID.Type }}]*{{ $.Name }}, len(nodes))
|
||||
for _, node := range nodes {
|
||||
ids[node.ID] = node
|
||||
fks = append(fks, node.ID)
|
||||
edgeids := make([]driver.Value, len(nodes))
|
||||
byid := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeids[i] = node.ID
|
||||
byid[node.ID] = node
|
||||
node.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{}
|
||||
}
|
||||
var (
|
||||
edgeids []{{ $e.Type.ID.Type }}
|
||||
edges = make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
|
||||
)
|
||||
_spec := &sqlgraph.EdgeQuerySpec{
|
||||
Edge: &sqlgraph.EdgeSpec{
|
||||
Inverse: {{ $e.IsInverse }},
|
||||
Table: {{ $.Package }}.{{ $e.TableConstant }},
|
||||
Columns: {{ $.Package }}.{{ $e.PKConstant }},
|
||||
},
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.InValues({{ $.Package }}.{{ $e.PKConstant }}[{{ if $e.IsInverse }}1{{ else }}0{{ end }}], fks...))
|
||||
},
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
|
||||
{{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }}
|
||||
{{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }}
|
||||
s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}]))
|
||||
s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeids...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
{{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }}
|
||||
{{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }}
|
||||
ScanValues: func() [2]interface{}{
|
||||
return [2]interface{}{new({{ $out }}), new({{ $in }})}
|
||||
},
|
||||
Assign: func(out, in interface{}) error {
|
||||
eout, ok := out.(*{{ $out }})
|
||||
if !ok || eout == nil {
|
||||
return fmt.Errorf("unexpected id value for edge-out")
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ein, ok := in.(*{{ $in }})
|
||||
if !ok || ein == nil {
|
||||
return fmt.Errorf("unexpected id value for edge-in")
|
||||
return append([]interface{}{new({{ $out }})}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*{{ $.Name }}]struct{}{byid[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
outValue := {{ with extend $ "Arg" "eout" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
inValue := {{ with extend $ "Arg" "ein" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
node, ok := ids[outValue]
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected node id in edges: %v", outValue)
|
||||
}
|
||||
if _, ok := edges[inValue]; !ok {
|
||||
edgeids = append(edgeids, inValue)
|
||||
}
|
||||
edges[inValue] = append(edges[inValue], node)
|
||||
nids[inValue][byid[outValue]] = struct{}{}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
{{- /* Allow mutating the sqlgraph.EdgeQuerySpec by ent extensions or user templates.*/}}
|
||||
{{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/spec/*" }}
|
||||
{{- range $tmpl := $tmpls }}
|
||||
{{- xtemplate $tmpl $ }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
if err := sqlgraph.QueryEdges(ctx, {{ $receiver }}.driver, _spec); err != nil {
|
||||
return nil, fmt.Errorf(`query edges "{{ $e.Name }}": %w`, err)
|
||||
}
|
||||
query.Where({{ $e.Type.Package }}.IDIn(edgeids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := edges[n.ID]
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
nodes[i].Edges.{{ $e.StructField }} = append(nodes[i].Edges.{{ $e.StructField }}, n)
|
||||
for kn := range nodes {
|
||||
kn.Edges.{{ $e.StructField }} = append(kn.Edges.{{ $e.StructField }}, n)
|
||||
}
|
||||
}
|
||||
{{- else if $e.OwnFK }}
|
||||
@@ -415,9 +398,9 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- $field := $.Scope.Field }}
|
||||
{{- $scantype := $.Scope.ScanType }}
|
||||
{{- if hasPrefix $scantype "sql" -}}
|
||||
{{ $field.ScanTypeField $arg -}}
|
||||
{{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}}
|
||||
{{- else -}}
|
||||
{{ if not $field.Nillable }}*{{ end }}{{ $arg }}
|
||||
{{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user