diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index de2e87a45..59a5a43ec 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1033,6 +1033,7 @@ type UpdateBuilder struct { columns []string values []any order []any + prefix Queries } // Update creates a builder for the `UPDATE` statement. @@ -1118,9 +1119,19 @@ func (u *UpdateBuilder) OrderBy(columns ...string) *UpdateBuilder { return u } +// Prefix prefixes the UPDATE statement with list of statements. +func (u *UpdateBuilder) Prefix(stmts ...Querier) *UpdateBuilder { + u.prefix = append(u.prefix, stmts...) + return u +} + // Query returns query representation of an `UPDATE` statement. func (u *UpdateBuilder) Query() (string, []any) { b := u.Builder.clone() + if len(u.prefix) > 0 { + b.join(u.prefix, " ") + b.Pad() + } b.WriteString("UPDATE ") b.writeSchema(u.schema) b.Ident(u.table).WriteString(" SET ") @@ -3333,7 +3344,7 @@ func (b *Builder) JoinComma(qs ...Querier) *Builder { return b.join(qs, ", ") } -// join joins a list of Queries to the builder with a given separator. +// join a list of Queries to the builder with a given separator. func (b *Builder) join(qs []Querier, sep string) *Builder { for i, q := range qs { if i > 0 { diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index ad7690563..22e387de2 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -853,18 +853,18 @@ func TestBuilder(t *testing.T) { }(), wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` RIGHT JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`", }, - { - input: func() Querier { - t1 := Table("groups").As("g") - t2 := Table("user_groups").As("ug") - return Select(t1.C("id"), As(Count("`*`"), "user_count")). - From(t1). - FullJoin(t2). - On(t1.C("id"), t2.C("group_id")). - GroupBy(t1.C("id")) - }(), - wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` FULL JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`", - }, + { + input: func() Querier { + t1 := Table("groups").As("g") + t2 := Table("user_groups").As("ug") + return Select(t1.C("id"), As(Count("`*`"), "user_count")). + From(t1). + FullJoin(t2). + On(t1.C("id"), t2.C("group_id")). + GroupBy(t1.C("id")) + }(), + wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` FULL JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`", + }, { input: func() Querier { t1 := Table("users").As("u") @@ -2169,3 +2169,27 @@ func TestUpdateBuilder_OrderBy(t *testing.T) { u = Dialect(dialect.Postgres).Update("users").Set("id", Expr("id + 1")).OrderBy("id") require.Error(t, u.Err()) } + +func TestUpdateBuilder_WithPrefix(t *testing.T) { + u := Dialect(dialect.MySQL). + Update("users"). + Prefix(ExprFunc(func(b *Builder) { + b.WriteString("SET @i = ").Arg(1).WriteByte(';') + })). + Set("id", Expr("(@i:=@i+1)")). + OrderBy("id") + require.NoError(t, u.Err()) + query, args := u.Query() + require.Equal(t, []any{1}, args) + require.Equal(t, "SET @i = ?; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query) + + u = Dialect(dialect.MySQL). + Update("users"). + Prefix(Expr("SET @i = 1;")). + Set("id", Expr("(@i:=@i+1)")). + OrderBy("id") + require.NoError(t, u.Err()) + query, args = u.Query() + require.Empty(t, args) + require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query) +} diff --git a/entc/gen/func.go b/entc/gen/func.go index 7d6613d94..aaed0d8f3 100644 --- a/entc/gen/func.go +++ b/entc/gen/func.go @@ -144,7 +144,6 @@ func pascalWords(words []string) string { // full_name => FullName // user_id => UserID // full-admin => FullAdmin -// func pascal(s string) string { words := strings.FieldsFunc(s, isSeparator) return pascalWords(words) @@ -156,7 +155,6 @@ func pascal(s string) string { // full_name => fullName // user_id => userID // full-admin => fullAdmin -// func camel(s string) string { words := strings.FieldsFunc(s, isSeparator) if len(words) == 1 { @@ -170,7 +168,6 @@ func camel(s string) string { // Username => username // FullName => full_name // HTTPCode => http_code -// func snake(s string) string { var ( j int @@ -199,7 +196,6 @@ func snake(s string) string { // [1]T => t // User => u // UserQuery => uq -// func receiver(s string) (r string) { // Trim invalid tokens for identifier prefix. s = strings.Trim(s, "[]*&0123456789") @@ -244,7 +240,6 @@ type graphScope struct { // {{ with $scope := extend $ "key" "value" }} // {{ template "setters" $scope }} // {{ end}} -// func extend(v any, kv ...any) (any, error) { if len(kv)%2 != 0 { return nil, fmt.Errorf("invalid number of parameters: %d", len(kv)) @@ -283,7 +278,7 @@ func add(xs ...int) (n int) { func ruleset() *inflect.Ruleset { rules := inflect.NewDefaultRuleset() - // Add common initialisms from golint and more. + // Add common initialism from golint and more. for _, w := range []string{ "ACL", "API", "ASCII", "AWS", "CPU", "CSS", "DNS", "EOF", "GB", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "KB", "LHS", "MAC", "MB", @@ -291,20 +286,18 @@ func ruleset() *inflect.Ruleset { "TLS", "TTL", "UDP", "UI", "UID", "URI", "URL", "UTF8", "UUID", "VM", "XML", "XMPP", "XSRF", "XSS", } { - addAcronym(rules, w) + acronyms[w] = struct{}{} + rules.AddAcronym(w) } return rules } -func addAcronym(rules *inflect.Ruleset, word string) { +// AddAcronym adds initialism to the global ruleset. +func AddAcronym(word string) { acronyms[word] = struct{}{} rules.AddAcronym(word) } -func AddAcronym(word string) { - addAcronym(rules, word) -} - // order returns a map of sort orders. // The key is the function name, and the value its database keyword. func order() map[string]string {