diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 8a5b4d1c7..4a4cec94b 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -2138,7 +2138,7 @@ type Selector struct { limit *int offset *int distinct bool - union []union + setOps []setOp prefix Queries lock *LockOptions } @@ -2424,23 +2424,27 @@ func (s *Selector) join(kind string, t TableView) *Selector { return s } -// unionType describes an UNION type. -type unionType string - -const ( - unionAll unionType = "ALL" - unionDistinct unionType = "DISTINCT" +type ( + // setOp represents a set/compound operation. + setOp struct { + Type setOpType // Set operation type. + All bool // Quantifier was set to ALL (defaults to DISTINCT). + TableView // Query or table to operate on. + } + // setOpType is a set operation type. + setOpType string ) -// union query option. -type union struct { - unionType - TableView -} +const ( + setOpTypeUnion setOpType = "UNION" + setOpTypeExcept setOpType = "EXCEPT" + setOpTypeIntersect setOpType = "INTERSECT" +) -// Union appends the UNION clause to the query. +// Union appends the UNION (DISTINCT) clause to the query. func (s *Selector) Union(t TableView) *Selector { - s.union = append(s.union, union{ + s.setOps = append(s.setOps, setOp{ + Type: setOpTypeUnion, TableView: t, }) return s @@ -2448,22 +2452,67 @@ func (s *Selector) Union(t TableView) *Selector { // UnionAll appends the UNION ALL clause to the query. func (s *Selector) UnionAll(t TableView) *Selector { - s.union = append(s.union, union{ - unionType: unionAll, + s.setOps = append(s.setOps, setOp{ + Type: setOpTypeUnion, + All: true, TableView: t, }) return s } // UnionDistinct appends the UNION DISTINCT clause to the query. +// Deprecated: use Union instead as by default, duplicate rows +// are eliminated unless ALL is specified. func (s *Selector) UnionDistinct(t TableView) *Selector { - s.union = append(s.union, union{ - unionType: unionDistinct, + return s.Union(t) +} + +// Except appends the EXCEPT clause to the query. +func (s *Selector) Except(t TableView) *Selector { + s.setOps = append(s.setOps, setOp{ + Type: setOpTypeExcept, TableView: t, }) return s } +// ExceptAll appends the EXCEPT ALL clause to the query. +func (s *Selector) ExceptAll(t TableView) *Selector { + if s.sqlite() { + s.AddError(errors.New("EXCEPT ALL is not supported by SQLite")) + } else { + s.setOps = append(s.setOps, setOp{ + Type: setOpTypeExcept, + All: true, + TableView: t, + }) + } + return s +} + +// Intersect appends the INTERSECT clause to the query. +func (s *Selector) Intersect(t TableView) *Selector { + s.setOps = append(s.setOps, setOp{ + Type: setOpTypeIntersect, + TableView: t, + }) + return s +} + +// IntersectAll appends the INTERSECT ALL clause to the query. +func (s *Selector) IntersectAll(t TableView) *Selector { + if s.sqlite() { + s.AddError(errors.New("INTERSECT ALL is not supported by SQLite")) + } else { + s.setOps = append(s.setOps, setOp{ + Type: setOpTypeIntersect, + All: true, + TableView: t, + }) + } + return s +} + // Prefix prefixes the query with list of queries. func (s *Selector) Prefix(queries ...Querier) *Selector { s.prefix = append(s.prefix, queries...) @@ -2779,8 +2828,8 @@ func (s *Selector) Query() (string, []any) { b.WriteString(" HAVING ") b.Join(s.having) } - if len(s.union) > 0 { - s.joinUnion(&b) + if len(s.setOps) > 0 { + s.joinSetOps(&b) } joinOrder(s.order, &b) if s.limit != nil { @@ -2822,13 +2871,13 @@ func (s *Selector) joinLock(b *Builder) { } } -func (s *Selector) joinUnion(b *Builder) { - for _, union := range s.union { - b.WriteString(" UNION ") - if union.unionType != "" { - b.WriteString(string(union.unionType) + " ") +func (s *Selector) joinSetOps(b *Builder) { + for _, op := range s.setOps { + b.WriteString(" " + string(op.Type) + " ") + if op.All { + b.WriteString("ALL ") } - switch view := union.TableView.(type) { + switch view := op.TableView.(type) { case *SelectTable: view.SetDialect(s.dialect) b.WriteString(view.ref()) diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 9860a3415..172c67fd1 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1716,34 +1716,96 @@ func TestSelector_SelectExpr(t *testing.T) { } func TestSelector_Union(t *testing.T) { - query, args := Dialect(dialect.Postgres). - Select("*"). - From(Table("users")). - Where(EQ("active", true)). - Union( - Select("*"). - From(Table("old_users1")). - Where( - And( - EQ("is_active", true), - GT("age", 20), - ), - ), - ). - UnionAll( - Select("*"). - From(Table("old_users2")). - Where( - And( - EQ("is_active", "true"), - LT("age", 18), - ), - ), - ). - Query() - require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query) - require.Equal(t, []any{20, "true", 18}, args) + query, args := Dialect(dialect.Postgres). + Select("*"). + From(Table("users")). + Where(EQ("active", true)). + Union( + Select("*"). + From(Table("old_users1")). + Where( + And( + EQ("is_active", true), + GT("age", 20), + ), + ), + ). + UnionAll( + Select("*"). + From(Table("old_users2")). + Where( + And( + EQ("is_active", "true"), + LT("age", 18), + ), + ), + ). + Query() + require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query) + require.Equal(t, []any{20, "true", 18}, args) +} +func TestSelector_Except(t *testing.T) { + query, args := Dialect(dialect.Postgres). + Select("*"). + From(Table("users")). + Where(EQ("active", true)). + Except( + Select("*"). + From(Table("old_users1")). + Where( + And( + EQ("is_active", true), + GT("age", 20), + ), + ), + ). + ExceptAll( + Select("*"). + From(Table("old_users2")). + Where( + And( + EQ("is_active", "true"), + LT("age", 18), + ), + ), + ). + Query() + require.Equal(t, `SELECT * FROM "users" WHERE "active" EXCEPT SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 EXCEPT ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query) + require.Equal(t, []any{20, "true", 18}, args) +} + +func TestSelector_Intersect(t *testing.T) { + query, args := Dialect(dialect.Postgres). + Select("*"). + From(Table("users")). + Where(EQ("active", true)). + Intersect( + Select("*"). + From(Table("old_users1")). + Where( + And( + EQ("is_active", true), + GT("age", 20), + ), + ), + ). + IntersectAll( + Select("*"). + From(Table("old_users2")). + Where( + And( + EQ("is_active", "true"), + LT("age", 18), + ), + ), + ). + Query() + require.Equal(t, `SELECT * FROM "users" WHERE "active" INTERSECT SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 INTERSECT ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query) + require.Equal(t, []any{20, "true", 18}, args) +} + +func TestSelector_SetOperatorWithRecursive(t *testing.T) { t1, t2, t3 := Table("files"), Table("files"), Table("path") n := Queries{ WithRecursive("path", "id", "name", "parent_id"). @@ -1768,7 +1830,7 @@ func TestSelector_Union(t *testing.T) { Select(t3.Columns("id", "name", "parent_id")...). From(t3), } - query, args = n.Query() + query, args := n.Query() require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND NOT `files`.`deleted` UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE NOT `files`.`deleted`) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query) require.Nil(t, args) }