dialect/sql/sqlgraph: pass context.Context to *sql.Selector (#1186)

* Ensure sqlgraph passes the context to *sql.Selector

* Update dialect/sql/sqlgraph/graph_test.go

Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>

* Update dialect/sql/sqlgraph/graph_test.go

Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>

* gofmt

Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>
This commit is contained in:
Marwan Sulaiman
2021-01-18 12:41:59 -05:00
committed by GitHub
parent ddb25280cd
commit 3439ca207f
2 changed files with 41 additions and 6 deletions

View File

@@ -271,6 +271,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
Join(to).
On(edge.C(pk1), to.C(s.To.Column))
matches := builder.Select().From(to)
matches.WithContext(q.Context())
pred(matches)
join.FromSelect(matches)
q.Where(sql.In(from.C(s.From.Column), join))
@@ -279,6 +280,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
to := builder.Table(s.To.Table).Schema(s.To.Schema)
matches := builder.Select(to.C(s.To.Column)).
From(to)
matches.WithContext(q.Context())
pred(matches)
q.Where(sql.In(from.C(s.Edge.Columns[0]), matches))
case r == O2M || (r == O2O && !s.Edge.Inverse):
@@ -286,6 +288,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
matches := builder.Select(to.C(s.Edge.Columns[0])).
From(to)
matches.WithContext(q.Context())
pred(matches)
q.Where(sql.In(from.C(s.From.Column), matches))
}
@@ -462,7 +465,8 @@ func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int
builder = sql.Dialect(drv.Dialect())
)
selector := builder.Select().
From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema))
From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema)).
WithContext(ctx)
if pred := spec.Predicate; pred != nil {
pred(selector)
}
@@ -556,7 +560,7 @@ type query struct {
func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
rows := &sql.Rows{}
selector, err := q.selector()
selector, err := q.selector(ctx)
if err != nil {
return err
}
@@ -586,7 +590,7 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
rows := &sql.Rows{}
selector, err := q.selector()
selector, err := q.selector(ctx)
if err != nil {
return 0, err
}
@@ -603,8 +607,11 @@ func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
return sql.ScanInt(rows)
}
func (q *query) selector() (*sql.Selector, error) {
selector := q.builder.Select().From(q.builder.Table(q.Node.Table).Schema(q.Node.Schema))
func (q *query) selector(ctx context.Context) (*sql.Selector, error) {
selector := q.builder.
Select().
From(q.builder.Table(q.Node.Table).Schema(q.Node.Schema)).
WithContext(ctx)
if q.From != nil {
selector = q.From
}
@@ -678,7 +685,8 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
multiple = u.hasExternalEdges(addEdges, clearEdges)
update = u.builder.Update(u.Node.Table).Schema(u.Node.Schema)
selector = u.builder.Select(u.Node.ID.Column).
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema))
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).
WithContext(ctx)
)
if pred := u.Predicate; pred != nil {
pred(selector)

View File

@@ -12,6 +12,7 @@ import (
"strings"
"testing"
"github.com/facebook/ent/dialect"
"github.com/facebook/ent/dialect/sql"
"github.com/facebook/ent/schema/field"
@@ -805,6 +806,32 @@ WHERE "s1"."users"."id" IN
}
}
func TestHasNeighborsWithContext(t *testing.T) {
type key string
ctx := context.WithValue(context.Background(), key("mykey"), "myval")
for _, rel := range [...]Rel{M2M, O2M, O2O} {
t.Run(rel.String(), func(t *testing.T) {
sel := sql.Dialect(dialect.Postgres).
Select("*").
From(sql.Table("users")).
WithContext(ctx)
step := NewStep(
From("users", "id"),
To("groups", "id"),
Edge(rel, false, "user_groups", "user_id", "group_id"),
)
var called bool
pred := func(s *sql.Selector) {
called = true
got := s.Context().Value(key("mykey")).(string)
require.Equal(t, "myval", got)
}
HasNeighborsWith(sel, step, pred)
require.True(t, called, "expected predicate function to be called")
})
}
}
func TestCreateNode(t *testing.T) {
tests := []struct {
name string