dialect/sql/sqlgraph: initial work for batch insert (#573)

This is the first part for adding batch insert support for the framework.
The second part if the codegen.
This commit is contained in:
Ariel Mashraki
2020-07-08 17:48:26 +03:00
committed by GitHub
parent 7df2e02343
commit 720766432a
2 changed files with 305 additions and 11 deletions

View File

@@ -324,14 +324,21 @@ type (
}
)
// CreateSpec holds the information for creating
// a node in the graph.
type CreateSpec struct {
Table string
ID *FieldSpec
Fields []*FieldSpec
Edges []*EdgeSpec
}
type (
// CreateSpec holds the information for creating
// a node in the graph.
CreateSpec struct {
Table string
ID *FieldSpec
Fields []*FieldSpec
Edges []*EdgeSpec
}
// BatchCreateSpec holds the information for creating
// multiple nodes in the graph.
BatchCreateSpec struct {
Nodes []*CreateSpec
}
)
// CreateNode applies the CreateSpec on the graph.
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
@@ -347,6 +354,20 @@ func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error
return tx.Commit()
}
// BatchCreate applies the BatchCreateSpec on the graph.
func BatchCreate(ctx context.Context, drv dialect.Driver, spec *BatchCreateSpec) error {
tx, err := drv.Tx(ctx)
if err != nil {
return err
}
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
cr := &creator{BatchCreateSpec: spec, graph: gr}
if err := cr.nodes(ctx, tx); err != nil {
return rollback(tx, err)
}
return tx.Commit()
}
type (
// EdgeMut defines edge mutations.
EdgeMut struct {
@@ -737,6 +758,7 @@ func (u *updater) scan(rows *sql.Rows) error {
type creator struct {
graph
*CreateSpec
*BatchCreateSpec
}
func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
@@ -760,6 +782,70 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
return nil
}
func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
if len(c.Nodes) == 0 {
return nil
}
columns := make(map[string]struct{})
values := make([]map[string]driver.Value, len(c.Nodes))
for i, node := range c.Nodes {
if i > 0 && node.Table != c.Nodes[i-1].Table {
return fmt.Errorf("more than 1 table for batch insert: %q != %q", node.Table, c.Nodes[i-1].Table)
}
values[i] = make(map[string]driver.Value)
if node.ID.Value != nil {
columns[node.ID.Column] = struct{}{}
values[i][node.ID.Column] = node.ID.Value
}
edges := EdgeSpecs(node.Edges).GroupRel()
err := setTableColumns(node.Fields, edges, func(column string, value driver.Value) {
columns[column] = struct{}{}
values[i][column] = value
})
if err != nil {
return err
}
}
for column := range columns {
for i := range values {
switch _, exists := values[i][column]; {
case column == c.Nodes[i].ID.Column && !exists:
// If the ID value was provided to one of the nodes, it should be
// provided to all others because this affects the way we calculate
// their values in MySQL and SQLite dialects.
return fmt.Errorf("incosistent id values for batch insert")
case !exists:
// Assign NULL values for empty placeholders.
values[i][column] = nil
}
}
}
sorted := keys(columns)
insert := c.builder.Insert(c.Nodes[0].Table).Default().Columns(sorted...)
for i := range values {
vs := make([]interface{}, len(sorted))
for j, c := range sorted {
vs[j] = values[i][c]
}
insert.Values(vs...)
}
if err := c.batchInsert(ctx, tx, insert); err != nil {
return fmt.Errorf("insert nodes to table %q: %v", c.Nodes[0].Table, err)
}
if err := c.batchAddM2M(ctx, c.BatchCreateSpec); err != nil {
return err
}
// FKs that exist in different tables can't be updated in batch (using the CASE
// statement), because we rely on RowsAffected to check if the FK column is NULL.
for _, node := range c.Nodes {
edges := EdgeSpecs(node.Edges).GroupRel()
if err := c.graph.addFKEdges(ctx, []driver.Value{node.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil {
return err
}
}
return nil
}
// setTableColumns sets the table columns and foreign_keys used in insert.
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
@@ -785,6 +871,18 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
return nil
}
// batchInsert inserts a batch of nodes to their table and sets their ID if it wasn't provided by the user.
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
ids, err := insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column))
if err != nil {
return err
}
for i, node := range c.Nodes {
node.ID.Value = ids[i]
}
return nil
}
// GroupRel groups edges by their relation type.
func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec {
edges := make(map[Rel][]*EdgeSpec)
@@ -803,6 +901,17 @@ func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec {
return edges
}
// FilterRel returns edges for the given relation type.
func (es EdgeSpecs) FilterRel(r Rel) EdgeSpecs {
edges := make([]*EdgeSpec, 0, len(es))
for _, edge := range es {
if edge.Rel == r {
edges = append(edges, edge)
}
}
return edges
}
// The common operations shared between the different builders.
//
// M2M edges reside in join tables and require INSERT and DELETE
@@ -822,7 +931,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
// The EdgeSpec is the same for all members in a group.
tables = edges.GroupTable()
)
for _, table := range sortedKeys(tables) {
for _, table := range edgeKeys(tables) {
edges := tables[table]
preds := make([]*sql.Predicate, 0, len(edges))
for _, edge := range edges {
@@ -850,7 +959,7 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
// The EdgeSpec is the same for all members in a group.
tables = edges.GroupTable()
)
for _, table := range sortedKeys(tables) {
for _, table := range edgeKeys(tables) {
edges := tables[table]
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
for _, edge := range edges {
@@ -873,6 +982,44 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
return nil
}
func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error {
tables := make(map[string]*sql.InsertBuilder)
for _, node := range spec.Nodes {
edges := EdgeSpecs(node.Edges).FilterRel(M2M)
for t, edges := range edges.GroupTable() {
insert, ok := tables[t]
if !ok {
insert = g.builder.Insert(t).Columns(edges[0].Columns...)
}
tables[t] = insert
if len(edges) != 1 {
return fmt.Errorf("expect exactly 1 edge-spec per table, but got %d", len(edges))
}
edge := edges[0]
pk1, pk2 := []driver.Value{node.ID.Value}, edge.Target.Nodes
if edge.Inverse {
pk1, pk2 = pk2, pk1
}
for _, pair := range product(pk1, pk2) {
insert.Values(pair[0], pair[1])
if edge.Bidi {
insert.Values(pair[1], pair[0])
}
}
}
}
for _, table := range insertKeys(tables) {
var (
res sql.Result
query, args = tables[table].Query()
)
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
return fmt.Errorf("add m2m edge for table %s: %v", table, err)
}
}
return nil
}
func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error {
for _, edge := range edges {
if edge.Rel == O2O && edge.Inverse {
@@ -925,6 +1072,8 @@ func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*Edg
if err != nil {
return err
}
// Setting the FK value of the "other" table
// without clearing it before, is not allowed.
if ids := edge.Target.Nodes; int(affected) < len(ids) {
return &ConstraintError{msg: fmt.Sprintf("one of %v is already connected to a different %s", ids, edge.Columns[0])}
}
@@ -979,6 +1128,45 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.Inser
return res.LastInsertId()
}
// insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities.
func insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (ids []int64, err error) {
query, args := insert.Query()
// PostgreSQL does not support the LastInsertId() method of sql.Result
// on Exec, and should be extracted manually using the `RETURNING` clause.
if insert.Dialect() == dialect.Postgres {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, err
}
defer rows.Close()
return ids, sql.ScanSlice(rows, &ids)
}
// MySQL, SQLite, etc.
var res sql.Result
if err := tx.Exec(ctx, query, args, &res); err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
affected, err := res.RowsAffected()
if err != nil {
return nil, err
}
ids = make([]int64, 0, affected)
switch insert.Dialect() {
case dialect.SQLite:
id -= affected - 1
fallthrough
case dialect.MySQL:
for i := int64(0); i < affected; i++ {
ids = append(ids, id+i)
}
}
return ids, nil
}
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
func rollback(tx dialect.Tx, err error) error {
if rerr := tx.Rollback(); rerr != nil {
@@ -987,7 +1175,25 @@ func rollback(tx dialect.Tx, err error) error {
return err
}
func sortedKeys(m map[string][]*EdgeSpec) []string {
func edgeKeys(m map[string][]*EdgeSpec) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func insertKeys(m map[string]*sql.InsertBuilder) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func keys(m map[string]struct{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)