From 68a6bd7fcdcbb3d556e98005902bf8ded6602c49 Mon Sep 17 00:00:00 2001 From: Ruben de Vries Date: Mon, 22 Mar 2021 10:42:12 +0100 Subject: [PATCH] dialect/sql: add helpers for basic predicates for comparing 2 columns. (#1358) --- dialect/sql/builder.go | 62 +++++++++++++++++++++++++++++++++++++ dialect/sql/builder_test.go | 19 ++++++++++++ 2 files changed, 81 insertions(+) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index bae5fbfc5..0354e1012 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -946,6 +946,14 @@ func (p *Predicate) Not() *Predicate { }) } +func (p *Predicate) columnsOp(col1, col2 string, op Op) *Predicate { + return p.Append(func(b *Builder) { + b.Ident(col1) + b.WriteOp(op) + b.Ident(col2) + }) +} + // And combines all given predicates with AND between them. func And(preds ...*Predicate) *Predicate { p := P() @@ -968,6 +976,15 @@ func (p *Predicate) EQ(col string, arg interface{}) *Predicate { }) } +func ColumnsEQ(col1 string, col2 string) *Predicate { + return P().ColumnsEQ(col1, col2) +} + +// ColumnsEQ appends a "=" predicate between 2 columns. +func (p *Predicate) ColumnsEQ(col1 string, col2 string) *Predicate { + return p.columnsOp(col1, col2, OpEQ) +} + // NEQ returns a "<>" predicate. func NEQ(col string, value interface{}) *Predicate { return P().NEQ(col, value) @@ -982,6 +999,15 @@ func (p *Predicate) NEQ(col string, arg interface{}) *Predicate { }) } +func ColumnsNEQ(col1 string, col2 string) *Predicate { + return P().ColumnsNEQ(col1, col2) +} + +// ColumnsNEQ appends a "<>" predicate between 2 columns. +func (p *Predicate) ColumnsNEQ(col1 string, col2 string) *Predicate { + return p.columnsOp(col1, col2, OpNEQ) +} + // LT returns a "<" predicate. func LT(col string, value interface{}) *Predicate { return P().LT(col, value) @@ -996,6 +1022,15 @@ func (p *Predicate) LT(col string, arg interface{}) *Predicate { }) } +func ColumnsLT(col1 string, col2 string) *Predicate { + return P().ColumnsLT(col1, col2) +} + +// ColumnsLT appends a "<" predicate between 2 columns. +func (p *Predicate) ColumnsLT(col1 string, col2 string) *Predicate { + return p.columnsOp(col1, col2, OpLT) +} + // LTE returns a "<=" predicate. func LTE(col string, value interface{}) *Predicate { return P().LTE(col, value) @@ -1010,6 +1045,15 @@ func (p *Predicate) LTE(col string, arg interface{}) *Predicate { }) } +func ColumnsLTE(col1 string, col2 string) *Predicate { + return P().ColumnsLTE(col1, col2) +} + +// ColumnsLTE appends a "<=" predicate between 2 columns. +func (p *Predicate) ColumnsLTE(col1 string, col2 string) *Predicate { + return p.columnsOp(col1, col2, OpLTE) +} + // GT returns a ">" predicate. func GT(col string, value interface{}) *Predicate { return P().GT(col, value) @@ -1024,6 +1068,15 @@ func (p *Predicate) GT(col string, arg interface{}) *Predicate { }) } +func ColumnsGT(col1 string, col2 string) *Predicate { + return P().ColumnsGT(col1, col2) +} + +// ColumnsGT appends a ">" predicate between 2 columns. +func (p *Predicate) ColumnsGT(col1 string, col2 string) *Predicate { + return p.columnsOp(col1, col2, OpGT) +} + // GTE returns a ">=" predicate. func GTE(col string, value interface{}) *Predicate { return P().GTE(col, value) @@ -1038,6 +1091,15 @@ func (p *Predicate) GTE(col string, arg interface{}) *Predicate { }) } +func ColumnsGTE(col1 string, col2 string) *Predicate { + return P().ColumnsGTE(col1, col2) +} + +// ColumnsGTE appends a ">=" predicate between 2 columns. +func (p *Predicate) ColumnsGTE(col1 string, col2 string) *Predicate { + return p.columnsOp(col1, col2, OpGTE) +} + // NotNull returns the `IS NOT NULL` predicate. func NotNull(col string) *Predicate { return P().NotNull(col) diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 7d562e8a8..597e12ae5 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1395,6 +1395,25 @@ WHERE wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active" = $5)`, wantArgs: []interface{}{"pedro", "pedro", "pedro", "luna", true}, }, + { + input: func() Querier { + t1 := Table("users") + return Dialect(dialect.Postgres). + Select(). + From(t1). + Where(ColumnsEQ(t1.C("id1"), t1.C("id2"))). + Where(ColumnsNEQ(t1.C("id1"), t1.C("id2"))). + Where(ColumnsGT(t1.C("id1"), t1.C("id2"))). + Where(ColumnsGTE(t1.C("id1"), t1.C("id2"))). + Where(ColumnsLT(t1.C("id1"), t1.C("id2"))). + Where(ColumnsLTE(t1.C("id1"), t1.C("id2"))) + }(), + wantQuery: strings.ReplaceAll(` +SELECT * FROM "users" +WHERE (((("users"."id1" = "users"."id2" AND "users"."id1" <> "users"."id2") +AND "users"."id1" > "users"."id2") AND "users"."id1" >= "users"."id2") +AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", ""), + }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) {