mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/sqlgraph: update nodes using predicate
Currently, only fields and own-FK. Next PR will edge types: M2M, O2M and O2O (non-inverse).
This commit is contained in:
@@ -378,6 +378,21 @@ func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// UpdateNodes applies the UpdateSpec on a set of nodes in the graph.
|
||||
func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int, error) {
|
||||
tx, err := drv.Tx(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
gr := graph{tx: tx, builder: Dialect(drv.Dialect())}
|
||||
cr := &updater{UpdateSpec: spec, graph: gr}
|
||||
affected, err := cr.nodes(ctx, tx)
|
||||
if err != nil {
|
||||
return 0, rollback(tx, err)
|
||||
}
|
||||
return affected, tx.Commit()
|
||||
}
|
||||
|
||||
type updater struct {
|
||||
graph
|
||||
*UpdateSpec
|
||||
@@ -387,16 +402,16 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
var (
|
||||
// id holds the PK of the node used for linking
|
||||
// it with the other nodes.
|
||||
id = u.Node.ID.Value
|
||||
res sql.Result
|
||||
id = []driver.Value{u.Node.ID.Value}
|
||||
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
|
||||
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
|
||||
)
|
||||
update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id))
|
||||
update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id[0]))
|
||||
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
|
||||
return err
|
||||
}
|
||||
if !update.Empty() {
|
||||
var res Result
|
||||
query, args := update.Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return err
|
||||
@@ -426,6 +441,46 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
return u.scan(rows)
|
||||
}
|
||||
|
||||
func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error) {
|
||||
var (
|
||||
ids []driver.Value
|
||||
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
|
||||
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
|
||||
)
|
||||
selector := u.builder.Select(u.Node.ID.Column).
|
||||
From(u.builder.Table(u.Node.Table))
|
||||
if pred := u.Predicate; pred != nil {
|
||||
pred(selector)
|
||||
}
|
||||
query, args := selector.Query()
|
||||
rows := &Rows{}
|
||||
if err := u.tx.Query(ctx, query, args, rows); err != nil {
|
||||
return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := ScanSlice(rows, &ids); err != nil {
|
||||
return 0, fmt.Errorf("scan node ids: %v", err)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
update := u.builder.Update(u.Node.Table).Where(matchID(u.Node.ID.Column, ids))
|
||||
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !update.Empty() {
|
||||
var res Result
|
||||
query, args := update.Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(ids), nil
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (u *updater) setTableColumns(update *UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error {
|
||||
for _, fi := range u.Fields.Clear {
|
||||
@@ -479,10 +534,10 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
if err := c.insert(ctx, tx, insert); err != nil {
|
||||
return fmt.Errorf("insert node to table %s: %v", c.Table, err)
|
||||
}
|
||||
if err := c.graph.addM2MEdges(ctx, c.ID.Value, edges[M2M]); err != nil {
|
||||
if err := c.graph.addM2MEdges(ctx, []driver.Value{c.ID.Value}, edges[M2M]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.graph.addFKEdges(ctx, c.ID.Value, append(edges[O2M], edges[O2O]...)); err != nil {
|
||||
if err := c.graph.addFKEdges(ctx, []driver.Value{c.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -543,7 +598,7 @@ type graph struct {
|
||||
builder *dialectBuilder
|
||||
}
|
||||
|
||||
func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error {
|
||||
func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
|
||||
var (
|
||||
res Result
|
||||
// Delete all M2M edges from the same type at once.
|
||||
@@ -554,13 +609,13 @@ func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSp
|
||||
edges := tables[table]
|
||||
preds := make([]*Predicate, 0, len(edges))
|
||||
for _, edge := range edges {
|
||||
pk1, pk2 := id, edge.Target.Nodes[0]
|
||||
pk1, pk2 := ids, edge.Target.Nodes
|
||||
if edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
preds = append(preds, EQ(edge.Columns[0], pk1).And().EQ(edge.Columns[1], pk2))
|
||||
preds = append(preds, matchIDs(edge.Columns[0], pk1, edge.Columns[1], pk2))
|
||||
if edge.Bidi {
|
||||
preds = append(preds, EQ(edge.Columns[0], pk2).And().EQ(edge.Columns[1], pk1))
|
||||
preds = append(preds, matchIDs(edge.Columns[0], pk2, edge.Columns[1], pk1))
|
||||
}
|
||||
}
|
||||
query, args := g.builder.Delete(table).Where(Or(preds...)).Query()
|
||||
@@ -571,7 +626,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error {
|
||||
func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
|
||||
var (
|
||||
res Result
|
||||
// Insert all M2M edges from the same type at once.
|
||||
@@ -582,13 +637,15 @@ func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpec
|
||||
edges := tables[table]
|
||||
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
|
||||
for _, edge := range edges {
|
||||
pk1, pk2 := id, edge.Target.Nodes[0]
|
||||
pk1, pk2 := ids, edge.Target.Nodes
|
||||
if edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
insert.Values(pk1, pk2)
|
||||
if edge.Bidi {
|
||||
insert.Values(pk2, pk1)
|
||||
for _, pair := range product(pk1, pk2) {
|
||||
insert.Values(pair[0], pair[1])
|
||||
if edge.Bidi {
|
||||
insert.Values(pair[1], pair[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
query, args := insert.Query()
|
||||
@@ -599,20 +656,20 @@ func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpec
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) clearFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error {
|
||||
func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error {
|
||||
for _, edge := range edges {
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
continue
|
||||
}
|
||||
p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
|
||||
// Use "IN" predicate instead of list of "OR"
|
||||
// in case of more than on nodes to connect.
|
||||
if len(edge.Target.Nodes) > 1 {
|
||||
p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
|
||||
// O2O relations can be cleared without
|
||||
// passing the target ids.
|
||||
pred := matchID(edge.Columns[0], ids)
|
||||
if nodes := edge.Target.Nodes; len(nodes) > 0 {
|
||||
pred = matchIDs(edge.Target.IDSpec.Column, edge.Target.Nodes, edge.Columns[0], ids)
|
||||
}
|
||||
query, args := g.builder.Update(edge.Table).
|
||||
SetNull(edge.Columns[0]).
|
||||
Where(And(p, EQ(edge.Columns[0], id))).
|
||||
Where(pred).
|
||||
Query()
|
||||
var res Result
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
@@ -622,7 +679,13 @@ func (g *graph) clearFKEdges(ctx context.Context, id driver.Value, edges []*Edge
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) addFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error {
|
||||
func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error {
|
||||
id := ids[0]
|
||||
if len(ids) > 1 {
|
||||
// O2M and O2O edges are defined by a FK in the "other" table.
|
||||
// Therefore, ids[i+1] will override ids[i] which is invalid.
|
||||
return fmt.Errorf("unable to link FK edge to more than 1 node: %v", ids)
|
||||
}
|
||||
for _, edge := range edges {
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
continue
|
||||
@@ -711,3 +774,31 @@ func sortedKeys(m map[string][]*EdgeSpec) []string {
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func matchID(column string, pk []driver.Value) *Predicate {
|
||||
if len(pk) > 1 {
|
||||
return InValues(column, pk...)
|
||||
}
|
||||
return EQ(column, pk[0])
|
||||
}
|
||||
|
||||
func matchIDs(column1 string, pk1 []driver.Value, column2 string, pk2 []driver.Value) *Predicate {
|
||||
p := matchID(column1, pk1)
|
||||
if len(pk2) > 1 {
|
||||
// Use "IN" predicate instead of list of "OR"
|
||||
// in case of more than on nodes to connect.
|
||||
return p.And().InValues(column2, pk2...)
|
||||
}
|
||||
return p.And().EQ(column2, pk2[0])
|
||||
}
|
||||
|
||||
// cartesian product of 2 id sets.
|
||||
func product(a, b []driver.Value) [][2]driver.Value {
|
||||
c := make([][2]driver.Value, 0, len(a)*len(b))
|
||||
for i := range a {
|
||||
for j := range b {
|
||||
c = append(c, [2]driver.Value{a[i], b[j]})
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -757,7 +757,7 @@ func (u *user) assign(values ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateOne(t *testing.T) {
|
||||
func TestUpdateNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec *UpdateSpec
|
||||
@@ -865,6 +865,7 @@ func TestUpdateOne(t *testing.T) {
|
||||
},
|
||||
Edges: EdgeMut{
|
||||
Clear: []*EdgeSpec{
|
||||
{Rel: O2O, Table: "users", Bidi: true, Columns: []string{"partner_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}},
|
||||
{Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
|
||||
},
|
||||
Add: []*EdgeSpec{
|
||||
@@ -874,12 +875,16 @@ func TestUpdateOne(t *testing.T) {
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Clear "spouse 2" from 1's column, and set "spouse 3".
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL, `spouse_id` = ? WHERE `id` = ?")).
|
||||
// Clear the "partner" and "spouse 2" from 1's column, and set "spouse 3".
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `partner_id` = NULL, `spouse_id` = NULL, `spouse_id` = ? WHERE `id` = ?")).
|
||||
WithArgs(3, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Clear the "partner_id" column from previous 1's partner.
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `partner_id` = NULL WHERE `partner_id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Clear "spouse 1" from 3's column.
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL WHERE (`id` = ?) AND (`spouse_id` = ?)")).
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL WHERE `id` = ? AND `spouse_id` = ?")).
|
||||
WithArgs(2, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Set 3's column to point "spouse 1".
|
||||
@@ -905,28 +910,28 @@ func TestUpdateOne(t *testing.T) {
|
||||
Edges: EdgeMut{
|
||||
Clear: []*EdgeSpec{
|
||||
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3}}},
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3, 7}}},
|
||||
},
|
||||
Add: []*EdgeSpec{
|
||||
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{4}}},
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{5}}},
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{6}}},
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{6, 8}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Clear user groups.
|
||||
mock.ExpectExec(escape("DELETE FROM `group_users` WHERE (`group_id` = ? AND `user_id` = ?)")).
|
||||
WithArgs(3, 1).
|
||||
mock.ExpectExec(escape("DELETE FROM `group_users` WHERE (`group_id` IN (?, ?) AND `user_id` = ?)")).
|
||||
WithArgs(3, 7, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Clear user friends.
|
||||
mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE ((`user_id` = ? AND `friend_id` = ?) OR (`user_id` = ? AND `friend_id` = ?))")).
|
||||
WithArgs(1, 2, 2, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Add new groups.
|
||||
mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")).
|
||||
WithArgs(5, 1, 6, 1).
|
||||
mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?), (?, ?)")).
|
||||
WithArgs(5, 1, 6, 1, 8, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Add new friends.
|
||||
mock.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")).
|
||||
@@ -956,6 +961,88 @@ func TestUpdateOne(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec *UpdateSpec
|
||||
prepare func(sqlmock.Sqlmock)
|
||||
wantErr bool
|
||||
wantAffected int
|
||||
}{
|
||||
{
|
||||
name: "without predicate",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
|
||||
},
|
||||
Fields: FieldMut{
|
||||
Set: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 30},
|
||||
{Column: "name", Type: field.TypeString, Value: "Ariel"},
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Get all node ids first.
|
||||
mock.ExpectQuery(escape("SELECT `id` FROM `users`")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(1).
|
||||
AddRow(2))
|
||||
// Apply field changes.
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ? WHERE `id` IN (?, ?)")).
|
||||
WithArgs(30, "Ariel", 1, 2).
|
||||
WillReturnResult(sqlmock.NewResult(0, 2))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantAffected: 2,
|
||||
},
|
||||
{
|
||||
name: "with",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
|
||||
},
|
||||
Fields: FieldMut{
|
||||
Clear: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt},
|
||||
{Column: "name", Type: field.TypeString},
|
||||
},
|
||||
},
|
||||
Predicate: func(s *Selector) {
|
||||
s.Where(EQ("name", "a8m"))
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Get all node ids first.
|
||||
mock.ExpectQuery(escape("SELECT `id` FROM `users` WHERE `name` = ?")).
|
||||
WithArgs("a8m").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(1))
|
||||
// Clear fields.
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `age` = NULL, `name` = NULL WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantAffected: 1,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.prepare(mock)
|
||||
affected, err := UpdateNodes(context.Background(), OpenDB("", db), tt.spec)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
require.Equal(t, tt.wantAffected, affected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func escape(query string) string {
|
||||
rows := strings.Split(query, "\n")
|
||||
for i := range rows {
|
||||
|
||||
Reference in New Issue
Block a user