entc/gen: move eager-loading to method (#2790)

This is a preparation work for 'WithNamed<E>' API
This commit is contained in:
Ariel Mashraki
2022-07-23 23:46:02 +03:00
committed by GitHub
parent 43ceed9b6f
commit a2b18f24f0
85 changed files with 7066 additions and 5678 deletions

View File

@@ -380,61 +380,70 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
if len(nodes) == 0 {
return nodes, nil
}
if query := uq.withFriends; query != nil {
edgeids := make([]driver.Value, len(nodes))
byid := make(map[int]*User)
nids := make(map[int]map[*User]struct{})
for i, node := range nodes {
edgeids[i] = node.ID
byid[node.ID] = node
node.Edges.Friends = []*User{}
}
query.Where(func(s *sql.Selector) {
joinT := sql.Table(user.FriendsTable)
s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1]))
s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...))
columns := s.SelectedColumns()
s.Select(joinT.C(user.FriendsPrimaryKey[0]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
spec.ScanValues = func(columns []string) ([]interface{}, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
return append([]interface{}{new(sql.NullInt64)}, values...), nil
}
spec.Assign = func(columns []string, values []interface{}) error {
outValue := int(values[0].(*sql.NullInt64).Int64)
inValue := int(values[1].(*sql.NullInt64).Int64)
if nids[inValue] == nil {
nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}}
return assign(columns[1:], values[1:])
}
nids[inValue][byid[outValue]] = struct{}{}
return nil
}
})
if err != nil {
if err := uq.loadFriends(ctx, query, nodes,
func(n *User) { n.Edges.Friends = []*User{} },
func(n *User, e *User) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := nids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID)
}
for kn := range nodes {
kn.Edges.Friends = append(kn.Edges.Friends, n)
}
}
return nodes, nil
}
func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error {
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[int]*User)
nids := make(map[int]map[*User]struct{})
for i, node := range nodes {
edgeIDs[i] = node.ID
byID[node.ID] = node
if init != nil {
init(node)
}
}
return nodes, nil
query.Where(func(s *sql.Selector) {
joinT := sql.Table(user.FriendsTable)
s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1]))
s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeIDs...))
columns := s.SelectedColumns()
s.Select(joinT.C(user.FriendsPrimaryKey[0]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
spec.ScanValues = func(columns []string) ([]interface{}, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
return append([]interface{}{new(sql.NullInt64)}, values...), nil
}
spec.Assign = func(columns []string, values []interface{}) error {
outValue := int(values[0].(*sql.NullInt64).Int64)
inValue := int(values[1].(*sql.NullInt64).Int64)
if nids[inValue] == nil {
nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}}
return assign(columns[1:], values[1:])
}
nids[inValue][byID[outValue]] = struct{}{}
return nil
}
})
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nids[n.ID]
if !ok {
return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID)
}
for kn := range nodes {
assign(kn, n)
}
}
return nil
}
func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) {