mirror of
https://github.com/ent/ent.git
synced 2026-04-28 05:30:56 +03:00
dialect/sql: support defining multiple FROM in query (#2933)
This commit is contained in:
@@ -1095,8 +1095,8 @@ func (u *UpdateBuilder) Where(p *Predicate) *UpdateBuilder {
|
|||||||
// FromSelect makes it possible to update entities that match the sub-query.
|
// FromSelect makes it possible to update entities that match the sub-query.
|
||||||
func (u *UpdateBuilder) FromSelect(s *Selector) *UpdateBuilder {
|
func (u *UpdateBuilder) FromSelect(s *Selector) *UpdateBuilder {
|
||||||
u.Where(s.where)
|
u.Where(s.where)
|
||||||
if table, _ := s.from.(*SelectTable); table != nil {
|
if t := s.Table(); t != nil {
|
||||||
u.table = table.name
|
u.table = t.name
|
||||||
}
|
}
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
@@ -1211,8 +1211,8 @@ func (d *DeleteBuilder) Where(p *Predicate) *DeleteBuilder {
|
|||||||
// FromSelect makes it possible to delete a sub query.
|
// FromSelect makes it possible to delete a sub query.
|
||||||
func (d *DeleteBuilder) FromSelect(s *Selector) *DeleteBuilder {
|
func (d *DeleteBuilder) FromSelect(s *Selector) *DeleteBuilder {
|
||||||
d.Where(s.where)
|
d.Where(s.where)
|
||||||
if table, _ := s.from.(*SelectTable); table != nil {
|
if t := s.Table(); t != nil {
|
||||||
d.table = table.name
|
d.table = t.name
|
||||||
}
|
}
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
@@ -1994,6 +1994,11 @@ type TableView interface {
|
|||||||
view()
|
view()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// queryView allows using Querier (expressions) in the FROM clause.
|
||||||
|
type queryView struct{ Querier }
|
||||||
|
|
||||||
|
func (*queryView) view() {}
|
||||||
|
|
||||||
// SelectTable is a table selector.
|
// SelectTable is a table selector.
|
||||||
type SelectTable struct {
|
type SelectTable struct {
|
||||||
Builder
|
Builder
|
||||||
@@ -2096,7 +2101,7 @@ type Selector struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
as string
|
as string
|
||||||
selection []any
|
selection []any
|
||||||
from TableView
|
from []TableView
|
||||||
joins []join
|
joins []join
|
||||||
where *Predicate
|
where *Predicate
|
||||||
or bool
|
or bool
|
||||||
@@ -2230,13 +2235,34 @@ func (s *Selector) UnqualifiedColumns() []string {
|
|||||||
|
|
||||||
// From sets the source of `FROM` clause.
|
// From sets the source of `FROM` clause.
|
||||||
func (s *Selector) From(t TableView) *Selector {
|
func (s *Selector) From(t TableView) *Selector {
|
||||||
s.from = t
|
s.from = nil
|
||||||
|
return s.AppendFrom(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendFrom appends a new TableView to the `FROM` clause.
|
||||||
|
func (s *Selector) AppendFrom(t TableView) *Selector {
|
||||||
|
s.from = append(s.from, t)
|
||||||
if st, ok := t.(state); ok {
|
if st, ok := t.(state); ok {
|
||||||
st.SetDialect(s.dialect)
|
st.SetDialect(s.dialect)
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FromExpr sets the expression of `FROM` clause.
|
||||||
|
func (s *Selector) FromExpr(x Querier) *Selector {
|
||||||
|
s.from = nil
|
||||||
|
return s.AppendFromExpr(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendFromExpr appends an expression (Queries) to the `FROM` clause.
|
||||||
|
func (s *Selector) AppendFromExpr(x Querier) *Selector {
|
||||||
|
s.from = append(s.from, &queryView{Querier: x})
|
||||||
|
if st, ok := x.(state); ok {
|
||||||
|
st.SetDialect(s.dialect)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// Distinct adds the DISTINCT keyword to the `SELECT` statement.
|
// Distinct adds the DISTINCT keyword to the `SELECT` statement.
|
||||||
func (s *Selector) Distinct() *Selector {
|
func (s *Selector) Distinct() *Selector {
|
||||||
s.distinct = true
|
s.distinct = true
|
||||||
@@ -2312,12 +2338,15 @@ func (s *Selector) Or() *Selector {
|
|||||||
|
|
||||||
// Table returns the selected table.
|
// Table returns the selected table.
|
||||||
func (s *Selector) Table() *SelectTable {
|
func (s *Selector) Table() *SelectTable {
|
||||||
return s.from.(*SelectTable)
|
if len(s.from) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.from[0].(*SelectTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName returns the name of the selected table or alias of selector.
|
// TableName returns the name of the selected table or alias of selector.
|
||||||
func (s *Selector) TableName() string {
|
func (s *Selector) TableName() string {
|
||||||
switch view := s.from.(type) {
|
switch view := s.from[0].(type) {
|
||||||
case *SelectTable:
|
case *SelectTable:
|
||||||
return view.name
|
return view.name
|
||||||
case *Selector:
|
case *Selector:
|
||||||
@@ -2665,13 +2694,18 @@ func (s *Selector) Query() (string, []any) {
|
|||||||
} else {
|
} else {
|
||||||
b.WriteString("*")
|
b.WriteString("*")
|
||||||
}
|
}
|
||||||
switch t := s.from.(type) {
|
if len(s.from) > 0 {
|
||||||
case *SelectTable:
|
|
||||||
b.WriteString(" FROM ")
|
b.WriteString(" FROM ")
|
||||||
|
}
|
||||||
|
for i, from := range s.from {
|
||||||
|
if i > 0 {
|
||||||
|
b.Comma()
|
||||||
|
}
|
||||||
|
switch t := from.(type) {
|
||||||
|
case *SelectTable:
|
||||||
t.SetDialect(s.dialect)
|
t.SetDialect(s.dialect)
|
||||||
b.WriteString(t.ref())
|
b.WriteString(t.ref())
|
||||||
case *Selector:
|
case *Selector:
|
||||||
b.WriteString(" FROM ")
|
|
||||||
t.SetDialect(s.dialect)
|
t.SetDialect(s.dialect)
|
||||||
b.Nested(func(b *Builder) {
|
b.Nested(func(b *Builder) {
|
||||||
b.Join(t)
|
b.Join(t)
|
||||||
@@ -2679,9 +2713,11 @@ func (s *Selector) Query() (string, []any) {
|
|||||||
b.WriteString(" AS ")
|
b.WriteString(" AS ")
|
||||||
b.Ident(t.as)
|
b.Ident(t.as)
|
||||||
case *WithBuilder:
|
case *WithBuilder:
|
||||||
b.WriteString(" FROM ")
|
|
||||||
t.SetDialect(s.dialect)
|
t.SetDialect(s.dialect)
|
||||||
b.Ident(t.Name())
|
b.Ident(t.Name())
|
||||||
|
case *queryView:
|
||||||
|
b.Join(t.Querier)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, join := range s.joins {
|
for _, join := range s.joins {
|
||||||
b.WriteString(" " + join.kind + " ")
|
b.WriteString(" " + join.kind + " ")
|
||||||
@@ -3109,7 +3145,7 @@ func (b *Builder) Quote(ident string) string {
|
|||||||
func (b *Builder) Ident(s string) *Builder {
|
func (b *Builder) Ident(s string) *Builder {
|
||||||
switch {
|
switch {
|
||||||
case len(s) == 0:
|
case len(s) == 0:
|
||||||
case s != "*" && !b.isIdent(s) && !isFunc(s) && !isModifier(s):
|
case !strings.HasSuffix(s, "*") && !b.isIdent(s) && !isFunc(s) && !isModifier(s):
|
||||||
if b.qualifier != "" {
|
if b.qualifier != "" {
|
||||||
b.WriteString(b.Quote(b.qualifier)).WriteByte('.')
|
b.WriteString(b.Quote(b.qualifier)).WriteByte('.')
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2193,3 +2193,42 @@ func TestUpdateBuilder_WithPrefix(t *testing.T) {
|
|||||||
require.Empty(t, args)
|
require.Empty(t, args)
|
||||||
require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query)
|
require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMultipleFrom(t *testing.T) {
|
||||||
|
query, args := Dialect(dialect.Postgres).
|
||||||
|
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
|
||||||
|
From(Table("items")).
|
||||||
|
AppendFrom(Table("to_tsquery('neutrino|(dark & matter)')").As("search_query")).
|
||||||
|
Where(P(func(b *Builder) {
|
||||||
|
b.WriteString("search @@ search_query")
|
||||||
|
})).
|
||||||
|
OrderBy(Desc("rank")).
|
||||||
|
Query()
|
||||||
|
require.Empty(t, args)
|
||||||
|
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery('neutrino|(dark & matter)') AS "search_query" WHERE search @@ search_query ORDER BY "rank" DESC`, query)
|
||||||
|
|
||||||
|
query, args = Dialect(dialect.Postgres).
|
||||||
|
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
|
||||||
|
From(Table("items")).
|
||||||
|
AppendFromExpr(Expr("to_tsquery($1) AS search_query", "neutrino|(dark & matter)")).
|
||||||
|
Where(P(func(b *Builder) {
|
||||||
|
b.WriteString("search @@ search_query")
|
||||||
|
})).
|
||||||
|
Query()
|
||||||
|
require.Equal(t, []any{"neutrino|(dark & matter)"}, args)
|
||||||
|
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery($1) AS search_query WHERE search @@ search_query`, query)
|
||||||
|
|
||||||
|
query, args = Dialect(dialect.Postgres).
|
||||||
|
Select("items.*", As("ts_rank_cd(search, search_query)", "rank")).
|
||||||
|
From(Table("items")).
|
||||||
|
Where(EQ("value", 10)).
|
||||||
|
AppendFromExpr(ExprFunc(func(b *Builder) {
|
||||||
|
b.WriteString("to_tsquery(").Arg("neutrino|(dark & matter)").WriteString(") AS search_query")
|
||||||
|
})).
|
||||||
|
Where(P(func(b *Builder) {
|
||||||
|
b.WriteString("search @@ search_query")
|
||||||
|
})).
|
||||||
|
Query()
|
||||||
|
require.Equal(t, []any{"neutrino|(dark & matter)", 10}, args)
|
||||||
|
require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery($1) AS search_query WHERE "value" = $2 AND search @@ search_query`, query)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user