dialect/sql: add support for statement prefix in update builder (#2904)

This commit is contained in:
Ariel Mashraki
2022-09-06 17:34:53 +03:00
committed by GitHub
parent 1e12537a35
commit 1773fc465e
3 changed files with 53 additions and 25 deletions

View File

@@ -1033,6 +1033,7 @@ type UpdateBuilder struct {
columns []string
values []any
order []any
prefix Queries
}
// Update creates a builder for the `UPDATE` statement.
@@ -1118,9 +1119,19 @@ func (u *UpdateBuilder) OrderBy(columns ...string) *UpdateBuilder {
return u
}
// Prefix prefixes the UPDATE statement with list of statements.
func (u *UpdateBuilder) Prefix(stmts ...Querier) *UpdateBuilder {
u.prefix = append(u.prefix, stmts...)
return u
}
// Query returns query representation of an `UPDATE` statement.
func (u *UpdateBuilder) Query() (string, []any) {
b := u.Builder.clone()
if len(u.prefix) > 0 {
b.join(u.prefix, " ")
b.Pad()
}
b.WriteString("UPDATE ")
b.writeSchema(u.schema)
b.Ident(u.table).WriteString(" SET ")
@@ -3333,7 +3344,7 @@ func (b *Builder) JoinComma(qs ...Querier) *Builder {
return b.join(qs, ", ")
}
// join joins a list of Queries to the builder with a given separator.
// join a list of Queries to the builder with a given separator.
func (b *Builder) join(qs []Querier, sep string) *Builder {
for i, q := range qs {
if i > 0 {

View File

@@ -853,18 +853,18 @@ func TestBuilder(t *testing.T) {
}(),
wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` RIGHT JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`",
},
{
input: func() Querier {
t1 := Table("groups").As("g")
t2 := Table("user_groups").As("ug")
return Select(t1.C("id"), As(Count("`*`"), "user_count")).
From(t1).
FullJoin(t2).
On(t1.C("id"), t2.C("group_id")).
GroupBy(t1.C("id"))
}(),
wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` FULL JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`",
},
{
input: func() Querier {
t1 := Table("groups").As("g")
t2 := Table("user_groups").As("ug")
return Select(t1.C("id"), As(Count("`*`"), "user_count")).
From(t1).
FullJoin(t2).
On(t1.C("id"), t2.C("group_id")).
GroupBy(t1.C("id"))
}(),
wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` FULL JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`",
},
{
input: func() Querier {
t1 := Table("users").As("u")
@@ -2169,3 +2169,27 @@ func TestUpdateBuilder_OrderBy(t *testing.T) {
u = Dialect(dialect.Postgres).Update("users").Set("id", Expr("id + 1")).OrderBy("id")
require.Error(t, u.Err())
}
func TestUpdateBuilder_WithPrefix(t *testing.T) {
u := Dialect(dialect.MySQL).
Update("users").
Prefix(ExprFunc(func(b *Builder) {
b.WriteString("SET @i = ").Arg(1).WriteByte(';')
})).
Set("id", Expr("(@i:=@i+1)")).
OrderBy("id")
require.NoError(t, u.Err())
query, args := u.Query()
require.Equal(t, []any{1}, args)
require.Equal(t, "SET @i = ?; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query)
u = Dialect(dialect.MySQL).
Update("users").
Prefix(Expr("SET @i = 1;")).
Set("id", Expr("(@i:=@i+1)")).
OrderBy("id")
require.NoError(t, u.Err())
query, args = u.Query()
require.Empty(t, args)
require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query)
}