entc/gen: add support for upsert/on-conflict feature-flag

This commit is contained in:
Ariel Mashraki
2021-08-02 19:19:46 +03:00
committed by Ariel Mashraki
parent a5c931ed13
commit 09c4306378
95 changed files with 9556 additions and 279 deletions

View File

@@ -974,6 +974,10 @@ func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*Ed
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
c.ensureLastInsertID(insert)
}
var res sql.Result
// If the id field was provided by the user.
if c.ID.Value != nil {
@@ -981,9 +985,6 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
query, args := insert.Query()
return tx.Exec(ctx, query, args, &res)
}
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
}
id, err := insertLastID(ctx, tx, insert.Returning(c.ID.Column))
if err != nil {
return err
@@ -992,6 +993,22 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
return nil
}
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
if !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
return
}
insert.OnConflict(sql.ResolveWith(func(s *sql.UpdateSet) {
for _, column := range s.UpdateColumns() {
if column == c.ID.Column {
return
}
}
s.Set(c.ID.Column, sql.Expr(fmt.Sprintf("LAST_INSERT_ID(%s)", s.Table().C(c.ID.Column))))
}))
}
// batchInsert inserts a batch of nodes to their table and sets their ID if it was not provided by the user.
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 {
@@ -1002,8 +1019,9 @@ func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, inser
return err
}
for i, node := range c.Nodes {
// ID field was provided by the user.
if node.ID.Value == nil {
// If the ID field was not provided by the user,
// but was returned by the `RETURNING` clause.
if node.ID.Value == nil && i < len(ids) {
node.ID.Value = ids[i]
}
}