mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: additional predicate helpers (#3429)
This commit is contained in:
@@ -4,9 +4,9 @@
|
||||
|
||||
package sql
|
||||
|
||||
// This file provides extra helpers to simplify the way raw predicates
|
||||
// are defined and used in both ent/schema and generated code. For full
|
||||
// predicates, check out the sql.P in builder.go.
|
||||
// The following helpers exist to simplify the way raw predicates
|
||||
// are defined and used in both ent/schema and generated code. For
|
||||
// full predicates API, check out the sql.P in builder.go.
|
||||
|
||||
// FieldIsNull returns a raw predicate to check if the given field is NULL.
|
||||
func FieldIsNull(name string) func(*Selector) {
|
||||
@@ -57,6 +57,13 @@ func FieldGT(name string, v any) func(*Selector) {
|
||||
}
|
||||
}
|
||||
|
||||
// FieldsGT returns a raw predicate to check if field1 is greater than field2.
|
||||
func FieldsGT(field1, field2 string) func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
s.Where(ColumnsGT(s.C(field1), s.C(field2)))
|
||||
}
|
||||
}
|
||||
|
||||
// FieldGTE returns a raw predicate to check if the given field is greater than or equal the given value.
|
||||
func FieldGTE(name string, v any) func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
@@ -64,6 +71,13 @@ func FieldGTE(name string, v any) func(*Selector) {
|
||||
}
|
||||
}
|
||||
|
||||
// FieldsGTE returns a raw predicate to check if field1 is greater than or equal field2.
|
||||
func FieldsGTE(field1, field2 string) func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
s.Where(ColumnsGTE(s.C(field1), s.C(field2)))
|
||||
}
|
||||
}
|
||||
|
||||
// FieldLT returns a raw predicate to check if the value of the field is less than the given value.
|
||||
func FieldLT(name string, v any) func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
@@ -71,6 +85,13 @@ func FieldLT(name string, v any) func(*Selector) {
|
||||
}
|
||||
}
|
||||
|
||||
// FieldsLT returns a raw predicate to check if field1 is lower than field2.
|
||||
func FieldsLT(field1, field2 string) func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
s.Where(ColumnsLT(s.C(field1), s.C(field2)))
|
||||
}
|
||||
}
|
||||
|
||||
// FieldLTE returns a raw predicate to check if the value of the field is less than the given value.
|
||||
func FieldLTE(name string, v any) func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
@@ -78,6 +99,13 @@ func FieldLTE(name string, v any) func(*Selector) {
|
||||
}
|
||||
}
|
||||
|
||||
// FieldsLTE returns a raw predicate to check if field1 is lower than or equal field2.
|
||||
func FieldsLTE(field1, field2 string) func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
s.Where(ColumnsLTE(s.C(field1), s.C(field2)))
|
||||
}
|
||||
}
|
||||
|
||||
// FieldIn returns a raw predicate to check if the value of the field is IN the given values.
|
||||
func FieldIn[T any](name string, vs ...T) func(*Selector) {
|
||||
return func(s *Selector) {
|
||||
@@ -138,6 +138,24 @@ func TestFieldGT(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldsGT(t *testing.T) {
|
||||
p := FieldsGT("a", "b")
|
||||
t.Run("MySQL", func(t *testing.T) {
|
||||
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
|
||||
p(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` > `users`.`b`", query)
|
||||
require.Empty(t, args)
|
||||
})
|
||||
t.Run("PostgreSQL", func(t *testing.T) {
|
||||
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
|
||||
p(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" > "users"."b"`, query)
|
||||
require.Empty(t, args)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldGTE(t *testing.T) {
|
||||
p := FieldGTE("stars", 1000)
|
||||
t.Run("MySQL", func(t *testing.T) {
|
||||
@@ -156,6 +174,24 @@ func TestFieldGTE(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldsGTE(t *testing.T) {
|
||||
p := FieldsGTE("a", "b")
|
||||
t.Run("MySQL", func(t *testing.T) {
|
||||
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
|
||||
p(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` >= `users`.`b`", query)
|
||||
require.Empty(t, args)
|
||||
})
|
||||
t.Run("PostgreSQL", func(t *testing.T) {
|
||||
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
|
||||
p(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" >= "users"."b"`, query)
|
||||
require.Empty(t, args)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldLT(t *testing.T) {
|
||||
p := FieldLT("stars", 1000)
|
||||
t.Run("MySQL", func(t *testing.T) {
|
||||
@@ -174,6 +210,24 @@ func TestFieldLT(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldsLT(t *testing.T) {
|
||||
p := FieldsLT("a", "b")
|
||||
t.Run("MySQL", func(t *testing.T) {
|
||||
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
|
||||
p(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` < `users`.`b`", query)
|
||||
require.Empty(t, args)
|
||||
})
|
||||
t.Run("PostgreSQL", func(t *testing.T) {
|
||||
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
|
||||
p(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" < "users"."b"`, query)
|
||||
require.Empty(t, args)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldLTE(t *testing.T) {
|
||||
p := FieldLTE("stars", 1000)
|
||||
t.Run("MySQL", func(t *testing.T) {
|
||||
@@ -192,6 +246,24 @@ func TestFieldLTE(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldsLTE(t *testing.T) {
|
||||
p := FieldsLTE("a", "b")
|
||||
t.Run("MySQL", func(t *testing.T) {
|
||||
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
|
||||
p(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`a` <= `users`.`b`", query)
|
||||
require.Empty(t, args)
|
||||
})
|
||||
t.Run("PostgreSQL", func(t *testing.T) {
|
||||
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
|
||||
p(s)
|
||||
query, args := s.Query()
|
||||
require.Equal(t, `SELECT * FROM "users" WHERE "users"."a" <= "users"."b"`, query)
|
||||
require.Empty(t, args)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldIn(t *testing.T) {
|
||||
p := FieldIn("name", "a8m", "foo", "bar")
|
||||
t.Run("MySQL", func(t *testing.T) {
|
||||
Reference in New Issue
Block a user