diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index ab885a165..7836f34bd 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1703,14 +1703,17 @@ func (p *Predicate) EqualFold(col, sub string) *Predicate { // We assume the CHARACTER SET is configured to utf8mb4, // because this how it is defined in dialect/sql/schema. b.Ident(col).WriteString(" COLLATE utf8mb4_general_ci = ") + b.Arg(strings.ToLower(sub)) case dialect.Postgres: b.Ident(col).WriteString(" ILIKE ") + w, _ := escape(sub) + b.Arg(strings.ToLower(w)) default: // SQLite. f.Lower(col) b.WriteString(f.String()) b.WriteOp(OpEQ) + b.Arg(strings.ToLower(sub)) } - b.Arg(strings.ToLower(sub)) }) } diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 830718cb4..aa1e80f6b 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1019,6 +1019,22 @@ func TestBuilder(t *testing.T) { wantQuery: `SELECT * FROM "users" WHERE "name" ILIKE $1 OR "name" ILIKE $2`, wantArgs: []interface{}{"bar", "baz"}, }, + { + input: Dialect(dialect.Postgres). + Select(). + From(Table("users")). + Where(Or(EqualFold("name", "BAR%"), EqualFold("name", "%BAZ"))), + wantQuery: `SELECT * FROM "users" WHERE "name" ILIKE $1 OR "name" ILIKE $2`, + wantArgs: []interface{}{"bar\\%", "\\%baz"}, + }, + { + input: Dialect(dialect.Postgres). + Select(). + From(Table("users")). + Where(Or(EqualFold("name", "BAR\\"), EqualFold("name", "\\BAZ"))), + wantQuery: `SELECT * FROM "users" WHERE "name" ILIKE $1 OR "name" ILIKE $2`, + wantArgs: []interface{}{"bar\\\\", "\\\\baz"}, + }, { input: Dialect(dialect.MySQL). Select(). diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 74a1cc939..574618836 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -298,6 +298,10 @@ func Sanity(t *testing.T, client *ent.Client) { require.True(client.Pet.Query().Where(pet.NameContainsFold("A")).ExistX(ctx)) require.False(client.Pet.Query().Where(pet.NameContainsFold("%A")).ExistX(ctx)) require.False(client.Pet.Query().Where(pet.NameContainsFold("A%")).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameEqualFold("A_\\")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameEqualFold("%A_\\")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameEqualFold("A_\\%")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameEqualFold("A%")).ExistX(ctx)) }) }