entc/gen: use join for loading m2m relationship (#2417)

* entc/gen: use join for m2m relationship

* entc/gen: add test for eager-load inverse-m2m
This commit is contained in:
Ariel Mashraki
2022-03-21 11:37:54 +02:00
committed by GitHub
parent 938233b191
commit edd968490e
125 changed files with 2093 additions and 2483 deletions

View File

@@ -355,7 +355,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error {
return nil
}
func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) {
func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, error) {
var (
nodes = []*Car{}
withFKs = cq.withFKs
@@ -371,18 +371,17 @@ func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) {
_spec.Node.Columns = append(_spec.Node.Columns, car.ForeignKeys...)
}
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
node := &Car{config: cq.config}
nodes = append(nodes, node)
return node.scanValues(columns)
return (*Car).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []interface{}) error {
if len(nodes) == 0 {
return fmt.Errorf("ent: Assign called without calling ScanValues")
}
node := nodes[len(nodes)-1]
node := &Car{config: cq.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil {
return nil, err
}

View File

@@ -13,6 +13,7 @@ import (
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/examples/start/ent/car"
"entgo.io/ent/examples/start/ent/group"
"entgo.io/ent/examples/start/ent/user"
@@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool {
}
return v
}
// queryHook describes an internal hook for the different sqlAll methods.
type queryHook func(context.Context, *sqlgraph.QuerySpec)

View File

@@ -355,7 +355,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error {
return nil
}
func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) {
var (
nodes = []*Group{}
_spec = gq.querySpec()
@@ -364,18 +364,17 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
}
)
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
node := &Group{config: gq.config}
nodes = append(nodes, node)
return node.scanValues(columns)
return (*Group).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []interface{}) error {
if len(nodes) == 0 {
return fmt.Errorf("ent: Assign called without calling ScanValues")
}
node := nodes[len(nodes)-1]
node := &Group{config: gq.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil {
return nil, err
}
@@ -384,66 +383,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
}
if query := gq.withUsers; query != nil {
fks := make([]driver.Value, 0, len(nodes))
ids := make(map[int]*Group, len(nodes))
for _, node := range nodes {
ids[node.ID] = node
fks = append(fks, node.ID)
edgeids := make([]driver.Value, len(nodes))
byid := make(map[int]*Group)
nids := make(map[int]map[*Group]struct{})
for i, node := range nodes {
edgeids[i] = node.ID
byid[node.ID] = node
node.Edges.Users = []*User{}
}
var (
edgeids []int
edges = make(map[int][]*Group)
)
_spec := &sqlgraph.EdgeQuerySpec{
Edge: &sqlgraph.EdgeSpec{
Inverse: false,
Table: group.UsersTable,
Columns: group.UsersPrimaryKey,
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues(group.UsersPrimaryKey[0], fks...))
},
ScanValues: func() [2]interface{} {
return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)}
},
Assign: func(out, in interface{}) error {
eout, ok := out.(*sql.NullInt64)
if !ok || eout == nil {
return fmt.Errorf("unexpected id value for edge-out")
query.Where(func(s *sql.Selector) {
joinT := sql.Table(group.UsersTable)
s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1]))
s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...))
columns := s.SelectedColumns()
s.Select(joinT.C(group.UsersPrimaryKey[0]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
spec.ScanValues = func(columns []string) ([]interface{}, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
ein, ok := in.(*sql.NullInt64)
if !ok || ein == nil {
return fmt.Errorf("unexpected id value for edge-in")
return append([]interface{}{new(sql.NullInt64)}, values...), nil
}
spec.Assign = func(columns []string, values []interface{}) error {
outValue := int(values[0].(*sql.NullInt64).Int64)
inValue := int(values[1].(*sql.NullInt64).Int64)
if nids[inValue] == nil {
nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}}
return assign(columns[1:], values[1:])
}
outValue := int(eout.Int64)
inValue := int(ein.Int64)
node, ok := ids[outValue]
if !ok {
return fmt.Errorf("unexpected node id in edges: %v", outValue)
}
if _, ok := edges[inValue]; !ok {
edgeids = append(edgeids, inValue)
}
edges[inValue] = append(edges[inValue], node)
nids[inValue][byid[outValue]] = struct{}{}
return nil
},
}
if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil {
return nil, fmt.Errorf(`query edges "users": %w`, err)
}
query.Where(user.IDIn(edgeids...))
neighbors, err := query.All(ctx)
}
})
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := edges[n.ID]
nodes, ok := nids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.Users = append(nodes[i].Edges.Users, n)
for kn := range nodes {
kn.Edges.Users = append(kn.Edges.Users, n)
}
}
}

