mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/sqlgraph: add update node api
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/222 Reviewed By: alexsn Differential Revision: D18833733 fbshipit-source-id: e833d84f4e5e5c73b1c85e7387472c9a87b7947e
This commit is contained in:
committed by
Facebook Github Bot
parent
bb051603ac
commit
0fb33aaa5e
@@ -10,6 +10,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/facebookincubator/ent/dialect"
|
||||
"github.com/facebookincubator/ent/schema/field"
|
||||
@@ -304,37 +305,170 @@ type (
|
||||
// EdgeSpecs used for perform common operations on list of edges.
|
||||
EdgeSpecs []*EdgeSpec
|
||||
|
||||
// CreateSpec holds the information for creating a node
|
||||
// in the graph.
|
||||
CreateSpec struct {
|
||||
Table string
|
||||
ID *FieldSpec
|
||||
Fields []*FieldSpec
|
||||
Edges []*EdgeSpec
|
||||
// NodeSpec defines the information for querying and
|
||||
// decoding nodes in the graph.
|
||||
NodeSpec struct {
|
||||
Table string
|
||||
Columns []string
|
||||
ID *FieldSpec
|
||||
}
|
||||
)
|
||||
|
||||
// CreateSpec holds the information for creating
|
||||
// a node in the graph.
|
||||
type CreateSpec struct {
|
||||
Table string
|
||||
ID *FieldSpec
|
||||
Fields []*FieldSpec
|
||||
Edges []*EdgeSpec
|
||||
}
|
||||
|
||||
// CreateNode applies the CreateSpec on the graph.
|
||||
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
|
||||
tx, err := drv.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cr := &creator{CreateSpec: spec, builder: Dialect(drv.Dialect())}
|
||||
gr := graph{tx: tx, builder: Dialect(drv.Dialect())}
|
||||
cr := &creator{CreateSpec: spec, graph: gr}
|
||||
if err := cr.node(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
type (
|
||||
// EdgeMut defines edge mutations.
|
||||
EdgeMut struct {
|
||||
Add []*EdgeSpec
|
||||
Clear []*EdgeSpec
|
||||
}
|
||||
|
||||
// FieldMut defines field mutations.
|
||||
FieldMut struct {
|
||||
Set []*FieldSpec // field = ?
|
||||
Add []*FieldSpec // field = field + ?
|
||||
Clear []*FieldSpec // field = NULL
|
||||
}
|
||||
|
||||
// UpdateSpec holds the information for updating one
|
||||
// or more nodes in the graph in the graph.
|
||||
UpdateSpec struct {
|
||||
Node *NodeSpec
|
||||
Edges EdgeMut
|
||||
Fields FieldMut
|
||||
Predicate func(*Selector)
|
||||
|
||||
ScanTypes []interface{}
|
||||
Assign func(...interface{}) error
|
||||
}
|
||||
)
|
||||
|
||||
// UpdateNode applies the UpdateSpec on one node in the graph.
|
||||
func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error {
|
||||
tx, err := drv.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gr := graph{tx: tx, builder: Dialect(drv.Dialect())}
|
||||
cr := &updater{UpdateSpec: spec, graph: gr}
|
||||
if err := cr.node(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
type updater struct {
|
||||
graph
|
||||
*UpdateSpec
|
||||
}
|
||||
|
||||
func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
var (
|
||||
// id holds the PK of the node used for linking
|
||||
// it with the other nodes.
|
||||
id = u.Node.ID.Value
|
||||
res sql.Result
|
||||
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
|
||||
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
|
||||
)
|
||||
update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id))
|
||||
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
|
||||
return err
|
||||
}
|
||||
if !update.Empty() {
|
||||
query, args := update.Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := u.graph.clearM2MEdges(ctx, id, clearEdges[M2M]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := u.graph.addM2MEdges(ctx, id, addEdges[M2M]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := u.graph.clearFKEdges(ctx, id, append(clearEdges[O2M], clearEdges[O2O]...)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := u.graph.addFKEdges(ctx, id, append(addEdges[O2M], addEdges[O2O]...)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Query and scan the node.
|
||||
selector := u.builder.Select(u.Node.Columns...).
|
||||
From(u.builder.Table(u.Node.Table)).
|
||||
Where(EQ(u.Node.ID.Column, u.Node.ID.Value))
|
||||
rows := &Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
return u.scan(rows)
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (u *updater) setTableColumns(update *UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error {
|
||||
for _, fi := range u.Fields.Clear {
|
||||
update.SetNull(fi.Column)
|
||||
}
|
||||
for _, e := range clearEdges[M2O] {
|
||||
update.SetNull(e.Columns[0])
|
||||
}
|
||||
for _, e := range clearEdges[O2O] {
|
||||
if e.Inverse || e.Bidi {
|
||||
update.SetNull(e.Columns[0])
|
||||
}
|
||||
}
|
||||
err := setTableColumns(u.Fields.Set, addEdges, func(column string, value driver.Value) {
|
||||
update.Set(column, value)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, fi := range u.Fields.Add {
|
||||
update.Add(fi.Column, fi.Value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) scan(rows *Rows) error {
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return fmt.Errorf("record with id %v not found in table %s", u.Node.ID.Value, u.Node.Table)
|
||||
}
|
||||
if err := rows.Scan(u.ScanTypes...); err != nil {
|
||||
return fmt.Errorf("failed scanning rows: %v", err)
|
||||
}
|
||||
return u.Assign(u.ScanTypes...)
|
||||
}
|
||||
|
||||
type creator struct {
|
||||
graph
|
||||
*CreateSpec
|
||||
builder *dialectBuilder
|
||||
}
|
||||
|
||||
func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
var (
|
||||
res sql.Result
|
||||
edges = EdgeSpecs(c.Edges).GroupRel()
|
||||
insert = c.builder.Insert(c.Table).Default()
|
||||
)
|
||||
@@ -345,76 +479,21 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
if err := c.insert(ctx, tx, insert); err != nil {
|
||||
return fmt.Errorf("insert node to table %s: %v", c.Table, err)
|
||||
}
|
||||
// Insert all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables := EdgeSpecs(edges[M2M]).GroupTable()
|
||||
for table, edges := range tables {
|
||||
edge := edges[0]
|
||||
insert = c.builder.Insert(table).Columns(edge.Columns...)
|
||||
for _, edge := range edges {
|
||||
pk1, pk2 := c.ID.Value, edge.Target.Nodes[0]
|
||||
if edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
insert.Values(pk1, pk2)
|
||||
if edge.Bidi {
|
||||
insert.Values(pk2, pk1)
|
||||
}
|
||||
}
|
||||
query, args := insert.Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %v", table, err)
|
||||
}
|
||||
if err := c.graph.addM2MEdges(ctx, c.ID.Value, edges[M2M]); err != nil {
|
||||
return err
|
||||
}
|
||||
// O2M and non-inverse O2O edges also reside in external tables.
|
||||
for _, edge := range append(edges[O2M], edges[O2O]...) {
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
continue
|
||||
}
|
||||
p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
|
||||
// Use "IN" predicate instead of list of "OR"
|
||||
// in case of more than on nodes to connect.
|
||||
if len(edge.Target.Nodes) > 1 {
|
||||
p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
|
||||
}
|
||||
query, args := c.builder.Update(edge.Table).
|
||||
Set(edge.Columns[0], c.ID.Value).
|
||||
Where(And(p, IsNull(edge.Columns[0]))).
|
||||
Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %v", edge.Table, err)
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ids := edge.Target.Nodes; int(affected) < len(ids) {
|
||||
return fmt.Errorf("one of %v is already connected to a different %s", ids, edge.Columns[0])
|
||||
}
|
||||
if err := c.graph.addFKEdges(ctx, c.ID.Value, append(edges[O2M], edges[O2O]...)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSpec) (err error) {
|
||||
for _, fi := range c.Fields {
|
||||
value := fi.Value
|
||||
if fi.Type == field.TypeJSON {
|
||||
if value, err = json.Marshal(value); err != nil {
|
||||
return fmt.Errorf("marshal value for column %s: %v", fi.Column, err)
|
||||
}
|
||||
}
|
||||
insert.Set(fi.Column, value)
|
||||
}
|
||||
for _, e := range edges[M2O] {
|
||||
insert.Set(e.Columns[0], e.Target.Nodes[0])
|
||||
}
|
||||
for _, e := range edges[O2O] {
|
||||
if e.Inverse || e.Bidi {
|
||||
insert.Set(e.Columns[0], e.Target.Nodes[0])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSpec) error {
|
||||
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
|
||||
insert.Set(column, value)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
|
||||
@@ -452,6 +531,149 @@ func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec {
|
||||
return edges
|
||||
}
|
||||
|
||||
// The common operations shared between the different builders.
|
||||
//
|
||||
// M2M edges reside in join tables and require INSERT and DELETE
|
||||
// queries for adding or removing edges respectively.
|
||||
//
|
||||
// O2M and non-inverse O2O edges also reside in external tables,
|
||||
// but use UPDATE queries (fk = ?, fk = NULL).
|
||||
type graph struct {
|
||||
tx dialect.ExecQuerier
|
||||
builder *dialectBuilder
|
||||
}
|
||||
|
||||
func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error {
|
||||
var (
|
||||
res Result
|
||||
// Delete all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
for _, table := range sortedKeys(tables) {
|
||||
edges := tables[table]
|
||||
preds := make([]*Predicate, 0, len(edges))
|
||||
for _, edge := range edges {
|
||||
pk1, pk2 := id, edge.Target.Nodes[0]
|
||||
if edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
preds = append(preds, EQ(edge.Columns[0], pk1).And().EQ(edge.Columns[1], pk2))
|
||||
if edge.Bidi {
|
||||
preds = append(preds, EQ(edge.Columns[0], pk2).And().EQ(edge.Columns[1], pk1))
|
||||
}
|
||||
}
|
||||
query, args := g.builder.Delete(table).Where(Or(preds...)).Query()
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("remove m2m edge for table %s: %v", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error {
|
||||
var (
|
||||
res Result
|
||||
// Insert all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
for _, table := range sortedKeys(tables) {
|
||||
edges := tables[table]
|
||||
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
|
||||
for _, edge := range edges {
|
||||
pk1, pk2 := id, edge.Target.Nodes[0]
|
||||
if edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
insert.Values(pk1, pk2)
|
||||
if edge.Bidi {
|
||||
insert.Values(pk2, pk1)
|
||||
}
|
||||
}
|
||||
query, args := insert.Query()
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %v", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) clearFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error {
|
||||
for _, edge := range edges {
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
continue
|
||||
}
|
||||
p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
|
||||
// Use "IN" predicate instead of list of "OR"
|
||||
// in case of more than on nodes to connect.
|
||||
if len(edge.Target.Nodes) > 1 {
|
||||
p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
|
||||
}
|
||||
query, args := g.builder.Update(edge.Table).
|
||||
SetNull(edge.Columns[0]).
|
||||
Where(And(p, EQ(edge.Columns[0], id))).
|
||||
Query()
|
||||
var res Result
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add %s edge for table %s: %v", edge.Rel, edge.Table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) addFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error {
|
||||
for _, edge := range edges {
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
continue
|
||||
}
|
||||
p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
|
||||
// Use "IN" predicate instead of list of "OR"
|
||||
// in case of more than on nodes to connect.
|
||||
if len(edge.Target.Nodes) > 1 {
|
||||
p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
|
||||
}
|
||||
query, args := g.builder.Update(edge.Table).
|
||||
Set(edge.Columns[0], id).
|
||||
Where(And(p, IsNull(edge.Columns[0]))).
|
||||
Query()
|
||||
var res Result
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add %s edge for table %s: %v", edge.Rel, edge.Table, err)
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ids := edge.Target.Nodes; int(affected) < len(ids) {
|
||||
return fmt.Errorf("one of %v is already connected to a different %s", ids, edge.Columns[0])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setTableColumns is shared between updater and creator.
|
||||
func setTableColumns(fields []*FieldSpec, edges map[Rel][]*EdgeSpec, set func(string, driver.Value)) (err error) {
|
||||
for _, fi := range fields {
|
||||
value := fi.Value
|
||||
if fi.Type == field.TypeJSON {
|
||||
if value, err = json.Marshal(value); err != nil {
|
||||
return fmt.Errorf("marshal value for column %s: %v", fi.Column, err)
|
||||
}
|
||||
}
|
||||
set(fi.Column, value)
|
||||
}
|
||||
for _, e := range edges[M2O] {
|
||||
set(e.Columns[0], e.Target.Nodes[0])
|
||||
}
|
||||
for _, e := range edges[O2O] {
|
||||
if e.Inverse || e.Bidi {
|
||||
set(e.Columns[0], e.Target.Nodes[0])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
|
||||
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) (int64, error) {
|
||||
query, args := insert.Query()
|
||||
@@ -480,3 +702,12 @@ func rollback(tx dialect.Tx, err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sortedKeys(m map[string][]*EdgeSpec) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
@@ -719,12 +719,12 @@ func TestCreateNode(t *testing.T) {
|
||||
m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")).
|
||||
WithArgs("mashraki").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")).
|
||||
WithArgs(1, 2, 2, 1, 1, 3, 3, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")).
|
||||
WithArgs(4, 1, 5, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")).
|
||||
WithArgs(1, 2, 2, 1, 1, 3, 3, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.ExpectCommit()
|
||||
},
|
||||
},
|
||||
@@ -740,6 +740,222 @@ func TestCreateNode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type user struct {
|
||||
id int
|
||||
age int
|
||||
name string
|
||||
}
|
||||
|
||||
func (*user) values() []interface{} {
|
||||
return []interface{}{&NullInt64{}, &NullInt64{}, &NullString{}}
|
||||
}
|
||||
|
||||
func (u *user) assign(values ...interface{}) error {
|
||||
u.id = int(values[0].(*NullInt64).Int64)
|
||||
u.age = int(values[1].(*NullInt64).Int64)
|
||||
u.name = values[2].(*NullString).String
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateOne(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec *UpdateSpec
|
||||
prepare func(sqlmock.Sqlmock)
|
||||
wantErr bool
|
||||
wantUser *user
|
||||
}{
|
||||
{
|
||||
name: "fields/set",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
Columns: []string{"id", "name", "age"},
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Fields: FieldMut{
|
||||
Set: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 30},
|
||||
{Column: "name", Type: field.TypeString, Value: "Ariel"},
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ? WHERE `id` = ?")).
|
||||
WithArgs(30, "Ariel", 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
|
||||
AddRow(1, 30, "Ariel"))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantUser: &user{name: "Ariel", age: 30, id: 1},
|
||||
},
|
||||
{
|
||||
name: "fields/add_clear",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
Columns: []string{"id", "name", "age"},
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Fields: FieldMut{
|
||||
Add: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Clear: []*FieldSpec{
|
||||
{Column: "name", Type: field.TypeString},
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`age`, ?) + ? WHERE `id` = ?")).
|
||||
WithArgs(0, 1, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
|
||||
AddRow(1, 31, nil))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantUser: &user{age: 31, id: 1},
|
||||
},
|
||||
{
|
||||
name: "edges/o2o_non_inverse and m2o",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
Columns: []string{"id", "name", "age"},
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Edges: EdgeMut{
|
||||
Clear: []*EdgeSpec{
|
||||
{Rel: O2O, Columns: []string{"car_id"}, Inverse: true},
|
||||
{Rel: M2O, Columns: []string{"workplace_id"}, Inverse: true},
|
||||
},
|
||||
Add: []*EdgeSpec{
|
||||
{Rel: O2O, Columns: []string{"card_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}},
|
||||
{Rel: M2O, Columns: []string{"parent_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ? WHERE `id` = ?")).
|
||||
WithArgs(2, 2, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
|
||||
AddRow(1, 31, nil))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantUser: &user{age: 31, id: 1},
|
||||
},
|
||||
{
|
||||
name: "edges/o2o_bidi",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
Columns: []string{"id", "name", "age"},
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Edges: EdgeMut{
|
||||
Clear: []*EdgeSpec{
|
||||
{Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
|
||||
},
|
||||
Add: []*EdgeSpec{
|
||||
{Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Clear "spouse 2" from 1's column, and set "spouse 3".
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL, `spouse_id` = ? WHERE `id` = ?")).
|
||||
WithArgs(3, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Clear "spouse 1" from 3's column.
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL WHERE (`id` = ?) AND (`spouse_id` = ?)")).
|
||||
WithArgs(2, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Set 3's column to point "spouse 1".
|
||||
mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = ? WHERE (`id` = ?) AND (`spouse_id` IS NULL)")).
|
||||
WithArgs(1, 3).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
|
||||
AddRow(1, 31, nil))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantUser: &user{age: 31, id: 1},
|
||||
},
|
||||
{
|
||||
name: "edges/clear_add_m2m",
|
||||
spec: &UpdateSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
Columns: []string{"id", "name", "age"},
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
|
||||
},
|
||||
Edges: EdgeMut{
|
||||
Clear: []*EdgeSpec{
|
||||
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3}}},
|
||||
},
|
||||
Add: []*EdgeSpec{
|
||||
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{4}}},
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{5}}},
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{6}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
// Clear user groups.
|
||||
mock.ExpectExec(escape("DELETE FROM `group_users` WHERE (`group_id` = ? AND `user_id` = ?)")).
|
||||
WithArgs(3, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Clear user friends.
|
||||
mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE ((`user_id` = ? AND `friend_id` = ?) OR (`user_id` = ? AND `friend_id` = ?))")).
|
||||
WithArgs(1, 2, 2, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Add new groups.
|
||||
mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")).
|
||||
WithArgs(5, 1, 6, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
// Add new friends.
|
||||
mock.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")).
|
||||
WithArgs(1, 4, 4, 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")).
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
|
||||
AddRow(1, 31, nil))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
wantUser: &user{age: 31, id: 1},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.prepare(mock)
|
||||
usr := &user{}
|
||||
tt.spec.Assign = usr.assign
|
||||
tt.spec.ScanTypes = usr.values()
|
||||
err = UpdateNode(context.Background(), OpenDB("", db), tt.spec)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
require.Equal(t, tt.wantUser, usr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func escape(query string) string {
|
||||
rows := strings.Split(query, "\n")
|
||||
for i := range rows {
|
||||
|
||||
@@ -195,6 +195,7 @@ func Sanity(t *testing.T, client *ent.Client) {
|
||||
usr = client.User.UpdateOne(usr).SetName("baz").AddGroups(grp).SaveX(ctx)
|
||||
require.Equal("baz", usr.Name)
|
||||
require.NotEmpty(usr.QueryGroups().AllX(ctx))
|
||||
|
||||
// grouping.
|
||||
var v []struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
Reference in New Issue
Block a user