dialect/sql: add Joined<T> helpers for Selector to avoid double joining (#3419)

This commit is contained in:
Ariel Mashraki
2023-03-28 15:33:21 +03:00
committed by GitHub
parent 427aaf7d45
commit 651a2a166e
2 changed files with 71 additions and 4 deletions

View File

@@ -2388,11 +2388,11 @@ func (s *Selector) Table() *SelectTable {
}
// selectTable returns a *SelectTable from the given TableView.
func selectTable(tb TableView) *SelectTable {
if tb == nil {
func selectTable(t TableView) *SelectTable {
if t == nil {
return nil
}
switch view := tb.(type) {
switch view := t.(type) {
case *SelectTable:
return view
case *Selector:
@@ -2400,8 +2400,10 @@ func selectTable(tb TableView) *SelectTable {
return nil
}
return selectTable(view.from[0])
case *queryView, *WithBuilder:
return nil
default:
panic(fmt.Sprintf("unhandled TableView type %T", tb))
panic(fmt.Sprintf("unexpected TableView %T", t))
}
}
@@ -2422,6 +2424,38 @@ func (s *Selector) HasJoins() bool {
return len(s.joins) > 0
}
// JoinedTable returns the first joined table with the given name.
func (s *Selector) JoinedTable(name string) (*SelectTable, bool) {
for _, j := range s.joins {
if t := selectTable(j.table); t != nil && t.name == name {
return t, true
}
}
return nil, false
}
// JoinedTableView returns the first joined TableView with the given name or alias.
func (s *Selector) JoinedTableView(name string) (TableView, bool) {
for _, j := range s.joins {
switch t := j.table.(type) {
case *SelectTable:
if t.name == name || t.as == name {
return t, true
}
case *Selector:
if t.as == name {
return t, true
}
for _, t2 := range t.from {
if t3 := selectTable(t2); t3 != nil && (t3.name == name || t3.as == name) {
return t3, true
}
}
}
}
return nil, false
}
// Join appends a `JOIN` clause to the statement.
func (s *Selector) Join(t TableView) *Selector {
return s.join("JOIN", t)