Ariel Mashraki
2021-10-24 20:52:44 +03:00
committed by Ariel Mashraki
parent 4919889eb4
commit a1f6de2793
3 changed files with 115 additions and 12 deletions

View File

@@ -1595,6 +1595,42 @@ func (p *Predicate) Like(col, pattern string) *Predicate {
})
}
// escape escapes w with the default escape character ('/'),
// to be used by the pattern matching functions below.
// The second return value indicates if w was escaped or not.
func escape(w string) (string, bool) {
var n int
for i := range w {
if c := w[i]; c == '%' || c == '_' || c == '\\' {
n++
}
}
// No characters to escape.
if n == 0 {
return w, false
}
var b strings.Builder
b.Grow(len(w) + n)
for i := range w {
if c := w[i]; c == '%' || c == '_' || c == '\\' {
b.WriteByte('\\')
}
b.WriteByte(w[i])
}
return b.String(), true
}
func (p *Predicate) escapedLike(col, left, right, word string) *Predicate {
return p.Append(func(b *Builder) {
w, escaped := escape(word)
b.Ident(col).WriteOp(OpLike)
b.Arg(left + w + right)
if p.dialect == dialect.SQLite && escaped {
p.WriteString(" ESCAPE ").Arg("\\")
}
})
}
// HasPrefix is a helper predicate that checks prefix using the LIKE predicate.
func HasPrefix(col, prefix string) *Predicate {
return P().HasPrefix(col, prefix)
@@ -1602,7 +1638,7 @@ func HasPrefix(col, prefix string) *Predicate {
// HasPrefix is a helper predicate that checks prefix using the LIKE predicate.
func (p *Predicate) HasPrefix(col, prefix string) *Predicate {
return p.Like(col, prefix+"%")
return p.escapedLike(col, "", "%", prefix)
}
// HasSuffix is a helper predicate that checks suffix using the LIKE predicate.
@@ -1610,7 +1646,7 @@ func HasSuffix(col, suffix string) *Predicate { return P().HasSuffix(col, suffix
// HasSuffix is a helper predicate that checks suffix using the LIKE predicate.
func (p *Predicate) HasSuffix(col, suffix string) *Predicate {
return p.Like(col, "%"+suffix)
return p.escapedLike(col, "%", "", suffix)
}
// EqualFold is a helper predicate that applies the "=" predicate with case-folding.
@@ -1641,39 +1677,45 @@ func (p *Predicate) EqualFold(col, sub string) *Predicate {
func Contains(col, sub string) *Predicate { return P().Contains(col, sub) }
// Contains is a helper predicate that checks substring using the LIKE predicate.
func (p *Predicate) Contains(col, sub string) *Predicate {
return p.Like(col, "%"+sub+"%")
func (p *Predicate) Contains(col, substr string) *Predicate {
return p.escapedLike(col, "%", "%", substr)
}
// ContainsFold is a helper predicate that checks substring using the LIKE predicate.
func ContainsFold(col, sub string) *Predicate { return P().ContainsFold(col, sub) }
// ContainsFold is a helper predicate that applies the LIKE predicate with case-folding.
func (p *Predicate) ContainsFold(col, sub string) *Predicate {
func (p *Predicate) ContainsFold(col, substr string) *Predicate {
return p.Append(func(b *Builder) {
f := &Func{}
f.SetDialect(b.dialect)
w, escaped := escape(substr)
switch b.dialect {
case dialect.MySQL:
// We assume the CHARACTER SET is configured to utf8mb4,
// because this how it is defined in dialect/sql/schema.
b.Ident(col).WriteString(" COLLATE utf8mb4_general_ci LIKE ")
b.Arg("%" + strings.ToLower(w) + "%")
case dialect.Postgres:
b.Ident(col).WriteString(" ILIKE ")
b.Arg("%" + strings.ToLower(w) + "%")
default: // SQLite.
var f Func
f.SetDialect(b.dialect)
f.Lower(col)
b.WriteString(f.String()).WriteString(" LIKE ")
b.Arg("%" + strings.ToLower(w) + "%")
if escaped {
p.WriteString(" ESCAPE ").Arg("\\")
}
}
b.Arg("%" + strings.ToLower(sub) + "%")
})
}
// CompositeGT returns a comiposite ">" predicate
// CompositeGT returns a composite ">" predicate
func CompositeGT(columns []string, args ...interface{}) *Predicate {
return P().CompositeGT(columns, args...)
}
// CompositeLT returns a comiposite "<" predicate
// CompositeLT returns a composite "<" predicate
func CompositeLT(columns []string, args ...interface{}) *Predicate {
return P().CompositeLT(columns, args...)
}

View File

@@ -1842,3 +1842,35 @@ func TestInsert_OnConflict(t *testing.T) {
require.Equal(t, []interface{}{"Mashraki"}, args)
})
}
func TestEscapePatterns(t *testing.T) {
q, args := Dialect(dialect.MySQL).
Update("users").
SetNull("name").
Where(
Or(
HasPrefix("nickname", "%a8m%"),
HasSuffix("nickname", "_alexsn_"),
Contains("nickname", "\\pedro\\"),
ContainsFold("nickname", "%AbcD%efg"),
),
).
Query()
require.Equal(t, "UPDATE `users` SET `name` = NULL WHERE `nickname` LIKE ? OR `nickname` LIKE ? OR `nickname` LIKE ? OR `nickname` COLLATE utf8mb4_general_ci LIKE ?", q)
require.Equal(t, []interface{}{"\\%a8m\\%%", "%\\_alexsn\\_", "%\\\\pedro\\\\%", "%\\%abcd\\%efg%"}, args)
q, args = Dialect(dialect.SQLite).
Update("users").
SetNull("name").
Where(
Or(
HasPrefix("nickname", "%a8m%"),
HasSuffix("nickname", "_alexsn_"),
Contains("nickname", "\\pedro\\"),
ContainsFold("nickname", "%AbcD%efg"),
),
).
Query()
require.Equal(t, "UPDATE `users` SET `name` = NULL WHERE `nickname` LIKE ? ESCAPE ? OR `nickname` LIKE ? ESCAPE ? OR `nickname` LIKE ? ESCAPE ? OR LOWER(`nickname`) LIKE ? ESCAPE ?", q)
require.Equal(t, []interface{}{"\\%a8m\\%%", "\\", "%\\_alexsn\\_", "\\", "%\\\\pedro\\\\%", "\\", "%\\%abcd\\%efg%", "\\"}, args)
}