From 90289b7494699ac7350419f2b4d39c751578eb80 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Sun, 2 Apr 2023 13:01:36 +0300 Subject: [PATCH] dialect/sql: additional predicate helpers (#3429) --- dialect/sql/{predicate.go => sql.go} | 34 ++++++++- .../sql/{predicate_test.go => sql_test.go} | 72 +++++++++++++++++++ doc/md/predicates.md | 22 ++++++ 3 files changed, 125 insertions(+), 3 deletions(-) rename dialect/sql/{predicate.go => sql.go} (80%) rename dialect/sql/{predicate_test.go => sql_test.go} (82%) diff --git a/dialect/sql/predicate.go b/dialect/sql/sql.go similarity index 80% rename from dialect/sql/predicate.go rename to dialect/sql/sql.go index 161057f9b..f315ff926 100644 --- a/dialect/sql/predicate.go +++ b/dialect/sql/sql.go @@ -4,9 +4,9 @@ package sql -// This file provides extra helpers to simplify the way raw predicates -// are defined and used in both ent/schema and generated code. For full -// predicates, check out the sql.P in builder.go. +// The following helpers exist to simplify the way raw predicates +// are defined and used in both ent/schema and generated code. For +// full predicates API, check out the sql.P in builder.go. // FieldIsNull returns a raw predicate to check if the given field is NULL. func FieldIsNull(name string) func(*Selector) { @@ -57,6 +57,13 @@ func FieldGT(name string, v any) func(*Selector) { } } +// FieldsGT returns a raw predicate to check if field1 is greater than field2. +func FieldsGT(field1, field2 string) func(*Selector) { + return func(s *Selector) { + s.Where(ColumnsGT(s.C(field1), s.C(field2))) + } +} + // FieldGTE returns a raw predicate to check if the given field is greater than or equal the given value. func FieldGTE(name string, v any) func(*Selector) { return func(s *Selector) { @@ -64,6 +71,13 @@ func FieldGTE(name string, v any) func(*Selector) { } } +// FieldsGTE returns a raw predicate to check if field1 is greater than or equal field2. +func FieldsGTE(field1, field2 string) func(*Selector) { + return func(s *Selector) { + s.Where(ColumnsGTE(s.C(field1), s.C(field2))) + } +} + // FieldLT returns a raw predicate to check if the value of the field is less than the given value. func FieldLT(name string, v any) func(*Selector) { return func(s *Selector) { @@ -71,6 +85,13 @@ func FieldLT(name string, v any) func(*Selector) { } } +// FieldsLT returns a raw predicate to check if field1 is lower than field2. +func FieldsLT(field1, field2 string) func(*Selector) { + return func(s *Selector) { + s.Where(ColumnsLT(s.C(field1), s.C(field2))) + } +} + // FieldLTE returns a raw predicate to check if the value of the field is less than the given value. func FieldLTE(name string, v any) func(*Selector) { return func(s *Selector) { @@ -78,6 +99,13 @@ func FieldLTE(name string, v any) func(*Selector) { } } +// FieldsLTE returns a raw predicate to check if field1 is lower than or equal field2. +func FieldsLTE(field1, field2 string) func(*Selector) { + return func(s *Selector) { + s.Where(ColumnsLTE(s.C(field1), s.C(field2))) + } +} + // FieldIn returns a raw predicate to check if the value of the field is IN the given values. func FieldIn[T any](name string, vs ...T) func(*Selector) { return func(s *Selector) { diff --git a/dialect/sql/predicate_test.go b/dialect/sql/sql_test.go similarity index 82% rename from dialect/sql/predicate_test.go rename to dialect/sql/sql_test.go index 6496e7b90..73a5653ab 100644 --- a/dialect/sql/predicate_test.go +++ b/dialect/sql/sql_test.go @@ -138,6 +138,24 @@ func TestFieldGT(t *testing.T) { }) } +func TestFieldsGT(t *testing.T) { + p := FieldsGT("a", "b") + t.Run("MySQL", func(t *testing.T) { + s := Dialect(dialect.MySQL).Select("*").From(Table("users")) + p(s) + query, args := s.Query() + require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` > `users`.`b`", query) + require.Empty(t, args) + }) + t.Run("PostgreSQL", func(t *testing.T) { + s := Dialect(dialect.Postgres).Select("*").From(Table("users")) + p(s) + query, args := s.Query() + require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" > "users"."b"`, query) + require.Empty(t, args) + }) +} + func TestFieldGTE(t *testing.T) { p := FieldGTE("stars", 1000) t.Run("MySQL", func(t *testing.T) { @@ -156,6 +174,24 @@ func TestFieldGTE(t *testing.T) { }) } +func TestFieldsGTE(t *testing.T) { + p := FieldsGTE("a", "b") + t.Run("MySQL", func(t *testing.T) { + s := Dialect(dialect.MySQL).Select("*").From(Table("users")) + p(s) + query, args := s.Query() + require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` >= `users`.`b`", query) + require.Empty(t, args) + }) + t.Run("PostgreSQL", func(t *testing.T) { + s := Dialect(dialect.Postgres).Select("*").From(Table("users")) + p(s) + query, args := s.Query() + require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" >= "users"."b"`, query) + require.Empty(t, args) + }) +} + func TestFieldLT(t *testing.T) { p := FieldLT("stars", 1000) t.Run("MySQL", func(t *testing.T) { @@ -174,6 +210,24 @@ func TestFieldLT(t *testing.T) { }) } +func TestFieldsLT(t *testing.T) { + p := FieldsLT("a", "b") + t.Run("MySQL", func(t *testing.T) { + s := Dialect(dialect.MySQL).Select("*").From(Table("users")) + p(s) + query, args := s.Query() + require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` < `users`.`b`", query) + require.Empty(t, args) + }) + t.Run("PostgreSQL", func(t *testing.T) { + s := Dialect(dialect.Postgres).Select("*").From(Table("users")) + p(s) + query, args := s.Query() + require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" < "users"."b"`, query) + require.Empty(t, args) + }) +} + func TestFieldLTE(t *testing.T) { p := FieldLTE("stars", 1000) t.Run("MySQL", func(t *testing.T) { @@ -192,6 +246,24 @@ func TestFieldLTE(t *testing.T) { }) } +func TestFieldsLTE(t *testing.T) { + p := FieldsLTE("a", "b") + t.Run("MySQL", func(t *testing.T) { + s := Dialect(dialect.MySQL).Select("*").From(Table("users")) + p(s) + query, args := s.Query() + require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` <= `users`.`b`", query) + require.Empty(t, args) + }) + t.Run("PostgreSQL", func(t *testing.T) { + s := Dialect(dialect.Postgres).Select("*").From(Table("users")) + p(s) + query, args := s.Query() + require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" <= "users"."b"`, query) + require.Empty(t, args) + }) +} + func TestFieldIn(t *testing.T) { p := FieldIn("name", "a8m", "foo", "bar") t.Run("MySQL", func(t *testing.T) { diff --git a/doc/md/predicates.md b/doc/md/predicates.md index b3242271a..84f52c1c0 100644 --- a/doc/md/predicates.md +++ b/doc/md/predicates.md @@ -323,3 +323,25 @@ sqljson.ValueIn(user.FieldURL, []any{"https", "ftp"}, sqljson.Path("Scheme")) sqljson.ValueNotIn(user.FieldURL, []any{"github", "gitlab"}, sqljson.Path("Host")) ``` +## Comparing Fields + +The `dialect/sql` package provides a set of comparison functions that can be used to compare fields in a query. + +```go +client.Order.Query(). + Where( + sql.FieldsEQ(order.FieldTotal, order.FieldTax), + sql.FieldsNEQ(order.FieldTotal, order.FieldDiscount), + ). + All(ctx) + +client.Order.Query(). + Where( + order.Or( + sql.FieldsGT(order.FieldTotal, order.FieldTax), + sql.FieldsLT(order.FieldTotal, order.FieldDiscount), + ), + ). + All(ctx) +``` +