From b19ac669c7a9ada241519cbbbc5609410cf00361 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sun, 1 Aug 2021 09:35:57 +0300 Subject: [PATCH] dialect/sql: override column values on Updater.Set Avoid cases like 'SET a = 1, a = 2'. --- dialect/sql/builder.go | 26 ++++++++++++++++---------- dialect/sql/builder_test.go | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 6d7caa6f0..25eae9f99 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -737,7 +737,7 @@ type ( action struct { nothing bool where *Predicate - update func(*UpdateSet) + update []func(*UpdateSet) } } @@ -831,11 +831,11 @@ func DoNothing() ConflictOption { // func ResolveWithIgnore() ConflictOption { return func(c *conflict) { - c.action.update = func(u *UpdateSet) { + c.action.update = append(c.action.update, func(u *UpdateSet) { for _, c := range u.columns { u.Set(c, Expr(u.Table().C(c))) } - } + }) } } @@ -856,11 +856,11 @@ func ResolveWithIgnore() ConflictOption { // func ResolveWithNewValues() ConflictOption { return func(c *conflict) { - c.action.update = func(u *UpdateSet) { + c.action.update = append(c.action.update, func(u *UpdateSet) { for _, c := range u.columns { u.SetExcluded(c) } - } + }) } } @@ -879,7 +879,7 @@ func ResolveWithNewValues() ConflictOption { // func ResolveWith(fn func(*UpdateSet)) ConflictOption { return func(c *conflict) { - c.action.update = fn + c.action.update = append(c.action.update, fn) } } @@ -1012,7 +1012,9 @@ func (i *InsertBuilder) writeConflict() { } u := &UpdateSet{table: i.table, columns: i.columns, update: &UpdateBuilder{}} u.update.Builder = i.Builder - i.conflict.action.update(u) + for _, f := range i.conflict.action.update { + f(u) + } u.update.writeSetter() if p := i.conflict.action.where; p != nil { i.WriteString(" WHERE ").Join(p) @@ -1044,6 +1046,12 @@ func (u *UpdateBuilder) Schema(name string) *UpdateBuilder { // Set sets a column to a given value. func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder { + for i := range u.columns { + if column == u.columns[i] { + u.values[i] = v + return u + } + } u.columns = append(u.columns, column) u.values = append(u.values, v) return u @@ -1051,8 +1059,7 @@ func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder { // Add adds a numeric value to the given column. func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder { - u.columns = append(u.columns, column) - u.values = append(u.values, ExprFunc(func(b *Builder) { + return u.Set(column, ExprFunc(func(b *Builder) { b.WriteString("COALESCE") b.Nested(func(b *Builder) { b.Ident(column).Comma().Arg(0) @@ -1060,7 +1067,6 @@ func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder { b.WriteString(" + ") b.Arg(v) })) - return u } // SetNull sets a column as null value. diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 529de6f2b..59fd6e9ff 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1802,5 +1802,19 @@ func TestInsert_OnConflict(t *testing.T) { Query() require.Equal(t, "INSERT INTO `users` (`id`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `created_at` = NULL, `name` = VALUES(`name`)", query) require.Equal(t, []interface{}{"1", "Mashraki"}, args) + + query, args = Dialect(dialect.MySQL). + Insert("users"). + Columns("name"). + Values("Mashraki"). + OnConflict( + ResolveWithNewValues(), + ResolveWith(func(s *UpdateSet) { + s.Set("id", Expr("LAST_INSERT_ID(`id`)")) + }), + ). + Query() + require.Equal(t, "INSERT INTO `users` (`name`) VALUES (?) ON DUPLICATE KEY UPDATE `name` = VALUES(`name`), `id` = LAST_INSERT_ID(`id`)", query) + require.Equal(t, []interface{}{"Mashraki"}, args) }) }