dialect/sql: move onconflict clause to functional options

This commit is contained in:
Ariel Mashraki
2021-07-26 20:35:08 +03:00
committed by Ariel Mashraki
parent 09fd9957c0
commit b8f8ea0f06
2 changed files with 378 additions and 146 deletions

View File

@@ -671,12 +671,7 @@ type InsertBuilder struct {
defaults bool
returning []string
values [][]interface{}
// Upsert
conflictColumns []string
updateColumns []string
updateValues []interface{}
onConflictOp ConflictResolutionOp
conflict *conflict
}
// Insert creates a builder for the `INSERT INTO` statement.
@@ -712,37 +707,6 @@ func (i *InsertBuilder) Columns(columns ...string) *InsertBuilder {
return i
}
// ConflictColumns sets the unique constraints that trigger the conflict resolution on insert
// to perform an upsert operation. The columns must have a unqiue constraint applied to trigger this behaviour.
func (i *InsertBuilder) ConflictColumns(values ...string) *InsertBuilder {
i.conflictColumns = append(i.conflictColumns, values...)
return i
}
// A ConflictResolutionOp represents a possible action to take when an insert conflict occurrs.
type ConflictResolutionOp int
// Conflict Operations
const (
OpResolveWithNewValues ConflictResolutionOp = iota // Update conflict columns using EXCLUDED.column (postres) or c = VALUES(c) (mysql)
OpResolveWithIgnore // Sets each column to itself to force an update and return the ID, otherwise does not change any data. This may still trigger update hooks in the database.
OpResolveWithAlternateValues // Update using provided values across all rows.
)
// OnConflict sets the conflict resolution behaviour when a unique constraint
// violation occurrs, triggering an upsert.
func (i *InsertBuilder) OnConflict(op ConflictResolutionOp) *InsertBuilder {
i.onConflictOp = op
return i
}
// UpdateSet sets a column and a its value for use on upsert
func (i *InsertBuilder) UpdateSet(column string, v interface{}) *InsertBuilder {
i.updateColumns = append(i.updateColumns, column)
i.updateValues = append(i.updateValues, v)
return i
}
// Values append a value tuple for the insert statement.
func (i *InsertBuilder) Values(values ...interface{}) *InsertBuilder {
i.values = append(i.values, values)
@@ -755,21 +719,235 @@ func (i *InsertBuilder) Default() *InsertBuilder {
return i
}
func (i *InsertBuilder) writeDefault() {
switch i.Dialect() {
case dialect.MySQL:
i.WriteString("VALUES ()")
case dialect.SQLite, dialect.Postgres:
i.WriteString("DEFAULT VALUES")
}
}
// Returning adds the `RETURNING` clause to the insert statement. PostgreSQL only.
func (i *InsertBuilder) Returning(columns ...string) *InsertBuilder {
i.returning = columns
return i
}
type (
// conflict holds the configuration for the
// `ON CONFLICT` / `ON DUPLICATE KEY` clause.
conflict struct {
target struct {
constraint string
columns []string
where *Predicate
}
action struct {
nothing bool
where *Predicate
update func(*UpdateSet)
}
}
// ConflictOption allows configuring the
// conflict config using functional options.
ConflictOption func(*conflict)
)
// ConflictColumns sets the unique constraints that trigger the conflict
// resolution on insert to perform an upsert operation. The columns must
// have a unique constraint applied to trigger this behaviour.
//
// sql.Insert("users").
// Columns("id", "name").
// Values(1, "Mashraki").
// OnConflict(
// sql.ConflictColumns("id"),
// sql.ResolveWithNewValues(),
// )
//
func ConflictColumns(names ...string) ConflictOption {
return func(c *conflict) {
c.target.columns = names
}
}
// ConflictConstraint allows setting the constraint
// name (i.e. `ON CONSTRAINT <name>`) for PostgreSQL.
//
// sql.Insert("users").
// Columns("id", "name").
// Values(1, "Mashraki").
// OnConflict(
// sql.ConflictConstraint("users_pkey"),
// sql.ResolveWithNewValues(),
// )
//
func ConflictConstraint(name string) ConflictOption {
return func(c *conflict) {
c.target.constraint = name
}
}
// ConflictWhere allows inference of partial unique indexes. See, PostgreSQL
// doc: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT
func ConflictWhere(p *Predicate) ConflictOption {
return func(c *conflict) {
c.target.where = p
}
}
// UpdateWhere allows setting the an update condition. Only rows
// for which this expression returns true will be updated.
func UpdateWhere(p *Predicate) ConflictOption {
return func(c *conflict) {
c.action.where = p
}
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported by SQLite and PostgreSQL.
//
// sql.Insert("users").
// Columns("id", "name").
// Values(1, "Mashraki").
// OnConflict(
// sql.ConflictColumns("id"),
// sql.DoNothing()
// )
//
func DoNothing() ConflictOption {
return func(c *conflict) {
c.action.nothing = true
}
}
// ResolveWithIgnore sets each column to itself to force an update and return the ID,
// otherwise does not change any data. This may still trigger update hooks in the database.
//
// sql.Insert("users").
// Columns("id").
// Values(1).
// OnConflict(
// sql.ConflictColumns("id"),
// sql.ResolveWithIgnore()
// )
//
// // Output:
// // MySQL: INSERT INTO `users` (`id`) VALUES(1) ON DUPLICATE KEY UPDATE `id` = `users`.`id`
// // PostgreSQL: INSERT INTO "users" ("id") VALUES(1) ON CONFLICT ("id") DO UPDATE SET "id" = "users"."id
//
func ResolveWithIgnore() ConflictOption {
return func(c *conflict) {
c.action.update = func(u *UpdateSet) {
for _, c := range u.columns {
u.Set(c, Expr(u.Table().C(c)))
}
}
}
}
// ResolveWithNewValues updates columns using the new values proposed
// for insertion using the special EXCLUDED/VALUES table.
//
// sql.Insert("users").
// Columns("id", "name").
// Values(1, "Mashraki").
// OnConflict(
// sql.ConflictColumns("id"),
// sql.ResolveWithNewValues()
// )
//
// // Output:
// // MySQL: INSERT INTO `users` (`id`, `name`) VALUES(1, 'Mashraki) ON DUPLICATE KEY UPDATE `id` = VALUES(`id`), `name` = VALUES(`name`),
// // PostgreSQL: INSERT INTO "users" ("id") VALUES(1) ON CONFLICT ("id") DO UPDATE SET "id" = "excluded"."id, "name" = "excluded"."name"
//
func ResolveWithNewValues() ConflictOption {
return func(c *conflict) {
c.action.update = func(u *UpdateSet) {
for _, c := range u.columns {
u.SetExcluded(c)
}
}
}
}
// ResolveWith allows setting a custom function to set the `UPDATE` clause.
//
// Insert("users").
// Columns("id", "name").
// Values(1, "Mashraki").
// OnConflict(
// ConflictColumns("name"),
// ResolveWith(func(s *UpdateSet) {
// s.SetNull("created_at")
// s.Set("name", Expr(s.Excluded().C("name")))
// }),
// )
//
func ResolveWith(fn func(*UpdateSet)) ConflictOption {
return func(c *conflict) {
c.action.update = fn
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// sql.Insert("users").
// Columns("id", "name").
// Values(1, "Mashraki").
// OnConflict(
// sql.ConflictColumns("id"),
// sql.ResolveWithNewValues()
// )
//
func (i *InsertBuilder) OnConflict(opts ...ConflictOption) *InsertBuilder {
if i.conflict == nil {
i.conflict = &conflict{}
}
for _, opt := range opts {
opt(i.conflict)
}
return i
}
// UpdateSet describes a set of changes of the `DO UPDATE` clause.
type UpdateSet struct {
table string
columns []string
update *UpdateBuilder
}
// Table returns the table the `UPSERT` statement is executed on.
func (u *UpdateSet) Table() *SelectTable {
return Dialect(u.update.dialect).Table(u.table)
}
// Columns returns all columns in the `INSERT` statement.
func (u *UpdateSet) Columns() []string {
return u.columns
}
// Set sets a column to a given value.
func (u *UpdateSet) Set(column string, v interface{}) *UpdateSet {
u.update.Set(column, v)
return u
}
// SetNull sets a column as null value.
func (u *UpdateSet) SetNull(column string) *UpdateSet {
u.update.SetNull(column)
return u
}
// SetExcluded sets the column name to its EXCLUDED/VALUES value.
// For example, "c" = "excluded"."c", or `c` = VALUES(`c`).
func (u *UpdateSet) SetExcluded(name string) *UpdateSet {
switch u.update.Dialect() {
case dialect.MySQL:
u.update.Set(name, ExprFunc(func(b *Builder) {
b.WriteString("VALUES(").Ident(name).WriteByte(')')
}))
default:
t := Dialect(u.update.dialect).Table("excluded")
u.update.Set(name, Expr(t.C(name)))
}
return u
}
// Query returns query representation of an `INSERT INTO` statement.
func (i *InsertBuilder) Query() (string, []interface{}) {
i.WriteString("INSERT INTO ")
@@ -787,8 +965,8 @@ func (i *InsertBuilder) Query() (string, []interface{}) {
i.WriteByte('(').Args(v...).WriteByte(')')
}
}
if len(i.conflictColumns) > 0 {
i.buildConflictHandling()
if i.conflict != nil {
i.writeConflict()
}
if len(i.returning) > 0 && i.postgres() {
i.WriteString(" RETURNING ")
@@ -797,69 +975,47 @@ func (i *InsertBuilder) Query() (string, []interface{}) {
return i.String(), i.args
}
func (i *InsertBuilder) buildConflictHandling() {
func (i *InsertBuilder) writeDefault() {
switch i.Dialect() {
case dialect.Postgres, dialect.SQLite:
i.Pad().
WriteString("ON CONFLICT").
Pad().
Nested(func(b *Builder) {
b.IdentComma(i.conflictColumns...)
}).
Pad().
WriteString("DO UPDATE SET ")
switch i.onConflictOp {
case OpResolveWithNewValues:
for j, c := range i.columns {
if j > 0 {
i.Comma()
}
i.Ident(c).WriteOp(OpEQ).Ident("excluded").WriteByte('.').Ident(c)
}
case OpResolveWithIgnore:
writeIgnoreValues(i)
case OpResolveWithAlternateValues:
writeUpdateValues(i, i.updateColumns, i.updateValues)
}
case dialect.MySQL:
i.Pad().WriteString("ON DUPLICATE KEY UPDATE ")
switch i.onConflictOp {
case OpResolveWithIgnore:
writeIgnoreValues(i)
case OpResolveWithNewValues:
for j, c := range i.columns {
if j > 0 {
i.Comma()
}
// update column with the value we tried to insert
i.Ident(c).WriteOp(OpEQ).WriteString("VALUES").WriteByte('(').Ident(c).WriteByte(')')
}
case OpResolveWithAlternateValues:
writeUpdateValues(i, i.updateColumns, i.updateValues)
}
i.WriteString("VALUES ()")
case dialect.SQLite, dialect.Postgres:
i.WriteString("DEFAULT VALUES")
}
}
func writeUpdateValues(builder *InsertBuilder, columns []string, values []interface{}) {
for i, c := range columns {
if i > 0 {
builder.Comma()
func (i *InsertBuilder) writeConflict() {
switch i.Dialect() {
case dialect.MySQL:
i.WriteString(" ON DUPLICATE KEY UPDATE ")
if i.conflict.action.nothing {
i.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')"))
}
builder.Ident(c).WriteString(" = ").Arg(builder.updateValues[i])
case dialect.SQLite, dialect.Postgres:
i.WriteString(" ON CONFLICT")
switch t := i.conflict.target; {
case t.constraint != "" && len(t.columns) != 0:
i.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns))
case t.constraint != "":
i.WriteString(" ON CONSTRAINT ").Ident(t.constraint)
case len(t.columns) != 0:
i.WriteString(" (").IdentComma(t.columns...).WriteByte(')')
}
if p := i.conflict.target.where; p != nil {
i.WriteString(" WHERE ").Join(p)
}
if i.conflict.action.nothing {
i.WriteString(" DO NOTHING")
return
}
i.WriteString(" DO UPDATE SET ")
}
}
// writeIgnoreValues ignores conflicts by setting each column to itself e.g. "c" = "c",
// performimg an update without changing any values so that it returns the record ID.
func writeIgnoreValues(builder *InsertBuilder) {
for j, c := range builder.columns {
if j > 0 {
builder.Comma()
}
builder.Ident(c).WriteOp(OpEQ).Ident(c)
u := &UpdateSet{table: i.table, columns: i.columns, update: &UpdateBuilder{}}
u.update.Builder = i.Builder
i.conflict.action.update(u)
u.update.writeSetter()
if p := i.conflict.action.where; p != nil {
i.WriteString(" WHERE ").Join(p)
}
}
@@ -886,7 +1042,7 @@ func (u *UpdateBuilder) Schema(name string) *UpdateBuilder {
return u
}
// Set sets a column and a its value.
// Set sets a column to a given value.
func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder {
u.columns = append(u.columns, column)
u.values = append(u.values, v)
@@ -896,7 +1052,7 @@ func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder {
// Add adds a numeric value to the given column.
func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder {
u.columns = append(u.columns, column)
u.values = append(u.values, P().Append(func(b *Builder) {
u.values = append(u.values, ExprFunc(func(b *Builder) {
b.WriteString("COALESCE")
b.Nested(func(b *Builder) {
b.Ident(column).Comma().Arg(0)
@@ -942,6 +1098,16 @@ func (u *UpdateBuilder) Query() (string, []interface{}) {
u.WriteString("UPDATE ")
u.writeSchema(u.schema)
u.Ident(u.table).WriteString(" SET ")
u.writeSetter()
if u.where != nil {
u.WriteString(" WHERE ")
u.Join(u.where)
}
return u.String(), u.args
}
// writeSetter writes the "SET" clause for the UPDATE statement.
func (u *UpdateBuilder) writeSetter() {
for i, c := range u.nulls {
if i > 0 {
u.Comma()
@@ -963,11 +1129,6 @@ func (u *UpdateBuilder) Query() (string, []interface{}) {
u.Arg(v)
}
}
if u.where != nil {
u.WriteString(" WHERE ")
u.Join(u.where)
}
return u.String(), u.args
}
// DeleteBuilder is a builder for `DELETE` statement.