dialect/sql/sqljson: add dialect-aware Append function

This commit is contained in:
Ariel Mashraki
2022-09-28 12:07:10 +03:00
committed by Ariel Mashraki
parent 5330f87759
commit eb4ea68356
10 changed files with 422 additions and 47 deletions

View File

@@ -0,0 +1,182 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sqljson
import (
"fmt"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
)
type sqlite struct{}
// Append implements the driver.Append method.
func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
setCase(u, column, when{
Cond: func(b *sql.Builder) {
typ := func(b *sql.Builder) *sql.Builder {
return b.WriteString("JSON_TYPE").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).mysqlPath(b)
})
}
typ(b).WriteOp(sql.OpIsNull)
b.WriteString(" OR ")
typ(b).WriteOp(sql.OpEQ).WriteString("'null'")
},
Then: func(b *sql.Builder) {
if len(opts) > 0 {
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).mysqlPath(b)
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
})
} else {
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
}
},
Else: func(b *sql.Builder) {
b.WriteString("JSON_INSERT").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
// If no path was provided the top-level value is
// a JSON array. i.e. JSON_INSERT(c, '$[#]', ?).
path := func(b *sql.Builder) { b.WriteString("'$[#]'") }
if len(opts) > 0 {
p := identPath(column, opts...)
p.Path = append(p.Path, "[#]")
path = p.mysqlPath
}
for i, e := range elems {
if i > 0 {
b.Comma()
}
path(b)
b.Comma().Arg(e)
}
})
},
})
}
type mysql struct{}
// Append implements the driver.Append method.
func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
setCase(u, column, when{
Cond: func(b *sql.Builder) {
typ := func(b *sql.Builder) *sql.Builder {
b.WriteString("JSON_TYPE(JSON_EXTRACT(")
b.Ident(column).Comma()
identPath(column, opts...).mysqlPath(b)
return b.WriteString("))")
}
typ(b).WriteOp(sql.OpIsNull)
b.WriteString(" OR ")
typ(b).WriteOp(sql.OpEQ).WriteString("'NULL'")
},
Then: func(b *sql.Builder) {
if len(opts) > 0 {
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).mysqlPath(b)
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
})
} else {
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
}
},
Else: func(b *sql.Builder) {
b.WriteString("JSON_ARRAY_APPEND").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
for i, e := range elems {
if i > 0 {
b.Comma()
}
identPath(column, opts...).mysqlPath(b)
b.Comma().Arg(e)
}
})
},
})
}
type postgres struct{}
// Append implements the driver.Append method.
func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
setCase(u, column, when{
Cond: func(b *sql.Builder) {
ValuePath(b, column, append(opts, Cast("jsonb"))...)
b.WriteOp(sql.OpIsNull)
b.WriteString(" OR ")
ValuePath(b, column, append(opts, Cast("jsonb"))...)
b.WriteOp(sql.OpEQ).WriteString("'null'::jsonb")
},
Then: func(b *sql.Builder) {
if len(opts) > 0 {
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).pgArrayPath(b)
b.Comma().Arg(marshal(elems))
b.Comma().WriteString("true")
})
} else {
b.Arg(marshal(elems))
}
},
Else: func(b *sql.Builder) {
if len(opts) > 0 {
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
identPath(column, opts...).pgArrayPath(b)
b.Comma()
path := identPath(column, opts...)
path.value(b)
b.WriteString(" || ").Arg(marshal(elems))
b.Comma().WriteString("true")
})
} else {
b.Ident(column).WriteString(" || ").Arg(marshal(elems))
}
},
})
}
// driver groups all dialect-specific methods.
type driver interface {
Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option)
}
func newDriver(name string) (driver, error) {
switch name {
case dialect.SQLite:
return (*sqlite)(nil), nil
case dialect.MySQL:
return (*mysql)(nil), nil
case dialect.Postgres:
return (*postgres)(nil), nil
default:
return nil, fmt.Errorf("sqljson: unknown driver %q", name)
}
}
type when struct{ Cond, Then, Else func(*sql.Builder) }
// setCase sets the column value using the "CASE WHEN" statement.
// The x defines the condition/predicate, t is the true (if) case,
// and 'f' defines the false (else).
func setCase(u *sql.UpdateBuilder, column string, w when) {
u.Set(column, sql.ExprFunc(func(b *sql.Builder) {
b.WriteString("CASE WHEN ").Wrap(func(b *sql.Builder) {
w.Cond(b)
})
b.WriteString(" THEN ")
w.Then(b)
b.WriteString(" ELSE ")
w.Else(b)
b.WriteString(" END")
}))
}

