diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 501aaf4c3..3f3af6d9f 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1595,6 +1595,42 @@ func (p *Predicate) Like(col, pattern string) *Predicate { }) } +// escape escapes w with the default escape character ('/'), +// to be used by the pattern matching functions below. +// The second return value indicates if w was escaped or not. +func escape(w string) (string, bool) { + var n int + for i := range w { + if c := w[i]; c == '%' || c == '_' || c == '\\' { + n++ + } + } + // No characters to escape. + if n == 0 { + return w, false + } + var b strings.Builder + b.Grow(len(w) + n) + for i := range w { + if c := w[i]; c == '%' || c == '_' || c == '\\' { + b.WriteByte('\\') + } + b.WriteByte(w[i]) + } + return b.String(), true +} + +func (p *Predicate) escapedLike(col, left, right, word string) *Predicate { + return p.Append(func(b *Builder) { + w, escaped := escape(word) + b.Ident(col).WriteOp(OpLike) + b.Arg(left + w + right) + if p.dialect == dialect.SQLite && escaped { + p.WriteString(" ESCAPE ").Arg("\\") + } + }) +} + // HasPrefix is a helper predicate that checks prefix using the LIKE predicate. func HasPrefix(col, prefix string) *Predicate { return P().HasPrefix(col, prefix) @@ -1602,7 +1638,7 @@ func HasPrefix(col, prefix string) *Predicate { // HasPrefix is a helper predicate that checks prefix using the LIKE predicate. func (p *Predicate) HasPrefix(col, prefix string) *Predicate { - return p.Like(col, prefix+"%") + return p.escapedLike(col, "", "%", prefix) } // HasSuffix is a helper predicate that checks suffix using the LIKE predicate. @@ -1610,7 +1646,7 @@ func HasSuffix(col, suffix string) *Predicate { return P().HasSuffix(col, suffix // HasSuffix is a helper predicate that checks suffix using the LIKE predicate. func (p *Predicate) HasSuffix(col, suffix string) *Predicate { - return p.Like(col, "%"+suffix) + return p.escapedLike(col, "%", "", suffix) } // EqualFold is a helper predicate that applies the "=" predicate with case-folding. @@ -1641,39 +1677,45 @@ func (p *Predicate) EqualFold(col, sub string) *Predicate { func Contains(col, sub string) *Predicate { return P().Contains(col, sub) } // Contains is a helper predicate that checks substring using the LIKE predicate. -func (p *Predicate) Contains(col, sub string) *Predicate { - return p.Like(col, "%"+sub+"%") +func (p *Predicate) Contains(col, substr string) *Predicate { + return p.escapedLike(col, "%", "%", substr) } // ContainsFold is a helper predicate that checks substring using the LIKE predicate. func ContainsFold(col, sub string) *Predicate { return P().ContainsFold(col, sub) } // ContainsFold is a helper predicate that applies the LIKE predicate with case-folding. -func (p *Predicate) ContainsFold(col, sub string) *Predicate { +func (p *Predicate) ContainsFold(col, substr string) *Predicate { return p.Append(func(b *Builder) { - f := &Func{} - f.SetDialect(b.dialect) + w, escaped := escape(substr) switch b.dialect { case dialect.MySQL: // We assume the CHARACTER SET is configured to utf8mb4, // because this how it is defined in dialect/sql/schema. b.Ident(col).WriteString(" COLLATE utf8mb4_general_ci LIKE ") + b.Arg("%" + strings.ToLower(w) + "%") case dialect.Postgres: b.Ident(col).WriteString(" ILIKE ") + b.Arg("%" + strings.ToLower(w) + "%") default: // SQLite. + var f Func + f.SetDialect(b.dialect) f.Lower(col) b.WriteString(f.String()).WriteString(" LIKE ") + b.Arg("%" + strings.ToLower(w) + "%") + if escaped { + p.WriteString(" ESCAPE ").Arg("\\") + } } - b.Arg("%" + strings.ToLower(sub) + "%") }) } -// CompositeGT returns a comiposite ">" predicate +// CompositeGT returns a composite ">" predicate func CompositeGT(columns []string, args ...interface{}) *Predicate { return P().CompositeGT(columns, args...) } -// CompositeLT returns a comiposite "<" predicate +// CompositeLT returns a composite "<" predicate func CompositeLT(columns []string, args ...interface{}) *Predicate { return P().CompositeLT(columns, args...) } diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index ceee89721..f5e3e4b25 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1842,3 +1842,35 @@ func TestInsert_OnConflict(t *testing.T) { require.Equal(t, []interface{}{"Mashraki"}, args) }) } + +func TestEscapePatterns(t *testing.T) { + q, args := Dialect(dialect.MySQL). + Update("users"). + SetNull("name"). + Where( + Or( + HasPrefix("nickname", "%a8m%"), + HasSuffix("nickname", "_alexsn_"), + Contains("nickname", "\\pedro\\"), + ContainsFold("nickname", "%AbcD%efg"), + ), + ). + Query() + require.Equal(t, "UPDATE `users` SET `name` = NULL WHERE `nickname` LIKE ? OR `nickname` LIKE ? OR `nickname` LIKE ? OR `nickname` COLLATE utf8mb4_general_ci LIKE ?", q) + require.Equal(t, []interface{}{"\\%a8m\\%%", "%\\_alexsn\\_", "%\\\\pedro\\\\%", "%\\%abcd\\%efg%"}, args) + + q, args = Dialect(dialect.SQLite). + Update("users"). + SetNull("name"). + Where( + Or( + HasPrefix("nickname", "%a8m%"), + HasSuffix("nickname", "_alexsn_"), + Contains("nickname", "\\pedro\\"), + ContainsFold("nickname", "%AbcD%efg"), + ), + ). + Query() + require.Equal(t, "UPDATE `users` SET `name` = NULL WHERE `nickname` LIKE ? ESCAPE ? OR `nickname` LIKE ? ESCAPE ? OR `nickname` LIKE ? ESCAPE ? OR LOWER(`nickname`) LIKE ? ESCAPE ?", q) + require.Equal(t, []interface{}{"\\%a8m\\%%", "\\", "%\\_alexsn\\_", "\\", "%\\\\pedro\\\\%", "\\", "%\\%abcd\\%efg%", "\\"}, args) +} diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index d1c59402d..a6e5fecb4 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -260,6 +260,34 @@ func Sanity(t *testing.T, client *ent.Client) { require.True(ok) require.Equal("-", fi.Tag.Get("json")) client.User.Create().SetName("tarrence").SetAge(30).ExecX(ctx) + + t.Run("StringPredicates", func(t *testing.T) { + client.Pet.Delete().ExecX(ctx) + a := client.Pet.Create().SetName("a%").SaveX(ctx) + require.True(client.Pet.Query().Where(pet.NameHasPrefix("a%")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameHasPrefix("%a%")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.Or(pet.NameHasPrefix("%a%"), pet.NameHasPrefix("%a%"))).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameHasSuffix("%")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameHasSuffix("a%%")).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameContains("a")).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameContains("a%")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameContains("%a%")).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameContainsFold("A%")).ExistX(ctx)) + + a.Update().SetName("a_\\").ExecX(ctx) + require.True(client.Pet.Query().Where(pet.NameHasPrefix("a")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameHasPrefix("%a")).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameHasPrefix("a_")).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameHasSuffix("a_\\")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameHasSuffix("%a")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameHasSuffix("a%")).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameContains("a")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameContains("%a")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameContains("a%")).ExistX(ctx)) + require.True(client.Pet.Query().Where(pet.NameContainsFold("A")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameContainsFold("%A")).ExistX(ctx)) + require.False(client.Pet.Query().Where(pet.NameContainsFold("A%")).ExistX(ctx)) + }) } func Upsert(t *testing.T, client *ent.Client) { @@ -944,8 +972,9 @@ func Relation(t *testing.T, client *ent.Client) { require.Empty(client.User.Query().Where(user.NameIn("alex", "rocket")).AllX(ctx)) require.NotNil(client.User.Query().Where(user.HasParentWith(user.NameIn("a8m", "neta"))).OnlyX(ctx)) require.Len(client.User.Query().Where(user.NameContains("a8")).AllX(ctx), 1) - require.Len(client.User.Query().Where(user.NameHasPrefix("a8")).AllX(ctx), 1) - require.Len(client.User.Query().Where(user.Or(user.NameHasPrefix("a8"), user.NameHasSuffix("eta"))).AllX(ctx), 2) + require.Equal(1, client.User.Query().Where(user.NameHasPrefix("a8")).CountX(ctx)) + require.Zero(client.User.Query().Where(user.NameHasPrefix("%a8%")).CountX(ctx)) + require.Equal(2, client.User.Query().Where(user.Or(user.NameHasPrefix("a8"), user.NameHasSuffix("eta"))).CountX(ctx)) t.Log("group-by one field") names, err := client.User.Query().GroupBy(user.FieldName).Strings(ctx)