mirror of
https://github.com/ent/ent.git
synced 2026-04-28 21:50:56 +03:00
223 lines
5.6 KiB
Go
223 lines
5.6 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package sqljson
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"entgo.io/ent/dialect"
|
|
"entgo.io/ent/dialect/sql"
|
|
)
|
|
|
|
type sqlite struct{}
|
|
|
|
// Append implements the driver.Append method.
|
|
func (d *sqlite) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
|
setCase(u, column, when{
|
|
Cond: func(b *sql.Builder) {
|
|
typ := func(b *sql.Builder) *sql.Builder {
|
|
return b.WriteString("JSON_TYPE").Wrap(func(b *sql.Builder) {
|
|
b.Ident(column).Comma()
|
|
identPath(column, opts...).mysqlPath(b)
|
|
})
|
|
}
|
|
typ(b).WriteOp(sql.OpIsNull)
|
|
b.WriteString(" OR ")
|
|
typ(b).WriteOp(sql.OpEQ).WriteString("'null'")
|
|
},
|
|
Then: func(b *sql.Builder) {
|
|
if len(opts) > 0 {
|
|
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
|
|
b.Ident(column).Comma()
|
|
identPath(column, opts...).mysqlPath(b)
|
|
b.Comma().Argf("JSON(?)", marshalArg(elems))
|
|
})
|
|
} else {
|
|
b.Arg(marshalArg(elems))
|
|
}
|
|
},
|
|
Else: func(b *sql.Builder) {
|
|
b.WriteString("JSON_INSERT").Wrap(func(b *sql.Builder) {
|
|
b.Ident(column).Comma()
|
|
// If no path was provided the top-level value is
|
|
// a JSON array. i.e. JSON_INSERT(c, '$[#]', ?).
|
|
path := func(b *sql.Builder) { b.WriteString("'$[#]'") }
|
|
if len(opts) > 0 {
|
|
p := identPath(column, opts...)
|
|
p.Path = append(p.Path, "[#]")
|
|
path = p.mysqlPath
|
|
}
|
|
for i, e := range elems {
|
|
if i > 0 {
|
|
b.Comma()
|
|
}
|
|
path(b)
|
|
b.Comma()
|
|
d.appendArg(b, e)
|
|
}
|
|
})
|
|
},
|
|
})
|
|
}
|
|
|
|
func (d *sqlite) appendArg(b *sql.Builder, v any) {
|
|
switch {
|
|
case !isPrimitive(v):
|
|
b.Argf("JSON(?)", marshalArg(v))
|
|
default:
|
|
b.Arg(v)
|
|
}
|
|
}
|
|
|
|
type mysql struct{}
|
|
|
|
// Append implements the driver.Append method.
|
|
func (d *mysql) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
|
setCase(u, column, when{
|
|
Cond: func(b *sql.Builder) {
|
|
typ := func(b *sql.Builder) *sql.Builder {
|
|
b.WriteString("JSON_TYPE(JSON_EXTRACT(")
|
|
b.Ident(column).Comma()
|
|
identPath(column, opts...).mysqlPath(b)
|
|
return b.WriteString("))")
|
|
}
|
|
typ(b).WriteOp(sql.OpIsNull)
|
|
b.WriteString(" OR ")
|
|
typ(b).WriteOp(sql.OpEQ).WriteString("'NULL'")
|
|
},
|
|
Then: func(b *sql.Builder) {
|
|
if len(opts) > 0 {
|
|
b.WriteString("JSON_SET").Wrap(func(b *sql.Builder) {
|
|
b.Ident(column).Comma()
|
|
identPath(column, opts...).mysqlPath(b)
|
|
b.Comma().WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
|
|
})
|
|
} else {
|
|
b.WriteString("JSON_ARRAY(").Args(d.marshalArgs(elems)...).WriteByte(')')
|
|
}
|
|
},
|
|
Else: func(b *sql.Builder) {
|
|
b.WriteString("JSON_ARRAY_APPEND").Wrap(func(b *sql.Builder) {
|
|
b.Ident(column).Comma()
|
|
for i, e := range elems {
|
|
if i > 0 {
|
|
b.Comma()
|
|
}
|
|
identPath(column, opts...).mysqlPath(b)
|
|
b.Comma()
|
|
d.appendArg(b, e)
|
|
}
|
|
})
|
|
},
|
|
})
|
|
}
|
|
|
|
func (d *mysql) marshalArgs(args []any) []any {
|
|
vs := make([]any, len(args))
|
|
for i, v := range args {
|
|
if !isPrimitive(v) {
|
|
v = marshalArg(v)
|
|
}
|
|
vs[i] = v
|
|
}
|
|
return vs
|
|
}
|
|
|
|
func (d *mysql) appendArg(b *sql.Builder, v any) {
|
|
switch {
|
|
case !isPrimitive(v):
|
|
b.Argf("CAST(? AS JSON)", marshalArg(v))
|
|
default:
|
|
b.Arg(v)
|
|
}
|
|
}
|
|
|
|
type postgres struct{}
|
|
|
|
// Append implements the driver.Append method.
|
|
func (*postgres) Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option) {
|
|
setCase(u, column, when{
|
|
Cond: func(b *sql.Builder) {
|
|
valuePath(b, column, append(opts, Cast("jsonb"))...)
|
|
b.WriteOp(sql.OpIsNull)
|
|
b.WriteString(" OR ")
|
|
valuePath(b, column, append(opts, Cast("jsonb"))...)
|
|
b.WriteOp(sql.OpEQ).WriteString("'null'::jsonb")
|
|
},
|
|
Then: func(b *sql.Builder) {
|
|
if len(opts) > 0 {
|
|
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
|
|
b.Ident(column).Comma()
|
|
identPath(column, opts...).pgArrayPath(b)
|
|
b.Comma().Arg(marshalArg(elems))
|
|
b.Comma().WriteString("true")
|
|
})
|
|
} else {
|
|
b.Arg(marshalArg(elems))
|
|
}
|
|
},
|
|
Else: func(b *sql.Builder) {
|
|
if len(opts) > 0 {
|
|
b.WriteString("jsonb_set").Wrap(func(b *sql.Builder) {
|
|
b.Ident(column).Comma()
|
|
identPath(column, opts...).pgArrayPath(b)
|
|
b.Comma()
|
|
path := identPath(column, opts...)
|
|
path.value(b)
|
|
b.WriteString(" || ").Arg(marshalArg(elems))
|
|
b.Comma().WriteString("true")
|
|
})
|
|
} else {
|
|
b.Ident(column).WriteString(" || ").Arg(marshalArg(elems))
|
|
}
|
|
},
|
|
})
|
|
}
|
|
|
|
// driver groups all dialect-specific methods.
|
|
type driver interface {
|
|
Append(u *sql.UpdateBuilder, column string, elems []any, opts ...Option)
|
|
}
|
|
|
|
func newDriver(name string) (driver, error) {
|
|
switch name {
|
|
case dialect.SQLite:
|
|
return (*sqlite)(nil), nil
|
|
case dialect.MySQL:
|
|
return (*mysql)(nil), nil
|
|
case dialect.Postgres:
|
|
return (*postgres)(nil), nil
|
|
default:
|
|
return nil, fmt.Errorf("sqljson: unknown driver %q", name)
|
|
}
|
|
}
|
|
|
|
type when struct{ Cond, Then, Else func(*sql.Builder) }
|
|
|
|
// setCase sets the column value using the "CASE WHEN" statement.
|
|
// The x defines the condition/predicate, t is the true (if) case,
|
|
// and 'f' defines the false (else).
|
|
func setCase(u *sql.UpdateBuilder, column string, w when) {
|
|
u.Set(column, sql.ExprFunc(func(b *sql.Builder) {
|
|
b.WriteString("CASE WHEN ").Wrap(func(b *sql.Builder) {
|
|
w.Cond(b)
|
|
})
|
|
b.WriteString(" THEN ")
|
|
w.Then(b)
|
|
b.WriteString(" ELSE ")
|
|
w.Else(b)
|
|
b.WriteString(" END")
|
|
}))
|
|
}
|
|
|
|
func isPrimitive(v any) bool {
|
|
switch reflect.TypeOf(v).Kind() {
|
|
case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct, reflect.Ptr, reflect.Interface:
|
|
return false
|
|
}
|
|
return true
|
|
}
|