From b067d5d8b4eebf9f4a5aef1b4c679f149d1dd9a8 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 9 Sep 2019 05:49:35 -0700 Subject: [PATCH] dialect/sql: increment and decrement columns using "add" Summary: Use case: ``` query, args := Update("data_packs"). Add("bytes_left", count). Where(expr...). Query() ``` Reviewed By: alexsn Differential Revision: D17257590 fbshipit-source-id: f27b4b388a711a16deb0c3b790e24957e044204d --- dialect/sql/builder.go | 52 ++++++++++++++++++++++++++++++------ dialect/sql/builder_test.go | 16 +++++++++++ entc/gen/internal/bindata.go | 8 +++--- 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index ee3f43f17..906bbe250 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -36,10 +36,11 @@ func (n Queries) Query() (string, []interface{}) { return b.String(), b.args } -// Builder is a query builder for the sql dsl. +// Builder is the base query builder for the sql dsl. type Builder struct { bytes.Buffer - args []interface{} + args []interface{} + dialect string } // Append appends the given string as a quoted parameter @@ -141,6 +142,22 @@ func (b Builder) clone() Builder { return c } +// SetDialect sets the builder dialect. It's used for garnering dialect specific queries. +func (b *Builder) SetDialect(dialect string) *Builder { + b.dialect = dialect + return b +} + +// Dialect returns the dialect of the builder. +func (b Builder) Dialect() string { + return b.dialect +} + +// Query implements the Querier interface. +func (b Builder) Query() (string, []interface{}) { + return b.String(), b.args +} + // ColumnBuilder is a builder for column definition in table creation. type ColumnBuilder struct { b Builder @@ -671,6 +688,20 @@ func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder { return u } +// Add adds a numeric value to the given column. +func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder { + u.columns = append(u.columns, column) + var b Builder + b.WriteString("COALESCE") + b.Nested(func(b *Builder) { + b.Append(column).Comma().Arg(0) + }) + b.WriteString(" + ") + b.Arg(v) + u.values = append(u.values, b) + return u +} + // SetNull sets a column as null value. func (u *UpdateBuilder) SetNull(column string) *UpdateBuilder { u.nulls = append(u.nulls, column) @@ -691,8 +722,8 @@ func (u *UpdateBuilder) Where(p *Predicate) *UpdateBuilder { func (u *UpdateBuilder) Query() (string, []interface{}) { u.b.WriteString("UPDATE ") u.b.Append(u.table).Pad().WriteString("SET ") - for j, c := range u.nulls { - if j > 0 { + for i, c := range u.nulls { + if i > 0 { u.b.Comma() } u.b.Append(c).WriteString(" = NULL") @@ -700,13 +731,18 @@ func (u *UpdateBuilder) Query() (string, []interface{}) { if len(u.nulls) > 0 && len(u.columns) > 0 { u.b.Comma() } - for j, c := range u.columns { - if j > 0 { + for i, c := range u.columns { + if i > 0 { u.b.Comma() } - u.b.Append(c).WriteString(" = ?") + u.b.Append(c).WriteString(" = ") + switch v := u.values[i].(type) { + case Querier: + u.b.Join(v) + default: + u.b.Arg(v) + } } - u.b.args = append(u.b.args, u.values...) if u.where != nil { u.b.WriteString(" WHERE ") u.b.Join(u.where) diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 5ac94a2af..91c9d27b4 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -188,6 +188,22 @@ func TestBuilder(t *testing.T) { wantQuery: "UPDATE `users` SET `name` = ? WHERE `nickname` LIKE ? AND `lastname` LIKE ?", wantArgs: []interface{}{"foo", "a8m%", "%mash%"}, }, + { + input: Update("users"). + Add("age", 1). + Where(HasPrefix("nickname", "a8m")), + wantQuery: "UPDATE `users` SET `age` = COALESCE(`age`, ?) + ? WHERE `nickname` LIKE ?", + wantArgs: []interface{}{0, 1, "a8m%"}, + }, + { + input: Update("users"). + Add("age", 1). + Set("nickname", "a8m"). + Add("version", 10). + Set("name", "mashraki"), + wantQuery: "UPDATE `users` SET `age` = COALESCE(`age`, ?) + ?, `nickname` = ?, `version` = COALESCE(`version`, ?) + ?, `name` = ?", + wantArgs: []interface{}{0, 1, "a8m", 0, 10, "mashraki"}, + }, { input: Select(). From(Table("users")). diff --git a/entc/gen/internal/bindata.go b/entc/gen/internal/bindata.go index 2cfa669f2..7e98dcb0d 100644 --- a/entc/gen/internal/bindata.go +++ b/entc/gen/internal/bindata.go @@ -201,7 +201,7 @@ func templateBuilderSetterTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/builder/setter.tmpl", size: 3452, mode: os.FileMode(420), modTime: time.Unix(1567330680, 0)} + info := bindataFileInfo{name: "template/builder/setter.tmpl", size: 3452, mode: os.FileMode(420), modTime: time.Unix(1568032077, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -221,7 +221,7 @@ func templateBuilderUpdateTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/builder/update.tmpl", size: 7847, mode: os.FileMode(420), modTime: time.Unix(1567330684, 0)} + info := bindataFileInfo{name: "template/builder/update.tmpl", size: 7847, mode: os.FileMode(420), modTime: time.Unix(1568032077, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -721,7 +721,7 @@ func templateDialectSqlUpdateTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/dialect/sql/update.tmpl", size: 11628, mode: os.FileMode(420), modTime: time.Unix(1567330621, 0)} + info := bindataFileInfo{name: "template/dialect/sql/update.tmpl", size: 11628, mode: os.FileMode(420), modTime: time.Unix(1568032077, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -921,7 +921,7 @@ func templateWhereTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/where.tmpl", size: 5099, mode: os.FileMode(420), modTime: time.Unix(1567330531, 0)} + info := bindataFileInfo{name: "template/where.tmpl", size: 5099, mode: os.FileMode(420), modTime: time.Unix(1567957074, 0)} a := &asset{bytes: bytes, info: info} return a, nil }