From 9e610c7b544faebdc0a66b3b21e2efd90ae515b6 Mon Sep 17 00:00:00 2001 From: Ruben de Vries Date: Sun, 7 Feb 2021 17:07:48 +0100 Subject: [PATCH] dialect/sql: fix Builder.Join implementation for postgres (#1212) Don't use .Total() in builder.Join and purely rely on len(args). --- dialect/sql/builder.go | 5 +---- dialect/sql/builder_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 0b2644ff2..173fb076d 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -2244,10 +2244,7 @@ func (b *Builder) join(qs []Querier, sep string) *Builder { query, args := q.Query() b.WriteString(query) b.args = append(b.args, args...) - b.total = len(b.args) - if ok { - b.total = st.Total() - } + b.total += len(args) } return b } diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 02d052c05..7d8f15e36 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1372,6 +1372,24 @@ WHERE wantQuery: "SELECT * FROM `users` JOIN `pets` AS `t0` ON `users`.`id` = `t0`.`owner_id` WHERE `t0`.`name` = ?", wantArgs: []interface{}{"pedro"}, }, + { + input: func() Querier { + t1 := Table("users") + sel := Select("*"). + From(t1). + Where(P(func(b *Builder) { + b.Join(Expr("name = $1", "pedro")) + })). + Where(P(func(b *Builder) { + b.Join(Expr("name = $2", "pedro")) + })). + Where(EQ("name", "pedro")) + sel.SetDialect(dialect.Postgres) + return sel + }(), + wantQuery: `SELECT * FROM "users" WHERE (name = $1 AND name = $2) AND "name" = $3`, + wantArgs: []interface{}{"pedro", "pedro", "pedro"}, + }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { @@ -1382,6 +1400,25 @@ WHERE } } +func TestAnd(t *testing.T) { + assert := require.New(t) + + p1 := P(func(b *Builder) { + b.Join(Expr("name = $1", "pedro")) + }) + p2 := P(func(b *Builder) { + b.Join(Expr("name = $2", "pedro")) + }) + + and := And(p1, p2) + + _, _ = and.Query() + + assert.Equal(1, p1.Total()) + assert.Equal(2, p2.Total()) + assert.Equal(2, and.Total()) +} + func TestBuilder_Err(t *testing.T) { b := Select("i-") require.NoError(t, b.Err())