diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index e3b6946ce..44f7b74b6 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1423,7 +1423,7 @@ func (*SelectTable) view() {} // join table option. type join struct { - on string + on *Predicate kind string table TableView } @@ -1433,6 +1433,7 @@ func (j join) clone() join { if sel, ok := j.table.(*Selector); ok { j.table = sel.Clone() } + j.on = j.on.clone() return j } @@ -1563,8 +1564,23 @@ func (s *Selector) Table() *SelectTable { // Join appends a `JOIN` clause to the statement. func (s *Selector) Join(t TableView) *Selector { + return s.join("JOIN", t) +} + +// LeftJoin appends a `LEFT JOIN` clause to the statement. +func (s *Selector) LeftJoin(t TableView) *Selector { + return s.join("LEFT JOIN", t) +} + +// RightJoin appends a `RIGHT JOIN` clause to the statement. +func (s *Selector) RightJoin(t TableView) *Selector { + return s.join("RIGHT JOIN", t) +} + +// join +func (s *Selector) join(kind string, t TableView) *Selector { s.joins = append(s.joins, join{ - kind: "JOIN", + kind: kind, table: t, }) switch view := t.(type) { @@ -1604,11 +1620,26 @@ func (s *Selector) Columns(columns ...string) []string { return names } +// OnP sets or appends the given predicate for the `ON` clause of the statement. +func (s *Selector) OnP(p *Predicate) *Selector { + if len(s.joins) > 0 { + join := &s.joins[len(s.joins)-1] + + switch { + case join.on == nil: + join.on = p + default: + join.on = And(join.on, p) + } + } + return s +} + // On sets the `ON` clause for the `JOIN` operation. func (s *Selector) On(c1, c2 string) *Selector { - if len(s.joins) > 0 { - s.joins[len(s.joins)-1].on = fmt.Sprintf("%s = %s", c1, c2) - } + s.OnP(P(func(builder *Builder) { + builder.Ident(c1).WriteOp(OpEQ).Ident(c2) + })) return s } @@ -1729,9 +1760,9 @@ func (s *Selector) Query() (string, []interface{}) { b.WriteString(" AS ") b.Ident(view.as) } - if join.on != "" { + if join.on != nil { b.WriteString(" ON ") - b.WriteString(join.on) + b.Join(join.on) } } if s.where != nil { @@ -2055,7 +2086,7 @@ func (b *Builder) JoinComma(qs ...Querier) *Builder { return b.join(qs, ", ") } -// join joins a list of Queries to the builder with a given separator. +// join adds a join table to the selector with the given kind. func (b *Builder) join(qs []Querier, sep string) *Builder { for i, q := range qs { if i > 0 { diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index e0c104192..b06ba48c5 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -660,6 +660,44 @@ func TestBuilder(t *testing.T) { wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g" ON "u"."id" = "g"."user_id" WHERE "u"."name" = $1 AND "g"."name" IS NOT NULL`, wantArgs: []interface{}{"bar"}, }, + { + input: func() Querier { + t1 := Table("users").As("u") + t2 := Table("user_groups").As("ug") + return Select(t1.C("id"), As(Count("`*`"), "group_count")). + From(t1). + LeftJoin(t2). + On(t1.C("id"), t2.C("user_id")). + GroupBy(t1.C("id")) + }(), + wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`", + }, + { + input: func() Querier { + t1 := Table("users").As("u") + t2 := Table("user_groups").As("ug") + return Select(t1.C("id"), As(Count("`*`"), "group_count")). + From(t1). + LeftJoin(t2). + OnP(P(func(builder *Builder) { + builder.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("user_id")) + })). + GroupBy(t1.C("id")).Clone() + }(), + wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`", + }, + { + input: func() Querier { + t1 := Table("groups").As("g") + t2 := Table("user_groups").As("ug") + return Select(t1.C("id"), As(Count("`*`"), "user_count")). + From(t1). + RightJoin(t2). + On(t1.C("id"), t2.C("group_id")). + GroupBy(t1.C("id")) + }(), + wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` RIGHT JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`", + }, { input: func() Querier { t1 := Table("users").As("u")