dialect/sql/sqlgraph: avoid creating tx blocks for single UPDATE statements

In PostgreSQL, every statement is executed within a transaction. Therefore, we can avoid
creating transaction blocks manually (group of statements surrounded by BEGIN and COMMIT)
for UpdateNodes operation with a single UPDATE statement.

Benchmark for 2000 operations was improved from:

	7.98s      3992160 ns/op    4887 B/op    116 allocs/op

To:

	4.42s      2209659 ns/op    4435 B/op    104 allocs/op

---

MySQL and SQLite share the same behavior. Please see #1858 for more info.
This commit is contained in:
Ariel Mashraki
2021-09-11 08:12:50 +03:00
committed by Ariel Mashraki
parent 52fa73a0d5
commit 0864659844
2 changed files with 91 additions and 42 deletions

View File

@@ -439,17 +439,9 @@ func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error
// UpdateNodes applies the UpdateSpec on a set of nodes in the graph.
func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int, error) {
tx, err := drv.Tx(ctx)
if err != nil {
return 0, err
}
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
gr := graph{tx: drv, builder: sql.Dialect(drv.Dialect())}
cr := &updater{UpdateSpec: spec, graph: gr}
affected, err := cr.nodes(ctx, tx)
if err != nil {
return 0, rollback(tx, err)
}
return affected, tx.Commit()
return cr.nodes(ctx, drv)
}
// NotFoundError returns when trying to update an
@@ -706,9 +698,8 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
return u.scan(rows)
}
func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error) {
func (u *updater) nodes(ctx context.Context, drv dialect.Driver) (int, error) {
var (
ids []driver.Value
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
multiple = hasExternalEdges(addEdges, clearEdges)
@@ -717,13 +708,28 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).
WithContext(ctx)
)
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
return 0, err
}
if pred := u.Predicate; pred != nil {
pred(selector)
}
// If this change-set contains multiple table updates.
if multiple {
query, args := selector.Query()
rows := &sql.Rows{}
// In case of single statement update, avoid opening a transaction manually.
if !multiple {
update.FromSelect(selector)
return u.updateTable(ctx, update)
}
tx, err := drv.Tx(ctx)
if err != nil {
return 0, err
}
u.tx = tx
affected, err := func() (int, error) {
var (
ids []driver.Value
rows = &sql.Rows{}
query, args = selector.Query()
)
if err := u.tx.Query(ctx, query, args, rows); err != nil {
return 0, fmt.Errorf("querying table %s: %w", u.Node.Table, err)
}
@@ -738,32 +744,39 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
return 0, nil
}
update.Where(matchID(u.Node.ID.Column, ids))
} else {
update.FromSelect(selector)
}
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
return 0, err
}
if !update.Empty() {
var res sql.Result
query, args := update.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
// In case of multi statement update, that change can
// affect more than 1 table, and therefore, we return
// the list of ids as number of affected records.
if _, err := u.updateTable(ctx, update); err != nil {
return 0, err
}
if !multiple {
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return int(affected), nil
}
}
if len(ids) > 0 {
if err := u.setExternalEdges(ctx, ids, addEdges, clearEdges); err != nil {
return 0, err
}
return len(ids), nil
}()
if err != nil {
return 0, rollback(tx, err)
}
return len(ids), nil
return affected, tx.Commit()
}
func (u *updater) updateTable(ctx context.Context, stmt *sql.UpdateBuilder) (int, error) {
if stmt.Empty() {
return 0, nil
}
var (
res sql.Result
query, args = stmt.Query()
)
if err := u.tx.Exec(ctx, query, args, &res); err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return int(affected), nil
}
func (u *updater) setExternalEdges(ctx context.Context, ids []driver.Value, addEdges, clearEdges map[Rel][]*EdgeSpec) error {

View File

@@ -1637,12 +1637,10 @@ func TestUpdateNodes(t *testing.T) {
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
// Apply field changes.
mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ?")).
WithArgs(30, "Ariel").
WillReturnResult(sqlmock.NewResult(0, 2))
mock.ExpectCommit()
},
wantAffected: 2,
},
@@ -1664,12 +1662,10 @@ func TestUpdateNodes(t *testing.T) {
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
// Clear fields.
mock.ExpectExec(escape("UPDATE `users` SET `age` = NULL, `name` = NULL WHERE `name` = ?")).
WithArgs("a8m").
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
wantAffected: 1,
},
@@ -1692,15 +1688,55 @@ func TestUpdateNodes(t *testing.T) {
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
// Clear "car" and "workplace" foreign_keys and add "card" and a "parent".
mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ?")).
WithArgs(4, 3).
WillReturnResult(sqlmock.NewResult(0, 3))
mock.ExpectCommit()
},
wantAffected: 3,
},
{
name: "o2m",
spec: &UpdateSpec{
Node: &NodeSpec{
Table: "users",
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
},
Fields: FieldMut{
Add: []*FieldSpec{
{Column: "version", Type: field.TypeInt, Value: 1},
},
},
Edges: EdgeMut{
Clear: []*EdgeSpec{
{Rel: O2M, Table: "cards", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{20, 30}, IDSpec: &FieldSpec{Column: "id"}}},
},
Add: []*EdgeSpec{
{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{40}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
// Get all node ids first.
mock.ExpectQuery(escape("SELECT `id` FROM `users`")).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(10))
mock.ExpectExec(escape("UPDATE `users` SET `version` = COALESCE(`version`, ?) + ? WHERE `id` = ?")).
WithArgs(0, 1, 10).
WillReturnResult(sqlmock.NewResult(0, 1))
// Clear "owner_id" column in the "cards" table.
mock.ExpectExec(escape("UPDATE `cards` SET `owner_id` = NULL WHERE `id` IN (?, ?) AND `owner_id` = ?")).
WithArgs(20, 30, 10).
WillReturnResult(sqlmock.NewResult(0, 2))
// Set "owner_id" column in the "pets" table.
mock.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")).
WithArgs(10, 40).
WillReturnResult(sqlmock.NewResult(0, 2))
mock.ExpectCommit()
},
wantAffected: 1,
},
{
name: "m2m_one",
spec: &UpdateSpec{