dialect/sql: support the RETURNING clause in UPDATE (#3016)

* dialect/sql: support the RETURNING clause in UPDATE
This commit is contained in:
Ariel Mashraki
2022-10-16 17:55:01 +03:00
committed by GitHub
parent ac725a61b7
commit bbdfea319d
2 changed files with 74 additions and 43 deletions

View File

@@ -708,7 +708,8 @@ func (i *InsertBuilder) Default() *InsertBuilder {
return i
}
// Returning adds the `RETURNING` clause to the insert statement. PostgreSQL only.
// Returning adds the `RETURNING` clause to the insert statement.
// Supported by SQLite and PostgreSQL.
func (i *InsertBuilder) Returning(columns ...string) *InsertBuilder {
i.returning = columns
return i
@@ -948,92 +949,91 @@ func (u *UpdateSet) SetExcluded(name string) *UpdateSet {
// Query returns query representation of an `INSERT INTO` statement.
func (i *InsertBuilder) Query() (string, []any) {
i.WriteString("INSERT INTO ")
i.writeSchema(i.schema)
i.Ident(i.table).Pad()
b := i.Builder.clone()
b.WriteString("INSERT INTO ")
b.writeSchema(i.schema)
b.Ident(i.table).Pad()
if i.defaults && len(i.columns) == 0 {
i.writeDefault()
i.writeDefault(&b)
} else {
i.WriteByte('(').IdentComma(i.columns...).WriteByte(')')
i.WriteString(" VALUES ")
b.WriteByte('(').IdentComma(i.columns...).WriteByte(')')
b.WriteString(" VALUES ")
for j, v := range i.values {
if j > 0 {
i.Comma()
b.Comma()
}
i.WriteByte('(').Args(v...).WriteByte(')')
b.WriteByte('(').Args(v...).WriteByte(')')
}
}
if i.conflict != nil {
i.writeConflict()
i.writeConflict(&b)
}
if len(i.returning) > 0 && !i.mysql() {
i.WriteString(" RETURNING ")
i.IdentComma(i.returning...)
}
return i.String(), i.args
joinReturning(i.returning, &b)
return b.String(), b.args
}
func (i *InsertBuilder) writeDefault() {
func (i *InsertBuilder) writeDefault(b *Builder) {
switch i.Dialect() {
case dialect.MySQL:
i.WriteString("VALUES ()")
b.WriteString("VALUES ()")
case dialect.SQLite, dialect.Postgres:
i.WriteString("DEFAULT VALUES")
b.WriteString("DEFAULT VALUES")
}
}
func (i *InsertBuilder) writeConflict() {
func (i *InsertBuilder) writeConflict(b *Builder) {
switch i.Dialect() {
case dialect.MySQL:
i.WriteString(" ON DUPLICATE KEY UPDATE ")
b.WriteString(" ON DUPLICATE KEY UPDATE ")
if i.conflict.action.nothing {
i.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')"))
b.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')"))
}
case dialect.SQLite, dialect.Postgres:
i.WriteString(" ON CONFLICT")
b.WriteString(" ON CONFLICT")
switch t := i.conflict.target; {
case t.constraint != "" && len(t.columns) != 0:
i.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns))
b.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns))
case t.constraint != "":
i.WriteString(" ON CONSTRAINT ").Ident(t.constraint)
b.WriteString(" ON CONSTRAINT ").Ident(t.constraint)
case len(t.columns) != 0:
i.WriteString(" (").IdentComma(t.columns...).WriteByte(')')
b.WriteString(" (").IdentComma(t.columns...).WriteByte(')')
}
if p := i.conflict.target.where; p != nil {
i.WriteString(" WHERE ").Join(p)
b.WriteString(" WHERE ").Join(p)
}
if i.conflict.action.nothing {
i.WriteString(" DO NOTHING")
b.WriteString(" DO NOTHING")
return
}
i.WriteString(" DO UPDATE SET ")
b.WriteString(" DO UPDATE SET ")
}
if len(i.conflict.action.update) == 0 {
i.AddError(errors.New("missing action for 'DO UPDATE SET' clause"))
b.AddError(errors.New("missing action for 'DO UPDATE SET' clause"))
}
u := &UpdateSet{columns: i.columns, update: Dialect(i.dialect).Update(i.table)}
u.update.Builder = i.Builder
u.update.Builder = *b
for _, f := range i.conflict.action.update {
f(u)
}
u.update.writeSetter(&i.Builder)
u.update.writeSetter(b)
if p := i.conflict.action.where; p != nil {
p.qualifier = i.table
i.WriteString(" WHERE ").Join(p)
b.WriteString(" WHERE ").Join(p)
}
}
// UpdateBuilder is a builder for `UPDATE` statement.
type UpdateBuilder struct {
Builder
table string
schema string
where *Predicate
nulls []string
columns []string
values []any
order []any
prefix Queries
table string
schema string
where *Predicate
nulls []string
columns []string
returning []string
values []any
order []any
prefix Queries
}
// Update creates a builder for the `UPDATE` statement.
@@ -1125,6 +1125,13 @@ func (u *UpdateBuilder) Prefix(stmts ...Querier) *UpdateBuilder {
return u
}
// Returning adds the `RETURNING` clause to the insert statement.
// Supported by SQLite and PostgreSQL.
func (u *UpdateBuilder) Returning(columns ...string) *UpdateBuilder {
u.returning = columns
return u
}
// Query returns query representation of an `UPDATE` statement.
func (u *UpdateBuilder) Query() (string, []any) {
b := u.Builder.clone()
@@ -1141,6 +1148,7 @@ func (u *UpdateBuilder) Query() (string, []any) {
b.Join(u.where)
}
joinOrder(u.order, &b)
joinReturning(u.returning, &b)
return b.String(), b.args
}
@@ -2835,6 +2843,14 @@ func joinOrder(order []any, b *Builder) {
}
}
func joinReturning(columns []string, b *Builder) {
if len(columns) == 0 || (!b.postgres() && !b.sqlite()) {
return
}
b.WriteString(" RETURNING ")
b.IdentComma(columns...)
}
func (s *Selector) joinSelect(b *Builder) {
for i := range s.selection {
if i > 0 {
@@ -3498,9 +3514,9 @@ func (b Builder) postgres() bool {
return b.Dialect() == dialect.Postgres
}
// mysql reports if the builder dialect is MySQL.
func (b Builder) mysql() bool {
return b.Dialect() == dialect.MySQL
// sqlite reports if the builder dialect is SQLite.
func (b Builder) sqlite() bool {
return b.Dialect() == dialect.SQLite
}
// fromIdent sets the builder dialect from the identifier format.