mirror of
https://github.com/ent/ent.git
synced 2026-03-05 19:35:23 +03:00
dialect/sql: support defining multiple FROM in query (#2933)
This commit is contained in:
@@ -1095,8 +1095,8 @@ func (u *UpdateBuilder) Where(p *Predicate) *UpdateBuilder {
|
||||
// FromSelect makes it possible to update entities that match the sub-query.
|
||||
func (u *UpdateBuilder) FromSelect(s *Selector) *UpdateBuilder {
|
||||
u.Where(s.where)
|
||||
if table, _ := s.from.(*SelectTable); table != nil {
|
||||
u.table = table.name
|
||||
if t := s.Table(); t != nil {
|
||||
u.table = t.name
|
||||
}
|
||||
return u
|
||||
}
|
||||
@@ -1211,8 +1211,8 @@ func (d *DeleteBuilder) Where(p *Predicate) *DeleteBuilder {
|
||||
// FromSelect makes it possible to delete a sub query.
|
||||
func (d *DeleteBuilder) FromSelect(s *Selector) *DeleteBuilder {
|
||||
d.Where(s.where)
|
||||
if table, _ := s.from.(*SelectTable); table != nil {
|
||||
d.table = table.name
|
||||
if t := s.Table(); t != nil {
|
||||
d.table = t.name
|
||||
}
|
||||
return d
|
||||
}
|
||||
@@ -1994,6 +1994,11 @@ type TableView interface {
|
||||
view()
|
||||
}
|
||||
|
||||
// queryView allows using Querier (expressions) in the FROM clause.
|
||||
type queryView struct{ Querier }
|
||||
|
||||
func (*queryView) view() {}
|
||||
|
||||
// SelectTable is a table selector.
|
||||
type SelectTable struct {
|
||||
Builder
|
||||
@@ -2096,7 +2101,7 @@ type Selector struct {
|
||||
ctx context.Context
|
||||
as string
|
||||
selection []any
|
||||
from TableView
|
||||
from []TableView
|
||||
joins []join
|
||||
where *Predicate
|
||||
or bool
|
||||
@@ -2230,13 +2235,34 @@ func (s *Selector) UnqualifiedColumns() []string {
|
||||
|
||||
// From sets the source of `FROM` clause.
|
||||
func (s *Selector) From(t TableView) *Selector {
|
||||
s.from = t
|
||||
s.from = nil
|
||||
return s.AppendFrom(t)
|
||||
}
|
||||
|
||||
// AppendFrom appends a new TableView to the `FROM` clause.
|
||||
func (s *Selector) AppendFrom(t TableView) *Selector {
|
||||
s.from = append(s.from, t)
|
||||
if st, ok := t.(state); ok {
|
||||
st.SetDialect(s.dialect)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// FromExpr sets the expression of `FROM` clause.
|
||||
func (s *Selector) FromExpr(x Querier) *Selector {
|
||||
s.from = nil
|
||||
return s.AppendFromExpr(x)
|
||||
}
|
||||
|
||||
// AppendFromExpr appends an expression (Queries) to the `FROM` clause.
|
||||
func (s *Selector) AppendFromExpr(x Querier) *Selector {
|
||||
s.from = append(s.from, &queryView{Querier: x})
|
||||
if st, ok := x.(state); ok {
|
||||
st.SetDialect(s.dialect)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Distinct adds the DISTINCT keyword to the `SELECT` statement.
|
||||
func (s *Selector) Distinct() *Selector {
|
||||
s.distinct = true
|
||||
@@ -2312,12 +2338,15 @@ func (s *Selector) Or() *Selector {
|
||||
|
||||
// Table returns the selected table.
|
||||
func (s *Selector) Table() *SelectTable {
|
||||
return s.from.(*SelectTable)
|
||||
if len(s.from) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.from[0].(*SelectTable)
|
||||
}
|
||||
|
||||
// TableName returns the name of the selected table or alias of selector.
|
||||
func (s *Selector) TableName() string {
|
||||
switch view := s.from.(type) {
|
||||
switch view := s.from[0].(type) {
|
||||
case *SelectTable:
|
||||
return view.name
|
||||
case *Selector:
|
||||
@@ -2665,23 +2694,30 @@ func (s *Selector) Query() (string, []any) {
|
||||
} else {
|
||||
b.WriteString("*")
|
||||
}
|
||||
switch t := s.from.(type) {
|
||||
case *SelectTable:
|
||||
if len(s.from) > 0 {
|
||||
b.WriteString(" FROM ")
|
||||
t.SetDialect(s.dialect)
|
||||
b.WriteString(t.ref())
|
||||
case *Selector:
|
||||
b.WriteString(" FROM ")
|
||||
t.SetDialect(s.dialect)
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Join(t)
|
||||
})
|
||||
b.WriteString(" AS ")
|
||||
b.Ident(t.as)
|
||||
case *WithBuilder:
|
||||
b.WriteString(" FROM ")
|
||||
t.SetDialect(s.dialect)
|
||||
b.Ident(t.Name())
|
||||
}
|
||||
for i, from := range s.from {
|
||||
if i > 0 {
|
||||
b.Comma()
|
||||
}
|
||||
switch t := from.(type) {
|
||||
case *SelectTable:
|
||||
t.SetDialect(s.dialect)
|
||||
b.WriteString(t.ref())
|
||||
case *Selector:
|
||||
t.SetDialect(s.dialect)
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Join(t)
|
||||
})
|
||||
b.WriteString(" AS ")
|
||||
b.Ident(t.as)
|
||||
case *WithBuilder:
|
||||
t.SetDialect(s.dialect)
|
||||
b.Ident(t.Name())
|
||||
case *queryView:
|
||||
b.Join(t.Querier)
|
||||
}
|
||||
}
|
||||
for _, join := range s.joins {
|
||||
b.WriteString(" " + join.kind + " ")
|
||||
@@ -3109,7 +3145,7 @@ func (b *Builder) Quote(ident string) string {
|
||||
func (b *Builder) Ident(s string) *Builder {
|
||||
switch {
|
||||
case len(s) == 0:
|
||||
case s != "*" && !b.isIdent(s) && !isFunc(s) && !isModifier(s):
|
||||
case !strings.HasSuffix(s, "*") && !b.isIdent(s) && !isFunc(s) && !isModifier(s):
|
||||
if b.qualifier != "" {
|
||||
b.WriteString(b.Quote(b.qualifier)).WriteByte('.')
|
||||
}
|
||||
|
||||
@@ -2193,3 +2193,42 @@ func TestUpdateBuilder_WithPrefix(t *testing.T) {
|
||||
require.Empty(t, args)
|
||||
require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query)
|
||||
}
|
||||
|
||||
func TestMultipleFrom(t *testing.T) {
|
||||
query, args := Dialect(dialect.Postgres).
|
||||
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
|
||||
From(Table("items")).
|
||||
AppendFrom(Table("to_tsquery('neutrino|(dark & matter)')").As("search_query")).
|
||||
Where(P(func(b *Builder) {
|
||||
b.WriteString("search @@ search_query")
|
||||
})).
|
||||
OrderBy(Desc("rank")).
|
||||
Query()
|
||||
require.Empty(t, args)
|
||||
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery('neutrino|(dark & matter)') AS "search_query" WHERE search @@ search_query ORDER BY "rank" DESC`, query)
|
||||
|
||||
query, args = Dialect(dialect.Postgres).
|
||||
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
|
||||
From(Table("items")).
|
||||
AppendFromExpr(Expr("to_tsquery($1) AS search_query", "neutrino|(dark & matter)")).
|
||||
Where(P(func(b *Builder) {
|
||||
b.WriteString("search @@ search_query")
|
||||
})).
|
||||
Query()
|
||||
require.Equal(t, []any{"neutrino|(dark & matter)"}, args)
|
||||
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery($1) AS search_query WHERE search @@ search_query`, query)
|
||||
|
||||
query, args = Dialect(dialect.Postgres).
|
||||
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
|
||||
From(Table("items")).
|
||||
Where(EQ("value", 10)).
|
||||
AppendFromExpr(ExprFunc(func(b *Builder) {
|
||||
b.WriteString("to_tsquery(").Arg("neutrino|(dark & matter)").WriteString(") AS search_query")
|
||||
})).
|
||||
Where(P(func(b *Builder) {
|
||||
b.WriteString("search @@ search_query")
|
||||
})).
|
||||
Query()
|
||||
require.Equal(t, []any{"neutrino|(dark & matter)", 10}, args)
|
||||
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery($1) AS search_query WHERE "value" = $2 AND search @@ search_query`, query)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user