mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql: add support for statement prefix in update builder (#2904)
This commit is contained in:
@@ -1033,6 +1033,7 @@ type UpdateBuilder struct {
|
||||
columns []string
|
||||
values []any
|
||||
order []any
|
||||
prefix Queries
|
||||
}
|
||||
|
||||
// Update creates a builder for the `UPDATE` statement.
|
||||
@@ -1118,9 +1119,19 @@ func (u *UpdateBuilder) OrderBy(columns ...string) *UpdateBuilder {
|
||||
return u
|
||||
}
|
||||
|
||||
// Prefix prefixes the UPDATE statement with list of statements.
|
||||
func (u *UpdateBuilder) Prefix(stmts ...Querier) *UpdateBuilder {
|
||||
u.prefix = append(u.prefix, stmts...)
|
||||
return u
|
||||
}
|
||||
|
||||
// Query returns query representation of an `UPDATE` statement.
|
||||
func (u *UpdateBuilder) Query() (string, []any) {
|
||||
b := u.Builder.clone()
|
||||
if len(u.prefix) > 0 {
|
||||
b.join(u.prefix, " ")
|
||||
b.Pad()
|
||||
}
|
||||
b.WriteString("UPDATE ")
|
||||
b.writeSchema(u.schema)
|
||||
b.Ident(u.table).WriteString(" SET ")
|
||||
@@ -3333,7 +3344,7 @@ func (b *Builder) JoinComma(qs ...Querier) *Builder {
|
||||
return b.join(qs, ", ")
|
||||
}
|
||||
|
||||
// join joins a list of Queries to the builder with a given separator.
|
||||
// join a list of Queries to the builder with a given separator.
|
||||
func (b *Builder) join(qs []Querier, sep string) *Builder {
|
||||
for i, q := range qs {
|
||||
if i > 0 {
|
||||
|
||||
@@ -853,18 +853,18 @@ func TestBuilder(t *testing.T) {
|
||||
}(),
|
||||
wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` RIGHT JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`",
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
t1 := Table("groups").As("g")
|
||||
t2 := Table("user_groups").As("ug")
|
||||
return Select(t1.C("id"), As(Count("`*`"), "user_count")).
|
||||
From(t1).
|
||||
FullJoin(t2).
|
||||
On(t1.C("id"), t2.C("group_id")).
|
||||
GroupBy(t1.C("id"))
|
||||
}(),
|
||||
wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` FULL JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`",
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
t1 := Table("groups").As("g")
|
||||
t2 := Table("user_groups").As("ug")
|
||||
return Select(t1.C("id"), As(Count("`*`"), "user_count")).
|
||||
From(t1).
|
||||
FullJoin(t2).
|
||||
On(t1.C("id"), t2.C("group_id")).
|
||||
GroupBy(t1.C("id"))
|
||||
}(),
|
||||
wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` FULL JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`",
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
@@ -2169,3 +2169,27 @@ func TestUpdateBuilder_OrderBy(t *testing.T) {
|
||||
u = Dialect(dialect.Postgres).Update("users").Set("id", Expr("id + 1")).OrderBy("id")
|
||||
require.Error(t, u.Err())
|
||||
}
|
||||
|
||||
func TestUpdateBuilder_WithPrefix(t *testing.T) {
|
||||
u := Dialect(dialect.MySQL).
|
||||
Update("users").
|
||||
Prefix(ExprFunc(func(b *Builder) {
|
||||
b.WriteString("SET @i = ").Arg(1).WriteByte(';')
|
||||
})).
|
||||
Set("id", Expr("(@i:=@i+1)")).
|
||||
OrderBy("id")
|
||||
require.NoError(t, u.Err())
|
||||
query, args := u.Query()
|
||||
require.Equal(t, []any{1}, args)
|
||||
require.Equal(t, "SET @i = ?; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query)
|
||||
|
||||
u = Dialect(dialect.MySQL).
|
||||
Update("users").
|
||||
Prefix(Expr("SET @i = 1;")).
|
||||
Set("id", Expr("(@i:=@i+1)")).
|
||||
OrderBy("id")
|
||||
require.NoError(t, u.Err())
|
||||
query, args = u.Query()
|
||||
require.Empty(t, args)
|
||||
require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user