mirror of
https://github.com/ent/ent.git
synced 2026-04-28 13:40:56 +03:00
json predicates (#760)
* dialect/sql/sqljson: cast postgres non-string values * entc/integration: test json predicates
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
package sqljson
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
@@ -32,6 +33,7 @@ func HasKey(column string, opts ...Option) *sql.Predicate {
|
||||
//
|
||||
func ValueEQ(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
WritePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpEQ).Arg(arg)
|
||||
})
|
||||
@@ -44,6 +46,7 @@ func ValueEQ(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
//
|
||||
func ValueNEQ(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
WritePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpNEQ).Arg(arg)
|
||||
})
|
||||
@@ -56,6 +59,7 @@ func ValueNEQ(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
//
|
||||
func ValueGT(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
WritePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpGT).Arg(arg)
|
||||
})
|
||||
@@ -69,6 +73,7 @@ func ValueGT(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
//
|
||||
func ValueGTE(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
WritePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpGTE).Arg(arg)
|
||||
})
|
||||
@@ -81,6 +86,7 @@ func ValueGTE(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
//
|
||||
func ValueLT(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
WritePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpLT).Arg(arg)
|
||||
})
|
||||
@@ -94,6 +100,7 @@ func ValueLT(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
//
|
||||
func ValueLTE(column string, arg interface{}, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
WritePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpLTE).Arg(arg)
|
||||
})
|
||||
@@ -172,8 +179,8 @@ func (p *PathOptions) WriteTo(b *sql.Builder) {
|
||||
b.Ident(p.Ident)
|
||||
case b.Dialect() == dialect.Postgres:
|
||||
if p.Cast != "" {
|
||||
b.WriteString("CAST(")
|
||||
defer b.WriteString(" AS " + p.Cast + ")")
|
||||
b.WriteByte('(')
|
||||
defer b.WriteString(")::" + p.Cast)
|
||||
}
|
||||
b.Ident(p.Ident)
|
||||
for i, s := range p.Path {
|
||||
@@ -260,6 +267,29 @@ func ParsePath(dotpath string) ([]string, error) {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// normalizePG adds cast option to the JSON path is the argument type is
|
||||
// not string, in order to avoid "missing type casts" error in Postgres.
|
||||
func normalizePG(b *sql.Builder, arg interface{}, opts []Option) ([]Option, interface{}) {
|
||||
if b.Dialect() != dialect.Postgres {
|
||||
return opts, arg
|
||||
}
|
||||
base := []Option{Unquote(true)}
|
||||
switch arg.(type) {
|
||||
case string:
|
||||
case bool:
|
||||
base = append(base, Cast("bool"))
|
||||
case float32, float64:
|
||||
base = append(base, Cast("float"))
|
||||
case int8, int16, int32, int64, int, uint8, uint16, uint32, uint64:
|
||||
base = append(base, Cast("int"))
|
||||
default: // convert unknown types to text.
|
||||
if buf, err := json.Marshal(arg); err == nil {
|
||||
arg = string(buf)
|
||||
}
|
||||
}
|
||||
return append(base, opts...), arg
|
||||
}
|
||||
|
||||
// isJSONIdx reports whether the string represents a JSON index.
|
||||
func isJSONIdx(s string) (string, bool) {
|
||||
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) {
|
||||
|
||||
Reference in New Issue
Block a user