dialect/sql/sqljson: move json predicates to a package (#735)

This commit is contained in:
Ariel Mashraki
2020-09-07 21:22:12 +03:00
committed by GitHub
parent 5450481513
commit ce48ab99b8
5 changed files with 378 additions and 373 deletions

View File

@@ -0,0 +1,217 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sqljson
import (
"fmt"
"strings"
"unicode"
"github.com/facebook/ent/dialect"
"github.com/facebook/ent/dialect/sql"
)
// HasKey return a predicate for checking that a JSON key
// exists and not NULL.
//
// sqljson.HasKey("column", sql.DotPath("a.b[2].c"))
//
func HasKey(column string, opts ...Option) *sql.Predicate {
return sql.P(func(b *sql.Builder) {
WritePath(b, column, opts...)
b.WriteOp(sql.OpNotNull)
})
}
// ValueEQ return a predicate for checking that a JSON value
// (returned by the path) is equal to the given argument.
//
// P().JSONValueEQ("column", "a.b[2].c", arg)
//
func ValueEQ(column string, arg interface{}, opts ...Option) *sql.Predicate {
return sql.P(func(b *sql.Builder) {
WritePath(b, column, opts...)
b.WriteOp(sql.OpEQ).Arg(arg)
})
}
// WritePath writes the JSON path from the given options to the SQL builder.
//
// sqljson.WritePath(b, Path("a", "b", "[1]", "c"), Cast("int"))
//
func WritePath(b *sql.Builder, column string, opts ...Option) {
path := &PathOptions{Ident: column}
for i := range opts {
opts[i](path)
}
path.WriteTo(b)
}
// Option allows for calling database JSON paths with functional options.
type Option func(*PathOptions)
// Path sets the path to the JSON value of a column.
//
// WritePath(b, "column", Path("a", "b", "[1]", "c"))
//
func Path(path ...string) Option {
return func(p *PathOptions) {
p.Path = path
}
}
// DotPath is similar to Path, but accepts string with dot format.
//
// WritePath(b, "column", DotPath("a.b.c"))
// WritePath(b, "column", DotPath("a.b[2].c"))
//
// Note that DotPath is ignored if the input is invalid.
func DotPath(dotpath string) Option {
path, _ := ParsePath(dotpath)
return func(p *PathOptions) {
p.Path = path
}
}
// Unquote indicates that the result value should be unquoted.
//
// WritePath(b, "column", Path("a", "b", "[1]", "c"), Unquote(true))
//
func Unquote(unquote bool) Option {
return func(p *PathOptions) {
p.Unquote = unquote
}
}
// Cast indicates that the result value should be casted to the given type.
//
// WritePath(b, "column", Path("a", "b", "[1]", "c"), Cast("int"))
//
func Cast(typ string) Option {
return func(p *PathOptions) {
p.Cast = typ
}
}
// PathOptions holds the options for accessing a JSON value from an identifier.
type PathOptions struct {
Ident string
Path []string
Cast string
Unquote bool
}
// WriteTo writes the JSON path to the sql.Builder.
func (p *PathOptions) WriteTo(b *sql.Builder) {
switch {
case len(p.Path) == 0:
b.Ident(p.Ident)
case b.Dialect() == dialect.Postgres:
if p.Cast != "" {
b.WriteString("CAST(")
defer b.WriteString(" AS " + p.Cast + ")")
}
b.Ident(p.Ident)
for i, s := range p.Path {
b.WriteString("->")
if p.Unquote && i == len(p.Path)-1 {
b.WriteString(">")
}
if idx, ok := isJSONIdx(s); ok {
b.WriteString(idx)
} else {
b.WriteString("'" + s + "'")
}
}
default:
if p.Unquote && b.Dialect() == dialect.MySQL {
b.WriteString("JSON_UNQUOTE(")
defer b.WriteByte(')')
}
b.WriteString("JSON_EXTRACT(")
b.Ident(p.Ident).Comma()
b.WriteString(`"$`)
for _, p := range p.Path {
if _, ok := isJSONIdx(p); ok {
b.WriteString(p)
} else {
b.WriteString("." + p)
}
}
b.WriteString(`")`)
}
}
// ParsePath parses the "dotpath" for the DotPath option.
//
// "a.b" => ["a", "b"]
// "a[1][2]" => ["a", "[1]", "[2]"]
// "a.\"b.c\" => ["a", "\"b.c\""]
//
func ParsePath(dotpath string) ([]string, error) {
var (
i, p int
path []string
)
for i < len(dotpath) {
switch r := dotpath[i]; {
case r == '"':
if i == len(dotpath)-1 {
return nil, fmt.Errorf("unexpected quote")
}
idx := strings.IndexRune(dotpath[i+1:], '"')
if idx == -1 || idx == 0 {
return nil, fmt.Errorf("unbalanced quote")
}
i += idx + 2
case r == '[':
if p != i {
path = append(path, dotpath[p:i])
}
p = i
if i == len(dotpath)-1 {
return nil, fmt.Errorf("unexpected bracket")
}
idx := strings.IndexRune(dotpath[i:], ']')
if idx == -1 || idx == 1 {
return nil, fmt.Errorf("unbalanced bracket")
}
if !isNumber(dotpath[i+1 : i+idx]) {
return nil, fmt.Errorf("invalid index %q", dotpath[i:i+idx+1])
}
i += idx + 1
case r == '.' || r == ']':
if p != i {
path = append(path, dotpath[p:i])
}
i++
p = i
default:
i++
}
}
if p != i {
path = append(path, dotpath[p:i])
}
return path, nil
}
// isJSONIdx reports whether the string represents a JSON index.
func isJSONIdx(s string) (string, bool) {
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) {
return s[1 : len(s)-1], true
}
return "", false
}
// isNumber reports whether the string is a number (category N).
func isNumber(s string) bool {
for _, r := range s {
if !unicode.IsNumber(r) {
return false
}
}
return true
}

