From 1c263c7abd1070f31dd3d55d33ac58e7753bfa9d Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Thu, 16 Dec 2021 13:50:03 +0200 Subject: [PATCH] dialect/sql: add support for SelectExpr (#2220) --- dialect/sql/builder.go | 140 +++++++++++++++++++-------- dialect/sql/builder_test.go | 34 +++++++ entc/integration/integration_test.go | 10 ++ 3 files changed, 145 insertions(+), 39 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 68667e7dd..7672f57f5 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -2032,23 +2032,23 @@ type Selector struct { Builder // ctx stores contextual data typically from // generated code such as alternate table schemas. - ctx context.Context - as string - columns []string - from TableView - joins []join - where *Predicate - or bool - not bool - order []interface{} - group []string - having *Predicate - limit *int - offset *int - distinct bool - union []union - prefix Queries - lock *LockOptions + ctx context.Context + as string + selection []interface{} + from TableView + joins []join + where *Predicate + or bool + not bool + order []interface{} + group []string + having *Predicate + limit *int + offset *int + distinct bool + union []union + prefix Queries + lock *LockOptions } // WithContext sets the context into the *Selector. @@ -2082,22 +2082,57 @@ func Select(columns ...string) *Selector { return (&Selector{}).Select(columns...) } +// SelectExpr is like Select, but supports passing arbitrary +// expressions for SELECT clause. +func SelectExpr(exprs ...Querier) *Selector { + return (&Selector{}).SelectExpr(exprs...) +} + // Select changes the columns selection of the SELECT statement. // Empty selection means all columns *. func (s *Selector) Select(columns ...string) *Selector { - s.columns = columns + s.selection = make([]interface{}, len(columns)) + for i := range columns { + s.selection[i] = columns[i] + } return s } -// AppendSelect appends additional columns/expressions to the SELECT statement. +// AppendSelect appends additional columns to the SELECT statement. func (s *Selector) AppendSelect(columns ...string) *Selector { - s.columns = append(s.columns, columns...) + for i := range columns { + s.selection = append(s.selection, columns[i]) + } + return s +} + +// SelectExpr changes the columns selection of the SELECT statement +// with custom list of expressions. +func (s *Selector) SelectExpr(exprs ...Querier) *Selector { + s.selection = make([]interface{}, len(exprs)) + for i := range exprs { + s.selection[i] = exprs[i] + } + return s +} + +// AppendSelectExpr appends additional expressions to the SELECT statement. +func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector { + for i := range exprs { + s.selection = append(s.selection, exprs[i]) + } return s } // SelectedColumns returns the selected columns of the Selector. func (s *Selector) SelectedColumns() []string { - return s.columns + columns := make([]string, 0, len(s.selection)) + for i := range s.selection { + if c, ok := s.selection[i].(string); ok { + columns = append(columns, c) + } + } + return columns } // From sets the source of `FROM` clause. @@ -2339,7 +2374,7 @@ func (s *Selector) Count(columns ...string) *Selector { b.IdentComma(columns...) column = b.String() } - s.columns = []string{Count(column)} + s.Select(Count(column)) return s } @@ -2447,21 +2482,21 @@ func (s *Selector) Clone() *Selector { joins[i] = s.joins[i].clone() } return &Selector{ - Builder: s.Builder.clone(), - ctx: s.ctx, - as: s.as, - or: s.or, - not: s.not, - from: s.from, - limit: s.limit, - offset: s.offset, - distinct: s.distinct, - where: s.where.clone(), - having: s.having.clone(), - joins: append([]join{}, joins...), - group: append([]string{}, s.group...), - order: append([]interface{}{}, s.order...), - columns: append([]string{}, s.columns...), + Builder: s.Builder.clone(), + ctx: s.ctx, + as: s.as, + or: s.or, + not: s.not, + from: s.from, + limit: s.limit, + offset: s.offset, + distinct: s.distinct, + where: s.where.clone(), + having: s.having.clone(), + joins: append([]join{}, joins...), + group: append([]string{}, s.group...), + order: append([]interface{}{}, s.order...), + selection: append([]interface{}{}, s.selection...), } } @@ -2516,8 +2551,8 @@ func (s *Selector) Query() (string, []interface{}) { if s.distinct { b.WriteString("DISTINCT ") } - if len(s.columns) > 0 { - b.IdentComma(s.columns...) + if len(s.selection) > 0 { + s.joinSelect(&b) } else { b.WriteString("*") } @@ -2652,6 +2687,20 @@ func (s *Selector) joinOrder(b *Builder) { } } +func (s *Selector) joinSelect(b *Builder) { + for i := range s.selection { + if i > 0 { + b.Comma() + } + switch s := s.selection[i].(type) { + case string: + b.Ident(s) + case Querier: + b.Join(s) + } + } +} + // implement the table view interface. func (*Selector) view() {} @@ -3323,6 +3372,19 @@ func (d *DialectBuilder) Select(columns ...string) *Selector { return b } +// SelectExpr is like Select, but supports passing arbitrary +// expressions for SELECT clause. +// +// Dialect(dialect.Postgres). +// SelectExpr(expr...). +// From(Table("users")) +// +func (d *DialectBuilder) SelectExpr(exprs ...Querier) *Selector { + b := SelectExpr(exprs...) + b.SetDialect(d.dialect) + return b +} + // Table creates a SelectTable for the configured dialect. // // Dialect(dialect.Postgres). diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index fd328c792..5285ed388 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1552,6 +1552,40 @@ func TestSelector_OrderByExpr(t *testing.T) { require.Equal(t, []interface{}{28, 1, 2}, args) } +func TestSelector_SelectExpr(t *testing.T) { + query, args := SelectExpr( + Expr("?", "a"), + ExprFunc(func(b *Builder) { + b.Ident("first_name").WriteOp(OpAdd).Ident("last_name") + }), + ExprFunc(func(b *Builder) { + b.WriteString("COALESCE(").Ident("age").Comma().Arg(0).WriteByte(')') + }), + Expr("?", "b"), + ).From(Table("users")).Query() + require.Equal(t, "SELECT ?, `first_name` + `last_name`, COALESCE(`age`, ?), ? FROM `users`", query) + require.Equal(t, []interface{}{"a", 0, "b"}, args) + + query, args = Dialect(dialect.Postgres). + Select("name"). + AppendSelectExpr( + Expr("age + $1", 1), + ExprFunc(func(b *Builder) { + b.Nested(func(b *Builder) { + b.WriteString("similarity(").Ident("name").Comma().Arg("A").WriteByte(')') + b.WriteOp(OpAdd) + b.WriteString("similarity(").Ident("desc").Comma().Arg("D").WriteByte(')') + }) + b.WriteString(" AS s") + }), + Expr("rank + $4", 10), + ). + From(Table("users")). + Query() + require.Equal(t, `SELECT "name", age + $1, (similarity("name", $2) + similarity("desc", $3)) AS s, rank + $4 FROM "users"`, query) + require.Equal(t, []interface{}{1, "A", "D", 10}, args) +} + func TestSelector_Union(t *testing.T) { query, args := Dialect(dialect.Postgres). Select("*"). diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index d5b78fb73..8d806e3c0 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -600,6 +600,16 @@ func Select(t *testing.T, client *ent.Client) { }). IntsX(ctx) require.Equal([]int{1, 1, 1, 1}, lens) + + dlen := client.Pet.Query(). + Modify(func(s *sql.Selector) { + s.SelectExpr(sql.ExprFunc(func(b *sql.Builder) { + b.WriteString("LENGTH(name)").WriteOp(sql.OpMul).Arg(2) + })) + }). + IntsX(ctx) + require.Equal([]int{2, 2, 2, 2}, dlen) + for i := range pets { pets[i].Update().SetName(pets[i].Name + pets[i].Name).ExecX(ctx) }