dialect/sql/sqljson: cast marshaled args as json (#3008)

This commit is contained in:
Ariel Mashraki
2022-10-11 14:15:01 +03:00
committed by GitHub
parent a26e21ff6a
commit cf137c665a
14 changed files with 310 additions and 34 deletions

View File

@@ -6,6 +6,7 @@ package sqljson
import (
"fmt"
"reflect"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
@@ -14,7 +15,7 @@ import (
type sqlite struct{}
// Append implements the driver.Append method.
func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
func (d *sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
setCase(u, column, when{
Cond: func(b *sql.Builder) {
typ := func(b *sql.Builder) *sql.Builder {
@@ -32,10 +33,10 @@ func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).mysqlPath(b)
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
b.Comma().Argf("JSON(?)", marshalArg(elems))
})
} else {
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
b.Arg(marshalArg(elems))
}
},
Else: func(b *sql.Builder) {
@@ -54,17 +55,27 @@ func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...
b.Comma()
}
path(b)
b.Comma().Arg(e)
b.Comma()
d.appendArg(b, e)
}
})
},
})
}
func (d *sqlite) appendArg(b *sql.Builder, v any) {
switch {
case !isPrimitive(v):
b.Argf("JSON(?)", marshalArg(v))
default:
b.Arg(v)
}
}
type mysql struct{}
// Append implements the driver.Append method.
func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
func (d *mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
setCase(u, column, when{
Cond: func(b *sql.Builder) {
typ := func(b *sql.Builder) *sql.Builder {
@@ -82,10 +93,10 @@ func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...O
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).mysqlPath(b)
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
b.Comma().WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
})
} else {
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
b.WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
}
},
Else: func(b *sql.Builder) {
@@ -96,13 +107,34 @@ func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...O
b.Comma()
}
identPath(column, opts...).mysqlPath(b)
b.Comma().Arg(e)
b.Comma()
d.appendArg(b, e)
}
})
},
})
}
func (d *mysql) marshalArgs(args []any) []any {
vs := make([]any, len(args))
for i, v := range args {
if !isPrimitive(v) {
v = marshalArg(v)
}
vs[i] = v
}
return vs
}
func (d *mysql) appendArg(b *sql.Builder, v any) {
switch {
case !isPrimitive(v):
b.Argf("CAST(? AS JSON)", marshalArg(v))
default:
b.Arg(v)
}
}
type postgres struct{}
// Append implements the driver.Append method.
@@ -120,11 +152,11 @@ func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts .
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).pgArrayPath(b)
b.Comma().Arg(marshal(elems))
b.Comma().Arg(marshalArg(elems))
b.Comma().WriteString("true")
})
} else {
b.Arg(marshal(elems))
b.Arg(marshalArg(elems))
}
},
Else: func(b *sql.Builder) {
@@ -135,11 +167,11 @@ func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts .
b.Comma()
path := identPath(column, opts...)
path.value(b)
b.WriteString(" || ").Arg(marshal(elems))
b.WriteString(" || ").Arg(marshalArg(elems))
b.Comma().WriteString("true")
})
} else {
b.Ident(column).WriteString(" || ").Arg(marshal(elems))
b.Ident(column).WriteString(" || ").Arg(marshalArg(elems))
}
},
})
@@ -180,3 +212,11 @@ func setCase(u *sql.UpdateBuilder, column string, w when) {
b.WriteString(" END")
}))
}
func isPrimitive(v any) bool {
switch reflect.TypeOf(v).Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct, reflect.Ptr, reflect.Interface:
return false
}
return true
}