mirror of
https://github.com/ent/ent.git
synced 2026-03-05 19:35:23 +03:00
dialect/sql/sqlgraph: expose standard modifier to eager-load N neighbors (#3603)
This commit is contained in:
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -143,7 +143,7 @@ jobs:
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
maria:
|
||||
image: mariadb
|
||||
image: mariadb:10.4 # Temporary to unblock PRs from failing.
|
||||
env:
|
||||
MYSQL_DATABASE: test
|
||||
MYSQL_ROOT_PASSWORD: pass
|
||||
@@ -321,7 +321,7 @@ jobs:
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
maria:
|
||||
image: mariadb
|
||||
image: mariadb:10.4 # Temporary to unblock PRs from failing.
|
||||
env:
|
||||
MYSQL_DATABASE: test
|
||||
MYSQL_ROOT_PASSWORD: pass
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
@@ -520,6 +521,68 @@ func OrderByNeighborTerms(q *sql.Selector, s *Step, opts ...sql.OrderTerm) {
|
||||
orderTerms(q, join, opts)
|
||||
}
|
||||
|
||||
// NeighborsLimit provides a modifier function that limits the
|
||||
// number of neighbors (rows) loaded per parent row (node).
|
||||
type NeighborsLimit struct {
|
||||
// SrcCTE, LimitCTE and RowNumber hold the identifier names
|
||||
// to src query, new limited one (using window function) and
|
||||
// the column for counting rows.
|
||||
SrcCTE, LimitCTE, RowNumber string
|
||||
// DefaultOrderField sets the default ordering for
|
||||
// sub-queries in case no order terms were provided.
|
||||
DefaultOrderField string
|
||||
}
|
||||
|
||||
// LimitNeighbors returns a modifier that limits the number of neighbors (rows) loaded per parent
|
||||
// row (node). The "partitionBy" is the foreign-key column (edge) to partition the window function
|
||||
// by, the "limit" is the maximum number of rows per parent, and the "orderBy" defines the order of
|
||||
// how neighbors (connected by the edge) are returned.
|
||||
//
|
||||
// This function is useful for non-unique edges, such as O2M and M2M, where the same parent can
|
||||
// have multiple children.
|
||||
func LimitNeighbors(partitionBy string, limit int, orderBy ...sql.Querier) func(*sql.Selector) {
|
||||
l := &NeighborsLimit{
|
||||
SrcCTE: "src_query",
|
||||
LimitCTE: "limited_query",
|
||||
RowNumber: "row_number",
|
||||
DefaultOrderField: "id",
|
||||
}
|
||||
return l.Modifier(partitionBy, limit, orderBy...)
|
||||
}
|
||||
|
||||
// Modifier returns a modifier function that limits the number of rows of the eager load query.
|
||||
func (l *NeighborsLimit) Modifier(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) {
|
||||
return func(s *sql.Selector) {
|
||||
var (
|
||||
d = sql.Dialect(s.Dialect())
|
||||
rn = sql.RowNumber().PartitionBy(partitionBy)
|
||||
)
|
||||
switch {
|
||||
case len(orderBy) > 0:
|
||||
rn.OrderExpr(orderBy...)
|
||||
case l.DefaultOrderField != "":
|
||||
rn.OrderBy(l.DefaultOrderField)
|
||||
default:
|
||||
s.AddError(errors.New("no order terms provided for window function"))
|
||||
return
|
||||
}
|
||||
s.SetDistinct(false)
|
||||
with := d.With(l.SrcCTE).
|
||||
As(s.Clone()).
|
||||
With(l.LimitCTE).
|
||||
As(
|
||||
d.Select("*").
|
||||
AppendSelectExprAs(rn, l.RowNumber).
|
||||
From(d.Table(l.SrcCTE)),
|
||||
)
|
||||
t := d.Table(l.LimitCTE).As(s.TableName())
|
||||
*s = *d.Select(s.UnqualifiedColumns()...).
|
||||
From(t).
|
||||
Where(sql.LTE(t.C(l.RowNumber), limit)).
|
||||
Prefix(with)
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
// FieldSpec holds the information for updating a field
|
||||
// column in the database.
|
||||
|
||||
@@ -2648,6 +2648,33 @@ func TestIsConstraintError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitNeighbors(t *testing.T) {
|
||||
t.Run("O2M", func(t *testing.T) {
|
||||
const fk = "author_id"
|
||||
// Authors load their posts.
|
||||
s := sql.Select(fk, "id").From(sql.Table("posts"))
|
||||
LimitNeighbors(fk, 2)(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t,
|
||||
"WITH `src_query` AS (SELECT `author_id`, `id` FROM `posts`), `limited_query` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`)) AS `row_number` FROM `src_query`) SELECT `author_id`, `id` FROM `limited_query` AS `posts` WHERE `posts`.`row_number` <= ?",
|
||||
query,
|
||||
)
|
||||
require.Equal(t, []any{2}, args)
|
||||
})
|
||||
t.Run("M2M", func(t *testing.T) {
|
||||
const fk = "user_id"
|
||||
edgeT, neighborsT := sql.Table("user_groups"), sql.Table("groups")
|
||||
s := sql.Select(fk, "id", "name").From(neighborsT).Join(edgeT).On(neighborsT.C("id"), edgeT.C("group_id"))
|
||||
LimitNeighbors(fk, 1, sql.ExprFunc(func(b *sql.Builder) { b.Ident("updated_at") }))(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t,
|
||||
"WITH `src_query` AS (SELECT `user_id`, `id`, `name` FROM `groups` JOIN `user_groups` AS `t1` ON `groups`.`id` = `t1`.`group_id`), `limited_query` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `user_id` ORDER BY `updated_at`)) AS `row_number` FROM `src_query`) SELECT `user_id`, `id`, `name` FROM `limited_query` AS `groups` WHERE `groups`.`row_number` <= ?",
|
||||
query,
|
||||
)
|
||||
require.Equal(t, []any{1}, args)
|
||||
})
|
||||
}
|
||||
|
||||
func escape(query string) string {
|
||||
rows := strings.Split(query, "\n")
|
||||
for i := range rows {
|
||||
|
||||
Reference in New Issue
Block a user