mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/sqljson: move json predicates to a package (#735)
This commit is contained in:
@@ -10,7 +10,6 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/facebook/ent/dialect"
|
||||
)
|
||||
@@ -967,37 +966,6 @@ func (p *Predicate) GTE(col string, arg interface{}) *Predicate {
|
||||
})
|
||||
}
|
||||
|
||||
// JSONHasKey calls Predicate.JSONHasKey.
|
||||
func JSONHasKey(col, path string) *Predicate {
|
||||
return P().JSONHasKey(col, path)
|
||||
}
|
||||
|
||||
// JSONHasKey return a predicate for checking that a JSON key exists and not NULL.
|
||||
//
|
||||
// P().JSONHasKey("column", "a.b[2].c")
|
||||
//
|
||||
func (p *Predicate) JSONHasKey(col, path string) *Predicate {
|
||||
return p.Append(func(b *Builder) {
|
||||
b.JSONPath(col, DotPath(path)).WriteOp(OpNotNull)
|
||||
})
|
||||
}
|
||||
|
||||
// JSONValueEQ calls Predicate.JSONValueEQ.
|
||||
func JSONValueEQ(col, path string, arg interface{}) *Predicate {
|
||||
return P().JSONValueEQ(col, path, arg)
|
||||
}
|
||||
|
||||
// JSONValueEQ return a predicate for checking that a JSON
|
||||
// value (returned by the path) is equal to the given argument.
|
||||
//
|
||||
// P().JSONValueEQ("column", "a.b[2].c", arg)
|
||||
//
|
||||
func (p *Predicate) JSONValueEQ(col, path string, arg interface{}) *Predicate {
|
||||
return p.Append(func(b *Builder) {
|
||||
b.JSONPath(col, DotPath(path)).WriteOp(OpEQ).Arg(arg)
|
||||
})
|
||||
}
|
||||
|
||||
// NotNull returns the `IS NOT NULL` predicate.
|
||||
func NotNull(col string) *Predicate {
|
||||
return P().NotNull(col)
|
||||
@@ -1962,114 +1930,6 @@ func (b *Builder) WriteOp(op Op) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// JSONOption allows for calling database JSON paths with functional options.
|
||||
type JSONOption func(*JSONPath)
|
||||
|
||||
// Path sets the path to the JSON value of a column.
|
||||
//
|
||||
// b.JSONPath("column", Path("a", "b", "[1]", "c"))
|
||||
//
|
||||
func Path(path ...string) JSONOption {
|
||||
return func(p *JSONPath) {
|
||||
p.path = path
|
||||
}
|
||||
}
|
||||
|
||||
// DotPath is similar to Path, but accepts string with dot format.
|
||||
//
|
||||
// b.JSONPath("column", DotPath("a.b[2].c"))
|
||||
// b.JSONPath("column", DotPath("a.b.c"))
|
||||
//
|
||||
// Note that DotPath is ignored if the input is invalid.
|
||||
func DotPath(dotpath string) JSONOption {
|
||||
path, _ := ParsePath(dotpath)
|
||||
return func(p *JSONPath) {
|
||||
p.path = path
|
||||
}
|
||||
}
|
||||
|
||||
// Unquote indicates that the result value should be unquoted.
|
||||
//
|
||||
// b.JSONPath("column", Path("a", "b", "[1]", "c"), Unquote(true))
|
||||
//
|
||||
func Unquote(unquote bool) JSONOption {
|
||||
return func(p *JSONPath) {
|
||||
p.unquote = unquote
|
||||
}
|
||||
}
|
||||
|
||||
// Cast indicates that the result value should be casted to the given type.
|
||||
//
|
||||
// b.JSONPath("column", Path("a", "b", "[1]", "c"), Cast("int"))
|
||||
//
|
||||
func Cast(typ string) JSONOption {
|
||||
return func(p *JSONPath) {
|
||||
p.cast = typ
|
||||
}
|
||||
}
|
||||
|
||||
// JSONPath represents a path to a JSON value.
|
||||
type JSONPath struct {
|
||||
ident string
|
||||
path []string
|
||||
cast string
|
||||
unquote bool
|
||||
}
|
||||
|
||||
// writeTo writes the JSON path to the builder.
|
||||
func (p *JSONPath) writeTo(b *Builder) {
|
||||
switch {
|
||||
case len(p.path) == 0:
|
||||
b.Ident(p.ident)
|
||||
case b.postgres():
|
||||
if p.cast != "" {
|
||||
b.WriteString("CAST(")
|
||||
defer b.WriteString(" AS " + p.cast + ")")
|
||||
}
|
||||
b.Ident(p.ident)
|
||||
for i, s := range p.path {
|
||||
b.WriteString("->")
|
||||
if p.unquote && i == len(p.path)-1 {
|
||||
b.WriteString(">")
|
||||
}
|
||||
if idx, ok := isJSONIdx(s); ok {
|
||||
b.WriteString(idx)
|
||||
} else {
|
||||
b.WriteString("'" + s + "'")
|
||||
}
|
||||
}
|
||||
default:
|
||||
if p.unquote && b.mysql() {
|
||||
b.WriteString("JSON_UNQUOTE(")
|
||||
defer b.WriteByte(')')
|
||||
}
|
||||
b.WriteString("JSON_EXTRACT(")
|
||||
b.Ident(p.ident).Comma()
|
||||
b.WriteString(`"$`)
|
||||
for _, p := range p.path {
|
||||
if _, ok := isJSONIdx(p); ok {
|
||||
b.WriteString(p)
|
||||
} else {
|
||||
b.WriteString("." + p)
|
||||
}
|
||||
}
|
||||
b.WriteString(`")`)
|
||||
}
|
||||
}
|
||||
|
||||
// JSONPath appends the given JSON paths to the builder.
|
||||
//
|
||||
// b.JSONPath("column", Path("a", "b", "[1]", "c"), Cast("int"))
|
||||
//
|
||||
func (b *Builder) JSONPath(ident string, opts ...JSONOption) *Builder {
|
||||
path := &JSONPath{ident: ident}
|
||||
for i := range opts {
|
||||
opts[i](path)
|
||||
}
|
||||
path.writeTo(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// Arg appends an input argument to the builder.
|
||||
func (b *Builder) Arg(a interface{}) *Builder {
|
||||
if r, ok := a.(*raw); ok {
|
||||
@@ -2197,11 +2057,6 @@ func (b Builder) postgres() bool {
|
||||
return b.Dialect() == dialect.Postgres
|
||||
}
|
||||
|
||||
// mysql reports if the builder dialect is MySQL.
|
||||
func (b Builder) mysql() bool {
|
||||
return b.Dialect() == dialect.MySQL
|
||||
}
|
||||
|
||||
// fromIdent sets the builder dialect from the identifier format.
|
||||
func (b *Builder) fromIdent(ident string) {
|
||||
if strings.Contains(ident, `"`) {
|
||||
@@ -2397,78 +2252,6 @@ func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// ParsePath parses the "dotpath" for the DotPath option.
|
||||
//
|
||||
// "a.b" => ["a", "b"]
|
||||
// "a[1][2]" => ["a", "[1]", "[2]"]
|
||||
// "a.\"b.c\" => ["a", "\"b.c\""]
|
||||
//
|
||||
func ParsePath(dotpath string) ([]string, error) {
|
||||
var (
|
||||
i, p int
|
||||
path []string
|
||||
)
|
||||
for i < len(dotpath) {
|
||||
switch r := dotpath[i]; {
|
||||
case r == '"':
|
||||
if i == len(dotpath)-1 {
|
||||
return nil, fmt.Errorf("unexpected quote")
|
||||
}
|
||||
idx := strings.IndexRune(dotpath[i+1:], '"')
|
||||
if idx == -1 || idx == 0 {
|
||||
return nil, fmt.Errorf("unbalanced quote")
|
||||
}
|
||||
i += idx + 2
|
||||
case r == '[':
|
||||
if p != i {
|
||||
path = append(path, dotpath[p:i])
|
||||
}
|
||||
p = i
|
||||
if i == len(dotpath)-1 {
|
||||
return nil, fmt.Errorf("unexpected bracket")
|
||||
}
|
||||
idx := strings.IndexRune(dotpath[i:], ']')
|
||||
if idx == -1 || idx == 1 {
|
||||
return nil, fmt.Errorf("unbalanced bracket")
|
||||
}
|
||||
if !isNumber(dotpath[i+1 : i+idx]) {
|
||||
return nil, fmt.Errorf("invalid index %q", dotpath[i:i+idx+1])
|
||||
}
|
||||
i += idx + 1
|
||||
case r == '.' || r == ']':
|
||||
if p != i {
|
||||
path = append(path, dotpath[p:i])
|
||||
}
|
||||
i++
|
||||
p = i
|
||||
default:
|
||||
i++
|
||||
}
|
||||
}
|
||||
if p != i {
|
||||
path = append(path, dotpath[p:i])
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// isNumber reports whether the string is a number (category N).
|
||||
func isNumber(s string) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsNumber(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isJSONIdx reports whether the string represents a JSON index.
|
||||
func isJSONIdx(s string) (string, bool) {
|
||||
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) {
|
||||
return s[1 : len(s)-1], true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func isFunc(s string) bool {
|
||||
return strings.Contains(s, "(") && strings.Contains(s, ")")
|
||||
}
|
||||
|
||||
@@ -1213,42 +1213,6 @@ WHERE
|
||||
OR ("f" <> $10 AND "g" <> $11)`),
|
||||
wantArgs: []interface{}{1, 2, 3, 2, 4, 5, "a", "b", "c", "f", "g"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.MySQL).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(P(func(b *Builder) {
|
||||
b.JSONPath("a", Path("b", "c", "[1]", "d"), Unquote(true))
|
||||
b.WriteOp(OpEQ)
|
||||
b.Arg("a")
|
||||
})),
|
||||
wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) = ?",
|
||||
wantArgs: []interface{}{"a"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(P(func(b *Builder) {
|
||||
b.JSONPath("a", Path("b", "c", "[1]", "d"), Unquote(true))
|
||||
b.WriteOp(OpEQ)
|
||||
b.Arg("a")
|
||||
})),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' = $1`,
|
||||
wantArgs: []interface{}{"a"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(P(func(b *Builder) {
|
||||
b.JSONPath("a", Path("b", "c", "[1]", "d"), Cast("int"))
|
||||
b.WriteOp(OpEQ)
|
||||
b.Arg(1)
|
||||
})),
|
||||
wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`,
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
@@ -1269,69 +1233,6 @@ WHERE
|
||||
wantQuery: `SELECT * FROM "test" WHERE nlevel("path") > $1`,
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(P(func(b *Builder) {
|
||||
b.JSONPath("a", DotPath("b.c[1].d"), Cast("int"))
|
||||
b.WriteOp(OpEQ)
|
||||
b.Arg(1)
|
||||
})),
|
||||
wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`,
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.MySQL).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(P(func(b *Builder) {
|
||||
b.JSONPath("a", DotPath("b.c[1].d"))
|
||||
b.WriteOp(OpEQ)
|
||||
b.Arg("a")
|
||||
})),
|
||||
wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.c[1].d\") = ?",
|
||||
wantArgs: []interface{}{"a"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.MySQL).
|
||||
Select("*").
|
||||
From(Table("users")).
|
||||
Where(P(func(b *Builder) {
|
||||
b.JSONPath("a", DotPath("b.\"c[1]\".d[1][2].e"))
|
||||
b.WriteOp(OpEQ)
|
||||
b.Arg("a")
|
||||
})),
|
||||
wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.\"c[1]\".d[1][2].e\") = ?",
|
||||
wantArgs: []interface{}{"a"},
|
||||
},
|
||||
{
|
||||
input: Select("*").
|
||||
From(Table("test")).
|
||||
Where(JSONHasKey("j", "a.*.c")),
|
||||
wantQuery: "SELECT * FROM `test` WHERE JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL",
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("test")).
|
||||
Where(JSONHasKey("j", "a.b.c")),
|
||||
wantQuery: `SELECT * FROM "test" WHERE "j"->'a'->'b'->'c' IS NOT NULL`,
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(Table("test")).
|
||||
Where(JSONHasKey("j", "a.b.c")),
|
||||
wantQuery: `SELECT * FROM "test" WHERE "j"->'a'->'b'->'c' IS NOT NULL`,
|
||||
},
|
||||
{
|
||||
input: Select("*").
|
||||
From(Table("test")).
|
||||
Where(JSONValueEQ("j", "a.b.c", 1)),
|
||||
wantQuery: "SELECT * FROM `test` WHERE JSON_EXTRACT(`j`, \"$.a.b.c\") = ?",
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
@@ -1341,58 +1242,3 @@ WHERE
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantPath []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
input: "a.b.c",
|
||||
wantPath: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
input: "a[1][2]",
|
||||
wantPath: []string{"a", "[1]", "[2]"},
|
||||
},
|
||||
{
|
||||
input: "a[1][2].b",
|
||||
wantPath: []string{"a", "[1]", "[2]", "b"},
|
||||
},
|
||||
{
|
||||
input: `a."b.c[0]"`,
|
||||
wantPath: []string{"a", `"b.c[0]"`},
|
||||
},
|
||||
{
|
||||
input: `a."b.c[0]".d`,
|
||||
wantPath: []string{"a", `"b.c[0]"`, "d"},
|
||||
},
|
||||
{
|
||||
input: `...`,
|
||||
},
|
||||
{
|
||||
input: `.a.b.`,
|
||||
wantPath: []string{"a", "b"},
|
||||
},
|
||||
{
|
||||
input: `a."`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
input: `a[`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
input: `a[a]`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
path, err := ParsePath(tt.input)
|
||||
require.Equal(t, tt.wantPath, path)
|
||||
require.Equal(t, tt.wantErr, err != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
217
dialect/sql/sqljson/sqljson.go
Normal file
217
dialect/sql/sqljson/sqljson.go
Normal file
@@ -0,0 +1,217 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package sqljson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/facebook/ent/dialect"
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// HasKey return a predicate for checking that a JSON key
|
||||
// exists and not NULL.
|
||||
//
|
||||
// sqljson.HasKey("column", sql.DotPath("a.b[2].c"))
|
||||
//
|
||||
func HasKey(column string, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
WritePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpNotNull)
|
||||
})
|
||||
}
|
||||
|
||||
// ValueEQ return a predicate for checking that a JSON value
|
||||
// (returned by the path) is equal to the given argument.
|
||||
//
|
||||
// P().JSONValueEQ("column", "a.b[2].c", arg)
|
||||
//
|
||||
func ValueEQ(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
WritePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpEQ).Arg(arg)
|
||||
})
|
||||
}
|
||||
|
||||
// WritePath writes the JSON path from the given options to the SQL builder.
|
||||
//
|
||||
// sqljson.WritePath(b, Path("a", "b", "[1]", "c"), Cast("int"))
|
||||
//
|
||||
func WritePath(b *sql.Builder, column string, opts ...Option) {
|
||||
path := &PathOptions{Ident: column}
|
||||
for i := range opts {
|
||||
opts[i](path)
|
||||
}
|
||||
path.WriteTo(b)
|
||||
}
|
||||
|
||||
// Option allows for calling database JSON paths with functional options.
|
||||
type Option func(*PathOptions)
|
||||
|
||||
// Path sets the path to the JSON value of a column.
|
||||
//
|
||||
// WritePath(b, "column", Path("a", "b", "[1]", "c"))
|
||||
//
|
||||
func Path(path ...string) Option {
|
||||
return func(p *PathOptions) {
|
||||
p.Path = path
|
||||
}
|
||||
}
|
||||
|
||||
// DotPath is similar to Path, but accepts string with dot format.
|
||||
//
|
||||
// WritePath(b, "column", DotPath("a.b.c"))
|
||||
// WritePath(b, "column", DotPath("a.b[2].c"))
|
||||
//
|
||||
// Note that DotPath is ignored if the input is invalid.
|
||||
func DotPath(dotpath string) Option {
|
||||
path, _ := ParsePath(dotpath)
|
||||
return func(p *PathOptions) {
|
||||
p.Path = path
|
||||
}
|
||||
}
|
||||
|
||||
// Unquote indicates that the result value should be unquoted.
|
||||
//
|
||||
// WritePath(b, "column", Path("a", "b", "[1]", "c"), Unquote(true))
|
||||
//
|
||||
func Unquote(unquote bool) Option {
|
||||
return func(p *PathOptions) {
|
||||
p.Unquote = unquote
|
||||
}
|
||||
}
|
||||
|
||||
// Cast indicates that the result value should be casted to the given type.
|
||||
//
|
||||
// WritePath(b, "column", Path("a", "b", "[1]", "c"), Cast("int"))
|
||||
//
|
||||
func Cast(typ string) Option {
|
||||
return func(p *PathOptions) {
|
||||
p.Cast = typ
|
||||
}
|
||||
}
|
||||
|
||||
// PathOptions holds the options for accessing a JSON value from an identifier.
|
||||
type PathOptions struct {
|
||||
Ident string
|
||||
Path []string
|
||||
Cast string
|
||||
Unquote bool
|
||||
}
|
||||
|
||||
// WriteTo writes the JSON path to the sql.Builder.
|
||||
func (p *PathOptions) WriteTo(b *sql.Builder) {
|
||||
switch {
|
||||
case len(p.Path) == 0:
|
||||
b.Ident(p.Ident)
|
||||
case b.Dialect() == dialect.Postgres:
|
||||
if p.Cast != "" {
|
||||
b.WriteString("CAST(")
|
||||
defer b.WriteString(" AS " + p.Cast + ")")
|
||||
}
|
||||
b.Ident(p.Ident)
|
||||
for i, s := range p.Path {
|
||||
b.WriteString("->")
|
||||
if p.Unquote && i == len(p.Path)-1 {
|
||||
b.WriteString(">")
|
||||
}
|
||||
if idx, ok := isJSONIdx(s); ok {
|
||||
b.WriteString(idx)
|
||||
} else {
|
||||
b.WriteString("'" + s + "'")
|
||||
}
|
||||
}
|
||||
default:
|
||||
if p.Unquote && b.Dialect() == dialect.MySQL {
|
||||
b.WriteString("JSON_UNQUOTE(")
|
||||
defer b.WriteByte(')')
|
||||
}
|
||||
b.WriteString("JSON_EXTRACT(")
|
||||
b.Ident(p.Ident).Comma()
|
||||
b.WriteString(`"$`)
|
||||
for _, p := range p.Path {
|
||||
if _, ok := isJSONIdx(p); ok {
|
||||
b.WriteString(p)
|
||||
} else {
|
||||
b.WriteString("." + p)
|
||||
}
|
||||
}
|
||||
b.WriteString(`")`)
|
||||
}
|
||||
}
|
||||
|
||||
// ParsePath parses the "dotpath" for the DotPath option.
|
||||
//
|
||||
// "a.b" => ["a", "b"]
|
||||
// "a[1][2]" => ["a", "[1]", "[2]"]
|
||||
// "a.\"b.c\" => ["a", "\"b.c\""]
|
||||
//
|
||||
func ParsePath(dotpath string) ([]string, error) {
|
||||
var (
|
||||
i, p int
|
||||
path []string
|
||||
)
|
||||
for i < len(dotpath) {
|
||||
switch r := dotpath[i]; {
|
||||
case r == '"':
|
||||
if i == len(dotpath)-1 {
|
||||
return nil, fmt.Errorf("unexpected quote")
|
||||
}
|
||||
idx := strings.IndexRune(dotpath[i+1:], '"')
|
||||
if idx == -1 || idx == 0 {
|
||||
return nil, fmt.Errorf("unbalanced quote")
|
||||
}
|
||||
i += idx + 2
|
||||
case r == '[':
|
||||
if p != i {
|
||||
path = append(path, dotpath[p:i])
|
||||
}
|
||||
p = i
|
||||
if i == len(dotpath)-1 {
|
||||
return nil, fmt.Errorf("unexpected bracket")
|
||||
}
|
||||
idx := strings.IndexRune(dotpath[i:], ']')
|
||||
if idx == -1 || idx == 1 {
|
||||
return nil, fmt.Errorf("unbalanced bracket")
|
||||
}
|
||||
if !isNumber(dotpath[i+1 : i+idx]) {
|
||||
return nil, fmt.Errorf("invalid index %q", dotpath[i:i+idx+1])
|
||||
}
|
||||
i += idx + 1
|
||||
case r == '.' || r == ']':
|
||||
if p != i {
|
||||
path = append(path, dotpath[p:i])
|
||||
}
|
||||
i++
|
||||
p = i
|
||||
default:
|
||||
i++
|
||||
}
|
||||
}
|
||||
if p != i {
|
||||
path = append(path, dotpath[p:i])
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// isJSONIdx reports whether the string represents a JSON index.
|
||||
func isJSONIdx(s string) (string, bool) {
|
||||
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) {
|
||||
return s[1 : len(s)-1], true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// isNumber reports whether the string is a number (category N).
|
||||
func isNumber(s string) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsNumber(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
158
dialect/sql/sqljson/sqljson_test.go
Normal file
158
dialect/sql/sqljson/sqljson_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package sqljson_test
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/facebook/ent/dialect"
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
"github.com/facebook/ent/dialect/sql/sqljson"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWritePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
input sql.Querier
|
||||
wantQuery string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
input: sql.Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(sql.Table("users")).
|
||||
Where(sqljson.ValueEQ("a", 1, sqljson.Path("b", "c", "[1]", "d"), sqljson.Cast("int"))),
|
||||
wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`,
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
input: sql.Dialect(dialect.MySQL).
|
||||
Select("*").
|
||||
From(sql.Table("users")).
|
||||
Where(sqljson.ValueEQ("a", "a", sqljson.DotPath("b.c[1].d"))),
|
||||
wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.c[1].d\") = ?",
|
||||
wantArgs: []interface{}{"a"},
|
||||
},
|
||||
{
|
||||
input: sql.Dialect(dialect.MySQL).
|
||||
Select("*").
|
||||
From(sql.Table("users")).
|
||||
Where(sqljson.ValueEQ("a", "a", sqljson.DotPath("b.\"c[1]\".d[1][2].e"))),
|
||||
wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.\"c[1]\".d[1][2].e\") = ?",
|
||||
wantArgs: []interface{}{"a"},
|
||||
},
|
||||
{
|
||||
input: sql.Select("*").
|
||||
From(sql.Table("test")).
|
||||
Where(sqljson.HasKey("j", sqljson.DotPath("a.*.c"))),
|
||||
wantQuery: "SELECT * FROM `test` WHERE JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL",
|
||||
},
|
||||
{
|
||||
input: sql.Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(sql.Table("test")).
|
||||
Where(sqljson.HasKey("j", sqljson.DotPath("a.b.c"))),
|
||||
wantQuery: `SELECT * FROM "test" WHERE "j"->'a'->'b'->'c' IS NOT NULL`,
|
||||
},
|
||||
{
|
||||
input: sql.Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(sql.Table("test")).
|
||||
Where(sql.And(
|
||||
sql.EQ("e", 10),
|
||||
sqljson.ValueEQ("a", 1, sqljson.DotPath("b.c")),
|
||||
)),
|
||||
wantQuery: `SELECT * FROM "test" WHERE "e" = $1 AND "a"->'b'->'c' = $2`,
|
||||
wantArgs: []interface{}{10, 1},
|
||||
},
|
||||
{
|
||||
input: sql.Dialect(dialect.MySQL).
|
||||
Select("*").
|
||||
From(sql.Table("users")).
|
||||
Where(sqljson.ValueEQ("a", "a", sqljson.Path("b", "c", "[1]", "d"), sqljson.Unquote(true))),
|
||||
wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) = ?",
|
||||
wantArgs: []interface{}{"a"},
|
||||
},
|
||||
{
|
||||
input: sql.Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(sql.Table("users")).
|
||||
Where(sqljson.ValueEQ("a", "a", sqljson.Path("b", "c", "[1]", "d"), sqljson.Unquote(true))),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' = $1`,
|
||||
wantArgs: []interface{}{"a"},
|
||||
},
|
||||
{
|
||||
input: sql.Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(sql.Table("users")).
|
||||
Where(sqljson.ValueEQ("a", 1, sqljson.Path("b", "c", "[1]", "d"), sqljson.Cast("int"))),
|
||||
wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`,
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
query, args := tt.input.Query()
|
||||
require.Equal(t, tt.wantQuery, query)
|
||||
require.Equal(t, tt.wantArgs, args)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantPath []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
input: "a.b.c",
|
||||
wantPath: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
input: "a[1][2]",
|
||||
wantPath: []string{"a", "[1]", "[2]"},
|
||||
},
|
||||
{
|
||||
input: "a[1][2].b",
|
||||
wantPath: []string{"a", "[1]", "[2]", "b"},
|
||||
},
|
||||
{
|
||||
input: `a."b.c[0]"`,
|
||||
wantPath: []string{"a", `"b.c[0]"`},
|
||||
},
|
||||
{
|
||||
input: `a."b.c[0]".d`,
|
||||
wantPath: []string{"a", `"b.c[0]"`, "d"},
|
||||
},
|
||||
{
|
||||
input: `...`,
|
||||
},
|
||||
{
|
||||
input: `.a.b.`,
|
||||
wantPath: []string{"a", "b"},
|
||||
},
|
||||
{
|
||||
input: `a."`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
input: `a[`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
input: `a[a]`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
path, err := sqljson.ParsePath(tt.input)
|
||||
require.Equal(t, tt.wantPath, path)
|
||||
require.Equal(t, tt.wantErr, err != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user