mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: add support for window functions (#2431)
This commit is contained in:
@@ -2172,7 +2172,7 @@ func (s *Selector) SelectExpr(exprs ...Querier) *Selector {
|
||||
return s
|
||||
}
|
||||
|
||||
// AppendSelectExpr appends additional expressions to the SELECT statement.
|
||||
// AppendSelectExpr appends additional expressions to the SELECT statement.
|
||||
func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector {
|
||||
for i := range exprs {
|
||||
s.selection = append(s.selection, exprs[i])
|
||||
@@ -2180,7 +2180,18 @@ func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector {
|
||||
return s
|
||||
}
|
||||
|
||||
// SelectedColumns returns the selected columns of the Selector.
|
||||
// AppendSelectExprAs appends additional expressions to the SELECT statement with the given name.
|
||||
func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector {
|
||||
s.selection = append(s.selection, ExprFunc(func(b *Builder) {
|
||||
b.WriteByte('(')
|
||||
b.Join(expr)
|
||||
b.WriteString(") AS ")
|
||||
b.Ident(as)
|
||||
}))
|
||||
return s
|
||||
}
|
||||
|
||||
// SelectedColumns returns the selected columns in the Selector.
|
||||
func (s *Selector) SelectedColumns() []string {
|
||||
columns := make([]string, 0, len(s.selection))
|
||||
for i := range s.selection {
|
||||
@@ -2191,6 +2202,28 @@ func (s *Selector) SelectedColumns() []string {
|
||||
return columns
|
||||
}
|
||||
|
||||
// UnqualifiedColumns returns the an unqualified version of the
|
||||
// selected columns in the Selector. e.g. "t1"."c" => "c".
|
||||
func (s *Selector) UnqualifiedColumns() []string {
|
||||
columns := make([]string, 0, len(s.selection))
|
||||
for i := range s.selection {
|
||||
c, ok := s.selection[i].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if s.isIdent(c) {
|
||||
parts := strings.FieldsFunc(c, func(r rune) bool {
|
||||
return r == '`' || r == '"'
|
||||
})
|
||||
if n := len(parts); n > 0 && parts[n-1] != "" {
|
||||
c = parts[n-1]
|
||||
}
|
||||
}
|
||||
columns = append(columns, c)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
// From sets the source of `FROM` clause.
|
||||
func (s *Selector) From(t TableView) *Selector {
|
||||
s.from = t
|
||||
@@ -2578,6 +2611,18 @@ func (s *Selector) OrderBy(columns ...string) *Selector {
|
||||
return s
|
||||
}
|
||||
|
||||
// OrderColumns returns the ordered columns in the Selector.
|
||||
// Note, this function skips columns selected with expressions.
|
||||
func (s *Selector) OrderColumns() []string {
|
||||
columns := make([]string, 0, len(s.order))
|
||||
for i := range s.order {
|
||||
if c, ok := s.order[i].(string); ok {
|
||||
columns = append(columns, c)
|
||||
}
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
// OrderExpr appends the `ORDER BY` clause to the `SELECT`
|
||||
// statement with custom list of expressions.
|
||||
func (s *Selector) OrderExpr(exprs ...Querier) *Selector {
|
||||
@@ -2626,7 +2671,7 @@ func (s *Selector) Query() (string, []interface{}) {
|
||||
b.Ident(t.as)
|
||||
case *WithBuilder:
|
||||
t.SetDialect(s.dialect)
|
||||
b.Ident(t.name)
|
||||
b.Ident(t.Name())
|
||||
}
|
||||
for _, join := range s.joins {
|
||||
b.WriteString(" " + join.kind + " ")
|
||||
@@ -2643,7 +2688,7 @@ func (s *Selector) Query() (string, []interface{}) {
|
||||
b.Ident(view.as)
|
||||
case *WithBuilder:
|
||||
view.SetDialect(s.dialect)
|
||||
b.Ident(view.name)
|
||||
b.Ident(view.Name())
|
||||
}
|
||||
if join.on != nil {
|
||||
b.WriteString(" ON ")
|
||||
@@ -2764,9 +2809,11 @@ func (*Selector) view() {}
|
||||
type WithBuilder struct {
|
||||
Builder
|
||||
recursive bool
|
||||
name string
|
||||
columns []string
|
||||
s *Selector
|
||||
ctes []struct {
|
||||
name string
|
||||
columns []string
|
||||
s *Selector
|
||||
}
|
||||
}
|
||||
|
||||
// With returns a new builder for the `WITH` statement.
|
||||
@@ -2778,7 +2825,15 @@ type WithBuilder struct {
|
||||
// return n.Query()
|
||||
//
|
||||
func With(name string, columns ...string) *WithBuilder {
|
||||
return &WithBuilder{name: name, columns: columns}
|
||||
return &WithBuilder{
|
||||
ctes: []struct {
|
||||
name string
|
||||
columns []string
|
||||
s *Selector
|
||||
}{
|
||||
{name: name, columns: columns},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WithRecursive returns a new builder for the `WITH RECURSIVE` statement.
|
||||
@@ -2790,22 +2845,32 @@ func With(name string, columns ...string) *WithBuilder {
|
||||
// return n.Query()
|
||||
//
|
||||
func WithRecursive(name string, columns ...string) *WithBuilder {
|
||||
return &WithBuilder{name: name, columns: columns, recursive: true}
|
||||
w := With(name, columns...)
|
||||
w.recursive = true
|
||||
return w
|
||||
}
|
||||
|
||||
// Name returns the name of the view.
|
||||
func (w *WithBuilder) Name() string { return w.name }
|
||||
func (w *WithBuilder) Name() string {
|
||||
return w.ctes[0].name
|
||||
}
|
||||
|
||||
// As sets the view sub query.
|
||||
func (w *WithBuilder) As(s *Selector) *WithBuilder {
|
||||
w.s = s
|
||||
w.ctes[len(w.ctes)-1].s = s
|
||||
return w
|
||||
}
|
||||
|
||||
// With appends another named CTE to the statement.
|
||||
func (w *WithBuilder) With(name string, columns ...string) *WithBuilder {
|
||||
w.ctes = append(w.ctes, With(name, columns...).ctes...)
|
||||
return w
|
||||
}
|
||||
|
||||
// C returns a formatted string for the WITH column.
|
||||
func (w *WithBuilder) C(column string) string {
|
||||
b := &Builder{dialect: w.dialect}
|
||||
b.Ident(w.name).WriteByte('.').Ident(column)
|
||||
b.Ident(w.Name()).WriteByte('.').Ident(column)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@@ -2815,22 +2880,87 @@ func (w *WithBuilder) Query() (string, []interface{}) {
|
||||
if w.recursive {
|
||||
w.WriteString("RECURSIVE ")
|
||||
}
|
||||
w.Ident(w.name)
|
||||
if len(w.columns) > 0 {
|
||||
w.WriteByte('(')
|
||||
w.IdentComma(w.columns...)
|
||||
w.WriteByte(')')
|
||||
for i, cte := range w.ctes {
|
||||
if i > 0 {
|
||||
w.Comma()
|
||||
}
|
||||
w.Ident(cte.name)
|
||||
if len(cte.columns) > 0 {
|
||||
w.WriteByte('(')
|
||||
w.IdentComma(cte.columns...)
|
||||
w.WriteByte(')')
|
||||
}
|
||||
w.WriteString(" AS ")
|
||||
w.Nested(func(b *Builder) {
|
||||
b.Join(cte.s)
|
||||
})
|
||||
}
|
||||
w.WriteString(" AS ")
|
||||
w.Nested(func(b *Builder) {
|
||||
b.Join(w.s)
|
||||
})
|
||||
return w.String(), w.args
|
||||
}
|
||||
|
||||
// implement the table view interface.
|
||||
func (*WithBuilder) view() {}
|
||||
|
||||
// WindowBuilder represents a builder for a window clause.
|
||||
// Note that window functions support is limited and used
|
||||
// only to query rows-limited edges in pagination.
|
||||
type WindowBuilder struct {
|
||||
Builder
|
||||
fn string // e.g. ROW_NUMBER(), RANK().
|
||||
partition func(*Builder)
|
||||
order func(*Builder)
|
||||
}
|
||||
|
||||
// RowNumber returns a new window clause with the ROW_NUMBER() as a function.
|
||||
// Using this function will assign a each row a number, from 1 to N, in the
|
||||
// order defined by the ORDER BY clause in the window spec.
|
||||
func RowNumber() *WindowBuilder {
|
||||
return &WindowBuilder{fn: "ROW_NUMBER"}
|
||||
}
|
||||
|
||||
// PartitionBy indicates to divide the query rows into groups by the given columns.
|
||||
// Note that, standard SQL spec allows partition only by columns, and in order to
|
||||
// use the "expression" version, use the PartitionByExpr.
|
||||
func (w *WindowBuilder) PartitionBy(columns ...string) *WindowBuilder {
|
||||
w.partition = func(b *Builder) {
|
||||
b.IdentComma(columns...)
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// PartitionExpr indicates to divide the query rows into groups by the given expression.
|
||||
func (w *WindowBuilder) PartitionExpr(x Querier) *WindowBuilder {
|
||||
w.partition = func(b *Builder) {
|
||||
b.Join(x)
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// OrderBy indicates how to sort rows in each partition.
|
||||
func (w *WindowBuilder) OrderBy(columns ...string) *WindowBuilder {
|
||||
w.order = func(b *Builder) {
|
||||
b.IdentComma(columns...)
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// Query returns query representation of the window function.
|
||||
func (w *WindowBuilder) Query() (string, []interface{}) {
|
||||
w.WriteString(w.fn)
|
||||
w.WriteString("() OVER ")
|
||||
w.Nested(func(b *Builder) {
|
||||
if w.partition != nil {
|
||||
b.WriteString("PARTITION BY ")
|
||||
w.partition(b)
|
||||
}
|
||||
if w.order != nil {
|
||||
b.WriteString(" ORDER BY ")
|
||||
w.order(b)
|
||||
}
|
||||
})
|
||||
return w.Builder.String(), w.args
|
||||
}
|
||||
|
||||
// Wrapper wraps a given Querier with different format.
|
||||
// Used to prefix/suffix other queries.
|
||||
type Wrapper struct {
|
||||
|
||||
@@ -2026,3 +2026,38 @@ func TestBoolPredicates(t *testing.T) {
|
||||
require.Nil(t, args)
|
||||
require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query)
|
||||
}
|
||||
|
||||
func TestWindowFunction(t *testing.T) {
|
||||
posts := Table("posts")
|
||||
base := Select(posts.Columns("id", "content", "author_id")...).
|
||||
From(posts).
|
||||
Where(EQ("active", true))
|
||||
with := With("active_posts").
|
||||
As(base).
|
||||
With("selected_posts").
|
||||
As(
|
||||
Select().
|
||||
AppendSelect("*").
|
||||
AppendSelectExprAs(
|
||||
RowNumber().PartitionBy("author_id").OrderBy("id"),
|
||||
"row_number",
|
||||
).
|
||||
From(Table("active_posts")),
|
||||
)
|
||||
query, args := Select("*").From(Table("selected_posts")).Where(LTE("row_number", 2)).Prefix(with).Query()
|
||||
require.Equal(t, "WITH `active_posts` AS (SELECT `posts`.`id`, `posts`.`content`, `posts`.`author_id` FROM `posts` WHERE `active`), `selected_posts` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`)) AS `row_number` FROM `active_posts`) SELECT * FROM `selected_posts` WHERE `row_number` <= ?", query)
|
||||
require.Equal(t, []interface{}{2}, args)
|
||||
}
|
||||
|
||||
func TestSelector_UnqualifiedColumns(t *testing.T) {
|
||||
t1, t2 := Table("t1"), Table("t2")
|
||||
s := Select(t1.C("a"), t2.C("b"))
|
||||
require.Equal(t, []string{"`t1`.`a`", "`t2`.`b`"}, s.SelectedColumns())
|
||||
require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns())
|
||||
|
||||
d := Dialect(dialect.Postgres)
|
||||
t1, t2 = d.Table("t1"), d.Table("t2")
|
||||
s = d.Select(t1.C("a"), t2.C("b"))
|
||||
require.Equal(t, []string{`"t1"."a"`, `"t2"."b"`}, s.SelectedColumns())
|
||||
require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user