mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: add left/right join support for selector builder
This commit is contained in:
@@ -1423,7 +1423,7 @@ func (*SelectTable) view() {}
|
||||
|
||||
// join table option.
|
||||
type join struct {
|
||||
on string
|
||||
on *Predicate
|
||||
kind string
|
||||
table TableView
|
||||
}
|
||||
@@ -1433,6 +1433,7 @@ func (j join) clone() join {
|
||||
if sel, ok := j.table.(*Selector); ok {
|
||||
j.table = sel.Clone()
|
||||
}
|
||||
j.on = j.on.clone()
|
||||
return j
|
||||
}
|
||||
|
||||
@@ -1563,8 +1564,23 @@ func (s *Selector) Table() *SelectTable {
|
||||
|
||||
// Join appends a `JOIN` clause to the statement.
|
||||
func (s *Selector) Join(t TableView) *Selector {
|
||||
return s.join("JOIN", t)
|
||||
}
|
||||
|
||||
// LeftJoin appends a `LEFT JOIN` clause to the statement.
|
||||
func (s *Selector) LeftJoin(t TableView) *Selector {
|
||||
return s.join("LEFT JOIN", t)
|
||||
}
|
||||
|
||||
// RightJoin appends a `RIGHT JOIN` clause to the statement.
|
||||
func (s *Selector) RightJoin(t TableView) *Selector {
|
||||
return s.join("RIGHT JOIN", t)
|
||||
}
|
||||
|
||||
// join
|
||||
func (s *Selector) join(kind string, t TableView) *Selector {
|
||||
s.joins = append(s.joins, join{
|
||||
kind: "JOIN",
|
||||
kind: kind,
|
||||
table: t,
|
||||
})
|
||||
switch view := t.(type) {
|
||||
@@ -1604,11 +1620,26 @@ func (s *Selector) Columns(columns ...string) []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// OnP sets or appends the given predicate for the `ON` clause of the statement.
|
||||
func (s *Selector) OnP(p *Predicate) *Selector {
|
||||
if len(s.joins) > 0 {
|
||||
join := &s.joins[len(s.joins)-1]
|
||||
|
||||
switch {
|
||||
case join.on == nil:
|
||||
join.on = p
|
||||
default:
|
||||
join.on = And(join.on, p)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// On sets the `ON` clause for the `JOIN` operation.
|
||||
func (s *Selector) On(c1, c2 string) *Selector {
|
||||
if len(s.joins) > 0 {
|
||||
s.joins[len(s.joins)-1].on = fmt.Sprintf("%s = %s", c1, c2)
|
||||
}
|
||||
s.OnP(P(func(builder *Builder) {
|
||||
builder.Ident(c1).WriteOp(OpEQ).Ident(c2)
|
||||
}))
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -1729,9 +1760,9 @@ func (s *Selector) Query() (string, []interface{}) {
|
||||
b.WriteString(" AS ")
|
||||
b.Ident(view.as)
|
||||
}
|
||||
if join.on != "" {
|
||||
if join.on != nil {
|
||||
b.WriteString(" ON ")
|
||||
b.WriteString(join.on)
|
||||
b.Join(join.on)
|
||||
}
|
||||
}
|
||||
if s.where != nil {
|
||||
@@ -2055,7 +2086,7 @@ func (b *Builder) JoinComma(qs ...Querier) *Builder {
|
||||
return b.join(qs, ", ")
|
||||
}
|
||||
|
||||
// join joins a list of Queries to the builder with a given separator.
|
||||
// join adds a join table to the selector with the given kind.
|
||||
func (b *Builder) join(qs []Querier, sep string) *Builder {
|
||||
for i, q := range qs {
|
||||
if i > 0 {
|
||||
|
||||
@@ -660,6 +660,44 @@ func TestBuilder(t *testing.T) {
|
||||
wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g" ON "u"."id" = "g"."user_id" WHERE "u"."name" = $1 AND "g"."name" IS NOT NULL`,
|
||||
wantArgs: []interface{}{"bar"},
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Table("user_groups").As("ug")
|
||||
return Select(t1.C("id"), As(Count("`*`"), "group_count")).
|
||||
From(t1).
|
||||
LeftJoin(t2).
|
||||
On(t1.C("id"), t2.C("user_id")).
|
||||
GroupBy(t1.C("id"))
|
||||
}(),
|
||||
wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`",
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Table("user_groups").As("ug")
|
||||
return Select(t1.C("id"), As(Count("`*`"), "group_count")).
|
||||
From(t1).
|
||||
LeftJoin(t2).
|
||||
OnP(P(func(builder *Builder) {
|
||||
builder.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("user_id"))
|
||||
})).
|
||||
GroupBy(t1.C("id")).Clone()
|
||||
}(),
|
||||
wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`",
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
t1 := Table("groups").As("g")
|
||||
t2 := Table("user_groups").As("ug")
|
||||
return Select(t1.C("id"), As(Count("`*`"), "user_count")).
|
||||
From(t1).
|
||||
RightJoin(t2).
|
||||
On(t1.C("id"), t2.C("group_id")).
|
||||
GroupBy(t1.C("id"))
|
||||
}(),
|
||||
wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` RIGHT JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`",
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
|
||||
Reference in New Issue
Block a user