dialect/sql: add helpers for basic predicates for comparing 2 columns. (#1358)

This commit is contained in:
Ruben de Vries
2021-03-22 10:42:12 +01:00
committed by GitHub
parent 6e3c3b6960
commit 68a6bd7fcd
2 changed files with 81 additions and 0 deletions

View File

@@ -946,6 +946,14 @@ func (p *Predicate) Not() *Predicate {
})
}
func (p *Predicate) columnsOp(col1, col2 string, op Op) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col1)
b.WriteOp(op)
b.Ident(col2)
})
}
// And combines all given predicates with AND between them.
func And(preds ...*Predicate) *Predicate {
p := P()
@@ -968,6 +976,15 @@ func (p *Predicate) EQ(col string, arg interface{}) *Predicate {
})
}
func ColumnsEQ(col1 string, col2 string) *Predicate {
return P().ColumnsEQ(col1, col2)
}
// ColumnsEQ appends a "=" predicate between 2 columns.
func (p *Predicate) ColumnsEQ(col1 string, col2 string) *Predicate {
return p.columnsOp(col1, col2, OpEQ)
}
// NEQ returns a "<>" predicate.
func NEQ(col string, value interface{}) *Predicate {
return P().NEQ(col, value)
@@ -982,6 +999,15 @@ func (p *Predicate) NEQ(col string, arg interface{}) *Predicate {
})
}
func ColumnsNEQ(col1 string, col2 string) *Predicate {
return P().ColumnsNEQ(col1, col2)
}
// ColumnsNEQ appends a "<>" predicate between 2 columns.
func (p *Predicate) ColumnsNEQ(col1 string, col2 string) *Predicate {
return p.columnsOp(col1, col2, OpNEQ)
}
// LT returns a "<" predicate.
func LT(col string, value interface{}) *Predicate {
return P().LT(col, value)
@@ -996,6 +1022,15 @@ func (p *Predicate) LT(col string, arg interface{}) *Predicate {
})
}
func ColumnsLT(col1 string, col2 string) *Predicate {
return P().ColumnsLT(col1, col2)
}
// ColumnsLT appends a "<" predicate between 2 columns.
func (p *Predicate) ColumnsLT(col1 string, col2 string) *Predicate {
return p.columnsOp(col1, col2, OpLT)
}
// LTE returns a "<=" predicate.
func LTE(col string, value interface{}) *Predicate {
return P().LTE(col, value)
@@ -1010,6 +1045,15 @@ func (p *Predicate) LTE(col string, arg interface{}) *Predicate {
})
}
func ColumnsLTE(col1 string, col2 string) *Predicate {
return P().ColumnsLTE(col1, col2)
}
// ColumnsLTE appends a "<=" predicate between 2 columns.
func (p *Predicate) ColumnsLTE(col1 string, col2 string) *Predicate {
return p.columnsOp(col1, col2, OpLTE)
}
// GT returns a ">" predicate.
func GT(col string, value interface{}) *Predicate {
return P().GT(col, value)
@@ -1024,6 +1068,15 @@ func (p *Predicate) GT(col string, arg interface{}) *Predicate {
})
}
func ColumnsGT(col1 string, col2 string) *Predicate {
return P().ColumnsGT(col1, col2)
}
// ColumnsGT appends a ">" predicate between 2 columns.
func (p *Predicate) ColumnsGT(col1 string, col2 string) *Predicate {
return p.columnsOp(col1, col2, OpGT)
}
// GTE returns a ">=" predicate.
func GTE(col string, value interface{}) *Predicate {
return P().GTE(col, value)
@@ -1038,6 +1091,15 @@ func (p *Predicate) GTE(col string, arg interface{}) *Predicate {
})
}
func ColumnsGTE(col1 string, col2 string) *Predicate {
return P().ColumnsGTE(col1, col2)
}
// ColumnsGTE appends a ">=" predicate between 2 columns.
func (p *Predicate) ColumnsGTE(col1 string, col2 string) *Predicate {
return p.columnsOp(col1, col2, OpGTE)
}
// NotNull returns the `IS NOT NULL` predicate.
func NotNull(col string) *Predicate {
return P().NotNull(col)

View File

@@ -1395,6 +1395,25 @@ WHERE
wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active" = $5)`,
wantArgs: []interface{}{"pedro", "pedro", "pedro", "luna", true},
},
{
input: func() Querier {
t1 := Table("users")
return Dialect(dialect.Postgres).
Select().
From(t1).
Where(ColumnsEQ(t1.C("id1"), t1.C("id2"))).
Where(ColumnsNEQ(t1.C("id1"), t1.C("id2"))).
Where(ColumnsGT(t1.C("id1"), t1.C("id2"))).
Where(ColumnsGTE(t1.C("id1"), t1.C("id2"))).
Where(ColumnsLT(t1.C("id1"), t1.C("id2"))).
Where(ColumnsLTE(t1.C("id1"), t1.C("id2")))
}(),
wantQuery: strings.ReplaceAll(`
SELECT * FROM "users"
WHERE (((("users"."id1" = "users"."id2" AND "users"."id1" <> "users"."id2")
AND "users"."id1" > "users"."id2") AND "users"."id1" >= "users"."id2")
AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", ""),
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {