mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: add union and with-recursive api for builder (#1595)
This commit is contained in:
@@ -1730,6 +1730,7 @@ type Selector struct {
|
||||
limit *int
|
||||
offset *int
|
||||
distinct bool
|
||||
union []union
|
||||
}
|
||||
|
||||
// WithContext sets the context into the *Selector.
|
||||
@@ -1904,6 +1905,46 @@ func (s *Selector) join(kind string, t TableView) *Selector {
|
||||
return s
|
||||
}
|
||||
|
||||
// unionType describes a union type.
|
||||
type unionType string
|
||||
|
||||
const (
|
||||
unionAll unionType = "ALL"
|
||||
unionDistinct unionType = "DISTINCT"
|
||||
)
|
||||
|
||||
// union query option.
|
||||
type union struct {
|
||||
unionType
|
||||
TableView
|
||||
}
|
||||
|
||||
// Union appends the UNION clause to the query.
|
||||
func (s *Selector) Union(t TableView) *Selector {
|
||||
s.union = append(s.union, union{
|
||||
TableView: t,
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// UnionAll appends the UNION ALL clause to the query.
|
||||
func (s *Selector) UnionAll(t TableView) *Selector {
|
||||
s.union = append(s.union, union{
|
||||
unionType: unionAll,
|
||||
TableView: t,
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// UnionDistinct appends the UNION DISTINCT clause to the query.
|
||||
func (s *Selector) UnionDistinct(t TableView) *Selector {
|
||||
s.union = append(s.union, union{
|
||||
unionType: unionDistinct,
|
||||
TableView: t,
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// C returns a formatted string for a selected column from this statement.
|
||||
func (s *Selector) C(column string) string {
|
||||
if s.as != "" {
|
||||
@@ -2104,11 +2145,35 @@ func (s *Selector) Query() (string, []interface{}) {
|
||||
b.WriteString(" OFFSET ")
|
||||
b.WriteString(strconv.Itoa(*s.offset))
|
||||
}
|
||||
if len(s.union) > 0 {
|
||||
s.joinUnion(&b)
|
||||
}
|
||||
s.total = b.total
|
||||
s.AddError(b.Err())
|
||||
return b.String(), b.args
|
||||
}
|
||||
|
||||
func (s *Selector) joinUnion(b *Builder) {
|
||||
for _, union := range s.union {
|
||||
b.WriteString(" UNION ")
|
||||
if union.unionType != "" {
|
||||
b.WriteString(string(union.unionType) + " ")
|
||||
}
|
||||
switch view := union.TableView.(type) {
|
||||
case *SelectTable:
|
||||
view.SetDialect(s.dialect)
|
||||
b.WriteString(view.ref())
|
||||
case *Selector:
|
||||
view.SetDialect(s.dialect)
|
||||
b.Join(view)
|
||||
if view.as != "" {
|
||||
b.WriteString(" AS ")
|
||||
b.Ident(view.as)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Selector) joinOrder(b *Builder) {
|
||||
b.WriteString(" ORDER BY ")
|
||||
for i := range s.order {
|
||||
@@ -2130,17 +2195,34 @@ func (*Selector) view() {}
|
||||
// WithBuilder is the builder for the `WITH` statement.
|
||||
type WithBuilder struct {
|
||||
Builder
|
||||
name string
|
||||
s *Selector
|
||||
recursive bool
|
||||
name string
|
||||
columns []string
|
||||
s *Selector
|
||||
}
|
||||
|
||||
// With returns a new builder for the `WITH` statement.
|
||||
//
|
||||
// n := Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))}
|
||||
// n := Queries{
|
||||
// With("users_view").As(Select().From(Table("users"))),
|
||||
// Select().From(Table("users_view")),
|
||||
// }
|
||||
// return n.Query()
|
||||
//
|
||||
func With(name string) *WithBuilder {
|
||||
return &WithBuilder{name: name}
|
||||
func With(name string, columns ...string) *WithBuilder {
|
||||
return &WithBuilder{name: name, columns: columns}
|
||||
}
|
||||
|
||||
// WithRecursive returns a new builder for the `WITH RECURSIVE` statement.
|
||||
//
|
||||
// n := Queries{
|
||||
// WithRecursive("users_view").As(Select().From(Table("users"))),
|
||||
// Select().From(Table("users_view")),
|
||||
// }
|
||||
// return n.Query()
|
||||
//
|
||||
func WithRecursive(name string, columns ...string) *WithBuilder {
|
||||
return &WithBuilder{name: name, columns: columns, recursive: true}
|
||||
}
|
||||
|
||||
// Name returns the name of the view.
|
||||
@@ -2154,7 +2236,17 @@ func (w *WithBuilder) As(s *Selector) *WithBuilder {
|
||||
|
||||
// Query returns query representation of a `WITH` clause.
|
||||
func (w *WithBuilder) Query() (string, []interface{}) {
|
||||
w.WriteString(fmt.Sprintf("WITH %s AS ", w.name))
|
||||
w.WriteString("WITH ")
|
||||
if w.recursive {
|
||||
w.WriteString("RECURSIVE ")
|
||||
}
|
||||
w.Ident(w.name)
|
||||
if len(w.columns) > 0 {
|
||||
w.WriteByte('(')
|
||||
w.IdentComma(w.columns...)
|
||||
w.WriteByte(')')
|
||||
}
|
||||
w.WriteString(" AS ")
|
||||
w.Nested(func(b *Builder) {
|
||||
b.Join(w.s)
|
||||
})
|
||||
|
||||
@@ -978,7 +978,7 @@ func TestBuilder(t *testing.T) {
|
||||
Where(Not(And(EQ("name", "foo"), EQ("age", "bar"))))
|
||||
return Queries{With("users_view").As(s1), Select("name").From(Table("users_view"))}
|
||||
}(),
|
||||
wantQuery: "WITH users_view AS (SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)) SELECT `name` FROM `users_view`",
|
||||
wantQuery: "WITH `users_view` AS (SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)) SELECT `name` FROM `users_view`",
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
@@ -989,7 +989,7 @@ func TestBuilder(t *testing.T) {
|
||||
Where(Not(And(EQ("name", "foo"), EQ("age", "bar"))))
|
||||
return Queries{d.With("users_view").As(s1), d.Select("name").From(Table("users_view"))}
|
||||
}(),
|
||||
wantQuery: `WITH users_view AS (SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)) SELECT "name" FROM "users_view"`,
|
||||
wantQuery: `WITH "users_view" AS (SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)) SELECT "name" FROM "users_view"`,
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
@@ -1208,14 +1208,14 @@ func TestBuilder(t *testing.T) {
|
||||
},
|
||||
{
|
||||
input: Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))},
|
||||
wantQuery: "WITH users_view AS (SELECT * FROM `users`) SELECT * FROM `users_view`",
|
||||
wantQuery: "WITH `users_view` AS (SELECT * FROM `users`) SELECT * FROM `users_view`",
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
base := Select("*").From(Table("groups"))
|
||||
return Queries{With("groups").As(base.Clone().Where(EQ("name", "bar"))), base.Select("age")}
|
||||
}(),
|
||||
wantQuery: "WITH groups AS (SELECT * FROM `groups` WHERE `name` = ?) SELECT `age` FROM `groups`",
|
||||
wantQuery: "WITH `groups` AS (SELECT * FROM `groups` WHERE `name` = ?) SELECT `age` FROM `groups`",
|
||||
wantArgs: []interface{}{"bar"},
|
||||
},
|
||||
{
|
||||
@@ -1516,6 +1516,64 @@ func TestSelector_OrderByExpr(t *testing.T) {
|
||||
require.Equal(t, []interface{}{28, 1, 2}, args)
|
||||
}
|
||||
|
||||
func TestSelector_Union(t *testing.T) {
|
||||
query, args := Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(EQ("active", true)).
|
||||
Union(
|
||||
Select("*").
|
||||
From(Table("old_users1")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", true),
|
||||
GT("age", 20),
|
||||
),
|
||||
),
|
||||
).
|
||||
UnionAll(
|
||||
Select("*").
|
||||
From(Table("old_users2")).
|
||||
Where(
|
||||
And(
|
||||
EQ("is_active", "true"),
|
||||
LT("age", 18),
|
||||
),
|
||||
),
|
||||
).
|
||||
Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "active" = $1 UNION SELECT * FROM "old_users1" WHERE "is_active" = $2 AND "age" > $3 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $4 AND "age" < $5`, query)
|
||||
require.Equal(t, []interface{}{true, true, 20, "true", 18}, args)
|
||||
|
||||
t1, t2, t3 := Table("files"), Table("files"), Table("path")
|
||||
n := Queries{
|
||||
WithRecursive("path", "id", "name", "parent_id").
|
||||
As(Select(t1.Columns("id", "name", "parent_id")...).
|
||||
From(t1).
|
||||
Where(
|
||||
And(
|
||||
IsNull(t1.C("parent_id")),
|
||||
EQ(t1.C("deleted"), false),
|
||||
),
|
||||
).
|
||||
UnionAll(
|
||||
Select(t2.Columns("id", "name", "parent_id")...).
|
||||
From(t2).
|
||||
Join(t3).
|
||||
On(t2.C("parent_id"), t3.C("id")).
|
||||
Where(
|
||||
EQ(t2.C("deleted"), false),
|
||||
),
|
||||
),
|
||||
),
|
||||
Select(t3.Columns("id", "name", "parent_id")...).
|
||||
From(t3),
|
||||
}
|
||||
query, args = n.Query()
|
||||
require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND `files`.`deleted` = ? UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE `files`.`deleted` = ?) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query)
|
||||
require.Equal(t, []interface{}{false, false}, args)
|
||||
}
|
||||
|
||||
func TestBuilderContext(t *testing.T) {
|
||||
type key string
|
||||
want := "myval"
|
||||
|
||||
Reference in New Issue
Block a user