dialect/sql: add union and with-recursive api for builder (#1595)

This commit is contained in:
Ariel Mashraki
2021-05-25 18:27:32 +03:00
committed by GitHub
parent b46a6f0d69
commit 5d3cc575b3
2 changed files with 160 additions and 10 deletions

View File

@@ -1730,6 +1730,7 @@ type Selector struct {
limit *int
offset *int
distinct bool
union []union
}
// WithContext sets the context into the *Selector.
@@ -1904,6 +1905,46 @@ func (s *Selector) join(kind string, t TableView) *Selector {
return s
}
// unionType describes a union type.
type unionType string
const (
unionAll unionType = "ALL"
unionDistinct unionType = "DISTINCT"
)
// union query option.
type union struct {
unionType
TableView
}
// Union appends the UNION clause to the query.
func (s *Selector) Union(t TableView) *Selector {
s.union = append(s.union, union{
TableView: t,
})
return s
}
// UnionAll appends the UNION ALL clause to the query.
func (s *Selector) UnionAll(t TableView) *Selector {
s.union = append(s.union, union{
unionType: unionAll,
TableView: t,
})
return s
}
// UnionDistinct appends the UNION DISTINCT clause to the query.
func (s *Selector) UnionDistinct(t TableView) *Selector {
s.union = append(s.union, union{
unionType: unionDistinct,
TableView: t,
})
return s
}
// C returns a formatted string for a selected column from this statement.
func (s *Selector) C(column string) string {
if s.as != "" {
@@ -2104,11 +2145,35 @@ func (s *Selector) Query() (string, []interface{}) {
b.WriteString(" OFFSET ")
b.WriteString(strconv.Itoa(*s.offset))
}
if len(s.union) > 0 {
s.joinUnion(&b)
}
s.total = b.total
s.AddError(b.Err())
return b.String(), b.args
}
func (s *Selector) joinUnion(b *Builder) {
for _, union := range s.union {
b.WriteString(" UNION ")
if union.unionType != "" {
b.WriteString(string(union.unionType) + " ")
}
switch view := union.TableView.(type) {
case *SelectTable:
view.SetDialect(s.dialect)
b.WriteString(view.ref())
case *Selector:
view.SetDialect(s.dialect)
b.Join(view)
if view.as != "" {
b.WriteString(" AS ")
b.Ident(view.as)
}
}
}
}
func (s *Selector) joinOrder(b *Builder) {
b.WriteString(" ORDER BY ")
for i := range s.order {
@@ -2130,17 +2195,34 @@ func (*Selector) view() {}
// WithBuilder is the builder for the `WITH` statement.
type WithBuilder struct {
Builder
name string
s *Selector
recursive bool
name string
columns []string
s *Selector
}
// With returns a new builder for the `WITH` statement.
//
// n := Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))}
// n := Queries{
// With("users_view").As(Select().From(Table("users"))),
// Select().From(Table("users_view")),
// }
// return n.Query()
//
func With(name string) *WithBuilder {
return &WithBuilder{name: name}
func With(name string, columns ...string) *WithBuilder {
return &WithBuilder{name: name, columns: columns}
}
// WithRecursive returns a new builder for the `WITH RECURSIVE` statement.
//
// n := Queries{
// WithRecursive("users_view").As(Select().From(Table("users"))),
// Select().From(Table("users_view")),
// }
// return n.Query()
//
func WithRecursive(name string, columns ...string) *WithBuilder {
return &WithBuilder{name: name, columns: columns, recursive: true}
}
// Name returns the name of the view.
@@ -2154,7 +2236,17 @@ func (w *WithBuilder) As(s *Selector) *WithBuilder {
// Query returns query representation of a `WITH` clause.
func (w *WithBuilder) Query() (string, []interface{}) {
w.WriteString(fmt.Sprintf("WITH %s AS ", w.name))
w.WriteString("WITH ")
if w.recursive {
w.WriteString("RECURSIVE ")
}
w.Ident(w.name)
if len(w.columns) > 0 {
w.WriteByte('(')
w.IdentComma(w.columns...)
w.WriteByte(')')
}
w.WriteString(" AS ")
w.Nested(func(b *Builder) {
b.Join(w.s)
})