diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 0e5c1b122..1ae4eea42 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -2951,7 +2951,7 @@ func (*WithBuilder) view() {} // only to query rows-limited edges in pagination. type WindowBuilder struct { Builder - fn string // e.g. ROW_NUMBER(), RANK(). + fn func(*Builder) // e.g. ROW_NUMBER(), RANK() partition func(*Builder) order []any } @@ -2960,7 +2960,19 @@ type WindowBuilder struct { // Using this function will assign a each row a number, from 1 to N, in the // order defined by the ORDER BY clause in the window spec. func RowNumber() *WindowBuilder { - return &WindowBuilder{fn: "ROW_NUMBER"} + return Window(func(b *Builder) { + b.WriteString("ROW_NUMBER()") + }) +} + +// Window returns a new window clause with a custom selector allowing +// for custom windown functions. +// +// Window(func(b *Builder) { +// b.WriteString(Sum(posts.C("duration"))) +// }).PartitionBy("author_id").OrderBy("id"), "duration"). +func Window(fn func(*Builder)) *WindowBuilder { + return &WindowBuilder{fn: fn} } // PartitionBy indicates to divide the query rows into groups by the given columns. @@ -3000,8 +3012,8 @@ func (w *WindowBuilder) OrderExpr(exprs ...Querier) *WindowBuilder { // Query returns query representation of the window function. func (w *WindowBuilder) Query() (string, []any) { - w.WriteString(w.fn) - w.WriteString("() OVER ") + w.fn(&w.Builder) + w.WriteString(" OVER ") w.Wrap(func(b *Builder) { if w.partition != nil { b.WriteString("PARTITION BY ") diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 451db0f2d..813132578 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -2146,6 +2146,21 @@ func TestWindowFunction(t *testing.T) { require.Equal(t, []any{2}, args) } +func TestWindowFunction_Select(t *testing.T) { + posts := Table("posts") + q := Select(). + AppendSelect("*"). + AppendSelectExprAs( + Window(func(b *Builder) { + b.WriteString(Sum(posts.C("duration"))) + }).PartitionBy("author_id").OrderBy("id"), "duration"). + From(posts) + + query, args := q.Query() + require.Equal(t, "SELECT *, (SUM(`posts`.`duration`) OVER (PARTITION BY `author_id` ORDER BY `id`)) AS `duration` FROM `posts`", query) + require.Nil(t, args) +} + func TestSelector_UnqualifiedColumns(t *testing.T) { t1, t2 := Table("t1"), Table("t2") s := Select(t1.C("a"), t2.C("b"))