dialect/sql/sqljson: add dialect-aware Append function

This commit is contained in:
Ariel Mashraki
2022-09-28 12:07:10 +03:00
committed by Ariel Mashraki
parent 5330f87759
commit eb4ea68356
10 changed files with 422 additions and 47 deletions

View File

@@ -4,4 +4,4 @@
package ent
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by ent, DO NOT EDIT." ./schema
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/modifier --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by ent, DO NOT EDIT." ./schema

View File

@@ -53,8 +53,8 @@ type T struct {
B bool `json:"b,omitempty"`
S string `json:"s,omitempty"`
T *T `json:"t,omitempty"`
Li []int `json:"li,omitempty"`
Ls []string `json:"ls,omitempty"`
Li []int `json:"li"`
Ls []string `json:"ls"`
// Do not omit empty or null maps.
M map[string]any `json:"m"`
}

View File

@@ -27,6 +27,7 @@ type UserQuery struct {
order []OrderFunc
fields []string
predicates []predicate.User
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -328,6 +329,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
if len(uq.modifiers) > 0 {
_spec.Modifiers = uq.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
@@ -342,6 +346,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) {
_spec := uq.querySpec()
if len(uq.modifiers) > 0 {
_spec.Modifiers = uq.modifiers
}
_spec.Node.Columns = uq.fields
if len(uq.fields) > 0 {
_spec.Unique = uq.unique != nil && *uq.unique
@@ -423,6 +430,9 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
if uq.unique != nil && *uq.unique {
selector.Distinct()
}
for _, m := range uq.modifiers {
m(selector)
}
for _, p := range uq.predicates {
p(selector)
}
@@ -440,6 +450,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (uq *UserQuery) Modify(modifiers ...func(s *sql.Selector)) *UserSelect {
uq.modifiers = append(uq.modifiers, modifiers...)
return uq.Select()
}
// UserGroupBy is the group-by builder for User entities.
type UserGroupBy struct {
config
@@ -531,3 +547,9 @@ func (us *UserSelect) sqlScan(ctx context.Context, v any) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (us *UserSelect) Modify(modifiers ...func(s *sql.Selector)) *UserSelect {
us.modifiers = append(us.modifiers, modifiers...)
return us
}

View File

@@ -25,8 +25,9 @@ import (
// UserUpdate is the builder for updating User entities.
type UserUpdate struct {
config
hooks []Hook
mutation *UserMutation
hooks []Hook
mutation *UserMutation
modifiers []func(*sql.UpdateBuilder)
}
// Where appends a list predicates to the UserUpdate builder.
@@ -192,6 +193,12 @@ func (uu *UserUpdate) ExecX(ctx context.Context) {
}
}
// Modify adds a statement modifier for attaching custom logic to the UPDATE statement.
func (uu *UserUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *UserUpdate {
uu.modifiers = append(uu.modifiers, modifiers...)
return uu
}
func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
@@ -308,6 +315,7 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
Column: user.FieldAddr,
})
}
_spec.Modifiers = uu.modifiers
if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -322,9 +330,10 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
// UserUpdateOne is the builder for updating a single User entity.
type UserUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserMutation
fields []string
hooks []Hook
mutation *UserMutation
modifiers []func(*sql.UpdateBuilder)
}
// SetT sets the "t" field.
@@ -497,6 +506,12 @@ func (uuo *UserUpdateOne) ExecX(ctx context.Context) {
}
}
// Modify adds a statement modifier for attaching custom logic to the UPDATE statement.
func (uuo *UserUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *UserUpdateOne {
uuo.modifiers = append(uuo.modifiers, modifiers...)
return uuo
}
func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
@@ -630,6 +645,7 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error)
Column: user.FieldAddr,
})
}
_spec.Modifiers = uuo.modifiers
_node = &User{config: uuo.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -47,12 +47,12 @@ func TestMySQL(t *testing.T) {
Dirs(t, client)
Ints(t, client)
Floats(t, client)
Strings(t, client)
NetAddr(t, client)
RawMessage(t, client)
// Skip predicates test for MySQL old versions.
// Skip tests with JSON functions for old MySQL versions.
if version != "56" {
Predicates(t, client)
Strings(t, client)
}
})
}
@@ -183,6 +183,38 @@ func Strings(t *testing.T, client *ent.Client) {
require.Empty(t, usr.Strings)
require.Empty(t, client.User.GetX(ctx, usr.ID).Strings)
require.Zero(t, client.User.Query().Where(user.StringsNotNil()).CountX(ctx))
// Append to an empty array.
usr.Update().SetStrings([]string{}).SetT(&schema.T{Ls: []string{}}).ExecX(ctx)
usr = usr.Update().Modify(func(u *sql.UpdateBuilder) {
sqljson.Append(u, user.FieldStrings, []string{"foo"})
sqljson.Append(u, user.FieldT, []string{"foo"}, sqljson.Path("ls"))
}).SaveX(ctx)
require.Equal(t, []string{"foo"}, usr.Strings)
require.Equal(t, []string{"foo"}, usr.T.Ls)
// Set a 'null' (or an undefined) value.
usr.Update().ClearStrings().ClearT().ExecX(ctx)
usr.Update().SetStrings(nil).SetT(&schema.T{Ls: nil}).ExecX(ctx)
usr = usr.Update().Modify(func(u *sql.UpdateBuilder) {
sqljson.Append(u, user.FieldStrings, []string{"foo"})
sqljson.Append(u, user.FieldT, []string{"foo"}, sqljson.Path("ls"))
}).SaveX(ctx)
require.Equal(t, []string{"foo"}, usr.Strings)
require.Equal(t, []string{"foo"}, usr.T.Ls)
usr = usr.Update().Modify(func(u *sql.UpdateBuilder) {
sqljson.Append(u, user.FieldStrings, []string{"bar", "baz"})
sqljson.Append(u, user.FieldT, []string{"bar", "baz"}, sqljson.Path("ls"))
}).SaveX(ctx)
require.Equal(t, []string{"foo", "bar", "baz"}, usr.Strings)
require.Equal(t, []string{"foo", "bar", "baz"}, usr.T.Ls)
// Set a NULL (or an undefined) value.
usr.Update().ClearStrings().ExecX(ctx)
usr = usr.Update().Modify(func(u *sql.UpdateBuilder) {
sqljson.Append(u, user.FieldStrings, []string{"foo"})
}).SaveX(ctx)
require.Equal(t, []string{"foo"}, usr.Strings)
}
func RawMessage(t *testing.T, client *ent.Client) {