dialect/sql/sqlgraph: update nodes using predicate

Currently, only fields and own-FK.
Next PR will edge types: M2M, O2M and O2O (non-inverse).
This commit is contained in:
Ariel Mashraki
2019-12-09 21:19:23 +02:00
parent 75c70bac1b
commit f86e39f179
2 changed files with 210 additions and 32 deletions

View File

@@ -378,6 +378,21 @@ func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error
return tx.Commit()
}
// UpdateNodes applies the UpdateSpec on a set of nodes in the graph.
func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int, error) {
tx, err := drv.Tx(ctx)
if err != nil {
return 0, err
}
gr := graph{tx: tx, builder: Dialect(drv.Dialect())}
cr := &updater{UpdateSpec: spec, graph: gr}
affected, err := cr.nodes(ctx, tx)
if err != nil {
return 0, rollback(tx, err)
}
return affected, tx.Commit()
}
type updater struct {
graph
*UpdateSpec
@@ -387,16 +402,16 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
var (
// id holds the PK of the node used for linking
// it with the other nodes.
id = u.Node.ID.Value
res sql.Result
id = []driver.Value{u.Node.ID.Value}
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
)
update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id))
update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id[0]))
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
return err
}
if !update.Empty() {
var res Result
query, args := update.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return err
@@ -426,6 +441,46 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
return u.scan(rows)
}
func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error) {
var (
ids []driver.Value
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
)
selector := u.builder.Select(u.Node.ID.Column).
From(u.builder.Table(u.Node.Table))
if pred := u.Predicate; pred != nil {
pred(selector)
}
query, args := selector.Query()
rows := &Rows{}
if err := u.tx.Query(ctx, query, args, rows); err != nil {
return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err)
}
defer rows.Close()
if err := ScanSlice(rows, &ids); err != nil {
return 0, fmt.Errorf("scan node ids: %v", err)
}
if err := rows.Close(); err != nil {
return 0, err
}
if len(ids) == 0 {
return 0, nil
}
update := u.builder.Update(u.Node.Table).Where(matchID(u.Node.ID.Column, ids))
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
return 0, err
}
if !update.Empty() {
var res Result
query, args := update.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, err
}
}
return len(ids), nil
}
// setTableColumns sets the table columns and foreign_keys used in insert.
func (u *updater) setTableColumns(update *UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error {
for _, fi := range u.Fields.Clear {
@@ -479,10 +534,10 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
if err := c.insert(ctx, tx, insert); err != nil {
return fmt.Errorf("insert node to table %s: %v", c.Table, err)
}
if err := c.graph.addM2MEdges(ctx, c.ID.Value, edges[M2M]); err != nil {
if err := c.graph.addM2MEdges(ctx, []driver.Value{c.ID.Value}, edges[M2M]); err != nil {
return err
}
if err := c.graph.addFKEdges(ctx, c.ID.Value, append(edges[O2M], edges[O2O]...)); err != nil {
if err := c.graph.addFKEdges(ctx, []driver.Value{c.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil {
return err
}
return nil
@@ -543,7 +598,7 @@ type graph struct {
builder *dialectBuilder
}
func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error {
func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
var (
res Result
// Delete all M2M edges from the same type at once.
@@ -554,13 +609,13 @@ func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSp
edges := tables[table]
preds := make([]*Predicate, 0, len(edges))
for _, edge := range edges {
pk1, pk2 := id, edge.Target.Nodes[0]
pk1, pk2 := ids, edge.Target.Nodes
if edge.Inverse {
pk1, pk2 = pk2, pk1
}
preds = append(preds, EQ(edge.Columns[0], pk1).And().EQ(edge.Columns[1], pk2))
preds = append(preds, matchIDs(edge.Columns[0], pk1, edge.Columns[1], pk2))
if edge.Bidi {
preds = append(preds, EQ(edge.Columns[0], pk2).And().EQ(edge.Columns[1], pk1))
preds = append(preds, matchIDs(edge.Columns[0], pk2, edge.Columns[1], pk1))
}
}
query, args := g.builder.Delete(table).Where(Or(preds...)).Query()
@@ -571,7 +626,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSp
return nil
}
func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error {
func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
var (
res Result
// Insert all M2M edges from the same type at once.
@@ -582,13 +637,15 @@ func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpec
edges := tables[table]
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
for _, edge := range edges {
pk1, pk2 := id, edge.Target.Nodes[0]
pk1, pk2 := ids, edge.Target.Nodes
if edge.Inverse {
pk1, pk2 = pk2, pk1
}
insert.Values(pk1, pk2)
if edge.Bidi {
insert.Values(pk2, pk1)
for _, pair := range product(pk1, pk2) {
insert.Values(pair[0], pair[1])
if edge.Bidi {
insert.Values(pair[1], pair[0])
}
}
}
query, args := insert.Query()
@@ -599,20 +656,20 @@ func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpec
return nil
}
func (g *graph) clearFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error {
func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error {
for _, edge := range edges {
if edge.Rel == O2O && edge.Inverse {
continue
}
p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
// Use "IN" predicate instead of list of "OR"
// in case of more than on nodes to connect.
if len(edge.Target.Nodes) > 1 {
p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
// O2O relations can be cleared without
// passing the target ids.
pred := matchID(edge.Columns[0], ids)
if nodes := edge.Target.Nodes; len(nodes) > 0 {
pred = matchIDs(edge.Target.IDSpec.Column, edge.Target.Nodes, edge.Columns[0], ids)
}
query, args := g.builder.Update(edge.Table).
SetNull(edge.Columns[0]).
Where(And(p, EQ(edge.Columns[0], id))).
Where(pred).
Query()
var res Result
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
@@ -622,7 +679,13 @@ func (g *graph) clearFKEdges(ctx context.Context, id driver.Value, edges []*Edge
return nil
}
func (g *graph) addFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error {
func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error {
id := ids[0]
if len(ids) > 1 {
// O2M and O2O edges are defined by a FK in the "other" table.
// Therefore, ids[i+1] will override ids[i] which is invalid.
return fmt.Errorf("unable to link FK edge to more than 1 node: %v", ids)
}
for _, edge := range edges {
if edge.Rel == O2O && edge.Inverse {
continue
@@ -711,3 +774,31 @@ func sortedKeys(m map[string][]*EdgeSpec) []string {
sort.Strings(keys)
return keys
}
func matchID(column string, pk []driver.Value) *Predicate {
if len(pk) > 1 {
return InValues(column, pk...)
}
return EQ(column, pk[0])
}
func matchIDs(column1 string, pk1 []driver.Value, column2 string, pk2 []driver.Value) *Predicate {
p := matchID(column1, pk1)
if len(pk2) > 1 {
// Use "IN" predicate instead of list of "OR"
// in case of more than on nodes to connect.
return p.And().InValues(column2, pk2...)
}
return p.And().EQ(column2, pk2[0])
}
// cartesian product of 2 id sets.
func product(a, b []driver.Value) [][2]driver.Value {
c := make([][2]driver.Value, 0, len(a)*len(b))
for i := range a {
for j := range b {
c = append(c, [2]driver.Value{a[i], b[j]})
}
}
return c
}