dialect/sql: Add context.Context to sql.Selector (#1185)

This commit is contained in:
Marwan Sulaiman
2021-01-18 10:32:42 -05:00
committed by GitHub
parent 0ac3526d30
commit ddb25280cd
2 changed files with 37 additions and 0 deletions

View File

@@ -13,6 +13,7 @@ package sql
import (
"bytes"
"context"
"database/sql/driver"
"fmt"
"strconv"
@@ -1516,6 +1517,9 @@ func (j join) clone() join {
// Selector is a builder for the `SELECT` statement.
type Selector struct {
Builder
// ctx stores contextual data typically from
// generated code such as alternate table schemas.
ctx context.Context
as string
columns []string
from TableView
@@ -1531,6 +1535,24 @@ type Selector struct {
distinct bool
}
// WithContext sets the context into the *Selector.
func (s *Selector) WithContext(ctx context.Context) *Selector {
if ctx == nil {
panic("nil context")
}
s.ctx = ctx
return s
}
// Context returns the Selector context or Background
// if nil.
func (s *Selector) Context() context.Context {
if s.ctx != nil {
return s.ctx
}
return context.Background()
}
// Select returns a new selector for the `SELECT` statement.
//
// t1 := Table("users").As("u")
@@ -1748,6 +1770,7 @@ func (s *Selector) Clone() *Selector {
}
return &Selector{
Builder: s.Builder.clone(),
ctx: s.ctx,
as: s.as,
or: s.or,
not: s.not,

View File

@@ -5,6 +5,7 @@
package sql
import (
"context"
"fmt"
"strconv"
"strings"
@@ -1400,3 +1401,16 @@ func TestSelector_OrderByExpr(t *testing.T) {
require.Equal(t, "SELECT * FROM `users` WHERE `age` > ? ORDER BY `name`, CASE WHEN id=? THEN id WHEN id=? THEN name END DESC", query)
require.Equal(t, []interface{}{28, 1, 2}, args)
}
func TestBuilderContext(t *testing.T) {
type key string
want := "myval"
ctx := context.WithValue(context.Background(), key("mykey"), want)
sel := Dialect(dialect.Postgres).Select().WithContext(ctx)
if got := sel.Context().Value(key("mykey")).(string); got != want {
t.Fatalf("expected selector context key to be %q but got %q", want, got)
}
if got := sel.Clone().Context().Value(key("mykey")).(string); got != want {
t.Fatalf("expected cloned selector context key to be %q but got %q", want, got)
}
}