dialect/sql/sqlgraph: support custom modifiers in UPDATE commands

This commit is contained in:
Ariel Mashraki
2022-08-05 11:45:30 +03:00
committed by Ariel Mashraki
parent 9f481d8716
commit 0fd641333c
2 changed files with 59 additions and 1 deletions

View File

@@ -134,7 +134,6 @@ func Edge(rel Rel, inverse bool, table string, columns ...string) StepOption {
// To("table", "pk"),
// Edge("name", O2M, "fk"),
// )
//
func NewStep(opts ...StepOption) *Step {
s := &Step{}
for _, opt := range opts {
@@ -409,6 +408,7 @@ type (
Edges EdgeMut
Fields FieldMut
Predicate func(*sql.Selector)
Modifiers []func(*sql.UpdateBuilder)
ScanValues func(columns []string) ([]interface{}, error)
Assign func(columns []string, values []interface{}) error
@@ -686,6 +686,12 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
return err
}
for _, m := range u.Modifiers {
m(update)
}
if err := update.Err(); err != nil {
return err
}
if !update.Empty() {
var res sql.Result
query, args := update.Query()
@@ -793,6 +799,12 @@ func (u *updater) nodes(ctx context.Context, drv dialect.Driver) (int, error) {
}
func (u *updater) updateTable(ctx context.Context, stmt *sql.UpdateBuilder) (int, error) {
for _, m := range u.Modifiers {
m(stmt)
}
if err := stmt.Err(); err != nil {
return 0, err
}
if stmt.Empty() {
return 0, nil
}

View File

@@ -1571,6 +1571,33 @@ func TestUpdateNode(t *testing.T) {
},
wantUser: &user{name: "Ariel", age: 30, id: 1},
},
{
name: "fields/set_modifier",
spec: &UpdateSpec{
Node: &NodeSpec{
Table: "users",
Columns: []string{"id", "name", "age"},
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
},
Modifiers: []func(*sql.UpdateBuilder){
func(u *sql.UpdateBuilder) {
u.Set("name", sql.Expr(sql.Lower("name")))
},
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec(escape("UPDATE `users` SET `name` = LOWER(`name`) WHERE `id` = ?")).
WithArgs(1).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
AddRow(1, 30, "Ariel"))
mock.ExpectCommit()
},
wantUser: &user{name: "Ariel", age: 30, id: 1},
},
{
name: "fields/add_set_clear",
spec: &UpdateSpec{
@@ -1909,6 +1936,25 @@ func TestUpdateNodes(t *testing.T) {
},
wantAffected: 1,
},
{
name: "with modifier",
spec: &UpdateSpec{
Node: &NodeSpec{
Table: "users",
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
},
Modifiers: []func(*sql.UpdateBuilder){
func(u *sql.UpdateBuilder) {
u.Set("id", sql.Expr("id + 1")).OrderBy("id")
},
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(escape("UPDATE `users` SET `id` = id + 1 ORDER BY `id`")).
WillReturnResult(sqlmock.NewResult(0, 1))
},
wantAffected: 1,
},
{
name: "own_fks/m2o_o2o_inverse",
spec: &UpdateSpec{