entc/gen: add eager-loading support (#263)

* entc/gen: add OwnFK indicator for type edges

* entc/gen: add Edges field for generated types

* entc/gen: add With<T> method to query-builder template

* entc/gen: scan and assign foreign-keys on eager-loading

* entc/gen: load fk-relations (wip)

* entc/integration: add o2m/m2o tests for eager-loading

* entc/gen: add m2m support for eager-loading

* entc/gen: add integration tests for m2m and subgraphs

* entc/gen/integration: add tests for o2o eager-loading

* all: generate all assets
This commit is contained in:
Ariel Mashraki
2020-01-13 17:21:26 +02:00
committed by GitHub
parent cd366c07e2
commit caf721df47
171 changed files with 6400 additions and 398 deletions

View File

@@ -21,20 +21,26 @@ type Group struct {
ID int `json:"id,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges struct {
// Users holds the value of the users edge.
Users []*User
}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*Group) scanValues() []interface{} {
return []interface{}{
&sql.NullInt64{},
&sql.NullString{},
&sql.NullInt64{}, // id
&sql.NullString{}, // name
}
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the Group fields.
func (gr *Group) assignValues(values ...interface{}) error {
if m, n := len(values), len(group.Columns); m != n {
if m, n := len(values), len(group.Columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
value, ok := values[0].(*sql.NullInt64)

View File

@@ -23,7 +23,7 @@ const (
UsersInverseTable = "users"
)
// Columns holds all SQL columns are group fields.
// Columns holds all SQL columns for group fields.
var Columns = []string{
FieldID,
FieldName,

View File

@@ -8,6 +8,7 @@ package ent
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"math"
@@ -28,6 +29,8 @@ type GroupQuery struct {
order []Order
unique []string
predicates []predicate.Group
// eager-loading edges.
withUsers *UserQuery
// intermediate query.
sql *sql.Selector
}
@@ -237,6 +240,17 @@ func (gq *GroupQuery) Clone() *GroupQuery {
}
}
// WithUsers tells the query-builder to eager-loads the nodes that are connected to
// the "users" edge. The optional arguments used to configure the query builder of the edge.
func (gq *GroupQuery) WithUsers(opts ...func(*UserQuery)) *GroupQuery {
query := &UserQuery{config: gq.config}
for _, opt := range opts {
opt(query)
}
gq.withUsers = query
return gq
}
// GroupBy used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -286,7 +300,8 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
spec.ScanValues = func() []interface{} {
node := &Group{config: gq.config}
nodes = append(nodes, node)
return node.scanValues()
values := node.scanValues()
return values
}
spec.Assign = func(values ...interface{}) error {
if len(nodes) == 0 {
@@ -298,6 +313,70 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
if err := sqlgraph.QueryNodes(ctx, gq.driver, spec); err != nil {
return nil, err
}
if query := gq.withUsers; query != nil {
fks := make([]driver.Value, 0, len(nodes))
ids := make(map[int]*Group, len(nodes))
for _, node := range nodes {
ids[node.ID] = node
fks = append(fks, node.ID)
}
var (
edgeids []int
edges = make(map[int][]*Group)
)
spec := &sqlgraph.EdgeQuerySpec{
Edge: &sqlgraph.EdgeSpec{
Inverse: false,
Table: group.UsersTable,
Columns: group.UsersPrimaryKey,
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues(group.UsersPrimaryKey[0], fks...))
},
ScanValues: func() [2]interface{} {
return [2]interface{}{&sql.NullInt64{}, &sql.NullInt64{}}
},
Assign: func(out, in interface{}) error {
eout, ok := out.(*sql.NullInt64)
if !ok || eout == nil {
return fmt.Errorf("unexpected id value for edge-out")
}
ein, ok := in.(*sql.NullInt64)
if !ok || ein == nil {
return fmt.Errorf("unexpected id value for edge-in")
}
outValue := int(eout.Int64)
inValue := int(eout.Int64)
node, ok := ids[outValue]
if !ok {
return fmt.Errorf("unexpected node id in edges: %v", outValue)
}
edgeids = append(edgeids, inValue)
edges[inValue] = append(edges[inValue], node)
return nil
},
}
if err := sqlgraph.QueryEdges(ctx, gq.driver, spec); err != nil {
return nil, fmt.Errorf(`query edges "users": %v`, err)
}
query.Where(user.IDIn(edgeids...))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := edges[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.Users = append(nodes[i].Edges.Users, n)
}
}
}
return nodes, nil
}

View File

@@ -23,21 +23,27 @@ type User struct {
Age int `json:"age,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges struct {
// Groups holds the value of the groups edge.
Groups []*Group
}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*User) scanValues() []interface{} {
return []interface{}{
&sql.NullInt64{},
&sql.NullInt64{},
&sql.NullString{},
&sql.NullInt64{}, // id
&sql.NullInt64{}, // age
&sql.NullString{}, // name
}
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the User fields.
func (u *User) assignValues(values ...interface{}) error {
if m, n := len(values), len(user.Columns); m != n {
if m, n := len(values), len(user.Columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
value, ok := values[0].(*sql.NullInt64)

View File

@@ -25,7 +25,7 @@ const (
GroupsInverseTable = "groups"
)
// Columns holds all SQL columns are user fields.
// Columns holds all SQL columns for user fields.
var Columns = []string{
FieldID,
FieldAge,

View File

@@ -8,6 +8,7 @@ package ent
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"math"
@@ -28,6 +29,8 @@ type UserQuery struct {
order []Order
unique []string
predicates []predicate.User
// eager-loading edges.
withGroups *GroupQuery
// intermediate query.
sql *sql.Selector
}
@@ -237,6 +240,17 @@ func (uq *UserQuery) Clone() *UserQuery {
}
}
// WithGroups tells the query-builder to eager-loads the nodes that are connected to
// the "groups" edge. The optional arguments used to configure the query builder of the edge.
func (uq *UserQuery) WithGroups(opts ...func(*GroupQuery)) *UserQuery {
query := &GroupQuery{config: uq.config}
for _, opt := range opts {
opt(query)
}
uq.withGroups = query
return uq
}
// GroupBy used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -286,7 +300,8 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
spec.ScanValues = func() []interface{} {
node := &User{config: uq.config}
nodes = append(nodes, node)
return node.scanValues()
values := node.scanValues()
return values
}
spec.Assign = func(values ...interface{}) error {
if len(nodes) == 0 {
@@ -298,6 +313,70 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil {
return nil, err
}
if query := uq.withGroups; query != nil {
fks := make([]driver.Value, 0, len(nodes))
ids := make(map[int]*User, len(nodes))
for _, node := range nodes {
ids[node.ID] = node
fks = append(fks, node.ID)
}
var (
edgeids []int
edges = make(map[int][]*User)
)
spec := &sqlgraph.EdgeQuerySpec{
Edge: &sqlgraph.EdgeSpec{
Inverse: true,
Table: user.GroupsTable,
Columns: user.GroupsPrimaryKey,
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues(user.GroupsPrimaryKey[1], fks...))
},
ScanValues: func() [2]interface{} {
return [2]interface{}{&sql.NullInt64{}, &sql.NullInt64{}}
},
Assign: func(out, in interface{}) error {
eout, ok := out.(*sql.NullInt64)
if !ok || eout == nil {
return fmt.Errorf("unexpected id value for edge-out")
}
ein, ok := in.(*sql.NullInt64)
if !ok || ein == nil {
return fmt.Errorf("unexpected id value for edge-in")
}
outValue := int(eout.Int64)
inValue := int(eout.Int64)
node, ok := ids[outValue]
if !ok {
return fmt.Errorf("unexpected node id in edges: %v", outValue)
}
edgeids = append(edgeids, inValue)
edges[inValue] = append(edges[inValue], node)
return nil
},
}
if err := sqlgraph.QueryEdges(ctx, uq.driver, spec); err != nil {
return nil, fmt.Errorf(`query edges "groups": %v`, err)
}
query.Where(group.IDIn(edgeids...))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := edges[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n)
}
}
}
return nodes, nil
}