mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/sqlgraph: better support for update nodes with predicates (#2574)
This commit is contained in:
@@ -36,7 +36,7 @@ func ScanOne(rows ColumnScanner, v interface{}) error {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// ScanInt64 scans and returns an int64 from the rows columns.
|
||||
// ScanInt64 scans and returns an int64 from the rows.
|
||||
func ScanInt64(rows ColumnScanner) (int64, error) {
|
||||
var n int64
|
||||
if err := ScanOne(rows, &n); err != nil {
|
||||
@@ -45,7 +45,7 @@ func ScanInt64(rows ColumnScanner) (int64, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// ScanInt scans and returns an int from the rows columns.
|
||||
// ScanInt scans and returns an int from the rows.
|
||||
func ScanInt(rows ColumnScanner) (int, error) {
|
||||
n, err := ScanInt64(rows)
|
||||
if err != nil {
|
||||
@@ -54,7 +54,16 @@ func ScanInt(rows ColumnScanner) (int, error) {
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
// ScanString scans and returns a string from the rows columns.
|
||||
// ScanBool scans and returns a boolean from the rows.
|
||||
func ScanBool(rows ColumnScanner) (bool, error) {
|
||||
var b bool
|
||||
if err := ScanOne(rows, &b); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// ScanString scans and returns a string from the rows.
|
||||
func ScanString(rows ColumnScanner) (string, error) {
|
||||
var s string
|
||||
if err := ScanOne(rows, &s); err != nil {
|
||||
@@ -63,7 +72,7 @@ func ScanString(rows ColumnScanner) (string, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ScanValue scans and returns a driver.Value from the rows columns.
|
||||
// ScanValue scans and returns a driver.Value from the rows.
|
||||
func ScanValue(rows ColumnScanner) (driver.Value, error) {
|
||||
var v driver.Value
|
||||
if err := ScanOne(rows, &v); err != nil {
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"entgo.io/ent/schema/field"
|
||||
)
|
||||
|
||||
// Rel is a relation type of an edge.
|
||||
// Rel is an edge relation type.
|
||||
type Rel int
|
||||
|
||||
// Relation types.
|
||||
@@ -671,10 +671,22 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
return err
|
||||
}
|
||||
if !update.Empty() {
|
||||
var res sql.Result
|
||||
query, args := update.Query()
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// In case there are zero affected rows by this statement, we need to distinguish
|
||||
// between the case of "record was not found" and "record was not changed".
|
||||
if affected == 0 && u.Predicate != nil {
|
||||
if err := u.ensureExists(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := u.setExternalEdges(ctx, []driver.Value{id}, addEdges, clearEdges); err != nil {
|
||||
return err
|
||||
@@ -686,10 +698,9 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
}
|
||||
selector := u.builder.Select(u.Node.Columns...).
|
||||
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).
|
||||
// Skip adding the custom predicates that were attached to the updater
|
||||
// as they may point to columns that were changed by the UPDATE statement.
|
||||
Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value))
|
||||
if pred := u.Predicate; pred != nil {
|
||||
pred(selector)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
@@ -858,6 +869,25 @@ func (u *updater) scan(rows *sql.Rows) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) ensureExists(ctx context.Context) error {
|
||||
exists := u.builder.Select().From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value))
|
||||
u.Predicate(exists)
|
||||
query, args := u.builder.SelectExpr(sql.Exists(exists)).Query()
|
||||
rows := &sql.Rows{}
|
||||
if err := u.tx.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
found, err := sql.ScanBool(rows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !found {
|
||||
return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type creator struct {
|
||||
graph
|
||||
*CreateSpec
|
||||
|
||||
@@ -1572,7 +1572,7 @@ func TestUpdateNode(t *testing.T) {
|
||||
wantUser: &user{name: "Ariel", age: 30, id: 1},
|
||||
},
|
||||
{
|
||||
name: "fields/add_clear",
|
||||
name: "fields/add_set_clear",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
@@ -1586,6 +1586,9 @@ func TestUpdateNode(t *testing.T) {
|
||||
Add: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Set: []*FieldSpec{
|
||||
{Column: "deleted", Type: field.TypeBool, Value: true},
|
||||
},
|
||||
Clear: []*FieldSpec{
|
||||
{Column: "name", Type: field.TypeString},
|
||||
},
|
||||
@@ -1593,10 +1596,10 @@ func TestUpdateNode(t *testing.T) {
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")).
|
||||
WithArgs(1, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND NOT `deleted`")).
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `deleted` = ?, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")).
|
||||
WithArgs(true, 1, 1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
|
||||
AddRow(1, 31, nil))
|
||||
@@ -1604,6 +1607,43 @@ func TestUpdateNode(t *testing.T) {
|
||||
},
|
||||
wantUser: &user{age: 31, id: 1},
|
||||
},
|
||||
{
|
||||
name: "fields/ensure_exists",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
Columns: []string{"id", "name", "age"},
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("deleted", false))
|
||||
},
|
||||
Fields: FieldMut{
|
||||
Add: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Set: []*FieldSpec{
|
||||
{Column: "deleted", Type: field.TypeBool, Value: true},
|
||||
},
|
||||
Clear: []*FieldSpec{
|
||||
{Column: "name", Type: field.TypeString},
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `deleted` = ?, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")).
|
||||
WithArgs(true, 1, 1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery(escape("SELECT EXISTS (SELECT * FROM `users` WHERE `id` = ? AND NOT `deleted`)")).
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).
|
||||
AddRow(false))
|
||||
mock.ExpectRollback()
|
||||
},
|
||||
wantErr: true,
|
||||
wantUser: &user{},
|
||||
},
|
||||
{
|
||||
name: "edges/o2o_non_inverse and m2o",
|
||||
spec: &UpdateSpec{
|
||||
|
||||
Reference in New Issue
Block a user