mirror of
https://github.com/ent/ent.git
synced 2026-04-28 13:40:56 +03:00
sql/sqljson: add support for ValueIn/ValueNotIn (#2882)
This commit is contained in:
@@ -18,7 +18,6 @@ import (
|
||||
// exists and not NULL.
|
||||
//
|
||||
// sqljson.HasKey("column", sql.DotPath("a.b[2].c"))
|
||||
//
|
||||
func HasKey(column string, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
switch b.Dialect() {
|
||||
@@ -49,7 +48,6 @@ func HasKey(column string, opts ...Option) *sql.Predicate {
|
||||
// the JSON key exists, use sql.IsNull or sqljson.HasKey.
|
||||
//
|
||||
// sqljson.ValueIsNull("a", sqljson.Path("b"))
|
||||
//
|
||||
func ValueIsNull(column string, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
switch b.Dialect() {
|
||||
@@ -75,10 +73,9 @@ func ValueIsNull(column string, opts ...Option) *sql.Predicate {
|
||||
// (returned by the path) is equal to the given argument.
|
||||
//
|
||||
// sqljson.ValueEQ("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func ValueEQ(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
opts = normalizePG(b, arg, opts)
|
||||
ValuePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpEQ).Arg(arg)
|
||||
})
|
||||
@@ -88,10 +85,9 @@ func ValueEQ(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
// (returned by the path) is not equal to the given argument.
|
||||
//
|
||||
// sqljson.ValueNEQ("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func ValueNEQ(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
opts = normalizePG(b, arg, opts)
|
||||
ValuePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpNEQ).Arg(arg)
|
||||
})
|
||||
@@ -101,10 +97,9 @@ func ValueNEQ(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
// (returned by the path) is greater than the given argument.
|
||||
//
|
||||
// sqljson.ValueGT("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func ValueGT(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
opts = normalizePG(b, arg, opts)
|
||||
ValuePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpGT).Arg(arg)
|
||||
})
|
||||
@@ -115,10 +110,9 @@ func ValueGT(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
// argument.
|
||||
//
|
||||
// sqljson.ValueGTE("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func ValueGTE(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
opts = normalizePG(b, arg, opts)
|
||||
ValuePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpGTE).Arg(arg)
|
||||
})
|
||||
@@ -128,10 +122,9 @@ func ValueGTE(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
// (returned by the path) is less than the given argument.
|
||||
//
|
||||
// sqljson.ValueLT("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func ValueLT(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
opts = normalizePG(b, arg, opts)
|
||||
ValuePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpLT).Arg(arg)
|
||||
})
|
||||
@@ -142,10 +135,9 @@ func ValueLT(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
// argument.
|
||||
//
|
||||
// sqljson.ValueLTE("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func ValueLTE(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
opts = normalizePG(b, arg, opts)
|
||||
ValuePath(b, column, opts...)
|
||||
b.WriteOp(sql.OpLTE).Arg(arg)
|
||||
})
|
||||
@@ -155,7 +147,6 @@ func ValueLTE(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
// value (returned by the path) contains the given argument.
|
||||
//
|
||||
// sqljson.ValueContains("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
path := identPath(column, opts...)
|
||||
@@ -176,7 +167,7 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
b.WriteString(" WHERE ").Ident("value").WriteOp(sql.OpEQ).Arg(arg)
|
||||
})
|
||||
case dialect.Postgres:
|
||||
opts, arg = normalizePG(b, arg, opts)
|
||||
opts = normalizePG(b, arg, opts)
|
||||
path.Cast = "jsonb"
|
||||
path.value(b)
|
||||
b.WriteString(" @> ").Arg(marshal(arg))
|
||||
@@ -214,11 +205,49 @@ func StringContains(column string, sub string, opts ...Option) *sql.Predicate {
|
||||
})
|
||||
}
|
||||
|
||||
// ValueIn return a predicate for checking that a JSON value
|
||||
// (returned by the path) is IN the given arguments.
|
||||
//
|
||||
// sqljson.ValueIn("a", []any{1, 2, 3}, sqljson.Path("b"))
|
||||
func ValueIn(column string, args []any, opts ...Option) *sql.Predicate {
|
||||
return valueInOp(column, args, opts, sql.OpIn)
|
||||
}
|
||||
|
||||
// ValueNotIn return a predicate for checking that a JSON value
|
||||
// (returned by the path) is NOT IN the given arguments.
|
||||
//
|
||||
// sqljson.ValueNotIn("a", []any{1, 2, 3}, sqljson.Path("b"))
|
||||
func ValueNotIn(column string, args []any, opts ...Option) *sql.Predicate {
|
||||
if len(args) == 0 {
|
||||
return sql.NotIn(column)
|
||||
}
|
||||
return valueInOp(column, args, opts, sql.OpNotIn)
|
||||
}
|
||||
|
||||
func valueInOp(column string, args []any, opts []Option, op sql.Op) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
if allString(args) {
|
||||
opts = append(opts, Unquote(true))
|
||||
}
|
||||
if len(args) > 0 {
|
||||
opts = normalizePG(b, args[0], opts)
|
||||
}
|
||||
ValuePath(b, column, opts...)
|
||||
b.WriteOp(op)
|
||||
b.Nested(func(b *sql.Builder) {
|
||||
if s, ok := args[0].(*sql.Selector); ok {
|
||||
b.Join(s)
|
||||
} else {
|
||||
b.Args(args...)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// LenEQ return a predicate for checking that an array length
|
||||
// of a JSON (returned by the path) is equal to the given argument.
|
||||
//
|
||||
// sqljson.LenEQ("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func LenEQ(column string, size int, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
LenPath(b, column, opts...)
|
||||
@@ -230,7 +259,6 @@ func LenEQ(column string, size int, opts ...Option) *sql.Predicate {
|
||||
// of a JSON (returned by the path) is not equal to the given argument.
|
||||
//
|
||||
// sqljson.LenEQ("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func LenNEQ(column string, size int, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
LenPath(b, column, opts...)
|
||||
@@ -243,7 +271,6 @@ func LenNEQ(column string, size int, opts ...Option) *sql.Predicate {
|
||||
// argument.
|
||||
//
|
||||
// sqljson.LenGT("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func LenGT(column string, size int, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
LenPath(b, column, opts...)
|
||||
@@ -256,7 +283,6 @@ func LenGT(column string, size int, opts ...Option) *sql.Predicate {
|
||||
// the given argument.
|
||||
//
|
||||
// sqljson.LenGTE("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func LenGTE(column string, size int, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
LenPath(b, column, opts...)
|
||||
@@ -269,7 +295,6 @@ func LenGTE(column string, size int, opts ...Option) *sql.Predicate {
|
||||
// argument.
|
||||
//
|
||||
// sqljson.LenLT("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func LenLT(column string, size int, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
LenPath(b, column, opts...)
|
||||
@@ -282,7 +307,6 @@ func LenLT(column string, size int, opts ...Option) *sql.Predicate {
|
||||
// the given argument.
|
||||
//
|
||||
// sqljson.LenLTE("a", 1, sqljson.Path("b"))
|
||||
//
|
||||
func LenLTE(column string, size int, opts ...Option) *sql.Predicate {
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
LenPath(b, column, opts...)
|
||||
@@ -294,7 +318,6 @@ func LenLTE(column string, size int, opts ...Option) *sql.Predicate {
|
||||
// getting the value of a given JSON path.
|
||||
//
|
||||
// sqljson.ValuePath(b, Path("a", "b", "[1]", "c"), Cast("int"))
|
||||
//
|
||||
func ValuePath(b *sql.Builder, column string, opts ...Option) {
|
||||
path := identPath(column, opts...)
|
||||
path.value(b)
|
||||
@@ -304,7 +327,6 @@ func ValuePath(b *sql.Builder, column string, opts ...Option) {
|
||||
// getting the length of a given JSON path.
|
||||
//
|
||||
// sqljson.LenPath(b, Path("a", "b", "[1]", "c"))
|
||||
//
|
||||
func LenPath(b *sql.Builder, column string, opts ...Option) {
|
||||
path := identPath(column, opts...)
|
||||
path.length(b)
|
||||
@@ -316,7 +338,6 @@ type Option func(*PathOptions)
|
||||
// Path sets the path to the JSON value of a column.
|
||||
//
|
||||
// ValuePath(b, "column", Path("a", "b", "[1]", "c"))
|
||||
//
|
||||
func Path(path ...string) Option {
|
||||
return func(p *PathOptions) {
|
||||
p.Path = path
|
||||
@@ -339,7 +360,6 @@ func DotPath(dotpath string) Option {
|
||||
// Unquote indicates that the result value should be unquoted.
|
||||
//
|
||||
// ValuePath(b, "column", Path("a", "b", "[1]", "c"), Unquote(true))
|
||||
//
|
||||
func Unquote(unquote bool) Option {
|
||||
return func(p *PathOptions) {
|
||||
p.Unquote = unquote
|
||||
@@ -349,7 +369,6 @@ func Unquote(unquote bool) Option {
|
||||
// Cast indicates that the result value should be casted to the given type.
|
||||
//
|
||||
// ValuePath(b, "column", Path("a", "b", "[1]", "c"), Cast("int"))
|
||||
//
|
||||
func Cast(typ string) Option {
|
||||
return func(p *PathOptions) {
|
||||
p.Cast = typ
|
||||
@@ -450,7 +469,6 @@ func (p *PathOptions) pgPath(b *sql.Builder) {
|
||||
// "a.b" => ["a", "b"]
|
||||
// "a[1][2]" => ["a", "[1]", "[2]"]
|
||||
// "a.\"b.c\" => ["a", "\"b.c\""]
|
||||
//
|
||||
func ParsePath(dotpath string) ([]string, error) {
|
||||
var (
|
||||
i, p int
|
||||
@@ -501,9 +519,9 @@ func ParsePath(dotpath string) ([]string, error) {
|
||||
|
||||
// normalizePG adds cast option to the JSON path is the argument type is
|
||||
// not string, in order to avoid "missing type casts" error in Postgres.
|
||||
func normalizePG(b *sql.Builder, arg any, opts []Option) ([]Option, any) {
|
||||
func normalizePG(b *sql.Builder, arg any, opts []Option) []Option {
|
||||
if b.Dialect() != dialect.Postgres {
|
||||
return opts, arg
|
||||
return opts
|
||||
}
|
||||
base := []Option{Unquote(true)}
|
||||
switch arg.(type) {
|
||||
@@ -515,7 +533,7 @@ func normalizePG(b *sql.Builder, arg any, opts []Option) ([]Option, any) {
|
||||
case int8, int16, int32, int64, int, uint8, uint16, uint32, uint64:
|
||||
base = append(base, Cast("int"))
|
||||
}
|
||||
return append(base, opts...), arg
|
||||
return append(base, opts...)
|
||||
}
|
||||
|
||||
// isJSONIdx reports whether the string represents a JSON index.
|
||||
@@ -536,6 +554,16 @@ func isNumber(s string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// allString reports if the slice contains only strings.
|
||||
func allString(v []any) bool {
|
||||
for i := range v {
|
||||
if _, ok := v[i].(string); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// marshal stringifies the given argument to a valid JSON document.
|
||||
func marshal(arg any) any {
|
||||
if buf, err := json.Marshal(arg); err == nil {
|
||||
|
||||
Reference in New Issue
Block a user