mirror of
https://github.com/ent/ent.git
synced 2026-05-05 00:50:54 +03:00
dialect/sql: Add context.Context to sql.Selector (#1185)
This commit is contained in:
@@ -13,6 +13,7 @@ package sql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@@ -1516,6 +1517,9 @@ func (j join) clone() join {
|
||||
// Selector is a builder for the `SELECT` statement.
|
||||
type Selector struct {
|
||||
Builder
|
||||
// ctx stores contextual data typically from
|
||||
// generated code such as alternate table schemas.
|
||||
ctx context.Context
|
||||
as string
|
||||
columns []string
|
||||
from TableView
|
||||
@@ -1531,6 +1535,24 @@ type Selector struct {
|
||||
distinct bool
|
||||
}
|
||||
|
||||
// WithContext sets the context into the *Selector.
|
||||
func (s *Selector) WithContext(ctx context.Context) *Selector {
|
||||
if ctx == nil {
|
||||
panic("nil context")
|
||||
}
|
||||
s.ctx = ctx
|
||||
return s
|
||||
}
|
||||
|
||||
// Context returns the Selector context or Background
|
||||
// if nil.
|
||||
func (s *Selector) Context() context.Context {
|
||||
if s.ctx != nil {
|
||||
return s.ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// Select returns a new selector for the `SELECT` statement.
|
||||
//
|
||||
// t1 := Table("users").As("u")
|
||||
@@ -1748,6 +1770,7 @@ func (s *Selector) Clone() *Selector {
|
||||
}
|
||||
return &Selector{
|
||||
Builder: s.Builder.clone(),
|
||||
ctx: s.ctx,
|
||||
as: s.as,
|
||||
or: s.or,
|
||||
not: s.not,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -1400,3 +1401,16 @@ func TestSelector_OrderByExpr(t *testing.T) {
|
||||
require.Equal(t, "SELECT * FROM `users` WHERE `age` > ? ORDER BY `name`, CASE WHEN id=? THEN id WHEN id=? THEN name END DESC", query)
|
||||
require.Equal(t, []interface{}{28, 1, 2}, args)
|
||||
}
|
||||
|
||||
func TestBuilderContext(t *testing.T) {
|
||||
type key string
|
||||
want := "myval"
|
||||
ctx := context.WithValue(context.Background(), key("mykey"), want)
|
||||
sel := Dialect(dialect.Postgres).Select().WithContext(ctx)
|
||||
if got := sel.Context().Value(key("mykey")).(string); got != want {
|
||||
t.Fatalf("expected selector context key to be %q but got %q", want, got)
|
||||
}
|
||||
if got := sel.Clone().Context().Value(key("mykey")).(string); got != want {
|
||||
t.Fatalf("expected cloned selector context key to be %q but got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user