mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/sqljson: add dialect-aware Append function
This commit is contained in:
committed by
Ariel Mashraki
parent
5330f87759
commit
eb4ea68356
@@ -99,7 +99,7 @@ func (c *ColumnBuilder) Query() (string, []any) {
|
||||
}
|
||||
if c.check != nil {
|
||||
c.WriteString(" CHECK ")
|
||||
c.Nested(c.check)
|
||||
c.Wrap(c.check)
|
||||
}
|
||||
return c.String(), c.args
|
||||
}
|
||||
@@ -213,11 +213,11 @@ func (t *TableBuilder) Query() (string, []any) {
|
||||
t.WriteString("IF NOT EXISTS ")
|
||||
}
|
||||
t.Ident(t.name)
|
||||
t.Nested(func(b *Builder) {
|
||||
t.Wrap(func(b *Builder) {
|
||||
b.JoinComma(t.columns...)
|
||||
if len(t.primary) > 0 {
|
||||
b.Comma().WriteString("PRIMARY KEY")
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.IdentComma(t.primary...)
|
||||
})
|
||||
}
|
||||
@@ -340,7 +340,7 @@ func (t *TableAlter) AddIndex(idx *IndexBuilder) *TableAlter {
|
||||
}
|
||||
b.WriteString("INDEX ")
|
||||
b.Ident(idx.name)
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.IdentComma(idx.columns...)
|
||||
})
|
||||
t.Queries = append(t.Queries, b)
|
||||
@@ -467,7 +467,7 @@ func (fk *ForeignKeyBuilder) Query() (string, []any) {
|
||||
fk.Ident(fk.symbol).Pad()
|
||||
}
|
||||
fk.WriteString("FOREIGN KEY")
|
||||
fk.Nested(func(b *Builder) {
|
||||
fk.Wrap(func(b *Builder) {
|
||||
b.IdentComma(fk.columns...)
|
||||
})
|
||||
fk.Pad().Join(fk.ref)
|
||||
@@ -505,7 +505,7 @@ func (r *ReferenceBuilder) Columns(s ...string) *ReferenceBuilder {
|
||||
func (r *ReferenceBuilder) Query() (string, []any) {
|
||||
r.WriteString("REFERENCES ")
|
||||
r.Ident(r.table)
|
||||
r.Nested(func(b *Builder) {
|
||||
r.Wrap(func(b *Builder) {
|
||||
b.IdentComma(r.columns...)
|
||||
})
|
||||
return r.String(), r.args
|
||||
@@ -593,18 +593,18 @@ func (i *IndexBuilder) Query() (string, []any) {
|
||||
if i.method != "" {
|
||||
i.WriteString(" USING ").Ident(i.method)
|
||||
}
|
||||
i.Nested(func(b *Builder) {
|
||||
i.Wrap(func(b *Builder) {
|
||||
b.IdentComma(i.columns...)
|
||||
})
|
||||
case dialect.MySQL:
|
||||
i.Nested(func(b *Builder) {
|
||||
i.Wrap(func(b *Builder) {
|
||||
b.IdentComma(i.columns...)
|
||||
})
|
||||
if i.method != "" {
|
||||
i.WriteString(" USING " + i.method)
|
||||
}
|
||||
default:
|
||||
i.Nested(func(b *Builder) {
|
||||
i.Wrap(func(b *Builder) {
|
||||
b.IdentComma(i.columns...)
|
||||
})
|
||||
}
|
||||
@@ -1067,7 +1067,7 @@ func (u *UpdateBuilder) Add(column string, v any) *UpdateBuilder {
|
||||
u.columns = append(u.columns, column)
|
||||
u.values = append(u.values, ExprFunc(func(b *Builder) {
|
||||
b.WriteString("COALESCE")
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.Ident(Table(u.table).C(column)).Comma().WriteByte('0')
|
||||
})
|
||||
b.WriteString(" + ")
|
||||
@@ -1281,7 +1281,7 @@ func (p *Predicate) False() *Predicate {
|
||||
// Not(Or(EQ("name", "foo"), EQ("name", "bar")))
|
||||
func Not(pred *Predicate) *Predicate {
|
||||
return P().Not().Append(func(b *Builder) {
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.Join(pred)
|
||||
})
|
||||
})
|
||||
@@ -1542,7 +1542,7 @@ func (p *Predicate) In(col string, args ...any) *Predicate {
|
||||
}
|
||||
return p.Append(func(b *Builder) {
|
||||
b.Ident(col).WriteOp(OpIn)
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
if s, ok := args[0].(*Selector); ok {
|
||||
b.Join(s)
|
||||
} else {
|
||||
@@ -1594,7 +1594,7 @@ func (p *Predicate) NotIn(col string, args ...any) *Predicate {
|
||||
}
|
||||
return p.Append(func(b *Builder) {
|
||||
b.Ident(col).WriteOp(OpNotIn)
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
if s, ok := args[0].(*Selector); ok {
|
||||
b.Join(s)
|
||||
} else {
|
||||
@@ -1613,7 +1613,7 @@ func Exists(query Querier) *Predicate {
|
||||
func (p *Predicate) Exists(query Querier) *Predicate {
|
||||
return p.Append(func(b *Builder) {
|
||||
b.WriteString("EXISTS ")
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.Join(query)
|
||||
})
|
||||
})
|
||||
@@ -1628,7 +1628,7 @@ func NotExists(query Querier) *Predicate {
|
||||
func (p *Predicate) NotExists(query Querier) *Predicate {
|
||||
return p.Append(func(b *Builder) {
|
||||
b.WriteString("NOT EXISTS ")
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.Join(query)
|
||||
})
|
||||
})
|
||||
@@ -1777,7 +1777,7 @@ func CompositeLT(columns []string, args ...any) *Predicate {
|
||||
|
||||
func (p *Predicate) compositeP(operator string, columns []string, args ...any) *Predicate {
|
||||
return p.Append(func(b *Builder) {
|
||||
b.Nested(func(nb *Builder) {
|
||||
b.Wrap(func(nb *Builder) {
|
||||
nb.IdentComma(columns...)
|
||||
})
|
||||
b.WriteString(operator)
|
||||
@@ -1822,7 +1822,7 @@ func (p *Predicate) Query() (string, []any) {
|
||||
func (*Predicate) arg(b *Builder, a any) {
|
||||
switch a.(type) {
|
||||
case *Selector:
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.Arg(a)
|
||||
})
|
||||
default:
|
||||
@@ -1855,7 +1855,7 @@ func (p *Predicate) mayWrap(preds []*Predicate, b *Builder, op string) {
|
||||
b.WriteByte(' ')
|
||||
}
|
||||
if len(preds[i].fns) > 1 {
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.Join(preds[i])
|
||||
})
|
||||
} else {
|
||||
@@ -1948,7 +1948,7 @@ func (f *Func) Avg(ident string) {
|
||||
func (f *Func) byName(fn, ident string) {
|
||||
f.Append(func(b *Builder) {
|
||||
f.WriteString(fn)
|
||||
f.Nested(func(b *Builder) {
|
||||
f.Wrap(func(b *Builder) {
|
||||
b.Ident(ident)
|
||||
})
|
||||
})
|
||||
@@ -2707,7 +2707,7 @@ func (s *Selector) Query() (string, []any) {
|
||||
b.WriteString(t.ref())
|
||||
case *Selector:
|
||||
t.SetDialect(s.dialect)
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.Join(t)
|
||||
})
|
||||
b.WriteString(" AS ")
|
||||
@@ -2727,7 +2727,7 @@ func (s *Selector) Query() (string, []any) {
|
||||
b.WriteString(view.ref())
|
||||
case *Selector:
|
||||
view.SetDialect(s.dialect)
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.Join(view)
|
||||
})
|
||||
b.WriteString(" AS ")
|
||||
@@ -2936,7 +2936,7 @@ func (w *WithBuilder) Query() (string, []any) {
|
||||
w.WriteByte(')')
|
||||
}
|
||||
w.WriteString(" AS ")
|
||||
w.Nested(func(b *Builder) {
|
||||
w.Wrap(func(b *Builder) {
|
||||
b.Join(cte.s)
|
||||
})
|
||||
}
|
||||
@@ -3002,7 +3002,7 @@ func (w *WindowBuilder) OrderExpr(exprs ...Querier) *WindowBuilder {
|
||||
func (w *WindowBuilder) Query() (string, []any) {
|
||||
w.WriteString(w.fn)
|
||||
w.WriteString("() OVER ")
|
||||
w.Nested(func(b *Builder) {
|
||||
w.Wrap(func(b *Builder) {
|
||||
if w.partition != nil {
|
||||
b.WriteString("PARTITION BY ")
|
||||
w.partition(b)
|
||||
@@ -3404,8 +3404,8 @@ func (b *Builder) join(qs []Querier, sep string) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// Nested gets a callback, and wraps its result with parentheses.
|
||||
func (b *Builder) Nested(f func(*Builder)) *Builder {
|
||||
// Wrap gets a callback, and wraps its result with parentheses.
|
||||
func (b *Builder) Wrap(f func(*Builder)) *Builder {
|
||||
nb := &Builder{dialect: b.dialect, total: b.total, sb: &strings.Builder{}}
|
||||
nb.WriteByte('(')
|
||||
f(nb)
|
||||
@@ -3416,6 +3416,13 @@ func (b *Builder) Nested(f func(*Builder)) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// Nested gets a callback, and wraps its result with parentheses.
|
||||
//
|
||||
// Deprecated: Use Builder.Wrap instead.
|
||||
func (b *Builder) Nested(f func(*Builder)) *Builder {
|
||||
return b.Wrap(f)
|
||||
}
|
||||
|
||||
// SetDialect sets the builder dialect. It's used for garnering dialect specific queries.
|
||||
func (b *Builder) SetDialect(dialect string) {
|
||||
b.dialect = dialect
|
||||
|
||||
@@ -1670,7 +1670,7 @@ func TestSelector_SelectExpr(t *testing.T) {
|
||||
AppendSelectExpr(
|
||||
Expr("age + $1", 1),
|
||||
ExprFunc(func(b *Builder) {
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Wrap(func(b *Builder) {
|
||||
b.WriteString("similarity(").Ident("name").Comma().Arg("A").WriteByte(')')
|
||||
b.WriteOp(OpAdd)
|
||||
b.WriteString("similarity(").Ident("desc").Comma().Arg("D").WriteByte(')')
|
||||
|
||||
182
dialect/sql/sqljson/dialect.go
Normal file
182
dialect/sql/sqljson/dialect.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package sqljson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type sqlite struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
func (*sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
setCase(u, column, when{
|
||||
Cond: func(b *sql.Builder) {
|
||||
typ := func(b *sql.Builder) *sql.Builder {
|
||||
return b.WriteString("JSON_TYPE").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
})
|
||||
}
|
||||
typ(b).WriteOp(sql.OpIsNull)
|
||||
b.WriteString(" OR ")
|
||||
typ(b).WriteOp(sql.OpEQ).WriteString("'null'")
|
||||
},
|
||||
Then: func(b *sql.Builder) {
|
||||
if len(opts) > 0 {
|
||||
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
})
|
||||
} else {
|
||||
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
b.WriteString("JSON_INSERT").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
// If no path was provided the top-level value is
|
||||
// a JSON array. i.e. JSON_INSERT(c, '$[#]', ?).
|
||||
path := func(b *sql.Builder) { b.WriteString("'$[#]'") }
|
||||
if len(opts) > 0 {
|
||||
p := identPath(column, opts...)
|
||||
p.Path = append(p.Path, "[#]")
|
||||
path = p.mysqlPath
|
||||
}
|
||||
for i, e := range elems {
|
||||
if i > 0 {
|
||||
b.Comma()
|
||||
}
|
||||
path(b)
|
||||
b.Comma().Arg(e)
|
||||
}
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type mysql struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
func (*mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
setCase(u, column, when{
|
||||
Cond: func(b *sql.Builder) {
|
||||
typ := func(b *sql.Builder) *sql.Builder {
|
||||
b.WriteString("JSON_TYPE(JSON_EXTRACT(")
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
return b.WriteString("))")
|
||||
}
|
||||
typ(b).WriteOp(sql.OpIsNull)
|
||||
b.WriteString(" OR ")
|
||||
typ(b).WriteOp(sql.OpEQ).WriteString("'NULL'")
|
||||
},
|
||||
Then: func(b *sql.Builder) {
|
||||
if len(opts) > 0 {
|
||||
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
})
|
||||
} else {
|
||||
b.WriteString("JSON_ARRAY(").Args(elems...).WriteByte(')')
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
b.WriteString("JSON_ARRAY_APPEND").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
for i, e := range elems {
|
||||
if i > 0 {
|
||||
b.Comma()
|
||||
}
|
||||
identPath(column, opts...).mysqlPath(b)
|
||||
b.Comma().Arg(e)
|
||||
}
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type postgres struct{}
|
||||
|
||||
// Append implements the driver.Append method.
|
||||
func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
||||
setCase(u, column, when{
|
||||
Cond: func(b *sql.Builder) {
|
||||
ValuePath(b, column, append(opts, Cast("jsonb"))...)
|
||||
b.WriteOp(sql.OpIsNull)
|
||||
b.WriteString(" OR ")
|
||||
ValuePath(b, column, append(opts, Cast("jsonb"))...)
|
||||
b.WriteOp(sql.OpEQ).WriteString("'null'::jsonb")
|
||||
},
|
||||
Then: func(b *sql.Builder) {
|
||||
if len(opts) > 0 {
|
||||
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).pgArrayPath(b)
|
||||
b.Comma().Arg(marshal(elems))
|
||||
b.Comma().WriteString("true")
|
||||
})
|
||||
} else {
|
||||
b.Arg(marshal(elems))
|
||||
}
|
||||
},
|
||||
Else: func(b *sql.Builder) {
|
||||
if len(opts) > 0 {
|
||||
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
identPath(column, opts...).pgArrayPath(b)
|
||||
b.Comma()
|
||||
path := identPath(column, opts...)
|
||||
path.value(b)
|
||||
b.WriteString(" || ").Arg(marshal(elems))
|
||||
b.Comma().WriteString("true")
|
||||
})
|
||||
} else {
|
||||
b.Ident(column).WriteString(" || ").Arg(marshal(elems))
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// driver groups all dialect-specific methods.
|
||||
type driver interface {
|
||||
Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option)
|
||||
}
|
||||
|
||||
func newDriver(name string) (driver, error) {
|
||||
switch name {
|
||||
case dialect.SQLite:
|
||||
return (*sqlite)(nil), nil
|
||||
case dialect.MySQL:
|
||||
return (*mysql)(nil), nil
|
||||
case dialect.Postgres:
|
||||
return (*postgres)(nil), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("sqljson: unknown driver %q", name)
|
||||
}
|
||||
}
|
||||
|
||||
type when struct{ Cond, Then, Else func(*sql.Builder) }
|
||||
|
||||
// setCase sets the column value using the "CASE WHEN" statement.
|
||||
// The x defines the condition/predicate, t is the true (if) case,
|
||||
// and 'f' defines the false (else).
|
||||
func setCase(u *sql.UpdateBuilder, column string, w when) {
|
||||
u.Set(column, sql.ExprFunc(func(b *sql.Builder) {
|
||||
b.WriteString("CASE WHEN ").Wrap(func(b *sql.Builder) {
|
||||
w.Cond(b)
|
||||
})
|
||||
b.WriteString(" THEN ")
|
||||
w.Then(b)
|
||||
b.WriteString(" ELSE ")
|
||||
w.Else(b)
|
||||
b.WriteString(" END")
|
||||
}))
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func ValueIsNull(column string, opts ...Option) *sql.Predicate {
|
||||
switch b.Dialect() {
|
||||
case dialect.MySQL:
|
||||
path := identPath(column, opts...)
|
||||
b.WriteString("JSON_CONTAINS").Nested(func(b *sql.Builder) {
|
||||
b.WriteString("JSON_CONTAINS").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
b.WriteString("'null'").Comma()
|
||||
path.mysqlPath(b)
|
||||
@@ -145,15 +145,15 @@ func ValueContains(column string, arg any, opts ...Option) *sql.Predicate {
|
||||
path := identPath(column, opts...)
|
||||
switch b.Dialect() {
|
||||
case dialect.MySQL:
|
||||
b.WriteString("JSON_CONTAINS").Nested(func(b *sql.Builder) {
|
||||
b.WriteString("JSON_CONTAINS").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
b.Arg(marshal(arg)).Comma()
|
||||
path.mysqlPath(b)
|
||||
})
|
||||
b.WriteOp(sql.OpEQ).Arg(1)
|
||||
case dialect.SQLite:
|
||||
b.WriteString("EXISTS").Nested(func(b *sql.Builder) {
|
||||
b.WriteString("SELECT * FROM JSON_EACH").Nested(func(b *sql.Builder) {
|
||||
b.WriteString("EXISTS").Wrap(func(b *sql.Builder) {
|
||||
b.WriteString("SELECT * FROM JSON_EACH").Wrap(func(b *sql.Builder) {
|
||||
b.Ident(column).Comma()
|
||||
path.mysqlPath(b)
|
||||
})
|
||||
@@ -227,7 +227,7 @@ func valueInOp(column string, args []any, opts []Option, op sql.Op) *sql.Predica
|
||||
}
|
||||
ValuePath(b, column, opts...)
|
||||
b.WriteOp(op)
|
||||
b.Nested(func(b *sql.Builder) {
|
||||
b.Wrap(func(b *sql.Builder) {
|
||||
if s, ok := args[0].(*sql.Selector); ok {
|
||||
b.Join(s)
|
||||
} else {
|
||||
@@ -325,6 +325,37 @@ func LenPath(b *sql.Builder, column string, opts ...Option) {
|
||||
path.length(b)
|
||||
}
|
||||
|
||||
// Append writes to the given SQL builder the SQL command for appending JSON values
|
||||
// into the array, optionally defined as a key. Note, the generated SQL will use the
|
||||
// Go semantics, the JSON column/key will be set to the given Array in case it is `null`
|
||||
// or NULL. For example:
|
||||
//
|
||||
// Append(u, column, []string{"a", "b"})
|
||||
// UPDATE "t" SET "c" = CASE
|
||||
// WHEN ("c" IS NULL OR "c" = 'null'::jsonb)
|
||||
// THEN $1 ELSE "c" || $2 END
|
||||
//
|
||||
// Append(u, column, []any{"a", 1}, sqljson.Path("a"))
|
||||
// UPDATE "t" SET "c" = CASE
|
||||
// WHEN (("c"->'a')::jsonb IS NULL OR ("c"->'a')::jsonb = 'null'::jsonb)
|
||||
// THEN jsonb_set("c", '{a}', $1, true) ELSE jsonb_set("c", '{a}', "c"->'a' || $2, true) END
|
||||
func Append[T any](u *sql.UpdateBuilder, column string, elems []T, opts ...Option) {
|
||||
if len(elems) == 0 {
|
||||
u.AddError(fmt.Errorf("sqljson: cannot append an empty array to column %q", column))
|
||||
return
|
||||
}
|
||||
drv, err := newDriver(u.Dialect())
|
||||
if err != nil {
|
||||
u.AddError(err)
|
||||
return
|
||||
}
|
||||
vs := make([]any, len(elems))
|
||||
for i, e := range elems {
|
||||
vs[i] = e
|
||||
}
|
||||
drv.Append(u, column, vs, opts...)
|
||||
}
|
||||
|
||||
// Option allows for calling database JSON paths with functional options.
|
||||
type Option func(*PathOptions)
|
||||
|
||||
@@ -359,7 +390,7 @@ func Unquote(unquote bool) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// Cast indicates that the result value should be casted to the given type.
|
||||
// Cast indicates that the result value should be cast to the given type.
|
||||
//
|
||||
// ValuePath(b, "column", Path("a", "b", "[1]", "c"), Cast("int"))
|
||||
func Cast(typ string) Option {
|
||||
@@ -395,7 +426,7 @@ func (p *PathOptions) value(b *sql.Builder) {
|
||||
b.WriteByte('(')
|
||||
defer b.WriteString(")::" + p.Cast)
|
||||
}
|
||||
p.pgPath(b)
|
||||
p.pgTextPath(b)
|
||||
default:
|
||||
if p.Unquote && b.Dialect() == dialect.MySQL {
|
||||
b.WriteString("JSON_UNQUOTE(")
|
||||
@@ -410,7 +441,7 @@ func (p *PathOptions) length(b *sql.Builder) {
|
||||
switch {
|
||||
case b.Dialect() == dialect.Postgres:
|
||||
b.WriteString("JSONB_ARRAY_LENGTH(")
|
||||
p.pgPath(b)
|
||||
p.pgTextPath(b)
|
||||
b.WriteByte(')')
|
||||
case b.Dialect() == dialect.MySQL:
|
||||
p.mysqlFunc("JSON_LENGTH", b)
|
||||
@@ -444,8 +475,8 @@ func (p *PathOptions) mysqlPath(b *sql.Builder) {
|
||||
b.WriteByte('\'')
|
||||
}
|
||||
|
||||
// pgPath writes the JSON path in Postgres format `"a"->'b'->>'c'`.
|
||||
func (p *PathOptions) pgPath(b *sql.Builder) {
|
||||
// pgTextPath writes the JSON path in PostgreSQL text format: `"a"->'b'->>'c'`.
|
||||
func (p *PathOptions) pgTextPath(b *sql.Builder) {
|
||||
b.Ident(p.Ident)
|
||||
for i, s := range p.Path {
|
||||
b.WriteString("->")
|
||||
@@ -460,6 +491,21 @@ func (p *PathOptions) pgPath(b *sql.Builder) {
|
||||
}
|
||||
}
|
||||
|
||||
// pgArrayPath writes the JSON path in PostgreSQL array text[] format: '{a,1,b}'.
|
||||
func (p *PathOptions) pgArrayPath(b *sql.Builder) {
|
||||
b.WriteString("'{")
|
||||
for i, s := range p.Path {
|
||||
if i > 0 {
|
||||
b.Comma()
|
||||
}
|
||||
if idx, ok := isJSONIdx(s); ok {
|
||||
s = idx
|
||||
}
|
||||
b.WriteString(s)
|
||||
}
|
||||
b.WriteString("}'")
|
||||
}
|
||||
|
||||
// ParsePath parses the "dotpath" for the DotPath option.
|
||||
//
|
||||
// "a.b" => ["a", "b"]
|
||||
@@ -553,7 +599,7 @@ func isQuoted(s string) bool {
|
||||
|
||||
// isJSONIdx reports whether the string represents a JSON index.
|
||||
func isJSONIdx(s string) (string, bool) {
|
||||
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) {
|
||||
if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && (isNumber(s[1:len(s)-1]) || s[1] == '#' && isNumber(s[2:len(s)-1])) {
|
||||
return s[1 : len(s)-1], true
|
||||
}
|
||||
return "", false
|
||||
|
||||
@@ -413,3 +413,73 @@ func TestParsePath(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppend(t *testing.T) {
|
||||
tests := []struct {
|
||||
input sql.Querier
|
||||
wantQuery string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
u := sql.Dialect(dialect.Postgres).Update("t")
|
||||
sqljson.Append(u, "c", []string{"a"})
|
||||
return u
|
||||
}(),
|
||||
wantQuery: `UPDATE "t" SET "c" = CASE WHEN ("c" IS NULL OR "c" = 'null'::jsonb) THEN $1 ELSE "c" || $2 END`,
|
||||
wantArgs: []any{`["a"]`, `["a"]`},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
u := sql.Dialect(dialect.Postgres).Update("t")
|
||||
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
|
||||
return u
|
||||
}(),
|
||||
wantQuery: `UPDATE "t" SET "c" = CASE WHEN (("c"->'a')::jsonb IS NULL OR ("c"->'a')::jsonb = 'null'::jsonb) THEN jsonb_set("c", '{a}', $1, true) ELSE jsonb_set("c", '{a}', "c"->'a' || $2, true) END`,
|
||||
wantArgs: []any{`["a"]`, `["a"]`},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
u := sql.Dialect(dialect.SQLite).Update("t")
|
||||
sqljson.Append(u, "c", []string{"a"})
|
||||
return u
|
||||
}(),
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$') IS NULL OR JSON_TYPE(`c`, '$') = 'null') THEN JSON_ARRAY(?) ELSE JSON_INSERT(`c`, '$[#]', ?) END",
|
||||
wantArgs: []any{"a", "a"},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
u := sql.Dialect(dialect.SQLite).Update("t")
|
||||
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
|
||||
return u
|
||||
}(),
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(`c`, '$.a') IS NULL OR JSON_TYPE(`c`, '$.a') = 'null') THEN JSON_SET(`c`, '$.a', JSON_ARRAY(?)) ELSE JSON_INSERT(`c`, '$.a[#]', ?) END",
|
||||
wantArgs: []any{"a", "a"},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
u := sql.Dialect(dialect.MySQL).Update("t")
|
||||
sqljson.Append(u, "c", []string{"a"})
|
||||
return u
|
||||
}(),
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(JSON_EXTRACT(`c`, '$')) IS NULL OR JSON_TYPE(JSON_EXTRACT(`c`, '$')) = 'NULL') THEN JSON_ARRAY(?) ELSE JSON_ARRAY_APPEND(`c`, '$', ?) END",
|
||||
wantArgs: []any{"a", "a"},
|
||||
},
|
||||
{
|
||||
input: func() sql.Querier {
|
||||
u := sql.Dialect(dialect.MySQL).Update("t")
|
||||
sqljson.Append(u, "c", []string{"a"}, sqljson.Path("a"))
|
||||
return u
|
||||
}(),
|
||||
wantQuery: "UPDATE `t` SET `c` = CASE WHEN (JSON_TYPE(JSON_EXTRACT(`c`, '$.a')) IS NULL OR JSON_TYPE(JSON_EXTRACT(`c`, '$.a')) = 'NULL') THEN JSON_SET(`c`, '$.a', JSON_ARRAY(?)) ELSE JSON_ARRAY_APPEND(`c`, '$.a', ?) END",
|
||||
wantArgs: []any{"a", "a"},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
query, args := tt.input.Query()
|
||||
require.Equal(t, tt.wantQuery, query)
|
||||
require.Equal(t, tt.wantArgs, args)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user