diff --git a/dialect/sql/sqljson/sqljson.go b/dialect/sql/sqljson/sqljson.go index de40aeb38..cdec95af4 100644 --- a/dialect/sql/sqljson/sqljson.go +++ b/dialect/sql/sqljson/sqljson.go @@ -184,6 +184,36 @@ func ValueContains(column string, arg interface{}, opts ...Option) *sql.Predicat }) } +// StringHasPrefix return a predicate for checking that a JSON string value +// (returned by the path) has the given substring as prefix +func StringHasPrefix(column string, prefix string, opts ...Option) *sql.Predicate { + return sql.P(func(b *sql.Builder) { + opts = append([]Option{Unquote(true)}, opts...) + ValuePath(b, column, opts...) + b.Join(sql.HasPrefix("", prefix)) + }) +} + +// StringHasSuffix return a predicate for checking that a JSON string value +// (returned by the path) has the given substring as suffix +func StringHasSuffix(column string, suffix string, opts ...Option) *sql.Predicate { + return sql.P(func(b *sql.Builder) { + opts = append([]Option{Unquote(true)}, opts...) + ValuePath(b, column, opts...) + b.Join(sql.HasSuffix("", suffix)) + }) +} + +// StringContains return a predicate for checking that a JSON string value +// (returned by the path) contains the given substring +func StringContains(column string, sub string, opts ...Option) *sql.Predicate { + return sql.P(func(b *sql.Builder) { + opts = append([]Option{Unquote(true)}, opts...) + ValuePath(b, column, opts...) + b.Join(sql.Contains("", sub)) + }) +} + // LenEQ return a predicate for checking that an array length // of a JSON (returned by the path) is equal to the given argument. // diff --git a/dialect/sql/sqljson/sqljson_test.go b/dialect/sql/sqljson/sqljson_test.go index fb5b396b3..8c40a5782 100644 --- a/dialect/sql/sqljson/sqljson_test.go +++ b/dialect/sql/sqljson/sqljson_test.go @@ -243,6 +243,80 @@ func TestWritePath(t *testing.T) { Where(sqljson.ValueIsNull("c", sqljson.Path("a"))), wantQuery: "SELECT * FROM `users` WHERE JSON_TYPE(`c`, \"$.a\") = 'null'", }, + { + input: sql.Dialect(dialect.Postgres). + Select("*"). + From(sql.Table("users")). + Where(sqljson.StringContains("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), + wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' LIKE $1`, + wantArgs: []interface{}{"%substr%"}, + }, + { + input: sql.Dialect(dialect.Postgres). + Select("*"). + From(sql.Table("users")). + Where( + sql.And( + sqljson.StringContains("a", "c", sqljson.Path("a")), + sqljson.StringContains("b", "d", sqljson.Path("b")), + ), + ), + wantQuery: `SELECT * FROM "users" WHERE "a"->>'a' LIKE $1 AND "b"->>'b' LIKE $2`, + wantArgs: []interface{}{"%c%", "%d%"}, + }, + { + input: sql.Dialect(dialect.MySQL). + Select("*"). + From(sql.Table("users")). + Where(sqljson.StringContains("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), + wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) LIKE ?", + wantArgs: []interface{}{"%substr%"}, + }, + { + input: sql.Dialect(dialect.MySQL). + Select("*"). + From(sql.Table("users")). + Where( + sql.And( + sqljson.StringContains("a", "c", sqljson.Path("a")), + sqljson.StringContains("b", "d", sqljson.Path("b")), + ), + ), + wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.a\")) LIKE ? AND JSON_UNQUOTE(JSON_EXTRACT(`b`, \"$.b\")) LIKE ?", + wantArgs: []interface{}{"%c%", "%d%"}, + }, + { + input: sql.Dialect(dialect.Postgres). + Select("*"). + From(sql.Table("users")). + Where(sqljson.StringHasPrefix("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), + wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' LIKE $1`, + wantArgs: []interface{}{"substr%"}, + }, + { + input: sql.Dialect(dialect.MySQL). + Select("*"). + From(sql.Table("users")). + Where(sqljson.StringHasPrefix("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), + wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) LIKE ?", + wantArgs: []interface{}{"substr%"}, + }, + { + input: sql.Dialect(dialect.Postgres). + Select("*"). + From(sql.Table("users")). + Where(sqljson.StringHasSuffix("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), + wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' LIKE $1`, + wantArgs: []interface{}{"%substr"}, + }, + { + input: sql.Dialect(dialect.MySQL). + Select("*"). + From(sql.Table("users")). + Where(sqljson.StringHasSuffix("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), + wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) LIKE ?", + wantArgs: []interface{}{"%substr"}, + }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { diff --git a/entc/integration/json/json_test.go b/entc/integration/json/json_test.go index 9deb0128d..d97f25707 100644 --- a/entc/integration/json/json_test.go +++ b/entc/integration/json/json_test.go @@ -300,6 +300,7 @@ func Predicates(t *testing.T, client *ent.Client) { s.Where(sqljson.ValueContains(user.FieldInts, 3)) }).OnlyX(ctx) require.Contains(t, r.Ints, 3) + r = client.User.Query().Where(func(s *sql.Selector) { s.Where(sqljson.ValueContains(user.FieldT, 3, sqljson.Path("li"))) }).OnlyX(ctx) @@ -334,4 +335,58 @@ func Predicates(t *testing.T, client *ent.Client) { }).CountX(ctx) require.Equal(t, 2, n, "both u1 and u2 have a 'User' key") }) + + t.Run("Strings", func(t *testing.T) { + client.User.Delete().ExecX(ctx) + u, err := url.Parse("https://github.com/a8m") + require.NoError(t, err) + dirs := []http.Dir{"/dev/null"} + _, err = client.User.CreateBulk( + client.User.Create().SetURL(u), + client.User.Create().SetDirs(dirs), + client.User.Create().SetT(&schema.T{S: "foobar", Ls: []string{"foo", "bar"}}), + ).Save(ctx) + require.NoError(t, err) + + ps := []*sql.Predicate{ + sqljson.StringContains(user.FieldDirs, "dev", sqljson.Path("[0]")), + sqljson.StringHasPrefix(user.FieldDirs, "/dev", sqljson.Path("[0]")), + sqljson.StringHasSuffix(user.FieldDirs, "/null", sqljson.Path("[0]")), + } + for _, p := range ps { + r = client.User.Query().Where(func(s *sql.Selector) { s.Where(p) }).OnlyX(ctx) + require.Equal(t, dirs, r.Dirs) + } + r = client.User.Query().Where(func(s *sql.Selector) { s.Where(sql.And(ps...)) }).OnlyX(ctx) + require.Equal(t, dirs, r.Dirs) + + ps = []*sql.Predicate{ + sqljson.StringContains(user.FieldURL, "hub", sqljson.Path("Host")), + sqljson.StringHasPrefix(user.FieldURL, "github", sqljson.Path("Host")), + sqljson.StringHasSuffix(user.FieldURL, "hub.com", sqljson.Path("Host")), + } + for _, p := range ps { + r = client.User.Query().Where(func(s *sql.Selector) { s.Where(p) }).OnlyX(ctx) + require.Equal(t, u, r.URL) + } + + ps = []*sql.Predicate{ + sqljson.StringHasPrefix(user.FieldT, "foo", sqljson.Path("ls", "[0]")), + sqljson.StringHasSuffix(user.FieldT, "bar", sqljson.DotPath("ls[1]")), + sql.And( + sql.Or( + sqljson.StringContains(user.FieldT, "foo", sqljson.DotPath("ls[0]")), + sqljson.StringContains(user.FieldT, "foo", sqljson.DotPath("ls[1]")), + ), + sql.Or( + sqljson.StringContains(user.FieldT, "bar", sqljson.DotPath("ls[0]")), + sqljson.StringContains(user.FieldT, "bar", sqljson.DotPath("ls[1]")), + ), + ), + } + for _, p := range ps { + r = client.User.Query().Where(func(s *sql.Selector) { s.Where(p) }).OnlyX(ctx) + require.Equal(t, []string{"foo", "bar"}, r.T.Ls) + } + }) }