View File

@@ -0,0 +1,158 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sqljson_test
import (
"strconv"
"testing"
"github.com/facebook/ent/dialect"
"github.com/facebook/ent/dialect/sql"
"github.com/facebook/ent/dialect/sql/sqljson"
"github.com/stretchr/testify/require"
)
func TestWritePath(t *testing.T) {
tests := []struct {
input sql.Querier
wantQuery string
wantArgs []interface{}
}{
{
input: sql.Dialect(dialect.Postgres).
Select("*").
From(sql.Table("users")).
Where(sqljson.ValueEQ("a", 1, sqljson.Path("b", "c", "[1]", "d"), sqljson.Cast("int"))),
wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`,
wantArgs: []interface{}{1},
},
{
input: sql.Dialect(dialect.MySQL).
Select("*").
From(sql.Table("users")).
Where(sqljson.ValueEQ("a", "a", sqljson.DotPath("b.c[1].d"))),
wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.c[1].d\") = ?",
wantArgs: []interface{}{"a"},
},
{
input: sql.Dialect(dialect.MySQL).
Select("*").
From(sql.Table("users")).
Where(sqljson.ValueEQ("a", "a", sqljson.DotPath("b.\"c[1]\".d[1][2].e"))),
wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.\"c[1]\".d[1][2].e\") = ?",
wantArgs: []interface{}{"a"},
},
{
input: sql.Select("*").
From(sql.Table("test")).
Where(sqljson.HasKey("j", sqljson.DotPath("a.*.c"))),
wantQuery: "SELECT * FROM `test` WHERE JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL",
},
{
input: sql.Dialect(dialect.Postgres).
Select("*").
From(sql.Table("test")).
Where(sqljson.HasKey("j", sqljson.DotPath("a.b.c"))),
wantQuery: `SELECT * FROM "test" WHERE "j"->'a'->'b'->'c' IS NOT NULL`,
},
{
input: sql.Dialect(dialect.Postgres).
Select("*").
From(sql.Table("test")).
Where(sql.And(
sql.EQ("e", 10),
sqljson.ValueEQ("a", 1, sqljson.DotPath("b.c")),
)),
wantQuery: `SELECT * FROM "test" WHERE "e" = $1 AND "a"->'b'->'c' = $2`,
wantArgs: []interface{}{10, 1},
},
{
input: sql.Dialect(dialect.MySQL).
Select("*").
From(sql.Table("users")).
Where(sqljson.ValueEQ("a", "a", sqljson.Path("b", "c", "[1]", "d"), sqljson.Unquote(true))),
wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) = ?",
wantArgs: []interface{}{"a"},
},
{
input: sql.Dialect(dialect.Postgres).
Select("*").
From(sql.Table("users")).
Where(sqljson.ValueEQ("a", "a", sqljson.Path("b", "c", "[1]", "d"), sqljson.Unquote(true))),
wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' = $1`,
wantArgs: []interface{}{"a"},
},
{
input: sql.Dialect(dialect.Postgres).
Select("*").
From(sql.Table("users")).
Where(sqljson.ValueEQ("a", 1, sqljson.Path("b", "c", "[1]", "d"), sqljson.Cast("int"))),
wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`,
wantArgs: []interface{}{1},
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
query, args := tt.input.Query()
require.Equal(t, tt.wantQuery, query)
require.Equal(t, tt.wantArgs, args)
})
}
}
func TestParsePath(t *testing.T) {
tests := []struct {
input string
wantPath []string
wantErr bool
}{
{
input: "a.b.c",
wantPath: []string{"a", "b", "c"},
},
{
input: "a[1][2]",
wantPath: []string{"a", "[1]", "[2]"},
},
{
input: "a[1][2].b",
wantPath: []string{"a", "[1]", "[2]", "b"},
},
{
input: `a."b.c[0]"`,
wantPath: []string{"a", `"b.c[0]"`},
},
{
input: `a."b.c[0]".d`,
wantPath: []string{"a", `"b.c[0]"`, "d"},
},
{
input: `...`,
},
{
input: `.a.b.`,
wantPath: []string{"a", "b"},
},
{
input: `a."`,
wantErr: true,
},
{
input: `a[`,
wantErr: true,
},
{
input: `a[a]`,
wantErr: true,
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
path, err := sqljson.ParsePath(tt.input)
require.Equal(t, tt.wantPath, path)
require.Equal(t, tt.wantErr, err != nil)
})
}
}