dialect/sql/sqlgraph: avoid query on update when it's not needed (#932)

Closed #909
This commit is contained in:
Ariel Mashraki
2020-11-10 10:01:40 +02:00
committed by GitHub
parent 762df65f11
commit e775227a11
3 changed files with 67 additions and 41 deletions

View File

@@ -659,28 +659,35 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
ids []driver.Value
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
multiple = u.hasExternalEdges(addEdges, clearEdges)
update = u.builder.Update(u.Node.Table)
selector = u.builder.Select(u.Node.ID.Column).
From(u.builder.Table(u.Node.Table))
)
selector := u.builder.Select(u.Node.ID.Column).
From(u.builder.Table(u.Node.Table))
if pred := u.Predicate; pred != nil {
pred(selector)
}
query, args := selector.Query()
rows := &sql.Rows{}
if err := u.tx.Query(ctx, query, args, rows); err != nil {
return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err)
// If this change-set contains multiple table updates.
if multiple {
query, args := selector.Query()
rows := &sql.Rows{}
if err := u.tx.Query(ctx, query, args, rows); err != nil {
return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err)
}
defer rows.Close()
if err := sql.ScanSlice(rows, &ids); err != nil {
return 0, fmt.Errorf("scan node ids: %v", err)
}
if err := rows.Close(); err != nil {
return 0, err
}
if len(ids) == 0 {
return 0, nil
}
update.Where(matchID(u.Node.ID.Column, ids))
} else {
update.FromSelect(selector)
}
defer rows.Close()
if err := sql.ScanSlice(rows, &ids); err != nil {
return 0, fmt.Errorf("scan node ids: %v", err)
}
if err := rows.Close(); err != nil {
return 0, err
}
if len(ids) == 0 {
return 0, nil
}
update := u.builder.Update(u.Node.Table).Where(matchID(u.Node.ID.Column, ids))
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
return 0, err
}
@@ -690,6 +697,13 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, err
}
if !multiple {
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return int(affected), nil
}
}
if err := u.setExternalEdges(ctx, ids, addEdges, clearEdges); err != nil {
return 0, err
@@ -713,6 +727,23 @@ func (u *updater) setExternalEdges(ctx context.Context, ids []driver.Value, addE
return nil
}
func (*updater) hasExternalEdges(addEdges, clearEdges map[Rel][]*EdgeSpec) bool {
// M2M edges reside in a join-table, and O2M edges reside
// in the M2O table (the entity that holds the FK).
if len(clearEdges[M2M]) > 0 || len(addEdges[M2M]) > 0 ||
len(clearEdges[O2M]) > 0 || len(addEdges[O2M]) > 0 {
return true
}
for _, edges := range [][]*EdgeSpec{clearEdges[O2O], addEdges[O2O]} {
for _, e := range edges {
if !e.Inverse {
return true
}
}
}
return false
}
// setTableColumns sets the table columns and foreign_keys used in insert.
func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error {
// Avoid multiple assignments to the same column.