dialect/sql: support defining multiple FROM in query (#2933)

This commit is contained in:
Ariel Mashraki
2022-09-16 00:10:11 +03:00
committed by GitHub
parent 4425d1a6e1
commit 0adfb94c30
2 changed files with 100 additions and 25 deletions

View File

@@ -1095,8 +1095,8 @@ func (u *UpdateBuilder) Where(p *Predicate) *UpdateBuilder {
// FromSelect makes it possible to update entities that match the sub-query. // FromSelect makes it possible to update entities that match the sub-query.
func (u *UpdateBuilder) FromSelect(s *Selector) *UpdateBuilder { func (u *UpdateBuilder) FromSelect(s *Selector) *UpdateBuilder {
u.Where(s.where) u.Where(s.where)
if table, _ := s.from.(*SelectTable); table != nil { if t := s.Table(); t != nil {
u.table = table.name u.table = t.name
} }
return u return u
} }
@@ -1211,8 +1211,8 @@ func (d *DeleteBuilder) Where(p *Predicate) *DeleteBuilder {
// FromSelect makes it possible to delete a sub query. // FromSelect makes it possible to delete a sub query.
func (d *DeleteBuilder) FromSelect(s *Selector) *DeleteBuilder { func (d *DeleteBuilder) FromSelect(s *Selector) *DeleteBuilder {
d.Where(s.where) d.Where(s.where)
if table, _ := s.from.(*SelectTable); table != nil { if t := s.Table(); t != nil {
d.table = table.name d.table = t.name
} }
return d return d
} }
@@ -1994,6 +1994,11 @@ type TableView interface {
view() view()
} }
// queryView allows using Querier (expressions) in the FROM clause.
type queryView struct{ Querier }
func (*queryView) view() {}
// SelectTable is a table selector. // SelectTable is a table selector.
type SelectTable struct { type SelectTable struct {
Builder Builder
@@ -2096,7 +2101,7 @@ type Selector struct {
ctx context.Context ctx context.Context
as string as string
selection []any selection []any
from TableView from []TableView
joins []join joins []join
where *Predicate where *Predicate
or bool or bool
@@ -2230,13 +2235,34 @@ func (s *Selector) UnqualifiedColumns() []string {
// From sets the source of `FROM` clause. // From sets the source of `FROM` clause.
func (s *Selector) From(t TableView) *Selector { func (s *Selector) From(t TableView) *Selector {
s.from = t s.from = nil
return s.AppendFrom(t)
}
// AppendFrom appends a new TableView to the `FROM` clause.
func (s *Selector) AppendFrom(t TableView) *Selector {
s.from = append(s.from, t)
if st, ok := t.(state); ok { if st, ok := t.(state); ok {
st.SetDialect(s.dialect) st.SetDialect(s.dialect)
} }
return s return s
} }
// FromExpr sets the expression of `FROM` clause.
func (s *Selector) FromExpr(x Querier) *Selector {
s.from = nil
return s.AppendFromExpr(x)
}
// AppendFromExpr appends an expression (Queries) to the `FROM` clause.
func (s *Selector) AppendFromExpr(x Querier) *Selector {
s.from = append(s.from, &queryView{Querier: x})
if st, ok := x.(state); ok {
st.SetDialect(s.dialect)
}
return s
}
// Distinct adds the DISTINCT keyword to the `SELECT` statement. // Distinct adds the DISTINCT keyword to the `SELECT` statement.
func (s *Selector) Distinct() *Selector { func (s *Selector) Distinct() *Selector {
s.distinct = true s.distinct = true
@@ -2312,12 +2338,15 @@ func (s *Selector) Or() *Selector {
// Table returns the selected table. // Table returns the selected table.
func (s *Selector) Table() *SelectTable { func (s *Selector) Table() *SelectTable {
return s.from.(*SelectTable) if len(s.from) == 0 {
return nil
}
return s.from[0].(*SelectTable)
} }
// TableName returns the name of the selected table or alias of selector. // TableName returns the name of the selected table or alias of selector.
func (s *Selector) TableName() string { func (s *Selector) TableName() string {
switch view := s.from.(type) { switch view := s.from[0].(type) {
case *SelectTable: case *SelectTable:
return view.name return view.name
case *Selector: case *Selector:
@@ -2665,13 +2694,18 @@ func (s *Selector) Query() (string, []any) {
} else { } else {
b.WriteString("*") b.WriteString("*")
} }
switch t := s.from.(type) { if len(s.from) > 0 {
case *SelectTable:
b.WriteString(" FROM ") b.WriteString(" FROM ")
}
for i, from := range s.from {
if i > 0 {
b.Comma()
}
switch t := from.(type) {
case *SelectTable:
t.SetDialect(s.dialect) t.SetDialect(s.dialect)
b.WriteString(t.ref()) b.WriteString(t.ref())
case *Selector: case *Selector:
b.WriteString(" FROM ")
t.SetDialect(s.dialect) t.SetDialect(s.dialect)
b.Nested(func(b *Builder) { b.Nested(func(b *Builder) {
b.Join(t) b.Join(t)
@@ -2679,9 +2713,11 @@ func (s *Selector) Query() (string, []any) {
b.WriteString(" AS ") b.WriteString(" AS ")
b.Ident(t.as) b.Ident(t.as)
case *WithBuilder: case *WithBuilder:
b.WriteString(" FROM ")
t.SetDialect(s.dialect) t.SetDialect(s.dialect)
b.Ident(t.Name()) b.Ident(t.Name())
case *queryView:
b.Join(t.Querier)
}
} }
for _, join := range s.joins { for _, join := range s.joins {
b.WriteString(" " + join.kind + " ") b.WriteString(" " + join.kind + " ")
@@ -3109,7 +3145,7 @@ func (b *Builder) Quote(ident string) string {
func (b *Builder) Ident(s string) *Builder { func (b *Builder) Ident(s string) *Builder {
switch { switch {
case len(s) == 0: case len(s) == 0:
case s != "*" && !b.isIdent(s) && !isFunc(s) && !isModifier(s): case !strings.HasSuffix(s, "*") && !b.isIdent(s) && !isFunc(s) && !isModifier(s):
if b.qualifier != "" { if b.qualifier != "" {
b.WriteString(b.Quote(b.qualifier)).WriteByte('.') b.WriteString(b.Quote(b.qualifier)).WriteByte('.')
} }

View File

@@ -2193,3 +2193,42 @@ func TestUpdateBuilder_WithPrefix(t *testing.T) {
require.Empty(t, args) require.Empty(t, args)
require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query) require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query)
} }
func TestMultipleFrom(t *testing.T) {
query, args := Dialect(dialect.Postgres).
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
From(Table("items")).
AppendFrom(Table("to_tsquery('neutrino|(dark & matter)')").As("search_query")).
Where(P(func(b *Builder) {
b.WriteString("search @@ search_query")
})).
OrderBy(Desc("rank")).
Query()
require.Empty(t, args)
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery('neutrino|(dark & matter)') AS "search_query" WHERE search @@ search_query ORDER BY "rank" DESC`, query)
query, args = Dialect(dialect.Postgres).
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
From(Table("items")).
AppendFromExpr(Expr("to_tsquery($1) AS search_query", "neutrino|(dark & matter)")).
Where(P(func(b *Builder) {
b.WriteString("search @@ search_query")
})).
Query()
require.Equal(t, []any{"neutrino|(dark & matter)"}, args)
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery($1) AS search_query WHERE search @@ search_query`, query)
query, args = Dialect(dialect.Postgres).
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
From(Table("items")).
Where(EQ("value", 10)).
AppendFromExpr(ExprFunc(func(b *Builder) {
b.WriteString("to_tsquery(").Arg("neutrino|(dark & matter)").WriteString(") AS search_query")
})).
Where(P(func(b *Builder) {
b.WriteString("search @@ search_query")
})).
Query()
require.Equal(t, []any{"neutrino|(dark & matter)", 10}, args)
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery($1) AS search_query WHERE "value" = $2 AND search @@ search_query`, query)
}