mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql: add on-conflict handling to sql builder (#1370)
* Adds conflict handling to sql builder * Revert some builder tests * Fix builder test * Revert * Revert another line * Update dialect/sql/builder.go * Move conflict ops * Refactor conflict handling builder Co-authored-by: Ivan Vanderbyl <ivanvanderbyl@gmail.com> Co-authored-by: Ivan Vanderbyl <ivanvanderbyl@users.noreply.github.com>
This commit is contained in:
committed by
Ariel Mashraki
parent
9520911f3d
commit
afa3beca6b
@@ -622,6 +622,12 @@ type InsertBuilder struct {
|
||||
defaults bool
|
||||
returning []string
|
||||
values [][]interface{}
|
||||
|
||||
// Upsert
|
||||
conflictColumns []string
|
||||
updateColumns []string
|
||||
updateValues []interface{}
|
||||
onConflictOp ConflictResolutionOp
|
||||
}
|
||||
|
||||
// Insert creates a builder for the `INSERT INTO` statement.
|
||||
@@ -657,7 +663,38 @@ func (i *InsertBuilder) Columns(columns ...string) *InsertBuilder {
|
||||
return i
|
||||
}
|
||||
|
||||
// Values appends a value tuple to the INSERT statement.
|
||||
// ConflictColumns sets the unique constraints that trigger the conflict resolution on insert
|
||||
// to perform an upsert operation. The columns must have a unqiue constraint applied to trigger this behaviour.
|
||||
func (i *InsertBuilder) ConflictColumns(values ...string) *InsertBuilder {
|
||||
i.conflictColumns = append(i.conflictColumns, values...)
|
||||
return i
|
||||
}
|
||||
|
||||
// A ConflictResolutionOp represents a possible action to take when an insert conflict occurrs.
|
||||
type ConflictResolutionOp int
|
||||
|
||||
// Conflict Operations
|
||||
const (
|
||||
OpResolveWithNewValues ConflictResolutionOp = iota // Update conflict columns using EXCLUDED.column (postres) or c = VALUES(c) (mysql)
|
||||
OpResolveWithIgnore // Sets each column to itself to force an update and return the ID, otherwise does not change any data. This may still trigger update hooks in the database.
|
||||
OpResolveWithAlternateValues // Update using provided values across all rows.
|
||||
)
|
||||
|
||||
// OnConflict sets the conflict resolution behaviour when a unique constraint
|
||||
// violation occurrs, triggering an upsert.
|
||||
func (i *InsertBuilder) OnConflict(op ConflictResolutionOp) *InsertBuilder {
|
||||
i.onConflictOp = op
|
||||
return i
|
||||
}
|
||||
|
||||
// UpdateSet sets a column and a its value for use on upsert
|
||||
func (i *InsertBuilder) UpdateSet(column string, v interface{}) *InsertBuilder {
|
||||
i.updateColumns = append(i.updateColumns, column)
|
||||
i.updateValues = append(i.updateValues, v)
|
||||
return i
|
||||
}
|
||||
|
||||
// Values append a value tuple for the insert statement.
|
||||
func (i *InsertBuilder) Values(values ...interface{}) *InsertBuilder {
|
||||
i.values = append(i.values, values)
|
||||
return i
|
||||
@@ -701,6 +738,11 @@ func (i *InsertBuilder) Query() (string, []interface{}) {
|
||||
i.WriteByte('(').Args(v...).WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
if len(i.conflictColumns) > 0 {
|
||||
i.buildConflictHandling()
|
||||
}
|
||||
|
||||
if len(i.returning) > 0 && i.postgres() {
|
||||
i.WriteString(" RETURNING ")
|
||||
i.IdentComma(i.returning...)
|
||||
@@ -708,6 +750,72 @@ func (i *InsertBuilder) Query() (string, []interface{}) {
|
||||
return i.String(), i.args
|
||||
}
|
||||
|
||||
func (i *InsertBuilder) buildConflictHandling() {
|
||||
switch i.Dialect() {
|
||||
case dialect.Postgres, dialect.SQLite:
|
||||
i.Pad().
|
||||
WriteString("ON CONFLICT").
|
||||
Pad().
|
||||
Nested(func(b *Builder) {
|
||||
b.IdentComma(i.conflictColumns...)
|
||||
}).
|
||||
Pad().
|
||||
WriteString("DO UPDATE SET ")
|
||||
|
||||
switch i.onConflictOp {
|
||||
case OpResolveWithNewValues:
|
||||
for j, c := range i.columns {
|
||||
if j > 0 {
|
||||
i.Comma()
|
||||
}
|
||||
i.Ident(c).WriteOp(OpEQ).Ident("excluded").WriteByte('.').Ident(c)
|
||||
}
|
||||
case OpResolveWithIgnore:
|
||||
writeIgnoreValues(i)
|
||||
case OpResolveWithAlternateValues:
|
||||
writeUpdateValues(i, i.updateColumns, i.updateValues)
|
||||
}
|
||||
|
||||
case dialect.MySQL:
|
||||
i.Pad().WriteString("ON DUPLICATE KEY UPDATE ")
|
||||
|
||||
switch i.onConflictOp {
|
||||
case OpResolveWithIgnore:
|
||||
writeIgnoreValues(i)
|
||||
case OpResolveWithNewValues:
|
||||
for j, c := range i.columns {
|
||||
if j > 0 {
|
||||
i.Comma()
|
||||
}
|
||||
// update column with the value we tried to insert
|
||||
i.Ident(c).WriteOp(OpEQ).WriteString("VALUES").WriteByte('(').Ident(c).WriteByte(')')
|
||||
}
|
||||
case OpResolveWithAlternateValues:
|
||||
writeUpdateValues(i, i.updateColumns, i.updateValues)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeUpdateValues(builder *InsertBuilder, columns []string, values []interface{}) {
|
||||
for i, c := range columns {
|
||||
if i > 0 {
|
||||
builder.Comma()
|
||||
}
|
||||
builder.Ident(c).WriteString(" = ").Arg(builder.updateValues[i])
|
||||
}
|
||||
}
|
||||
|
||||
// writeIgnoreValues ignores conflicts by setting each column to itself e.g. "c" = "c",
|
||||
// performimg an update without changing any values so that it returns the record ID.
|
||||
func writeIgnoreValues(builder *InsertBuilder) {
|
||||
for j, c := range builder.columns {
|
||||
if j > 0 {
|
||||
builder.Comma()
|
||||
}
|
||||
builder.Ident(c).WriteOp(OpEQ).Ident(c)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateBuilder is a builder for `UPDATE` statement.
|
||||
type UpdateBuilder struct {
|
||||
Builder
|
||||
|
||||
@@ -1314,14 +1314,14 @@ func TestBuilder(t *testing.T) {
|
||||
Or().
|
||||
Where(And(NEQ("f", "f"), NEQ("g", "g"))),
|
||||
wantQuery: strings.NewReplacer("\n", "", "\t", "").Replace(`
|
||||
SELECT * FROM "users"
|
||||
WHERE
|
||||
(
|
||||
(("id" = $1 AND "group_id" IN ($2, $3)) OR ("id" = $4 AND "group_id" IN ($5, $6)))
|
||||
AND
|
||||
(("a" = $7 OR ("b" = $8 AND "c" = $9)) AND (NOT ("d" IS NULL OR "e" IS NOT NULL)))
|
||||
)
|
||||
OR ("f" <> $10 AND "g" <> $11)`),
|
||||
SELECT * FROM "users"
|
||||
WHERE
|
||||
(
|
||||
(("id" = $1 AND "group_id" IN ($2, $3)) OR ("id" = $4 AND "group_id" IN ($5, $6)))
|
||||
AND
|
||||
(("a" = $7 OR ("b" = $8 AND "c" = $9)) AND (NOT ("d" IS NULL OR "e" IS NOT NULL)))
|
||||
)
|
||||
OR ("f" <> $10 AND "g" <> $11)`),
|
||||
wantArgs: []interface{}{1, 2, 3, 2, 4, 5, "a", "b", "c", "f", "g"},
|
||||
},
|
||||
{
|
||||
@@ -1414,6 +1414,41 @@ WHERE (((("users"."id1" = "users"."id2" AND "users"."id1" <> "users"."id2")
|
||||
AND "users"."id1" > "users"."id2") AND "users"."id1" >= "users"."id2")
|
||||
AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", ""),
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).Insert("users").Columns("id", "email").Values("1", "user@example.com").ConflictColumns("id").UpdateSet("email", "user-1@example.com"),
|
||||
wantQuery: `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "id" = "excluded"."id", "email" = "excluded"."email"`,
|
||||
wantArgs: []interface{}{"1", "user@example.com"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).Insert("users").Columns("id", "email").Values("1", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("id"),
|
||||
wantQuery: `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "id" = "id", "email" = "email"`,
|
||||
wantArgs: []interface{}{"1", "user@example.com"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithAlternateValues).UpdateSet("email", "user-1@example.com").ConflictColumns("email"),
|
||||
wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = ?",
|
||||
wantArgs: []interface{}{"user@example.com", "user-1@example.com"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithAlternateValues).UpdateSet("email", "user-1@example.com").ConflictColumns("email"),
|
||||
wantQuery: `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "email" = $2`,
|
||||
wantArgs: []interface{}{"user@example.com", "user-1@example.com"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("email"),
|
||||
wantQuery: `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "email" = "email"`,
|
||||
wantArgs: []interface{}{"user@example.com"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("email"),
|
||||
wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = `email`",
|
||||
wantArgs: []interface{}{"user@example.com"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithNewValues).ConflictColumns("email"),
|
||||
wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = VALUES(`email`)",
|
||||
wantArgs: []interface{}{"user@example.com"},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user