From 94eee235b4003d7d657f470b6d7f1051fd92aa8f Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 24 Aug 2020 14:41:12 +0300 Subject: [PATCH] dialect/sql: add json-path for sql builders --- dialect/sql/builder.go | 172 ++++++++++++++++++++++++++++++------ dialect/sql/builder_test.go | 36 ++++++++ 2 files changed, 182 insertions(+), 26 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 55d96cd76..711964d65 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -10,6 +10,7 @@ import ( "fmt" "strconv" "strings" + "unicode" "github.com/facebook/ent/dialect" ) @@ -697,7 +698,7 @@ func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder { // Add adds a numeric value to the given column. func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder { u.columns = append(u.columns, column) - u.values = append(u.values, P().append(func(b *Builder) { + u.values = append(u.values, P().Append(func(b *Builder) { b.WriteString("COALESCE") b.Nested(func(b *Builder) { b.Ident(column).Comma().Arg(0) @@ -833,7 +834,7 @@ func P() *Predicate { return &Predicate{} } // func Or(preds ...*Predicate) *Predicate { p := P() - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { p.mayWrap(preds, b, "OR") }) } @@ -848,7 +849,7 @@ func False() *Predicate { // False appends FALSE to the predicate. func (p *Predicate) False() *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.WriteString("FALSE") }) } @@ -858,7 +859,7 @@ func (p *Predicate) False() *Predicate { // Not(Or(EQ("name", "foo"), EQ("name", "bar"))) // func Not(pred *Predicate) *Predicate { - return P().Not().append(func(b *Builder) { + return P().Not().Append(func(b *Builder) { b.Nested(func(b *Builder) { b.Join(pred) }) @@ -867,7 +868,7 @@ func Not(pred *Predicate) *Predicate { // Not appends NOT to the predicate. func (p *Predicate) Not() *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.WriteString("NOT ") }) } @@ -875,7 +876,7 @@ func (p *Predicate) Not() *Predicate { // And combines all given predicates with AND between them. func And(preds ...*Predicate) *Predicate { p := P() - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { p.mayWrap(preds, b, "AND") }) } @@ -887,7 +888,7 @@ func EQ(col string, value interface{}) *Predicate { // EQ appends a "=" predicate. func (p *Predicate) EQ(col string, arg interface{}) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" = ") b.Arg(arg) }) @@ -900,7 +901,7 @@ func NEQ(col string, value interface{}) *Predicate { // NEQ appends a "<>" predicate. func (p *Predicate) NEQ(col string, arg interface{}) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" <> ") b.Arg(arg) }) @@ -913,7 +914,7 @@ func LT(col string, value interface{}) *Predicate { // LT appends a "<" predicate. func (p *Predicate) LT(col string, arg interface{}) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" < ") b.Arg(arg) }) @@ -926,7 +927,7 @@ func LTE(col string, value interface{}) *Predicate { // LTE appends a "<=" predicate. func (p *Predicate) LTE(col string, arg interface{}) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" <= ") b.Arg(arg) }) @@ -939,7 +940,7 @@ func GT(col string, value interface{}) *Predicate { // GT appends a ">" predicate. func (p *Predicate) GT(col string, arg interface{}) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" > ") b.Arg(arg) }) @@ -952,7 +953,7 @@ func GTE(col string, value interface{}) *Predicate { // GTE appends a ">=" predicate. func (p *Predicate) GTE(col string, arg interface{}) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" >= ") b.Arg(arg) }) @@ -965,7 +966,7 @@ func NotNull(col string) *Predicate { // NotNull appends the `IS NOT NULL` predicate. func (p *Predicate) NotNull(col string) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" IS NOT NULL") }) } @@ -977,7 +978,7 @@ func IsNull(col string) *Predicate { // IsNull appends the `IS NULL` predicate. func (p *Predicate) IsNull(col string) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" IS NULL") }) } @@ -992,7 +993,7 @@ func (p *Predicate) In(col string, args ...interface{}) *Predicate { if len(args) == 0 { return p } - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" IN ") b.Nested(func(b *Builder) { if s, ok := args[0].(*Selector); ok { @@ -1039,7 +1040,7 @@ func NotIn(col string, args ...interface{}) *Predicate { // NotIn appends the `Not IN` predicate. func (p *Predicate) NotIn(col string, args ...interface{}) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" NOT IN ") b.Nested(func(b *Builder) { b.Args(args...) @@ -1054,7 +1055,7 @@ func Like(col, pattern string) *Predicate { // Like appends the `LIKE` predicate. func (p *Predicate) Like(col, pattern string) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Ident(col).WriteString(" LIKE ") b.Arg(pattern) }) @@ -1083,7 +1084,7 @@ func EqualFold(col, sub string) *Predicate { return (&Predicate{}).EqualFold(col // EqualFold is a helper predicate that applies the "=" predicate with case-folding. func (p *Predicate) EqualFold(col, sub string) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { f := &Func{} f.SetDialect(b.dialect) b.Ident(f.Lower(col)).WriteString(" = ") @@ -1104,7 +1105,7 @@ func ContainsFold(col, sub string) *Predicate { return (&Predicate{}).ContainsFo // ContainsFold is a helper predicate that applies the LIKE predicate with case-folding. func (p *Predicate) ContainsFold(col, sub string) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { f := &Func{} f.SetDialect(b.dialect) switch b.dialect { @@ -1130,7 +1131,7 @@ func CompositeLT(columns []string, args ...interface{}) *Predicate { } func (p *Predicate) compositeP(operator string, columns []string, args ...interface{}) *Predicate { - return p.append(func(b *Builder) { + return p.Append(func(b *Builder) { b.Nested(func(nb *Builder) { nb.IdentComma(columns...) }) @@ -1153,6 +1154,13 @@ func (p *Predicate) CompositeLT(columns []string, args ...interface{}) *Predicat return p.compositeP(operator, columns, args...) } +// Append appends a new function to the predicate callbacks. +// The callback list are executed on call to Query. +func (p *Predicate) Append(f func(*Builder)) *Predicate { + p.fns = append(p.fns, f) + return p +} + // Query returns query representation of a predicate. func (p *Predicate) Query() (string, []interface{}) { for _, f := range p.fns { @@ -1169,11 +1177,6 @@ func (p *Predicate) clone() *Predicate { return &Predicate{fns: append([]func(*Builder){}, p.fns...)} } -func (p *Predicate) append(f func(*Builder)) *Predicate { - p.fns = append(p.fns, f) - return p -} - func (p *Predicate) mayWrap(preds []*Predicate, b *Builder, op string) { switch n := len(preds); { case n == 1: @@ -1873,6 +1876,100 @@ func (b *Builder) IdentComma(s ...string) *Builder { return b } +// JSONOption allows for calling database JSON paths with functional options. +type JSONOption func(*JSONPath) + +// Path sets the path to the JSON value of a column. +// +// b.JSONPath("column", Path("a", "b", "[1]", "c")) +// +func Path(path ...string) JSONOption { + return func(p *JSONPath) { + p.path = path + } +} + +// Unquote indicates that the result value should be unquoted. +// +// b.JSONPath("column", Path("a", "b", "[1]", "c"), Unquote(true)) +// +func Unquote(unquote bool) JSONOption { + return func(p *JSONPath) { + p.unquote = unquote + } +} + +// Cast indicates that the result value should be casted to the given type. +// +// b.JSONPath("column", Path("a", "b", "[1]", "c"), Cast("int")) +// +func Cast(typ string) JSONOption { + return func(p *JSONPath) { + p.cast = typ + } +} + +// JSONPath represents a path to a JSON value. +type JSONPath struct { + ident string + path []string + cast string + unquote bool +} + +// writeTo writes the JSON path to the builder. +func (p *JSONPath) writeTo(b *Builder) { + switch { + case len(p.path) == 0: + b.Ident(p.ident) + case b.postgres(): + if p.cast != "" { + b.WriteString("CAST(") + defer b.WriteString(" AS " + p.cast + ")") + } + b.Ident(p.ident) + for i, s := range p.path { + b.WriteString("->") + if p.unquote && i == len(p.path)-1 { + b.WriteString(">") + } + if idx, ok := isJSONIdx(s); ok { + b.WriteString(idx) + } else { + b.WriteString("'" + s + "'") + } + } + default: + if p.unquote && b.mysql() { + b.WriteString("JSON_UNQUOTE(") + defer b.WriteByte(')') + } + b.WriteString("JSON_EXTRACT(") + b.Ident(p.ident).Comma() + b.WriteString(`"$`) + for _, p := range p.path { + if _, ok := isJSONIdx(p); ok { + b.WriteString(p) + } else { + b.WriteString("." + p) + } + } + b.WriteString(`")`) + } +} + +// JSONPath appends the given JSON paths to the builder. +// +// b.JSONPath("column", Path("a", "b", "[1]", "c"), Cast("int")) +// +func (b *Builder) JSONPath(ident string, opts ...JSONOption) { + path := &JSONPath{ident: ident} + for i := range opts { + opts[i](path) + } + path.writeTo(b) +} + // Arg appends an input argument to the builder. func (b *Builder) Arg(a interface{}) *Builder { if r, ok := a.(*raw); ok { @@ -1995,11 +2092,16 @@ func (b Builder) clone() Builder { return c } -// postgres reports if the builder dialect is postgres. +// postgres reports if the builder dialect is PostgreSQL. func (b Builder) postgres() bool { return b.Dialect() == dialect.Postgres } +// mysql reports if the builder dialect is MySQL. +func (b Builder) mysql() bool { + return b.Dialect() == dialect.MySQL +} + // fromIdent sets the builder dialect from the identifier format. func (b *Builder) fromIdent(ident string) { if strings.Contains(ident, `"`) { @@ -2185,6 +2287,24 @@ func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder { return b } +// isNumber reports whether the string is a number (category N). +func isNumber(s string) bool { + for _, r := range s { + if !unicode.IsNumber(r) { + return false + } + } + return true +} + +// isJSONIdx reports whether the string represents a JSON index. +func isJSONIdx(s string) (string, bool) { + if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) { + return s[1 : len(s)-1], true + } + return "", false +} + func isFunc(s string) bool { return strings.Contains(s, "(") && strings.Contains(s, ")") } diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 68aa217db..b2ced1713 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1213,6 +1213,42 @@ WHERE OR ("f" <> $10 AND "g" <> $11)`), wantArgs: []interface{}{1, 2, 3, 2, 4, 5, "a", "b", "c", "f", "g"}, }, + { + input: Dialect(dialect.MySQL). + Select("*"). + From(Table("users")). + Where(P().Append(func(b *Builder) { + b.JSONPath("a", Path("b", "c", "[1]", "d"), Unquote(true)) + b.WriteString(" = ") + b.Arg("a") + })), + wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) = ?", + wantArgs: []interface{}{"a"}, + }, + { + input: Dialect(dialect.Postgres). + Select("*"). + From(Table("users")). + Where(P().Append(func(b *Builder) { + b.JSONPath("a", Path("b", "c", "[1]", "d"), Unquote(true)) + b.WriteString(" = ") + b.Arg("a") + })), + wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' = $1`, + wantArgs: []interface{}{"a"}, + }, + { + input: Dialect(dialect.Postgres). + Select("*"). + From(Table("users")). + Where(P().Append(func(b *Builder) { + b.JSONPath("a", Path("b", "c", "[1]", "d"), Cast("int")) + b.WriteString(" = ") + b.Arg(1) + })), + wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`, + wantArgs: []interface{}{1}, + }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) {