diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index 7f86c097b..21aed61fc 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -10,6 +10,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "sort" "github.com/facebookincubator/ent/dialect" "github.com/facebookincubator/ent/schema/field" @@ -304,37 +305,170 @@ type ( // EdgeSpecs used for perform common operations on list of edges. EdgeSpecs []*EdgeSpec - // CreateSpec holds the information for creating a node - // in the graph. - CreateSpec struct { - Table string - ID *FieldSpec - Fields []*FieldSpec - Edges []*EdgeSpec + // NodeSpec defines the information for querying and + // decoding nodes in the graph. + NodeSpec struct { + Table string + Columns []string + ID *FieldSpec } ) +// CreateSpec holds the information for creating +// a node in the graph. +type CreateSpec struct { + Table string + ID *FieldSpec + Fields []*FieldSpec + Edges []*EdgeSpec +} + // CreateNode applies the CreateSpec on the graph. func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error { tx, err := drv.Tx(ctx) if err != nil { return err } - cr := &creator{CreateSpec: spec, builder: Dialect(drv.Dialect())} + gr := graph{tx: tx, builder: Dialect(drv.Dialect())} + cr := &creator{CreateSpec: spec, graph: gr} if err := cr.node(ctx, tx); err != nil { return rollback(tx, err) } return tx.Commit() } +type ( + // EdgeMut defines edge mutations. + EdgeMut struct { + Add []*EdgeSpec + Clear []*EdgeSpec + } + + // FieldMut defines field mutations. + FieldMut struct { + Set []*FieldSpec // field = ? + Add []*FieldSpec // field = field + ? + Clear []*FieldSpec // field = NULL + } + + // UpdateSpec holds the information for updating one + // or more nodes in the graph in the graph. + UpdateSpec struct { + Node *NodeSpec + Edges EdgeMut + Fields FieldMut + Predicate func(*Selector) + + ScanTypes []interface{} + Assign func(...interface{}) error + } +) + +// UpdateNode applies the UpdateSpec on one node in the graph. +func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error { + tx, err := drv.Tx(ctx) + if err != nil { + return err + } + gr := graph{tx: tx, builder: Dialect(drv.Dialect())} + cr := &updater{UpdateSpec: spec, graph: gr} + if err := cr.node(ctx, tx); err != nil { + return rollback(tx, err) + } + return tx.Commit() +} + +type updater struct { + graph + *UpdateSpec +} + +func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error { + var ( + // id holds the PK of the node used for linking + // it with the other nodes. + id = u.Node.ID.Value + res sql.Result + addEdges = EdgeSpecs(u.Edges.Add).GroupRel() + clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel() + ) + update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id)) + if err := u.setTableColumns(update, addEdges, clearEdges); err != nil { + return err + } + if !update.Empty() { + query, args := update.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return err + } + } + if err := u.graph.clearM2MEdges(ctx, id, clearEdges[M2M]); err != nil { + return err + } + if err := u.graph.addM2MEdges(ctx, id, addEdges[M2M]); err != nil { + return err + } + if err := u.graph.clearFKEdges(ctx, id, append(clearEdges[O2M], clearEdges[O2O]...)); err != nil { + return err + } + if err := u.graph.addFKEdges(ctx, id, append(addEdges[O2M], addEdges[O2O]...)); err != nil { + return err + } + // Query and scan the node. + selector := u.builder.Select(u.Node.Columns...). + From(u.builder.Table(u.Node.Table)). + Where(EQ(u.Node.ID.Column, u.Node.ID.Value)) + rows := &Rows{} + query, args := selector.Query() + if err := tx.Query(ctx, query, args, rows); err != nil { + return err + } + return u.scan(rows) +} + +// setTableColumns sets the table columns and foreign_keys used in insert. +func (u *updater) setTableColumns(update *UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error { + for _, fi := range u.Fields.Clear { + update.SetNull(fi.Column) + } + for _, e := range clearEdges[M2O] { + update.SetNull(e.Columns[0]) + } + for _, e := range clearEdges[O2O] { + if e.Inverse || e.Bidi { + update.SetNull(e.Columns[0]) + } + } + err := setTableColumns(u.Fields.Set, addEdges, func(column string, value driver.Value) { + update.Set(column, value) + }) + if err != nil { + return err + } + for _, fi := range u.Fields.Add { + update.Add(fi.Column, fi.Value) + } + return nil +} + +func (u *updater) scan(rows *Rows) error { + defer rows.Close() + if !rows.Next() { + return fmt.Errorf("record with id %v not found in table %s", u.Node.ID.Value, u.Node.Table) + } + if err := rows.Scan(u.ScanTypes...); err != nil { + return fmt.Errorf("failed scanning rows: %v", err) + } + return u.Assign(u.ScanTypes...) +} + type creator struct { + graph *CreateSpec - builder *dialectBuilder } func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error { var ( - res sql.Result edges = EdgeSpecs(c.Edges).GroupRel() insert = c.builder.Insert(c.Table).Default() ) @@ -345,76 +479,21 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error { if err := c.insert(ctx, tx, insert); err != nil { return fmt.Errorf("insert node to table %s: %v", c.Table, err) } - // Insert all M2M edges from the same type at once. - // The EdgeSpec is the same for all members in a group. - tables := EdgeSpecs(edges[M2M]).GroupTable() - for table, edges := range tables { - edge := edges[0] - insert = c.builder.Insert(table).Columns(edge.Columns...) - for _, edge := range edges { - pk1, pk2 := c.ID.Value, edge.Target.Nodes[0] - if edge.Inverse { - pk1, pk2 = pk2, pk1 - } - insert.Values(pk1, pk2) - if edge.Bidi { - insert.Values(pk2, pk1) - } - } - query, args := insert.Query() - if err := tx.Exec(ctx, query, args, &res); err != nil { - return fmt.Errorf("add m2m edge for table %s: %v", table, err) - } + if err := c.graph.addM2MEdges(ctx, c.ID.Value, edges[M2M]); err != nil { + return err } - // O2M and non-inverse O2O edges also reside in external tables. - for _, edge := range append(edges[O2M], edges[O2O]...) { - if edge.Rel == O2O && edge.Inverse { - continue - } - p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0]) - // Use "IN" predicate instead of list of "OR" - // in case of more than on nodes to connect. - if len(edge.Target.Nodes) > 1 { - p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...) - } - query, args := c.builder.Update(edge.Table). - Set(edge.Columns[0], c.ID.Value). - Where(And(p, IsNull(edge.Columns[0]))). - Query() - if err := tx.Exec(ctx, query, args, &res); err != nil { - return fmt.Errorf("add m2m edge for table %s: %v", edge.Table, err) - } - affected, err := res.RowsAffected() - if err != nil { - return err - } - if ids := edge.Target.Nodes; int(affected) < len(ids) { - return fmt.Errorf("one of %v is already connected to a different %s", ids, edge.Columns[0]) - } + if err := c.graph.addFKEdges(ctx, c.ID.Value, append(edges[O2M], edges[O2O]...)); err != nil { + return err } return nil } // setTableColumns sets the table columns and foreign_keys used in insert. -func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSpec) (err error) { - for _, fi := range c.Fields { - value := fi.Value - if fi.Type == field.TypeJSON { - if value, err = json.Marshal(value); err != nil { - return fmt.Errorf("marshal value for column %s: %v", fi.Column, err) - } - } - insert.Set(fi.Column, value) - } - for _, e := range edges[M2O] { - insert.Set(e.Columns[0], e.Target.Nodes[0]) - } - for _, e := range edges[O2O] { - if e.Inverse || e.Bidi { - insert.Set(e.Columns[0], e.Target.Nodes[0]) - } - } - return nil +func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSpec) error { + err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) { + insert.Set(column, value) + }) + return err } // insert inserts the node to its table and sets its ID if it wasn't provided by the user. @@ -452,6 +531,149 @@ func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec { return edges } +// The common operations shared between the different builders. +// +// M2M edges reside in join tables and require INSERT and DELETE +// queries for adding or removing edges respectively. +// +// O2M and non-inverse O2O edges also reside in external tables, +// but use UPDATE queries (fk = ?, fk = NULL). +type graph struct { + tx dialect.ExecQuerier + builder *dialectBuilder +} + +func (g *graph) clearM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error { + var ( + res Result + // Delete all M2M edges from the same type at once. + // The EdgeSpec is the same for all members in a group. + tables = edges.GroupTable() + ) + for _, table := range sortedKeys(tables) { + edges := tables[table] + preds := make([]*Predicate, 0, len(edges)) + for _, edge := range edges { + pk1, pk2 := id, edge.Target.Nodes[0] + if edge.Inverse { + pk1, pk2 = pk2, pk1 + } + preds = append(preds, EQ(edge.Columns[0], pk1).And().EQ(edge.Columns[1], pk2)) + if edge.Bidi { + preds = append(preds, EQ(edge.Columns[0], pk2).And().EQ(edge.Columns[1], pk1)) + } + } + query, args := g.builder.Delete(table).Where(Or(preds...)).Query() + if err := g.tx.Exec(ctx, query, args, &res); err != nil { + return fmt.Errorf("remove m2m edge for table %s: %v", table, err) + } + } + return nil +} + +func (g *graph) addM2MEdges(ctx context.Context, id driver.Value, edges EdgeSpecs) error { + var ( + res Result + // Insert all M2M edges from the same type at once. + // The EdgeSpec is the same for all members in a group. + tables = edges.GroupTable() + ) + for _, table := range sortedKeys(tables) { + edges := tables[table] + insert := g.builder.Insert(table).Columns(edges[0].Columns...) + for _, edge := range edges { + pk1, pk2 := id, edge.Target.Nodes[0] + if edge.Inverse { + pk1, pk2 = pk2, pk1 + } + insert.Values(pk1, pk2) + if edge.Bidi { + insert.Values(pk2, pk1) + } + } + query, args := insert.Query() + if err := g.tx.Exec(ctx, query, args, &res); err != nil { + return fmt.Errorf("add m2m edge for table %s: %v", table, err) + } + } + return nil +} + +func (g *graph) clearFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error { + for _, edge := range edges { + if edge.Rel == O2O && edge.Inverse { + continue + } + p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0]) + // Use "IN" predicate instead of list of "OR" + // in case of more than on nodes to connect. + if len(edge.Target.Nodes) > 1 { + p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...) + } + query, args := g.builder.Update(edge.Table). + SetNull(edge.Columns[0]). + Where(And(p, EQ(edge.Columns[0], id))). + Query() + var res Result + if err := g.tx.Exec(ctx, query, args, &res); err != nil { + return fmt.Errorf("add %s edge for table %s: %v", edge.Rel, edge.Table, err) + } + } + return nil +} + +func (g *graph) addFKEdges(ctx context.Context, id driver.Value, edges []*EdgeSpec) error { + for _, edge := range edges { + if edge.Rel == O2O && edge.Inverse { + continue + } + p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0]) + // Use "IN" predicate instead of list of "OR" + // in case of more than on nodes to connect. + if len(edge.Target.Nodes) > 1 { + p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...) + } + query, args := g.builder.Update(edge.Table). + Set(edge.Columns[0], id). + Where(And(p, IsNull(edge.Columns[0]))). + Query() + var res Result + if err := g.tx.Exec(ctx, query, args, &res); err != nil { + return fmt.Errorf("add %s edge for table %s: %v", edge.Rel, edge.Table, err) + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if ids := edge.Target.Nodes; int(affected) < len(ids) { + return fmt.Errorf("one of %v is already connected to a different %s", ids, edge.Columns[0]) + } + } + return nil +} + +// setTableColumns is shared between updater and creator. +func setTableColumns(fields []*FieldSpec, edges map[Rel][]*EdgeSpec, set func(string, driver.Value)) (err error) { + for _, fi := range fields { + value := fi.Value + if fi.Type == field.TypeJSON { + if value, err = json.Marshal(value); err != nil { + return fmt.Errorf("marshal value for column %s: %v", fi.Column, err) + } + } + set(fi.Column, value) + } + for _, e := range edges[M2O] { + set(e.Columns[0], e.Target.Nodes[0]) + } + for _, e := range edges[O2O] { + if e.Inverse || e.Bidi { + set(e.Columns[0], e.Target.Nodes[0]) + } + } + return nil +} + // insertLastID invokes the insert query on the transaction and returns the LastInsertID. func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) (int64, error) { query, args := insert.Query() @@ -480,3 +702,12 @@ func rollback(tx dialect.Tx, err error) error { } return err } + +func sortedKeys(m map[string][]*EdgeSpec) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index 364f6cda3..ed676e05f 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -719,12 +719,12 @@ func TestCreateNode(t *testing.T) { m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("mashraki"). WillReturnResult(sqlmock.NewResult(1, 1)) - m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). - WithArgs(1, 2, 2, 1, 1, 3, 3, 1). - WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). WithArgs(4, 1, 5, 1). WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). + WithArgs(1, 2, 2, 1, 1, 3, 3, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, @@ -740,6 +740,222 @@ func TestCreateNode(t *testing.T) { } } +type user struct { + id int + age int + name string +} + +func (*user) values() []interface{} { + return []interface{}{&NullInt64{}, &NullInt64{}, &NullString{}} +} + +func (u *user) assign(values ...interface{}) error { + u.id = int(values[0].(*NullInt64).Int64) + u.age = int(values[1].(*NullInt64).Int64) + u.name = values[2].(*NullString).String + return nil +} + +func TestUpdateOne(t *testing.T) { + tests := []struct { + name string + spec *UpdateSpec + prepare func(sqlmock.Sqlmock) + wantErr bool + wantUser *user + }{ + { + name: "fields/set", + spec: &UpdateSpec{ + Node: &NodeSpec{ + Table: "users", + Columns: []string{"id", "name", "age"}, + ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, + }, + Fields: FieldMut{ + Set: []*FieldSpec{ + {Column: "age", Type: field.TypeInt, Value: 30}, + {Column: "name", Type: field.TypeString, Value: "Ariel"}, + }, + }, + }, + prepare: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ? WHERE `id` = ?")). + WithArgs(30, "Ariel", 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). + AddRow(1, 30, "Ariel")) + mock.ExpectCommit() + }, + wantUser: &user{name: "Ariel", age: 30, id: 1}, + }, + { + name: "fields/add_clear", + spec: &UpdateSpec{ + Node: &NodeSpec{ + Table: "users", + Columns: []string{"id", "name", "age"}, + ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, + }, + Fields: FieldMut{ + Add: []*FieldSpec{ + {Column: "age", Type: field.TypeInt, Value: 1}, + }, + Clear: []*FieldSpec{ + {Column: "name", Type: field.TypeString}, + }, + }, + }, + prepare: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`age`, ?) + ? WHERE `id` = ?")). + WithArgs(0, 1, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). + AddRow(1, 31, nil)) + mock.ExpectCommit() + }, + wantUser: &user{age: 31, id: 1}, + }, + { + name: "edges/o2o_non_inverse and m2o", + spec: &UpdateSpec{ + Node: &NodeSpec{ + Table: "users", + Columns: []string{"id", "name", "age"}, + ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, + }, + Edges: EdgeMut{ + Clear: []*EdgeSpec{ + {Rel: O2O, Columns: []string{"car_id"}, Inverse: true}, + {Rel: M2O, Columns: []string{"workplace_id"}, Inverse: true}, + }, + Add: []*EdgeSpec{ + {Rel: O2O, Columns: []string{"card_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, + {Rel: M2O, Columns: []string{"parent_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, + }, + }, + }, + prepare: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ? WHERE `id` = ?")). + WithArgs(2, 2, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). + AddRow(1, 31, nil)) + mock.ExpectCommit() + }, + wantUser: &user{age: 31, id: 1}, + }, + { + name: "edges/o2o_bidi", + spec: &UpdateSpec{ + Node: &NodeSpec{ + Table: "users", + Columns: []string{"id", "name", "age"}, + ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, + }, + Edges: EdgeMut{ + Clear: []*EdgeSpec{ + {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}}, + }, + Add: []*EdgeSpec{ + {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3}}}, + }, + }, + }, + prepare: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + // Clear "spouse 2" from 1's column, and set "spouse 3". + mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL, `spouse_id` = ? WHERE `id` = ?")). + WithArgs(3, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + // Clear "spouse 1" from 3's column. + mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL WHERE (`id` = ?) AND (`spouse_id` = ?)")). + WithArgs(2, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + // Set 3's column to point "spouse 1". + mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = ? WHERE (`id` = ?) AND (`spouse_id` IS NULL)")). + WithArgs(1, 3). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). + AddRow(1, 31, nil)) + mock.ExpectCommit() + }, + wantUser: &user{age: 31, id: 1}, + }, + { + name: "edges/clear_add_m2m", + spec: &UpdateSpec{ + Node: &NodeSpec{ + Table: "users", + Columns: []string{"id", "name", "age"}, + ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, + }, + Edges: EdgeMut{ + Clear: []*EdgeSpec{ + {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}}, + {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3}}}, + }, + Add: []*EdgeSpec{ + {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{4}}}, + {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{5}}}, + {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{6}}}, + }, + }, + }, + prepare: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + // Clear user groups. + mock.ExpectExec(escape("DELETE FROM `group_users` WHERE (`group_id` = ? AND `user_id` = ?)")). + WithArgs(3, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + // Clear user friends. + mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE ((`user_id` = ? AND `friend_id` = ?) OR (`user_id` = ? AND `friend_id` = ?))")). + WithArgs(1, 2, 2, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + // Add new groups. + mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). + WithArgs(5, 1, 6, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + // Add new friends. + mock.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")). + WithArgs(1, 4, 4, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). + AddRow(1, 31, nil)) + mock.ExpectCommit() + }, + wantUser: &user{age: 31, id: 1}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + tt.prepare(mock) + usr := &user{} + tt.spec.Assign = usr.assign + tt.spec.ScanTypes = usr.values() + err = UpdateNode(context.Background(), OpenDB("", db), tt.spec) + require.Equal(t, tt.wantErr, err != nil, err) + require.Equal(t, tt.wantUser, usr) + }) + } +} + func escape(query string) string { rows := strings.Split(query, "\n") for i := range rows { diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 441f92b6d..bc5029796 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -195,6 +195,7 @@ func Sanity(t *testing.T, client *ent.Client) { usr = client.User.UpdateOne(usr).SetName("baz").AddGroups(grp).SaveX(ctx) require.Equal("baz", usr.Name) require.NotEmpty(usr.QueryGroups().AllX(ctx)) + // grouping. var v []struct { Name string `json:"name"`