mirror of
https://github.com/ent/ent.git
synced 2026-04-29 22:20:54 +03:00
dialect/sql/sqlgraph: avoid query on update when it's not needed (#932)
Closed #909
This commit is contained in:
@@ -659,28 +659,35 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
|
||||
ids []driver.Value
|
||||
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
|
||||
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
|
||||
multiple = u.hasExternalEdges(addEdges, clearEdges)
|
||||
update = u.builder.Update(u.Node.Table)
|
||||
selector = u.builder.Select(u.Node.ID.Column).
|
||||
From(u.builder.Table(u.Node.Table))
|
||||
)
|
||||
selector := u.builder.Select(u.Node.ID.Column).
|
||||
From(u.builder.Table(u.Node.Table))
|
||||
if pred := u.Predicate; pred != nil {
|
||||
pred(selector)
|
||||
}
|
||||
query, args := selector.Query()
|
||||
rows := &sql.Rows{}
|
||||
if err := u.tx.Query(ctx, query, args, rows); err != nil {
|
||||
return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err)
|
||||
// If this change-set contains multiple table updates.
|
||||
if multiple {
|
||||
query, args := selector.Query()
|
||||
rows := &sql.Rows{}
|
||||
if err := u.tx.Query(ctx, query, args, rows); err != nil {
|
||||
return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := sql.ScanSlice(rows, &ids); err != nil {
|
||||
return 0, fmt.Errorf("scan node ids: %v", err)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
update.Where(matchID(u.Node.ID.Column, ids))
|
||||
} else {
|
||||
update.FromSelect(selector)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := sql.ScanSlice(rows, &ids); err != nil {
|
||||
return 0, fmt.Errorf("scan node ids: %v", err)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
update := u.builder.Update(u.Node.Table).Where(matchID(u.Node.ID.Column, ids))
|
||||
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -690,6 +697,13 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !multiple {
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(affected), nil
|
||||
}
|
||||
}
|
||||
if err := u.setExternalEdges(ctx, ids, addEdges, clearEdges); err != nil {
|
||||
return 0, err
|
||||
@@ -713,6 +727,23 @@ func (u *updater) setExternalEdges(ctx context.Context, ids []driver.Value, addE
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*updater) hasExternalEdges(addEdges, clearEdges map[Rel][]*EdgeSpec) bool {
|
||||
// M2M edges reside in a join-table, and O2M edges reside
|
||||
// in the M2O table (the entity that holds the FK).
|
||||
if len(clearEdges[M2M]) > 0 || len(addEdges[M2M]) > 0 ||
|
||||
len(clearEdges[O2M]) > 0 || len(addEdges[O2M]) > 0 {
|
||||
return true
|
||||
}
|
||||
for _, edges := range [][]*EdgeSpec{clearEdges[O2O], addEdges[O2O]} {
|
||||
for _, e := range edges {
|
||||
if !e.Inverse {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error {
|
||||
// Avoid multiple assignments to the same column.
|
||||
|
||||
@@ -1102,14 +1102,9 @@ func TestUpdateNodes(t *testing.T) {
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Get all node ids first.
|
||||
mock.ExpectQuery(escape("SELECT `id` FROM `users`")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(1).
|
||||
AddRow(2))
|
||||
// Apply field changes.
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ? WHERE `id` IN (?, ?)")).
|
||||
WithArgs(30, "Ariel", 1, 2).
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ?")).
|
||||
WithArgs(30, "Ariel").
|
||||
WillReturnResult(sqlmock.NewResult(0, 2))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
@@ -1134,14 +1129,9 @@ func TestUpdateNodes(t *testing.T) {
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Get all node ids first.
|
||||
mock.ExpectQuery(escape("SELECT `id` FROM `users` WHERE `name` = ?")).
|
||||
WithArgs("a8m").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(1))
|
||||
// Clear fields.
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `age` = NULL, `name` = NULL WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `age` = NULL, `name` = NULL WHERE `name` = ?")).
|
||||
WithArgs("a8m").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
@@ -1167,17 +1157,13 @@ func TestUpdateNodes(t *testing.T) {
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Get all node ids first.
|
||||
mock.ExpectQuery(escape("SELECT `id` FROM `users`")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(1))
|
||||
// Clear "car" and "workplace" foreign_keys and add "card" and a "parent".
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ? WHERE `id` = ?")).
|
||||
WithArgs(4, 3, 1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ?")).
|
||||
WithArgs(4, 3).
|
||||
WillReturnResult(sqlmock.NewResult(0, 3))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantAffected: 1,
|
||||
wantAffected: 3,
|
||||
},
|
||||
{
|
||||
name: "m2m_one",
|
||||
@@ -1406,5 +1392,5 @@ func escape(query string) string {
|
||||
rows[i] = strings.TrimPrefix(rows[i], " ")
|
||||
}
|
||||
query = strings.Join(rows, " ")
|
||||
return regexp.QuoteMeta(query)
|
||||
return strings.TrimSpace(regexp.QuoteMeta(query)) + "$"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user