mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: move eager-loading to method (#2790)
This is a preparation work for 'WithNamed<E>' API
This commit is contained in:
@@ -388,37 +388,43 @@ func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, err
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
if query := cq.withOwner; query != nil {
|
||||
ids := make([]int, 0, len(nodes))
|
||||
nodeids := make(map[int][]*Car)
|
||||
for i := range nodes {
|
||||
if nodes[i].user_cars == nil {
|
||||
continue
|
||||
}
|
||||
fk := *nodes[i].user_cars
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
query.Where(user.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
if err := cq.loadOwner(ctx, query, nodes, nil,
|
||||
func(n *Car, e *User) { n.Edges.Owner = e }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected foreign-key "user_cars" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
nodes[i].Edges.Owner = n
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (cq *CarQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Car, init func(*Car), assign func(*Car, *User)) error {
|
||||
ids := make([]int, 0, len(nodes))
|
||||
nodeids := make(map[int][]*Car)
|
||||
for i := range nodes {
|
||||
if nodes[i].user_cars == nil {
|
||||
continue
|
||||
}
|
||||
fk := *nodes[i].user_cars
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
query.Where(user.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "user_cars" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
assign(nodes[i], n)
|
||||
}
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
|
||||
@@ -381,61 +381,70 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group,
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
if query := gq.withUsers; query != nil {
|
||||
edgeids := make([]driver.Value, len(nodes))
|
||||
byid := make(map[int]*Group)
|
||||
nids := make(map[int]map[*Group]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeids[i] = node.ID
|
||||
byid[node.ID] = node
|
||||
node.Edges.Users = []*User{}
|
||||
}
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table(group.UsersTable)
|
||||
s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1]))
|
||||
s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C(group.UsersPrimaryKey[0]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]interface{}{new(sql.NullInt64)}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := int(values[0].(*sql.NullInt64).Int64)
|
||||
inValue := int(values[1].(*sql.NullInt64).Int64)
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
nids[inValue][byid[outValue]] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
if err := gq.loadUsers(ctx, query, nodes,
|
||||
func(n *Group) { n.Edges.Users = []*User{} },
|
||||
func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID)
|
||||
}
|
||||
for kn := range nodes {
|
||||
kn.Edges.Users = append(kn.Edges.Users, n)
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error {
|
||||
edgeIDs := make([]driver.Value, len(nodes))
|
||||
byID := make(map[int]*Group)
|
||||
nids := make(map[int]map[*Group]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeIDs[i] = node.ID
|
||||
byID[node.ID] = node
|
||||
if init != nil {
|
||||
init(node)
|
||||
}
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table(group.UsersTable)
|
||||
s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1]))
|
||||
s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeIDs...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C(group.UsersPrimaryKey[0]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]interface{}{new(sql.NullInt64)}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := int(values[0].(*sql.NullInt64).Int64)
|
||||
inValue := int(values[1].(*sql.NullInt64).Int64)
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
nids[inValue][byID[outValue]] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected "users" node returned %v`, n.ID)
|
||||
}
|
||||
for kn := range nodes {
|
||||
assign(kn, n)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
|
||||
@@ -418,90 +418,108 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
if query := uq.withCars; query != nil {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int]*User)
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
nodes[i].Edges.Cars = []*Car{}
|
||||
}
|
||||
query.withFKs = true
|
||||
query.Where(predicate.Car(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(user.CarsColumn, fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
if err := uq.loadCars(ctx, query, nodes,
|
||||
func(n *User) { n.Edges.Cars = []*Car{} },
|
||||
func(n *User, e *Car) { n.Edges.Cars = append(n.Edges.Cars, e) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
fk := n.user_cars
|
||||
if fk == nil {
|
||||
return nil, fmt.Errorf(`foreign-key "user_cars" is nil for node %v`, n.ID)
|
||||
}
|
||||
node, ok := nodeids[*fk]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected foreign-key "user_cars" returned %v for node %v`, *fk, n.ID)
|
||||
}
|
||||
node.Edges.Cars = append(node.Edges.Cars, n)
|
||||
}
|
||||
}
|
||||
|
||||
if query := uq.withGroups; query != nil {
|
||||
edgeids := make([]driver.Value, len(nodes))
|
||||
byid := make(map[int]*User)
|
||||
nids := make(map[int]map[*User]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeids[i] = node.ID
|
||||
byid[node.ID] = node
|
||||
node.Edges.Groups = []*Group{}
|
||||
}
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table(user.GroupsTable)
|
||||
s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0]))
|
||||
s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeids...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C(user.GroupsPrimaryKey[1]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]interface{}{new(sql.NullInt64)}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := int(values[0].(*sql.NullInt64).Int64)
|
||||
inValue := int(values[1].(*sql.NullInt64).Int64)
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
nids[inValue][byid[outValue]] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
if err := uq.loadGroups(ctx, query, nodes,
|
||||
func(n *User) { n.Edges.Groups = []*Group{} },
|
||||
func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID)
|
||||
}
|
||||
for kn := range nodes {
|
||||
kn.Edges.Groups = append(kn.Edges.Groups, n)
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (uq *UserQuery) loadCars(ctx context.Context, query *CarQuery, nodes []*User, init func(*User), assign func(*User, *Car)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int]*User)
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
if init != nil {
|
||||
init(nodes[i])
|
||||
}
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
query.withFKs = true
|
||||
query.Where(predicate.Car(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(user.CarsColumn, fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
fk := n.user_cars
|
||||
if fk == nil {
|
||||
return fmt.Errorf(`foreign-key "user_cars" is nil for node %v`, n.ID)
|
||||
}
|
||||
node, ok := nodeids[*fk]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "user_cars" returned %v for node %v`, *fk, n.ID)
|
||||
}
|
||||
assign(node, n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error {
|
||||
edgeIDs := make([]driver.Value, len(nodes))
|
||||
byID := make(map[int]*User)
|
||||
nids := make(map[int]map[*User]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeIDs[i] = node.ID
|
||||
byID[node.ID] = node
|
||||
if init != nil {
|
||||
init(node)
|
||||
}
|
||||
}
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table(user.GroupsTable)
|
||||
s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0]))
|
||||
s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeIDs...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C(user.GroupsPrimaryKey[1]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
spec.ScanValues = func(columns []string) ([]interface{}, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]interface{}{new(sql.NullInt64)}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []interface{}) error {
|
||||
outValue := int(values[0].(*sql.NullInt64).Int64)
|
||||
inValue := int(values[1].(*sql.NullInt64).Int64)
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
nids[inValue][byID[outValue]] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID)
|
||||
}
|
||||
for kn := range nodes {
|
||||
assign(kn, n)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
|
||||
Reference in New Issue
Block a user