dialect/sql: add json-path for sql builders

This commit is contained in:
Ariel Mashraki
2020-08-24 14:41:12 +03:00
parent c907c8bbbc
commit 94eee235b4
2 changed files with 182 additions and 26 deletions

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"strconv"
"strings"
"unicode"
"github.com/facebook/ent/dialect"
)
@@ -697,7 +698,7 @@ func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder {
// Add adds a numeric value to the given column.
func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder {
u.columns = append(u.columns, column)
u.values = append(u.values, P().append(func(b *Builder) {
u.values = append(u.values, P().Append(func(b *Builder) {
b.WriteString("COALESCE")
b.Nested(func(b *Builder) {
b.Ident(column).Comma().Arg(0)
@@ -833,7 +834,7 @@ func P() *Predicate { return &Predicate{} }
//
func Or(preds ...*Predicate) *Predicate {
p := P()
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
p.mayWrap(preds, b, "OR")
})
}
@@ -848,7 +849,7 @@ func False() *Predicate {
// False appends FALSE to the predicate.
func (p *Predicate) False() *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.WriteString("FALSE")
})
}
@@ -858,7 +859,7 @@ func (p *Predicate) False() *Predicate {
// Not(Or(EQ("name", "foo"), EQ("name", "bar")))
//
func Not(pred *Predicate) *Predicate {
return P().Not().append(func(b *Builder) {
return P().Not().Append(func(b *Builder) {
b.Nested(func(b *Builder) {
b.Join(pred)
})
@@ -867,7 +868,7 @@ func Not(pred *Predicate) *Predicate {
// Not appends NOT to the predicate.
func (p *Predicate) Not() *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.WriteString("NOT ")
})
}
@@ -875,7 +876,7 @@ func (p *Predicate) Not() *Predicate {
// And combines all given predicates with AND between them.
func And(preds ...*Predicate) *Predicate {
p := P()
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
p.mayWrap(preds, b, "AND")
})
}
@@ -887,7 +888,7 @@ func EQ(col string, value interface{}) *Predicate {
// EQ appends a "=" predicate.
func (p *Predicate) EQ(col string, arg interface{}) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" = ")
b.Arg(arg)
})
@@ -900,7 +901,7 @@ func NEQ(col string, value interface{}) *Predicate {
// NEQ appends a "<>" predicate.
func (p *Predicate) NEQ(col string, arg interface{}) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" <> ")
b.Arg(arg)
})
@@ -913,7 +914,7 @@ func LT(col string, value interface{}) *Predicate {
// LT appends a "<" predicate.
func (p *Predicate) LT(col string, arg interface{}) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" < ")
b.Arg(arg)
})
@@ -926,7 +927,7 @@ func LTE(col string, value interface{}) *Predicate {
// LTE appends a "<=" predicate.
func (p *Predicate) LTE(col string, arg interface{}) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" <= ")
b.Arg(arg)
})
@@ -939,7 +940,7 @@ func GT(col string, value interface{}) *Predicate {
// GT appends a ">" predicate.
func (p *Predicate) GT(col string, arg interface{}) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" > ")
b.Arg(arg)
})
@@ -952,7 +953,7 @@ func GTE(col string, value interface{}) *Predicate {
// GTE appends a ">=" predicate.
func (p *Predicate) GTE(col string, arg interface{}) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" >= ")
b.Arg(arg)
})
@@ -965,7 +966,7 @@ func NotNull(col string) *Predicate {
// NotNull appends the `IS NOT NULL` predicate.
func (p *Predicate) NotNull(col string) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" IS NOT NULL")
})
}
@@ -977,7 +978,7 @@ func IsNull(col string) *Predicate {
// IsNull appends the `IS NULL` predicate.
func (p *Predicate) IsNull(col string) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" IS NULL")
})
}
@@ -992,7 +993,7 @@ func (p *Predicate) In(col string, args ...interface{}) *Predicate {
if len(args) == 0 {
return p
}
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" IN ")
b.Nested(func(b *Builder) {
if s, ok := args[0].(*Selector); ok {
@@ -1039,7 +1040,7 @@ func NotIn(col string, args ...interface{}) *Predicate {
// NotIn appends the `Not IN` predicate.
func (p *Predicate) NotIn(col string, args ...interface{}) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" NOT IN ")
b.Nested(func(b *Builder) {
b.Args(args...)
@@ -1054,7 +1055,7 @@ func Like(col, pattern string) *Predicate {
// Like appends the `LIKE` predicate.
func (p *Predicate) Like(col, pattern string) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Ident(col).WriteString(" LIKE ")
b.Arg(pattern)
})
@@ -1083,7 +1084,7 @@ func EqualFold(col, sub string) *Predicate { return (&Predicate{}).EqualFold(col
// EqualFold is a helper predicate that applies the "=" predicate with case-folding.
func (p *Predicate) EqualFold(col, sub string) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
f := &Func{}
f.SetDialect(b.dialect)
b.Ident(f.Lower(col)).WriteString(" = ")
@@ -1104,7 +1105,7 @@ func ContainsFold(col, sub string) *Predicate { return (&Predicate{}).ContainsFo
// ContainsFold is a helper predicate that applies the LIKE predicate with case-folding.
func (p *Predicate) ContainsFold(col, sub string) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
f := &Func{}
f.SetDialect(b.dialect)
switch b.dialect {
@@ -1130,7 +1131,7 @@ func CompositeLT(columns []string, args ...interface{}) *Predicate {
}
func (p *Predicate) compositeP(operator string, columns []string, args ...interface{}) *Predicate {
return p.append(func(b *Builder) {
return p.Append(func(b *Builder) {
b.Nested(func(nb *Builder) {
nb.IdentComma(columns...)
})
@@ -1153,6 +1154,13 @@ func (p *Predicate) CompositeLT(columns []string, args ...interface{}) *Predicat
return p.compositeP(operator, columns, args...)
}
// Append appends a new function to the predicate callbacks.
// The callback list are executed on call to Query.
func (p *Predicate) Append(f func(*Builder)) *Predicate {
p.fns = append(p.fns, f)
return p
}
// Query returns query representation of a predicate.
func (p *Predicate) Query() (string, []interface{}) {
for _, f := range p.fns {
@@ -1169,11 +1177,6 @@ func (p *Predicate) clone() *Predicate {
return &Predicate{fns: append([]func(*Builder){}, p.fns...)}
}
func (p *Predicate) append(f func(*Builder)) *Predicate {
p.fns = append(p.fns, f)
return p
}
func (p *Predicate) mayWrap(preds []*Predicate, b *Builder, op string) {
switch n := len(preds); {
case n == 1:
@@ -1873,6 +1876,100 @@ func (b *Builder) IdentComma(s ...string) *Builder {
return b
}
// JSONOption allows for calling database JSON paths with functional options.
type JSONOption func(*JSONPath)
// Path sets the path to the JSON value of a column.
//
// b.JSONPath("column", Path("a", "b", "[1]", "c"))
//
func Path(path ...string) JSONOption {
return func(p *JSONPath) {
p.path = path
}
}
// Unquote indicates that the result value should be unquoted.
//
// b.JSONPath("column", Path("a", "b", "[1]", "c"), Unquote(true))
//
func Unquote(unquote bool) JSONOption {
return func(p *JSONPath) {
p.unquote = unquote
}
}
// Cast indicates that the result value should be casted to the given type.
//
// b.JSONPath("column", Path("a", "b", "[1]", "c"), Cast("int"))
//
func Cast(typ string) JSONOption {
return func(p *JSONPath) {
p.cast = typ
}
}
// JSONPath represents a path to a JSON value.
type JSONPath struct {
ident string
path []string
cast string
unquote bool
}
// writeTo writes the JSON path to the builder.
func (p *JSONPath) writeTo(b *Builder) {
switch {
case len(p.path) == 0:
b.Ident(p.ident)
case b.postgres():
if p.cast != "" {
b.WriteString("CAST(")
defer b.WriteString(" AS " + p.cast + ")")
}
b.Ident(p.ident)
for i, s := range p.path {
b.WriteString("->")
if p.unquote && i == len(p.path)-1 {
b.WriteString(">")
}
if idx, ok := isJSONIdx(s); ok {
b.WriteString(idx)
} else {
b.WriteString("'" + s + "'")
}
}
default:
if p.unquote && b.mysql() {
b.WriteString("JSON_UNQUOTE(")
defer b.WriteByte(')')
}
b.WriteString("JSON_EXTRACT(")
b.Ident(p.ident).Comma()
b.WriteString(`"$`)
for _, p := range p.path {
if _, ok := isJSONIdx(p); ok {
b.WriteString(p)
} else {
b.WriteString("." + p)
}
}
b.WriteString(`")`)
}
}
// JSONPath appends the given JSON paths to the builder.
//
// b.JSONPath("column", Path("a", "b", "[1]", "c"), Cast("int"))
//
func (b *Builder) JSONPath(ident string, opts ...JSONOption) {
path := &JSONPath{ident: ident}
for i := range opts {
opts[i](path)
}
path.writeTo(b)
}
// Arg appends an input argument to the builder.
func (b *Builder) Arg(a interface{}) *Builder {
if r, ok := a.(*raw); ok {
@@ -1995,11 +2092,16 @@ func (b Builder) clone() Builder {
return c
}
// postgres reports if the builder dialect is postgres.
// postgres reports if the builder dialect is PostgreSQL.
func (b Builder) postgres() bool {
return b.Dialect() == dialect.Postgres
}
// mysql reports if the builder dialect is MySQL.
func (b Builder) mysql() bool {
return b.Dialect() == dialect.MySQL
}
// fromIdent sets the builder dialect from the identifier format.
func (b *Builder) fromIdent(ident string) {
if strings.Contains(ident, `"`) {
@@ -2185,6 +2287,24 @@ func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder {
return b
}
// isNumber reports whether the string is a number (category N).
func isNumber(s string) bool {
for _, r := range s {
if !unicode.IsNumber(r) {
return false
}
}
return true
}
// isJSONIdx reports whether the string represents a JSON index.
func isJSONIdx(s string) (string, bool) {
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) {
return s[1 : len(s)-1], true
}
return "", false
}
func isFunc(s string) bool {
return strings.Contains(s, "(") && strings.Contains(s, ")")
}

View File

@@ -1213,6 +1213,42 @@ WHERE
OR ("f" <> $10 AND "g" <> $11)`),
wantArgs: []interface{}{1, 2, 3, 2, 4, 5, "a", "b", "c", "f", "g"},
},
{
input: Dialect(dialect.MySQL).
Select("*").
From(Table("users")).
Where(P().Append(func(b *Builder) {
b.JSONPath("a", Path("b", "c", "[1]", "d"), Unquote(true))
b.WriteString(" = ")
b.Arg("a")
})),
wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) = ?",
wantArgs: []interface{}{"a"},
},
{
input: Dialect(dialect.Postgres).
Select("*").
From(Table("users")).
Where(P().Append(func(b *Builder) {
b.JSONPath("a", Path("b", "c", "[1]", "d"), Unquote(true))
b.WriteString(" = ")
b.Arg("a")
})),
wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' = $1`,
wantArgs: []interface{}{"a"},
},
{
input: Dialect(dialect.Postgres).
Select("*").
From(Table("users")).
Where(P().Append(func(b *Builder) {
b.JSONPath("a", Path("b", "c", "[1]", "d"), Cast("int"))
b.WriteString(" = ")
b.Arg(1)
})),
wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`,
wantArgs: []interface{}{1},
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {