diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 3faa9cc61..9d2ec605f 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1782,6 +1782,7 @@ type Selector struct { distinct bool union []union prefix Queries + lock *LockConfig } // WithContext sets the context into the *Selector. @@ -2069,6 +2070,100 @@ func (s *Selector) Count(columns ...string) *Selector { return s } +// LockAction tells the transaction what to do in case of +// requesting a row that is locked by other transaction. +type LockAction string + +const ( + // NoWait means never wait and returns an error. + NoWait LockAction = "NOWAIT" + // SkipLocked means never wait and skip. + SkipLocked LockAction = "SKIP LOCKED" +) + +// LockStrength defines the strength of the lock (see the list below). +type LockStrength string + +// A list of all locking clauses. +const ( + LockShare LockStrength = "SHARE" + LockUpdate LockStrength = "UPDATE" + LockNoKeyUpdate LockStrength = "NO KEY UPDATE" + LockKeyShare LockStrength = "KEY SHARE" +) + +type ( + // LockConfig defines a SELECT statement + // lock for protecting concurrent updates. + LockConfig struct { + // Strength of the lock. + Strength LockStrength + // Action of the lock. + Action LockAction + // Tables are an option tables. + Tables []string + // custom clause for locking. + clause string + } + // LockOption allows configuring the LockConfig using functional options. + LockOption func(*LockConfig) +) + +// WithLockAction sets the Action of the lock. +func WithLockAction(action LockAction) LockOption { + return func(c *LockConfig) { + c.Action = action + } +} + +// WithLockTables sets the Tables of the lock. +func WithLockTables(tables ...*SelectTable) LockOption { + return func(c *LockConfig) { + names := make([]string, len(tables)) + for i := range tables { + names[i] = tables[i].name + } + c.Tables = names + } +} + +// WithLockClause allows providing a custom clause for +// locking the statement. For example, in MySQL <= 8.22: +// +// Select(). +// From(Table("users")). +// ForShare( +// WithLockClause("LOCK IN SHARE MODE"), +// ) +// +func WithLockClause(clause string) LockOption { + return func(c *LockConfig) { + c.clause = clause + } +} + +// For sets the lock configuration for suffixing the `SELECT` +// statement with the `FOR [SHARE | UPDATE] ...` clause. +func (s *Selector) For(l LockStrength, opts ...LockOption) *Selector { + s.lock = &LockConfig{Strength: l} + for _, opt := range opts { + opt(s.lock) + } + return s +} + +// ForShare sets the lock configuration for suffixing the +// `SELECT` statement with the `FOR SHARE` clause. +func (s *Selector) ForShare(opts ...LockOption) *Selector { + return s.For(LockShare, opts...) +} + +// LockUpdate sets the lock configuration for suffixing the +// `SELECT` statement with the `FOR UPDATE` clause. +func (s *Selector) ForUpdate(opts ...LockOption) *Selector { + return s.For(LockUpdate, opts...) +} + // Clone returns a duplicate of the selector, including all associated steps. It can be // used to prepare common SELECT statements and use them differently after the clone is made. func (s *Selector) Clone() *Selector { @@ -2218,6 +2313,7 @@ func (s *Selector) Query() (string, []interface{}) { if len(s.union) > 0 { s.joinUnion(&b) } + s.joinLock(&b) s.total = b.total s.AddError(b.Err()) return b.String(), b.args @@ -2226,7 +2322,25 @@ func (s *Selector) Query() (string, []interface{}) { func (s *Selector) joinPrefix(b *Builder) { if len(s.prefix) > 0 { b.join(s.prefix, " ") - b.WriteByte(' ') + b.Pad() + } +} + +func (s *Selector) joinLock(b *Builder) { + if s.lock == nil { + return + } + b.Pad() + if s.lock.clause != "" { + b.WriteString(s.lock.clause) + return + } + b.WriteString("FOR ").WriteString(string(s.lock.Strength)) + if len(s.lock.Tables) > 0 { + b.WriteString(" OF ").IdentComma(s.lock.Tables...) + } + if s.lock.Action != "" { + b.Pad().WriteString(string(s.lock.Action)) } } diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 8bce0f9eb..211d91693 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1641,3 +1641,47 @@ func TestParamFormatter(t *testing.T) { require.Equal(t, "SELECT * FROM `users` WHERE `point` = ST_GeomFromWKB(?)", query) require.Equal(t, p, args[0]) } + +func TestSelectWithLock(t *testing.T) { + query, args := Dialect(dialect.MySQL). + Select(). + From(Table("users")). + Where(EQ("id", 1)). + ForUpdate(). + Query() + require.Equal(t, "SELECT * FROM `users` WHERE `id` = ? FOR UPDATE", query) + require.Equal(t, 1, args[0]) + + query, args = Dialect(dialect.Postgres). + Select(). + From(Table("users")). + Where(EQ("id", 1)). + ForUpdate(WithLockAction(NoWait)). + Query() + require.Equal(t, `SELECT * FROM "users" WHERE "id" = $1 FOR UPDATE NOWAIT`, query) + require.Equal(t, 1, args[0]) + + users, pets := Table("users"), Table("pets") + query, args = Dialect(dialect.Postgres). + Select(). + From(pets). + Join(users). + On(pets.C("owner_id"), users.C("id")). + Where(EQ("id", 20)). + ForUpdate( + WithLockAction(SkipLocked), + WithLockTables(pets), + ). + Query() + require.Equal(t, `SELECT * FROM "pets" JOIN "users" AS "t1" ON "pets"."owner_id" = "t1"."id" WHERE "id" = $1 FOR UPDATE OF "pets" SKIP LOCKED`, query) + require.Equal(t, 20, args[0]) + + query, args = Dialect(dialect.MySQL). + Select(). + From(Table("users")). + Where(EQ("id", 20)). + ForShare(WithLockClause("LOCK IN SHARE MODE")). + Query() + require.Equal(t, "SELECT * FROM `users` WHERE `id` = ? LOCK IN SHARE MODE", query) + require.Equal(t, 20, args[0]) +}