dialect/sql: add an option for clearing m2m edges (#730)

This commit is contained in:
Ariel Mashraki
2020-09-06 17:27:31 +03:00
committed by GitHub
parent feed51d773
commit fc03257412
2 changed files with 96 additions and 18 deletions

View File

@@ -927,7 +927,7 @@ type graph struct {
func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
var (
res sql.Result
// Delete all M2M edges from the same type at once.
// Remove all M2M edges from the same type at once.
// The EdgeSpec is the same for all members in a group.
tables = edges.GroupTable()
)
@@ -935,13 +935,23 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
edges := tables[table]
preds := make([]*sql.Predicate, 0, len(edges))
for _, edge := range edges {
pk1, pk2 := ids, edge.Target.Nodes
fromC, toC := edge.Columns[0], edge.Columns[1]
if edge.Inverse {
pk1, pk2 = pk2, pk1
fromC, toC = toC, fromC
}
preds = append(preds, matchIDs(edge.Columns[0], pk1, edge.Columns[1], pk2))
if edge.Bidi {
preds = append(preds, matchIDs(edge.Columns[0], pk2, edge.Columns[1], pk1))
// If there are no specific edges (to target-nodes) to remove,
// clear all edges that go out (or come in) from the nodes.
if len(edge.Target.Nodes) == 0 {
preds = append(preds, matchID(fromC, ids))
if edge.Bidi {
preds = append(preds, matchID(toC, ids))
}
} else {
pk1, pk2 := ids, edge.Target.Nodes
preds = append(preds, matchIDs(fromC, pk1, toC, pk2))
if edge.Bidi {
preds = append(preds, matchIDs(toC, pk1, fromC, pk2))
}
}
}
query, args := g.builder.Delete(table).Where(sql.Or(preds...)).Query()

View File

@@ -1010,6 +1010,12 @@ func TestUpdateNode(t *testing.T) {
Clear: []*EdgeSpec{
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3, 7}}},
// Clear all "following" edges (and their inverse).
{Rel: M2M, Table: "user_following", Bidi: true, Columns: []string{"following_id", "follower_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}},
// Clear all "user_blocked" edges.
{Rel: M2M, Table: "user_blocked", Columns: []string{"user_id", "blocked_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}},
// Clear all "comments" edges.
{Rel: M2M, Inverse: true, Table: "comment_responders", Columns: []string{"comment_id", "responder_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}},
},
Add: []*EdgeSpec{
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{4}}},
@@ -1020,13 +1026,25 @@ func TestUpdateNode(t *testing.T) {
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
// Clear user groups.
mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `group_id` IN (?, ?) AND `user_id` = ?")).
WithArgs(3, 7, 1).
// Clear comment responders.
mock.ExpectExec(escape("DELETE FROM `comment_responders` WHERE `responder_id` = ?")).
WithArgs(1).
WillReturnResult(sqlmock.NewResult(1, 1))
// Remove user groups.
mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` = ? AND `group_id` IN (?, ?)")).
WithArgs(1, 3, 7).
WillReturnResult(sqlmock.NewResult(1, 1))
// Clear all blocked users.
mock.ExpectExec(escape("DELETE FROM `user_blocked` WHERE `user_id` = ?")).
WithArgs(1).
WillReturnResult(sqlmock.NewResult(1, 1))
// Clear all user following.
mock.ExpectExec(escape("DELETE FROM `user_following` WHERE `following_id` = ? OR `follower_id` = ?")).
WithArgs(1, 1).
WillReturnResult(sqlmock.NewResult(1, 2))
// Clear user friends.
mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` = ? AND `friend_id` = ?) OR (`user_id` = ? AND `friend_id` = ?)")).
WithArgs(1, 2, 2, 1).
mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` = ? AND `friend_id` = ?) OR (`friend_id` = ? AND `user_id` = ?)")).
WithArgs(1, 2, 1, 2).
WillReturnResult(sqlmock.NewResult(1, 1))
// Add new groups.
mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?), (?, ?)")).
@@ -1162,7 +1180,7 @@ func TestUpdateNodes(t *testing.T) {
wantAffected: 1,
},
{
name: "m2m",
name: "m2m_one",
spec: &UpdateSpec{
Node: &NodeSpec{
Table: "users",
@@ -1187,16 +1205,16 @@ func TestUpdateNodes(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(1))
// Clear user's groups.
mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `group_id` IN (?, ?) AND `user_id` = ?")).
WithArgs(2, 3, 1).
mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` = ? AND `group_id` IN (?, ?)")).
WithArgs(1, 2, 3).
WillReturnResult(sqlmock.NewResult(0, 2))
// Clear user's followers.
mock.ExpectExec(escape("DELETE FROM `user_followers` WHERE (`user_id` = ? AND `follower_id` IN (?, ?)) OR (`user_id` IN (?, ?) AND `follower_id` = ?)")).
WithArgs(1, 5, 6, 5, 6, 1).
mock.ExpectExec(escape("DELETE FROM `user_followers` WHERE (`user_id` = ? AND `follower_id` IN (?, ?)) OR (`follower_id` = ? AND `user_id` IN (?, ?))")).
WithArgs(1, 5, 6, 1, 5, 6).
WillReturnResult(sqlmock.NewResult(0, 2))
// Clear user's friends.
mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` = ? AND `friend_id` = ?) OR (`user_id` = ? AND `friend_id` = ?)")).
WithArgs(1, 4, 4, 1).
mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` = ? AND `friend_id` = ?) OR (`friend_id` = ? AND `user_id` = ?)")).
WithArgs(1, 4, 1, 4).
WillReturnResult(sqlmock.NewResult(0, 2))
// Attach new groups to user.
mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")).
@@ -1210,6 +1228,56 @@ func TestUpdateNodes(t *testing.T) {
},
wantAffected: 1,
},
{
name: "m2m_many",
spec: &UpdateSpec{
Node: &NodeSpec{
Table: "users",
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
},
Edges: EdgeMut{
Clear: []*EdgeSpec{
{Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2, 3}}},
{Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{5, 6}}},
{Rel: M2M, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{4}}},
},
Add: []*EdgeSpec{
{Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{7, 8}}},
{Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{9}}},
},
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
// Get all node ids first.
mock.ExpectQuery(escape("SELECT `id` FROM `users`")).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(10).
AddRow(20))
// Clear user's groups.
mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` IN (?, ?) AND `group_id` IN (?, ?)")).
WithArgs(10, 20, 2, 3).
WillReturnResult(sqlmock.NewResult(0, 2))
// Clear user's followers.
mock.ExpectExec(escape("DELETE FROM `user_followers` WHERE (`user_id` IN (?, ?) AND `follower_id` IN (?, ?)) OR (`follower_id` IN (?, ?) AND `user_id` IN (?, ?))")).
WithArgs(10, 20, 5, 6, 10, 20, 5, 6).
WillReturnResult(sqlmock.NewResult(0, 2))
// Clear user's friends.
mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` IN (?, ?) AND `friend_id` = ?) OR (`friend_id` IN (?, ?) AND `user_id` = ?)")).
WithArgs(10, 20, 4, 10, 20, 4).
WillReturnResult(sqlmock.NewResult(0, 2))
// Attach new groups to user.
mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")).
WithArgs(7, 10, 7, 20, 8, 10, 8, 20).
WillReturnResult(sqlmock.NewResult(0, 4))
// Attach new friends to user.
mock.ExpectExec(escape("INSERT INTO `user_followers` (`user_id`, `follower_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")).
WithArgs(10, 9, 9, 10, 20, 9, 9, 20).
WillReturnResult(sqlmock.NewResult(0, 4))
mock.ExpectCommit()
},
wantAffected: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {