mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: use join for loading m2m relationship (#2417)
* entc/gen: use join for m2m relationship * entc/gen: add test for eager-load inverse-m2m
This commit is contained in:
@@ -355,7 +355,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) {
|
||||
func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, error) {
|
||||
var (
|
||||
nodes = []*Car{}
|
||||
withFKs = cq.withFKs
|
||||
@@ -371,18 +371,17 @@ func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, car.ForeignKeys...)
|
||||
}
|
||||
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
node := &Car{config: cq.config}
|
||||
nodes = append(nodes, node)
|
||||
return node.scanValues(columns)
|
||||
return (*Car).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []interface{}) error {
|
||||
if len(nodes) == 0 {
|
||||
return fmt.Errorf("ent: Assign called without calling ScanValues")
|
||||
}
|
||||
node := nodes[len(nodes)-1]
|
||||
node := &Car{config: cq.config}
|
||||
nodes = append(nodes, node)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/examples/start/ent/car"
|
||||
"entgo.io/ent/examples/start/ent/group"
|
||||
"entgo.io/ent/examples/start/ent/user"
|
||||
@@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// queryHook describes an internal hook for the different sqlAll methods.
|
||||
type queryHook func(context.Context, *sqlgraph.QuerySpec)
|
||||
|
||||
@@ -355,7 +355,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
|
||||
func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) {
|
||||
var (
|
||||
nodes = []*Group{}
|
||||
_spec = gq.querySpec()
|
||||
@@ -364,18 +364,17 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
|
||||
}
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
node := &Group{config: gq.config}
|
||||
nodes = append(nodes, node)
|
||||
return node.scanValues(columns)
|
||||
return (*Group).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []interface{}) error {
|
||||
if len(nodes) == 0 {
|
||||
return fmt.Errorf("ent: Assign called without calling ScanValues")
|
||||
}
|
||||
node := nodes[len(nodes)-1]
|
||||
node := &Group{config: gq.config}
|
||||
nodes = append(nodes, node)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -384,66 +383,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
|
||||
}
|
||||
|
||||
if query := gq.withUsers; query != nil {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
ids := make(map[int]*Group, len(nodes))
|
||||
for _, node := range nodes {
|
||||
ids[node.ID] = node
|
||||
fks = append(fks, node.ID)
|
||||
edgeids := make([]driver.Value, len(nodes))
|
||||
byid := make(map[int]*Group)
|
||||
nids := make(map[int]map[*Group]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeids[i] = node.ID
|
||||
byid[node.ID] = node
|
||||
node.Edges.Users = []*User{}
|
||||
}
|
||||
var (
|
||||
edgeids []int
|
||||
edges = make(map[int][]*Group)
|
||||
)
|
||||
_spec := &sqlgraph.EdgeQuerySpec{
|
||||
Edge: &sqlgraph.EdgeSpec{
|
||||
Inverse: false,
|
||||
Table: group.UsersTable,
|
||||
Columns: group.UsersPrimaryKey,
|
||||
},
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(group.UsersPrimaryKey[0], fks...))
|
||||
},
|
||||
ScanValues: func() [2]interface{} {
|
||||
return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)}
|
||||
},
|
||||
Assign: func(out, in interface{}) error {
|
||||
eout, ok := out.(*sql.NullInt64)
|
||||
if !ok || eout == nil {
|
||||
return fmt.Errorf("unexpected id value for edge-out")
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table(group.UsersTable)
|
||||
s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1]))
|
||||
s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C(group.UsersPrimaryKey[0]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ein, ok := in.(*sql.NullInt64)
|
||||
if !ok || ein == nil {
|
||||
return fmt.Errorf("unexpected id value for edge-in")
|
||||
return append([]interface{}{new(sql.NullInt64)}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := int(values[0].(*sql.NullInt64).Int64)
|
||||
inValue := int(values[1].(*sql.NullInt64).Int64)
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
outValue := int(eout.Int64)
|
||||
inValue := int(ein.Int64)
|
||||
node, ok := ids[outValue]
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected node id in edges: %v", outValue)
|
||||
}
|
||||
if _, ok := edges[inValue]; !ok {
|
||||
edgeids = append(edgeids, inValue)
|
||||
}
|
||||
edges[inValue] = append(edges[inValue], node)
|
||||
nids[inValue][byid[outValue]] = struct{}{}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil {
|
||||
return nil, fmt.Errorf(`query edges "users": %w`, err)
|
||||
}
|
||||
query.Where(user.IDIn(edgeids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := edges[n.ID]
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
nodes[i].Edges.Users = append(nodes[i].Edges.Users, n)
|
||||
for kn := range nodes {
|
||||
kn.Edges.Users = append(kn.Edges.Users, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,7 +391,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
|
||||
func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) {
|
||||
var (
|
||||
nodes = []*User{}
|
||||
_spec = uq.querySpec()
|
||||
@@ -401,18 +401,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
|
||||
}
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
node := &User{config: uq.config}
|
||||
nodes = append(nodes, node)
|
||||
return node.scanValues(columns)
|
||||
return (*User).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []interface{}) error {
|
||||
if len(nodes) == 0 {
|
||||
return fmt.Errorf("ent: Assign called without calling ScanValues")
|
||||
}
|
||||
node := nodes[len(nodes)-1]
|
||||
node := &User{config: uq.config}
|
||||
nodes = append(nodes, node)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -450,66 +449,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
|
||||
}
|
||||
|
||||
if query := uq.withGroups; query != nil {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
ids := make(map[int]*User, len(nodes))
|
||||
for _, node := range nodes {
|
||||
ids[node.ID] = node
|
||||
fks = append(fks, node.ID)
|
||||
edgeids := make([]driver.Value, len(nodes))
|
||||
byid := make(map[int]*User)
|
||||
nids := make(map[int]map[*User]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeids[i] = node.ID
|
||||
byid[node.ID] = node
|
||||
node.Edges.Groups = []*Group{}
|
||||
}
|
||||
var (
|
||||
edgeids []int
|
||||
edges = make(map[int][]*User)
|
||||
)
|
||||
_spec := &sqlgraph.EdgeQuerySpec{
|
||||
Edge: &sqlgraph.EdgeSpec{
|
||||
Inverse: true,
|
||||
Table: user.GroupsTable,
|
||||
Columns: user.GroupsPrimaryKey,
|
||||
},
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(user.GroupsPrimaryKey[1], fks...))
|
||||
},
|
||||
ScanValues: func() [2]interface{} {
|
||||
return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)}
|
||||
},
|
||||
Assign: func(out, in interface{}) error {
|
||||
eout, ok := out.(*sql.NullInt64)
|
||||
if !ok || eout == nil {
|
||||
return fmt.Errorf("unexpected id value for edge-out")
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table(user.GroupsTable)
|
||||
s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0]))
|
||||
s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeids...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C(user.GroupsPrimaryKey[1]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ein, ok := in.(*sql.NullInt64)
|
||||
if !ok || ein == nil {
|
||||
return fmt.Errorf("unexpected id value for edge-in")
|
||||
return append([]interface{}{new(sql.NullInt64)}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := int(values[0].(*sql.NullInt64).Int64)
|
||||
inValue := int(values[1].(*sql.NullInt64).Int64)
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
outValue := int(eout.Int64)
|
||||
inValue := int(ein.Int64)
|
||||
node, ok := ids[outValue]
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected node id in edges: %v", outValue)
|
||||
}
|
||||
if _, ok := edges[inValue]; !ok {
|
||||
edgeids = append(edgeids, inValue)
|
||||
}
|
||||
edges[inValue] = append(edges[inValue], node)
|
||||
nids[inValue][byid[outValue]] = struct{}{}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil {
|
||||
return nil, fmt.Errorf(`query edges "groups": %w`, err)
|
||||
}
|
||||
query.Where(group.IDIn(edgeids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := edges[n.ID]
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n)
|
||||
for kn := range nodes {
|
||||
kn.Edges.Groups = append(kn.Edges.Groups, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user