From 651a2a166e9d818259b0f226e6da8955de9941c0 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Tue, 28 Mar 2023 15:33:21 +0300 Subject: [PATCH] dialect/sql: add Joined helpers for Selector to avoid double joining (#3419) --- dialect/sql/builder.go | 42 +++++++++++++++++++++++++++++++++---- dialect/sql/builder_test.go | 33 +++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 542daf15b..7b150d762 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -2388,11 +2388,11 @@ func (s *Selector) Table() *SelectTable { } // selectTable returns a *SelectTable from the given TableView. -func selectTable(tb TableView) *SelectTable { - if tb == nil { +func selectTable(t TableView) *SelectTable { + if t == nil { return nil } - switch view := tb.(type) { + switch view := t.(type) { case *SelectTable: return view case *Selector: @@ -2400,8 +2400,10 @@ func selectTable(tb TableView) *SelectTable { return nil } return selectTable(view.from[0]) + case *queryView, *WithBuilder: + return nil default: - panic(fmt.Sprintf("unhandled TableView type %T", tb)) + panic(fmt.Sprintf("unexpected TableView %T", t)) } } @@ -2422,6 +2424,38 @@ func (s *Selector) HasJoins() bool { return len(s.joins) > 0 } +// JoinedTable returns the first joined table with the given name. +func (s *Selector) JoinedTable(name string) (*SelectTable, bool) { + for _, j := range s.joins { + if t := selectTable(j.table); t != nil && t.name == name { + return t, true + } + } + return nil, false +} + +// JoinedTableView returns the first joined TableView with the given name or alias. +func (s *Selector) JoinedTableView(name string) (TableView, bool) { + for _, j := range s.joins { + switch t := j.table.(type) { + case *SelectTable: + if t.name == name || t.as == name { + return t, true + } + case *Selector: + if t.as == name { + return t, true + } + for _, t2 := range t.from { + if t3 := selectTable(t2); t3 != nil && (t3.name == name || t3.as == name) { + return t3, true + } + } + } + } + return nil, false +} + // Join appends a `JOIN` clause to the statement. func (s *Selector) Join(t TableView) *Selector { return s.join("JOIN", t) diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index e6c518941..8ec020536 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -2378,3 +2378,36 @@ func TestSelector_HasJoins(t *testing.T) { s.Join(Table("t2")) require.True(t, s.HasJoins()) } + +func TestSelector_JoinedTable(t *testing.T) { + s := Select("*").From(Table("t1")) + t2, ok := s.JoinedTable("t2") + require.False(t, ok) + require.Nil(t, t2) + s.Join(Table("t2").As("t2")) + t2, ok = s.JoinedTable("t2") + require.True(t, ok) + require.Equal(t, "`t2`.`c`", t2.C("c")) + s.LeftJoin(Select().From(Table("t3").As("t3")).Where(EQ("id", 1))) + t3, ok := s.JoinedTable("t3") + require.True(t, ok) + require.Equal(t, "`t3`.`c`", t3.C("c")) +} + +func TestSelector_JoinedTableView(t *testing.T) { + s := Select("*").From(Table("t1")) + t2, ok := s.JoinedTableView("t2") + require.False(t, ok) + require.Nil(t, t2) + s.Join(Table("users").As("t2")) + t2, ok = s.JoinedTableView("t2") + require.True(t, ok) + require.Equal(t, "`t2`.`c`", t2.C("c")) + s.LeftJoin(Select().From(Table("pets").As("t3")).Where(EQ("id", 1)).As("t4")) + t3, ok := s.JoinedTableView("t3") + require.True(t, ok) + require.Equal(t, "`t3`.`c`", t3.C("c")) + t4, ok := s.JoinedTableView("t4") + require.True(t, ok) + require.Equal(t, "`t4`.`c`", t4.C("c")) +}