mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql: add except and intersect to builder (#3127)
* dialect/sql: add except and intersect for builder * dialect/sql: report an error in case INTERSECT/EXCEPT ALL were set in SQLite Co-authored-by: Liooo <ryoyamada3@gmail.com>
This commit is contained in:
@@ -2138,7 +2138,7 @@ type Selector struct {
|
||||
limit *int
|
||||
offset *int
|
||||
distinct bool
|
||||
union []union
|
||||
setOps []setOp
|
||||
prefix Queries
|
||||
lock *LockOptions
|
||||
}
|
||||
@@ -2424,23 +2424,27 @@ func (s *Selector) join(kind string, t TableView) *Selector {
|
||||
return s
|
||||
}
|
||||
|
||||
// unionType describes an UNION type.
|
||||
type unionType string
|
||||
|
||||
const (
|
||||
unionAll unionType = "ALL"
|
||||
unionDistinct unionType = "DISTINCT"
|
||||
type (
|
||||
// setOp represents a set/compound operation.
|
||||
setOp struct {
|
||||
Type setOpType // Set operation type.
|
||||
All bool // Quantifier was set to ALL (defaults to DISTINCT).
|
||||
TableView // Query or table to operate on.
|
||||
}
|
||||
// setOpType is a set operation type.
|
||||
setOpType string
|
||||
)
|
||||
|
||||
// union query option.
|
||||
type union struct {
|
||||
unionType
|
||||
TableView
|
||||
}
|
||||
const (
|
||||
setOpTypeUnion setOpType = "UNION"
|
||||
setOpTypeExcept setOpType = "EXCEPT"
|
||||
setOpTypeIntersect setOpType = "INTERSECT"
|
||||
)
|
||||
|
||||
// Union appends the UNION clause to the query.
|
||||
// Union appends the UNION (DISTINCT) clause to the query.
|
||||
func (s *Selector) Union(t TableView) *Selector {
|
||||
s.union = append(s.union, union{
|
||||
s.setOps = append(s.setOps, setOp{
|
||||
Type: setOpTypeUnion,
|
||||
TableView: t,
|
||||
})
|
||||
return s
|
||||
@@ -2448,22 +2452,67 @@ func (s *Selector) Union(t TableView) *Selector {
|
||||
|
||||
// UnionAll appends the UNION ALL clause to the query.
|
||||
func (s *Selector) UnionAll(t TableView) *Selector {
|
||||
s.union = append(s.union, union{
|
||||
unionType: unionAll,
|
||||
s.setOps = append(s.setOps, setOp{
|
||||
Type: setOpTypeUnion,
|
||||
All: true,
|
||||
TableView: t,
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// UnionDistinct appends the UNION DISTINCT clause to the query.
|
||||
// Deprecated: use Union instead as by default, duplicate rows
|
||||
// are eliminated unless ALL is specified.
|
||||
func (s *Selector) UnionDistinct(t TableView) *Selector {
|
||||
s.union = append(s.union, union{
|
||||
unionType: unionDistinct,
|
||||
return s.Union(t)
|
||||
}
|
||||
|
||||
// Except appends the EXCEPT clause to the query.
|
||||
func (s *Selector) Except(t TableView) *Selector {
|
||||
s.setOps = append(s.setOps, setOp{
|
||||
Type: setOpTypeExcept,
|
||||
TableView: t,
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// ExceptAll appends the EXCEPT ALL clause to the query.
|
||||
func (s *Selector) ExceptAll(t TableView) *Selector {
|
||||
if s.sqlite() {
|
||||
s.AddError(errors.New("EXCEPT ALL is not supported by SQLite"))
|
||||
} else {
|
||||
s.setOps = append(s.setOps, setOp{
|
||||
Type: setOpTypeExcept,
|
||||
All: true,
|
||||
TableView: t,
|
||||
})
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Intersect appends the INTERSECT clause to the query.
|
||||
func (s *Selector) Intersect(t TableView) *Selector {
|
||||
s.setOps = append(s.setOps, setOp{
|
||||
Type: setOpTypeIntersect,
|
||||
TableView: t,
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// IntersectAll appends the INTERSECT ALL clause to the query.
|
||||
func (s *Selector) IntersectAll(t TableView) *Selector {
|
||||
if s.sqlite() {
|
||||
s.AddError(errors.New("INTERSECT ALL is not supported by SQLite"))
|
||||
} else {
|
||||
s.setOps = append(s.setOps, setOp{
|
||||
Type: setOpTypeIntersect,
|
||||
All: true,
|
||||
TableView: t,
|
||||
})
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Prefix prefixes the query with list of queries.
|
||||
func (s *Selector) Prefix(queries ...Querier) *Selector {
|
||||
s.prefix = append(s.prefix, queries...)
|
||||
@@ -2779,8 +2828,8 @@ func (s *Selector) Query() (string, []any) {
|
||||
b.WriteString(" HAVING ")
|
||||
b.Join(s.having)
|
||||
}
|
||||
if len(s.union) > 0 {
|
||||
s.joinUnion(&b)
|
||||
if len(s.setOps) > 0 {
|
||||
s.joinSetOps(&b)
|
||||
}
|
||||
joinOrder(s.order, &b)
|
||||
if s.limit != nil {
|
||||
@@ -2822,13 +2871,13 @@ func (s *Selector) joinLock(b *Builder) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Selector) joinUnion(b *Builder) {
|
||||
for _, union := range s.union {
|
||||
b.WriteString(" UNION ")
|
||||
if union.unionType != "" {
|
||||
b.WriteString(string(union.unionType) + " ")
|
||||
func (s *Selector) joinSetOps(b *Builder) {
|
||||
for _, op := range s.setOps {
|
||||
b.WriteString(" " + string(op.Type) + " ")
|
||||
if op.All {
|
||||
b.WriteString("ALL ")
|
||||
}
|
||||
switch view := union.TableView.(type) {
|
||||
switch view := op.TableView.(type) {
|
||||
case *SelectTable:
|
||||
view.SetDialect(s.dialect)
|
||||
b.WriteString(view.ref())
|
||||
|
||||
@@ -1716,34 +1716,96 @@ func TestSelector_SelectExpr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSelector_Union(t *testing.T) {
|
||||
query, args := Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(EQ("active", true)).
|
||||
Union(
|
||||
Select("*").
|
||||
From(Table("old_users1")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", true),
|
||||
GT("age", 20),
|
||||
),
|
||||
),
|
||||
).
|
||||
UnionAll(
|
||||
Select("*").
|
||||
From(Table("old_users2")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", "true"),
|
||||
LT("age", 18),
|
||||
),
|
||||
),
|
||||
).
|
||||
Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query)
|
||||
require.Equal(t, []any{20, "true", 18}, args)
|
||||
query, args := Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(EQ("active", true)).
|
||||
Union(
|
||||
Select("*").
|
||||
From(Table("old_users1")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", true),
|
||||
GT("age", 20),
|
||||
),
|
||||
),
|
||||
).
|
||||
UnionAll(
|
||||
Select("*").
|
||||
From(Table("old_users2")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", "true"),
|
||||
LT("age", 18),
|
||||
),
|
||||
),
|
||||
).
|
||||
Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query)
|
||||
require.Equal(t, []any{20, "true", 18}, args)
|
||||
}
|
||||
|
||||
func TestSelector_Except(t *testing.T) {
|
||||
query, args := Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(EQ("active", true)).
|
||||
Except(
|
||||
Select("*").
|
||||
From(Table("old_users1")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", true),
|
||||
GT("age", 20),
|
||||
),
|
||||
),
|
||||
).
|
||||
ExceptAll(
|
||||
Select("*").
|
||||
From(Table("old_users2")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", "true"),
|
||||
LT("age", 18),
|
||||
),
|
||||
),
|
||||
).
|
||||
Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "active" EXCEPT SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 EXCEPT ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query)
|
||||
require.Equal(t, []any{20, "true", 18}, args)
|
||||
}
|
||||
|
||||
func TestSelector_Intersect(t *testing.T) {
|
||||
query, args := Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(EQ("active", true)).
|
||||
Intersect(
|
||||
Select("*").
|
||||
From(Table("old_users1")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", true),
|
||||
GT("age", 20),
|
||||
),
|
||||
),
|
||||
).
|
||||
IntersectAll(
|
||||
Select("*").
|
||||
From(Table("old_users2")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", "true"),
|
||||
LT("age", 18),
|
||||
),
|
||||
),
|
||||
).
|
||||
Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "active" INTERSECT SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 INTERSECT ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query)
|
||||
require.Equal(t, []any{20, "true", 18}, args)
|
||||
}
|
||||
|
||||
func TestSelector_SetOperatorWithRecursive(t *testing.T) {
|
||||
t1, t2, t3 := Table("files"), Table("files"), Table("path")
|
||||
n := Queries{
|
||||
WithRecursive("path", "id", "name", "parent_id").
|
||||
@@ -1768,7 +1830,7 @@ func TestSelector_Union(t *testing.T) {
|
||||
Select(t3.Columns("id", "name", "parent_id")...).
|
||||
From(t3),
|
||||
}
|
||||
query, args = n.Query()
|
||||
query, args := n.Query()
|
||||
require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND NOT `files`.`deleted` UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE NOT `files`.`deleted`) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query)
|
||||
require.Nil(t, args)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user