dialect/sql/sqlgraph: expose standard modifier to eager-load N neighbors (#3603)

This commit is contained in:
Ariel Mashraki
2023-06-17 12:23:57 +03:00
committed by GitHub
parent ee7a50bc48
commit b49d5f5924
3 changed files with 92 additions and 2 deletions

View File

@@ -10,6 +10,7 @@ import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"math"
"sort"
@@ -520,6 +521,68 @@ func OrderByNeighborTerms(q *sql.Selector, s *Step, opts ...sql.OrderTerm) {
orderTerms(q, join, opts)
}
// NeighborsLimit provides a modifier function that limits the
// number of neighbors (rows) loaded per parent row (node).
type NeighborsLimit struct {
// SrcCTE, LimitCTE and RowNumber hold the identifier names
// to src query, new limited one (using window function) and
// the column for counting rows.
SrcCTE, LimitCTE, RowNumber string
// DefaultOrderField sets the default ordering for
// sub-queries in case no order terms were provided.
DefaultOrderField string
}
// LimitNeighbors returns a modifier that limits the number of neighbors (rows) loaded per parent
// row (node). The "partitionBy" is the foreign-key column (edge) to partition the window function
// by, the "limit" is the maximum number of rows per parent, and the "orderBy" defines the order of
// how neighbors (connected by the edge) are returned.
//
// This function is useful for non-unique edges, such as O2M and M2M, where the same parent can
// have multiple children.
func LimitNeighbors(partitionBy string, limit int, orderBy ...sql.Querier) func(*sql.Selector) {
l := &NeighborsLimit{
SrcCTE: "src_query",
LimitCTE: "limited_query",
RowNumber: "row_number",
DefaultOrderField: "id",
}
return l.Modifier(partitionBy, limit, orderBy...)
}
// Modifier returns a modifier function that limits the number of rows of the eager load query.
func (l *NeighborsLimit) Modifier(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) {
return func(s *sql.Selector) {
var (
d = sql.Dialect(s.Dialect())
rn = sql.RowNumber().PartitionBy(partitionBy)
)
switch {
case len(orderBy) > 0:
rn.OrderExpr(orderBy...)
case l.DefaultOrderField != "":
rn.OrderBy(l.DefaultOrderField)
default:
s.AddError(errors.New("no order terms provided for window function"))
return
}
s.SetDistinct(false)
with := d.With(l.SrcCTE).
As(s.Clone()).
With(l.LimitCTE).
As(
d.Select("*").
AppendSelectExprAs(rn, l.RowNumber).
From(d.Table(l.SrcCTE)),
)
t := d.Table(l.LimitCTE).As(s.TableName())
*s = *d.Select(s.UnqualifiedColumns()...).
From(t).
Where(sql.LTE(t.C(l.RowNumber), limit)).
Prefix(with)
}
}
type (
// FieldSpec holds the information for updating a field
// column in the database.