dialect/sql: add on-conflict handling to sql builder (#1370)

* Adds conflict handling to sql builder

* Revert some builder tests

* Fix builder test

* Revert

* Revert another line

* Update dialect/sql/builder.go

* Move conflict ops

* Refactor conflict handling builder

Co-authored-by: Ivan Vanderbyl <ivanvanderbyl@gmail.com>
Co-authored-by: Ivan Vanderbyl <ivanvanderbyl@users.noreply.github.com>
This commit is contained in:
Ivan Vanderbyl
2021-03-22 12:26:38 +02:00
committed by Ariel Mashraki
parent 9520911f3d
commit afa3beca6b
2 changed files with 152 additions and 9 deletions

View File

@@ -622,6 +622,12 @@ type InsertBuilder struct {
defaults bool
returning []string
values [][]interface{}
// Upsert
conflictColumns []string
updateColumns []string
updateValues []interface{}
onConflictOp ConflictResolutionOp
}
// Insert creates a builder for the `INSERT INTO` statement.
@@ -657,7 +663,38 @@ func (i *InsertBuilder) Columns(columns ...string) *InsertBuilder {
return i
}
// Values appends a value tuple to the INSERT statement.
// ConflictColumns sets the unique constraints that trigger the conflict resolution on insert
// to perform an upsert operation. The columns must have a unqiue constraint applied to trigger this behaviour.
func (i *InsertBuilder) ConflictColumns(values ...string) *InsertBuilder {
i.conflictColumns = append(i.conflictColumns, values...)
return i
}
// A ConflictResolutionOp represents a possible action to take when an insert conflict occurrs.
type ConflictResolutionOp int
// Conflict Operations
const (
OpResolveWithNewValues ConflictResolutionOp = iota // Update conflict columns using EXCLUDED.column (postres) or c = VALUES(c) (mysql)
OpResolveWithIgnore // Sets each column to itself to force an update and return the ID, otherwise does not change any data. This may still trigger update hooks in the database.
OpResolveWithAlternateValues // Update using provided values across all rows.
)
// OnConflict sets the conflict resolution behaviour when a unique constraint
// violation occurrs, triggering an upsert.
func (i *InsertBuilder) OnConflict(op ConflictResolutionOp) *InsertBuilder {
i.onConflictOp = op
return i
}
// UpdateSet sets a column and a its value for use on upsert
func (i *InsertBuilder) UpdateSet(column string, v interface{}) *InsertBuilder {
i.updateColumns = append(i.updateColumns, column)
i.updateValues = append(i.updateValues, v)
return i
}
// Values append a value tuple for the insert statement.
func (i *InsertBuilder) Values(values ...interface{}) *InsertBuilder {
i.values = append(i.values, values)
return i
@@ -701,6 +738,11 @@ func (i *InsertBuilder) Query() (string, []interface{}) {
i.WriteByte('(').Args(v...).WriteByte(')')
}
}
if len(i.conflictColumns) > 0 {
i.buildConflictHandling()
}
if len(i.returning) > 0 && i.postgres() {
i.WriteString(" RETURNING ")
i.IdentComma(i.returning...)
@@ -708,6 +750,72 @@ func (i *InsertBuilder) Query() (string, []interface{}) {
return i.String(), i.args
}
func (i *InsertBuilder) buildConflictHandling() {
switch i.Dialect() {
case dialect.Postgres, dialect.SQLite:
i.Pad().
WriteString("ON CONFLICT").
Pad().
Nested(func(b *Builder) {
b.IdentComma(i.conflictColumns...)
}).
Pad().
WriteString("DO UPDATE SET ")
switch i.onConflictOp {
case OpResolveWithNewValues:
for j, c := range i.columns {
if j > 0 {
i.Comma()
}
i.Ident(c).WriteOp(OpEQ).Ident("excluded").WriteByte('.').Ident(c)
}
case OpResolveWithIgnore:
writeIgnoreValues(i)
case OpResolveWithAlternateValues:
writeUpdateValues(i, i.updateColumns, i.updateValues)
}
case dialect.MySQL:
i.Pad().WriteString("ON DUPLICATE KEY UPDATE ")
switch i.onConflictOp {
case OpResolveWithIgnore:
writeIgnoreValues(i)
case OpResolveWithNewValues:
for j, c := range i.columns {
if j > 0 {
i.Comma()
}
// update column with the value we tried to insert
i.Ident(c).WriteOp(OpEQ).WriteString("VALUES").WriteByte('(').Ident(c).WriteByte(')')
}
case OpResolveWithAlternateValues:
writeUpdateValues(i, i.updateColumns, i.updateValues)
}
}
}
func writeUpdateValues(builder *InsertBuilder, columns []string, values []interface{}) {
for i, c := range columns {
if i > 0 {
builder.Comma()
}
builder.Ident(c).WriteString(" = ").Arg(builder.updateValues[i])
}
}
// writeIgnoreValues ignores conflicts by setting each column to itself e.g. "c" = "c",
// performimg an update without changing any values so that it returns the record ID.
func writeIgnoreValues(builder *InsertBuilder) {
for j, c := range builder.columns {
if j > 0 {
builder.Comma()
}
builder.Ident(c).WriteOp(OpEQ).Ident(c)
}
}
// UpdateBuilder is a builder for `UPDATE` statement.
type UpdateBuilder struct {
Builder

View File

@@ -1314,14 +1314,14 @@ func TestBuilder(t *testing.T) {
Or().
Where(And(NEQ("f", "f"), NEQ("g", "g"))),
wantQuery: strings.NewReplacer("\n", "", "\t", "").Replace(`
SELECT * FROM "users"
WHERE
(
(("id" = $1 AND "group_id" IN ($2, $3)) OR ("id" = $4 AND "group_id" IN ($5, $6)))
AND
(("a" = $7 OR ("b" = $8 AND "c" = $9)) AND (NOT ("d" IS NULL OR "e" IS NOT NULL)))
)
OR ("f" <> $10 AND "g" <> $11)`),
SELECT * FROM "users"
WHERE
(
(("id" = $1 AND "group_id" IN ($2, $3)) OR ("id" = $4 AND "group_id" IN ($5, $6)))
AND
(("a" = $7 OR ("b" = $8 AND "c" = $9)) AND (NOT ("d" IS NULL OR "e" IS NOT NULL)))
)
OR ("f" <> $10 AND "g" <> $11)`),
wantArgs: []interface{}{1, 2, 3, 2, 4, 5, "a", "b", "c", "f", "g"},
},
{
@@ -1414,6 +1414,41 @@ WHERE (((("users"."id1" = "users"."id2" AND "users"."id1" <> "users"."id2")
AND "users"."id1" > "users"."id2") AND "users"."id1" >= "users"."id2")
AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", ""),
},
{
input: Dialect(dialect.Postgres).Insert("users").Columns("id", "email").Values("1", "user@example.com").ConflictColumns("id").UpdateSet("email", "user-1@example.com"),
wantQuery: `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "id" = "excluded"."id", "email" = "excluded"."email"`,
wantArgs: []interface{}{"1", "user@example.com"},
},
{
input: Dialect(dialect.Postgres).Insert("users").Columns("id", "email").Values("1", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("id"),
wantQuery: `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "id" = "id", "email" = "email"`,
wantArgs: []interface{}{"1", "user@example.com"},
},
{
input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithAlternateValues).UpdateSet("email", "user-1@example.com").ConflictColumns("email"),
wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = ?",
wantArgs: []interface{}{"user@example.com", "user-1@example.com"},
},
{
input: Dialect(dialect.Postgres).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithAlternateValues).UpdateSet("email", "user-1@example.com").ConflictColumns("email"),
wantQuery: `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "email" = $2`,
wantArgs: []interface{}{"user@example.com", "user-1@example.com"},
},
{
input: Dialect(dialect.Postgres).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("email"),
wantQuery: `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "email" = "email"`,
wantArgs: []interface{}{"user@example.com"},
},
{
input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("email"),
wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = `email`",
wantArgs: []interface{}{"user@example.com"},
},
{
input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithNewValues).ConflictColumns("email"),
wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = VALUES(`email`)",
wantArgs: []interface{}{"user@example.com"},
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {