dialect/sql/sqljson: cast marshaled args as json (#3008)

This commit is contained in:
Ariel Mashraki
2022-10-11 14:15:01 +03:00
committed by GitHub
parent a26e21ff6a
commit cf137c665a
14 changed files with 310 additions and 34 deletions

View File

@@ -6,6 +6,7 @@ package sqljson
import (
"fmt"
"reflect"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
@@ -14,7 +15,7 @@ import (
type sqlite struct{}
// Append implements the driver.Append method.
func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
func (d *sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
setCase(u, column, when{
Cond: func(b *sql.Builder) {
typ := func(b *sql.Builder) *sql.Builder {
@@ -32,10 +33,10 @@ func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).mysqlPath(b)
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
b.Comma().Argf("JSON(?)", marshalArg(elems))
})
} else {
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
b.Arg(marshalArg(elems))
}
},
Else: func(b *sql.Builder) {
@@ -54,17 +55,27 @@ func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...
b.Comma()
}
path(b)
b.Comma().Arg(e)
b.Comma()
d.appendArg(b, e)
}
})
},
})
}
func (d *sqlite) appendArg(b *sql.Builder, v any) {
switch {
case !isPrimitive(v):
b.Argf("JSON(?)", marshalArg(v))
default:
b.Arg(v)
}
}
type mysql struct{}
// Append implements the driver.Append method.
func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
func (d *mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
setCase(u, column, when{
Cond: func(b *sql.Builder) {
typ := func(b *sql.Builder) *sql.Builder {
@@ -82,10 +93,10 @@ func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...O
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).mysqlPath(b)
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
b.Comma().WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
})
} else {
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
b.WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
}
},
Else: func(b *sql.Builder) {
@@ -96,13 +107,34 @@ func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...O
b.Comma()
}
identPath(column, opts...).mysqlPath(b)
b.Comma().Arg(e)
b.Comma()
d.appendArg(b, e)
}
})
},
})
}
func (d *mysql) marshalArgs(args []any) []any {
vs := make([]any, len(args))
for i, v := range args {
if !isPrimitive(v) {
v = marshalArg(v)
}
vs[i] = v
}
return vs
}
func (d *mysql) appendArg(b *sql.Builder, v any) {
switch {
case !isPrimitive(v):
b.Argf("CAST(? AS JSON)", marshalArg(v))
default:
b.Arg(v)
}
}
type postgres struct{}
// Append implements the driver.Append method.
@@ -120,11 +152,11 @@ func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts .
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).pgArrayPath(b)
b.Comma().Arg(marshal(elems))
b.Comma().Arg(marshalArg(elems))
b.Comma().WriteString("true")
})
} else {
b.Arg(marshal(elems))
b.Arg(marshalArg(elems))
}
},
Else: func(b *sql.Builder) {
@@ -135,11 +167,11 @@ func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts .
b.Comma()
path := identPath(column, opts...)
path.value(b)
b.WriteString(" || ").Arg(marshal(elems))
b.WriteString(" || ").Arg(marshalArg(elems))
b.Comma().WriteString("true")
})
} else {
b.Ident(column).WriteString(" || ").Arg(marshal(elems))
b.Ident(column).WriteString(" || ").Arg(marshalArg(elems))
}
},
})
@@ -180,3 +212,11 @@ func setCase(u *sql.UpdateBuilder, column string, w when) {
b.WriteString(" END")
}))
}
func isPrimitive(v any) bool {
switch reflect.TypeOf(v).Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct, reflect.Ptr, reflect.Interface:
return false
}
return true
}

View File

@@ -147,7 +147,7 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
case dialect.MySQL:
b.WriteString("JSON_CONTAINS").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
b.Arg(marshal(arg)).Comma()
b.Arg(marshalArg(arg)).Comma()
path.mysqlPath(b)
})
b.WriteOp(sql.OpEQ).Arg(1)
@@ -163,7 +163,7 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
opts = normalizePG(b, arg, opts)
path.Cast = "jsonb"
path.value(b)
b.WriteString(" @> ").Arg(marshal(arg))
b.WriteString(" @> ").Arg(marshalArg(arg))
}
})
}
@@ -642,8 +642,8 @@ func allString(v []any) bool {
return true
}
// marshal stringifies the given argument to a valid JSON document.
func marshal(arg any) any {
// marshalArg stringifies the given argument to a valid JSON document.
func marshalArg(arg any) any {
if buf, err := json.Marshal(arg); err == nil {
arg = string(buf)
}

View File

@@ -457,17 +457,17 @@ func TestAppend(t *testing.T) {
sqljson.Append(u, "c", []string{"a"})
return u
}(),
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$') IS NULL OR JSON_TYPE(`c`, '$') = 'null') THEN JSON_ARRAY(?) ELSE JSON_INSERT(`c`, '$[#]', ?) END",
wantArgs: []any{"a", "a"},
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$') IS NULL OR JSON_TYPE(`c`, '$') = 'null') THEN ? ELSE JSON_INSERT(`c`, '$[#]', ?) END",
wantArgs: []any{`["a"]`, "a"},
},
{
input: func() sql.Querier {
u := sql.Dialect(dialect.SQLite).Update("t")
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
sqljson.Append(u, "c", []any{"a", struct{}{}}, sqljson.Path("a"))
return u
}(),
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$.a') IS NULL OR JSON_TYPE(`c`, '$.a') = 'null') THEN JSON_SET(`c`, '$.a', JSON_ARRAY(?)) ELSE JSON_INSERT(`c`, '$.a[#]', ?) END",
wantArgs: []any{"a", "a"},
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$.a') IS NULL OR JSON_TYPE(`c`, '$.a') = 'null') THEN JSON_SET(`c`, '$.a', JSON(?)) ELSE JSON_INSERT(`c`, '$.a[#]', ?, '$.a[#]', JSON(?)) END",
wantArgs: []any{`["a",{}]`, "a", "{}"},
},
{
input: func() sql.Querier {