View File

@@ -391,7 +391,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error {
return nil
}
func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) {
var (
nodes = []*User{}
_spec = uq.querySpec()
@@ -401,18 +401,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
}
)
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
node := &User{config: uq.config}
nodes = append(nodes, node)
return node.scanValues(columns)
return (*User).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []interface{}) error {
if len(nodes) == 0 {
return fmt.Errorf("ent: Assign called without calling ScanValues")
}
node := nodes[len(nodes)-1]
node := &User{config: uq.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil {
return nil, err
}
@@ -450,66 +449,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
}
if query := uq.withGroups; query != nil {
fks := make([]driver.Value, 0, len(nodes))
ids := make(map[int]*User, len(nodes))
for _, node := range nodes {
ids[node.ID] = node
fks = append(fks, node.ID)
edgeids := make([]driver.Value, len(nodes))
byid := make(map[int]*User)
nids := make(map[int]map[*User]struct{})
for i, node := range nodes {
edgeids[i] = node.ID
byid[node.ID] = node
node.Edges.Groups = []*Group{}
}
var (
edgeids []int
edges = make(map[int][]*User)
)
_spec := &sqlgraph.EdgeQuerySpec{
Edge: &sqlgraph.EdgeSpec{
Inverse: true,
Table: user.GroupsTable,
Columns: user.GroupsPrimaryKey,
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues(user.GroupsPrimaryKey[1], fks...))
},
ScanValues: func() [2]interface{} {
return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)}
},
Assign: func(out, in interface{}) error {
eout, ok := out.(*sql.NullInt64)
if !ok || eout == nil {
return fmt.Errorf("unexpected id value for edge-out")
query.Where(func(s *sql.Selector) {
joinT := sql.Table(user.GroupsTable)
s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0]))
s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeids...))
columns := s.SelectedColumns()
s.Select(joinT.C(user.GroupsPrimaryKey[1]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
spec.ScanValues = func(columns []string) ([]interface{}, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
ein, ok := in.(*sql.NullInt64)
if !ok || ein == nil {
return fmt.Errorf("unexpected id value for edge-in")
return append([]interface{}{new(sql.NullInt64)}, values...), nil
}
spec.Assign = func(columns []string, values []interface{}) error {
outValue := int(values[0].(*sql.NullInt64).Int64)
inValue := int(values[1].(*sql.NullInt64).Int64)
if nids[inValue] == nil {
nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}}
return assign(columns[1:], values[1:])
}
outValue := int(eout.Int64)
inValue := int(ein.Int64)
node, ok := ids[outValue]
if !ok {
return fmt.Errorf("unexpected node id in edges: %v", outValue)
}
if _, ok := edges[inValue]; !ok {
edgeids = append(edgeids, inValue)
}
edges[inValue] = append(edges[inValue], node)
nids[inValue][byid[outValue]] = struct{}{}
return nil
},
}
if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil {
return nil, fmt.Errorf(`query edges "groups": %w`, err)
}
query.Where(group.IDIn(edgeids...))
neighbors, err := query.All(ctx)
}
})
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := edges[n.ID]
nodes, ok := nids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n)
for kn := range nodes {
kn.Edges.Groups = append(kn.Edges.Groups, n)
}
}
}