mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: add support for upsert/on-conflict feature-flag
This commit is contained in:
committed by
Ariel Mashraki
parent
a5c931ed13
commit
09c4306378
@@ -921,6 +921,11 @@ func (u *UpdateSet) Columns() []string {
|
||||
return u.columns
|
||||
}
|
||||
|
||||
// UpdateColumns returns all columns in the `UPDATE` statement.
|
||||
func (u *UpdateSet) UpdateColumns() []string {
|
||||
return append(u.update.nulls, u.update.columns...)
|
||||
}
|
||||
|
||||
// Set sets a column to a given value.
|
||||
func (u *UpdateSet) Set(column string, v interface{}) *UpdateSet {
|
||||
u.update.Set(column, v)
|
||||
@@ -1015,7 +1020,7 @@ func (i *InsertBuilder) writeConflict() {
|
||||
for _, f := range i.conflict.action.update {
|
||||
f(u)
|
||||
}
|
||||
u.update.writeSetter()
|
||||
u.update.writeSetter(&i.Builder)
|
||||
if p := i.conflict.action.where; p != nil {
|
||||
i.WriteString(" WHERE ").Join(p)
|
||||
}
|
||||
@@ -1044,7 +1049,8 @@ func (u *UpdateBuilder) Schema(name string) *UpdateBuilder {
|
||||
return u
|
||||
}
|
||||
|
||||
// Set sets a column to a given value.
|
||||
// Set sets a column to a given value. If `Set` was called before with
|
||||
// the same column name, it overrides the value of the previous call.
|
||||
func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder {
|
||||
for i := range u.columns {
|
||||
if column == u.columns[i] {
|
||||
@@ -1101,38 +1107,39 @@ func (u *UpdateBuilder) Empty() bool {
|
||||
|
||||
// Query returns query representation of an `UPDATE` statement.
|
||||
func (u *UpdateBuilder) Query() (string, []interface{}) {
|
||||
u.WriteString("UPDATE ")
|
||||
u.writeSchema(u.schema)
|
||||
u.Ident(u.table).WriteString(" SET ")
|
||||
u.writeSetter()
|
||||
b := u.Builder.clone()
|
||||
b.WriteString("UPDATE ")
|
||||
b.writeSchema(u.schema)
|
||||
b.Ident(u.table).WriteString(" SET ")
|
||||
u.writeSetter(&b)
|
||||
if u.where != nil {
|
||||
u.WriteString(" WHERE ")
|
||||
u.Join(u.where)
|
||||
b.WriteString(" WHERE ")
|
||||
b.Join(u.where)
|
||||
}
|
||||
return u.String(), u.args
|
||||
return b.String(), b.args
|
||||
}
|
||||
|
||||
// writeSetter writes the "SET" clause for the UPDATE statement.
|
||||
func (u *UpdateBuilder) writeSetter() {
|
||||
func (u *UpdateBuilder) writeSetter(b *Builder) {
|
||||
for i, c := range u.nulls {
|
||||
if i > 0 {
|
||||
u.Comma()
|
||||
b.Comma()
|
||||
}
|
||||
u.Ident(c).WriteString(" = NULL")
|
||||
b.Ident(c).WriteString(" = NULL")
|
||||
}
|
||||
if len(u.nulls) > 0 && len(u.columns) > 0 {
|
||||
u.Comma()
|
||||
b.Comma()
|
||||
}
|
||||
for i, c := range u.columns {
|
||||
if i > 0 {
|
||||
u.Comma()
|
||||
b.Comma()
|
||||
}
|
||||
u.Ident(c).WriteString(" = ")
|
||||
b.Ident(c).WriteString(" = ")
|
||||
switch v := u.values[i].(type) {
|
||||
case Querier:
|
||||
u.Join(v)
|
||||
b.Join(v)
|
||||
default:
|
||||
u.Arg(v)
|
||||
b.Arg(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1974,7 +1981,7 @@ type Selector struct {
|
||||
distinct bool
|
||||
union []union
|
||||
prefix Queries
|
||||
lock *LockConfig
|
||||
lock *LockOptions
|
||||
}
|
||||
|
||||
// WithContext sets the context into the *Selector.
|
||||
@@ -2292,9 +2299,9 @@ const (
|
||||
)
|
||||
|
||||
type (
|
||||
// LockConfig defines a SELECT statement
|
||||
// LockOptions defines a SELECT statement
|
||||
// lock for protecting concurrent updates.
|
||||
LockConfig struct {
|
||||
LockOptions struct {
|
||||
// Strength of the lock.
|
||||
Strength LockStrength
|
||||
// Action of the lock.
|
||||
@@ -2305,19 +2312,19 @@ type (
|
||||
clause string
|
||||
}
|
||||
// LockOption allows configuring the LockConfig using functional options.
|
||||
LockOption func(*LockConfig)
|
||||
LockOption func(*LockOptions)
|
||||
)
|
||||
|
||||
// WithLockAction sets the Action of the lock.
|
||||
func WithLockAction(action LockAction) LockOption {
|
||||
return func(c *LockConfig) {
|
||||
return func(c *LockOptions) {
|
||||
c.Action = action
|
||||
}
|
||||
}
|
||||
|
||||
// WithLockTables sets the Tables of the lock.
|
||||
func WithLockTables(tables ...string) LockOption {
|
||||
return func(c *LockConfig) {
|
||||
return func(c *LockOptions) {
|
||||
c.Tables = tables
|
||||
}
|
||||
}
|
||||
@@ -2332,7 +2339,7 @@ func WithLockTables(tables ...string) LockOption {
|
||||
// )
|
||||
//
|
||||
func WithLockClause(clause string) LockOption {
|
||||
return func(c *LockConfig) {
|
||||
return func(c *LockOptions) {
|
||||
c.clause = clause
|
||||
}
|
||||
}
|
||||
@@ -2340,7 +2347,7 @@ func WithLockClause(clause string) LockOption {
|
||||
// For sets the lock configuration for suffixing the `SELECT`
|
||||
// statement with the `FOR [SHARE | UPDATE] ...` clause.
|
||||
func (s *Selector) For(l LockStrength, opts ...LockOption) *Selector {
|
||||
s.lock = &LockConfig{Strength: l}
|
||||
s.lock = &LockOptions{Strength: l}
|
||||
for _, opt := range opts {
|
||||
opt(s.lock)
|
||||
}
|
||||
|
||||
@@ -974,6 +974,10 @@ func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*Ed
|
||||
|
||||
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
|
||||
func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
|
||||
insert.OnConflict(opts...)
|
||||
c.ensureLastInsertID(insert)
|
||||
}
|
||||
var res sql.Result
|
||||
// If the id field was provided by the user.
|
||||
if c.ID.Value != nil {
|
||||
@@ -981,9 +985,6 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
|
||||
query, args := insert.Query()
|
||||
return tx.Exec(ctx, query, args, &res)
|
||||
}
|
||||
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
|
||||
insert.OnConflict(opts...)
|
||||
}
|
||||
id, err := insertLastID(ctx, tx, insert.Returning(c.ID.Column))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -992,6 +993,22 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
|
||||
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
|
||||
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
|
||||
if !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
|
||||
return
|
||||
}
|
||||
insert.OnConflict(sql.ResolveWith(func(s *sql.UpdateSet) {
|
||||
for _, column := range s.UpdateColumns() {
|
||||
if column == c.ID.Column {
|
||||
return
|
||||
}
|
||||
}
|
||||
s.Set(c.ID.Column, sql.Expr(fmt.Sprintf("LAST_INSERT_ID(%s)", s.Table().C(c.ID.Column))))
|
||||
}))
|
||||
}
|
||||
|
||||
// batchInsert inserts a batch of nodes to their table and sets their ID if it was not provided by the user.
|
||||
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 {
|
||||
@@ -1002,8 +1019,9 @@ func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, inser
|
||||
return err
|
||||
}
|
||||
for i, node := range c.Nodes {
|
||||
// ID field was provided by the user.
|
||||
if node.ID.Value == nil {
|
||||
// If the ID field was not provided by the user,
|
||||
// but was returned by the `RETURNING` clause.
|
||||
if node.ID.Value == nil && i < len(ids) {
|
||||
node.ID.Value = ids[i]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user