diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 0eaf1ba15..c47f6e3e0 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1726,6 +1726,32 @@ func (p *Predicate) HasPrefix(col, prefix string) *Predicate { return p.escapedLike(col, "", "%", prefix) } +// ColumnsHasPrefix appends a new predicate that checks if the given column begins with the other column (prefix). +func ColumnsHasPrefix(col, prefixC string) *Predicate { + return P().ColumnsHasPrefix(col, prefixC) +} + +// ColumnsHasPrefix appends a new predicate that checks if the given column begins with the other column (prefix). +func (p *Predicate) ColumnsHasPrefix(col, prefixC string) *Predicate { + return p.Append(func(b *Builder) { + switch p.dialect { + case dialect.MySQL: + b.Ident(col) + b.WriteOp(OpLike) + b.S("CONCAT(REPLACE(REPLACE(").Ident(prefixC).S(", '_', '\\_'), '%', '\\%'), '%')") + case dialect.Postgres, dialect.SQLite: + b.Ident(col) + b.WriteOp(OpLike) + b.S("(REPLACE(REPLACE(").Ident(prefixC).S(", '_', '\\_'), '%', '\\%') || '%')") + if p.dialect == dialect.SQLite { + p.WriteString(" ESCAPE ").Arg("\\") + } + default: + b.AddError(fmt.Errorf("ColumnsHasPrefix: unsupported dialect: %q", p.dialect)) + } + }) +} + // HasSuffix is a helper predicate that checks suffix using the LIKE predicate. func HasSuffix(col, suffix string) *Predicate { return P().HasSuffix(col, suffix) } diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index bb10317e9..8454ab3bb 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -2475,3 +2475,24 @@ func TestSelector_SelectedColumn(t *testing.T) { require.Equal(t, []string{`"t2"."e"`, "t2.e", `"t1"."e"`, "t1.e", "e"}, s.FindSelection("e")) }) } + +func TestColumnsHasPrefix(t *testing.T) { + t.Run("MySQL", func(t *testing.T) { + query, args := Dialect(dialect.MySQL). + Select("*").From(Table("t1")).Where(ColumnsHasPrefix("a", "b")).Query() + require.Equal(t, "SELECT * FROM `t1` WHERE `a` LIKE CONCAT(REPLACE(REPLACE(`b`, '_', '\\_'), '%', '\\%'), '%')", query) + require.Empty(t, args) + }) + t.Run("Postgres", func(t *testing.T) { + query, args := Dialect(dialect.Postgres). + Select("*").From(Table("t1")).Where(ColumnsHasPrefix("a", "b")).Query() + require.Equal(t, `SELECT * FROM "t1" WHERE "a" LIKE (REPLACE(REPLACE("b", '_', '\_'), '%', '\%') || '%')`, query) + require.Empty(t, args) + }) + t.Run("SQLite", func(t *testing.T) { + query, args := Dialect(dialect.SQLite). + Select("*").From(Table("t1")).Where(ColumnsHasPrefix("a", "b")).Query() + require.Equal(t, "SELECT * FROM `t1` WHERE `a` LIKE (REPLACE(REPLACE(`b`, '_', '\\_'), '%', '\\%') || '%') ESCAPE ?", query) + require.Equal(t, []any{`\`}, args) + }) +} diff --git a/dialect/sql/sql.go b/dialect/sql/sql.go index 0f8494d65..bda4ecd58 100644 --- a/dialect/sql/sql.go +++ b/dialect/sql/sql.go @@ -113,6 +113,13 @@ func FieldsLTE(field1, field2 string) func(*Selector) { } } +// FieldsHasPrefix returns a raw predicate to checks if field1 begins with the value of field2. +func FieldsHasPrefix(field1, field2 string) func(*Selector) { + return func(s *Selector) { + s.Where(ColumnsHasPrefix(s.C(field1), s.C(field2))) + } +} + // FieldIn returns a raw predicate to check if the value of the field is IN the given values. func FieldIn[T any](name string, vs ...T) func(*Selector) { return func(s *Selector) { diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 6e5c8f3fb..c834e27b5 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -976,6 +976,15 @@ func Predicate(t *testing.T, client *ent.Client) { require.Equal(lab.ID, client.Group.Query().Where(group.Active(false)).OnlyIDX(ctx)) require.Equal(hub.ID, client.Group.Query().Where(group.ActiveNEQ(false)).OnlyIDX(ctx)) require.Equal(lab.ID, client.Group.Query().Where(group.ActiveNEQ(true)).OnlyIDX(ctx)) + + client.User.CreateBulk( + client.User.Create().SetAge(1).SetName("Ariel").SetNickname("A"), + client.User.Create().SetAge(1).SetName("Ariel").SetNickname("A%"), + ).ExecX(ctx) + a1 := client.User.Query().Where(sql.FieldsHasPrefix(user.FieldName, user.FieldNickname)).OnlyX(ctx) + require.Equal("A", a1.Nickname) + a2 := client.User.Query().Where(user.Not(sql.FieldsHasPrefix(user.FieldName, user.FieldNickname))).OnlyX(ctx) + require.Equal("A%", a2.Nickname) } func AddValues(t *testing.T, client *ent.Client) {