entql: add typed-builder for field predicates

This commit is contained in:
Ariel Mashraki
2020-09-17 08:13:50 +03:00
committed by Ariel Mashraki
parent 16e804a788
commit fae1956828
26 changed files with 4883 additions and 166 deletions

View File

@@ -1,3 +1,7 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sqlgraph
import (
@@ -8,12 +12,14 @@ import (
)
type (
// A Graph holds multiple ent/schemas and their relations in the graph.
// A Schema holds a representation of ent/schema at runtime. Each Node
// represents a single schema-type and its relations in the graph (storage).
//
// It is used for translating common graph traversal operations to the
// underlying SQL storage. For example, an operation like `has_edge(E)`,
// will be translated to an SQL lookup based on the relation type and the
// FK configuration.
Graph struct {
Schema struct {
Nodes []*Node
}
@@ -21,7 +27,7 @@ type (
Node struct {
NodeSpec
// Type or label holds the
// Type holds the node type (schema name).
Type string
// Fields maps from field names to their spec.
@@ -41,7 +47,7 @@ type (
// g.AddE("pets", spec, "user", "pet")
// g.AddE("friends", spec, "user", "user")
//
func (g *Graph) AddE(name string, spec *EdgeSpec, from, to string) error {
func (g *Schema) AddE(name string, spec *EdgeSpec, from, to string) error {
var fromT, toT *Node
for i := range g.Nodes {
t := g.Nodes[i].Type
@@ -71,8 +77,15 @@ func (g *Graph) AddE(name string, spec *EdgeSpec, from, to string) error {
return nil
}
// EvalP evaluates the entql predicate on the query builder.
func (g *Graph) EvalP(nodeType string, p entql.P, selector *sql.Selector) error {
// MustAddE is like AddE but panics if the edge can be added to the graph.
func (g *Schema) MustAddE(name string, spec *EdgeSpec, from, to string) {
if err := g.AddE(name, spec, from, to); err != nil {
panic(err)
}
}
// EvalP evaluates the entql predicate on the given selector (query builder).
func (g *Schema) EvalP(nodeType string, p entql.P, selector *sql.Selector) error {
var node *Node
for i := range g.Nodes {
if g.Nodes[i].Type == nodeType {
@@ -81,9 +94,9 @@ func (g *Graph) EvalP(nodeType string, p entql.P, selector *sql.Selector) error
}
}
if node == nil {
return fmt.Errorf("node %s was not found in the graph", nodeType)
return fmt.Errorf("node %s was not found in the graph schema", nodeType)
}
pr, err := execExpr(node, selector, p)
pr, err := evalExpr(node, selector, p)
if err != nil {
return err
}
@@ -91,6 +104,23 @@ func (g *Graph) EvalP(nodeType string, p entql.P, selector *sql.Selector) error
return nil
}
// FuncSelector represents a selector function to be used as an entql foreign-function.
const FuncSelector entql.Func = "func_selector"
// wrappedFunc wraps the selector-function to an ent-expression.
type wrappedFunc struct {
entql.Expr
Func func(*sql.Selector)
}
// WrapFunc wraps a selector-func with an entql call expression.
func WrapFunc(s func(*sql.Selector)) *entql.CallExpr {
return &entql.CallExpr{
Func: FuncSelector,
Args: []entql.Expr{wrappedFunc{Func: s}},
}
}
var (
binary = [...]sql.Op{
entql.OpEQ: sql.OpEQ,
@@ -106,23 +136,31 @@ var (
entql.OpAnd: sql.And,
entql.OpOr: sql.Or,
}
strFunc = [...]func(string, string) *sql.Predicate{
entql.OpContains: sql.Contains,
entql.OpContainsFold: sql.ContainsFold,
entql.OpEqualFold: sql.EqualFold,
entql.OpHasPrefix: sql.HasPrefix,
entql.OpHasSuffix: sql.HasSuffix,
strFunc = map[entql.Func]func(string, string) *sql.Predicate{
entql.FuncContains: sql.Contains,
entql.FuncContainsFold: sql.ContainsFold,
entql.FuncEqualFold: sql.EqualFold,
entql.FuncHasPrefix: sql.HasPrefix,
entql.FuncHasSuffix: sql.HasSuffix,
}
nullFunc = [...]func(string) *sql.Predicate{
entql.OpEQ: sql.IsNull,
entql.OpNEQ: sql.NotNull,
}
)
type exec struct {
// state represents the state of a predicate evaluation.
// Note that, the evaluation output is a predicate to be
// applied on the database.
type state struct {
sql.Builder
context *Node
selector *sql.Selector
}
func execExpr(context *Node, selector *sql.Selector, expr entql.Expr) (p *sql.Predicate, err error) {
ex := &exec{
// evalExpr evaluates the entql expression and returns a new SQL predicate to be applied on the database.
func evalExpr(context *Node, selector *sql.Selector, expr entql.Expr) (p *sql.Predicate, err error) {
ex := &state{
context: context,
selector: selector,
}
@@ -131,7 +169,8 @@ func execExpr(context *Node, selector *sql.Selector, expr entql.Expr) (p *sql.Pr
return
}
func (e *exec) evalExpr(expr entql.Expr) *sql.Predicate {
// evalExpr evaluates any expression.
func (e *state) evalExpr(expr entql.Expr) *sql.Predicate {
switch expr := expr.(type) {
case *entql.BinaryExpr:
return e.evalBinary(expr)
@@ -144,18 +183,18 @@ func (e *exec) evalExpr(expr entql.Expr) *sql.Predicate {
}
return nary[expr.Op](ps...)
case *entql.CallExpr:
switch expr.Op {
case entql.OpHasPrefix, entql.OpHasSuffix, entql.OpContains, entql.OpEqualFold, entql.OpContainsFold:
expect(len(expr.Args) == 2, "invalid number of arguments for %s", expr.Op)
switch expr.Func {
case entql.FuncHasPrefix, entql.FuncHasSuffix, entql.FuncContains, entql.FuncEqualFold, entql.FuncContainsFold:
expect(len(expr.Args) == 2, "invalid number of arguments for %s", expr.Func)
f, ok := expr.Args[0].(*entql.Field)
expect(ok, "*entql.Field, got %T", expr.Args[0])
v, ok := expr.Args[1].(*entql.Value)
expect(ok, "*entql.Value, got %T", expr.Args[1])
s, ok := v.V.(string)
expect(ok, "string value, got %T", v.V)
return strFunc[expr.Op](e.field(f), s)
case entql.OpHasEdge:
expect(len(expr.Args) > 0, "invalid number of arguments for %s", expr.Op)
return strFunc[expr.Func](e.field(f), s)
case entql.FuncHasEdge:
expect(len(expr.Args) > 0, "invalid number of arguments for %s", expr.Func)
edge, ok := expr.Args[0].(*entql.Edge)
expect(ok, "*entql.Edge, got %T", expr.Args[0])
return e.evalEdge(edge.Name, expr.Args[1:]...)
@@ -164,20 +203,20 @@ func (e *exec) evalExpr(expr entql.Expr) *sql.Predicate {
panic("invalid")
}
func (e *exec) evalBinary(expr *entql.BinaryExpr) *sql.Predicate {
if (expr.Op == entql.OpEQ || expr.Op == entql.OpNEQ) && expr.Y == (*entql.Value)(nil) {
f, ok := expr.X.(*entql.Field)
expect(ok, "*entql.Field, got %T", expr.Y)
if expr.Op == entql.OpEQ {
return sql.IsNull(e.field(f))
}
return sql.NotNull(e.field(f))
}
// evalBinary evaluates binary expressions.
func (e *state) evalBinary(expr *entql.BinaryExpr) *sql.Predicate {
switch expr.Op {
case entql.OpOr:
return sql.Or(e.evalExpr(expr.X), e.evalExpr(expr.Y))
case entql.OpAnd:
return sql.And(e.evalExpr(expr.X), e.evalExpr(expr.Y))
case entql.OpEQ, entql.OpNEQ:
if expr.Y == (*entql.Value)(nil) {
f, ok := expr.X.(*entql.Field)
expect(ok, "*entql.Field, got %T", expr.Y)
return nullFunc[expr.Op](e.field(f))
}
fallthrough
default:
field, ok := expr.X.(*entql.Field)
expect(ok, "expr.X to be *entql.Field (got %T)", expr.X)
@@ -199,13 +238,8 @@ func (e *exec) evalBinary(expr *entql.BinaryExpr) *sql.Predicate {
}
}
func (e *exec) field(f *entql.Field) string {
_, ok := e.context.Fields[f.Name]
expect(ok || e.context.ID.Column == f.Name, "field %q was not found for node %q", f.Name, e.context.Type)
return f.Name
}
func (e *exec) evalEdge(name string, exprs ...entql.Expr) *sql.Predicate {
// evalEdge evaluates has-edge and has-edge-with calls.
func (e *state) evalEdge(name string, exprs ...entql.Expr) *sql.Predicate {
edge, ok := e.context.Edges[name]
expect(ok, "edge %q was not found for node %q", name, e.context.Type)
step := NewStep(
@@ -215,20 +249,33 @@ func (e *exec) evalEdge(name string, exprs ...entql.Expr) *sql.Predicate {
)
selector := e.selector.Clone().SetP(nil)
selector.SetTotal(e.Total())
if len(exprs) > 0 {
HasNeighborsWith(selector, step, func(s *sql.Selector) {
for i := range exprs {
p, err := execExpr(edge.To, s, exprs[i])
if len(exprs) == 0 {
HasNeighbors(selector, step)
return selector.P()
}
HasNeighborsWith(selector, step, func(s *sql.Selector) {
for _, expr := range exprs {
if cx, ok := expr.(*entql.CallExpr); ok && cx.Func == FuncSelector {
expect(len(cx.Args) == 1, "invalid number of arguments for %s", FuncSelector)
wrapped, ok := cx.Args[0].(wrappedFunc)
expect(ok, "invalid argument for %s: %T", FuncSelector, cx.Args[0])
wrapped.Func(s)
} else {
p, err := evalExpr(edge.To, s, expr)
expect(err == nil, "edge evaluation failed for %s->%s: %s", e.context.Type, name, err)
s.Where(p)
}
})
} else {
HasNeighbors(selector, step)
}
}
})
return selector.P()
}
func (e *state) field(f *entql.Field) string {
_, ok := e.context.Fields[f.Name]
expect(ok || e.context.ID.Column == f.Name, "field %q was not found for node %q", f.Name, e.context.Type)
return f.Name
}
func args(b *sql.Builder, v *entql.Value) {
vs, ok := v.V.([]interface{})
if !ok {
@@ -241,19 +288,21 @@ func args(b *sql.Builder, v *entql.Value) {
// expect panics if the condition is false.
func expect(cond bool, msg string, args ...interface{}) {
if !cond {
panic(execError{fmt.Sprintf("expect "+msg, args...)})
panic(evalError{fmt.Sprintf("expect "+msg, args...)})
}
}
type execError struct {
type evalError struct {
msg string
}
func (p execError) Error() string { return fmt.Sprintf("sqlgraph: %s", p.msg) }
func (p evalError) Error() string {
return fmt.Sprintf("sqlgraph: %s", p.msg)
}
func catch(err *error) {
if e := recover(); e != nil {
xerr, ok := e.(execError)
xerr, ok := e.(evalError)
if !ok {
panic(e)
}

View File

@@ -1,3 +1,7 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sqlgraph
import (
@@ -14,7 +18,7 @@ import (
)
func TestGraph_AddE(t *testing.T) {
g := &Graph{
g := &Schema{
Nodes: []*Node{{Type: "user"}, {Type: "pet"}},
}
err := g.AddE("pets", &EdgeSpec{Rel: O2M}, "user", "pet")
@@ -26,7 +30,7 @@ func TestGraph_AddE(t *testing.T) {
}
func TestGraph_EvalP(t *testing.T) {
g := &Graph{
g := &Schema{
Nodes: []*Node{
{
Type: "user",
@@ -107,6 +111,11 @@ func TestGraph_EvalP(t *testing.T) {
p: entql.EQ(entql.F("name"), entql.F("last")),
wantQuery: `SELECT * FROM "users" WHERE "name" = "last"`,
},
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")),
p: entql.And(entql.FieldNil("name"), entql.FieldNotNil("last")),
wantQuery: `SELECT * FROM "users" WHERE "name" IS NULL AND "last" IS NOT NULL`,
},
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).
Where(sql.EQ("foo", "bar")),
@@ -142,6 +151,14 @@ func TestGraph_EvalP(t *testing.T) {
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "name" = "uid")`,
wantArgs: []interface{}{true},
},
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
p: entql.HasEdgeWith("pets", entql.FieldEQ("name", "pedro"), WrapFunc(func(s *sql.Selector) {
s.Where(sql.EQ("owner_id", 10))
})),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2 AND "owner_id" = $3)`,
wantArgs: []interface{}{true, "pedro", 10},
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {