entc/gen: support de/incrementing values on upsert

Fixed https://github.com/ent/ent/issues/1952.
This commit is contained in:
Ariel Mashraki
2021-09-17 00:29:37 +03:00
committed by Ariel Mashraki
parent 8cb468f979
commit 5f31091dcd
15 changed files with 941 additions and 67 deletions

View File

@@ -908,14 +908,13 @@ func (i *InsertBuilder) OnConflict(opts ...ConflictOption) *InsertBuilder {
// UpdateSet describes a set of changes of the `DO UPDATE` clause.
type UpdateSet struct {
table string
columns []string
update *UpdateBuilder
}
// Table returns the table the `UPSERT` statement is executed on.
func (u *UpdateSet) Table() *SelectTable {
return Dialect(u.update.dialect).Table(u.table)
return Dialect(u.update.dialect).Table(u.update.table)
}
// Columns returns all columns in the `INSERT` statement.
@@ -1031,7 +1030,7 @@ func (i *InsertBuilder) writeConflict() {
if len(i.conflict.action.update) == 0 {
i.AddError(errors.New("missing action for 'DO UPDATE SET' clause"))
}
u := &UpdateSet{table: i.table, columns: i.columns, update: &UpdateBuilder{}}
u := &UpdateSet{columns: i.columns, update: Dialect(i.dialect).Update(i.table)}
u.update.Builder = i.Builder
for _, f := range i.conflict.action.update {
f(u)
@@ -1084,7 +1083,7 @@ func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder {
return u.Set(column, ExprFunc(func(b *Builder) {
b.WriteString("COALESCE")
b.Nested(func(b *Builder) {
b.Ident(column).Comma().WriteByte('0')
b.Ident(Table(u.table).C(column)).Comma().WriteByte('0')
})
b.WriteString(" + ")
b.Arg(v)

View File

@@ -445,7 +445,7 @@ func TestBuilder(t *testing.T) {
input: Update("users").
Add("age", 1).
Where(HasPrefix("nickname", "a8m")),
wantQuery: "UPDATE `users` SET `age` = COALESCE(`age`, 0) + ? WHERE `nickname` LIKE ?",
wantQuery: "UPDATE `users` SET `age` = COALESCE(`users`.`age`, 0) + ? WHERE `nickname` LIKE ?",
wantArgs: []interface{}{1, "a8m%"},
},
{
@@ -453,7 +453,7 @@ func TestBuilder(t *testing.T) {
Update("users").
Add("age", 1).
Where(HasPrefix("nickname", "a8m")),
wantQuery: `UPDATE "users" SET "age" = COALESCE("age", 0) + $1 WHERE "nickname" LIKE $2`,
wantQuery: `UPDATE "users" SET "age" = COALESCE("users"."age", 0) + $1 WHERE "nickname" LIKE $2`,
wantArgs: []interface{}{1, "a8m%"},
},
{
@@ -462,7 +462,7 @@ func TestBuilder(t *testing.T) {
Set("nickname", "a8m").
Add("version", 10).
Set("name", "mashraki"),
wantQuery: "UPDATE `users` SET `age` = COALESCE(`age`, 0) + ?, `nickname` = ?, `version` = COALESCE(`version`, 0) + ?, `name` = ?",
wantQuery: "UPDATE `users` SET `age` = COALESCE(`users`.`age`, 0) + ?, `nickname` = ?, `version` = COALESCE(`users`.`version`, 0) + ?, `name` = ?",
wantArgs: []interface{}{1, "a8m", 10, "mashraki"},
},
{
@@ -472,7 +472,7 @@ func TestBuilder(t *testing.T) {
Set("nickname", "a8m").
Add("version", 10).
Set("name", "mashraki"),
wantQuery: `UPDATE "users" SET "age" = COALESCE("age", 0) + $1, "nickname" = $2, "version" = COALESCE("version", 0) + $3, "name" = $4`,
wantQuery: `UPDATE "users" SET "age" = COALESCE("users"."age", 0) + $1, "nickname" = $2, "version" = COALESCE("users"."version", 0) + $3, "name" = $4`,
wantArgs: []interface{}{1, "a8m", 10, "mashraki"},
},
{
@@ -485,7 +485,7 @@ func TestBuilder(t *testing.T) {
Set("first", "ariel").
Add("score", 1e5).
Where(Or(EQ("age", 1), EQ("age", 2))),
wantQuery: `UPDATE "users" SET "age" = COALESCE("age", 0) + $1, "nickname" = $2, "version" = COALESCE("version", 0) + $3, "name" = $4, "first" = $5, "score" = COALESCE("score", 0) + $6 WHERE "age" = $7 OR "age" = $8`,
wantQuery: `UPDATE "users" SET "age" = COALESCE("users"."age", 0) + $1, "nickname" = $2, "version" = COALESCE("users"."version", 0) + $3, "name" = $4, "first" = $5, "score" = COALESCE("users"."score", 0) + $6 WHERE "age" = $7 OR "age" = $8`,
wantArgs: []interface{}{1, "a8m", 10, "mashraki", "ariel", 1e5, 1, 2},
},
{
@@ -1725,7 +1725,7 @@ func TestInsert_OnConflict(t *testing.T) {
UpdateWhere(NEQ("updated_at", 0)),
).
Query()
require.Equal(t, `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("email") WHERE "name" = $3 DO UPDATE SET "id" = "users"."id", "email" = "excluded"."email", "version" = COALESCE("version", 0) + $4 WHERE "updated_at" <> $5`, query)
require.Equal(t, `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("email") WHERE "name" = $3 DO UPDATE SET "id" = "users"."id", "email" = "excluded"."email", "version" = COALESCE("users"."version", 0) + $4 WHERE "updated_at" <> $5`, query)
require.Equal(t, []interface{}{"1", "user@example.com", "Ariel", 1, 0}, args)
query, args = Dialect(dialect.Postgres).
@@ -1814,7 +1814,7 @@ func TestInsert_OnConflict(t *testing.T) {
}),
).
Query()
require.Equal(t, "INSERT INTO `users` (`id`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `created_at` = NULL, `name` = VALUES(`name`), `version` = COALESCE(`version`, 0) + ?", query)
require.Equal(t, "INSERT INTO `users` (`id`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `created_at` = NULL, `name` = VALUES(`name`), `version` = COALESCE(`users`.`version`, 0) + ?", query)
require.Equal(t, []interface{}{"1", "Mashraki", 1}, args)
query, args = Dialect(dialect.MySQL).

View File

@@ -1393,7 +1393,7 @@ func TestUpdateNode(t *testing.T) {
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`age`, 0) + ? WHERE `id` = ? AND `deleted` = ?")).
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND `deleted` = ?")).
WithArgs(1, 1, false).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND `deleted` = ?")).
@@ -1722,7 +1722,7 @@ func TestUpdateNodes(t *testing.T) {
mock.ExpectQuery(escape("SELECT `id` FROM `users`")).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(10))
mock.ExpectExec(escape("UPDATE `users` SET `version` = COALESCE(`version`, 0) + ? WHERE `id` = ?")).
mock.ExpectExec(escape("UPDATE `users` SET `version` = COALESCE(`users`.`version`, 0) + ? WHERE `id` = ?")).
WithArgs(1, 10).
WillReturnResult(sqlmock.NewResult(0, 1))
// Clear "owner_id" column in the "cards" table.