From bb051603ac718c475d10bced631f8ca0d8482c0b Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Wed, 4 Dec 2019 09:43:29 -0800 Subject: [PATCH] dialect/sqlgraph: add edges in node creation (#216) Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/216 WIP - ignore for now Reviewed By: alexsn Differential Revision: D18795361 fbshipit-source-id: d3a4ef5562be5faf0837cad6364130ec203a9d37 --- dialect/sql/builder.go | 15 +++ dialect/sql/graph.go | 166 ++++++++++++++++++++++----- dialect/sql/graph_test.go | 231 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 385 insertions(+), 27 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 10fa8605e..571ecf9ad 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -6,6 +6,7 @@ package sql import ( "bytes" + "database/sql/driver" "fmt" "strconv" "strings" @@ -943,6 +944,11 @@ func InInts(col string, args ...int) *Predicate { return (&Predicate{}).InInts(col, args...) } +// InValues adds the `IN` predicate for slice of driver.Value. +func InValues(col string, args ...driver.Value) *Predicate { + return (&Predicate{}).InValues(col, args...) +} + // InInts adds the `IN` predicate for ints. func (p *Predicate) InInts(col string, args ...int) *Predicate { iface := make([]interface{}, len(args)) @@ -952,6 +958,15 @@ func (p *Predicate) InInts(col string, args ...int) *Predicate { return p.In(col, iface...) } +// InValues adds the `IN` predicate for slice of driver.Value. +func (p *Predicate) InValues(col string, args ...driver.Value) *Predicate { + iface := make([]interface{}, len(args)) + for i := range args { + iface[i] = args[i] + } + return p.In(col, iface...) +} + // NotIn returns the `Not IN` predicate. func NotIn(col string, args ...interface{}) *Predicate { return (&Predicate{}).NotIn(col, args...) diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index fba3dbdd5..7f86c097b 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -283,61 +283,173 @@ type ( Value driver.Value // value to be stored. } + // EdgeTarget holds the information for the target nodes + // of an edge. + EdgeTarget struct { + Nodes []driver.Value + IDSpec *FieldSpec + } + // EdgeSpec holds the information for updating a field // column in the database. EdgeSpec struct { Rel Rel + Inverse bool Table string Columns []string - Inverse bool - Value driver.Value + Bidi bool // bidirectional edge. + Target *EdgeTarget // target nodes. } + // EdgeSpecs used for perform common operations on list of edges. + EdgeSpecs []*EdgeSpec + // CreateSpec holds the information for creating a node // in the graph. CreateSpec struct { - // Type or table name. - Table string - // ID field. - ID *FieldSpec - // Fields. + Table string + ID *FieldSpec Fields []*FieldSpec - // Edges. - Edges []*EdgeSpec + Edges []*EdgeSpec } ) -// CreateNode applies the spec on the graph. +// CreateNode applies the CreateSpec on the graph. func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error { tx, err := drv.Tx(ctx) if err != nil { return err } - insert := Dialect(drv.Dialect()).Insert(spec.Table).Default() - for _, fi := range spec.Fields { + cr := &creator{CreateSpec: spec, builder: Dialect(drv.Dialect())} + if err := cr.node(ctx, tx); err != nil { + return rollback(tx, err) + } + return tx.Commit() +} + +type creator struct { + *CreateSpec + builder *dialectBuilder +} + +func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error { + var ( + res sql.Result + edges = EdgeSpecs(c.Edges).GroupRel() + insert = c.builder.Insert(c.Table).Default() + ) + // Set and create the node. + if err := c.setTableColumns(insert, edges); err != nil { + return err + } + if err := c.insert(ctx, tx, insert); err != nil { + return fmt.Errorf("insert node to table %s: %v", c.Table, err) + } + // Insert all M2M edges from the same type at once. + // The EdgeSpec is the same for all members in a group. + tables := EdgeSpecs(edges[M2M]).GroupTable() + for table, edges := range tables { + edge := edges[0] + insert = c.builder.Insert(table).Columns(edge.Columns...) + for _, edge := range edges { + pk1, pk2 := c.ID.Value, edge.Target.Nodes[0] + if edge.Inverse { + pk1, pk2 = pk2, pk1 + } + insert.Values(pk1, pk2) + if edge.Bidi { + insert.Values(pk2, pk1) + } + } + query, args := insert.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return fmt.Errorf("add m2m edge for table %s: %v", table, err) + } + } + // O2M and non-inverse O2O edges also reside in external tables. + for _, edge := range append(edges[O2M], edges[O2O]...) { + if edge.Rel == O2O && edge.Inverse { + continue + } + p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0]) + // Use "IN" predicate instead of list of "OR" + // in case of more than on nodes to connect. + if len(edge.Target.Nodes) > 1 { + p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...) + } + query, args := c.builder.Update(edge.Table). + Set(edge.Columns[0], c.ID.Value). + Where(And(p, IsNull(edge.Columns[0]))). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return fmt.Errorf("add m2m edge for table %s: %v", edge.Table, err) + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if ids := edge.Target.Nodes; int(affected) < len(ids) { + return fmt.Errorf("one of %v is already connected to a different %s", ids, edge.Columns[0]) + } + } + return nil +} + +// setTableColumns sets the table columns and foreign_keys used in insert. +func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSpec) (err error) { + for _, fi := range c.Fields { value := fi.Value if fi.Type == field.TypeJSON { if value, err = json.Marshal(value); err != nil { - return err + return fmt.Errorf("marshal value for column %s: %v", fi.Column, err) } } insert.Set(fi.Column, value) } - // ID was provided by the user. - if spec.ID.Value != nil { - insert.Set(spec.ID.Column, spec.ID.Value) - query, args := insert.Query() - if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil { - return rollback(tx, err) - } - } else { - id, err := insertLastID(ctx, tx, insert.Returning(spec.ID.Column)) - if err != nil { - return rollback(tx, err) - } - spec.ID.Value = id + for _, e := range edges[M2O] { + insert.Set(e.Columns[0], e.Target.Nodes[0]) } - return tx.Commit() + for _, e := range edges[O2O] { + if e.Inverse || e.Bidi { + insert.Set(e.Columns[0], e.Target.Nodes[0]) + } + } + return nil +} + +// insert inserts the node to its table and sets its ID if it wasn't provided by the user. +func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) error { + var res sql.Result + // If the id field was provided by the user. + if c.ID.Value != nil { + insert.Set(c.ID.Column, c.ID.Value) + query, args := insert.Query() + return tx.Exec(ctx, query, args, &res) + } + id, err := insertLastID(ctx, tx, insert.Returning(c.ID.Column)) + if err != nil { + return err + } + c.ID.Value = id + return nil +} + +// GroupRel groups edges by their relation type. +func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec { + edges := make(map[Rel][]*EdgeSpec) + for _, edge := range es { + edges[edge.Rel] = append(edges[edge.Rel], edge) + } + return edges +} + +// GroupTable groups edges by their table name. +func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec { + edges := make(map[string][]*EdgeSpec) + for _, edge := range es { + edges[edge.Table] = append(edges[edge.Table], edge) + } + return edges } // insertLastID invokes the insert query on the transaction and returns the LastInsertID. diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index aa4895c5b..364f6cda3 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -6,6 +6,7 @@ package sql import ( "context" + "database/sql/driver" "regexp" "strings" "testing" @@ -497,6 +498,236 @@ func TestCreateNode(t *testing.T) { m.ExpectCommit() }, }, + { + name: "edges/m2o", + spec: &CreateSpec{ + Table: "pets", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "pedro"}, + }, + Edges: []*EdgeSpec{ + {Rel: M2O, Columns: []string{"owner_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `pets` (`name`, `owner_id`) VALUES (?, ?)")). + WithArgs("pedro", 2). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "edges/o2o/inverse", + spec: &CreateSpec{ + Table: "cards", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "number", Type: field.TypeString, Value: "0001"}, + }, + Edges: []*EdgeSpec{ + {Rel: O2O, Columns: []string{"owner_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `cards` (`number`, `owner_id`) VALUES (?, ?)")). + WithArgs("0001", 2). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "edges/o2m", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "a8m"}, + }, + Edges: []*EdgeSpec{ + {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). + WithArgs("a8m"). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE (`id` = ?) AND (`owner_id` IS NULL)")). + WithArgs(1, 2). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "edges/o2m", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "a8m"}, + }, + Edges: []*EdgeSpec{ + {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2, 3, 4}, IDSpec: &FieldSpec{Column: "id"}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). + WithArgs("a8m"). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE (`id` IN (?, ?, ?)) AND (`owner_id` IS NULL)")). + WithArgs(1, 2, 3, 4). + WillReturnResult(sqlmock.NewResult(1, 3)) + m.ExpectCommit() + }, + }, + { + name: "edges/o2o", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "a8m"}, + }, + Edges: []*EdgeSpec{ + {Rel: O2O, Table: "cards", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). + WithArgs("a8m"). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("UPDATE `cards` SET `owner_id` = ? WHERE (`id` = ?) AND (`owner_id` IS NULL)")). + WithArgs(1, 2). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "edges/o2o/bidi", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "a8m"}, + }, + Edges: []*EdgeSpec{ + {Rel: O2O, Bidi: true, Table: "users", Columns: []string{"spouse_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`name`, `spouse_id`) VALUES (?, ?)")). + WithArgs("a8m", 2). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("UPDATE `users` SET `spouse_id` = ? WHERE (`id` = ?) AND (`spouse_id` IS NULL)")). + WithArgs(1, 2). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "edges/m2m", + spec: &CreateSpec{ + Table: "groups", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "GitHub"}, + }, + Edges: []*EdgeSpec{ + {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `groups` (`name`) VALUES (?)")). + WithArgs("GitHub"). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?)")). + WithArgs(1, 2). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "edges/m2m/inverse", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "mashraki"}, + }, + Edges: []*EdgeSpec{ + {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). + WithArgs("mashraki"). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?)")). + WithArgs(2, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "edges/m2m/bidi", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "mashraki"}, + }, + Edges: []*EdgeSpec{ + {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). + WithArgs("mashraki"). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")). + WithArgs(1, 2, 2, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "edges/m2m/bidi/batch", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "name", Type: field.TypeString, Value: "mashraki"}, + }, + Edges: []*EdgeSpec{ + {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, + {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}}, + {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{4}, IDSpec: &FieldSpec{Column: "id"}}}, + {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{5}, IDSpec: &FieldSpec{Column: "id"}}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). + WithArgs("mashraki"). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). + WithArgs(1, 2, 2, 1, 1, 3, 3, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). + WithArgs(4, 1, 5, 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {