dialect/sql/sqlgraph: add function to order by edge terms (#3426)

This commit is contained in:
Ariel Mashraki
2023-04-01 20:55:00 +03:00
committed by GitHub
parent 6f847a3492
commit 60bb939fc2
4 changed files with 399 additions and 41 deletions

View File

@@ -82,8 +82,6 @@ type Step struct {
Columns []string
// Inverse indicates if the edge is an inverse edge.
Inverse bool
// Name allows giving this edge a name for making queries more readable.
Name string
}
// To is the dest of the path (the neighbors).
To struct {
@@ -292,12 +290,14 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
type (
// OrderByOptions holds the information needed to order a query by an edge.
OrderByOptions struct {
// Step to get the edge to order by.
// Step to get the edge to order by its count.
Step *Step
// Desc indicates if the ordering should be descending.
// When false, nulls are ordered first. When true, nulls
// are ordered last.
// Desc indicates if the ordering should be in descending order. When false, NULL values
// are ordered first (default in MySQL and SQLite). When true, NULLs are ordered last.
Desc bool
// Terms used for non-aggregation ordering.
// See, OrderByNeighborTerms for more info.
Terms []OrderByTerm
}
// OrderByInfo holds the information done by the OrderBy functions.
OrderByInfo struct {
@@ -307,21 +307,90 @@ type (
OrderByTerm struct {
Column string // Column name. If empty, an expression is used.
Expr sql.Querier // Expression. If nil, the column is used.
As string // Optional alias.
Type field.Type // Term type.
Desc bool // Descending order.
}
// OrderByOption allows configuring OrderByOptions using functional options.
OrderByOption func(*OrderByOptions)
)
// OrderDesc sets the order to be descending order.
// This option is valid only for a single count order.
func OrderDesc() OrderByOption {
return func(opts *OrderByOptions) {
opts.Desc = true
}
}
// OrderByExpr appends an expression to the order by clause.
func OrderByExpr(x sql.Querier, as string) OrderByOption {
return func(opts *OrderByOptions) {
opts.Terms = append(opts.Terms, OrderByTerm{
Expr: x,
As: as,
})
}
}
// OrderByExprDesc appends an expression to the order by clause in descending order.
func OrderByExprDesc(x sql.Querier, as string) OrderByOption {
return func(opts *OrderByOptions) {
opts.Terms = append(opts.Terms, OrderByTerm{
Expr: x,
As: as,
Desc: true,
})
}
}
// OrderByColumn appends a column to the order by clause.
func OrderByColumn(c string) OrderByOption {
return func(opts *OrderByOptions) {
opts.Terms = append(opts.Terms, OrderByTerm{
Column: c,
})
}
}
// OrderByColumnDesc appends a column to the order by clause in descending order.
func OrderByColumnDesc(c string) OrderByOption {
return func(opts *OrderByOptions) {
opts.Terms = append(opts.Terms, OrderByTerm{
Column: c,
Desc: true,
})
}
}
// NewOrderBy gets list of options and returns a configured order-by.
//
// NewOrderBy(
// sqlgraph.NewStep(
// sqlgraph.From(user.Table, user.FieldID),
// sqlgraph.To(group.Table, group.FieldID),
// sqlgraph.Edge(sqlgraph.M2M, false, user.GroupsTable, user.GroupsPrimaryKey...),
// ),
// OrderByExpr(
// sql.Expr("SUM(age)"),
// "sum_age",
// ),
// )
func NewOrderBy(s *Step, opts ...OrderByOption) *OrderByOptions {
r := &OrderByOptions{Step: s}
for _, opt := range opts {
opt(r)
}
return r
}
// countAlias returns the alias to use for the count column.
func countAlias(q *sql.Selector, s *Step) string {
eName := s.Edge.Name
if eName == "" {
eName = s.To.Table
}
selected := make(map[string]struct{})
for _, c := range q.SelectedColumns() {
selected[c] = struct{}{}
}
column := fmt.Sprintf("count_%s", eName)
column := fmt.Sprintf("count_%s", s.To.Table)
// If the column was already selected,
// try to find a free alias.
if _, ok := selected[column]; ok {
@@ -335,12 +404,12 @@ func countAlias(q *sql.Selector, s *Step) string {
return column
}
// OrderByCountNeighbors appends ordering based on the number of neighbors.
// OrderByNeighborsCount appends ordering based on the number of neighbors.
// For example, order users by their number of posts.
// HasNeighbors applies on the given Selector a neighbors check.
func OrderByCountNeighbors(q *sql.Selector, opts *OrderByOptions) *OrderByInfo {
func OrderByNeighborsCount(q *sql.Selector, opts *OrderByOptions) *OrderByInfo {
var (
countC string
join *sql.Selector
build = sql.Dialect(q.Dialect())
)
switch s, r := opts.Step, opts.Step.Edge.Rel; {
@@ -368,47 +437,120 @@ func OrderByCountNeighbors(q *sql.Selector, opts *OrderByOptions) *OrderByInfo {
pk1 = s.Edge.Columns[1]
}
joinT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
to := build.Select(
join = build.Select(
joinT.C(pk1),
build.String(func(b *sql.Builder) {
b.WriteString("COUNT(*) AS ").Ident(countC)
}),
).From(joinT).GroupBy(joinT.C(pk1))
q.LeftJoin(to).
q.LeftJoin(join).
On(
q.C(s.From.Column),
to.C(pk1),
join.C(pk1),
)
case r == O2M || (r == O2O && !s.Edge.Inverse):
countC = countAlias(q, s)
edgeT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
to := build.Select(
join = build.Select(
edgeT.C(s.Edge.Columns[0]),
build.String(func(b *sql.Builder) {
b.WriteString("COUNT(*) AS ").Ident(countC)
}),
).From(edgeT).GroupBy(edgeT.C(s.Edge.Columns[0]))
q.LeftJoin(to).
q.LeftJoin(join).
On(
q.C(s.From.Column),
to.C(s.Edge.Columns[0]),
join.C(s.Edge.Columns[0]),
)
}
q.OrderExpr(
build.Expr(func(b *sql.Builder) {
b.WriteString("COALESCE(").Ident(countC).WriteString(", 0)")
if opts.Desc {
terms := []OrderByTerm{
{Column: countC, Type: field.TypeInt, Desc: opts.Desc},
}
orderTerms(q, join, terms)
return &OrderByInfo{Terms: terms}
}
func orderTerms(q, join *sql.Selector, ts []OrderByTerm) {
for _, t := range ts {
t := t
q.OrderExprFunc(func(b *sql.Builder) {
switch {
case t.As != "":
b.WriteString(join.C(t.As))
case t.Column != "":
b.WriteString(join.C(t.Column))
case t.Expr != nil:
b.Join(t.Expr)
}
// Unlike MySQL and SQLite, NULL values sort as if larger than any other value.
// Therefore, we need to explicitly order NULLs first on ASC and last on DESC.
switch pg := b.Dialect() == dialect.Postgres; {
case pg && t.Desc:
b.WriteString(" DESC NULLS LAST")
case pg:
b.WriteString(" NULLS FIRST")
case t.Desc:
b.WriteString(" DESC")
}
}),
)
return &OrderByInfo{
Terms: []OrderByTerm{
{Column: countC, Type: field.TypeInt},
},
})
}
}
func selectTerms(q *sql.Selector, ts []OrderByTerm) {
for _, t := range ts {
switch {
case t.Column != "" && t.As != "":
q.AppendSelect(q.C(t.Column), t.As)
case t.Column != "":
q.AppendSelect(q.C(t.Column))
case t.Expr != nil:
q.AppendSelectExprAs(t.Expr, t.As)
}
}
}
// OrderByNeighborTerms appends ordering based on the number of neighbors.
// For example, order users by their number of posts.
func OrderByNeighborTerms(q *sql.Selector, opts *OrderByOptions) {
var (
join *sql.Selector
build = sql.Dialect(q.Dialect())
)
switch s, r := opts.Step, opts.Step.Edge.Rel; {
case r == M2O || (r == O2O && s.Edge.Inverse):
toT := build.Table(s.To.Table).Schema(s.To.Schema)
join = build.Select(toT.C(s.To.Column)).
From(toT)
selectTerms(join, opts.Terms)
q.LeftJoin(join).
On(q.C(s.Edge.Columns[0]), join.C(s.To.Column))
case r == M2M:
pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0]
if s.Edge.Inverse {
pk1, pk2 = pk2, pk1
}
toT := build.Table(s.To.Table).Schema(s.To.Schema)
joinT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
join = build.Select(pk2).
From(toT).
Join(joinT).
On(toT.C(s.To.Column), joinT.C(pk1)).
GroupBy(pk2)
selectTerms(join, opts.Terms)
q.LeftJoin(join).
On(q.C(s.From.Column), join.C(pk2))
case r == O2M || (r == O2O && !s.Edge.Inverse):
toT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
join = build.Select(toT.C(s.Edge.Columns[0])).
From(toT).
GroupBy(toT.C(s.Edge.Columns[0]))
selectTerms(join, opts.Terms)
q.LeftJoin(join).
On(q.C(s.From.Column), join.C(s.Edge.Columns[0]))
}
orderTerms(q, join, opts.Terms)
}
type (
// FieldSpec holds the information for updating a field
// column in the database.

View File

@@ -911,14 +911,14 @@ func TestHasNeighborsWithContext(t *testing.T) {
}
}
func TestOrderByCountNeighbors(t *testing.T) {
func TestOrderByNeighborsCount(t *testing.T) {
build := sql.Dialect(dialect.Postgres)
t1 := build.Table("users")
s := build.Select(t1.C("name")).
From(t1)
t.Run("O2M", func(t *testing.T) {
s := s.Clone()
OrderByCountNeighbors(s, &OrderByOptions{
OrderByNeighborsCount(s, &OrderByOptions{
Step: NewStep(
From("users", "id"),
To("pets", "owner_id"),
@@ -928,11 +928,11 @@ func TestOrderByCountNeighbors(t *testing.T) {
})
query, args := s.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "pets"."owner_id", COUNT(*) AS "count_pets" FROM "pets" GROUP BY "pets"."owner_id") AS "t1" ON "users"."id" = "t1"."owner_id" ORDER BY COALESCE("count_pets", 0) DESC`, query)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "pets"."owner_id", COUNT(*) AS "count_pets" FROM "pets" GROUP BY "pets"."owner_id") AS "t1" ON "users"."id" = "t1"."owner_id" ORDER BY "t1"."count_pets" DESC NULLS LAST`, query)
})
t.Run("M2M", func(t *testing.T) {
s := s.Clone()
OrderByCountNeighbors(s, &OrderByOptions{
OrderByNeighborsCount(s, &OrderByOptions{
Step: NewStep(
From("users", "id"),
To("groups", "id"),
@@ -941,12 +941,12 @@ func TestOrderByCountNeighbors(t *testing.T) {
})
query, args := s.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "user_groups"."user_id", COUNT(*) AS "count_groups" FROM "user_groups" GROUP BY "user_groups"."user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY COALESCE("count_groups", 0)`, query)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "user_groups"."user_id", COUNT(*) AS "count_groups" FROM "user_groups" GROUP BY "user_groups"."user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."count_groups" NULLS FIRST`, query)
})
// Zero or one.
t.Run("M2O", func(t *testing.T) {
s1, s2 := s.Clone(), s.Clone()
OrderByCountNeighbors(s1, &OrderByOptions{
OrderByNeighborsCount(s1, &OrderByOptions{
Step: NewStep(
From("pets", "owner_id"),
To("users", "id"),
@@ -957,7 +957,7 @@ func TestOrderByCountNeighbors(t *testing.T) {
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" ORDER BY "owner_id" IS NULL`, query)
OrderByCountNeighbors(s2, &OrderByOptions{
OrderByNeighborsCount(s2, &OrderByOptions{
Step: NewStep(
From("pets", "owner_id"),
To("users", "id"),
@@ -971,6 +971,65 @@ func TestOrderByCountNeighbors(t *testing.T) {
})
}
func TestOrderByNeighborTerms(t *testing.T) {
build := sql.Dialect(dialect.Postgres)
t1 := build.Table("users")
s := build.Select(t1.C("name")).
From(t1)
t.Run("M2O", func(t *testing.T) {
s := s.Clone()
OrderByNeighborTerms(s, NewOrderBy(
NewStep(
From("users", "id"),
To("workplace", "id"),
Edge(M2O, true, "users", "workplace_id"),
),
OrderByColumn("name"),
))
query, args := s.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "workplace"."id", "workplace"."name" FROM "workplace") AS "t1" ON "users"."workplace_id" = "t1"."id" ORDER BY "t1"."name" NULLS FIRST`, query)
})
t.Run("O2M", func(t *testing.T) {
s := s.Clone()
OrderByNeighborTerms(s, NewOrderBy(
NewStep(
From("users", "id"),
To("repos", "id"),
Edge(O2M, false, "repo", "owner_id"),
),
OrderByExpr(
sql.ExprFunc(func(b *sql.Builder) {
b.S("SUM(").Ident("num_stars").S(")")
}),
"total_stars",
),
))
query, args := s.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "repo"."owner_id", (SUM("num_stars")) AS "total_stars" FROM "repo" GROUP BY "repo"."owner_id") AS "t1" ON "users"."id" = "t1"."owner_id" ORDER BY "t1"."total_stars" NULLS FIRST`, query)
})
t.Run("M2M", func(t *testing.T) {
s := s.Clone()
OrderByNeighborTerms(s, NewOrderBy(
NewStep(
From("users", "id"),
To("group", "id"),
Edge(M2M, false, "user_groups", "user_id", "group_id"),
),
OrderByExpr(
sql.ExprFunc(func(b *sql.Builder) {
b.S("SUM(").Ident("num_users").S(")")
}),
"total_users",
),
))
query, args := s.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "user_id", (SUM("num_users")) AS "total_users" FROM "group" JOIN "user_groups" AS "t1" ON "group"."id" = "t1"."group_id" GROUP BY "user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."total_users" NULLS FIRST`, query)
})
}
func TestCreateNode(t *testing.T) {
tests := []struct {
name string