entc/gen: add eager-loading support (#263)

* entc/gen: add OwnFK indicator for type edges

* entc/gen: add Edges field for generated types

* entc/gen: add With<T> method to query-builder template

* entc/gen: scan and assign foreign-keys on eager-loading

* entc/gen: load fk-relations (wip)

* entc/integration: add o2m/m2o tests for eager-loading

* entc/gen: add m2m support for eager-loading

* entc/gen: add integration tests for m2m and subgraphs

* entc/gen/integration: add tests for o2o eager-loading

* all: generate all assets
This commit is contained in:
Ariel Mashraki
2020-01-13 17:21:26 +02:00
committed by GitHub
parent cd366c07e2
commit caf721df47
171 changed files with 6400 additions and 398 deletions

View File

@@ -24,21 +24,35 @@ type Card struct {
Expired time.Time `json:"expired,omitempty"`
// Number holds the value of the "number" field.
Number string `json:"number,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the CardQuery when eager-loading is set.
Edges struct {
// Owner holds the value of the owner edge.
Owner *User
}
owner_id *int
}
// scanValues returns the types for scanning values from sql.Rows.
func (*Card) scanValues() []interface{} {
return []interface{}{
&sql.NullInt64{},
&sql.NullTime{},
&sql.NullString{},
&sql.NullInt64{}, // id
&sql.NullTime{}, // expired
&sql.NullString{}, // number
}
}
// fkValues returns the types for scanning foreign-keys values from sql.Rows.
func (*Card) fkValues() []interface{} {
return []interface{}{
&sql.NullInt64{}, // owner_id
}
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the Card fields.
func (c *Card) assignValues(values ...interface{}) error {
if m, n := len(values), len(card.Columns); m != n {
if m, n := len(values), len(card.Columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
value, ok := values[0].(*sql.NullInt64)
@@ -57,6 +71,15 @@ func (c *Card) assignValues(values ...interface{}) error {
} else if value.Valid {
c.Number = value.String
}
values = values[2:]
if len(values) == len(card.ForeignKeys) {
if value, ok := values[0].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for edge-field owner_id", value)
} else if value.Valid {
c.owner_id = new(int)
*c.owner_id = int(value.Int64)
}
}
return nil
}

View File

@@ -27,9 +27,14 @@ const (
OwnerColumn = "owner_id"
)
// Columns holds all SQL columns are card fields.
// Columns holds all SQL columns for card fields.
var Columns = []string{
FieldID,
FieldExpired,
FieldNumber,
}
// ForeignKeys holds the SQL foreign-keys that are owned by the Card type.
var ForeignKeys = []string{
"owner_id",
}

View File

@@ -28,6 +28,9 @@ type CardQuery struct {
order []Order
unique []string
predicates []predicate.Card
// eager-loading edges.
withOwner *UserQuery
withFKs bool
// intermediate query.
sql *sql.Selector
}
@@ -237,6 +240,17 @@ func (cq *CardQuery) Clone() *CardQuery {
}
}
// WithOwner tells the query-builder to eager-loads the nodes that are connected to
// the "owner" edge. The optional arguments used to configure the query builder of the edge.
func (cq *CardQuery) WithOwner(opts ...func(*UserQuery)) *CardQuery {
query := &UserQuery{config: cq.config}
for _, opt := range opts {
opt(query)
}
cq.withOwner = query
return cq
}
// GroupBy used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -280,13 +294,24 @@ func (cq *CardQuery) Select(field string, fields ...string) *CardSelect {
func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) {
var (
nodes []*Card
spec = cq.querySpec()
nodes []*Card
withFKs = cq.withFKs
spec = cq.querySpec()
)
if cq.withOwner != nil {
withFKs = true
}
if withFKs {
spec.Node.Columns = append(spec.Node.Columns, card.ForeignKeys...)
}
spec.ScanValues = func() []interface{} {
node := &Card{config: cq.config}
nodes = append(nodes, node)
return node.scanValues()
values := node.scanValues()
if withFKs {
values = append(values, node.fkValues()...)
}
return values
}
spec.Assign = func(values ...interface{}) error {
if len(nodes) == 0 {
@@ -298,6 +323,32 @@ func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) {
if err := sqlgraph.QueryNodes(ctx, cq.driver, spec); err != nil {
return nil, err
}
if query := cq.withOwner; query != nil {
ids := make([]int, 0, len(nodes))
nodeids := make(map[int][]*Card)
for i := range nodes {
if fk := nodes[i].owner_id; fk != nil {
ids = append(ids, *fk)
nodeids[*fk] = append(nodeids[*fk], nodes[i])
}
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.Owner = n
}
}
}
return nodes, nil
}

View File

@@ -23,21 +23,27 @@ type User struct {
Age int `json:"age,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges struct {
// Card holds the value of the card edge.
Card *Card
}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*User) scanValues() []interface{} {
return []interface{}{
&sql.NullInt64{},
&sql.NullInt64{},
&sql.NullString{},
&sql.NullInt64{}, // id
&sql.NullInt64{}, // age
&sql.NullString{}, // name
}
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the User fields.
func (u *User) assignValues(values ...interface{}) error {
if m, n := len(values), len(user.Columns); m != n {
if m, n := len(values), len(user.Columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
value, ok := values[0].(*sql.NullInt64)

View File

@@ -27,7 +27,7 @@ const (
CardColumn = "owner_id"
)
// Columns holds all SQL columns are user fields.
// Columns holds all SQL columns for user fields.
var Columns = []string{
FieldID,
FieldAge,

View File

@@ -8,6 +8,7 @@ package ent
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"math"
@@ -28,6 +29,8 @@ type UserQuery struct {
order []Order
unique []string
predicates []predicate.User
// eager-loading edges.
withCard *CardQuery
// intermediate query.
sql *sql.Selector
}
@@ -237,6 +240,17 @@ func (uq *UserQuery) Clone() *UserQuery {
}
}
// WithCard tells the query-builder to eager-loads the nodes that are connected to
// the "card" edge. The optional arguments used to configure the query builder of the edge.
func (uq *UserQuery) WithCard(opts ...func(*CardQuery)) *UserQuery {
query := &CardQuery{config: uq.config}
for _, opt := range opts {
opt(query)
}
uq.withCard = query
return uq
}
// GroupBy used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -286,7 +300,8 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
spec.ScanValues = func() []interface{} {
node := &User{config: uq.config}
nodes = append(nodes, node)
return node.scanValues()
values := node.scanValues()
return values
}
spec.Assign = func(values ...interface{}) error {
if len(nodes) == 0 {
@@ -298,6 +313,35 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil {
return nil, err
}
if query := uq.withCard; query != nil {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
}
query.withFKs = true
query.Where(predicate.Card(func(s *sql.Selector) {
s.Where(sql.InValues(user.CardColumn, fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
fk := n.owner_id
if fk == nil {
return nil, fmt.Errorf(`foreign-key "owner_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, *fk, n.ID)
}
node.Edges.Card = n
}
}
return nodes, nil
}