entc/gen: add eager-loading support (#263)

* entc/gen: add OwnFK indicator for type edges

* entc/gen: add Edges field for generated types

* entc/gen: add With<T> method to query-builder template

* entc/gen: scan and assign foreign-keys on eager-loading

* entc/gen: load fk-relations (wip)

* entc/integration: add o2m/m2o tests for eager-loading

* entc/gen: add m2m support for eager-loading

* entc/gen: add integration tests for m2m and subgraphs

* entc/gen/integration: add tests for o2o eager-loading

* all: generate all assets
This commit is contained in:
Ariel Mashraki
2020-01-13 17:21:26 +02:00
committed by GitHub
parent cd366c07e2
commit caf721df47
171 changed files with 6400 additions and 398 deletions

View File

@@ -10,17 +10,29 @@ in the LICENSE file in the root directory of this source tree.
// scanValues returns the types for scanning values from sql.Rows.
func (*{{ $.Name }}) scanValues() []interface{} {
return []interface{} {
&{{ if not $.ID.UserDefined }}sql.NullInt64{{ else }}{{ $.ID.NullType }}{{ end }}{},
&{{ if not $.ID.UserDefined }}sql.NullInt64{{ else }}{{ $.ID.NullType }}{{ end }}{}, // {{ $.ID.Name }}
{{- range $_, $f := $.Fields }}
&{{ $f.NullType }}{},
&{{ $f.NullType }}{}, // {{ $f.Name }}
{{- end }}
}
}
{{- with $.ForeignKeys }}
// fkValues returns the types for scanning foreign-keys values from sql.Rows.
func (*{{ $.Name }}) fkValues() []interface{} {
return []interface{} {
{{- range $fk := . }}
{{- $f := $fk.Field }}
&{{ if not $f.UserDefined }}sql.NullInt64{{ else }}{{ $f.NullType }}{{ end }}{}, // {{ $f.Name }}
{{- end }}
}
}
{{- end }}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the {{ $.Name }} fields.
func ({{ $receiver }} *{{ $.Name }}) assignValues(values ...interface{}) error {
if m, n := len(values), len({{ $.Package }}.Columns); m != n {
if m, n := len(values), len({{ $.Package }}.Columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
{{- if and $.ID.UserDefined (or $.ID.IsString $.ID.IsUUID) }}
@@ -40,6 +52,26 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(values ...interface{}) error {
{{ template "dialect/sql/decode/field" . }}
{{- end }}
{{- end }}
{{- with $.ForeignKeys }}
values = values[{{ len $.Fields }}:]
if len(values) == len({{ $.Package }}.ForeignKeys) {
{{- range $i, $fk := . }}
{{- $f := $fk.Field }}
{{- if and $f.UserDefined (or $f.IsString $f.IsUUID) }}
{{- with extend $ "Idx" 0 "Field" $f "Rec" $receiver "StructField" $f.Name }}
{{ template "dialect/sql/decode/field" . }}
{{- end }}
{{- else }}
if value, ok := values[{{ $i }}].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for edge-field {{ $f.Name}}", value)
} else if value.Valid {
{{ $receiver }}.{{ $f.Name }} = new({{ $f.Type }})
*{{ $receiver }}.{{ $f.Name }} = {{ if $f.IsString }}strconv.FormatInt(value.Int64, 10){{ else }}{{ $f.Type }}(value.Int64){{ end }}
}
{{- end }}
{{- end }}
}
{{- end }}
return nil
}
{{ end }}
@@ -48,11 +80,12 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(values ...interface{}) error {
{{- $i := $.Scope.Idx -}}
{{- $f := $.Scope.Field -}}
{{- $ret := $.Scope.Rec -}}
{{- $field := $f.StructField }}{{ with $.Scope.StructField }}{{ $field = . }}{{ end }}
{{- if $f.IsJSON }}
if value, ok := values[{{ $i }}].(*{{ $f.NullType }}); !ok {
return fmt.Errorf("unexpected type %T for field {{ $f.Name }}", values[{{ $i }}])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &{{ $ret }}.{{ $f.StructField }}); err != nil {
if err := json.Unmarshal(*value, &{{ $ret }}.{{ $field }}); err != nil {
return fmt.Errorf("unmarshal field {{ $f.Name }}: %v", err)
}
}
@@ -63,14 +96,14 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(values ...interface{}) error {
{{- if hasPrefix $nulltype "sql" }}
} else if value.Valid {
{{- if $f.Nillable }}
{{ $ret }}.{{ $f.StructField }} = new({{ $f.Type }})
*{{ $ret }}.{{ $f.StructField }} = {{ $f.NullTypeField "value" }}
{{ $ret }}.{{ $field }} = new({{ $f.Type }})
*{{ $ret }}.{{ $field }} = {{ $f.NullTypeField "value" }}
{{- else }}
{{ $ret }}.{{ $f.StructField }} = {{ $f.NullTypeField "value" }}
{{ $ret }}.{{ $field }} = {{ $f.NullTypeField "value" }}
{{- end }}
{{- else }}
} else if value != nil {
{{ $ret }}.{{ $f.StructField }} = *value
{{ $ret }}.{{ $field }} = *value
{{- end }}
}
{{- end }}
@@ -78,3 +111,11 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(values ...interface{}) error {
{{ define "dialect/sql/decode/many" }}
{{ end }}
{{/* Additional fields for the generated model for holding the foreign-keys */}}
{{ define "dialect/sql/model/fields" }}
{{- range $fk := $.ForeignKeys }}
{{- $f := $fk.Field }}
{{ $f.Name }} {{ if $f.Nillable }}*{{ end }}{{ $f.Type }}
{{- end }}
{{ end }}

View File

@@ -12,7 +12,7 @@ in the LICENSE file in the root directory of this source tree.
// {{ $e.TableConstant }} is the table the holds the {{ $e.Name }} relation/edge.
{{- if $e.M2M }} The primary key declared below.{{ end }}
{{ $e.TableConstant }} = "{{ $e.Rel.Table }}"
{{- if eq $.Table $e.Type.Table | not }}
{{- if ne $.Table $e.Type.Table }}
// {{ $e.InverseTableConstant }} is the table name for the {{ $e.Type.Name }} entity.
// It exists in this package in order to avoid circular dependency with the "{{ $e.Type.Package }}" package.
{{ $e.InverseTableConstant }} = "{{ $e.Type.Table }}"
@@ -26,13 +26,22 @@ in the LICENSE file in the root directory of this source tree.
{{/* variables needed for sql dialects. */}}
{{ define "dialect/sql/meta/variables" }}
// Columns holds all SQL columns are {{ lower $.Name }} fields.
// Columns holds all SQL columns for {{ lower $.Name }} fields.
var Columns = []string{
{{ $.ID.Constant }},
{{- range $_, $f := $.Fields }}
{{- range $f := $.Fields }}
{{ $f.Constant }},
{{- end }}
}
{{/* if any of the edges owns a foreign-key */}}
{{ with $.ForeignKeys }}
// ForeignKeys holds the SQL foreign-keys that are owned by the {{ $.Name }} type.
var ForeignKeys = []string{
{{- range $fk := . }}
"{{ $fk.Field.Name }}",
{{- end }}
}
{{ end }}
{{ with $.NumM2M }}
var (

View File

@@ -4,6 +4,13 @@ This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}
{{/* Additional fields for the builder. */}}
{{ define "dialect/sql/query/fields" }}
{{- with $.ForeignKeys }}
withFKs bool
{{- end }}
{{- end }}
{{ define "dialect/sql/query" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
@@ -12,12 +19,31 @@ in the LICENSE file in the root directory of this source tree.
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name }}, error) {
var (
nodes []*{{ $.Name }}
{{- with $.ForeignKeys }}
withFKs = {{ $receiver }}.withFKs
{{- end }}
spec = {{ $receiver }}.querySpec()
)
{{- with $.ForeignKeys }}
{{- with $.FKEdges }}
if {{ range $i, $e := . }}{{ if gt $i 0 }} || {{ end }}{{ $receiver }}.with{{ pascal $e.Name }} != nil{{ end }} {
withFKs = true
}
{{- end }}
if withFKs {
spec.Node.Columns = append(spec.Node.Columns, {{ $.Package }}.ForeignKeys...)
}
{{- end }}
spec.ScanValues = func() []interface{} {
node := &{{ $.Name }}{config: {{ $receiver }}.config}
nodes = append(nodes, node)
return node.scanValues()
values := node.scanValues()
{{- with $.ForeignKeys }}
if withFKs {
values = append(values, node.fkValues()...)
}
{{- end }}
return values
}
spec.Assign = func(values ...interface{}) error {
if len(nodes) == 0 {
@@ -29,6 +55,11 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, spec); err != nil {
return nil, err
}
{{- range $e := $.Edges }}
{{- with extend $ "Rec" $receiver "Edge" $e }}
{{ template "dialect/sql/query/eagerloading" . }}
{{- end }}
{{- end }}
return nodes, nil
}
@@ -145,3 +176,145 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery() *sql.Selector {
)
query.sql = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
{{ end }}
{{ define "dialect/sql/query/eagerloading" }}
{{- $e := $.Scope.Edge }}
{{- $receiver := $.Scope.Rec }}
if query := {{ $receiver }}.with{{ pascal $e.Name }}; query != nil {
{{- if $e.M2M }}
fks := make([]driver.Value, 0, len(nodes))
ids := make(map[{{ $.ID.Type }}]*{{ $.Name }}, len(nodes))
for _, node := range nodes {
ids[node.ID] = node
fks = append(fks, node.ID)
}
var (
edgeids []{{ $e.Type.ID.Type }}
edges = make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
)
spec := &sqlgraph.EdgeQuerySpec{
Edge: &sqlgraph.EdgeSpec{
Inverse: {{ $e.IsInverse }},
Table: {{ $.Package }}.{{ $e.TableConstant }},
Columns: {{ $.Package }}.{{ $e.PKConstant }},
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues({{ $.Package }}.{{ $e.PKConstant }}[{{ if $e.IsInverse }}1{{ else }}0{{ end }}], fks...))
},
{{ $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.NullType }}{{ end }}
{{ $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.NullType }}{{ end }}
ScanValues: func() [2]interface{}{
return [2]interface{}{&{{ $out }}{}, &{{ $in }}{}}
},
Assign: func(out, in interface{}) error {
eout, ok := out.(*{{ $out }})
if !ok || eout == nil {
return fmt.Errorf("unexpected id value for edge-out")
}
ein, ok := in.(*{{ $in }})
if !ok || ein == nil {
return fmt.Errorf("unexpected id value for edge-in")
}
outValue := {{ with extend $ "Arg" "eout" "Field" $.ID "NullType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
inValue := {{ with extend $ "Arg" "ein" "Field" $e.Type.ID "NullType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
node, ok := ids[outValue]
if !ok {
return fmt.Errorf("unexpected node id in edges: %v", outValue)
}
edgeids = append(edgeids, inValue)
edges[inValue] = append(edges[inValue], node)
return nil
},
}
if err := sqlgraph.QueryEdges(ctx, {{ $receiver }}.driver, spec); err != nil {
return nil, fmt.Errorf(`query edges "{{ $e.Name }}": %v`, err)
}
query.Where({{ $e.Type.Package }}.IDIn(edgeids...))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := edges[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.{{ $e.StructField }} = append(nodes[i].Edges.{{ $e.StructField }}, n)
}
}
{{- else if $e.OwnFK }}
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
for i := range nodes {
if fk := nodes[i].{{ $e.StructFKField }}; fk != nil {
ids = append(ids, *fk)
nodeids[*fk] = append(nodeids[*fk], nodes[i])
}
}
query.Where({{ $e.Type.Package }}.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "{{ $e.StructFKField }}" returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.{{ $e.StructField }} = n
}
}
{{- else }}
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
for i := range nodes {
{{- /* Convert string-ids that are stored as int in the database */ -}}
{{- if and (not $.ID.UserDefined) $.ID.IsString }}
id, err := strconv.Atoi(nodes[i].ID)
if err != nil {
return nil, err
}
fks = append(fks, id)
{{- else }}
fks = append(fks, nodes[i].ID)
{{- end }}
nodeids[nodes[i].ID] = nodes[i]
}
query.withFKs = true
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
fk := n.{{ $e.StructFKField }}
if fk == nil {
return nil, fmt.Errorf(`foreign-key "{{ $e.StructFKField }}" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "{{ $e.StructFKField }}" returned %v for node %v`, *fk, n.ID)
}
node.Edges.{{ $e.StructField }} = {{ if $e.Unique }}n{{ else }}append(node.Edges.{{ $e.StructField }}, n){{ end }}
}
{{- end }}
}
{{ end }}
{{- /* Convert string-ids that are stored as int in the database */ -}}
{{ define "dialect/sql/query/eagerloading/m2massign" }}
{{- $arg := $.Scope.Arg }}
{{- $field := $.Scope.Field }}
{{- $nulltype := $.Scope.NullType }}
{{- if and (not $field.UserDefined) $field.IsString -}}
strconv.FormatInt({{ $arg }}.Int64, 10)
{{- else if hasPrefix $nulltype "sql" -}}
{{ $field.NullTypeField "eout" -}}
{{- else -}}
{{ $arg }}
{{- end }}
{{- end }}