dialect/sqlgraph: add edges in node creation (#216)

Summary:
Pull Request resolved: https://github.com/facebookincubator/ent/pull/216

WIP - ignore for now

Reviewed By: alexsn

Differential Revision: D18795361

fbshipit-source-id: d3a4ef5562be5faf0837cad6364130ec203a9d37
This commit is contained in:
Ariel Mashraki
2019-12-04 09:43:29 -08:00
committed by Facebook Github Bot
parent a5e4a9cf54
commit bb051603ac
3 changed files with 385 additions and 27 deletions

View File

@@ -6,6 +6,7 @@ package sql
import (
"bytes"
"database/sql/driver"
"fmt"
"strconv"
"strings"
@@ -943,6 +944,11 @@ func InInts(col string, args ...int) *Predicate {
return (&Predicate{}).InInts(col, args...)
}
// InValues adds the `IN` predicate for slice of driver.Value.
func InValues(col string, args ...driver.Value) *Predicate {
return (&Predicate{}).InValues(col, args...)
}
// InInts adds the `IN` predicate for ints.
func (p *Predicate) InInts(col string, args ...int) *Predicate {
iface := make([]interface{}, len(args))
@@ -952,6 +958,15 @@ func (p *Predicate) InInts(col string, args ...int) *Predicate {
return p.In(col, iface...)
}
// InValues adds the `IN` predicate for slice of driver.Value.
func (p *Predicate) InValues(col string, args ...driver.Value) *Predicate {
iface := make([]interface{}, len(args))
for i := range args {
iface[i] = args[i]
}
return p.In(col, iface...)
}
// NotIn returns the `Not IN` predicate.
func NotIn(col string, args ...interface{}) *Predicate {
return (&Predicate{}).NotIn(col, args...)

View File

@@ -283,61 +283,173 @@ type (
Value driver.Value // value to be stored.
}
// EdgeTarget holds the information for the target nodes
// of an edge.
EdgeTarget struct {
Nodes []driver.Value
IDSpec *FieldSpec
}
// EdgeSpec holds the information for updating a field
// column in the database.
EdgeSpec struct {
Rel Rel
Inverse bool
Table string
Columns []string
Inverse bool
Value driver.Value
Bidi bool // bidirectional edge.
Target *EdgeTarget // target nodes.
}
// EdgeSpecs used for perform common operations on list of edges.
EdgeSpecs []*EdgeSpec
// CreateSpec holds the information for creating a node
// in the graph.
CreateSpec struct {
// Type or table name.
Table string
// ID field.
ID *FieldSpec
// Fields.
Table string
ID *FieldSpec
Fields []*FieldSpec
// Edges.
Edges []*EdgeSpec
Edges []*EdgeSpec
}
)
// CreateNode applies the spec on the graph.
// CreateNode applies the CreateSpec on the graph.
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
tx, err := drv.Tx(ctx)
if err != nil {
return err
}
insert := Dialect(drv.Dialect()).Insert(spec.Table).Default()
for _, fi := range spec.Fields {
cr := &creator{CreateSpec: spec, builder: Dialect(drv.Dialect())}
if err := cr.node(ctx, tx); err != nil {
return rollback(tx, err)
}
return tx.Commit()
}
type creator struct {
*CreateSpec
builder *dialectBuilder
}
func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
var (
res sql.Result
edges = EdgeSpecs(c.Edges).GroupRel()
insert = c.builder.Insert(c.Table).Default()
)
// Set and create the node.
if err := c.setTableColumns(insert, edges); err != nil {
return err
}
if err := c.insert(ctx, tx, insert); err != nil {
return fmt.Errorf("insert node to table %s: %v", c.Table, err)
}
// Insert all M2M edges from the same type at once.
// The EdgeSpec is the same for all members in a group.
tables := EdgeSpecs(edges[M2M]).GroupTable()
for table, edges := range tables {
edge := edges[0]
insert = c.builder.Insert(table).Columns(edge.Columns...)
for _, edge := range edges {
pk1, pk2 := c.ID.Value, edge.Target.Nodes[0]
if edge.Inverse {
pk1, pk2 = pk2, pk1
}
insert.Values(pk1, pk2)
if edge.Bidi {
insert.Values(pk2, pk1)
}
}
query, args := insert.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return fmt.Errorf("add m2m edge for table %s: %v", table, err)
}
}
// O2M and non-inverse O2O edges also reside in external tables.
for _, edge := range append(edges[O2M], edges[O2O]...) {
if edge.Rel == O2O && edge.Inverse {
continue
}
p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
// Use "IN" predicate instead of list of "OR"
// in case of more than on nodes to connect.
if len(edge.Target.Nodes) > 1 {
p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
}
query, args := c.builder.Update(edge.Table).
Set(edge.Columns[0], c.ID.Value).
Where(And(p, IsNull(edge.Columns[0]))).
Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return fmt.Errorf("add m2m edge for table %s: %v", edge.Table, err)
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if ids := edge.Target.Nodes; int(affected) < len(ids) {
return fmt.Errorf("one of %v is already connected to a different %s", ids, edge.Columns[0])
}
}
return nil
}
// setTableColumns sets the table columns and foreign_keys used in insert.
func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSpec) (err error) {
for _, fi := range c.Fields {
value := fi.Value
if fi.Type == field.TypeJSON {
if value, err = json.Marshal(value); err != nil {
return err
return fmt.Errorf("marshal value for column %s: %v", fi.Column, err)
}
}
insert.Set(fi.Column, value)
}
// ID was provided by the user.
if spec.ID.Value != nil {
insert.Set(spec.ID.Column, spec.ID.Value)
query, args := insert.Query()
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
return rollback(tx, err)
}
} else {
id, err := insertLastID(ctx, tx, insert.Returning(spec.ID.Column))
if err != nil {
return rollback(tx, err)
}
spec.ID.Value = id
for _, e := range edges[M2O] {
insert.Set(e.Columns[0], e.Target.Nodes[0])
}
return tx.Commit()
for _, e := range edges[O2O] {
if e.Inverse || e.Bidi {
insert.Set(e.Columns[0], e.Target.Nodes[0])
}
}
return nil
}
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) error {
var res sql.Result
// If the id field was provided by the user.
if c.ID.Value != nil {
insert.Set(c.ID.Column, c.ID.Value)
query, args := insert.Query()
return tx.Exec(ctx, query, args, &res)
}
id, err := insertLastID(ctx, tx, insert.Returning(c.ID.Column))
if err != nil {
return err
}
c.ID.Value = id
return nil
}
// GroupRel groups edges by their relation type.
func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec {
edges := make(map[Rel][]*EdgeSpec)
for _, edge := range es {
edges[edge.Rel] = append(edges[edge.Rel], edge)
}
return edges
}
// GroupTable groups edges by their table name.
func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec {
edges := make(map[string][]*EdgeSpec)
for _, edge := range es {
edges[edge.Table] = append(edges[edge.Table], edge)
}
return edges
}
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.

