mirror of
https://github.com/ent/ent.git
synced 2026-04-28 13:40:56 +03:00
dialect/sql/sqlgraph: pass context.Context to *sql.Selector (#1186)
* Ensure sqlgraph passes the context to *sql.Selector * Update dialect/sql/sqlgraph/graph_test.go Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com> * Update dialect/sql/sqlgraph/graph_test.go Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com> * gofmt Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/facebook/ent/dialect"
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
|
||||
@@ -805,6 +806,32 @@ WHERE "s1"."users"."id" IN
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasNeighborsWithContext(t *testing.T) {
|
||||
type key string
|
||||
ctx := context.WithValue(context.Background(), key("mykey"), "myval")
|
||||
for _, rel := range [...]Rel{M2M, O2M, O2O} {
|
||||
t.Run(rel.String(), func(t *testing.T) {
|
||||
sel := sql.Dialect(dialect.Postgres).
|
||||
Select("*").
|
||||
From(sql.Table("users")).
|
||||
WithContext(ctx)
|
||||
step := NewStep(
|
||||
From("users", "id"),
|
||||
To("groups", "id"),
|
||||
Edge(rel, false, "user_groups", "user_id", "group_id"),
|
||||
)
|
||||
var called bool
|
||||
pred := func(s *sql.Selector) {
|
||||
called = true
|
||||
got := s.Context().Value(key("mykey")).(string)
|
||||
require.Equal(t, "myval", got)
|
||||
}
|
||||
HasNeighborsWith(sel, step, pred)
|
||||
require.True(t, called, "expected predicate function to be called")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user