From 5d3cc575b316bec51423c3ada729b2e2ea0f78ab Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Tue, 25 May 2021 18:27:32 +0300 Subject: [PATCH] dialect/sql: add union and with-recursive api for builder (#1595) --- dialect/sql/builder.go | 104 +++++++++++++++++++++++++++++++++--- dialect/sql/builder_test.go | 66 +++++++++++++++++++++-- 2 files changed, 160 insertions(+), 10 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index e7fadba36..2de12003c 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1730,6 +1730,7 @@ type Selector struct { limit *int offset *int distinct bool + union []union } // WithContext sets the context into the *Selector. @@ -1904,6 +1905,46 @@ func (s *Selector) join(kind string, t TableView) *Selector { return s } +// unionType describes a union type. +type unionType string + +const ( + unionAll unionType = "ALL" + unionDistinct unionType = "DISTINCT" +) + +// union query option. +type union struct { + unionType + TableView +} + +// Union appends the UNION clause to the query. +func (s *Selector) Union(t TableView) *Selector { + s.union = append(s.union, union{ + TableView: t, + }) + return s +} + +// UnionAll appends the UNION ALL clause to the query. +func (s *Selector) UnionAll(t TableView) *Selector { + s.union = append(s.union, union{ + unionType: unionAll, + TableView: t, + }) + return s +} + +// UnionDistinct appends the UNION DISTINCT clause to the query. +func (s *Selector) UnionDistinct(t TableView) *Selector { + s.union = append(s.union, union{ + unionType: unionDistinct, + TableView: t, + }) + return s +} + // C returns a formatted string for a selected column from this statement. func (s *Selector) C(column string) string { if s.as != "" { @@ -2104,11 +2145,35 @@ func (s *Selector) Query() (string, []interface{}) { b.WriteString(" OFFSET ") b.WriteString(strconv.Itoa(*s.offset)) } + if len(s.union) > 0 { + s.joinUnion(&b) + } s.total = b.total s.AddError(b.Err()) return b.String(), b.args } +func (s *Selector) joinUnion(b *Builder) { + for _, union := range s.union { + b.WriteString(" UNION ") + if union.unionType != "" { + b.WriteString(string(union.unionType) + " ") + } + switch view := union.TableView.(type) { + case *SelectTable: + view.SetDialect(s.dialect) + b.WriteString(view.ref()) + case *Selector: + view.SetDialect(s.dialect) + b.Join(view) + if view.as != "" { + b.WriteString(" AS ") + b.Ident(view.as) + } + } + } +} + func (s *Selector) joinOrder(b *Builder) { b.WriteString(" ORDER BY ") for i := range s.order { @@ -2130,17 +2195,34 @@ func (*Selector) view() {} // WithBuilder is the builder for the `WITH` statement. type WithBuilder struct { Builder - name string - s *Selector + recursive bool + name string + columns []string + s *Selector } // With returns a new builder for the `WITH` statement. // -// n := Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))} +// n := Queries{ +// With("users_view").As(Select().From(Table("users"))), +// Select().From(Table("users_view")), +// } // return n.Query() // -func With(name string) *WithBuilder { - return &WithBuilder{name: name} +func With(name string, columns ...string) *WithBuilder { + return &WithBuilder{name: name, columns: columns} +} + +// WithRecursive returns a new builder for the `WITH RECURSIVE` statement. +// +// n := Queries{ +// WithRecursive("users_view").As(Select().From(Table("users"))), +// Select().From(Table("users_view")), +// } +// return n.Query() +// +func WithRecursive(name string, columns ...string) *WithBuilder { + return &WithBuilder{name: name, columns: columns, recursive: true} } // Name returns the name of the view. @@ -2154,7 +2236,17 @@ func (w *WithBuilder) As(s *Selector) *WithBuilder { // Query returns query representation of a `WITH` clause. func (w *WithBuilder) Query() (string, []interface{}) { - w.WriteString(fmt.Sprintf("WITH %s AS ", w.name)) + w.WriteString("WITH ") + if w.recursive { + w.WriteString("RECURSIVE ") + } + w.Ident(w.name) + if len(w.columns) > 0 { + w.WriteByte('(') + w.IdentComma(w.columns...) + w.WriteByte(')') + } + w.WriteString(" AS ") w.Nested(func(b *Builder) { b.Join(w.s) }) diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index e81b48894..de25fbf77 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -978,7 +978,7 @@ func TestBuilder(t *testing.T) { Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))) return Queries{With("users_view").As(s1), Select("name").From(Table("users_view"))} }(), - wantQuery: "WITH users_view AS (SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)) SELECT `name` FROM `users_view`", + wantQuery: "WITH `users_view` AS (SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)) SELECT `name` FROM `users_view`", wantArgs: []interface{}{"foo", "bar"}, }, { @@ -989,7 +989,7 @@ func TestBuilder(t *testing.T) { Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))) return Queries{d.With("users_view").As(s1), d.Select("name").From(Table("users_view"))} }(), - wantQuery: `WITH users_view AS (SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)) SELECT "name" FROM "users_view"`, + wantQuery: `WITH "users_view" AS (SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)) SELECT "name" FROM "users_view"`, wantArgs: []interface{}{"foo", "bar"}, }, { @@ -1208,14 +1208,14 @@ func TestBuilder(t *testing.T) { }, { input: Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))}, - wantQuery: "WITH users_view AS (SELECT * FROM `users`) SELECT * FROM `users_view`", + wantQuery: "WITH `users_view` AS (SELECT * FROM `users`) SELECT * FROM `users_view`", }, { input: func() Querier { base := Select("*").From(Table("groups")) return Queries{With("groups").As(base.Clone().Where(EQ("name", "bar"))), base.Select("age")} }(), - wantQuery: "WITH groups AS (SELECT * FROM `groups` WHERE `name` = ?) SELECT `age` FROM `groups`", + wantQuery: "WITH `groups` AS (SELECT * FROM `groups` WHERE `name` = ?) SELECT `age` FROM `groups`", wantArgs: []interface{}{"bar"}, }, { @@ -1516,6 +1516,64 @@ func TestSelector_OrderByExpr(t *testing.T) { require.Equal(t, []interface{}{28, 1, 2}, args) } +func TestSelector_Union(t *testing.T) { + query, args := Dialect(dialect.Postgres). + Select("*"). + From(Table("users")). + Where(EQ("active", true)). + Union( + Select("*"). + From(Table("old_users1")). + Where( + And( + EQ("is_active", true), + GT("age", 20), + ), + ), + ). + UnionAll( + Select("*"). + From(Table("old_users2")). + Where( + And( + EQ("is_active", "true"), + LT("age", 18), + ), + ), + ). + Query() + require.Equal(t, `SELECT * FROM "users" WHERE "active" = $1 UNION SELECT * FROM "old_users1" WHERE "is_active" = $2 AND "age" > $3 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $4 AND "age" < $5`, query) + require.Equal(t, []interface{}{true, true, 20, "true", 18}, args) + + t1, t2, t3 := Table("files"), Table("files"), Table("path") + n := Queries{ + WithRecursive("path", "id", "name", "parent_id"). + As(Select(t1.Columns("id", "name", "parent_id")...). + From(t1). + Where( + And( + IsNull(t1.C("parent_id")), + EQ(t1.C("deleted"), false), + ), + ). + UnionAll( + Select(t2.Columns("id", "name", "parent_id")...). + From(t2). + Join(t3). + On(t2.C("parent_id"), t3.C("id")). + Where( + EQ(t2.C("deleted"), false), + ), + ), + ), + Select(t3.Columns("id", "name", "parent_id")...). + From(t3), + } + query, args = n.Query() + require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND `files`.`deleted` = ? UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE `files`.`deleted` = ?) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query) + require.Equal(t, []interface{}{false, false}, args) +} + func TestBuilderContext(t *testing.T) { type key string want := "myval"