View File

@@ -6,6 +6,7 @@ package sql
import (
"context"
"database/sql/driver"
"regexp"
"strings"
"testing"
@@ -497,6 +498,236 @@ func TestCreateNode(t *testing.T) {
m.ExpectCommit()
},
},
{
name: "edges/m2o",
spec: &CreateSpec{
Table: "pets",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "pedro"},
},
Edges: []*EdgeSpec{
{Rel: M2O, Columns: []string{"owner_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `pets` (`name`, `owner_id`) VALUES (?, ?)")).
WithArgs("pedro", 2).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/o2o/inverse",
spec: &CreateSpec{
Table: "cards",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "number", Type: field.TypeString, Value: "0001"},
},
Edges: []*EdgeSpec{
{Rel: O2O, Columns: []string{"owner_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `cards` (`number`, `owner_id`) VALUES (?, ?)")).
WithArgs("0001", 2).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/o2m",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "a8m"},
},
Edges: []*EdgeSpec{
{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")).
WithArgs("a8m").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE (`id` = ?) AND (`owner_id` IS NULL)")).
WithArgs(1, 2).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/o2m",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "a8m"},
},
Edges: []*EdgeSpec{
{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2, 3, 4}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")).
WithArgs("a8m").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE (`id` IN (?, ?, ?)) AND (`owner_id` IS NULL)")).
WithArgs(1, 2, 3, 4).
WillReturnResult(sqlmock.NewResult(1, 3))
m.ExpectCommit()
},
},
{
name: "edges/o2o",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "a8m"},
},
Edges: []*EdgeSpec{
{Rel: O2O, Table: "cards", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")).
WithArgs("a8m").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("UPDATE `cards` SET `owner_id` = ? WHERE (`id` = ?) AND (`owner_id` IS NULL)")).
WithArgs(1, 2).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/o2o/bidi",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "a8m"},
},
Edges: []*EdgeSpec{
{Rel: O2O, Bidi: true, Table: "users", Columns: []string{"spouse_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`name`, `spouse_id`) VALUES (?, ?)")).
WithArgs("a8m", 2).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("UPDATE `users` SET `spouse_id` = ? WHERE (`id` = ?) AND (`spouse_id` IS NULL)")).
WithArgs(1, 2).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/m2m",
spec: &CreateSpec{
Table: "groups",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "GitHub"},
},
Edges: []*EdgeSpec{
{Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `groups` (`name`) VALUES (?)")).
WithArgs("GitHub").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?)")).
WithArgs(1, 2).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/m2m/inverse",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "mashraki"},
},
Edges: []*EdgeSpec{
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")).
WithArgs("mashraki").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?)")).
WithArgs(2, 1).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/m2m/bidi",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "mashraki"},
},
Edges: []*EdgeSpec{
{Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")).
WithArgs("mashraki").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")).
WithArgs(1, 2, 2, 1).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/m2m/bidi/batch",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "mashraki"},
},
Edges: []*EdgeSpec{
{Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
{Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}},
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{4}, IDSpec: &FieldSpec{Column: "id"}}},
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{5}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")).
WithArgs("mashraki").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")).
WithArgs(1, 2, 2, 1, 1, 3, 3, 1).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")).
WithArgs(4, 1, 5, 1).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {