mirror of
https://github.com/ent/ent.git
synced 2026-04-28 13:40:56 +03:00
dialect/sql/sqlgraph: allow scanning non-numeric IDs in batch creations (#3830)
* Use same logic between lastInsertId & lastInsertIds * fix for bulkcreator.insertLastIds
This commit is contained in:
@@ -1924,7 +1924,14 @@ func (c *batchCreator) insertLastIDs(ctx context.Context, tx dialect.ExecQuerier
|
||||
defer rows.Close()
|
||||
for i := 0; rows.Next(); i++ {
|
||||
node := c.Nodes[i]
|
||||
if node.ID.Type.Numeric() {
|
||||
switch _, ok := node.ID.Value.(field.ValueScanner); {
|
||||
case ok:
|
||||
// If the ID implements the sql.Scanner
|
||||
// interface it should be a pointer type.
|
||||
if err := rows.Scan(node.ID.Value); err != nil {
|
||||
return err
|
||||
}
|
||||
case node.ID.Type.Numeric():
|
||||
// Normalize the type to int64 to make it looks
|
||||
// like LastInsertId.
|
||||
var id int64
|
||||
@@ -1932,8 +1939,10 @@ func (c *batchCreator) insertLastIDs(ctx context.Context, tx dialect.ExecQuerier
|
||||
return err
|
||||
}
|
||||
node.ID.Value = id
|
||||
} else if err := rows.Scan(&node.ID.Value); err != nil {
|
||||
return err
|
||||
default:
|
||||
if err := rows.Scan(&node.ID.Value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows.Err()
|
||||
|
||||
Reference in New Issue
Block a user