mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: override column values on Updater.Set
Avoid cases like 'SET a = 1, a = 2'.
This commit is contained in:
committed by
Ariel Mashraki
parent
64e0116ed7
commit
b19ac669c7
@@ -737,7 +737,7 @@ type (
|
||||
action struct {
|
||||
nothing bool
|
||||
where *Predicate
|
||||
update func(*UpdateSet)
|
||||
update []func(*UpdateSet)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -831,11 +831,11 @@ func DoNothing() ConflictOption {
|
||||
//
|
||||
func ResolveWithIgnore() ConflictOption {
|
||||
return func(c *conflict) {
|
||||
c.action.update = func(u *UpdateSet) {
|
||||
c.action.update = append(c.action.update, func(u *UpdateSet) {
|
||||
for _, c := range u.columns {
|
||||
u.Set(c, Expr(u.Table().C(c)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -856,11 +856,11 @@ func ResolveWithIgnore() ConflictOption {
|
||||
//
|
||||
func ResolveWithNewValues() ConflictOption {
|
||||
return func(c *conflict) {
|
||||
c.action.update = func(u *UpdateSet) {
|
||||
c.action.update = append(c.action.update, func(u *UpdateSet) {
|
||||
for _, c := range u.columns {
|
||||
u.SetExcluded(c)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -879,7 +879,7 @@ func ResolveWithNewValues() ConflictOption {
|
||||
//
|
||||
func ResolveWith(fn func(*UpdateSet)) ConflictOption {
|
||||
return func(c *conflict) {
|
||||
c.action.update = fn
|
||||
c.action.update = append(c.action.update, fn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1012,7 +1012,9 @@ func (i *InsertBuilder) writeConflict() {
|
||||
}
|
||||
u := &UpdateSet{table: i.table, columns: i.columns, update: &UpdateBuilder{}}
|
||||
u.update.Builder = i.Builder
|
||||
i.conflict.action.update(u)
|
||||
for _, f := range i.conflict.action.update {
|
||||
f(u)
|
||||
}
|
||||
u.update.writeSetter()
|
||||
if p := i.conflict.action.where; p != nil {
|
||||
i.WriteString(" WHERE ").Join(p)
|
||||
@@ -1044,6 +1046,12 @@ func (u *UpdateBuilder) Schema(name string) *UpdateBuilder {
|
||||
|
||||
// Set sets a column to a given value.
|
||||
func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder {
|
||||
for i := range u.columns {
|
||||
if column == u.columns[i] {
|
||||
u.values[i] = v
|
||||
return u
|
||||
}
|
||||
}
|
||||
u.columns = append(u.columns, column)
|
||||
u.values = append(u.values, v)
|
||||
return u
|
||||
@@ -1051,8 +1059,7 @@ func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder {
|
||||
|
||||
// Add adds a numeric value to the given column.
|
||||
func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder {
|
||||
u.columns = append(u.columns, column)
|
||||
u.values = append(u.values, ExprFunc(func(b *Builder) {
|
||||
return u.Set(column, ExprFunc(func(b *Builder) {
|
||||
b.WriteString("COALESCE")
|
||||
b.Nested(func(b *Builder) {
|
||||
b.Ident(column).Comma().Arg(0)
|
||||
@@ -1060,7 +1067,6 @@ func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder {
|
||||
b.WriteString(" + ")
|
||||
b.Arg(v)
|
||||
}))
|
||||
return u
|
||||
}
|
||||
|
||||
// SetNull sets a column as null value.
|
||||
|
||||
@@ -1802,5 +1802,19 @@ func TestInsert_OnConflict(t *testing.T) {
|
||||
Query()
|
||||
require.Equal(t, "INSERT INTO `users` (`id`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `created_at` = NULL, `name` = VALUES(`name`)", query)
|
||||
require.Equal(t, []interface{}{"1", "Mashraki"}, args)
|
||||
|
||||
query, args = Dialect(dialect.MySQL).
|
||||
Insert("users").
|
||||
Columns("name").
|
||||
Values("Mashraki").
|
||||
OnConflict(
|
||||
ResolveWithNewValues(),
|
||||
ResolveWith(func(s *UpdateSet) {
|
||||
s.Set("id", Expr("LAST_INSERT_ID(`id`)"))
|
||||
}),
|
||||
).
|
||||
Query()
|
||||
require.Equal(t, "INSERT INTO `users` (`name`) VALUES (?) ON DUPLICATE KEY UPDATE `name` = VALUES(`name`), `id` = LAST_INSERT_ID(`id`)", query)
|
||||
require.Equal(t, []interface{}{"Mashraki"}, args)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user