diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 1ae4eea42..743233a43 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -708,7 +708,8 @@ func (i *InsertBuilder) Default() *InsertBuilder { return i } -// Returning adds the `RETURNING` clause to the insert statement. PostgreSQL only. +// Returning adds the `RETURNING` clause to the insert statement. +// Supported by SQLite and PostgreSQL. func (i *InsertBuilder) Returning(columns ...string) *InsertBuilder { i.returning = columns return i @@ -948,92 +949,91 @@ func (u *UpdateSet) SetExcluded(name string) *UpdateSet { // Query returns query representation of an `INSERT INTO` statement. func (i *InsertBuilder) Query() (string, []any) { - i.WriteString("INSERT INTO ") - i.writeSchema(i.schema) - i.Ident(i.table).Pad() + b := i.Builder.clone() + b.WriteString("INSERT INTO ") + b.writeSchema(i.schema) + b.Ident(i.table).Pad() if i.defaults && len(i.columns) == 0 { - i.writeDefault() + i.writeDefault(&b) } else { - i.WriteByte('(').IdentComma(i.columns...).WriteByte(')') - i.WriteString(" VALUES ") + b.WriteByte('(').IdentComma(i.columns...).WriteByte(')') + b.WriteString(" VALUES ") for j, v := range i.values { if j > 0 { - i.Comma() + b.Comma() } - i.WriteByte('(').Args(v...).WriteByte(')') + b.WriteByte('(').Args(v...).WriteByte(')') } } if i.conflict != nil { - i.writeConflict() + i.writeConflict(&b) } - if len(i.returning) > 0 && !i.mysql() { - i.WriteString(" RETURNING ") - i.IdentComma(i.returning...) - } - return i.String(), i.args + joinReturning(i.returning, &b) + return b.String(), b.args } -func (i *InsertBuilder) writeDefault() { +func (i *InsertBuilder) writeDefault(b *Builder) { switch i.Dialect() { case dialect.MySQL: - i.WriteString("VALUES ()") + b.WriteString("VALUES ()") case dialect.SQLite, dialect.Postgres: - i.WriteString("DEFAULT VALUES") + b.WriteString("DEFAULT VALUES") } } -func (i *InsertBuilder) writeConflict() { +func (i *InsertBuilder) writeConflict(b *Builder) { switch i.Dialect() { case dialect.MySQL: - i.WriteString(" ON DUPLICATE KEY UPDATE ") + b.WriteString(" ON DUPLICATE KEY UPDATE ") if i.conflict.action.nothing { - i.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')")) + b.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')")) } case dialect.SQLite, dialect.Postgres: - i.WriteString(" ON CONFLICT") + b.WriteString(" ON CONFLICT") switch t := i.conflict.target; { case t.constraint != "" && len(t.columns) != 0: - i.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns)) + b.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns)) case t.constraint != "": - i.WriteString(" ON CONSTRAINT ").Ident(t.constraint) + b.WriteString(" ON CONSTRAINT ").Ident(t.constraint) case len(t.columns) != 0: - i.WriteString(" (").IdentComma(t.columns...).WriteByte(')') + b.WriteString(" (").IdentComma(t.columns...).WriteByte(')') } if p := i.conflict.target.where; p != nil { - i.WriteString(" WHERE ").Join(p) + b.WriteString(" WHERE ").Join(p) } if i.conflict.action.nothing { - i.WriteString(" DO NOTHING") + b.WriteString(" DO NOTHING") return } - i.WriteString(" DO UPDATE SET ") + b.WriteString(" DO UPDATE SET ") } if len(i.conflict.action.update) == 0 { - i.AddError(errors.New("missing action for 'DO UPDATE SET' clause")) + b.AddError(errors.New("missing action for 'DO UPDATE SET' clause")) } u := &UpdateSet{columns: i.columns, update: Dialect(i.dialect).Update(i.table)} - u.update.Builder = i.Builder + u.update.Builder = *b for _, f := range i.conflict.action.update { f(u) } - u.update.writeSetter(&i.Builder) + u.update.writeSetter(b) if p := i.conflict.action.where; p != nil { p.qualifier = i.table - i.WriteString(" WHERE ").Join(p) + b.WriteString(" WHERE ").Join(p) } } // UpdateBuilder is a builder for `UPDATE` statement. type UpdateBuilder struct { Builder - table string - schema string - where *Predicate - nulls []string - columns []string - values []any - order []any - prefix Queries + table string + schema string + where *Predicate + nulls []string + columns []string + returning []string + values []any + order []any + prefix Queries } // Update creates a builder for the `UPDATE` statement. @@ -1125,6 +1125,13 @@ func (u *UpdateBuilder) Prefix(stmts ...Querier) *UpdateBuilder { return u } +// Returning adds the `RETURNING` clause to the insert statement. +// Supported by SQLite and PostgreSQL. +func (u *UpdateBuilder) Returning(columns ...string) *UpdateBuilder { + u.returning = columns + return u +} + // Query returns query representation of an `UPDATE` statement. func (u *UpdateBuilder) Query() (string, []any) { b := u.Builder.clone() @@ -1141,6 +1148,7 @@ func (u *UpdateBuilder) Query() (string, []any) { b.Join(u.where) } joinOrder(u.order, &b) + joinReturning(u.returning, &b) return b.String(), b.args } @@ -2835,6 +2843,14 @@ func joinOrder(order []any, b *Builder) { } } +func joinReturning(columns []string, b *Builder) { + if len(columns) == 0 || (!b.postgres() && !b.sqlite()) { + return + } + b.WriteString(" RETURNING ") + b.IdentComma(columns...) +} + func (s *Selector) joinSelect(b *Builder) { for i := range s.selection { if i > 0 { @@ -3498,9 +3514,9 @@ func (b Builder) postgres() bool { return b.Dialect() == dialect.Postgres } -// mysql reports if the builder dialect is MySQL. -func (b Builder) mysql() bool { - return b.Dialect() == dialect.MySQL +// sqlite reports if the builder dialect is SQLite. +func (b Builder) sqlite() bool { + return b.Dialect() == dialect.SQLite } // fromIdent sets the builder dialect from the identifier format. diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 813132578..3c7e3fc58 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -302,6 +302,16 @@ func TestBuilder(t *testing.T) { wantQuery: `UPDATE "users" SET "name" = $1`, wantArgs: []any{"foo"}, }, + { + input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Returning("*"), + wantQuery: `UPDATE "users" SET "name" = $1 RETURNING *`, + wantArgs: []any{"foo"}, + }, + { + input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Returning("id", "name"), + wantQuery: `UPDATE "users" SET "name" = $1 RETURNING "id", "name"`, + wantArgs: []any{"foo"}, + }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Schema("mydb"), wantQuery: `UPDATE "mydb"."users" SET "name" = $1`, @@ -317,6 +327,11 @@ func TestBuilder(t *testing.T) { wantQuery: "UPDATE `users` SET `name` = ?, `age` = ?", wantArgs: []any{"foo", 10}, }, + { + input: Dialect(dialect.SQLite).Update("users").Set("name", "foo").Returning("id", "name"), + wantQuery: "UPDATE `users` SET `name` = ? RETURNING `id`, `name`", + wantArgs: []any{"foo"}, + }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Set("age", 10), wantQuery: `UPDATE "users" SET "name" = $1, "age" = $2`,