diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index c61839a1c..6d7caa6f0 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -671,12 +671,7 @@ type InsertBuilder struct { defaults bool returning []string values [][]interface{} - - // Upsert - conflictColumns []string - updateColumns []string - updateValues []interface{} - onConflictOp ConflictResolutionOp + conflict *conflict } // Insert creates a builder for the `INSERT INTO` statement. @@ -712,37 +707,6 @@ func (i *InsertBuilder) Columns(columns ...string) *InsertBuilder { return i } -// ConflictColumns sets the unique constraints that trigger the conflict resolution on insert -// to perform an upsert operation. The columns must have a unqiue constraint applied to trigger this behaviour. -func (i *InsertBuilder) ConflictColumns(values ...string) *InsertBuilder { - i.conflictColumns = append(i.conflictColumns, values...) - return i -} - -// A ConflictResolutionOp represents a possible action to take when an insert conflict occurrs. -type ConflictResolutionOp int - -// Conflict Operations -const ( - OpResolveWithNewValues ConflictResolutionOp = iota // Update conflict columns using EXCLUDED.column (postres) or c = VALUES(c) (mysql) - OpResolveWithIgnore // Sets each column to itself to force an update and return the ID, otherwise does not change any data. This may still trigger update hooks in the database. - OpResolveWithAlternateValues // Update using provided values across all rows. -) - -// OnConflict sets the conflict resolution behaviour when a unique constraint -// violation occurrs, triggering an upsert. -func (i *InsertBuilder) OnConflict(op ConflictResolutionOp) *InsertBuilder { - i.onConflictOp = op - return i -} - -// UpdateSet sets a column and a its value for use on upsert -func (i *InsertBuilder) UpdateSet(column string, v interface{}) *InsertBuilder { - i.updateColumns = append(i.updateColumns, column) - i.updateValues = append(i.updateValues, v) - return i -} - // Values append a value tuple for the insert statement. func (i *InsertBuilder) Values(values ...interface{}) *InsertBuilder { i.values = append(i.values, values) @@ -755,21 +719,235 @@ func (i *InsertBuilder) Default() *InsertBuilder { return i } -func (i *InsertBuilder) writeDefault() { - switch i.Dialect() { - case dialect.MySQL: - i.WriteString("VALUES ()") - case dialect.SQLite, dialect.Postgres: - i.WriteString("DEFAULT VALUES") - } -} - // Returning adds the `RETURNING` clause to the insert statement. PostgreSQL only. func (i *InsertBuilder) Returning(columns ...string) *InsertBuilder { i.returning = columns return i } +type ( + // conflict holds the configuration for the + // `ON CONFLICT` / `ON DUPLICATE KEY` clause. + conflict struct { + target struct { + constraint string + columns []string + where *Predicate + } + action struct { + nothing bool + where *Predicate + update func(*UpdateSet) + } + } + + // ConflictOption allows configuring the + // conflict config using functional options. + ConflictOption func(*conflict) +) + +// ConflictColumns sets the unique constraints that trigger the conflict +// resolution on insert to perform an upsert operation. The columns must +// have a unique constraint applied to trigger this behaviour. +// +// sql.Insert("users"). +// Columns("id", "name"). +// Values(1, "Mashraki"). +// OnConflict( +// sql.ConflictColumns("id"), +// sql.ResolveWithNewValues(), +// ) +// +func ConflictColumns(names ...string) ConflictOption { + return func(c *conflict) { + c.target.columns = names + } +} + +// ConflictConstraint allows setting the constraint +// name (i.e. `ON CONSTRAINT `) for PostgreSQL. +// +// sql.Insert("users"). +// Columns("id", "name"). +// Values(1, "Mashraki"). +// OnConflict( +// sql.ConflictConstraint("users_pkey"), +// sql.ResolveWithNewValues(), +// ) +// +func ConflictConstraint(name string) ConflictOption { + return func(c *conflict) { + c.target.constraint = name + } +} + +// ConflictWhere allows inference of partial unique indexes. See, PostgreSQL +// doc: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT +func ConflictWhere(p *Predicate) ConflictOption { + return func(c *conflict) { + c.target.where = p + } +} + +// UpdateWhere allows setting the an update condition. Only rows +// for which this expression returns true will be updated. +func UpdateWhere(p *Predicate) ConflictOption { + return func(c *conflict) { + c.action.where = p + } +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported by SQLite and PostgreSQL. +// +// sql.Insert("users"). +// Columns("id", "name"). +// Values(1, "Mashraki"). +// OnConflict( +// sql.ConflictColumns("id"), +// sql.DoNothing() +// ) +// +func DoNothing() ConflictOption { + return func(c *conflict) { + c.action.nothing = true + } +} + +// ResolveWithIgnore sets each column to itself to force an update and return the ID, +// otherwise does not change any data. This may still trigger update hooks in the database. +// +// sql.Insert("users"). +// Columns("id"). +// Values(1). +// OnConflict( +// sql.ConflictColumns("id"), +// sql.ResolveWithIgnore() +// ) +// +// // Output: +// // MySQL: INSERT INTO `users` (`id`) VALUES(1) ON DUPLICATE KEY UPDATE `id` = `users`.`id` +// // PostgreSQL: INSERT INTO "users" ("id") VALUES(1) ON CONFLICT ("id") DO UPDATE SET "id" = "users"."id +// +func ResolveWithIgnore() ConflictOption { + return func(c *conflict) { + c.action.update = func(u *UpdateSet) { + for _, c := range u.columns { + u.Set(c, Expr(u.Table().C(c))) + } + } + } +} + +// ResolveWithNewValues updates columns using the new values proposed +// for insertion using the special EXCLUDED/VALUES table. +// +// sql.Insert("users"). +// Columns("id", "name"). +// Values(1, "Mashraki"). +// OnConflict( +// sql.ConflictColumns("id"), +// sql.ResolveWithNewValues() +// ) +// +// // Output: +// // MySQL: INSERT INTO `users` (`id`, `name`) VALUES(1, 'Mashraki) ON DUPLICATE KEY UPDATE `id` = VALUES(`id`), `name` = VALUES(`name`), +// // PostgreSQL: INSERT INTO "users" ("id") VALUES(1) ON CONFLICT ("id") DO UPDATE SET "id" = "excluded"."id, "name" = "excluded"."name" +// +func ResolveWithNewValues() ConflictOption { + return func(c *conflict) { + c.action.update = func(u *UpdateSet) { + for _, c := range u.columns { + u.SetExcluded(c) + } + } + } +} + +// ResolveWith allows setting a custom function to set the `UPDATE` clause. +// +// Insert("users"). +// Columns("id", "name"). +// Values(1, "Mashraki"). +// OnConflict( +// ConflictColumns("name"), +// ResolveWith(func(s *UpdateSet) { +// s.SetNull("created_at") +// s.Set("name", Expr(s.Excluded().C("name"))) +// }), +// ) +// +func ResolveWith(fn func(*UpdateSet)) ConflictOption { + return func(c *conflict) { + c.action.update = fn + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// sql.Insert("users"). +// Columns("id", "name"). +// Values(1, "Mashraki"). +// OnConflict( +// sql.ConflictColumns("id"), +// sql.ResolveWithNewValues() +// ) +// +func (i *InsertBuilder) OnConflict(opts ...ConflictOption) *InsertBuilder { + if i.conflict == nil { + i.conflict = &conflict{} + } + for _, opt := range opts { + opt(i.conflict) + } + return i +} + +// UpdateSet describes a set of changes of the `DO UPDATE` clause. +type UpdateSet struct { + table string + columns []string + update *UpdateBuilder +} + +// Table returns the table the `UPSERT` statement is executed on. +func (u *UpdateSet) Table() *SelectTable { + return Dialect(u.update.dialect).Table(u.table) +} + +// Columns returns all columns in the `INSERT` statement. +func (u *UpdateSet) Columns() []string { + return u.columns +} + +// Set sets a column to a given value. +func (u *UpdateSet) Set(column string, v interface{}) *UpdateSet { + u.update.Set(column, v) + return u +} + +// SetNull sets a column as null value. +func (u *UpdateSet) SetNull(column string) *UpdateSet { + u.update.SetNull(column) + return u +} + +// SetExcluded sets the column name to its EXCLUDED/VALUES value. +// For example, "c" = "excluded"."c", or `c` = VALUES(`c`). +func (u *UpdateSet) SetExcluded(name string) *UpdateSet { + switch u.update.Dialect() { + case dialect.MySQL: + u.update.Set(name, ExprFunc(func(b *Builder) { + b.WriteString("VALUES(").Ident(name).WriteByte(')') + })) + default: + t := Dialect(u.update.dialect).Table("excluded") + u.update.Set(name, Expr(t.C(name))) + } + return u +} + // Query returns query representation of an `INSERT INTO` statement. func (i *InsertBuilder) Query() (string, []interface{}) { i.WriteString("INSERT INTO ") @@ -787,8 +965,8 @@ func (i *InsertBuilder) Query() (string, []interface{}) { i.WriteByte('(').Args(v...).WriteByte(')') } } - if len(i.conflictColumns) > 0 { - i.buildConflictHandling() + if i.conflict != nil { + i.writeConflict() } if len(i.returning) > 0 && i.postgres() { i.WriteString(" RETURNING ") @@ -797,69 +975,47 @@ func (i *InsertBuilder) Query() (string, []interface{}) { return i.String(), i.args } -func (i *InsertBuilder) buildConflictHandling() { +func (i *InsertBuilder) writeDefault() { switch i.Dialect() { - case dialect.Postgres, dialect.SQLite: - i.Pad(). - WriteString("ON CONFLICT"). - Pad(). - Nested(func(b *Builder) { - b.IdentComma(i.conflictColumns...) - }). - Pad(). - WriteString("DO UPDATE SET ") - - switch i.onConflictOp { - case OpResolveWithNewValues: - for j, c := range i.columns { - if j > 0 { - i.Comma() - } - i.Ident(c).WriteOp(OpEQ).Ident("excluded").WriteByte('.').Ident(c) - } - case OpResolveWithIgnore: - writeIgnoreValues(i) - case OpResolveWithAlternateValues: - writeUpdateValues(i, i.updateColumns, i.updateValues) - } - case dialect.MySQL: - i.Pad().WriteString("ON DUPLICATE KEY UPDATE ") - - switch i.onConflictOp { - case OpResolveWithIgnore: - writeIgnoreValues(i) - case OpResolveWithNewValues: - for j, c := range i.columns { - if j > 0 { - i.Comma() - } - // update column with the value we tried to insert - i.Ident(c).WriteOp(OpEQ).WriteString("VALUES").WriteByte('(').Ident(c).WriteByte(')') - } - case OpResolveWithAlternateValues: - writeUpdateValues(i, i.updateColumns, i.updateValues) - } + i.WriteString("VALUES ()") + case dialect.SQLite, dialect.Postgres: + i.WriteString("DEFAULT VALUES") } } -func writeUpdateValues(builder *InsertBuilder, columns []string, values []interface{}) { - for i, c := range columns { - if i > 0 { - builder.Comma() +func (i *InsertBuilder) writeConflict() { + switch i.Dialect() { + case dialect.MySQL: + i.WriteString(" ON DUPLICATE KEY UPDATE ") + if i.conflict.action.nothing { + i.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')")) } - builder.Ident(c).WriteString(" = ").Arg(builder.updateValues[i]) + case dialect.SQLite, dialect.Postgres: + i.WriteString(" ON CONFLICT") + switch t := i.conflict.target; { + case t.constraint != "" && len(t.columns) != 0: + i.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns)) + case t.constraint != "": + i.WriteString(" ON CONSTRAINT ").Ident(t.constraint) + case len(t.columns) != 0: + i.WriteString(" (").IdentComma(t.columns...).WriteByte(')') + } + if p := i.conflict.target.where; p != nil { + i.WriteString(" WHERE ").Join(p) + } + if i.conflict.action.nothing { + i.WriteString(" DO NOTHING") + return + } + i.WriteString(" DO UPDATE SET ") } -} - -// writeIgnoreValues ignores conflicts by setting each column to itself e.g. "c" = "c", -// performimg an update without changing any values so that it returns the record ID. -func writeIgnoreValues(builder *InsertBuilder) { - for j, c := range builder.columns { - if j > 0 { - builder.Comma() - } - builder.Ident(c).WriteOp(OpEQ).Ident(c) + u := &UpdateSet{table: i.table, columns: i.columns, update: &UpdateBuilder{}} + u.update.Builder = i.Builder + i.conflict.action.update(u) + u.update.writeSetter() + if p := i.conflict.action.where; p != nil { + i.WriteString(" WHERE ").Join(p) } } @@ -886,7 +1042,7 @@ func (u *UpdateBuilder) Schema(name string) *UpdateBuilder { return u } -// Set sets a column and a its value. +// Set sets a column to a given value. func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder { u.columns = append(u.columns, column) u.values = append(u.values, v) @@ -896,7 +1052,7 @@ func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder { // Add adds a numeric value to the given column. func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder { u.columns = append(u.columns, column) - u.values = append(u.values, P().Append(func(b *Builder) { + u.values = append(u.values, ExprFunc(func(b *Builder) { b.WriteString("COALESCE") b.Nested(func(b *Builder) { b.Ident(column).Comma().Arg(0) @@ -942,6 +1098,16 @@ func (u *UpdateBuilder) Query() (string, []interface{}) { u.WriteString("UPDATE ") u.writeSchema(u.schema) u.Ident(u.table).WriteString(" SET ") + u.writeSetter() + if u.where != nil { + u.WriteString(" WHERE ") + u.Join(u.where) + } + return u.String(), u.args +} + +// writeSetter writes the "SET" clause for the UPDATE statement. +func (u *UpdateBuilder) writeSetter() { for i, c := range u.nulls { if i > 0 { u.Comma() @@ -963,11 +1129,6 @@ func (u *UpdateBuilder) Query() (string, []interface{}) { u.Arg(v) } } - if u.where != nil { - u.WriteString(" WHERE ") - u.Join(u.where) - } - return u.String(), u.args } // DeleteBuilder is a builder for `DELETE` statement. diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 520237730..529de6f2b 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1492,41 +1492,6 @@ WHERE (((("users"."id1" = "users"."id2" AND "users"."id1" <> "users"."id2") AND "users"."id1" > "users"."id2") AND "users"."id1" >= "users"."id2") AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", ""), }, - { - input: Dialect(dialect.Postgres).Insert("users").Columns("id", "email").Values("1", "user@example.com").ConflictColumns("id").UpdateSet("email", "user-1@example.com"), - wantQuery: `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "id" = "excluded"."id", "email" = "excluded"."email"`, - wantArgs: []interface{}{"1", "user@example.com"}, - }, - { - input: Dialect(dialect.Postgres).Insert("users").Columns("id", "email").Values("1", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("id"), - wantQuery: `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "id" = "id", "email" = "email"`, - wantArgs: []interface{}{"1", "user@example.com"}, - }, - { - input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithAlternateValues).UpdateSet("email", "user-1@example.com").ConflictColumns("email"), - wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = ?", - wantArgs: []interface{}{"user@example.com", "user-1@example.com"}, - }, - { - input: Dialect(dialect.Postgres).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithAlternateValues).UpdateSet("email", "user-1@example.com").ConflictColumns("email"), - wantQuery: `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "email" = $2`, - wantArgs: []interface{}{"user@example.com", "user-1@example.com"}, - }, - { - input: Dialect(dialect.Postgres).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("email"), - wantQuery: `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "email" = "email"`, - wantArgs: []interface{}{"user@example.com"}, - }, - { - input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithIgnore).ConflictColumns("email"), - wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = `email`", - wantArgs: []interface{}{"user@example.com"}, - }, - { - input: Dialect(dialect.MySQL).Insert("users").Set("email", "user@example.com").OnConflict(OpResolveWithNewValues).ConflictColumns("email"), - wantQuery: "INSERT INTO `users` (`email`) VALUES (?) ON DUPLICATE KEY UPDATE `email` = VALUES(`email`)", - wantArgs: []interface{}{"user@example.com"}, - }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { @@ -1733,3 +1698,109 @@ func TestUpdateBuilder_SetExpr(t *testing.T) { require.Equal(t, `UPDATE "users" SET "name" = $1, "active" = NOT(active), "age" = "excluded"."age", "x" = "excluded"."x" || ' (formerly ' || "x" || ')', "y" = $2 + "excluded"."y" + $3`, query) require.Equal(t, []interface{}{"Ariel", "~", "~"}, args) } + +func TestInsert_OnConflict(t *testing.T) { + t.Run("Postgres", func(t *testing.T) { // And SQLite. + query, args := Dialect(dialect.Postgres). + Insert("users"). + Columns("id", "email"). + Values("1", "user@example.com"). + OnConflict( + ConflictColumns("email"), + ConflictWhere(EQ("name", "Ariel")), + ResolveWithNewValues(), + UpdateWhere(NEQ("updated_at", 0)), + ). + Query() + require.Equal(t, `INSERT INTO "users" ("id", "email") VALUES ($1, $2) ON CONFLICT ("email") WHERE "name" = $3 DO UPDATE SET "id" = "excluded"."id", "email" = "excluded"."email" WHERE "updated_at" <> $4`, query) + require.Equal(t, []interface{}{"1", "user@example.com", "Ariel", 0}, args) + + query, args = Dialect(dialect.Postgres). + Insert("users"). + Columns("id", "name"). + Values("1", "Mashraki"). + OnConflict( + ConflictConstraint("users_pkey"), + DoNothing(), + ). + Query() + require.Equal(t, `INSERT INTO "users" ("id", "name") VALUES ($1, $2) ON CONFLICT ON CONSTRAINT "users_pkey" DO NOTHING`, query) + require.Equal(t, []interface{}{"1", "Mashraki"}, args) + + query, args = Dialect(dialect.Postgres). + Insert("users"). + Columns("id"). + Values(1). + OnConflict( + DoNothing(), + ). + Query() + require.Equal(t, `INSERT INTO "users" ("id") VALUES ($1) ON CONFLICT DO NOTHING`, query) + require.Equal(t, []interface{}{1}, args) + + query, args = Dialect(dialect.Postgres). + Insert("users"). + Columns("id"). + Values(1). + OnConflict( + ConflictColumns("id"), + ResolveWithIgnore(), + ). + Query() + require.Equal(t, `INSERT INTO "users" ("id") VALUES ($1) ON CONFLICT ("id") DO UPDATE SET "id" = "users"."id"`, query) + require.Equal(t, []interface{}{1}, args) + + query, args = Dialect(dialect.Postgres). + Insert("users"). + Columns("id", "name"). + Values(1, "Mashraki"). + OnConflict( + ConflictColumns("name"), + ResolveWith(func(s *UpdateSet) { + s.SetExcluded("name") + s.SetNull("created_at") + }), + ). + Query() + require.Equal(t, `INSERT INTO "users" ("id", "name") VALUES ($1, $2) ON CONFLICT ("name") DO UPDATE SET "created_at" = NULL, "name" = "excluded"."name"`, query) + require.Equal(t, []interface{}{1, "Mashraki"}, args) + }) + + t.Run("MySQL", func(t *testing.T) { + query, args := Dialect(dialect.MySQL). + Insert("users"). + Columns("id", "email"). + Values("1", "user@example.com"). + OnConflict( + ResolveWithNewValues(), + ). + Query() + require.Equal(t, "INSERT INTO `users` (`id`, `email`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `id` = VALUES(`id`), `email` = VALUES(`email`)", query) + require.Equal(t, []interface{}{"1", "user@example.com"}, args) + + query, args = Dialect(dialect.MySQL). + Insert("users"). + Columns("id", "email"). + Values("1", "user@example.com"). + OnConflict( + ResolveWithIgnore(), + ). + Query() + require.Equal(t, "INSERT INTO `users` (`id`, `email`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `id` = `users`.`id`, `email` = `users`.`email`", query) + require.Equal(t, []interface{}{"1", "user@example.com"}, args) + + query, args = Dialect(dialect.MySQL). + Insert("users"). + Columns("id", "name"). + Values("1", "Mashraki"). + OnConflict( + ResolveWith(func(s *UpdateSet) { + s.SetExcluded("name") + s.SetNull("created_at") + }), + ). + Query() + require.Equal(t, "INSERT INTO `users` (`id`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `created_at` = NULL, `name` = VALUES(`name`)", query) + require.Equal(t, []interface{}{"1", "Mashraki"}, args) + }) +}