dialect/sql/sqlgraph: initial work for batch insert (#573)

This is the first part for adding batch insert support for the framework.
The second part if the codegen.
This commit is contained in:
Ariel Mashraki
2020-07-08 17:48:26 +03:00
committed by GitHub
parent 7df2e02343
commit 720766432a
2 changed files with 305 additions and 11 deletions

View File

@@ -324,14 +324,21 @@ type (
}
)
// CreateSpec holds the information for creating
// a node in the graph.
type CreateSpec struct {
Table string
ID *FieldSpec
Fields []*FieldSpec
Edges []*EdgeSpec
}
type (
// CreateSpec holds the information for creating
// a node in the graph.
CreateSpec struct {
Table string
ID *FieldSpec
Fields []*FieldSpec
Edges []*EdgeSpec
}
// BatchCreateSpec holds the information for creating
// multiple nodes in the graph.
BatchCreateSpec struct {
Nodes []*CreateSpec
}
)
// CreateNode applies the CreateSpec on the graph.
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
@@ -347,6 +354,20 @@ func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error
return tx.Commit()
}
// BatchCreate applies the BatchCreateSpec on the graph.
func BatchCreate(ctx context.Context, drv dialect.Driver, spec *BatchCreateSpec) error {
tx, err := drv.Tx(ctx)
if err != nil {
return err
}
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
cr := &creator{BatchCreateSpec: spec, graph: gr}
if err := cr.nodes(ctx, tx); err != nil {
return rollback(tx, err)
}
return tx.Commit()
}
type (
// EdgeMut defines edge mutations.
EdgeMut struct {
@@ -737,6 +758,7 @@ func (u *updater) scan(rows *sql.Rows) error {
type creator struct {
graph
*CreateSpec
*BatchCreateSpec
}
func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
@@ -760,6 +782,70 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
return nil
}
func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
if len(c.Nodes) == 0 {
return nil
}
columns := make(map[string]struct{})
values := make([]map[string]driver.Value, len(c.Nodes))
for i, node := range c.Nodes {
if i > 0 && node.Table != c.Nodes[i-1].Table {
return fmt.Errorf("more than 1 table for batch insert: %q != %q", node.Table, c.Nodes[i-1].Table)
}
values[i] = make(map[string]driver.Value)
if node.ID.Value != nil {
columns[node.ID.Column] = struct{}{}
values[i][node.ID.Column] = node.ID.Value
}
edges := EdgeSpecs(node.Edges).GroupRel()
err := setTableColumns(node.Fields, edges, func(column string, value driver.Value) {
columns[column] = struct{}{}
values[i][column] = value
})
if err != nil {
return err
}
}
for column := range columns {
for i := range values {
switch _, exists := values[i][column]; {
case column == c.Nodes[i].ID.Column && !exists:
// If the ID value was provided to one of the nodes, it should be
// provided to all others because this affects the way we calculate
// their values in MySQL and SQLite dialects.
return fmt.Errorf("incosistent id values for batch insert")
case !exists:
// Assign NULL values for empty placeholders.
values[i][column] = nil
}
}
}
sorted := keys(columns)
insert := c.builder.Insert(c.Nodes[0].Table).Default().Columns(sorted...)
for i := range values {
vs := make([]interface{}, len(sorted))
for j, c := range sorted {
vs[j] = values[i][c]
}
insert.Values(vs...)
}
if err := c.batchInsert(ctx, tx, insert); err != nil {
return fmt.Errorf("insert nodes to table %q: %v", c.Nodes[0].Table, err)
}
if err := c.batchAddM2M(ctx, c.BatchCreateSpec); err != nil {
return err
}
// FKs that exist in different tables can't be updated in batch (using the CASE
// statement), because we rely on RowsAffected to check if the FK column is NULL.
for _, node := range c.Nodes {
edges := EdgeSpecs(node.Edges).GroupRel()
if err := c.graph.addFKEdges(ctx, []driver.Value{node.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil {
return err
}
}
return nil
}
// setTableColumns sets the table columns and foreign_keys used in insert.
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
@@ -785,6 +871,18 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
return nil
}
// batchInsert inserts a batch of nodes to their table and sets their ID if it wasn't provided by the user.
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
ids, err := insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column))
if err != nil {
return err
}
for i, node := range c.Nodes {
node.ID.Value = ids[i]
}
return nil
}
// GroupRel groups edges by their relation type.
func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec {
edges := make(map[Rel][]*EdgeSpec)
@@ -803,6 +901,17 @@ func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec {
return edges
}
// FilterRel returns edges for the given relation type.
func (es EdgeSpecs) FilterRel(r Rel) EdgeSpecs {
edges := make([]*EdgeSpec, 0, len(es))
for _, edge := range es {
if edge.Rel == r {
edges = append(edges, edge)
}
}
return edges
}
// The common operations shared between the different builders.
//
// M2M edges reside in join tables and require INSERT and DELETE
@@ -822,7 +931,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
// The EdgeSpec is the same for all members in a group.
tables = edges.GroupTable()
)
for _, table := range sortedKeys(tables) {
for _, table := range edgeKeys(tables) {
edges := tables[table]
preds := make([]*sql.Predicate, 0, len(edges))
for _, edge := range edges {
@@ -850,7 +959,7 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
// The EdgeSpec is the same for all members in a group.
tables = edges.GroupTable()
)
for _, table := range sortedKeys(tables) {
for _, table := range edgeKeys(tables) {
edges := tables[table]
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
for _, edge := range edges {
@@ -873,6 +982,44 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
return nil
}
func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error {
tables := make(map[string]*sql.InsertBuilder)
for _, node := range spec.Nodes {
edges := EdgeSpecs(node.Edges).FilterRel(M2M)
for t, edges := range edges.GroupTable() {
insert, ok := tables[t]
if !ok {
insert = g.builder.Insert(t).Columns(edges[0].Columns...)
}
tables[t] = insert
if len(edges) != 1 {
return fmt.Errorf("expect exactly 1 edge-spec per table, but got %d", len(edges))
}
edge := edges[0]
pk1, pk2 := []driver.Value{node.ID.Value}, edge.Target.Nodes
if edge.Inverse {
pk1, pk2 = pk2, pk1
}
for _, pair := range product(pk1, pk2) {
insert.Values(pair[0], pair[1])
if edge.Bidi {
insert.Values(pair[1], pair[0])
}
}
}
}
for _, table := range insertKeys(tables) {
var (
res sql.Result
query, args = tables[table].Query()
)
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
return fmt.Errorf("add m2m edge for table %s: %v", table, err)
}
}
return nil
}
func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error {
for _, edge := range edges {
if edge.Rel == O2O && edge.Inverse {
@@ -925,6 +1072,8 @@ func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*Edg
if err != nil {
return err
}
// Setting the FK value of the "other" table
// without clearing it before, is not allowed.
if ids := edge.Target.Nodes; int(affected) < len(ids) {
return &ConstraintError{msg: fmt.Sprintf("one of %v is already connected to a different %s", ids, edge.Columns[0])}
}
@@ -979,6 +1128,45 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.Inser
return res.LastInsertId()
}
// insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities.
func insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (ids []int64, err error) {
query, args := insert.Query()
// PostgreSQL does not support the LastInsertId() method of sql.Result
// on Exec, and should be extracted manually using the `RETURNING` clause.
if insert.Dialect() == dialect.Postgres {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, err
}
defer rows.Close()
return ids, sql.ScanSlice(rows, &ids)
}
// MySQL, SQLite, etc.
var res sql.Result
if err := tx.Exec(ctx, query, args, &res); err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
affected, err := res.RowsAffected()
if err != nil {
return nil, err
}
ids = make([]int64, 0, affected)
switch insert.Dialect() {
case dialect.SQLite:
id -= affected - 1
fallthrough
case dialect.MySQL:
for i := int64(0); i < affected; i++ {
ids = append(ids, id+i)
}
}
return ids, nil
}
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
func rollback(tx dialect.Tx, err error) error {
if rerr := tx.Rollback(); rerr != nil {
@@ -987,7 +1175,25 @@ func rollback(tx dialect.Tx, err error) error {
return err
}
func sortedKeys(m map[string][]*EdgeSpec) []string {
func edgeKeys(m map[string][]*EdgeSpec) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func insertKeys(m map[string]*sql.InsertBuilder) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func keys(m map[string]struct{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)

View File

@@ -741,6 +741,94 @@ func TestCreateNode(t *testing.T) {
}
}
func TestBatchCreate(t *testing.T) {
tests := []struct {
name string
nodes []*CreateSpec
expect func(sqlmock.Sqlmock)
wantErr bool
}{
{
name: "empty",
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectCommit()
},
},
{
name: "multiple",
nodes: []*CreateSpec{
{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 32},
{Column: "name", Type: field.TypeString, Value: "a8m"},
{Column: "active", Type: field.TypeBool, Value: false},
},
Edges: []*EdgeSpec{
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
{Rel: M2M, Table: "user_products", Columns: []string{"user_id", "product_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
{Rel: M2O, Table: "company", Columns: []string{"workplace_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}}},
{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 30},
{Column: "name", Type: field.TypeString, Value: "nati"},
},
Edges: []*EdgeSpec{
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
{Rel: M2M, Table: "user_products", Columns: []string{"user_id", "product_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}},
},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
// Insert nodes with FKs.
m.ExpectExec(escape("INSERT INTO `users` (`active`, `age`, `name`, `workplace_id`) VALUES (?, ?, ?, ?), (?, ?, ?, ?)")).
WithArgs(false, 32, "a8m", 2, nil, 30, "nati", nil).
WillReturnResult(sqlmock.NewResult(10, 2))
// Insert M2M inverse-edges.
m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")).
WithArgs(2, 10, 2, 11).
WillReturnResult(sqlmock.NewResult(2, 2))
// Insert M2M bidirectional edges.
m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")).
WithArgs(10, 2, 2, 10, 11, 2, 2, 11).
WillReturnResult(sqlmock.NewResult(2, 2))
// Insert M2M edges.
m.ExpectExec(escape("INSERT INTO `user_products` (`user_id`, `product_id`) VALUES (?, ?), (?, ?)")).
WithArgs(10, 2, 11, 2).
WillReturnResult(sqlmock.NewResult(2, 2))
// Update FKs exist in different tables.
m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE (`id` = ?) AND (`owner_id` IS NULL")).
WithArgs(10 /* id of the 1st new node */, 2 /* pet id */).
WillReturnResult(sqlmock.NewResult(2, 2))
m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE (`id` = ?) AND (`owner_id` IS NULL")).
WithArgs(11 /* id of the 2nd new node */, 3 /* pet id */).
WillReturnResult(sqlmock.NewResult(2, 2))
m.ExpectCommit()
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tt.expect(mock)
err = BatchCreate(context.Background(), sql.OpenDB("mysql", db), &BatchCreateSpec{Nodes: tt.nodes})
require.Equal(t, tt.wantErr, err != nil, err)
})
}
}
type user struct {
id int
age int