mirror of
https://github.com/ent/ent.git
synced 2026-04-28 21:50:56 +03:00
dialect/sql/sqlgraph: pass context.Context to *sql.Selector (#1186)
* Ensure sqlgraph passes the context to *sql.Selector * Update dialect/sql/sqlgraph/graph_test.go Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com> * Update dialect/sql/sqlgraph/graph_test.go Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com> * gofmt Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>
This commit is contained in:
@@ -271,6 +271,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
Join(to).
|
||||
On(edge.C(pk1), to.C(s.To.Column))
|
||||
matches := builder.Select().From(to)
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
join.FromSelect(matches)
|
||||
q.Where(sql.In(from.C(s.From.Column), join))
|
||||
@@ -279,6 +280,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
to := builder.Table(s.To.Table).Schema(s.To.Schema)
|
||||
matches := builder.Select(to.C(s.To.Column)).
|
||||
From(to)
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
q.Where(sql.In(from.C(s.Edge.Columns[0]), matches))
|
||||
case r == O2M || (r == O2O && !s.Edge.Inverse):
|
||||
@@ -286,6 +288,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
|
||||
matches := builder.Select(to.C(s.Edge.Columns[0])).
|
||||
From(to)
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
q.Where(sql.In(from.C(s.From.Column), matches))
|
||||
}
|
||||
@@ -462,7 +465,8 @@ func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int
|
||||
builder = sql.Dialect(drv.Dialect())
|
||||
)
|
||||
selector := builder.Select().
|
||||
From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema))
|
||||
From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema)).
|
||||
WithContext(ctx)
|
||||
if pred := spec.Predicate; pred != nil {
|
||||
pred(selector)
|
||||
}
|
||||
@@ -556,7 +560,7 @@ type query struct {
|
||||
|
||||
func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
rows := &sql.Rows{}
|
||||
selector, err := q.selector()
|
||||
selector, err := q.selector(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -586,7 +590,7 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
|
||||
func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
|
||||
rows := &sql.Rows{}
|
||||
selector, err := q.selector()
|
||||
selector, err := q.selector(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -603,8 +607,11 @@ func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
|
||||
return sql.ScanInt(rows)
|
||||
}
|
||||
|
||||
func (q *query) selector() (*sql.Selector, error) {
|
||||
selector := q.builder.Select().From(q.builder.Table(q.Node.Table).Schema(q.Node.Schema))
|
||||
func (q *query) selector(ctx context.Context) (*sql.Selector, error) {
|
||||
selector := q.builder.
|
||||
Select().
|
||||
From(q.builder.Table(q.Node.Table).Schema(q.Node.Schema)).
|
||||
WithContext(ctx)
|
||||
if q.From != nil {
|
||||
selector = q.From
|
||||
}
|
||||
@@ -678,7 +685,8 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
|
||||
multiple = u.hasExternalEdges(addEdges, clearEdges)
|
||||
update = u.builder.Update(u.Node.Table).Schema(u.Node.Schema)
|
||||
selector = u.builder.Select(u.Node.ID.Column).
|
||||
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema))
|
||||
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).
|
||||
WithContext(ctx)
|
||||
)
|
||||
if pred := u.Predicate; pred != nil {
|
||||
pred(selector)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/facebook/ent/dialect"
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
|
||||
@@ -805,6 +806,32 @@ WHERE "s1"."users"."id" IN
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasNeighborsWithContext(t *testing.T) {
|
||||
type key string
|
||||
ctx := context.WithValue(context.Background(), key("mykey"), "myval")
|
||||
for _, rel := range [...]Rel{M2M, O2M, O2O} {
|
||||
t.Run(rel.String(), func(t *testing.T) {
|
||||
sel := sql.Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(sql.Table("users")).
|
||||
WithContext(ctx)
|
||||
step := NewStep(
|
||||
From("users", "id"),
|
||||
To("groups", "id"),
|
||||
Edge(rel, false, "user_groups", "user_id", "group_id"),
|
||||
)
|
||||
var called bool
|
||||
pred := func(s *sql.Selector) {
|
||||
called = true
|
||||
got := s.Context().Value(key("mykey")).(string)
|
||||
require.Equal(t, "myval", got)
|
||||
}
|
||||
HasNeighborsWith(sel, step, pred)
|
||||
require.True(t, called, "expected predicate function to be called")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user