dialect/sql/sqlgraph: better support for update nodes with predicates (#2574)

This commit is contained in:
Ariel Mashraki
2022-05-29 16:23:52 +03:00
committed by GitHub
parent 0917701f91
commit 5b81d7d832
7 changed files with 116 additions and 21 deletions

View File

@@ -36,7 +36,7 @@ func ScanOne(rows ColumnScanner, v interface{}) error {
return rows.Err()
}
// ScanInt64 scans and returns an int64 from the rows columns.
// ScanInt64 scans and returns an int64 from the rows.
func ScanInt64(rows ColumnScanner) (int64, error) {
var n int64
if err := ScanOne(rows, &n); err != nil {
@@ -45,7 +45,7 @@ func ScanInt64(rows ColumnScanner) (int64, error) {
return n, nil
}
// ScanInt scans and returns an int from the rows columns.
// ScanInt scans and returns an int from the rows.
func ScanInt(rows ColumnScanner) (int, error) {
n, err := ScanInt64(rows)
if err != nil {
@@ -54,7 +54,16 @@ func ScanInt(rows ColumnScanner) (int, error) {
return int(n), nil
}
// ScanString scans and returns a string from the rows columns.
// ScanBool scans and returns a boolean from the rows.
func ScanBool(rows ColumnScanner) (bool, error) {
var b bool
if err := ScanOne(rows, &b); err != nil {
return false, err
}
return b, nil
}
// ScanString scans and returns a string from the rows.
func ScanString(rows ColumnScanner) (string, error) {
var s string
if err := ScanOne(rows, &s); err != nil {
@@ -63,7 +72,7 @@ func ScanString(rows ColumnScanner) (string, error) {
return s, nil
}
// ScanValue scans and returns a driver.Value from the rows columns.
// ScanValue scans and returns a driver.Value from the rows.
func ScanValue(rows ColumnScanner) (driver.Value, error) {
var v driver.Value
if err := ScanOne(rows, &v); err != nil {

View File

@@ -19,7 +19,7 @@ import (
"entgo.io/ent/schema/field"
)
// Rel is a relation type of an edge.
// Rel is an edge relation type.
type Rel int
// Relation types.
@@ -671,10 +671,22 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
return err
}
if !update.Empty() {
var res sql.Result
query, args := update.Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
if err := tx.Exec(ctx, query, args, &res); err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
// In case there are zero affected rows by this statement, we need to distinguish
// between the case of "record was not found" and "record was not changed".
if affected == 0 && u.Predicate != nil {
if err := u.ensureExists(ctx); err != nil {
return err
}
}
}
if err := u.setExternalEdges(ctx, []driver.Value{id}, addEdges, clearEdges); err != nil {
return err
@@ -686,10 +698,9 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
}
selector := u.builder.Select(u.Node.Columns...).
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).
// Skip adding the custom predicates that were attached to the updater
// as they may point to columns that were changed by the UPDATE statement.
Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value))
if pred := u.Predicate; pred != nil {
pred(selector)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
@@ -858,6 +869,25 @@ func (u *updater) scan(rows *sql.Rows) error {
return nil
}
func (u *updater) ensureExists(ctx context.Context) error {
exists := u.builder.Select().From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value))
u.Predicate(exists)
query, args := u.builder.SelectExpr(sql.Exists(exists)).Query()
rows := &sql.Rows{}
if err := u.tx.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
found, err := sql.ScanBool(rows)
if err != nil {
return err
}
if !found {
return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value}
}
return nil
}
type creator struct {
graph
*CreateSpec

View File

@@ -1572,7 +1572,7 @@ func TestUpdateNode(t *testing.T) {
wantUser: &user{name: "Ariel", age: 30, id: 1},
},
{
name: "fields/add_clear",
name: "fields/add_set_clear",
spec: &UpdateSpec{
Node: &NodeSpec{
Table: "users",
@@ -1586,6 +1586,9 @@ func TestUpdateNode(t *testing.T) {
Add: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 1},
},
Set: []*FieldSpec{
{Column: "deleted", Type: field.TypeBool, Value: true},
},
Clear: []*FieldSpec{
{Column: "name", Type: field.TypeString},
},
@@ -1593,10 +1596,10 @@ func TestUpdateNode(t *testing.T) {
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")).
WithArgs(1, 1).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND NOT `deleted`")).
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `deleted` = ?, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")).
WithArgs(true, 1, 1).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
AddRow(1, 31, nil))
@@ -1604,6 +1607,43 @@ func TestUpdateNode(t *testing.T) {
},
wantUser: &user{age: 31, id: 1},
},
{
name: "fields/ensure_exists",
spec: &UpdateSpec{
Node: &NodeSpec{
Table: "users",
Columns: []string{"id", "name", "age"},
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
},
Predicate: func(s *sql.Selector) {
s.Where(sql.EQ("deleted", false))
},
Fields: FieldMut{
Add: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 1},
},
Set: []*FieldSpec{
{Column: "deleted", Type: field.TypeBool, Value: true},
},
Clear: []*FieldSpec{
{Column: "name", Type: field.TypeString},
},
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `deleted` = ?, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")).
WithArgs(true, 1, 1).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery(escape("SELECT EXISTS (SELECT * FROM `users` WHERE `id` = ? AND NOT `deleted`)")).
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).
AddRow(false))
mock.ExpectRollback()
},
wantErr: true,
wantUser: &user{},
},
{
name: "edges/o2o_non_inverse and m2o",
spec: &UpdateSpec{