From 9f324ce0301ac65e85d8823102bd3680d0aba6dc Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Wed, 11 Dec 2019 15:38:29 +0200 Subject: [PATCH] dialect/sql/sqlgraph: delete nodes in the graph --- dialect/sql/graph.go | 35 ++++++++++++++++++++++++++++++++++- dialect/sql/graph_test.go | 17 +++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index 7e321ff5f..a6ea35ce3 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -352,7 +352,7 @@ type ( } // UpdateSpec holds the information for updating one - // or more nodes in the graph in the graph. + // or more nodes in the graph. UpdateSpec struct { Node *NodeSpec Edges EdgeMut @@ -393,6 +393,39 @@ func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int return affected, tx.Commit() } +// DeleteSpec holds the information for delete one +// or more nodes in the graph. +type DeleteSpec struct { + Node *NodeSpec + Predicate func(*Selector) +} + +// DeleteNodes applies the DeleteSpec on the graph. +func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int, error) { + tx, err := drv.Tx(ctx) + if err != nil { + return 0, err + } + var ( + res Result + builder = Dialect(drv.Dialect()) + ) + selector := builder.Select(). + From(builder.Table(spec.Node.Table)) + if pred := spec.Predicate; pred != nil { + pred(selector) + } + query, args := builder.Delete(spec.Node.Table).FromSelect(selector).Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + affected, err := res.RowsAffected() + if err != nil { + return 0, rollback(tx, err) + } + return int(affected), tx.Commit() +} + type updater struct { graph *UpdateSpec diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index bd6d33b11..d6ecdb3d2 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -1124,6 +1124,23 @@ func TestUpdateNodes(t *testing.T) { } } +func TestDeleteNodes(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + mock.ExpectBegin() + mock.ExpectExec(escape("DELETE FROM `users`")). + WillReturnResult(sqlmock.NewResult(0, 2)) + mock.ExpectCommit() + affected, err := DeleteNodes(context.Background(), OpenDB("", db), &DeleteSpec{ + Node: &NodeSpec{ + Table: "users", + ID: &FieldSpec{Column: "id", Type: field.TypeInt}, + }, + }) + require.NoError(t, err) + require.Equal(t, 2, affected) +} + func escape(query string) string { rows := strings.Split(query, "\n") for i := range rows {