dialect/sql/sqljson: initial work for json_contains predicate (#886)

This commit is contained in:
Ariel Mashraki
2020-10-26 14:11:22 +02:00
committed by GitHub
parent 989deeb951
commit 9ea996593b
4 changed files with 128 additions and 16 deletions

View File

@@ -106,6 +106,42 @@ func ValueLTE(column string, arg interface{}, opts ...Option) *sql.Predicate {
})
}
// ValueContains return a predicate for checking that a JSON
// value (returned by the path) contains the given argument.
//
// sqljson.ValueContains("a", 1, sqljson.Path("b"))
//
func ValueContains(column string, arg interface{}, opts ...Option) *sql.Predicate {
return sql.P(func(b *sql.Builder) {
path := &PathOptions{Ident: column}
for i := range opts {
opts[i](path)
}
switch b.Dialect() {
case dialect.MySQL:
b.WriteString("JSON_CONTAINS").Nested(func(b *sql.Builder) {
b.Ident(column).Comma()
b.Arg(marshal(arg)).Comma()
path.mysqlPath(b)
})
b.WriteOp(sql.OpEQ).Arg(1)
case dialect.SQLite:
b.WriteString("EXISTS").Nested(func(b *sql.Builder) {
b.WriteString("SELECT * FROM JSON_EACH").Nested(func(b *sql.Builder) {
b.Ident(column).Comma()
path.mysqlPath(b)
})
b.WriteString(" WHERE ").Ident("value").WriteOp(sql.OpEQ).Arg(arg)
})
case dialect.Postgres:
opts, arg = normalizePG(b, arg, opts)
path.Cast = "jsonb"
path.value(b)
b.WriteString(" @> ").Arg(marshal(arg))
}
})
}
// LenEQ return a predicate for checking that an array length
// of a JSON (returned by the path) is equal to the given argument.
//
@@ -278,7 +314,7 @@ func (p *PathOptions) value(b *sql.Builder) {
b.WriteString("JSON_UNQUOTE(")
defer b.WriteByte(')')
}
p.mysqlPath("JSON_EXTRACT", b)
p.mysqlFunc("JSON_EXTRACT", b)
}
}
@@ -290,17 +326,23 @@ func (p *PathOptions) length(b *sql.Builder) {
p.pgPath(b)
b.WriteByte(')')
case b.Dialect() == dialect.MySQL:
p.mysqlPath("JSON_LENGTH", b)
p.mysqlFunc("JSON_LENGTH", b)
default:
p.mysqlPath("JSON_ARRAY_LENGTH", b)
p.mysqlFunc("JSON_ARRAY_LENGTH", b)
}
}
// mysqlPath writes the JSON path in MySQL format for the
// mysqlFunc writes the JSON path in MySQL format for the
// the given function. `JSON_EXTRACT("a", '$.b.c')`.
func (p *PathOptions) mysqlPath(fn string, b *sql.Builder) {
func (p *PathOptions) mysqlFunc(fn string, b *sql.Builder) {
b.WriteString(fn).WriteByte('(')
b.Ident(p.Ident).Comma()
p.mysqlPath(b)
b.WriteByte(')')
}
// mysqlPath writes the JSON path in MySQL (or SQLite) format.
func (p *PathOptions) mysqlPath(b *sql.Builder) {
b.WriteString(`"$`)
for _, p := range p.Path {
if _, ok := isJSONIdx(p); ok {
@@ -309,7 +351,7 @@ func (p *PathOptions) mysqlPath(fn string, b *sql.Builder) {
b.WriteString("." + p)
}
}
b.WriteString(`")`)
b.WriteByte('"')
}
// pgPath writes the JSON path in Postgres format `"a"->'b'->>'c'`.
@@ -398,9 +440,7 @@ func normalizePG(b *sql.Builder, arg interface{}, opts []Option) ([]Option, inte
case int8, int16, int32, int64, int, uint8, uint16, uint32, uint64:
base = append(base, Cast("int"))
default: // convert unknown types to text.
if buf, err := json.Marshal(arg); err == nil {
arg = string(buf)
}
arg = marshal(arg)
}
return append(base, opts...), arg
}
@@ -422,3 +462,11 @@ func isNumber(s string) bool {
}
return true
}
// marshal stringifies the given argument to a valid JSON document.
func marshal(arg interface{}) interface{} {
if buf, err := json.Marshal(arg); err == nil {
arg = string(buf)
}
return arg
}