entc/gen: add typed traversers and interceptors (#3182)

This commit is contained in:
Ariel Mashraki
2022-12-20 11:12:19 +02:00
committed by GitHub
parent 34bd0b7b6f
commit 92bacc10e4
4 changed files with 85 additions and 31 deletions

View File

@@ -73,7 +73,7 @@ func (f TraverseFunc) Traverse(ctx context.Context, q {{ $pkg }}.Query) error {
}
{{- range $n := $.Nodes }}
{{ $name := print $n.Name "QueryFunc" }}
{{ $name := print $n.Name "Func" }}
{{ $type := printf "*%s.%s" $pkg $n.QueryName }}
// The {{ $name }} type is an adapter to allow the use of ordinary function as a Querier.
type {{ $name }} func(context.Context, {{ $type }}) ({{ $pkg }}.Value, error)
@@ -85,6 +85,23 @@ func (f TraverseFunc) Traverse(ctx context.Context, q {{ $pkg }}.Query) error {
}
return nil, fmt.Errorf("unexpected query type %T. expect {{ $type }}", q)
}
{{ $name = print "Traverse" $n.Name }}
// The {{ $name }} type is an adapter to allow the use of ordinary function as Traverser.
type {{ $name }} func(context.Context, {{ $type }}) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f {{ $name }}) Intercept(next {{ $pkg }}.Querier) {{ $pkg }}.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f {{ $name }}) Traverse(ctx context.Context, q {{ $pkg }}.Query) error {
if q, ok := q.({{ $type }}); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect {{ $type }}", q)
}
{{- end }}
// NewQuery returns the generic Query interface for the given typed query.

View File

@@ -71,39 +71,87 @@ func (f TraverseFunc) Traverse(ctx context.Context, q ent.Query) error {
return f(ctx, query)
}
// The CardQueryFunc type is an adapter to allow the use of ordinary function as a Querier.
type CardQueryFunc func(context.Context, *ent.CardQuery) (ent.Value, error)
// The CardFunc type is an adapter to allow the use of ordinary function as a Querier.
type CardFunc func(context.Context, *ent.CardQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f CardQueryFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
func (f CardFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.CardQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.CardQuery", q)
}
// The PetQueryFunc type is an adapter to allow the use of ordinary function as a Querier.
type PetQueryFunc func(context.Context, *ent.PetQuery) (ent.Value, error)
// The TraverseCard type is an adapter to allow the use of ordinary function as Traverser.
type TraverseCard func(context.Context, *ent.CardQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseCard) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseCard) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.CardQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.CardQuery", q)
}
// The PetFunc type is an adapter to allow the use of ordinary function as a Querier.
type PetFunc func(context.Context, *ent.PetQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f PetQueryFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
func (f PetFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.PetQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.PetQuery", q)
}
// The UserQueryFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserQueryFunc func(context.Context, *ent.UserQuery) (ent.Value, error)
// The TraversePet type is an adapter to allow the use of ordinary function as Traverser.
type TraversePet func(context.Context, *ent.PetQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraversePet) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraversePet) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.PetQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.PetQuery", q)
}
// The UserFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserFunc func(context.Context, *ent.UserQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UserQueryFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
func (f UserFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UserQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserQuery", q)
}
// The TraverseUser type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUser func(context.Context, *ent.UserQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUser) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUser) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UserQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UserQuery", q)
}
// NewQuery returns the generic Query interface for the given typed query.
func NewQuery(q ent.Query) (Query, error) {
switch q := q.(type) {

View File

@@ -11,10 +11,8 @@ import (
"testing"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/entc/integration/hooks/ent"
"entgo.io/ent/entc/integration/hooks/ent/card"
"entgo.io/ent/entc/integration/hooks/ent/enttest"
@@ -494,9 +492,7 @@ func TestInterceptor_Sanity(t *testing.T) {
defer client.Close()
client.Intercept(
ent.InterceptFunc(func(next ent.Querier) ent.Querier {
return ent.QuerierFunc(func(ctx context.Context, query ent.Query) (ent.Value, error) {
_, err := intercept.NewQuery(query)
require.NoError(t, err)
return intercept.UserFunc(func(ctx context.Context, query *ent.UserQuery) (ent.Value, error) {
calls++
nodes, err := next.Query(ctx, query)
require.NoError(t, err)
@@ -739,13 +735,12 @@ func TestTypedTraverser(t *testing.T) {
// Add an interceptor that filters out inactive users.
client.User.Intercept(
ent.TraverseFunc(func(ctx context.Context, query ent.Query) error {
if q, ok := query.(*ent.UserQuery); ok {
q.Where(user.Active(true))
}
intercept.TraverseUser(func(ctx context.Context, q *ent.UserQuery) error {
q.Where(user.Active(true))
return nil
}),
)
// Only pets of active users are returned.
if n := client.User.Query().QueryPets().CountX(ctx); n != 2 {
t.Errorf("got %d pets, want 2", n)
@@ -786,9 +781,7 @@ func TestFilterTraverseFunc(t *testing.T) {
// Add an interceptor that filters out inactive users.
client.User.Intercept(
intercept.TraverseFunc(func(ctx context.Context, query intercept.Query) error {
query.WhereP(func(s *sql.Selector) {
s.Where(sql.EQ("active", true))
})
query.WhereP(sql.FieldEQ("active", true))
return nil
}),
)