mirror of
https://github.com/ent/ent.git
synced 2026-04-29 06:00:55 +03:00
dialect/sql/sqljson: cast marshaled args as json (#3008)
This commit is contained in:
@@ -6,6 +6,7 @@ package sqljson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
type sqlite struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
func (d *sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
setCase(u, column, when{
|
||||
Cond: func(b *sql.Builder) {
|
||||
typ := func(b *sql.Builder) *sql.Builder {
|
||||
@@ -32,10 +33,10 @@ func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...
|
||||
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
b.Comma().Argf("JSON(?)", marshalArg(elems))
|
||||
})
|
||||
} else {
|
||||
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
b.Arg(marshalArg(elems))
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
@@ -54,17 +55,27 @@ func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...
|
||||
b.Comma()
|
||||
}
|
||||
path(b)
|
||||
b.Comma().Arg(e)
|
||||
b.Comma()
|
||||
d.appendArg(b, e)
|
||||
}
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (d *sqlite) appendArg(b *sql.Builder, v any) {
|
||||
switch {
|
||||
case !isPrimitive(v):
|
||||
b.Argf("JSON(?)", marshalArg(v))
|
||||
default:
|
||||
b.Arg(v)
|
||||
}
|
||||
}
|
||||
|
||||
type mysql struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
func (d *mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
setCase(u, column, when{
|
||||
Cond: func(b *sql.Builder) {
|
||||
typ := func(b *sql.Builder) *sql.Builder {
|
||||
@@ -82,10 +93,10 @@ func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...O
|
||||
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
b.Comma().WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
|
||||
})
|
||||
} else {
|
||||
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
b.WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
@@ -96,13 +107,34 @@ func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...O
|
||||
b.Comma()
|
||||
}
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().Arg(e)
|
||||
b.Comma()
|
||||
d.appendArg(b, e)
|
||||
}
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (d *mysql) marshalArgs(args []any) []any {
|
||||
vs := make([]any, len(args))
|
||||
for i, v := range args {
|
||||
if !isPrimitive(v) {
|
||||
v = marshalArg(v)
|
||||
}
|
||||
vs[i] = v
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
func (d *mysql) appendArg(b *sql.Builder, v any) {
|
||||
switch {
|
||||
case !isPrimitive(v):
|
||||
b.Argf("CAST(? AS JSON)", marshalArg(v))
|
||||
default:
|
||||
b.Arg(v)
|
||||
}
|
||||
}
|
||||
|
||||
type postgres struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
@@ -120,11 +152,11 @@ func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts .
|
||||
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).pgArrayPath(b)
|
||||
b.Comma().Arg(marshal(elems))
|
||||
b.Comma().Arg(marshalArg(elems))
|
||||
b.Comma().WriteString("true")
|
||||
})
|
||||
} else {
|
||||
b.Arg(marshal(elems))
|
||||
b.Arg(marshalArg(elems))
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
@@ -135,11 +167,11 @@ func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts .
|
||||
b.Comma()
|
||||
path := identPath(column, opts...)
|
||||
path.value(b)
|
||||
b.WriteString(" || ").Arg(marshal(elems))
|
||||
b.WriteString(" || ").Arg(marshalArg(elems))
|
||||
b.Comma().WriteString("true")
|
||||
})
|
||||
} else {
|
||||
b.Ident(column).WriteString(" || ").Arg(marshal(elems))
|
||||
b.Ident(column).WriteString(" || ").Arg(marshalArg(elems))
|
||||
}
|
||||
},
|
||||
})
|
||||
@@ -180,3 +212,11 @@ func setCase(u *sql.UpdateBuilder, column string, w when) {
|
||||
b.WriteString(" END")
|
||||
}))
|
||||
}
|
||||
|
||||
func isPrimitive(v any) bool {
|
||||
switch reflect.TypeOf(v).Kind() {
|
||||
case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct, reflect.Ptr, reflect.Interface:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -147,7 +147,7 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
case dialect.MySQL:
|
||||
b.WriteString("JSON_CONTAINS").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
b.Arg(marshal(arg)).Comma()
|
||||
b.Arg(marshalArg(arg)).Comma()
|
||||
path.mysqlPath(b)
|
||||
})
|
||||
b.WriteOp(sql.OpEQ).Arg(1)
|
||||
@@ -163,7 +163,7 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
opts = normalizePG(b, arg, opts)
|
||||
path.Cast = "jsonb"
|
||||
path.value(b)
|
||||
b.WriteString(" @> ").Arg(marshal(arg))
|
||||
b.WriteString(" @> ").Arg(marshalArg(arg))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -642,8 +642,8 @@ func allString(v []any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// marshal stringifies the given argument to a valid JSON document.
|
||||
func marshal(arg any) any {
|
||||
// marshalArg stringifies the given argument to a valid JSON document.
|
||||
func marshalArg(arg any) any {
|
||||
if buf, err := json.Marshal(arg); err == nil {
|
||||
arg = string(buf)
|
||||
}
|
||||
|
||||
@@ -457,17 +457,17 @@ func TestAppend(t *testing.T) {
|
||||
sqljson.Append(u, "c", []string{"a"})
|
||||
return u
|
||||
}(),
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$') IS NULL OR JSON_TYPE(`c`, '$') = 'null') THEN JSON_ARRAY(?) ELSE JSON_INSERT(`c`, '$[#]', ?) END",
|
||||
wantArgs: []any{"a", "a"},
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$') IS NULL OR JSON_TYPE(`c`, '$') = 'null') THEN ? ELSE JSON_INSERT(`c`, '$[#]', ?) END",
|
||||
wantArgs: []any{`["a"]`, "a"},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
u := sql.Dialect(dialect.SQLite).Update("t")
|
||||
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
|
||||
sqljson.Append(u, "c", []any{"a", struct{}{}}, sqljson.Path("a"))
|
||||
return u
|
||||
}(),
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$.a') IS NULL OR JSON_TYPE(`c`, '$.a') = 'null') THEN JSON_SET(`c`, '$.a', JSON_ARRAY(?)) ELSE JSON_INSERT(`c`, '$.a[#]', ?) END",
|
||||
wantArgs: []any{"a", "a"},
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$.a') IS NULL OR JSON_TYPE(`c`, '$.a') = 'null') THEN JSON_SET(`c`, '$.a', JSON(?)) ELSE JSON_INSERT(`c`, '$.a[#]', ?, '$.a[#]', JSON(?)) END",
|
||||
wantArgs: []any{`["a",{}]`, "a", "{}"},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
|
||||
Reference in New Issue
Block a user