mirror of
https://github.com/ent/ent.git
synced 2026-04-28 21:50:56 +03:00
dialect/sql/sqlgraph: handle edge schema in batch inserts (#2718)
This commit is contained in:
@@ -972,7 +972,7 @@ func (c *creator) ensureConflict(insert *sql.InsertBuilder) {
|
||||
}
|
||||
|
||||
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
|
||||
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
|
||||
// 'ON DUPLICATE ... UPDATE' clause in it was not provided.
|
||||
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
|
||||
if c.ID == nil || !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
|
||||
return
|
||||
@@ -1003,7 +1003,7 @@ func (c *batchCreator) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
return fmt.Errorf("more than 1 table for batch insert: %q != %q", node.Table, c.Nodes[i-1].Table)
|
||||
}
|
||||
values[i] = make(map[string]driver.Value)
|
||||
if node.ID.Value != nil {
|
||||
if node.ID != nil && node.ID.Value != nil {
|
||||
columns[node.ID.Column] = struct{}{}
|
||||
values[i][node.ID.Column] = node.ID.Value
|
||||
}
|
||||
@@ -1018,13 +1018,13 @@ func (c *batchCreator) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
}
|
||||
for column := range columns {
|
||||
for i := range values {
|
||||
switch _, exists := values[i][column]; {
|
||||
case column == c.Nodes[i].ID.Column && !exists:
|
||||
// If the ID value was provided to one of the nodes, it should be
|
||||
// provided to all others because this affects the way we calculate
|
||||
// their values in MySQL and SQLite dialects.
|
||||
return fmt.Errorf("incosistent id values for batch insert")
|
||||
case !exists:
|
||||
if _, exists := values[i][column]; !exists {
|
||||
if c.Nodes[i].ID != nil && column == c.Nodes[i].ID.Column {
|
||||
// If the ID value was provided to one of the nodes, it should be
|
||||
// provided to all others because this affects the way we calculate
|
||||
// their values in MySQL and SQLite dialects.
|
||||
return fmt.Errorf("incosistent id values for batch insert")
|
||||
}
|
||||
// Assign NULL values for empty placeholders.
|
||||
values[i][column] = nil
|
||||
}
|
||||
@@ -1045,6 +1045,13 @@ func (c *batchCreator) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
}
|
||||
c.tx = tx
|
||||
if err := func() error {
|
||||
// In case the spec does not contain an ID field, we assume
|
||||
// we interact with an edge-schema with composite primary key.
|
||||
if c.Nodes[0].ID == nil {
|
||||
c.ensureConflict(insert)
|
||||
query, args := insert.Query()
|
||||
return tx.Exec(ctx, query, args, nil)
|
||||
}
|
||||
if err := c.batchInsert(ctx, tx, insert); err != nil {
|
||||
return fmt.Errorf("insert nodes to table %q: %w", c.Nodes[0].Table, err)
|
||||
}
|
||||
@@ -1080,10 +1087,15 @@ func (c *batchCreator) mayTx(ctx context.Context, drv dialect.Driver) (dialect.T
|
||||
|
||||
// batchInsert inserts a batch of nodes to their table and sets their ID if it was not provided by the user.
|
||||
func (c *batchCreator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
c.ensureConflict(insert)
|
||||
return c.insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column))
|
||||
}
|
||||
|
||||
// ensureConflict ensures the ON CONFLICT is added to the insert statement.
|
||||
func (c *batchCreator) ensureConflict(insert *sql.InsertBuilder) {
|
||||
if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 {
|
||||
insert.OnConflict(opts...)
|
||||
}
|
||||
return c.insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column))
|
||||
}
|
||||
|
||||
// GroupRel groups edges by their relation type.
|
||||
|
||||
Reference in New Issue
Block a user