mirror of
https://github.com/ent/ent.git
synced 2026-04-30 06:30:55 +03:00
dialect/sql/sqlgraph: pass context.Context to *sql.Selector (#1186)
* Ensure sqlgraph passes the context to *sql.Selector * Update dialect/sql/sqlgraph/graph_test.go Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com> * Update dialect/sql/sqlgraph/graph_test.go Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com> * gofmt Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>
This commit is contained in:
@@ -271,6 +271,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
Join(to).
|
||||
On(edge.C(pk1), to.C(s.To.Column))
|
||||
matches := builder.Select().From(to)
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
join.FromSelect(matches)
|
||||
q.Where(sql.In(from.C(s.From.Column), join))
|
||||
@@ -279,6 +280,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
to := builder.Table(s.To.Table).Schema(s.To.Schema)
|
||||
matches := builder.Select(to.C(s.To.Column)).
|
||||
From(to)
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
q.Where(sql.In(from.C(s.Edge.Columns[0]), matches))
|
||||
case r == O2M || (r == O2O && !s.Edge.Inverse):
|
||||
@@ -286,6 +288,7 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
|
||||
matches := builder.Select(to.C(s.Edge.Columns[0])).
|
||||
From(to)
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
q.Where(sql.In(from.C(s.From.Column), matches))
|
||||
}
|
||||
@@ -462,7 +465,8 @@ func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int
|
||||
builder = sql.Dialect(drv.Dialect())
|
||||
)
|
||||
selector := builder.Select().
|
||||
From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema))
|
||||
From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema)).
|
||||
WithContext(ctx)
|
||||
if pred := spec.Predicate; pred != nil {
|
||||
pred(selector)
|
||||
}
|
||||
@@ -556,7 +560,7 @@ type query struct {
|
||||
|
||||
func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
rows := &sql.Rows{}
|
||||
selector, err := q.selector()
|
||||
selector, err := q.selector(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -586,7 +590,7 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
|
||||
func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
|
||||
rows := &sql.Rows{}
|
||||
selector, err := q.selector()
|
||||
selector, err := q.selector(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -603,8 +607,11 @@ func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
|
||||
return sql.ScanInt(rows)
|
||||
}
|
||||
|
||||
func (q *query) selector() (*sql.Selector, error) {
|
||||
selector := q.builder.Select().From(q.builder.Table(q.Node.Table).Schema(q.Node.Schema))
|
||||
func (q *query) selector(ctx context.Context) (*sql.Selector, error) {
|
||||
selector := q.builder.
|
||||
Select().
|
||||
From(q.builder.Table(q.Node.Table).Schema(q.Node.Schema)).
|
||||
WithContext(ctx)
|
||||
if q.From != nil {
|
||||
selector = q.From
|
||||
}
|
||||
@@ -678,7 +685,8 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
|
||||
multiple = u.hasExternalEdges(addEdges, clearEdges)
|
||||
update = u.builder.Update(u.Node.Table).Schema(u.Node.Schema)
|
||||
selector = u.builder.Select(u.Node.ID.Column).
|
||||
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema))
|
||||
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).
|
||||
WithContext(ctx)
|
||||
)
|
||||
if pred := u.Predicate; pred != nil {
|
||||
pred(selector)
|
||||
|
||||
Reference in New Issue
Block a user