dialect/sql: add Joined<T> helpers for Selector to avoid double joining (#3419)

This commit is contained in:
Ariel Mashraki
2023-03-28 15:33:21 +03:00
committed by GitHub
parent 427aaf7d45
commit 651a2a166e
2 changed files with 71 additions and 4 deletions

View File

@@ -2388,11 +2388,11 @@ func (s *Selector) Table() *SelectTable {
}
// selectTable returns a *SelectTable from the given TableView.
func selectTable(tb TableView) *SelectTable {
if tb == nil {
func selectTable(t TableView) *SelectTable {
if t == nil {
return nil
}
switch view := tb.(type) {
switch view := t.(type) {
case *SelectTable:
return view
case *Selector:
@@ -2400,8 +2400,10 @@ func selectTable(tb TableView) *SelectTable {
return nil
}
return selectTable(view.from[0])
case *queryView, *WithBuilder:
return nil
default:
panic(fmt.Sprintf("unhandled TableView type %T", tb))
panic(fmt.Sprintf("unexpected TableView %T", t))
}
}
@@ -2422,6 +2424,38 @@ func (s *Selector) HasJoins() bool {
return len(s.joins) > 0
}
// JoinedTable returns the first joined table with the given name.
func (s *Selector) JoinedTable(name string) (*SelectTable, bool) {
for _, j := range s.joins {
if t := selectTable(j.table); t != nil && t.name == name {
return t, true
}
}
return nil, false
}
// JoinedTableView returns the first joined TableView with the given name or alias.
func (s *Selector) JoinedTableView(name string) (TableView, bool) {
for _, j := range s.joins {
switch t := j.table.(type) {
case *SelectTable:
if t.name == name || t.as == name {
return t, true
}
case *Selector:
if t.as == name {
return t, true
}
for _, t2 := range t.from {
if t3 := selectTable(t2); t3 != nil && (t3.name == name || t3.as == name) {
return t3, true
}
}
}
}
return nil, false
}
// Join appends a `JOIN` clause to the statement.
func (s *Selector) Join(t TableView) *Selector {
return s.join("JOIN", t)

View File

@@ -2378,3 +2378,36 @@ func TestSelector_HasJoins(t *testing.T) {
s.Join(Table("t2"))
require.True(t, s.HasJoins())
}
func TestSelector_JoinedTable(t *testing.T) {
s := Select("*").From(Table("t1"))
t2, ok := s.JoinedTable("t2")
require.False(t, ok)
require.Nil(t, t2)
s.Join(Table("t2").As("t2"))
t2, ok = s.JoinedTable("t2")
require.True(t, ok)
require.Equal(t, "`t2`.`c`", t2.C("c"))
s.LeftJoin(Select().From(Table("t3").As("t3")).Where(EQ("id", 1)))
t3, ok := s.JoinedTable("t3")
require.True(t, ok)
require.Equal(t, "`t3`.`c`", t3.C("c"))
}
func TestSelector_JoinedTableView(t *testing.T) {
s := Select("*").From(Table("t1"))
t2, ok := s.JoinedTableView("t2")
require.False(t, ok)
require.Nil(t, t2)
s.Join(Table("users").As("t2"))
t2, ok = s.JoinedTableView("t2")
require.True(t, ok)
require.Equal(t, "`t2`.`c`", t2.C("c"))
s.LeftJoin(Select().From(Table("pets").As("t3")).Where(EQ("id", 1)).As("t4"))
t3, ok := s.JoinedTableView("t3")
require.True(t, ok)
require.Equal(t, "`t3`.`c`", t3.C("c"))
t4, ok := s.JoinedTableView("t4")
require.True(t, ok)
require.Equal(t, "`t4`.`c`", t4.C("c"))
}