diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 0354e1012..e2902e96c 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -622,6 +622,12 @@ type InsertBuilder struct { defaults bool returning []string values [][]interface{} + + // Upsert + conflictColumns []string + updateColumns []string + updateValues []interface{} + onConflictOp ConflictResolutionOp } // Insert creates a builder for the `INSERT INTO` statement. @@ -657,7 +663,38 @@ func (i *InsertBuilder) Columns(columns ...string) *InsertBuilder { return i } -// Values appends a value tuple to the INSERT statement. +// ConflictColumns sets the unique constraints that trigger the conflict resolution on insert +// to perform an upsert operation. The columns must have a unqiue constraint applied to trigger this behaviour. +func (i *InsertBuilder) ConflictColumns(values ...string) *InsertBuilder { + i.conflictColumns = append(i.conflictColumns, values...) + return i +} + +// A ConflictResolutionOp represents a possible action to take when an insert conflict occurrs. +type ConflictResolutionOp int + +// Conflict Operations +const ( + OpResolveWithNewValues ConflictResolutionOp = iota // Update conflict columns using EXCLUDED.column (postres) or c = VALUES(c) (mysql) + OpResolveWithIgnore // Sets each column to itself to force an update and return the ID, otherwise does not change any data. This may still trigger update hooks in the database. + OpResolveWithAlternateValues // Update using provided values across all rows. +) + +// OnConflict sets the conflict resolution behaviour when a unique constraint +// violation occurrs, triggering an upsert. +func (i *InsertBuilder) OnConflict(op ConflictResolutionOp) *InsertBuilder { + i.onConflictOp = op + return i +} + +// UpdateSet sets a column and a its value for use on upsert +func (i *InsertBuilder) UpdateSet(column string, v interface{}) *InsertBuilder { + i.updateColumns = append(i.updateColumns, column) + i.updateValues = append(i.updateValues, v) + return i +} + +// Values append a value tuple for the insert statement. func (i *InsertBuilder) Values(values ...interface{}) *InsertBuilder { i.values = append(i.values, values) return i @@ -701,6 +738,11 @@ func (i *InsertBuilder) Query() (string, []interface{}) { i.WriteByte('(').Args(v...).WriteByte(')') } } + + if len(i.conflictColumns) > 0 { + i.buildConflictHandling() + } + if len(i.returning) > 0 && i.postgres() { i.WriteString(" RETURNING ") i.IdentComma(i.returning...) @@ -708,6 +750,72 @@ func (i *InsertBuilder) Query() (string, []interface{}) { return i.String(), i.args } +func (i *InsertBuilder) buildConflictHandling() { + switch i.Dialect() { + case dialect.Postgres, dialect.SQLite: + i.Pad(). + WriteString("ON CONFLICT"). + Pad(). + Nested(func(b *Builder) { + b.IdentComma(i.conflictColumns...) + }). + Pad(). + WriteString("DO UPDATE SET ") + + switch i.onConflictOp { + case OpResolveWithNewValues: + for j, c := range i.columns { + if j > 0 { + i.Comma() + } + i.Ident(c).WriteOp(OpEQ).Ident("excluded").WriteByte('.').Ident(c) + } + case OpResolveWithIgnore: + writeIgnoreValues(i) + case OpResolveWithAlternateValues: + writeUpdateValues(i, i.updateColumns, i.updateValues) + } + + case dialect.MySQL: + i.Pad().WriteString("ON DUPLICATE KEY UPDATE ") + + switch i.onConflictOp { + case OpResolveWithIgnore: + writeIgnoreValues(i) + case OpResolveWithNewValues: + for j, c := range i.columns { + if j > 0 { + i.Comma() + } + // update column with the value we tried to insert + i.Ident(c).WriteOp(OpEQ).WriteString("VALUES").WriteByte('(').Ident(c).WriteByte(')') + } + case OpResolveWithAlternateValues: + writeUpdateValues(i, i.updateColumns, i.updateValues) + } + } +} + +func writeUpdateValues(builder *InsertBuilder, columns []string, values []interface{}) { + for i, c := range columns { + if i > 0 { + builder.Comma() + } + builder.Ident(c).WriteString(" = ").Arg(builder.updateValues[i]) + } +} + +// writeIgnoreValues ignores conflicts by setting each column to itself e.g. "c" = "c", +// performimg an update without changing any values so that it returns the record ID. +func writeIgnoreValues(builder *InsertBuilder) { + for j, c := range builder.columns { + if j > 0 { + builder.Comma() + } + builder.Ident(c).WriteOp(OpEQ).Ident(c) + } +} + // UpdateBuilder is a builder for `UPDATE` statement. type UpdateBuilder struct { Builder diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 597e12ae5..4aa297746 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1314,14 +1314,14 @@ func TestBuilder(t *testing.T) { Or(). Where(And(NEQ("f", "f"), NEQ("g", "g"))), wantQuery: strings.NewReplacer("\n", "", "\t", "").Replace(` -SELECT * FROM "users" -WHERE - ( - (("id" = $1 AND "group_id" IN ($2, $3)) OR ("id" = $4 AND "group_id" IN ($5, $6))) - AND - (("a" = $7 OR ("b" = $8 AND "c" = $9)) AND (NOT ("d" IS NULL OR "e" IS NOT NULL))) - ) - OR ("f" <> $10 AND "g" <> $11)`), + SELECT * FROM "users" + WHERE + ( + (("id" = $1 AND "group_id" IN ($2, $3)) OR ("id" = $4 AND "group_id" IN ($5, $6))) + AND + (("a" = $7 OR ("b" = $8 AND "c" = $9)) AND (NOT ("d" IS NULL OR "e" IS NOT NULL))) + ) + OR ("f" <> $10 AND "g" <> $11)`), wantArgs: []interface{}{1, 2, 3, 2, 4, 5, "a", "b", "c", "f", "g"}, }, { @@ -1414,6 +1414,41 @@ WHERE (((("users"."id1" = "users"."id2" AND "users"."id1" <> "users"."id2") AND "users"."id1" > "users"."id2") AND "users"."id1" >= "users"."id2") AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", ""), }, + { + input: Dialect(dialect.Postgres).Insert("users").Columns("id", "email").Values("1", "user@example.com").ConflictColumns("id").UpdateSet("email", "user-1@example.com"), + wantQuery: `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "id" = "excluded"."id", "email" = "excluded"."email"`, + wantArgs: []interface{}{"1", "user@example.com"}, + }, + { + input: Dialect(dialect.Postgres).Insert("users").Columns("id", "email").Values("1", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("id"), + wantQuery: `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "id" = "id", "email" = "email"`, + wantArgs: []interface{}{"1", "user@example.com"}, + }, + { + input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithAlternateValues).UpdateSet("email", "user-1@example.com").ConflictColumns("email"), + wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = ?", + wantArgs: []interface{}{"user@example.com", "user-1@example.com"}, + }, + { + input: Dialect(dialect.Postgres).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithAlternateValues).UpdateSet("email", "user-1@example.com").ConflictColumns("email"), + wantQuery: `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "email" = $2`, + wantArgs: []interface{}{"user@example.com", "user-1@example.com"}, + }, + { + input: Dialect(dialect.Postgres).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("email"), + wantQuery: `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "email" = "email"`, + wantArgs: []interface{}{"user@example.com"}, + }, + { + input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("email"), + wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = `email`", + wantArgs: []interface{}{"user@example.com"}, + }, + { + input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithNewValues).ConflictColumns("email"), + wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = VALUES(`email`)", + wantArgs: []interface{}{"user@example.com"}, + }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) {