dialect/sql: support setting ORDER BY for UPDATE statements

This commit is contained in:
Ariel Mashraki
2022-08-05 11:27:01 +03:00
committed by Ariel Mashraki
parent 47972774c5
commit 9f481d8716
2 changed files with 35 additions and 57 deletions

View File

@@ -49,7 +49,6 @@ type ColumnBuilder struct {
// Column returns a new ColumnBuilder with the given name.
//
// sql.Column("group_id").Type("int").Attr("UNIQUE")
//
func Column(name string) *ColumnBuilder { return &ColumnBuilder{name: name} }
// Type sets the column type.
@@ -127,7 +126,6 @@ type TableBuilder struct {
// Column("name").Type("varchar(255)"),
// ).
// PrimaryKey("id")
//
func CreateTable(name string) *TableBuilder { return &TableBuilder{name: name} }
// IfNotExists appends the `IF NOT EXISTS` clause to the `CREATE TABLE` statement.
@@ -206,9 +204,9 @@ func (t *TableBuilder) Options(s string) *TableBuilder {
// Query returns query representation of a `CREATE TABLE` statement.
//
// CREATE TABLE [IF NOT EXISTS] name
// (table definition)
// [charset and collation]
//
// (table definition)
// [charset and collation]
func (t *TableBuilder) Query() (string, []interface{}) {
t.WriteString("CREATE TABLE ")
if t.exists {
@@ -251,7 +249,6 @@ type DescribeBuilder struct {
// Describe returns a query builder for the `DESCRIBE` statement.
//
// Describe("users")
//
func Describe(name string) *DescribeBuilder { return &DescribeBuilder{name: name} }
// Query returns query representation of a `DESCRIBE` statement.
@@ -275,7 +272,6 @@ type TableAlter struct {
// AddForeignKey(ForeignKey().Columns("group_id").
// Reference(Reference().Table("groups").Columns("id")).OnDelete("CASCADE")),
// )
//
func AlterTable(name string) *TableAlter { return &TableAlter{name: name} }
// AddColumn appends the `ADD COLUMN` clause to the given `ALTER TABLE` statement.
@@ -373,7 +369,6 @@ func (t *TableAlter) DropForeignKey(ident string) *TableAlter {
//
// ALTER TABLE name
// [alter_specification]
//
func (t *TableAlter) Query() (string, []interface{}) {
t.WriteString("ALTER TABLE ")
t.Ident(t.name)
@@ -393,7 +388,6 @@ type IndexAlter struct {
//
// AlterIndex("old_key").
// Rename("new_key")
//
func AlterIndex(name string) *IndexAlter { return &IndexAlter{name: name} }
// Rename appends the `RENAME TO` clause to the `ALTER INDEX` statement.
@@ -406,7 +400,6 @@ func (i *IndexAlter) Rename(name string) *IndexAlter {
//
// ALTER INDEX name
// [alter_specification]
//
func (i *IndexAlter) Query() (string, []interface{}) {
i.WriteString("ALTER INDEX ")
i.Ident(i.name)
@@ -426,11 +419,10 @@ type ForeignKeyBuilder struct {
// ForeignKey returns a builder for the foreign-key constraint clause in create/alter table statements.
//
// ForeignKey().
// Columns("group_id").
// ForeignKey().
// Columns("group_id").
// Reference(Reference().Table("groups").Columns("id")).
// OnDelete("CASCADE")
//
func ForeignKey(symbol ...string) *ForeignKeyBuilder {
fk := &ForeignKeyBuilder{}
if len(symbol) != 0 {
@@ -495,7 +487,6 @@ type ReferenceBuilder struct {
// Reference create a reference builder for the reference_option clause.
//
// Reference().Table("groups").Columns("id")
//
func Reference() *ReferenceBuilder { return &ReferenceBuilder{} }
// Table sets the referenced table.
@@ -544,7 +535,6 @@ type IndexBuilder struct {
// Unique().
// Table("users").
// Columns("name", "age")
//
func CreateIndex(name string) *IndexBuilder {
return &IndexBuilder{name: name}
}
@@ -638,7 +628,6 @@ type DropIndexBuilder struct {
// SQLite/PostgreSQL:
//
// DropIndex("index_name")
//
func DropIndex(name string) *DropIndexBuilder {
return &DropIndexBuilder{name: name}
}
@@ -652,7 +641,6 @@ func (d *DropIndexBuilder) Table(table string) *DropIndexBuilder {
// Query returns query representation of a reference clause.
//
// DROP INDEX index_name [ON table_name]
//
func (d *DropIndexBuilder) Query() (string, []interface{}) {
d.WriteString("DROP INDEX ")
d.Ident(d.name)
@@ -758,7 +746,6 @@ type (
// sql.ConflictColumns("id"),
// sql.ResolveWithNewValues(),
// )
//
func ConflictColumns(names ...string) ConflictOption {
return func(c *conflict) {
c.target.columns = names
@@ -775,7 +762,6 @@ func ConflictColumns(names ...string) ConflictOption {
// sql.ConflictConstraint("users_pkey"),
// sql.ResolveWithNewValues(),
// )
//
func ConflictConstraint(name string) ConflictOption {
return func(c *conflict) {
c.target.constraint = name
@@ -808,7 +794,6 @@ func UpdateWhere(p *Predicate) ConflictOption {
// sql.ConflictColumns("id"),
// sql.DoNothing()
// )
//
func DoNothing() ConflictOption {
return func(c *conflict) {
c.action.nothing = true
@@ -829,7 +814,6 @@ func DoNothing() ConflictOption {
// // Output:
// // MySQL: INSERT INTO `users` (`id`) VALUES(1) ON DUPLICATE KEY UPDATE `id` = `users`.`id`
// // PostgreSQL: INSERT INTO "users" ("id") VALUES(1) ON CONFLICT ("id") DO UPDATE SET "id" = "users"."id
//
func ResolveWithIgnore() ConflictOption {
return func(c *conflict) {
c.action.update = append(c.action.update, func(u *UpdateSet) {
@@ -854,7 +838,6 @@ func ResolveWithIgnore() ConflictOption {
// // Output:
// // MySQL: INSERT INTO `users` (`id`, `name`) VALUES(1, 'Mashraki) ON DUPLICATE KEY UPDATE `id` = VALUES(`id`), `name` = VALUES(`name`),
// // PostgreSQL: INSERT INTO "users" ("id") VALUES(1) ON CONFLICT ("id") DO UPDATE SET "id" = "excluded"."id, "name" = "excluded"."name"
//
func ResolveWithNewValues() ConflictOption {
return func(c *conflict) {
c.action.update = append(c.action.update, func(u *UpdateSet) {
@@ -878,7 +861,6 @@ func ResolveWithNewValues() ConflictOption {
// u.Set("name", Expr(u.Excluded().C("name")))
// }),
// )
//
func ResolveWith(fn func(*UpdateSet)) ConflictOption {
return func(c *conflict) {
c.action.update = append(c.action.update, fn)
@@ -895,7 +877,6 @@ func ResolveWith(fn func(*UpdateSet)) ConflictOption {
// sql.ConflictColumns("id"),
// sql.ResolveWithNewValues()
// )
//
func (i *InsertBuilder) OnConflict(opts ...ConflictOption) *InsertBuilder {
if i.conflict == nil {
i.conflict = &conflict{}
@@ -1051,12 +1032,12 @@ type UpdateBuilder struct {
nulls []string
columns []string
values []interface{}
order []interface{}
}
// Update creates a builder for the `UPDATE` statement.
//
// Update("users").Set("name", "foo").Set("age", 10)
//
func Update(table string) *UpdateBuilder { return &UpdateBuilder{table: table} }
// Schema sets the database name for the updated table.
@@ -1124,6 +1105,19 @@ func (u *UpdateBuilder) Empty() bool {
return len(u.columns) == 0 && len(u.nulls) == 0
}
// OrderBy appends the `ORDER BY` clause to the `UPDATE` statement.
// Supported by SQLite and MySQL.
func (u *UpdateBuilder) OrderBy(columns ...string) *UpdateBuilder {
if u.postgres() {
u.AddError(errors.New("ORDER BY is not supported by PostgreSQL"))
return u
}
for i := range columns {
u.order = append(u.order, columns[i])
}
return u
}
// Query returns query representation of an `UPDATE` statement.
func (u *UpdateBuilder) Query() (string, []interface{}) {
b := u.Builder.clone()
@@ -1135,6 +1129,7 @@ func (u *UpdateBuilder) Query() (string, []interface{}) {
b.WriteString(" WHERE ")
b.Join(u.where)
}
joinOrder(u.order, &b)
return b.String(), b.args
}
@@ -1184,7 +1179,6 @@ type DeleteBuilder struct {
// ),
// ),
// )
//
func Delete(table string) *DeleteBuilder { return &DeleteBuilder{table: table} }
// Schema sets the database name for the table whose row will be deleted.
@@ -1234,7 +1228,6 @@ type Predicate struct {
// P creates a new predicate.
//
// P().EQ("name", "a8m").And().EQ("age", 30)
//
func P(fns ...func(*Builder)) *Predicate {
return &Predicate{fns: fns}
}
@@ -1242,7 +1235,6 @@ func P(fns ...func(*Builder)) *Predicate {
// ExprP creates a new predicate from the given expression.
//
// ExprP("A = ? AND B > ?", args...)
//
func ExprP(exr string, args ...interface{}) *Predicate {
return P(func(b *Builder) {
b.Join(Expr(exr, args...))
@@ -1252,7 +1244,6 @@ func ExprP(exr string, args ...interface{}) *Predicate {
// Or combines all given predicates with OR between them.
//
// Or(EQ("name", "foo"), EQ("name", "bar"))
//
func Or(preds ...*Predicate) *Predicate {
p := P()
return p.Append(func(b *Builder) {
@@ -1263,7 +1254,6 @@ func Or(preds ...*Predicate) *Predicate {
// False appends the FALSE keyword to the predicate.
//
// Delete().From("users").Where(False())
//
func False() *Predicate {
return P().False()
}
@@ -1278,7 +1268,6 @@ func (p *Predicate) False() *Predicate {
// Not wraps the given predicate with the not predicate.
//
// Not(Or(EQ("name", "foo"), EQ("name", "bar")))
//
func Not(pred *Predicate) *Predicate {
return P().Not().Append(func(b *Builder) {
b.Nested(func(b *Builder) {
@@ -1870,7 +1859,6 @@ type Func struct {
// Lower wraps the given column with the LOWER function.
//
// P().EQ(sql.Lower("name"), "a8m")
//
func Lower(ident string) string {
f := &Func{}
f.Lower(ident)
@@ -2005,7 +1993,6 @@ type SelectTable struct {
//
// t1 := Table("users").As("u")
// return Select(t1.C("name"))
//
func Table(name string) *SelectTable {
return &SelectTable{quote: true, name: name}
}
@@ -2137,7 +2124,6 @@ func (s *Selector) Context() context.Context {
// From(t1).
// Join(t2).
// On(t1.C("id"), t2.C("user_id"))
//
func Select(columns ...string) *Selector {
return (&Selector{}).Select(columns...)
}
@@ -2532,7 +2518,6 @@ func WithLockTables(tables ...string) LockOption {
// ForShare(
// WithLockClause("LOCK IN SHARE MODE"),
// )
//
func WithLockClause(clause string) LockOption {
return func(c *LockOptions) {
c.clause = clause
@@ -2716,9 +2701,7 @@ func (s *Selector) Query() (string, []interface{}) {
if len(s.union) > 0 {
s.joinUnion(&b)
}
if len(s.order) > 0 {
joinOrder(s.order, &b)
}
joinOrder(s.order, &b)
if s.limit != nil {
b.WriteString(" LIMIT ")
b.WriteString(strconv.Itoa(*s.limit))
@@ -2780,6 +2763,9 @@ func (s *Selector) joinUnion(b *Builder) {
}
func joinOrder(order []interface{}, b *Builder) {
if len(order) == 0 {
return
}
b.WriteString(" ORDER BY ")
for i := range order {
if i > 0 {
@@ -2829,7 +2815,6 @@ type WithBuilder struct {
// Select().From(Table("users_view")),
// }
// return n.Query()
//
func With(name string, columns ...string) *WithBuilder {
return &WithBuilder{
ctes: []struct {
@@ -2849,7 +2834,6 @@ func With(name string, columns ...string) *WithBuilder {
// Select().From(Table("users_view")),
// }
// return n.Query()
//
func WithRecursive(name string, columns ...string) *WithBuilder {
w := With(name, columns...)
w.recursive = true
@@ -2968,9 +2952,7 @@ func (w *WindowBuilder) Query() (string, []interface{}) {
b.WriteString("PARTITION BY ")
w.partition(b)
}
if w.order != nil {
joinOrder(w.order, b)
}
joinOrder(w.order, b)
})
return w.Builder.String(), w.args
}
@@ -3044,7 +3026,6 @@ func (e *expr) Query() (string, []interface{}) { return e.s, e.args }
// // was set before the function was executed.
// b.Ident("x").WriteOp(OpAdd).Arg(1)
// }))
//
func ExprFunc(fn func(*Builder)) Querier {
return &exprFunc{fn: fn}
}
@@ -3466,7 +3447,6 @@ func Dialect(name string) *DialectBuilder {
//
// Dialect(dialect.Postgres).
// Describe("users")
//
func (d *DialectBuilder) Describe(name string) *DescribeBuilder {
b := Describe(name)
b.SetDialect(d.dialect)
@@ -3482,7 +3462,6 @@ func (d *DialectBuilder) Describe(name string) *DescribeBuilder {
// Column("name").Type("varchar(255)"),
// ).
// PrimaryKey("id")
//
func (d *DialectBuilder) CreateTable(name string) *TableBuilder {
b := CreateTable(name)
b.SetDialect(d.dialect)
@@ -3498,7 +3477,6 @@ func (d *DialectBuilder) CreateTable(name string) *TableBuilder {
// Reference(Reference().Table("groups").Columns("id")).
// OnDelete("CASCADE"),
// )
//
func (d *DialectBuilder) AlterTable(name string) *TableAlter {
b := AlterTable(name)
b.SetDialect(d.dialect)
@@ -3510,7 +3488,6 @@ func (d *DialectBuilder) AlterTable(name string) *TableAlter {
// Dialect(dialect.Postgres).
// AlterIndex("old").
// Rename("new")
//
func (d *DialectBuilder) AlterIndex(name string) *IndexAlter {
b := AlterIndex(name)
b.SetDialect(d.dialect)
@@ -3521,7 +3498,6 @@ func (d *DialectBuilder) AlterIndex(name string) *IndexAlter {
//
// Dialect(dialect.Postgres)..
// Column("group_id").Type("int").Attr("UNIQUE")
//
func (d *DialectBuilder) Column(name string) *ColumnBuilder {
b := Column(name)
b.SetDialect(d.dialect)
@@ -3532,7 +3508,6 @@ func (d *DialectBuilder) Column(name string) *ColumnBuilder {
//
// Dialect(dialect.Postgres).
// Insert("users").Columns("age").Values(1)
//
func (d *DialectBuilder) Insert(table string) *InsertBuilder {
b := Insert(table)
b.SetDialect(d.dialect)
@@ -3543,7 +3518,6 @@ func (d *DialectBuilder) Insert(table string) *InsertBuilder {
//
// Dialect(dialect.Postgres).
// Update("users").Set("name", "foo")
//
func (d *DialectBuilder) Update(table string) *UpdateBuilder {
b := Update(table)
b.SetDialect(d.dialect)
@@ -3554,7 +3528,6 @@ func (d *DialectBuilder) Update(table string) *UpdateBuilder {
//
// Dialect(dialect.Postgres).
// Delete().From("users")
//
func (d *DialectBuilder) Delete(table string) *DeleteBuilder {
b := Delete(table)
b.SetDialect(d.dialect)
@@ -3565,7 +3538,6 @@ func (d *DialectBuilder) Delete(table string) *DeleteBuilder {
//
// Dialect(dialect.Postgres).
// Select().From(Table("users"))
//
func (d *DialectBuilder) Select(columns ...string) *Selector {
b := Select(columns...)
b.SetDialect(d.dialect)
@@ -3578,7 +3550,6 @@ func (d *DialectBuilder) Select(columns ...string) *Selector {
// Dialect(dialect.Postgres).
// SelectExpr(expr...).
// From(Table("users"))
//
func (d *DialectBuilder) SelectExpr(exprs ...Querier) *Selector {
b := SelectExpr(exprs...)
b.SetDialect(d.dialect)
@@ -3589,7 +3560,6 @@ func (d *DialectBuilder) SelectExpr(exprs ...Querier) *Selector {
//
// Dialect(dialect.Postgres).
// Table("users").As("u")
//
func (d *DialectBuilder) Table(name string) *SelectTable {
b := Table(name)
b.SetDialect(d.dialect)
@@ -3601,7 +3571,6 @@ func (d *DialectBuilder) Table(name string) *SelectTable {
// Dialect(dialect.Postgres).
// With("users_view").
// As(Select().From(Table("users")))
//
func (d *DialectBuilder) With(name string) *WithBuilder {
b := With(name)
b.SetDialect(d.dialect)
@@ -3615,7 +3584,6 @@ func (d *DialectBuilder) With(name string) *WithBuilder {
// Unique().
// Table("users").
// Columns("first", "last")
//
func (d *DialectBuilder) CreateIndex(name string) *IndexBuilder {
b := CreateIndex(name)
b.SetDialect(d.dialect)
@@ -3626,7 +3594,6 @@ func (d *DialectBuilder) CreateIndex(name string) *IndexBuilder {
//
// Dialect(dialect.Postgres).
// DropIndex("name")
//
func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder {
b := DropIndex(name)
b.SetDialect(d.dialect)