View File

@@ -46,7 +46,7 @@ func ValueIsNull(column string, opts ...Option) *sql.Predicate {
switch b.Dialect() {
case dialect.MySQL:
path := identPath(column, opts...)
b.WriteString("JSON_CONTAINS").Nested(func(b *sql.Builder) {
b.WriteString("JSON_CONTAINS").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
b.WriteString("'null'").Comma()
path.mysqlPath(b)
@@ -145,15 +145,15 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
path := identPath(column, opts...)
switch b.Dialect() {
case dialect.MySQL:
b.WriteString("JSON_CONTAINS").Nested(func(b *sql.Builder) {
b.WriteString("JSON_CONTAINS").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
b.Arg(marshal(arg)).Comma()
path.mysqlPath(b)
})
b.WriteOp(sql.OpEQ).Arg(1)
case dialect.SQLite:
b.WriteString("EXISTS").Nested(func(b *sql.Builder) {
b.WriteString("SELECT * FROM JSON_EACH").Nested(func(b *sql.Builder) {
b.WriteString("EXISTS").Wrap(func(b *sql.Builder) {
b.WriteString("SELECT * FROM JSON_EACH").Wrap(func(b *sql.Builder) {
b.Ident(column).Comma()
path.mysqlPath(b)
})
@@ -227,7 +227,7 @@ func valueInOp(column string, args []any, opts []Option, op sql.Op) *sql.Predica
}
ValuePath(b, column, opts...)
b.WriteOp(op)
b.Nested(func(b *sql.Builder) {
b.Wrap(func(b *sql.Builder) {
if s, ok := args[0].(*sql.Selector); ok {
b.Join(s)
} else {
@@ -325,6 +325,37 @@ func LenPath(b *sql.Builder, column string, opts ...Option) {
path.length(b)
}
// Append writes to the given SQL builder the SQL command for appending JSON values
// into the array, optionally defined as a key. Note, the generated SQL will use the
// Go semantics, the JSON column/key will be set to the given Array in case it is `null`
// or NULL. For example:
//
// Append(u, column, []string{"a", "b"})
// UPDATE "t" SET "c" = CASE
// WHEN ("c" IS NULL OR "c" = 'null'::jsonb)
// THEN $1 ELSE "c" || $2 END
//
// Append(u, column, []any{"a", 1}, sqljson.Path("a"))
// UPDATE "t" SET "c" = CASE
// WHEN (("c"->'a')::jsonb IS NULL OR ("c"->'a')::jsonb = 'null'::jsonb)
// THEN jsonb_set("c", '{a}', $1, true) ELSE jsonb_set("c", '{a}', "c"->'a' || $2, true) END
func Append[T any](u *sql.UpdateBuilder, column string, elems []T, opts ...Option) {
if len(elems) == 0 {
u.AddError(fmt.Errorf("sqljson: cannot append an empty array to column %q", column))
return
}
drv, err := newDriver(u.Dialect())
if err != nil {
u.AddError(err)
return
}
vs := make([]any, len(elems))
for i, e := range elems {
vs[i] = e
}
drv.Append(u, column, vs, opts...)
}
// Option allows for calling database JSON paths with functional options.
type Option func(*PathOptions)
@@ -359,7 +390,7 @@ func Unquote(unquote bool) Option {
}
}
// Cast indicates that the result value should be casted to the given type.
// Cast indicates that the result value should be cast to the given type.
//
// ValuePath(b, "column", Path("a", "b", "[1]", "c"), Cast("int"))
func Cast(typ string) Option {
@@ -395,7 +426,7 @@ func (p *PathOptions) value(b *sql.Builder) {
b.WriteByte('(')
defer b.WriteString(")::" + p.Cast)
}
p.pgPath(b)
p.pgTextPath(b)
default:
if p.Unquote && b.Dialect() == dialect.MySQL {
b.WriteString("JSON_UNQUOTE(")
@@ -410,7 +441,7 @@ func (p *PathOptions) length(b *sql.Builder) {
switch {
case b.Dialect() == dialect.Postgres:
b.WriteString("JSONB_ARRAY_LENGTH(")
p.pgPath(b)
p.pgTextPath(b)
b.WriteByte(')')
case b.Dialect() == dialect.MySQL:
p.mysqlFunc("JSON_LENGTH", b)
@@ -444,8 +475,8 @@ func (p *PathOptions) mysqlPath(b *sql.Builder) {
b.WriteByte('\'')
}
// pgPath writes the JSON path in Postgres format `"a"->'b'->>'c'`.
func (p *PathOptions) pgPath(b *sql.Builder) {
// pgTextPath writes the JSON path in PostgreSQL text format: `"a"->'b'->>'c'`.
func (p *PathOptions) pgTextPath(b *sql.Builder) {
b.Ident(p.Ident)
for i, s := range p.Path {
b.WriteString("->")
@@ -460,6 +491,21 @@ func (p *PathOptions) pgPath(b *sql.Builder) {
}
}
// pgArrayPath writes the JSON path in PostgreSQL array text[] format: '{a,1,b}'.
func (p *PathOptions) pgArrayPath(b *sql.Builder) {
b.WriteString("'{")
for i, s := range p.Path {
if i > 0 {
b.Comma()
}
if idx, ok := isJSONIdx(s); ok {
s = idx
}
b.WriteString(s)
}
b.WriteString("}'")
}
// ParsePath parses the "dotpath" for the DotPath option.
//
// "a.b" => ["a", "b"]
@@ -553,7 +599,7 @@ func isQuoted(s string) bool {
// isJSONIdx reports whether the string represents a JSON index.
func isJSONIdx(s string) (string, bool) {
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) {
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && (isNumber(s[1:len(s)-1]) || s[1] == '#' && isNumber(s[2:len(s)-1])) {
return s[1 : len(s)-1], true
}
return "", false

View File

@@ -413,3 +413,73 @@ func TestParsePath(t *testing.T) {
})
}
}
func TestAppend(t *testing.T) {
tests := []struct {
input sql.Querier
wantQuery string
wantArgs []any
}{
{
input: func() sql.Querier {
u := sql.Dialect(dialect.Postgres).Update("t")
sqljson.Append(u, "c", []string{"a"})
return u
}(),
wantQuery: `UPDATE "t" SET "c" = CASE WHEN ("c" IS NULL OR "c" = 'null'::jsonb) THEN $1 ELSE "c" || $2 END`,
wantArgs: []any{`["a"]`, `["a"]`},
},
{
input: func() sql.Querier {
u := sql.Dialect(dialect.Postgres).Update("t")
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
return u
}(),
wantQuery: `UPDATE "t" SET "c" = CASE WHEN (("c"->'a')::jsonb IS NULL OR ("c"->'a')::jsonb = 'null'::jsonb) THEN jsonb_set("c", '{a}', $1, true) ELSE jsonb_set("c", '{a}', "c"->'a' || $2, true) END`,
wantArgs: []any{`["a"]`, `["a"]`},
},
{
input: func() sql.Querier {
u := sql.Dialect(dialect.SQLite).Update("t")
sqljson.Append(u, "c", []string{"a"})
return u
}(),
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$') IS NULL OR JSON_TYPE(`c`, '$') = 'null') THEN JSON_ARRAY(?) ELSE JSON_INSERT(`c`, '$[#]', ?) END",
wantArgs: []any{"a", "a"},
},
{
input: func() sql.Querier {
u := sql.Dialect(dialect.SQLite).Update("t")
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
return u
}(),
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$.a') IS NULL OR JSON_TYPE(`c`, '$.a') = 'null') THEN JSON_SET(`c`, '$.a', JSON_ARRAY(?)) ELSE JSON_INSERT(`c`, '$.a[#]', ?) END",
wantArgs: []any{"a", "a"},
},
{
input: func() sql.Querier {
u := sql.Dialect(dialect.MySQL).Update("t")
sqljson.Append(u, "c", []string{"a"})
return u
}(),
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(JSON_EXTRACT(`c`, '$')) IS NULL OR JSON_TYPE(JSON_EXTRACT(`c`, '$')) = 'NULL') THEN JSON_ARRAY(?) ELSE JSON_ARRAY_APPEND(`c`, '$', ?) END",
wantArgs: []any{"a", "a"},
},
{
input: func() sql.Querier {
u := sql.Dialect(dialect.MySQL).Update("t")
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
return u
}(),
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(JSON_EXTRACT(`c`, '$.a')) IS NULL OR JSON_TYPE(JSON_EXTRACT(`c`, '$.a')) = 'NULL') THEN JSON_SET(`c`, '$.a', JSON_ARRAY(?)) ELSE JSON_ARRAY_APPEND(`c`, '$.a', ?) END",
wantArgs: []any{"a", "a"},
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
query, args := tt.input.Query()
require.Equal(t, tt.wantQuery, query)
require.Equal(t, tt.wantArgs, args)
})
}
}