dialect/sql/sqlgraph: handle edge schema in batch inserts (#2718)

This commit is contained in:
Ariel Mashraki
2022-07-05 14:15:31 +03:00
committed by GitHub
parent 5b67bdab4f
commit 8c55008a9d
2 changed files with 44 additions and 10 deletions

View File

@@ -972,7 +972,7 @@ func (c *creator) ensureConflict(insert *sql.InsertBuilder) {
}
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
// 'ON DUPLICATE ... UPDATE' clause in it was not provided.
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
if c.ID == nil || !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
return
@@ -1003,7 +1003,7 @@ func (c *batchCreator) nodes(ctx context.Context, drv dialect.Driver) error {
return fmt.Errorf("more than 1 table for batch insert: %q != %q", node.Table, c.Nodes[i-1].Table)
}
values[i] = make(map[string]driver.Value)
if node.ID.Value != nil {
if node.ID != nil && node.ID.Value != nil {
columns[node.ID.Column] = struct{}{}
values[i][node.ID.Column] = node.ID.Value
}
@@ -1018,13 +1018,13 @@ func (c *batchCreator) nodes(ctx context.Context, drv dialect.Driver) error {
}
for column := range columns {
for i := range values {
switch _, exists := values[i][column]; {
case column == c.Nodes[i].ID.Column && !exists:
// If the ID value was provided to one of the nodes, it should be
// provided to all others because this affects the way we calculate
// their values in MySQL and SQLite dialects.
return fmt.Errorf("incosistent id values for batch insert")
case !exists:
if _, exists := values[i][column]; !exists {
if c.Nodes[i].ID != nil && column == c.Nodes[i].ID.Column {
// If the ID value was provided to one of the nodes, it should be
// provided to all others because this affects the way we calculate
// their values in MySQL and SQLite dialects.
return fmt.Errorf("incosistent id values for batch insert")
}
// Assign NULL values for empty placeholders.
values[i][column] = nil
}
@@ -1045,6 +1045,13 @@ func (c *batchCreator) nodes(ctx context.Context, drv dialect.Driver) error {
}
c.tx = tx
if err := func() error {
// In case the spec does not contain an ID field, we assume
// we interact with an edge-schema with composite primary key.
if c.Nodes[0].ID == nil {
c.ensureConflict(insert)
query, args := insert.Query()
return tx.Exec(ctx, query, args, nil)
}
if err := c.batchInsert(ctx, tx, insert); err != nil {
return fmt.Errorf("insert nodes to table %q: %w", c.Nodes[0].Table, err)
}
@@ -1080,10 +1087,15 @@ func (c *batchCreator) mayTx(ctx context.Context, drv dialect.Driver) (dialect.T
// batchInsert inserts a batch of nodes to their table and sets their ID if it was not provided by the user.
func (c *batchCreator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
c.ensureConflict(insert)
return c.insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column))
}
// ensureConflict ensures the ON CONFLICT is added to the insert statement.
func (c *batchCreator) ensureConflict(insert *sql.InsertBuilder) {
if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
}
return c.insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column))
}
// GroupRel groups edges by their relation type.