diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index d56dc3bb7..7e321ff5f 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -402,11 +402,11 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error { var ( // id holds the PK of the node used for linking // it with the other nodes. - id = []driver.Value{u.Node.ID.Value} + id = u.Node.ID.Value addEdges = EdgeSpecs(u.Edges.Add).GroupRel() clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel() ) - update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id[0])) + update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id)) if err := u.setTableColumns(update, addEdges, clearEdges); err != nil { return err } @@ -417,19 +417,9 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error { return err } } - if err := u.graph.clearM2MEdges(ctx, id, clearEdges[M2M]); err != nil { + if err := u.setExternalEdges(ctx, []driver.Value{id}, addEdges, clearEdges); err != nil { return err } - if err := u.graph.addM2MEdges(ctx, id, addEdges[M2M]); err != nil { - return err - } - if err := u.graph.clearFKEdges(ctx, id, append(clearEdges[O2M], clearEdges[O2O]...)); err != nil { - return err - } - if err := u.graph.addFKEdges(ctx, id, append(addEdges[O2M], addEdges[O2O]...)); err != nil { - return err - } - // Query and scan the node. selector := u.builder.Select(u.Node.Columns...). From(u.builder.Table(u.Node.Table)). Where(EQ(u.Node.ID.Column, u.Node.ID.Value)) @@ -478,9 +468,28 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error return 0, err } } + if err := u.setExternalEdges(ctx, ids, addEdges, clearEdges); err != nil { + return 0, err + } return len(ids), nil } +func (u *updater) setExternalEdges(ctx context.Context, ids []driver.Value, addEdges, clearEdges map[Rel][]*EdgeSpec) error { + if err := u.graph.clearM2MEdges(ctx, ids, clearEdges[M2M]); err != nil { + return err + } + if err := u.graph.addM2MEdges(ctx, ids, addEdges[M2M]); err != nil { + return err + } + if err := u.graph.clearFKEdges(ctx, ids, append(clearEdges[O2M], clearEdges[O2O]...)); err != nil { + return err + } + if err := u.graph.addFKEdges(ctx, ids, append(addEdges[O2M], addEdges[O2O]...)); err != nil { + return err + } + return nil +} + // setTableColumns sets the table columns and foreign_keys used in insert. func (u *updater) setTableColumns(update *UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error { for _, fi := range u.Fields.Clear { @@ -681,7 +690,7 @@ func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*E func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error { id := ids[0] - if len(ids) > 1 { + if len(ids) > 1 && len(edges) != 0 { // O2M and O2O edges are defined by a FK in the "other" table. // Therefore, ids[i+1] will override ids[i] which is invalid. return fmt.Errorf("unable to link FK edge to more than 1 node: %v", ids)