mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/sqlgraph: support custom modifiers in UPDATE commands
This commit is contained in:
committed by
Ariel Mashraki
parent
9f481d8716
commit
0fd641333c
@@ -134,7 +134,6 @@ func Edge(rel Rel, inverse bool, table string, columns ...string) StepOption {
|
||||
// To("table", "pk"),
|
||||
// Edge("name", O2M, "fk"),
|
||||
// )
|
||||
//
|
||||
func NewStep(opts ...StepOption) *Step {
|
||||
s := &Step{}
|
||||
for _, opt := range opts {
|
||||
@@ -409,6 +408,7 @@ type (
|
||||
Edges EdgeMut
|
||||
Fields FieldMut
|
||||
Predicate func(*sql.Selector)
|
||||
Modifiers []func(*sql.UpdateBuilder)
|
||||
|
||||
ScanValues func(columns []string) ([]interface{}, error)
|
||||
Assign func(columns []string, values []interface{}) error
|
||||
@@ -686,6 +686,12 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, m := range u.Modifiers {
|
||||
m(update)
|
||||
}
|
||||
if err := update.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if !update.Empty() {
|
||||
var res sql.Result
|
||||
query, args := update.Query()
|
||||
@@ -793,6 +799,12 @@ func (u *updater) nodes(ctx context.Context, drv dialect.Driver) (int, error) {
|
||||
}
|
||||
|
||||
func (u *updater) updateTable(ctx context.Context, stmt *sql.UpdateBuilder) (int, error) {
|
||||
for _, m := range u.Modifiers {
|
||||
m(stmt)
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if stmt.Empty() {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -1571,6 +1571,33 @@ func TestUpdateNode(t *testing.T) {
|
||||
},
|
||||
wantUser: &user{name: "Ariel", age: 30, id: 1},
|
||||
},
|
||||
{
|
||||
name: "fields/set_modifier",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
Columns: []string{"id", "name", "age"},
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Modifiers: []func(*sql.UpdateBuilder){
|
||||
func(u *sql.UpdateBuilder) {
|
||||
u.Set("name", sql.Expr(sql.Lower("name")))
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `name` = LOWER(`name`) WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
|
||||
AddRow(1, 30, "Ariel"))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantUser: &user{name: "Ariel", age: 30, id: 1},
|
||||
},
|
||||
{
|
||||
name: "fields/add_set_clear",
|
||||
spec: &UpdateSpec{
|
||||
@@ -1909,6 +1936,25 @@ func TestUpdateNodes(t *testing.T) {
|
||||
},
|
||||
wantAffected: 1,
|
||||
},
|
||||
{
|
||||
name: "with modifier",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
|
||||
},
|
||||
Modifiers: []func(*sql.UpdateBuilder){
|
||||
func(u *sql.UpdateBuilder) {
|
||||
u.Set("id", sql.Expr("id + 1")).OrderBy("id")
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `id` = id + 1 ORDER BY `id`")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
},
|
||||
wantAffected: 1,
|
||||
},
|
||||
{
|
||||
name: "own_fks/m2o_o2o_inverse",
|
||||
spec: &UpdateSpec{
|
||||
|
||||
Reference in New Issue
Block a user