mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
entc/gen: add fluent-api for order options (#3449)
This commit is contained in:
@@ -2249,10 +2249,13 @@ func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector {
|
||||
// AppendSelectExprAs appends additional expressions to the SELECT statement with the given name.
|
||||
func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector {
|
||||
s.selection = append(s.selection, ExprFunc(func(b *Builder) {
|
||||
b.WriteByte('(')
|
||||
b.Join(expr)
|
||||
b.WriteString(") AS ")
|
||||
b.Ident(as)
|
||||
switch expr.(type) {
|
||||
case *raw:
|
||||
// Raw expressions are not wrapped in parentheses.
|
||||
b.Join(expr).S(" AS ").Ident(as)
|
||||
default:
|
||||
b.S("(").Join(expr).S(") AS ").Ident(as)
|
||||
}
|
||||
}))
|
||||
return s
|
||||
}
|
||||
@@ -3298,8 +3301,9 @@ type exprFunc struct {
|
||||
}
|
||||
|
||||
func (e *exprFunc) Query() (string, []any) {
|
||||
e.fn(&e.Builder)
|
||||
return e.Builder.Query()
|
||||
b := e.Builder.clone()
|
||||
e.fn(&b)
|
||||
return b.Query()
|
||||
}
|
||||
|
||||
// Queries are list of queries join with space between them.
|
||||
|
||||
@@ -6,6 +6,7 @@ package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// The following helpers exist to simplify the way raw predicates
|
||||
@@ -196,7 +197,7 @@ type (
|
||||
// OrderExprTerm represents an ordering by an expression.
|
||||
OrderExprTerm struct {
|
||||
OrderTermOptions
|
||||
Expr Querier // Expression.
|
||||
Expr func(*Selector) Querier // Expression.
|
||||
}
|
||||
// OrderTerm represents an ordering by a term.
|
||||
OrderTerm interface {
|
||||
@@ -204,9 +205,11 @@ type (
|
||||
}
|
||||
// OrderTermOptions represents options for ordering by a term.
|
||||
OrderTermOptions struct {
|
||||
Desc bool // Whether to sort in descending order.
|
||||
As string // Optional alias.
|
||||
Selected bool // Whether the term should be selected.
|
||||
Desc bool // Whether to sort in descending order.
|
||||
As string // Optional alias.
|
||||
Selected bool // Whether the term should be selected.
|
||||
NullsFirst bool // Whether to sort nulls first.
|
||||
NullsLast bool // Whether to sort nulls last.
|
||||
}
|
||||
// OrderTermOption is an option for ordering by a term.
|
||||
OrderTermOption func(*OrderTermOptions)
|
||||
@@ -241,6 +244,20 @@ func OrderSelectAs(as string) OrderTermOption {
|
||||
}
|
||||
}
|
||||
|
||||
// OrderNullsFirst returns an option to sort nulls first.
|
||||
func OrderNullsFirst() OrderTermOption {
|
||||
return func(o *OrderTermOptions) {
|
||||
o.NullsFirst = true
|
||||
}
|
||||
}
|
||||
|
||||
// OrderNullsLast returns an option to sort nulls last.
|
||||
func OrderNullsLast() OrderTermOption {
|
||||
return func(o *OrderTermOptions) {
|
||||
o.NullsLast = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewOrderTermOptions returns a new OrderTermOptions from the given options.
|
||||
func NewOrderTermOptions(opts ...OrderTermOption) *OrderTermOptions {
|
||||
o := &OrderTermOptions{}
|
||||
@@ -251,13 +268,59 @@ func NewOrderTermOptions(opts ...OrderTermOption) *OrderTermOptions {
|
||||
}
|
||||
|
||||
// OrderByField returns an ordering by the given field.
|
||||
func OrderByField(name string, opts ...OrderTermOption) *OrderFieldTerm {
|
||||
return &OrderFieldTerm{Field: name, OrderTermOptions: *NewOrderTermOptions(opts...)}
|
||||
func OrderByField(field string, opts ...OrderTermOption) *OrderFieldTerm {
|
||||
return &OrderFieldTerm{Field: field, OrderTermOptions: *NewOrderTermOptions(opts...)}
|
||||
}
|
||||
|
||||
// OrderByExpr returns an ordering by the given expression.
|
||||
func OrderByExpr(x Querier, opts ...OrderTermOption) *OrderExprTerm {
|
||||
return &OrderExprTerm{Expr: x, OrderTermOptions: *NewOrderTermOptions(opts...)}
|
||||
// OrderBySum returns an ordering by the sum of the given field.
|
||||
func OrderBySum(field string, opts ...OrderTermOption) *OrderExprTerm {
|
||||
return orderByAgg("SUM", field, opts...)
|
||||
}
|
||||
|
||||
// OrderByCount returns an ordering by the count of the given field.
|
||||
func OrderByCount(field string, opts ...OrderTermOption) *OrderExprTerm {
|
||||
return orderByAgg("COUNT", field, opts...)
|
||||
}
|
||||
|
||||
// orderByAgg returns an ordering by the aggregation of the given field.
|
||||
func orderByAgg(fn, field string, opts ...OrderTermOption) *OrderExprTerm {
|
||||
return &OrderExprTerm{
|
||||
OrderTermOptions: *NewOrderTermOptions(
|
||||
append(
|
||||
// Default alias is "<func>_<field>".
|
||||
[]OrderTermOption{OrderAs(fmt.Sprintf("%s_%s", strings.ToLower(fn), field))},
|
||||
opts...,
|
||||
)...,
|
||||
),
|
||||
Expr: func(s *Selector) Querier {
|
||||
var c string
|
||||
switch {
|
||||
case field == "*", isFunc(field):
|
||||
c = field
|
||||
default:
|
||||
c = s.C(field)
|
||||
}
|
||||
return Raw(fmt.Sprintf("%s(%s)", fn, c))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToFunc returns a function that sets the ordering on the given selector.
|
||||
// This is used by the generated code.
|
||||
func (f *OrderFieldTerm) ToFunc() func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
s.OrderExprFunc(func(b *Builder) {
|
||||
b.WriteString(s.C(f.Field))
|
||||
if f.Desc {
|
||||
b.WriteString(" DESC")
|
||||
}
|
||||
if f.NullsFirst {
|
||||
b.WriteString(" NULLS FIRST")
|
||||
} else if f.NullsLast {
|
||||
b.WriteString(" NULLS LAST")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (OrderFieldTerm) term() {}
|
||||
|
||||
@@ -349,7 +349,10 @@ func OrderByNeighborsCount(q *sql.Selector, s *Step, opts ...sql.OrderTermOption
|
||||
}
|
||||
q.OrderExpr(build.Expr(x))
|
||||
case s.ThroughEdgeTable():
|
||||
countC := countAlias(q, s, opt)
|
||||
countAs := countAlias(q, s, opt)
|
||||
terms := []sql.OrderTerm{
|
||||
sql.OrderByCount("*", append([]sql.OrderTermOption{sql.OrderAs(countAs)}, opts...)...),
|
||||
}
|
||||
pk1 := s.Edge.Columns[0]
|
||||
if s.Edge.Inverse {
|
||||
pk1 = s.Edge.Columns[1]
|
||||
@@ -357,42 +360,38 @@ func OrderByNeighborsCount(q *sql.Selector, s *Step, opts ...sql.OrderTermOption
|
||||
joinT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
|
||||
join = build.Select(
|
||||
joinT.C(pk1),
|
||||
build.String(func(b *sql.Builder) {
|
||||
b.WriteString("COUNT(*) AS ").Ident(countC)
|
||||
}),
|
||||
).From(joinT).GroupBy(joinT.C(pk1))
|
||||
selectTerms(join, terms)
|
||||
q.LeftJoin(join).
|
||||
On(
|
||||
q.C(s.From.Column),
|
||||
join.C(pk1),
|
||||
)
|
||||
orderTerms(q, join, []sql.OrderTerm{
|
||||
sql.OrderByExpr(nil, append(opts, sql.OrderAs(countC))...),
|
||||
})
|
||||
orderTerms(q, join, terms)
|
||||
case s.ToEdgeOwner():
|
||||
countC := countAlias(q, s, opt)
|
||||
countAs := countAlias(q, s, opt)
|
||||
terms := []sql.OrderTerm{
|
||||
sql.OrderByCount("*", append([]sql.OrderTermOption{sql.OrderAs(countAs)}, opts...)...),
|
||||
}
|
||||
edgeT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
|
||||
join = build.Select(
|
||||
edgeT.C(s.Edge.Columns[0]),
|
||||
build.String(func(b *sql.Builder) {
|
||||
b.WriteString("COUNT(*) AS ").Ident(countC)
|
||||
}),
|
||||
).From(edgeT).GroupBy(edgeT.C(s.Edge.Columns[0]))
|
||||
selectTerms(join, terms)
|
||||
q.LeftJoin(join).
|
||||
On(
|
||||
q.C(s.From.Column),
|
||||
join.C(s.Edge.Columns[0]),
|
||||
)
|
||||
orderTerms(q, join, []sql.OrderTerm{
|
||||
sql.OrderByExpr(nil, append(opts, sql.OrderAs(countC))...),
|
||||
})
|
||||
orderTerms(q, join, terms)
|
||||
}
|
||||
}
|
||||
|
||||
func orderTerms(q, join *sql.Selector, ts []sql.OrderTerm) {
|
||||
for _, t := range ts {
|
||||
t := t
|
||||
q.OrderExprFunc(func(b *sql.Builder) {
|
||||
var desc bool
|
||||
var desc, nullsfirst, nullslast bool
|
||||
switch t := t.(type) {
|
||||
case *sql.OrderFieldTerm:
|
||||
f := t.Field
|
||||
@@ -405,6 +404,8 @@ func orderTerms(q, join *sql.Selector, ts []sql.OrderTerm) {
|
||||
q.AppendSelect(c)
|
||||
}
|
||||
desc = t.Desc
|
||||
nullsfirst = t.NullsFirst
|
||||
nullslast = t.NullsLast
|
||||
case *sql.OrderExprTerm:
|
||||
if t.As != "" {
|
||||
c := join.C(t.As)
|
||||
@@ -413,33 +414,42 @@ func orderTerms(q, join *sql.Selector, ts []sql.OrderTerm) {
|
||||
q.AppendSelect(c)
|
||||
}
|
||||
} else {
|
||||
b.Join(t.Expr)
|
||||
b.Join(t.Expr(join))
|
||||
}
|
||||
desc = t.Desc
|
||||
nullsfirst = t.NullsFirst
|
||||
nullslast = t.NullsLast
|
||||
default:
|
||||
return
|
||||
}
|
||||
// Unlike MySQL and SQLite, NULL values sort as if larger than any other value.
|
||||
// Therefore, we need to explicitly order NULLs first on ASC and last on DESC.
|
||||
switch pg := b.Dialect() == dialect.Postgres; {
|
||||
case pg && desc:
|
||||
// Unlike MySQL and SQLite, NULL values sort as if larger than any other value. Therefore,
|
||||
// we need to explicitly order NULLs first on ASC and last on DESC unless specified otherwise.
|
||||
switch normalizePG := b.Dialect() == dialect.Postgres && !nullsfirst && !nullslast; {
|
||||
case normalizePG && desc:
|
||||
b.WriteString(" DESC NULLS LAST")
|
||||
case pg:
|
||||
case normalizePG:
|
||||
b.WriteString(" NULLS FIRST")
|
||||
case desc:
|
||||
b.WriteString(" DESC")
|
||||
}
|
||||
if nullsfirst {
|
||||
b.WriteString(" NULLS FIRST")
|
||||
} else if nullslast {
|
||||
b.WriteString(" NULLS LAST")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// selectTerms appends the select terms to the joined query.
|
||||
// Afterward, the term aliases are utilized to order the root query.
|
||||
func selectTerms(q *sql.Selector, ts []sql.OrderTerm) {
|
||||
for _, t := range ts {
|
||||
switch t := t.(type) {
|
||||
case *sql.OrderFieldTerm:
|
||||
q.AppendSelect(q.C(t.Field))
|
||||
case *sql.OrderExprTerm:
|
||||
q.AppendSelectExprAs(t.Expr, t.As)
|
||||
q.AppendSelectExprAs(t.Expr(q), t.As)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1006,6 +1006,23 @@ func TestOrderByNeighborTerms(t *testing.T) {
|
||||
require.Empty(t, args)
|
||||
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "workplace"."id", "workplace"."name" FROM "workplace") AS "t1" ON "users"."workplace_id" = "t1"."id" ORDER BY "t1"."name" NULLS FIRST`, query)
|
||||
})
|
||||
t.Run("M2O/NullsLast", func(t *testing.T) {
|
||||
s := s.Clone()
|
||||
OrderByNeighborTerms(s,
|
||||
NewStep(
|
||||
From("users", "id"),
|
||||
To("workplace", "id"),
|
||||
Edge(M2O, true, "users", "workplace_id"),
|
||||
),
|
||||
sql.OrderByField(
|
||||
"name",
|
||||
sql.OrderNullsLast(),
|
||||
),
|
||||
)
|
||||
query, args := s.Query()
|
||||
require.Empty(t, args)
|
||||
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "workplace"."id", "workplace"."name" FROM "workplace") AS "t1" ON "users"."workplace_id" = "t1"."id" ORDER BY "t1"."name" NULLS LAST`, query)
|
||||
})
|
||||
t.Run("O2M", func(t *testing.T) {
|
||||
s := s.Clone()
|
||||
OrderByNeighborTerms(s,
|
||||
@@ -1014,16 +1031,14 @@ func TestOrderByNeighborTerms(t *testing.T) {
|
||||
To("repos", "id"),
|
||||
Edge(O2M, false, "repo", "owner_id"),
|
||||
),
|
||||
sql.OrderByExpr(
|
||||
sql.ExprFunc(func(b *sql.Builder) {
|
||||
b.S("SUM(").Ident("num_stars").S(")")
|
||||
}),
|
||||
sql.OrderBySum(
|
||||
"num_stars",
|
||||
sql.OrderSelectAs("total_stars"),
|
||||
),
|
||||
)
|
||||
query, args := s.Query()
|
||||
require.Empty(t, args)
|
||||
require.Equal(t, `SELECT "users"."name", "t1"."total_stars" FROM "users" LEFT JOIN (SELECT "repo"."owner_id", (SUM("num_stars")) AS "total_stars" FROM "repo" GROUP BY "repo"."owner_id") AS "t1" ON "users"."id" = "t1"."owner_id" ORDER BY "t1"."total_stars" NULLS FIRST`, query)
|
||||
require.Equal(t, `SELECT "users"."name", "t1"."total_stars" FROM "users" LEFT JOIN (SELECT "repo"."owner_id", SUM("repo"."num_stars") AS "total_stars" FROM "repo" GROUP BY "repo"."owner_id") AS "t1" ON "users"."id" = "t1"."owner_id" ORDER BY "t1"."total_stars" NULLS FIRST`, query)
|
||||
})
|
||||
t.Run("M2M", func(t *testing.T) {
|
||||
s := s.Clone()
|
||||
@@ -1033,16 +1048,32 @@ func TestOrderByNeighborTerms(t *testing.T) {
|
||||
To("group", "id"),
|
||||
Edge(M2M, false, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
sql.OrderByExpr(
|
||||
sql.ExprFunc(func(b *sql.Builder) {
|
||||
b.S("SUM(").Ident("num_users").S(")")
|
||||
}),
|
||||
sql.OrderBySum(
|
||||
"num_users",
|
||||
sql.OrderSelectAs("total_users"),
|
||||
),
|
||||
)
|
||||
query, args := s.Query()
|
||||
require.Empty(t, args)
|
||||
require.Equal(t, `SELECT "users"."name", "t1"."total_users" FROM "users" LEFT JOIN (SELECT "user_id", (SUM("num_users")) AS "total_users" FROM "group" JOIN "user_groups" AS "t1" ON "group"."id" = "t1"."group_id" GROUP BY "user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."total_users" NULLS FIRST`, query)
|
||||
require.Equal(t, `SELECT "users"."name", "t1"."total_users" FROM "users" LEFT JOIN (SELECT "user_id", SUM("group"."num_users") AS "total_users" FROM "group" JOIN "user_groups" AS "t1" ON "group"."id" = "t1"."group_id" GROUP BY "user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."total_users" NULLS FIRST`, query)
|
||||
})
|
||||
t.Run("M2M/NullsLast", func(t *testing.T) {
|
||||
s := s.Clone()
|
||||
OrderByNeighborTerms(s,
|
||||
NewStep(
|
||||
From("users", "id"),
|
||||
To("group", "id"),
|
||||
Edge(M2M, false, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
sql.OrderBySum(
|
||||
"num_users",
|
||||
sql.OrderAs("total_users"),
|
||||
sql.OrderNullsLast(),
|
||||
),
|
||||
)
|
||||
query, args := s.Query()
|
||||
require.Empty(t, args)
|
||||
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "user_id", SUM("group"."num_users") AS "total_users" FROM "group" JOIN "user_groups" AS "t1" ON "group"."id" = "t1"."group_id" GROUP BY "user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."total_users" NULLS LAST`, query)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user