dialect/sql: additional predicate helpers (#3429)

This commit is contained in:
Ariel Mashraki
2023-04-02 13:01:36 +03:00
committed by GitHub
parent 27bc0470eb
commit 90289b7494
3 changed files with 125 additions and 3 deletions

View File

@@ -4,9 +4,9 @@
package sql
// This file provides extra helpers to simplify the way raw predicates
// are defined and used in both ent/schema and generated code. For full
// predicates, check out the sql.P in builder.go.
// The following helpers exist to simplify the way raw predicates
// are defined and used in both ent/schema and generated code. For
// full predicates API, check out the sql.P in builder.go.
// FieldIsNull returns a raw predicate to check if the given field is NULL.
func FieldIsNull(name string) func(*Selector) {
@@ -57,6 +57,13 @@ func FieldGT(name string, v any) func(*Selector) {
}
}
// FieldsGT returns a raw predicate to check if field1 is greater than field2.
func FieldsGT(field1, field2 string) func(*Selector) {
return func(s *Selector) {
s.Where(ColumnsGT(s.C(field1), s.C(field2)))
}
}
// FieldGTE returns a raw predicate to check if the given field is greater than or equal the given value.
func FieldGTE(name string, v any) func(*Selector) {
return func(s *Selector) {
@@ -64,6 +71,13 @@ func FieldGTE(name string, v any) func(*Selector) {
}
}
// FieldsGTE returns a raw predicate to check if field1 is greater than or equal field2.
func FieldsGTE(field1, field2 string) func(*Selector) {
return func(s *Selector) {
s.Where(ColumnsGTE(s.C(field1), s.C(field2)))
}
}
// FieldLT returns a raw predicate to check if the value of the field is less than the given value.
func FieldLT(name string, v any) func(*Selector) {
return func(s *Selector) {
@@ -71,6 +85,13 @@ func FieldLT(name string, v any) func(*Selector) {
}
}
// FieldsLT returns a raw predicate to check if field1 is lower than field2.
func FieldsLT(field1, field2 string) func(*Selector) {
return func(s *Selector) {
s.Where(ColumnsLT(s.C(field1), s.C(field2)))
}
}
// FieldLTE returns a raw predicate to check if the value of the field is less than the given value.
func FieldLTE(name string, v any) func(*Selector) {
return func(s *Selector) {
@@ -78,6 +99,13 @@ func FieldLTE(name string, v any) func(*Selector) {
}
}
// FieldsLTE returns a raw predicate to check if field1 is lower than or equal field2.
func FieldsLTE(field1, field2 string) func(*Selector) {
return func(s *Selector) {
s.Where(ColumnsLTE(s.C(field1), s.C(field2)))
}
}
// FieldIn returns a raw predicate to check if the value of the field is IN the given values.
func FieldIn[T any](name string, vs ...T) func(*Selector) {
return func(s *Selector) {

View File

@@ -138,6 +138,24 @@ func TestFieldGT(t *testing.T) {
})
}
func TestFieldsGT(t *testing.T) {
p := FieldsGT("a", "b")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` > `users`.`b`", query)
require.Empty(t, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" > "users"."b"`, query)
require.Empty(t, args)
})
}
func TestFieldGTE(t *testing.T) {
p := FieldGTE("stars", 1000)
t.Run("MySQL", func(t *testing.T) {
@@ -156,6 +174,24 @@ func TestFieldGTE(t *testing.T) {
})
}
func TestFieldsGTE(t *testing.T) {
p := FieldsGTE("a", "b")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` >= `users`.`b`", query)
require.Empty(t, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" >= "users"."b"`, query)
require.Empty(t, args)
})
}
func TestFieldLT(t *testing.T) {
p := FieldLT("stars", 1000)
t.Run("MySQL", func(t *testing.T) {
@@ -174,6 +210,24 @@ func TestFieldLT(t *testing.T) {
})
}
func TestFieldsLT(t *testing.T) {
p := FieldsLT("a", "b")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` < `users`.`b`", query)
require.Empty(t, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" < "users"."b"`, query)
require.Empty(t, args)
})
}
func TestFieldLTE(t *testing.T) {
p := FieldLTE("stars", 1000)
t.Run("MySQL", func(t *testing.T) {
@@ -192,6 +246,24 @@ func TestFieldLTE(t *testing.T) {
})
}
func TestFieldsLTE(t *testing.T) {
p := FieldsLTE("a", "b")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` <= `users`.`b`", query)
require.Empty(t, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" <= "users"."b"`, query)
require.Empty(t, args)
})
}
func TestFieldIn(t *testing.T) {
p := FieldIn("name", "a8m", "foo", "bar")
t.Run("MySQL", func(t *testing.T) {