From f86e39f179b8a4770d61e195699d026eaafd6dc4 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 9 Dec 2019 21:19:23 +0200 Subject: [PATCH] dialect/sql/sqlgraph: update nodes using predicate Currently, only fields and own-FK. Next PR will edge types: M2M, O2M and O2O (non-inverse). --- dialect/sql/graph.go | 135 +++++++++++++++++++++++++++++++------- dialect/sql/graph_test.go | 107 +++++++++++++++++++++++++++--- 2 files changed, 210 insertions(+), 32 deletions(-) diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index 21aed61fc..d56dc3bb7 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -378,6 +378,21 @@ func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error return tx.Commit() } +// UpdateNodes applies the UpdateSpec on a set of nodes in the graph. +func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int, error) { + tx, err := drv.Tx(ctx) + if err != nil { + return 0, err + } + gr := graph{tx: tx, builder: Dialect(drv.Dialect())} + cr := &updater{UpdateSpec: spec, graph: gr} + affected, err := cr.nodes(ctx, tx) + if err != nil { + return 0, rollback(tx, err) + } + return affected, tx.Commit() +} + type updater struct { graph *UpdateSpec @@ -387,16 +402,16 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error { var ( // id holds the PK of the node used for linking // it with the other nodes. - id = u.Node.ID.Value - res sql.Result + id = []driver.Value{u.Node.ID.Value} addEdges = EdgeSpecs(u.Edges.Add).GroupRel() clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel() ) - update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id)) + update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id[0])) if err := u.setTableColumns(update, addEdges, clearEdges); err != nil { return err } if !update.Empty() { + var res Result query, args := update.Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return err @@ -426,6 +441,46 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error { return u.scan(rows) } +func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error) { + var ( + ids []driver.Value + addEdges = EdgeSpecs(u.Edges.Add).GroupRel() + clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel() + ) + selector := u.builder.Select(u.Node.ID.Column). + From(u.builder.Table(u.Node.Table)) + if pred := u.Predicate; pred != nil { + pred(selector) + } + query, args := selector.Query() + rows := &Rows{} + if err := u.tx.Query(ctx, query, args, rows); err != nil { + return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err) + } + defer rows.Close() + if err := ScanSlice(rows, &ids); err != nil { + return 0, fmt.Errorf("scan node ids: %v", err) + } + if err := rows.Close(); err != nil { + return 0, err + } + if len(ids) == 0 { + return 0, nil + } + update := u.builder.Update(u.Node.Table).Where(matchID(u.Node.ID.Column, ids)) + if err := u.setTableColumns(update, addEdges, clearEdges); err != nil { + return 0, err + } + if !update.Empty() { + var res Result + query, args := update.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, err + } + } + return len(ids), nil +} + // setTableColumns sets the table columns and foreign_keys used in insert. func (u *updater) setTableColumns(update *UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error { for _, fi := range u.Fields.Clear { @@ -479,10 +534,10 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error { if err := c.insert(ctx, tx, insert); err != nil { return fmt.Errorf("insert node to table %s: %v", c.Table, err) } - if err := c.graph.addM2MEdges(ctx, c.ID.Value, edges[M2M]); err != nil { + if err := c.graph.addM2MEdges(ctx, []driver.Value{c.ID.Value}, edges[M2M]); err != nil { return err } - if err := c.graph.addFKEdges(ctx, c.ID.Value, append(edges[O2M], edges[O2O]...)); err != nil { + if err := c.graph.addFKEdges(ctx, []driver.Value{c.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil { return err } return nil @@ -543,7 +598,7 @@ type graph struct { builder *dialectBuilder } -func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error { +func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error { var ( res Result // Delete all M2M edges from the same type at once. @@ -554,13 +609,13 @@ func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSp edges := tables[table] preds := make([]*Predicate, 0, len(edges)) for _, edge := range edges { - pk1, pk2 := id, edge.Target.Nodes[0] + pk1, pk2 := ids, edge.Target.Nodes if edge.Inverse { pk1, pk2 = pk2, pk1 } - preds = append(preds, EQ(edge.Columns[0], pk1).And().EQ(edge.Columns[1], pk2)) + preds = append(preds, matchIDs(edge.Columns[0], pk1, edge.Columns[1], pk2)) if edge.Bidi { - preds = append(preds, EQ(edge.Columns[0], pk2).And().EQ(edge.Columns[1], pk1)) + preds = append(preds, matchIDs(edge.Columns[0], pk2, edge.Columns[1], pk1)) } } query, args := g.builder.Delete(table).Where(Or(preds...)).Query() @@ -571,7 +626,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSp return nil } -func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error { +func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error { var ( res Result // Insert all M2M edges from the same type at once. @@ -582,13 +637,15 @@ func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpec edges := tables[table] insert := g.builder.Insert(table).Columns(edges[0].Columns...) for _, edge := range edges { - pk1, pk2 := id, edge.Target.Nodes[0] + pk1, pk2 := ids, edge.Target.Nodes if edge.Inverse { pk1, pk2 = pk2, pk1 } - insert.Values(pk1, pk2) - if edge.Bidi { - insert.Values(pk2, pk1) + for _, pair := range product(pk1, pk2) { + insert.Values(pair[0], pair[1]) + if edge.Bidi { + insert.Values(pair[1], pair[0]) + } } } query, args := insert.Query() @@ -599,20 +656,20 @@ func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpec return nil } -func (g *graph) clearFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error { +func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error { for _, edge := range edges { if edge.Rel == O2O && edge.Inverse { continue } - p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0]) - // Use "IN" predicate instead of list of "OR" - // in case of more than on nodes to connect. - if len(edge.Target.Nodes) > 1 { - p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...) + // O2O relations can be cleared without + // passing the target ids. + pred := matchID(edge.Columns[0], ids) + if nodes := edge.Target.Nodes; len(nodes) > 0 { + pred = matchIDs(edge.Target.IDSpec.Column, edge.Target.Nodes, edge.Columns[0], ids) } query, args := g.builder.Update(edge.Table). SetNull(edge.Columns[0]). - Where(And(p, EQ(edge.Columns[0], id))). + Where(pred). Query() var res Result if err := g.tx.Exec(ctx, query, args, &res); err != nil { @@ -622,7 +679,13 @@ func (g *graph) clearFKEdges(ctx context.Context, id driver.Value, edges []*Edge return nil } -func (g *graph) addFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error { +func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error { + id := ids[0] + if len(ids) > 1 { + // O2M and O2O edges are defined by a FK in the "other" table. + // Therefore, ids[i+1] will override ids[i] which is invalid. + return fmt.Errorf("unable to link FK edge to more than 1 node: %v", ids) + } for _, edge := range edges { if edge.Rel == O2O && edge.Inverse { continue @@ -711,3 +774,31 @@ func sortedKeys(m map[string][]*EdgeSpec) []string { sort.Strings(keys) return keys } + +func matchID(column string, pk []driver.Value) *Predicate { + if len(pk) > 1 { + return InValues(column, pk...) + } + return EQ(column, pk[0]) +} + +func matchIDs(column1 string, pk1 []driver.Value, column2 string, pk2 []driver.Value) *Predicate { + p := matchID(column1, pk1) + if len(pk2) > 1 { + // Use "IN" predicate instead of list of "OR" + // in case of more than on nodes to connect. + return p.And().InValues(column2, pk2...) + } + return p.And().EQ(column2, pk2[0]) +} + +// cartesian product of 2 id sets. +func product(a, b []driver.Value) [][2]driver.Value { + c := make([][2]driver.Value, 0, len(a)*len(b)) + for i := range a { + for j := range b { + c = append(c, [2]driver.Value{a[i], b[j]}) + } + } + return c +} diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index ed676e05f..08e4502fa 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -757,7 +757,7 @@ func (u *user) assign(values ...interface{}) error { return nil } -func TestUpdateOne(t *testing.T) { +func TestUpdateNode(t *testing.T) { tests := []struct { name string spec *UpdateSpec @@ -865,6 +865,7 @@ func TestUpdateOne(t *testing.T) { }, Edges: EdgeMut{ Clear: []*EdgeSpec{ + {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"partner_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}}, }, Add: []*EdgeSpec{ @@ -874,12 +875,16 @@ func TestUpdateOne(t *testing.T) { }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() - // Clear "spouse 2" from 1's column, and set "spouse 3". - mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL, `spouse_id` = ? WHERE `id` = ?")). + // Clear the "partner" and "spouse 2" from 1's column, and set "spouse 3". + mock.ExpectExec(escape("UPDATE `users` SET `partner_id` = NULL, `spouse_id` = NULL, `spouse_id` = ? WHERE `id` = ?")). WithArgs(3, 1). WillReturnResult(sqlmock.NewResult(1, 1)) + // Clear the "partner_id" column from previous 1's partner. + mock.ExpectExec(escape("UPDATE `users` SET `partner_id` = NULL WHERE `partner_id` = ?")). + WithArgs(1). + WillReturnResult(sqlmock.NewResult(1, 1)) // Clear "spouse 1" from 3's column. - mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL WHERE (`id` = ?) AND (`spouse_id` = ?)")). + mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL WHERE `id` = ? AND `spouse_id` = ?")). WithArgs(2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Set 3's column to point "spouse 1". @@ -905,28 +910,28 @@ func TestUpdateOne(t *testing.T) { Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}}, - {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3}}}, + {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3, 7}}}, }, Add: []*EdgeSpec{ {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{4}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{5}}}, - {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{6}}}, + {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{6, 8}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Clear user groups. - mock.ExpectExec(escape("DELETE FROM `group_users` WHERE (`group_id` = ? AND `user_id` = ?)")). - WithArgs(3, 1). + mock.ExpectExec(escape("DELETE FROM `group_users` WHERE (`group_id` IN (?, ?) AND `user_id` = ?)")). + WithArgs(3, 7, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear user friends. mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE ((`user_id` = ? AND `friend_id` = ?) OR (`user_id` = ? AND `friend_id` = ?))")). WithArgs(1, 2, 2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Add new groups. - mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). - WithArgs(5, 1, 6, 1). + mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?), (?, ?)")). + WithArgs(5, 1, 6, 1, 8, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Add new friends. mock.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")). @@ -956,6 +961,88 @@ func TestUpdateOne(t *testing.T) { } } +func TestUpdateNodes(t *testing.T) { + tests := []struct { + name string + spec *UpdateSpec + prepare func(sqlmock.Sqlmock) + wantErr bool + wantAffected int + }{ + { + name: "without predicate", + spec: &UpdateSpec{ + Node: &NodeSpec{ + Table: "users", + ID: &FieldSpec{Column: "id", Type: field.TypeInt}, + }, + Fields: FieldMut{ + Set: []*FieldSpec{ + {Column: "age", Type: field.TypeInt, Value: 30}, + {Column: "name", Type: field.TypeString, Value: "Ariel"}, + }, + }, + }, + prepare: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + // Get all node ids first. + mock.ExpectQuery(escape("SELECT `id` FROM `users`")). + WillReturnRows(sqlmock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2)) + // Apply field changes. + mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ? WHERE `id` IN (?, ?)")). + WithArgs(30, "Ariel", 1, 2). + WillReturnResult(sqlmock.NewResult(0, 2)) + mock.ExpectCommit() + }, + wantAffected: 2, + }, + { + name: "with", + spec: &UpdateSpec{ + Node: &NodeSpec{ + Table: "users", + ID: &FieldSpec{Column: "id", Type: field.TypeInt}, + }, + Fields: FieldMut{ + Clear: []*FieldSpec{ + {Column: "age", Type: field.TypeInt}, + {Column: "name", Type: field.TypeString}, + }, + }, + Predicate: func(s *Selector) { + s.Where(EQ("name", "a8m")) + }, + }, + prepare: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + // Get all node ids first. + mock.ExpectQuery(escape("SELECT `id` FROM `users` WHERE `name` = ?")). + WithArgs("a8m"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). + AddRow(1)) + // Clear fields. + mock.ExpectExec(escape("UPDATE `users` SET `age` = NULL, `name` = NULL WHERE `id` = ?")). + WithArgs(1). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + }, + wantAffected: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + tt.prepare(mock) + affected, err := UpdateNodes(context.Background(), OpenDB("", db), tt.spec) + require.Equal(t, tt.wantErr, err != nil, err) + require.Equal(t, tt.wantAffected, affected) + }) + } +} + func escape(query string) string { rows := strings.Split(query, "\n") for i := range rows {