dialect/sql: add except and intersect to builder (#3127)

* dialect/sql: add except and intersect for builder

* dialect/sql: report an error in case INTERSECT/EXCEPT ALL were set in SQLite

Co-authored-by: Liooo <ryoyamada3@gmail.com>
This commit is contained in:
Ariel Mashraki
2022-11-28 14:54:59 +02:00
committed by GitHub
parent 397afc3d85
commit 122574376b
2 changed files with 165 additions and 54 deletions

View File

@@ -2138,7 +2138,7 @@ type Selector struct {
limit *int
offset *int
distinct bool
union []union
setOps []setOp
prefix Queries
lock *LockOptions
}
@@ -2424,23 +2424,27 @@ func (s *Selector) join(kind string, t TableView) *Selector {
return s
}
// unionType describes an UNION type.
type unionType string
const (
unionAll unionType = "ALL"
unionDistinct unionType = "DISTINCT"
type (
// setOp represents a set/compound operation.
setOp struct {
Type setOpType // Set operation type.
All bool // Quantifier was set to ALL (defaults to DISTINCT).
TableView // Query or table to operate on.
}
// setOpType is a set operation type.
setOpType string
)
// union query option.
type union struct {
unionType
TableView
}
const (
setOpTypeUnion setOpType = "UNION"
setOpTypeExcept setOpType = "EXCEPT"
setOpTypeIntersect setOpType = "INTERSECT"
)
// Union appends the UNION clause to the query.
// Union appends the UNION (DISTINCT) clause to the query.
func (s *Selector) Union(t TableView) *Selector {
s.union = append(s.union, union{
s.setOps = append(s.setOps, setOp{
Type: setOpTypeUnion,
TableView: t,
})
return s
@@ -2448,22 +2452,67 @@ func (s *Selector) Union(t TableView) *Selector {
// UnionAll appends the UNION ALL clause to the query.
func (s *Selector) UnionAll(t TableView) *Selector {
s.union = append(s.union, union{
unionType: unionAll,
s.setOps = append(s.setOps, setOp{
Type: setOpTypeUnion,
All: true,
TableView: t,
})
return s
}
// UnionDistinct appends the UNION DISTINCT clause to the query.
// Deprecated: use Union instead as by default, duplicate rows
// are eliminated unless ALL is specified.
func (s *Selector) UnionDistinct(t TableView) *Selector {
s.union = append(s.union, union{
unionType: unionDistinct,
return s.Union(t)
}
// Except appends the EXCEPT clause to the query.
func (s *Selector) Except(t TableView) *Selector {
s.setOps = append(s.setOps, setOp{
Type: setOpTypeExcept,
TableView: t,
})
return s
}
// ExceptAll appends the EXCEPT ALL clause to the query.
func (s *Selector) ExceptAll(t TableView) *Selector {
if s.sqlite() {
s.AddError(errors.New("EXCEPT ALL is not supported by SQLite"))
} else {
s.setOps = append(s.setOps, setOp{
Type: setOpTypeExcept,
All: true,
TableView: t,
})
}
return s
}
// Intersect appends the INTERSECT clause to the query.
func (s *Selector) Intersect(t TableView) *Selector {
s.setOps = append(s.setOps, setOp{
Type: setOpTypeIntersect,
TableView: t,
})
return s
}
// IntersectAll appends the INTERSECT ALL clause to the query.
func (s *Selector) IntersectAll(t TableView) *Selector {
if s.sqlite() {
s.AddError(errors.New("INTERSECT ALL is not supported by SQLite"))
} else {
s.setOps = append(s.setOps, setOp{
Type: setOpTypeIntersect,
All: true,
TableView: t,
})
}
return s
}
// Prefix prefixes the query with list of queries.
func (s *Selector) Prefix(queries ...Querier) *Selector {
s.prefix = append(s.prefix, queries...)
@@ -2779,8 +2828,8 @@ func (s *Selector) Query() (string, []any) {
b.WriteString(" HAVING ")
b.Join(s.having)
}
if len(s.union) > 0 {
s.joinUnion(&b)
if len(s.setOps) > 0 {
s.joinSetOps(&b)
}
joinOrder(s.order, &b)
if s.limit != nil {
@@ -2822,13 +2871,13 @@ func (s *Selector) joinLock(b *Builder) {
}
}
func (s *Selector) joinUnion(b *Builder) {
for _, union := range s.union {
b.WriteString(" UNION ")
if union.unionType != "" {
b.WriteString(string(union.unionType) + " ")
func (s *Selector) joinSetOps(b *Builder) {
for _, op := range s.setOps {
b.WriteString(" " + string(op.Type) + " ")
if op.All {
b.WriteString("ALL ")
}
switch view := union.TableView.(type) {
switch view := op.TableView.(type) {
case *SelectTable:
view.SetDialect(s.dialect)
b.WriteString(view.ref())