mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/sqljson: cast marshaled args as json (#3008)
This commit is contained in:
@@ -3331,22 +3331,19 @@ func (b *Builder) Arg(a any) *Builder {
|
||||
b.Join(a)
|
||||
return b
|
||||
}
|
||||
b.total++
|
||||
b.args = append(b.args, a)
|
||||
// Default placeholder param (MySQL and SQLite).
|
||||
param := "?"
|
||||
format := "?"
|
||||
if b.postgres() {
|
||||
// Postgres' arguments are referenced using the syntax $n.
|
||||
// $1 refers to the 1st argument, $2 to the 2nd, and so on.
|
||||
param = "$" + strconv.Itoa(b.total)
|
||||
format = "$" + strconv.Itoa(b.total+1)
|
||||
}
|
||||
if f, ok := a.(ParamFormatter); ok {
|
||||
param = f.FormatParam(param, &StmtInfo{
|
||||
format = f.FormatParam(format, &StmtInfo{
|
||||
Dialect: b.dialect,
|
||||
})
|
||||
}
|
||||
b.WriteString(param)
|
||||
return b
|
||||
return b.Argf(format, a)
|
||||
}
|
||||
|
||||
// Args appends a list of arguments to the builder.
|
||||
@@ -3360,6 +3357,29 @@ func (b *Builder) Args(a ...any) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// Argf appends an input argument to the builder
|
||||
// with the given format. For example:
|
||||
//
|
||||
// FormatArg("JSON(?)", b).
|
||||
// FormatArg("ST_GeomFromText(?)", geom)
|
||||
func (b *Builder) Argf(format string, a any) *Builder {
|
||||
switch a := a.(type) {
|
||||
case nil:
|
||||
b.WriteString("NULL")
|
||||
return b
|
||||
case *raw:
|
||||
b.WriteString(a.s)
|
||||
return b
|
||||
case Querier:
|
||||
b.Join(a)
|
||||
return b
|
||||
}
|
||||
b.total++
|
||||
b.args = append(b.args, a)
|
||||
b.WriteString(format)
|
||||
return b
|
||||
}
|
||||
|
||||
// Comma adds a comma to the query.
|
||||
func (b *Builder) Comma() *Builder {
|
||||
return b.WriteString(", ")
|
||||
|
||||
@@ -6,6 +6,7 @@ package sqljson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
type sqlite struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
func (d *sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
setCase(u, column, when{
|
||||
Cond: func(b *sql.Builder) {
|
||||
typ := func(b *sql.Builder) *sql.Builder {
|
||||
@@ -32,10 +33,10 @@ func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...
|
||||
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
b.Comma().Argf("JSON(?)", marshalArg(elems))
|
||||
})
|
||||
} else {
|
||||
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
b.Arg(marshalArg(elems))
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
@@ -54,17 +55,27 @@ func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...
|
||||
b.Comma()
|
||||
}
|
||||
path(b)
|
||||
b.Comma().Arg(e)
|
||||
b.Comma()
|
||||
d.appendArg(b, e)
|
||||
}
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (d *sqlite) appendArg(b *sql.Builder, v any) {
|
||||
switch {
|
||||
case !isPrimitive(v):
|
||||
b.Argf("JSON(?)", marshalArg(v))
|
||||
default:
|
||||
b.Arg(v)
|
||||
}
|
||||
}
|
||||
|
||||
type mysql struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
func (d *mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
setCase(u, column, when{
|
||||
Cond: func(b *sql.Builder) {
|
||||
typ := func(b *sql.Builder) *sql.Builder {
|
||||
@@ -82,10 +93,10 @@ func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...O
|
||||
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
b.Comma().WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
|
||||
})
|
||||
} else {
|
||||
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
b.WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
@@ -96,13 +107,34 @@ func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...O
|
||||
b.Comma()
|
||||
}
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().Arg(e)
|
||||
b.Comma()
|
||||
d.appendArg(b, e)
|
||||
}
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (d *mysql) marshalArgs(args []any) []any {
|
||||
vs := make([]any, len(args))
|
||||
for i, v := range args {
|
||||
if !isPrimitive(v) {
|
||||
v = marshalArg(v)
|
||||
}
|
||||
vs[i] = v
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
func (d *mysql) appendArg(b *sql.Builder, v any) {
|
||||
switch {
|
||||
case !isPrimitive(v):
|
||||
b.Argf("CAST(? AS JSON)", marshalArg(v))
|
||||
default:
|
||||
b.Arg(v)
|
||||
}
|
||||
}
|
||||
|
||||
type postgres struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
@@ -120,11 +152,11 @@ func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts .
|
||||
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).pgArrayPath(b)
|
||||
b.Comma().Arg(marshal(elems))
|
||||
b.Comma().Arg(marshalArg(elems))
|
||||
b.Comma().WriteString("true")
|
||||
})
|
||||
} else {
|
||||
b.Arg(marshal(elems))
|
||||
b.Arg(marshalArg(elems))
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
@@ -135,11 +167,11 @@ func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts .
|
||||
b.Comma()
|
||||
path := identPath(column, opts...)
|
||||
path.value(b)
|
||||
b.WriteString(" || ").Arg(marshal(elems))
|
||||
b.WriteString(" || ").Arg(marshalArg(elems))
|
||||
b.Comma().WriteString("true")
|
||||
})
|
||||
} else {
|
||||
b.Ident(column).WriteString(" || ").Arg(marshal(elems))
|
||||
b.Ident(column).WriteString(" || ").Arg(marshalArg(elems))
|
||||
}
|
||||
},
|
||||
})
|
||||
@@ -180,3 +212,11 @@ func setCase(u *sql.UpdateBuilder, column string, w when) {
|
||||
b.WriteString(" END")
|
||||
}))
|
||||
}
|
||||
|
||||
func isPrimitive(v any) bool {
|
||||
switch reflect.TypeOf(v).Kind() {
|
||||
case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct, reflect.Ptr, reflect.Interface:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -147,7 +147,7 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
case dialect.MySQL:
|
||||
b.WriteString("JSON_CONTAINS").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
b.Arg(marshal(arg)).Comma()
|
||||
b.Arg(marshalArg(arg)).Comma()
|
||||
path.mysqlPath(b)
|
||||
})
|
||||
b.WriteOp(sql.OpEQ).Arg(1)
|
||||
@@ -163,7 +163,7 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
opts = normalizePG(b, arg, opts)
|
||||
path.Cast = "jsonb"
|
||||
path.value(b)
|
||||
b.WriteString(" @> ").Arg(marshal(arg))
|
||||
b.WriteString(" @> ").Arg(marshalArg(arg))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -642,8 +642,8 @@ func allString(v []any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// marshal stringifies the given argument to a valid JSON document.
|
||||
func marshal(arg any) any {
|
||||
// marshalArg stringifies the given argument to a valid JSON document.
|
||||
func marshalArg(arg any) any {
|
||||
if buf, err := json.Marshal(arg); err == nil {
|
||||
arg = string(buf)
|
||||
}
|
||||
|
||||
@@ -457,17 +457,17 @@ func TestAppend(t *testing.T) {
|
||||
sqljson.Append(u, "c", []string{"a"})
|
||||
return u
|
||||
}(),
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$') IS NULL OR JSON_TYPE(`c`, '$') = 'null') THEN JSON_ARRAY(?) ELSE JSON_INSERT(`c`, '$[#]', ?) END",
|
||||
wantArgs: []any{"a", "a"},
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$') IS NULL OR JSON_TYPE(`c`, '$') = 'null') THEN ? ELSE JSON_INSERT(`c`, '$[#]', ?) END",
|
||||
wantArgs: []any{`["a"]`, "a"},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
u := sql.Dialect(dialect.SQLite).Update("t")
|
||||
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
|
||||
sqljson.Append(u, "c", []any{"a", struct{}{}}, sqljson.Path("a"))
|
||||
return u
|
||||
}(),
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$.a') IS NULL OR JSON_TYPE(`c`, '$.a') = 'null') THEN JSON_SET(`c`, '$.a', JSON_ARRAY(?)) ELSE JSON_INSERT(`c`, '$.a[#]', ?) END",
|
||||
wantArgs: []any{"a", "a"},
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$.a') IS NULL OR JSON_TYPE(`c`, '$.a') = 'null') THEN JSON_SET(`c`, '$.a', JSON(?)) ELSE JSON_INSERT(`c`, '$.a[#]', ?, '$.a[#]', JSON(?)) END",
|
||||
wantArgs: []any{`["a",{}]`, "a", "{}"},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
|
||||
Reference in New Issue
Block a user