mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql: support the RETURNING clause in UPDATE (#3016)
* dialect/sql: support the RETURNING clause in UPDATE
This commit is contained in:
@@ -708,7 +708,8 @@ func (i *InsertBuilder) Default() *InsertBuilder {
|
||||
return i
|
||||
}
|
||||
|
||||
// Returning adds the `RETURNING` clause to the insert statement. PostgreSQL only.
|
||||
// Returning adds the `RETURNING` clause to the insert statement.
|
||||
// Supported by SQLite and PostgreSQL.
|
||||
func (i *InsertBuilder) Returning(columns ...string) *InsertBuilder {
|
||||
i.returning = columns
|
||||
return i
|
||||
@@ -948,92 +949,91 @@ func (u *UpdateSet) SetExcluded(name string) *UpdateSet {
|
||||
|
||||
// Query returns query representation of an `INSERT INTO` statement.
|
||||
func (i *InsertBuilder) Query() (string, []any) {
|
||||
i.WriteString("INSERT INTO ")
|
||||
i.writeSchema(i.schema)
|
||||
i.Ident(i.table).Pad()
|
||||
b := i.Builder.clone()
|
||||
b.WriteString("INSERT INTO ")
|
||||
b.writeSchema(i.schema)
|
||||
b.Ident(i.table).Pad()
|
||||
if i.defaults && len(i.columns) == 0 {
|
||||
i.writeDefault()
|
||||
i.writeDefault(&b)
|
||||
} else {
|
||||
i.WriteByte('(').IdentComma(i.columns...).WriteByte(')')
|
||||
i.WriteString(" VALUES ")
|
||||
b.WriteByte('(').IdentComma(i.columns...).WriteByte(')')
|
||||
b.WriteString(" VALUES ")
|
||||
for j, v := range i.values {
|
||||
if j > 0 {
|
||||
i.Comma()
|
||||
b.Comma()
|
||||
}
|
||||
i.WriteByte('(').Args(v...).WriteByte(')')
|
||||
b.WriteByte('(').Args(v...).WriteByte(')')
|
||||
}
|
||||
}
|
||||
if i.conflict != nil {
|
||||
i.writeConflict()
|
||||
i.writeConflict(&b)
|
||||
}
|
||||
if len(i.returning) > 0 && !i.mysql() {
|
||||
i.WriteString(" RETURNING ")
|
||||
i.IdentComma(i.returning...)
|
||||
}
|
||||
return i.String(), i.args
|
||||
joinReturning(i.returning, &b)
|
||||
return b.String(), b.args
|
||||
}
|
||||
|
||||
func (i *InsertBuilder) writeDefault() {
|
||||
func (i *InsertBuilder) writeDefault(b *Builder) {
|
||||
switch i.Dialect() {
|
||||
case dialect.MySQL:
|
||||
i.WriteString("VALUES ()")
|
||||
b.WriteString("VALUES ()")
|
||||
case dialect.SQLite, dialect.Postgres:
|
||||
i.WriteString("DEFAULT VALUES")
|
||||
b.WriteString("DEFAULT VALUES")
|
||||
}
|
||||
}
|
||||
|
||||
func (i *InsertBuilder) writeConflict() {
|
||||
func (i *InsertBuilder) writeConflict(b *Builder) {
|
||||
switch i.Dialect() {
|
||||
case dialect.MySQL:
|
||||
i.WriteString(" ON DUPLICATE KEY UPDATE ")
|
||||
b.WriteString(" ON DUPLICATE KEY UPDATE ")
|
||||
if i.conflict.action.nothing {
|
||||
i.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')"))
|
||||
b.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')"))
|
||||
}
|
||||
case dialect.SQLite, dialect.Postgres:
|
||||
i.WriteString(" ON CONFLICT")
|
||||
b.WriteString(" ON CONFLICT")
|
||||
switch t := i.conflict.target; {
|
||||
case t.constraint != "" && len(t.columns) != 0:
|
||||
i.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns))
|
||||
b.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns))
|
||||
case t.constraint != "":
|
||||
i.WriteString(" ON CONSTRAINT ").Ident(t.constraint)
|
||||
b.WriteString(" ON CONSTRAINT ").Ident(t.constraint)
|
||||
case len(t.columns) != 0:
|
||||
i.WriteString(" (").IdentComma(t.columns...).WriteByte(')')
|
||||
b.WriteString(" (").IdentComma(t.columns...).WriteByte(')')
|
||||
}
|
||||
if p := i.conflict.target.where; p != nil {
|
||||
i.WriteString(" WHERE ").Join(p)
|
||||
b.WriteString(" WHERE ").Join(p)
|
||||
}
|
||||
if i.conflict.action.nothing {
|
||||
i.WriteString(" DO NOTHING")
|
||||
b.WriteString(" DO NOTHING")
|
||||
return
|
||||
}
|
||||
i.WriteString(" DO UPDATE SET ")
|
||||
b.WriteString(" DO UPDATE SET ")
|
||||
}
|
||||
if len(i.conflict.action.update) == 0 {
|
||||
i.AddError(errors.New("missing action for 'DO UPDATE SET' clause"))
|
||||
b.AddError(errors.New("missing action for 'DO UPDATE SET' clause"))
|
||||
}
|
||||
u := &UpdateSet{columns: i.columns, update: Dialect(i.dialect).Update(i.table)}
|
||||
u.update.Builder = i.Builder
|
||||
u.update.Builder = *b
|
||||
for _, f := range i.conflict.action.update {
|
||||
f(u)
|
||||
}
|
||||
u.update.writeSetter(&i.Builder)
|
||||
u.update.writeSetter(b)
|
||||
if p := i.conflict.action.where; p != nil {
|
||||
p.qualifier = i.table
|
||||
i.WriteString(" WHERE ").Join(p)
|
||||
b.WriteString(" WHERE ").Join(p)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateBuilder is a builder for `UPDATE` statement.
|
||||
type UpdateBuilder struct {
|
||||
Builder
|
||||
table string
|
||||
schema string
|
||||
where *Predicate
|
||||
nulls []string
|
||||
columns []string
|
||||
values []any
|
||||
order []any
|
||||
prefix Queries
|
||||
table string
|
||||
schema string
|
||||
where *Predicate
|
||||
nulls []string
|
||||
columns []string
|
||||
returning []string
|
||||
values []any
|
||||
order []any
|
||||
prefix Queries
|
||||
}
|
||||
|
||||
// Update creates a builder for the `UPDATE` statement.
|
||||
@@ -1125,6 +1125,13 @@ func (u *UpdateBuilder) Prefix(stmts ...Querier) *UpdateBuilder {
|
||||
return u
|
||||
}
|
||||
|
||||
// Returning adds the `RETURNING` clause to the insert statement.
|
||||
// Supported by SQLite and PostgreSQL.
|
||||
func (u *UpdateBuilder) Returning(columns ...string) *UpdateBuilder {
|
||||
u.returning = columns
|
||||
return u
|
||||
}
|
||||
|
||||
// Query returns query representation of an `UPDATE` statement.
|
||||
func (u *UpdateBuilder) Query() (string, []any) {
|
||||
b := u.Builder.clone()
|
||||
@@ -1141,6 +1148,7 @@ func (u *UpdateBuilder) Query() (string, []any) {
|
||||
b.Join(u.where)
|
||||
}
|
||||
joinOrder(u.order, &b)
|
||||
joinReturning(u.returning, &b)
|
||||
return b.String(), b.args
|
||||
}
|
||||
|
||||
@@ -2835,6 +2843,14 @@ func joinOrder(order []any, b *Builder) {
|
||||
}
|
||||
}
|
||||
|
||||
func joinReturning(columns []string, b *Builder) {
|
||||
if len(columns) == 0 || (!b.postgres() && !b.sqlite()) {
|
||||
return
|
||||
}
|
||||
b.WriteString(" RETURNING ")
|
||||
b.IdentComma(columns...)
|
||||
}
|
||||
|
||||
func (s *Selector) joinSelect(b *Builder) {
|
||||
for i := range s.selection {
|
||||
if i > 0 {
|
||||
@@ -3498,9 +3514,9 @@ func (b Builder) postgres() bool {
|
||||
return b.Dialect() == dialect.Postgres
|
||||
}
|
||||
|
||||
// mysql reports if the builder dialect is MySQL.
|
||||
func (b Builder) mysql() bool {
|
||||
return b.Dialect() == dialect.MySQL
|
||||
// sqlite reports if the builder dialect is SQLite.
|
||||
func (b Builder) sqlite() bool {
|
||||
return b.Dialect() == dialect.SQLite
|
||||
}
|
||||
|
||||
// fromIdent sets the builder dialect from the identifier format.
|
||||
|
||||
@@ -302,6 +302,16 @@ func TestBuilder(t *testing.T) {
|
||||
wantQuery: `UPDATE "users" SET "name" = $1`,
|
||||
wantArgs: []any{"foo"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Returning("*"),
|
||||
wantQuery: `UPDATE "users" SET "name" = $1 RETURNING *`,
|
||||
wantArgs: []any{"foo"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Returning("id", "name"),
|
||||
wantQuery: `UPDATE "users" SET "name" = $1 RETURNING "id", "name"`,
|
||||
wantArgs: []any{"foo"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Schema("mydb"),
|
||||
wantQuery: `UPDATE "mydb"."users" SET "name" = $1`,
|
||||
@@ -317,6 +327,11 @@ func TestBuilder(t *testing.T) {
|
||||
wantQuery: "UPDATE `users` SET `name` = ?, `age` = ?",
|
||||
wantArgs: []any{"foo", 10},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.SQLite).Update("users").Set("name", "foo").Returning("id", "name"),
|
||||
wantQuery: "UPDATE `users` SET `name` = ? RETURNING `id`, `name`",
|
||||
wantArgs: []any{"foo"},
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Set("age", 10),
|
||||
wantQuery: `UPDATE "users" SET "name" = $1, "age" = $2`,
|
||||
|
||||
Reference in New Issue
Block a user