dialect/sql/sqlgraph: better support for update nodes with predicates (#2574)

This commit is contained in:
Ariel Mashraki
2022-05-29 16:23:52 +03:00
committed by GitHub
parent 0917701f91
commit 5b81d7d832
7 changed files with 116 additions and 21 deletions

View File

@@ -19,7 +19,7 @@ import (
"entgo.io/ent/schema/field"
)
// Rel is a relation type of an edge.
// Rel is an edge relation type.
type Rel int
// Relation types.
@@ -671,10 +671,22 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
return err
}
if !update.Empty() {
var res sql.Result
query, args := update.Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
if err := tx.Exec(ctx, query, args, &res); err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
// In case there are zero affected rows by this statement, we need to distinguish
// between the case of "record was not found" and "record was not changed".
if affected == 0 && u.Predicate != nil {
if err := u.ensureExists(ctx); err != nil {
return err
}
}
}
if err := u.setExternalEdges(ctx, []driver.Value{id}, addEdges, clearEdges); err != nil {
return err
@@ -686,10 +698,9 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
}
selector := u.builder.Select(u.Node.Columns...).
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).
// Skip adding the custom predicates that were attached to the updater
// as they may point to columns that were changed by the UPDATE statement.
Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value))
if pred := u.Predicate; pred != nil {
pred(selector)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
@@ -858,6 +869,25 @@ func (u *updater) scan(rows *sql.Rows) error {
return nil
}
func (u *updater) ensureExists(ctx context.Context) error {
exists := u.builder.Select().From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value))
u.Predicate(exists)
query, args := u.builder.SelectExpr(sql.Exists(exists)).Query()
rows := &sql.Rows{}
if err := u.tx.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
found, err := sql.ScanBool(rows)
if err != nil {
return err
}
if !found {
return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value}
}
return nil
}
type creator struct {
graph
*CreateSpec