mirror of
https://github.com/ent/ent.git
synced 2026-04-28 21:50:56 +03:00
entql: add experimental implementation for entql
This commit is contained in:
committed by
Ariel Mashraki
parent
dfc4dee8a5
commit
16e804a788
262
dialect/sql/sqlgraph/entql.go
Normal file
262
dialect/sql/sqlgraph/entql.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package sqlgraph
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
"github.com/facebook/ent/entql"
|
||||
)
|
||||
|
||||
type (
|
||||
// A Graph holds multiple ent/schemas and their relations in the graph.
|
||||
// It is used for translating common graph traversal operations to the
|
||||
// underlying SQL storage. For example, an operation like `has_edge(E)`,
|
||||
// will be translated to an SQL lookup based on the relation type and the
|
||||
// FK configuration.
|
||||
Graph struct {
|
||||
Nodes []*Node
|
||||
}
|
||||
|
||||
// A Node in the graph holds the SQL information for an ent/schema.
|
||||
Node struct {
|
||||
NodeSpec
|
||||
|
||||
// Type or label holds the
|
||||
Type string
|
||||
|
||||
// Fields maps from field names to their spec.
|
||||
Fields map[string]*FieldSpec
|
||||
|
||||
// Edges maps from edge names to their spec.
|
||||
Edges map[string]struct {
|
||||
To *Node
|
||||
Spec *EdgeSpec
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// AddEdge adds an edge to the graph. It fails, if one of the node
|
||||
// types is missing.
|
||||
//
|
||||
// g.AddE("pets", spec, "user", "pet")
|
||||
// g.AddE("friends", spec, "user", "user")
|
||||
//
|
||||
func (g *Graph) AddE(name string, spec *EdgeSpec, from, to string) error {
|
||||
var fromT, toT *Node
|
||||
for i := range g.Nodes {
|
||||
t := g.Nodes[i].Type
|
||||
if t == from {
|
||||
fromT = g.Nodes[i]
|
||||
}
|
||||
if t == to {
|
||||
toT = g.Nodes[i]
|
||||
}
|
||||
}
|
||||
if fromT == nil || toT == nil {
|
||||
return fmt.Errorf("from/to type was not found")
|
||||
}
|
||||
if fromT.Edges == nil {
|
||||
fromT.Edges = make(map[string]struct {
|
||||
To *Node
|
||||
Spec *EdgeSpec
|
||||
})
|
||||
}
|
||||
fromT.Edges[name] = struct {
|
||||
To *Node
|
||||
Spec *EdgeSpec
|
||||
}{
|
||||
To: toT,
|
||||
Spec: spec,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EvalP evaluates the entql predicate on the query builder.
|
||||
func (g *Graph) EvalP(nodeType string, p entql.P, selector *sql.Selector) error {
|
||||
var node *Node
|
||||
for i := range g.Nodes {
|
||||
if g.Nodes[i].Type == nodeType {
|
||||
node = g.Nodes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if node == nil {
|
||||
return fmt.Errorf("node %s was not found in the graph", nodeType)
|
||||
}
|
||||
pr, err := execExpr(node, selector, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
selector.Where(pr)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
binary = [...]sql.Op{
|
||||
entql.OpEQ: sql.OpEQ,
|
||||
entql.OpNEQ: sql.OpNEQ,
|
||||
entql.OpGT: sql.OpGT,
|
||||
entql.OpGTE: sql.OpGTE,
|
||||
entql.OpLT: sql.OpLT,
|
||||
entql.OpLTE: sql.OpLTE,
|
||||
entql.OpIn: sql.OpIn,
|
||||
entql.OpNotIn: sql.OpNotIn,
|
||||
}
|
||||
nary = [...]func(...*sql.Predicate) *sql.Predicate{
|
||||
entql.OpAnd: sql.And,
|
||||
entql.OpOr: sql.Or,
|
||||
}
|
||||
strFunc = [...]func(string, string) *sql.Predicate{
|
||||
entql.OpContains: sql.Contains,
|
||||
entql.OpContainsFold: sql.ContainsFold,
|
||||
entql.OpEqualFold: sql.EqualFold,
|
||||
entql.OpHasPrefix: sql.HasPrefix,
|
||||
entql.OpHasSuffix: sql.HasSuffix,
|
||||
}
|
||||
)
|
||||
|
||||
type exec struct {
|
||||
sql.Builder
|
||||
context *Node
|
||||
selector *sql.Selector
|
||||
}
|
||||
|
||||
func execExpr(context *Node, selector *sql.Selector, expr entql.Expr) (p *sql.Predicate, err error) {
|
||||
ex := &exec{
|
||||
context: context,
|
||||
selector: selector,
|
||||
}
|
||||
defer catch(&err)
|
||||
p = ex.evalExpr(expr)
|
||||
return
|
||||
}
|
||||
|
||||
func (e *exec) evalExpr(expr entql.Expr) *sql.Predicate {
|
||||
switch expr := expr.(type) {
|
||||
case *entql.BinaryExpr:
|
||||
return e.evalBinary(expr)
|
||||
case *entql.UnaryExpr:
|
||||
return sql.Not(e.evalExpr(expr.X))
|
||||
case *entql.NaryExpr:
|
||||
ps := make([]*sql.Predicate, len(expr.Xs))
|
||||
for i, x := range expr.Xs {
|
||||
ps[i] = e.evalExpr(x)
|
||||
}
|
||||
return nary[expr.Op](ps...)
|
||||
case *entql.CallExpr:
|
||||
switch expr.Op {
|
||||
case entql.OpHasPrefix, entql.OpHasSuffix, entql.OpContains, entql.OpEqualFold, entql.OpContainsFold:
|
||||
expect(len(expr.Args) == 2, "invalid number of arguments for %s", expr.Op)
|
||||
f, ok := expr.Args[0].(*entql.Field)
|
||||
expect(ok, "*entql.Field, got %T", expr.Args[0])
|
||||
v, ok := expr.Args[1].(*entql.Value)
|
||||
expect(ok, "*entql.Value, got %T", expr.Args[1])
|
||||
s, ok := v.V.(string)
|
||||
expect(ok, "string value, got %T", v.V)
|
||||
return strFunc[expr.Op](e.field(f), s)
|
||||
case entql.OpHasEdge:
|
||||
expect(len(expr.Args) > 0, "invalid number of arguments for %s", expr.Op)
|
||||
edge, ok := expr.Args[0].(*entql.Edge)
|
||||
expect(ok, "*entql.Edge, got %T", expr.Args[0])
|
||||
return e.evalEdge(edge.Name, expr.Args[1:]...)
|
||||
}
|
||||
}
|
||||
panic("invalid")
|
||||
}
|
||||
|
||||
func (e *exec) evalBinary(expr *entql.BinaryExpr) *sql.Predicate {
|
||||
if (expr.Op == entql.OpEQ || expr.Op == entql.OpNEQ) && expr.Y == (*entql.Value)(nil) {
|
||||
f, ok := expr.X.(*entql.Field)
|
||||
expect(ok, "*entql.Field, got %T", expr.Y)
|
||||
if expr.Op == entql.OpEQ {
|
||||
return sql.IsNull(e.field(f))
|
||||
}
|
||||
return sql.NotNull(e.field(f))
|
||||
}
|
||||
switch expr.Op {
|
||||
case entql.OpOr:
|
||||
return sql.Or(e.evalExpr(expr.X), e.evalExpr(expr.Y))
|
||||
case entql.OpAnd:
|
||||
return sql.And(e.evalExpr(expr.X), e.evalExpr(expr.Y))
|
||||
default:
|
||||
field, ok := expr.X.(*entql.Field)
|
||||
expect(ok, "expr.X to be *entql.Field (got %T)", expr.X)
|
||||
_, ok = expr.Y.(*entql.Field)
|
||||
if !ok {
|
||||
_, ok = expr.Y.(*entql.Value)
|
||||
}
|
||||
expect(ok, "expr.Y to be *entql.Field or *entql.Value (got %T)", expr.X)
|
||||
return sql.P(func(b *sql.Builder) {
|
||||
b.Ident(e.field(field))
|
||||
b.WriteOp(binary[expr.Op])
|
||||
switch x := expr.Y.(type) {
|
||||
case *entql.Field:
|
||||
b.Ident(e.field(x))
|
||||
case *entql.Value:
|
||||
args(b, x)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *exec) field(f *entql.Field) string {
|
||||
_, ok := e.context.Fields[f.Name]
|
||||
expect(ok || e.context.ID.Column == f.Name, "field %q was not found for node %q", f.Name, e.context.Type)
|
||||
return f.Name
|
||||
}
|
||||
|
||||
func (e *exec) evalEdge(name string, exprs ...entql.Expr) *sql.Predicate {
|
||||
edge, ok := e.context.Edges[name]
|
||||
expect(ok, "edge %q was not found for node %q", name, e.context.Type)
|
||||
step := NewStep(
|
||||
From(e.context.Table, e.context.ID.Column),
|
||||
To(edge.To.Table, edge.To.ID.Column),
|
||||
Edge(edge.Spec.Rel, edge.Spec.Inverse, edge.Spec.Table, edge.Spec.Columns...),
|
||||
)
|
||||
selector := e.selector.Clone().SetP(nil)
|
||||
selector.SetTotal(e.Total())
|
||||
if len(exprs) > 0 {
|
||||
HasNeighborsWith(selector, step, func(s *sql.Selector) {
|
||||
for i := range exprs {
|
||||
p, err := execExpr(edge.To, s, exprs[i])
|
||||
expect(err == nil, "edge evaluation failed for %s->%s: %s", e.context.Type, name, err)
|
||||
s.Where(p)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
HasNeighbors(selector, step)
|
||||
}
|
||||
return selector.P()
|
||||
}
|
||||
|
||||
func args(b *sql.Builder, v *entql.Value) {
|
||||
vs, ok := v.V.([]interface{})
|
||||
if !ok {
|
||||
b.Arg(v.V)
|
||||
return
|
||||
}
|
||||
b.Args(vs...)
|
||||
}
|
||||
|
||||
// expect panics if the condition is false.
|
||||
func expect(cond bool, msg string, args ...interface{}) {
|
||||
if !cond {
|
||||
panic(execError{fmt.Sprintf("expect "+msg, args...)})
|
||||
}
|
||||
}
|
||||
|
||||
type execError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (p execError) Error() string { return fmt.Sprintf("sqlgraph: %s", p.msg) }
|
||||
|
||||
func catch(err *error) {
|
||||
if e := recover(); e != nil {
|
||||
xerr, ok := e.(execError)
|
||||
if !ok {
|
||||
panic(e)
|
||||
}
|
||||
*err = xerr
|
||||
}
|
||||
}
|
||||
155
dialect/sql/sqlgraph/entql_test.go
Normal file
155
dialect/sql/sqlgraph/entql_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package sqlgraph
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/facebook/ent/dialect"
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
"github.com/facebook/ent/entql"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGraph_AddE(t *testing.T) {
|
||||
g := &Graph{
|
||||
Nodes: []*Node{{Type: "user"}, {Type: "pet"}},
|
||||
}
|
||||
err := g.AddE("pets", &EdgeSpec{Rel: O2M}, "user", "pet")
|
||||
assert.NoError(t, err)
|
||||
err = g.AddE("owner", &EdgeSpec{Rel: O2M}, "pet", "user")
|
||||
assert.NoError(t, err)
|
||||
err = g.AddE("groups", &EdgeSpec{Rel: M2M}, "pet", "groups")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGraph_EvalP(t *testing.T) {
|
||||
g := &Graph{
|
||||
Nodes: []*Node{
|
||||
{
|
||||
Type: "user",
|
||||
NodeSpec: NodeSpec{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "uid"},
|
||||
},
|
||||
Fields: map[string]*FieldSpec{
|
||||
"name": {Column: "name", Type: field.TypeString},
|
||||
"last": {Column: "last", Type: field.TypeString},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "pet",
|
||||
NodeSpec: NodeSpec{
|
||||
Table: "pets",
|
||||
ID: &FieldSpec{Column: "pid"},
|
||||
},
|
||||
Fields: map[string]*FieldSpec{
|
||||
"name": {Column: "name", Type: field.TypeString},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "group",
|
||||
NodeSpec: NodeSpec{
|
||||
Table: "groups",
|
||||
ID: &FieldSpec{Column: "gid"},
|
||||
},
|
||||
Fields: map[string]*FieldSpec{
|
||||
"name": {Column: "name", Type: field.TypeString},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := g.AddE("pets", &EdgeSpec{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}}, "user", "pet")
|
||||
require.NoError(t, err)
|
||||
err = g.AddE("owner", &EdgeSpec{Rel: M2O, Inverse: true, Table: "pets", Columns: []string{"owner_id"}}, "pet", "user")
|
||||
require.NoError(t, err)
|
||||
err = g.AddE("groups", &EdgeSpec{Rel: M2M, Table: "user_groups", Columns: []string{"user_id", "group_id"}}, "user", "group")
|
||||
require.NoError(t, err)
|
||||
err = g.AddE("users", &EdgeSpec{Rel: M2M, Inverse: true, Table: "user_groups", Columns: []string{"user_id", "group_id"}}, "group", "user")
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
s *sql.Selector
|
||||
p entql.P
|
||||
wantQuery string
|
||||
wantArgs []interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")),
|
||||
p: entql.FieldHasPrefix("name", "a"),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "name" LIKE $1`,
|
||||
wantArgs: []interface{}{"a%"},
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).
|
||||
Where(sql.EQ("age", 1)),
|
||||
p: entql.FieldHasPrefix("name", "a"),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "age" = $1 AND "name" LIKE $2`,
|
||||
wantArgs: []interface{}{1, "a%"},
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).
|
||||
Where(sql.EQ("age", 1)),
|
||||
p: entql.FieldHasPrefix("name", "a"),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "age" = $1 AND "name" LIKE $2`,
|
||||
wantArgs: []interface{}{1, "a%"},
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")),
|
||||
p: entql.EQ(entql.F("name"), entql.F("last")),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "name" = "last"`,
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")),
|
||||
p: entql.EQ(entql.F("name"), entql.F("last")),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "name" = "last"`,
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).
|
||||
Where(sql.EQ("foo", "bar")),
|
||||
p: entql.Or(entql.FieldEQ("name", "foo"), entql.FieldEQ("name", "baz")),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "foo" = $1 AND ("name" = $2 OR "name" = $3)`,
|
||||
wantArgs: []interface{}{"bar", "foo", "baz"},
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")),
|
||||
p: entql.HasEdge("pets"),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL)`,
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")),
|
||||
p: entql.HasEdge("groups"),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups")`,
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")),
|
||||
p: entql.HasEdgeWith("pets", entql.Or(entql.FieldEQ("name", "pedro"), entql.FieldEQ("name", "xabi"))),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $1 OR "name" = $2)`,
|
||||
wantArgs: []interface{}{"pedro", "xabi"},
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
|
||||
p: entql.HasEdgeWith("groups", entql.Or(entql.FieldEQ("name", "GitHub"), entql.FieldEQ("name", "GitLab"))),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t0" ON "user_groups"."group_id" = "t0"."gid" WHERE "name" = $2 OR "name" = $3)`,
|
||||
wantArgs: []interface{}{true, "GitHub", "GitLab"},
|
||||
},
|
||||
{
|
||||
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
|
||||
p: entql.And(entql.HasEdge("pets"), entql.HasEdge("groups"), entql.EQ(entql.F("name"), entql.F("uid"))),
|
||||
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "name" = "uid")`,
|
||||
wantArgs: []interface{}{true},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
err = g.EvalP("user", tt.p, tt.s)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
query, args := tt.s.Query()
|
||||
require.Equal(t, tt.wantQuery, query)
|
||||
require.Equal(t, tt.wantArgs, args)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user