mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: add Joined<T> helpers for Selector to avoid double joining (#3419)
This commit is contained in:
@@ -2388,11 +2388,11 @@ func (s *Selector) Table() *SelectTable {
|
||||
}
|
||||
|
||||
// selectTable returns a *SelectTable from the given TableView.
|
||||
func selectTable(tb TableView) *SelectTable {
|
||||
if tb == nil {
|
||||
func selectTable(t TableView) *SelectTable {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
switch view := tb.(type) {
|
||||
switch view := t.(type) {
|
||||
case *SelectTable:
|
||||
return view
|
||||
case *Selector:
|
||||
@@ -2400,8 +2400,10 @@ func selectTable(tb TableView) *SelectTable {
|
||||
return nil
|
||||
}
|
||||
return selectTable(view.from[0])
|
||||
case *queryView, *WithBuilder:
|
||||
return nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unhandled TableView type %T", tb))
|
||||
panic(fmt.Sprintf("unexpected TableView %T", t))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2422,6 +2424,38 @@ func (s *Selector) HasJoins() bool {
|
||||
return len(s.joins) > 0
|
||||
}
|
||||
|
||||
// JoinedTable returns the first joined table with the given name.
|
||||
func (s *Selector) JoinedTable(name string) (*SelectTable, bool) {
|
||||
for _, j := range s.joins {
|
||||
if t := selectTable(j.table); t != nil && t.name == name {
|
||||
return t, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// JoinedTableView returns the first joined TableView with the given name or alias.
|
||||
func (s *Selector) JoinedTableView(name string) (TableView, bool) {
|
||||
for _, j := range s.joins {
|
||||
switch t := j.table.(type) {
|
||||
case *SelectTable:
|
||||
if t.name == name || t.as == name {
|
||||
return t, true
|
||||
}
|
||||
case *Selector:
|
||||
if t.as == name {
|
||||
return t, true
|
||||
}
|
||||
for _, t2 := range t.from {
|
||||
if t3 := selectTable(t2); t3 != nil && (t3.name == name || t3.as == name) {
|
||||
return t3, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Join appends a `JOIN` clause to the statement.
|
||||
func (s *Selector) Join(t TableView) *Selector {
|
||||
return s.join("JOIN", t)
|
||||
|
||||
@@ -2378,3 +2378,36 @@ func TestSelector_HasJoins(t *testing.T) {
|
||||
s.Join(Table("t2"))
|
||||
require.True(t, s.HasJoins())
|
||||
}
|
||||
|
||||
func TestSelector_JoinedTable(t *testing.T) {
|
||||
s := Select("*").From(Table("t1"))
|
||||
t2, ok := s.JoinedTable("t2")
|
||||
require.False(t, ok)
|
||||
require.Nil(t, t2)
|
||||
s.Join(Table("t2").As("t2"))
|
||||
t2, ok = s.JoinedTable("t2")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "`t2`.`c`", t2.C("c"))
|
||||
s.LeftJoin(Select().From(Table("t3").As("t3")).Where(EQ("id", 1)))
|
||||
t3, ok := s.JoinedTable("t3")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "`t3`.`c`", t3.C("c"))
|
||||
}
|
||||
|
||||
func TestSelector_JoinedTableView(t *testing.T) {
|
||||
s := Select("*").From(Table("t1"))
|
||||
t2, ok := s.JoinedTableView("t2")
|
||||
require.False(t, ok)
|
||||
require.Nil(t, t2)
|
||||
s.Join(Table("users").As("t2"))
|
||||
t2, ok = s.JoinedTableView("t2")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "`t2`.`c`", t2.C("c"))
|
||||
s.LeftJoin(Select().From(Table("pets").As("t3")).Where(EQ("id", 1)).As("t4"))
|
||||
t3, ok := s.JoinedTableView("t3")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "`t3`.`c`", t3.C("c"))
|
||||
t4, ok := s.JoinedTableView("t4")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "`t4`.`c`", t4.C("c"))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user