ent: export query interceptors (#3157)

This commit is contained in:
Ariel Mashraki
2022-12-19 10:17:10 +02:00
committed by GitHub
parent 3328201ba8
commit f226627d67
493 changed files with 22829 additions and 10766 deletions

View File

@@ -1762,7 +1762,7 @@ func (p *Predicate) Contains(col, substr string) *Predicate {
return p.escapedLike(col, "%", "%", substr)
}
// ContainsFold is a helper predicate that checks substring using the LIKE predicate.
// ContainsFold is a helper predicate that checks substring using the LIKE predicate with case-folding.
func ContainsFold(col, sub string) *Predicate { return P().ContainsFold(col, sub) }
// ContainsFold is a helper predicate that applies the LIKE predicate with case-folding.

136
dialect/sql/predicate.go Normal file
View File

@@ -0,0 +1,136 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sql
// This file provides extra helpers to simplify the way raw predicates
// are defined and used in both ent/schema and generated code. For full
// predicates, check out the sql.P in builder.go.
// FieldIsNull returns a raw predicate to check if the given field is NULL.
func FieldIsNull(name string) func(*Selector) {
return func(s *Selector) {
s.Where(IsNull(s.C(name)))
}
}
// FieldNotNull returns a raw predicate to check if the given field is not NULL.
func FieldNotNull(name string) func(*Selector) {
return func(s *Selector) {
s.Where(NotNull(s.C(name)))
}
}
// FieldEQ returns a raw predicate to check if the given field equals to the given value.
func FieldEQ(name string, v any) func(*Selector) {
return func(s *Selector) {
s.Where(EQ(s.C(name), v))
}
}
// FieldsEQ returns a raw predicate to check if the given fields (columns) are equal.
func FieldsEQ(field1, field2 string) func(*Selector) {
return func(s *Selector) {
s.Where(ColumnsEQ(s.C(field1), s.C(field2)))
}
}
// FieldNEQ returns a raw predicate to check if the given field does not equal to the given value.
func FieldNEQ(name string, v any) func(*Selector) {
return func(s *Selector) {
s.Where(NEQ(s.C(name), v))
}
}
// FieldsNEQ returns a raw predicate to check if the given fields (columns) are not equal.
func FieldsNEQ(field1, field2 string) func(*Selector) {
return func(s *Selector) {
s.Where(ColumnsNEQ(s.C(field1), s.C(field2)))
}
}
// FieldGT returns a raw predicate to check if the given field is greater than the given value.
func FieldGT(name string, v any) func(*Selector) {
return func(s *Selector) {
s.Where(GT(s.C(name), v))
}
}
// FieldGTE returns a raw predicate to check if the given field is greater than or equal the given value.
func FieldGTE(name string, v any) func(*Selector) {
return func(s *Selector) {
s.Where(GTE(s.C(name), v))
}
}
// FieldLT returns a raw predicate to check if the value of the field is less than the given value.
func FieldLT(name string, v any) func(*Selector) {
return func(s *Selector) {
s.Where(LT(s.C(name), v))
}
}
// FieldLTE returns a raw predicate to check if the value of the field is less than the given value.
func FieldLTE(name string, v any) func(*Selector) {
return func(s *Selector) {
s.Where(LTE(s.C(name), v))
}
}
// FieldIn returns a raw predicate to check if the value of the field is IN the given values.
func FieldIn[T any](name string, vs ...T) func(*Selector) {
return func(s *Selector) {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
s.Where(In(s.C(name), v...))
}
}
// FieldNotIn returns a raw predicate to check if the value of the field is NOT IN the given values.
func FieldNotIn[T any](name string, vs ...T) func(*Selector) {
return func(s *Selector) {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
s.Where(NotIn(s.C(name), v...))
}
}
// FieldEqualFold returns a raw predicate to check if the field has the given prefix with case-folding.
func FieldEqualFold(name string, substr string) func(*Selector) {
return func(s *Selector) {
s.Where(EqualFold(s.C(name), substr))
}
}
// FieldHasPrefix returns a raw predicate to check if the field has the given prefix.
func FieldHasPrefix(name string, prefix string) func(*Selector) {
return func(s *Selector) {
s.Where(HasPrefix(s.C(name), prefix))
}
}
// FieldHasSuffix returns a raw predicate to check if the field has the given suffix.
func FieldHasSuffix(name string, suffix string) func(*Selector) {
return func(s *Selector) {
s.Where(HasSuffix(s.C(name), suffix))
}
}
// FieldContains returns a raw predicate to check if the field contains the given substring.
func FieldContains(name string, substr string) func(*Selector) {
return func(s *Selector) {
s.Where(Contains(s.C(name), substr))
}
}
// FieldContainsFold returns a raw predicate to check if the field contains the given substring with case-folding.
func FieldContainsFold(name string, substr string) func(*Selector) {
return func(s *Selector) {
s.Where(ContainsFold(s.C(name), substr))
}
}

View File

@@ -0,0 +1,319 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sql
import (
"testing"
"entgo.io/ent/dialect"
"github.com/stretchr/testify/require"
)
func TestFieldIsNull(t *testing.T) {
p := FieldIsNull("name")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` IS NULL", query)
require.Empty(t, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" IS NULL`, query)
require.Empty(t, args)
})
}
func TestFieldNotNull(t *testing.T) {
p := FieldNotNull("name")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` IS NOT NULL", query)
require.Empty(t, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" IS NOT NULL`, query)
require.Empty(t, args)
})
}
func TestFieldEQ(t *testing.T) {
p := FieldEQ("name", "a8m")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` = ?", query)
require.Equal(t, []any{"a8m"}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" = $1`, query)
require.Equal(t, []any{"a8m"}, args)
})
}
func TestFieldsEQ(t *testing.T) {
p := FieldsEQ("create_time", "update_time")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`create_time` = `users`.`update_time`", query)
require.Empty(t, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."create_time" = "users"."update_time"`, query)
require.Empty(t, args)
})
}
func TestFieldsNEQ(t *testing.T) {
p := FieldsNEQ("create_time", "update_time")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`create_time` <> `users`.`update_time`", query)
require.Empty(t, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."create_time" <> "users"."update_time"`, query)
require.Empty(t, args)
})
}
func TestFieldNEQ(t *testing.T) {
p := FieldNEQ("name", "a8m")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` <> ?", query)
require.Equal(t, []any{"a8m"}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" <> $1`, query)
require.Equal(t, []any{"a8m"}, args)
})
}
func TestFieldGT(t *testing.T) {
p := FieldGT("stars", 1000)
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`stars` > ?", query)
require.Equal(t, []any{1000}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."stars" > $1`, query)
require.Equal(t, []any{1000}, args)
})
}
func TestFieldGTE(t *testing.T) {
p := FieldGTE("stars", 1000)
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`stars` >= ?", query)
require.Equal(t, []any{1000}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."stars" >= $1`, query)
require.Equal(t, []any{1000}, args)
})
}
func TestFieldLT(t *testing.T) {
p := FieldLT("stars", 1000)
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`stars` < ?", query)
require.Equal(t, []any{1000}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."stars" < $1`, query)
require.Equal(t, []any{1000}, args)
})
}
func TestFieldLTE(t *testing.T) {
p := FieldLTE("stars", 1000)
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`stars` <= ?", query)
require.Equal(t, []any{1000}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."stars" <= $1`, query)
require.Equal(t, []any{1000}, args)
})
}
func TestFieldIn(t *testing.T) {
p := FieldIn("name", "a8m", "foo", "bar")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` IN (?, ?, ?)", query)
require.Equal(t, []any{"a8m", "foo", "bar"}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" IN ($1, $2, $3)`, query)
require.Equal(t, []any{"a8m", "foo", "bar"}, args)
})
}
func TestFieldNotIn(t *testing.T) {
p := FieldNotIn("id", 1, 2, 3)
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`id` NOT IN (?, ?, ?)", query)
require.Equal(t, []any{1, 2, 3}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."id" NOT IN ($1, $2, $3)`, query)
require.Equal(t, []any{1, 2, 3}, args)
})
}
func TestFieldEqualFold(t *testing.T) {
p := FieldEqualFold("name", "a8m")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` COLLATE utf8mb4_general_ci = ?", query)
require.Equal(t, []any{"a8m"}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" ILIKE $1`, query)
require.Equal(t, []any{"a8m"}, args)
})
}
func TestFieldHasPrefix(t *testing.T) {
p := FieldHasPrefix("name", "a8m")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` LIKE ?", query)
require.Equal(t, []any{"a8m%"}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" LIKE $1`, query)
require.Equal(t, []any{"a8m%"}, args)
})
}
func TestFieldHasSuffix(t *testing.T) {
p := FieldHasSuffix("name", "a8m")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` LIKE ?", query)
require.Equal(t, []any{"%a8m"}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" LIKE $1`, query)
require.Equal(t, []any{"%a8m"}, args)
})
}
func TestFieldContains(t *testing.T) {
p := FieldContains("name", "a8m")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` LIKE ?", query)
require.Equal(t, []any{"%a8m%"}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" LIKE $1`, query)
require.Equal(t, []any{"%a8m%"}, args)
})
}
func TestFieldContainsFold(t *testing.T) {
p := FieldContainsFold("name", "a8m")
t.Run("MySQL", func(t *testing.T) {
s := Dialect(dialect.MySQL).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, "SELECT * FROM `users` WHERE `users`.`name` COLLATE utf8mb4_general_ci LIKE ?", query)
require.Equal(t, []any{"%a8m%"}, args)
})
t.Run("PostgreSQL", func(t *testing.T) {
s := Dialect(dialect.Postgres).Select("*").From(Table("users"))
p(s)
query, args := s.Query()
require.Equal(t, `SELECT * FROM "users" WHERE "users"."name" ILIKE $1`, query)
require.Equal(t, []any{"%a8m%"}, args)
})
}