dialect/sql: add Selector.HasJoins helper (#3418)

This commit is contained in:
Ariel Mashraki
2023-03-28 14:31:43 +03:00
committed by GitHub
parent d9e7adfa05
commit 427aaf7d45
2 changed files with 23 additions and 1 deletions

View File

@@ -2025,6 +2025,9 @@ func Distinct(idents ...string) string {
// TableView is a view that returns a table view. Can be a Table, Selector or a View (WITH statement). // TableView is a view that returns a table view. Can be a Table, Selector or a View (WITH statement).
type TableView interface { type TableView interface {
view() view()
// C returns a formatted string prefixed
// with the table view qualifier.
C(string) string
} }
// queryView allows using Querier (expressions) in the FROM clause. // queryView allows using Querier (expressions) in the FROM clause.
@@ -2032,6 +2035,13 @@ type queryView struct{ Querier }
func (*queryView) view() {} func (*queryView) view() {}
func (q *queryView) C(column string) string {
if tv, ok := q.Querier.(TableView); ok {
return tv.C(column)
}
return column
}
// SelectTable is a table selector. // SelectTable is a table selector.
type SelectTable struct { type SelectTable struct {
Builder Builder
@@ -2407,6 +2417,11 @@ func (s *Selector) TableName() string {
} }
} }
// HasJoins reports if the selector has any JOINs.
func (s *Selector) HasJoins() bool {
return len(s.joins) > 0
}
// Join appends a `JOIN` clause to the statement. // Join appends a `JOIN` clause to the statement.
func (s *Selector) Join(t TableView) *Selector { func (s *Selector) Join(t TableView) *Selector {
return s.join("JOIN", t) return s.join("JOIN", t)
@@ -3090,7 +3105,7 @@ func RowNumber() *WindowBuilder {
} }
// Window returns a new window clause with a custom selector allowing // Window returns a new window clause with a custom selector allowing
// for custom windown functions. // for custom window functions.
// //
// Window(func(b *Builder) { // Window(func(b *Builder) {
// b.WriteString(Sum(posts.C("duration"))) // b.WriteString(Sum(posts.C("duration")))

View File

@@ -2371,3 +2371,10 @@ func TestFormattedColumnFromSubQuery(t *testing.T) {
}), "score").From(Table("table_name").As("table_name_alias"))) }), "score").From(Table("table_name").As("table_name_alias")))
require.Equal(t, "`table_name_alias`.`score`", q.C("score")) require.Equal(t, "`table_name_alias`.`score`", q.C("score"))
} }
func TestSelector_HasJoins(t *testing.T) {
s := Select("*").From(Table("t1"))
require.False(t, s.HasJoins())
s.Join(Table("t2"))
require.True(t, s.HasJoins())
}