mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/sqljson: add dialect-aware Append function
This commit is contained in:
committed by
Ariel Mashraki
parent
5330f87759
commit
eb4ea68356
@@ -4,4 +4,4 @@
|
||||
|
||||
package ent
|
||||
|
||||
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by ent, DO NOT EDIT." ./schema
|
||||
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/modifier --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by ent, DO NOT EDIT." ./schema
|
||||
|
||||
@@ -53,8 +53,8 @@ type T struct {
|
||||
B bool `json:"b,omitempty"`
|
||||
S string `json:"s,omitempty"`
|
||||
T *T `json:"t,omitempty"`
|
||||
Li []int `json:"li,omitempty"`
|
||||
Ls []string `json:"ls,omitempty"`
|
||||
Li []int `json:"li"`
|
||||
Ls []string `json:"ls"`
|
||||
// Do not omit empty or null maps.
|
||||
M map[string]any `json:"m"`
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ type UserQuery struct {
|
||||
order []OrderFunc
|
||||
fields []string
|
||||
predicates []predicate.User
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -328,6 +329,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
nodes = append(nodes, node)
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(uq.modifiers) > 0 {
|
||||
_spec.Modifiers = uq.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -342,6 +346,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
|
||||
func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := uq.querySpec()
|
||||
if len(uq.modifiers) > 0 {
|
||||
_spec.Modifiers = uq.modifiers
|
||||
}
|
||||
_spec.Node.Columns = uq.fields
|
||||
if len(uq.fields) > 0 {
|
||||
_spec.Unique = uq.unique != nil && *uq.unique
|
||||
@@ -423,6 +430,9 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if uq.unique != nil && *uq.unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range uq.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range uq.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -440,6 +450,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (uq *UserQuery) Modify(modifiers ...func(s *sql.Selector)) *UserSelect {
|
||||
uq.modifiers = append(uq.modifiers, modifiers...)
|
||||
return uq.Select()
|
||||
}
|
||||
|
||||
// UserGroupBy is the group-by builder for User entities.
|
||||
type UserGroupBy struct {
|
||||
config
|
||||
@@ -531,3 +547,9 @@ func (us *UserSelect) sqlScan(ctx context.Context, v any) error {
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (us *UserSelect) Modify(modifiers ...func(s *sql.Selector)) *UserSelect {
|
||||
us.modifiers = append(us.modifiers, modifiers...)
|
||||
return us
|
||||
}
|
||||
|
||||
@@ -25,8 +25,9 @@ import (
|
||||
// UserUpdate is the builder for updating User entities.
|
||||
type UserUpdate struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *UserMutation
|
||||
hooks []Hook
|
||||
mutation *UserMutation
|
||||
modifiers []func(*sql.UpdateBuilder)
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserUpdate builder.
|
||||
@@ -192,6 +193,12 @@ func (uu *UserUpdate) ExecX(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Modify adds a statement modifier for attaching custom logic to the UPDATE statement.
|
||||
func (uu *UserUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *UserUpdate {
|
||||
uu.modifiers = append(uu.modifiers, modifiers...)
|
||||
return uu
|
||||
}
|
||||
|
||||
func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
||||
_spec := &sqlgraph.UpdateSpec{
|
||||
Node: &sqlgraph.NodeSpec{
|
||||
@@ -308,6 +315,7 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
||||
Column: user.FieldAddr,
|
||||
})
|
||||
}
|
||||
_spec.Modifiers = uu.modifiers
|
||||
if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{user.Label}
|
||||
@@ -322,9 +330,10 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
||||
// UserUpdateOne is the builder for updating a single User entity.
|
||||
type UserUpdateOne struct {
|
||||
config
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *UserMutation
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *UserMutation
|
||||
modifiers []func(*sql.UpdateBuilder)
|
||||
}
|
||||
|
||||
// SetT sets the "t" field.
|
||||
@@ -497,6 +506,12 @@ func (uuo *UserUpdateOne) ExecX(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Modify adds a statement modifier for attaching custom logic to the UPDATE statement.
|
||||
func (uuo *UserUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *UserUpdateOne {
|
||||
uuo.modifiers = append(uuo.modifiers, modifiers...)
|
||||
return uuo
|
||||
}
|
||||
|
||||
func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
||||
_spec := &sqlgraph.UpdateSpec{
|
||||
Node: &sqlgraph.NodeSpec{
|
||||
@@ -630,6 +645,7 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error)
|
||||
Column: user.FieldAddr,
|
||||
})
|
||||
}
|
||||
_spec.Modifiers = uuo.modifiers
|
||||
_node = &User{config: uuo.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
|
||||
@@ -47,12 +47,12 @@ func TestMySQL(t *testing.T) {
|
||||
Dirs(t, client)
|
||||
Ints(t, client)
|
||||
Floats(t, client)
|
||||
Strings(t, client)
|
||||
NetAddr(t, client)
|
||||
RawMessage(t, client)
|
||||
// Skip predicates test for MySQL old versions.
|
||||
// Skip tests with JSON functions for old MySQL versions.
|
||||
if version != "56" {
|
||||
Predicates(t, client)
|
||||
Strings(t, client)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -183,6 +183,38 @@ func Strings(t *testing.T, client *ent.Client) {
|
||||
require.Empty(t, usr.Strings)
|
||||
require.Empty(t, client.User.GetX(ctx, usr.ID).Strings)
|
||||
require.Zero(t, client.User.Query().Where(user.StringsNotNil()).CountX(ctx))
|
||||
|
||||
// Append to an empty array.
|
||||
usr.Update().SetStrings([]string{}).SetT(&schema.T{Ls: []string{}}).ExecX(ctx)
|
||||
usr = usr.Update().Modify(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, user.FieldStrings, []string{"foo"})
|
||||
sqljson.Append(u, user.FieldT, []string{"foo"}, sqljson.Path("ls"))
|
||||
}).SaveX(ctx)
|
||||
require.Equal(t, []string{"foo"}, usr.Strings)
|
||||
require.Equal(t, []string{"foo"}, usr.T.Ls)
|
||||
|
||||
// Set a 'null' (or an undefined) value.
|
||||
usr.Update().ClearStrings().ClearT().ExecX(ctx)
|
||||
usr.Update().SetStrings(nil).SetT(&schema.T{Ls: nil}).ExecX(ctx)
|
||||
usr = usr.Update().Modify(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, user.FieldStrings, []string{"foo"})
|
||||
sqljson.Append(u, user.FieldT, []string{"foo"}, sqljson.Path("ls"))
|
||||
}).SaveX(ctx)
|
||||
require.Equal(t, []string{"foo"}, usr.Strings)
|
||||
require.Equal(t, []string{"foo"}, usr.T.Ls)
|
||||
usr = usr.Update().Modify(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, user.FieldStrings, []string{"bar", "baz"})
|
||||
sqljson.Append(u, user.FieldT, []string{"bar", "baz"}, sqljson.Path("ls"))
|
||||
}).SaveX(ctx)
|
||||
require.Equal(t, []string{"foo", "bar", "baz"}, usr.Strings)
|
||||
require.Equal(t, []string{"foo", "bar", "baz"}, usr.T.Ls)
|
||||
|
||||
// Set a NULL (or an undefined) value.
|
||||
usr.Update().ClearStrings().ExecX(ctx)
|
||||
usr = usr.Update().Modify(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, user.FieldStrings, []string{"foo"})
|
||||
}).SaveX(ctx)
|
||||
require.Equal(t, []string{"foo"}, usr.Strings)
|
||||
}
|
||||
|
||||
func RawMessage(t *testing.T, client *ent.Client) {
|
||||
|
||||
Reference in New Issue
Block a user