dialect/sql: add left/right join support for selector builder

This commit is contained in:
maxilozoz
2020-10-02 01:03:45 +08:00
committed by GitHub
parent 83ac7bdbae
commit 21c2b3b467
2 changed files with 77 additions and 8 deletions

View File

@@ -1423,7 +1423,7 @@ func (*SelectTable) view() {}
// join table option.
type join struct {
on string
on *Predicate
kind string
table TableView
}
@@ -1433,6 +1433,7 @@ func (j join) clone() join {
if sel, ok := j.table.(*Selector); ok {
j.table = sel.Clone()
}
j.on = j.on.clone()
return j
}
@@ -1563,8 +1564,23 @@ func (s *Selector) Table() *SelectTable {
// Join appends a `JOIN` clause to the statement.
func (s *Selector) Join(t TableView) *Selector {
return s.join("JOIN", t)
}
// LeftJoin appends a `LEFT JOIN` clause to the statement.
func (s *Selector) LeftJoin(t TableView) *Selector {
return s.join("LEFT JOIN", t)
}
// RightJoin appends a `RIGHT JOIN` clause to the statement.
func (s *Selector) RightJoin(t TableView) *Selector {
return s.join("RIGHT JOIN", t)
}
// join
func (s *Selector) join(kind string, t TableView) *Selector {
s.joins = append(s.joins, join{
kind: "JOIN",
kind: kind,
table: t,
})
switch view := t.(type) {
@@ -1604,11 +1620,26 @@ func (s *Selector) Columns(columns ...string) []string {
return names
}
// OnP sets or appends the given predicate for the `ON` clause of the statement.
func (s *Selector) OnP(p *Predicate) *Selector {
if len(s.joins) > 0 {
join := &s.joins[len(s.joins)-1]
switch {
case join.on == nil:
join.on = p
default:
join.on = And(join.on, p)
}
}
return s
}
// On sets the `ON` clause for the `JOIN` operation.
func (s *Selector) On(c1, c2 string) *Selector {
if len(s.joins) > 0 {
s.joins[len(s.joins)-1].on = fmt.Sprintf("%s = %s", c1, c2)
}
s.OnP(P(func(builder *Builder) {
builder.Ident(c1).WriteOp(OpEQ).Ident(c2)
}))
return s
}
@@ -1729,9 +1760,9 @@ func (s *Selector) Query() (string, []interface{}) {
b.WriteString(" AS ")
b.Ident(view.as)
}
if join.on != "" {
if join.on != nil {
b.WriteString(" ON ")
b.WriteString(join.on)
b.Join(join.on)
}
}
if s.where != nil {
@@ -2055,7 +2086,7 @@ func (b *Builder) JoinComma(qs ...Querier) *Builder {
return b.join(qs, ", ")
}
// join joins a list of Queries to the builder with a given separator.
// join adds a join table to the selector with the given kind.
func (b *Builder) join(qs []Querier, sep string) *Builder {
for i, q := range qs {
if i > 0 {

View File

@@ -660,6 +660,44 @@ func TestBuilder(t *testing.T) {
wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g" ON "u"."id" = "g"."user_id" WHERE "u"."name" = $1 AND "g"."name" IS NOT NULL`,
wantArgs: []interface{}{"bar"},
},
{
input: func() Querier {
t1 := Table("users").As("u")
t2 := Table("user_groups").As("ug")
return Select(t1.C("id"), As(Count("`*`"), "group_count")).
From(t1).
LeftJoin(t2).
On(t1.C("id"), t2.C("user_id")).
GroupBy(t1.C("id"))
}(),
wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`",
},
{
input: func() Querier {
t1 := Table("users").As("u")
t2 := Table("user_groups").As("ug")
return Select(t1.C("id"), As(Count("`*`"), "group_count")).
From(t1).
LeftJoin(t2).
OnP(P(func(builder *Builder) {
builder.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("user_id"))
})).
GroupBy(t1.C("id")).Clone()
}(),
wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`",
},
{
input: func() Querier {
t1 := Table("groups").As("g")
t2 := Table("user_groups").As("ug")
return Select(t1.C("id"), As(Count("`*`"), "user_count")).
From(t1).
RightJoin(t2).
On(t1.C("id"), t2.C("group_id")).
GroupBy(t1.C("id"))
}(),
wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` RIGHT JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`",
},
{
input: func() Querier {
t1 := Table("users").As("u")