dialect/sql/sqlgraph: pass context.Context to *sql.Selector (#1186)

* Ensure sqlgraph passes the context to *sql.Selector

* Update dialect/sql/sqlgraph/graph_test.go

Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>

* Update dialect/sql/sqlgraph/graph_test.go

Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>

* gofmt

Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com>
This commit is contained in:
Marwan Sulaiman
2021-01-18 12:41:59 -05:00
committed by GitHub
parent ddb25280cd
commit 3439ca207f
2 changed files with 41 additions and 6 deletions

View File

@@ -12,6 +12,7 @@ import (
"strings"
"testing"
"github.com/facebook/ent/dialect"
"github.com/facebook/ent/dialect/sql"
"github.com/facebook/ent/schema/field"
@@ -805,6 +806,32 @@ WHERE "s1"."users"."id" IN
}
}
func TestHasNeighborsWithContext(t *testing.T) {
type key string
ctx := context.WithValue(context.Background(), key("mykey"), "myval")
for _, rel := range [...]Rel{M2M, O2M, O2O} {
t.Run(rel.String(), func(t *testing.T) {
sel := sql.Dialect(dialect.Postgres).
Select("*").
From(sql.Table("users")).
WithContext(ctx)
step := NewStep(
From("users", "id"),
To("groups", "id"),
Edge(rel, false, "user_groups", "user_id", "group_id"),
)
var called bool
pred := func(s *sql.Selector) {
called = true
got := s.Context().Value(key("mykey")).(string)
require.Equal(t, "myval", got)
}
HasNeighborsWith(sel, step, pred)
require.True(t, called, "expected predicate function to be called")
})
}
}
func TestCreateNode(t *testing.T) {
tests := []struct {
name string