entc/gen: add support for upsert/on-conflict feature-flag

This commit is contained in:
Ariel Mashraki
2021-08-02 19:19:46 +03:00
committed by Ariel Mashraki
parent a5c931ed13
commit 09c4306378
95 changed files with 9556 additions and 279 deletions

View File

@@ -921,6 +921,11 @@ func (u *UpdateSet) Columns() []string {
return u.columns
}
// UpdateColumns returns all columns in the `UPDATE` statement.
func (u *UpdateSet) UpdateColumns() []string {
return append(u.update.nulls, u.update.columns...)
}
// Set sets a column to a given value.
func (u *UpdateSet) Set(column string, v interface{}) *UpdateSet {
u.update.Set(column, v)
@@ -1015,7 +1020,7 @@ func (i *InsertBuilder) writeConflict() {
for _, f := range i.conflict.action.update {
f(u)
}
u.update.writeSetter()
u.update.writeSetter(&i.Builder)
if p := i.conflict.action.where; p != nil {
i.WriteString(" WHERE ").Join(p)
}
@@ -1044,7 +1049,8 @@ func (u *UpdateBuilder) Schema(name string) *UpdateBuilder {
return u
}
// Set sets a column to a given value.
// Set sets a column to a given value. If `Set` was called before with
// the same column name, it overrides the value of the previous call.
func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder {
for i := range u.columns {
if column == u.columns[i] {
@@ -1101,38 +1107,39 @@ func (u *UpdateBuilder) Empty() bool {
// Query returns query representation of an `UPDATE` statement.
func (u *UpdateBuilder) Query() (string, []interface{}) {
u.WriteString("UPDATE ")
u.writeSchema(u.schema)
u.Ident(u.table).WriteString(" SET ")
u.writeSetter()
b := u.Builder.clone()
b.WriteString("UPDATE ")
b.writeSchema(u.schema)
b.Ident(u.table).WriteString(" SET ")
u.writeSetter(&b)
if u.where != nil {
u.WriteString(" WHERE ")
u.Join(u.where)
b.WriteString(" WHERE ")
b.Join(u.where)
}
return u.String(), u.args
return b.String(), b.args
}
// writeSetter writes the "SET" clause for the UPDATE statement.
func (u *UpdateBuilder) writeSetter() {
func (u *UpdateBuilder) writeSetter(b *Builder) {
for i, c := range u.nulls {
if i > 0 {
u.Comma()
b.Comma()
}
u.Ident(c).WriteString(" = NULL")
b.Ident(c).WriteString(" = NULL")
}
if len(u.nulls) > 0 && len(u.columns) > 0 {
u.Comma()
b.Comma()
}
for i, c := range u.columns {
if i > 0 {
u.Comma()
b.Comma()
}
u.Ident(c).WriteString(" = ")
b.Ident(c).WriteString(" = ")
switch v := u.values[i].(type) {
case Querier:
u.Join(v)
b.Join(v)
default:
u.Arg(v)
b.Arg(v)
}
}
}
@@ -1974,7 +1981,7 @@ type Selector struct {
distinct bool
union []union
prefix Queries
lock *LockConfig
lock *LockOptions
}
// WithContext sets the context into the *Selector.
@@ -2292,9 +2299,9 @@ const (
)
type (
// LockConfig defines a SELECT statement
// LockOptions defines a SELECT statement
// lock for protecting concurrent updates.
LockConfig struct {
LockOptions struct {
// Strength of the lock.
Strength LockStrength
// Action of the lock.
@@ -2305,19 +2312,19 @@ type (
clause string
}
// LockOption allows configuring the LockConfig using functional options.
LockOption func(*LockConfig)
LockOption func(*LockOptions)
)
// WithLockAction sets the Action of the lock.
func WithLockAction(action LockAction) LockOption {
return func(c *LockConfig) {
return func(c *LockOptions) {
c.Action = action
}
}
// WithLockTables sets the Tables of the lock.
func WithLockTables(tables ...string) LockOption {
return func(c *LockConfig) {
return func(c *LockOptions) {
c.Tables = tables
}
}
@@ -2332,7 +2339,7 @@ func WithLockTables(tables ...string) LockOption {
// )
//
func WithLockClause(clause string) LockOption {
return func(c *LockConfig) {
return func(c *LockOptions) {
c.clause = clause
}
}
@@ -2340,7 +2347,7 @@ func WithLockClause(clause string) LockOption {
// For sets the lock configuration for suffixing the `SELECT`
// statement with the `FOR [SHARE | UPDATE] ...` clause.
func (s *Selector) For(l LockStrength, opts ...LockOption) *Selector {
s.lock = &LockConfig{Strength: l}
s.lock = &LockOptions{Strength: l}
for _, opt := range opts {
opt(s.lock)
}

View File

@@ -974,6 +974,10 @@ func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*Ed
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
c.ensureLastInsertID(insert)
}
var res sql.Result
// If the id field was provided by the user.
if c.ID.Value != nil {
@@ -981,9 +985,6 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
query, args := insert.Query()
return tx.Exec(ctx, query, args, &res)
}
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
}
id, err := insertLastID(ctx, tx, insert.Returning(c.ID.Column))
if err != nil {
return err
@@ -992,6 +993,22 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
return nil
}
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
if !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
return
}
insert.OnConflict(sql.ResolveWith(func(s *sql.UpdateSet) {
for _, column := range s.UpdateColumns() {
if column == c.ID.Column {
return
}
}
s.Set(c.ID.Column, sql.Expr(fmt.Sprintf("LAST_INSERT_ID(%s)", s.Table().C(c.ID.Column))))
}))
}
// batchInsert inserts a batch of nodes to their table and sets their ID if it was not provided by the user.
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 {
@@ -1002,8 +1019,9 @@ func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, inser
return err
}
for i, node := range c.Nodes {
// ID field was provided by the user.
if node.ID.Value == nil {
// If the ID field was not provided by the user,
// but was returned by the `RETURNING` clause.
if node.ID.Value == nil && i < len(ids) {
node.ID.Value = ids[i]
}
}