// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqlgraph import ( "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/entql" ) type ( // A Schema holds a representation of ent/schema at runtime. Each Node // represents a single schema-type and its relations in the graph (storage). // // It is used for translating common graph traversal operations to the // underlying SQL storage. For example, an operation like `has_edge(E)`, // will be translated to an SQL lookup based on the relation type and the // FK configuration. Schema struct { Nodes []*Node } // A Node in the graph holds the SQL information for an ent/schema. Node struct { NodeSpec // Type holds the node type (schema name). Type string // Fields maps from field names to their spec. Fields map[string]*FieldSpec // Edges maps from edge names to their spec. Edges map[string]struct { To *Node Spec *EdgeSpec } } ) // AddE adds an edge to the graph. It fails, if one of the node // types is missing. // // g.AddE("pets", spec, "user", "pet") // g.AddE("friends", spec, "user", "user") func (g *Schema) AddE(name string, spec *EdgeSpec, from, to string) error { var fromT, toT *Node for i := range g.Nodes { t := g.Nodes[i].Type if t == from { fromT = g.Nodes[i] } if t == to { toT = g.Nodes[i] } } if fromT == nil || toT == nil { return fmt.Errorf("from/to type was not found") } if fromT.Edges == nil { fromT.Edges = make(map[string]struct { To *Node Spec *EdgeSpec }) } fromT.Edges[name] = struct { To *Node Spec *EdgeSpec }{ To: toT, Spec: spec, } return nil } // MustAddE is like AddE but panics if the edge can be added to the graph. func (g *Schema) MustAddE(name string, spec *EdgeSpec, from, to string) { if err := g.AddE(name, spec, from, to); err != nil { panic(err) } } // EvalP evaluates the entql predicate on the given selector (query builder). func (g *Schema) EvalP(nodeType string, p entql.P, selector *sql.Selector) error { var node *Node for i := range g.Nodes { if g.Nodes[i].Type == nodeType { node = g.Nodes[i] break } } if node == nil { return fmt.Errorf("node %s was not found in the graph schema", nodeType) } pr, err := evalExpr(node, selector, p) if err != nil { return err } selector.Where(pr) return nil } // FuncSelector represents a selector function to be used as an entql foreign-function. const FuncSelector entql.Func = "func_selector" // wrappedFunc wraps the selector-function to an ent-expression. type wrappedFunc struct { entql.Expr Func func(*sql.Selector) } // WrapFunc wraps a selector-func with an entql call expression. func WrapFunc(s func(*sql.Selector)) *entql.CallExpr { return &entql.CallExpr{ Func: FuncSelector, Args: []entql.Expr{wrappedFunc{Func: s}}, } } var ( binary = [...]sql.Op{ entql.OpEQ: sql.OpEQ, entql.OpNEQ: sql.OpNEQ, entql.OpGT: sql.OpGT, entql.OpGTE: sql.OpGTE, entql.OpLT: sql.OpLT, entql.OpLTE: sql.OpLTE, entql.OpIn: sql.OpIn, entql.OpNotIn: sql.OpNotIn, } nary = [...]func(...*sql.Predicate) *sql.Predicate{ entql.OpAnd: sql.And, entql.OpOr: sql.Or, } strFunc = map[entql.Func]func(string, string) *sql.Predicate{ entql.FuncContains: sql.Contains, entql.FuncContainsFold: sql.ContainsFold, entql.FuncEqualFold: sql.EqualFold, entql.FuncHasPrefix: sql.HasPrefix, entql.FuncHasSuffix: sql.HasSuffix, } nullFunc = [...]func(string) *sql.Predicate{ entql.OpEQ: sql.IsNull, entql.OpNEQ: sql.NotNull, } ) // state represents the state of a predicate evaluation. // Note that, the evaluation output is a predicate to be // applied on the database. type state struct { sql.Builder context *Node selector *sql.Selector } // evalExpr evaluates the entql expression and returns a new SQL predicate to be applied on the database. func evalExpr(context *Node, selector *sql.Selector, expr entql.Expr) (p *sql.Predicate, err error) { ex := &state{ context: context, selector: selector, } defer catch(&err) p = ex.evalExpr(expr) return } // evalExpr evaluates any expression. func (e *state) evalExpr(expr entql.Expr) *sql.Predicate { switch expr := expr.(type) { case *entql.BinaryExpr: return e.evalBinary(expr) case *entql.UnaryExpr: return sql.Not(e.evalExpr(expr.X)) case *entql.NaryExpr: ps := make([]*sql.Predicate, len(expr.Xs)) for i, x := range expr.Xs { ps[i] = e.evalExpr(x) } return nary[expr.Op](ps...) case *entql.CallExpr: switch expr.Func { case entql.FuncHasPrefix, entql.FuncHasSuffix, entql.FuncContains, entql.FuncEqualFold, entql.FuncContainsFold: expect(len(expr.Args) == 2, "invalid number of arguments for %s", expr.Func) f, ok := expr.Args[0].(*entql.Field) expect(ok, "*entql.Field, got %T", expr.Args[0]) v, ok := expr.Args[1].(*entql.Value) expect(ok, "*entql.Value, got %T", expr.Args[1]) s, ok := v.V.(string) expect(ok, "string value, got %T", v.V) return strFunc[expr.Func](e.field(f), s) case entql.FuncHasEdge: expect(len(expr.Args) > 0, "invalid number of arguments for %s", expr.Func) edge, ok := expr.Args[0].(*entql.Edge) expect(ok, "*entql.Edge, got %T", expr.Args[0]) return e.evalEdge(edge.Name, expr.Args[1:]...) } } panic("invalid") } // evalBinary evaluates binary expressions. func (e *state) evalBinary(expr *entql.BinaryExpr) *sql.Predicate { switch expr.Op { case entql.OpOr: return sql.Or(e.evalExpr(expr.X), e.evalExpr(expr.Y)) case entql.OpAnd: return sql.And(e.evalExpr(expr.X), e.evalExpr(expr.Y)) case entql.OpEQ, entql.OpNEQ: if expr.Y == (*entql.Value)(nil) { f, ok := expr.X.(*entql.Field) expect(ok, "*entql.Field, got %T", expr.Y) return nullFunc[expr.Op](e.field(f)) } fallthrough default: field, ok := expr.X.(*entql.Field) expect(ok, "expr.X to be *entql.Field (got %T)", expr.X) _, ok = expr.Y.(*entql.Field) if !ok { _, ok = expr.Y.(*entql.Value) } expect(ok, "expr.Y to be *entql.Field or *entql.Value (got %T)", expr.X) switch x := expr.Y.(type) { case *entql.Field: return sql.ColumnsOp(e.field(field), e.field(x), binary[expr.Op]) case *entql.Value: c := e.field(field) return sql.P(func(b *sql.Builder) { b.Ident(c).WriteOp(binary[expr.Op]) args(b, x) }) default: panic("unreachable") } } } // evalEdge evaluates has-edge and has-edge-with calls. func (e *state) evalEdge(name string, exprs ...entql.Expr) *sql.Predicate { edge, ok := e.context.Edges[name] expect(ok, "edge %q was not found for node %q", name, e.context.Type) var fromC, toC string switch { case edge.To.ID != nil: toC = edge.To.ID.Column // Edge-owner points to its edge schema. case edge.To.CompositeID != nil && !edge.Spec.Inverse: toC = edge.To.CompositeID[0].Column // Edge-backref points to its edge schema. case edge.To.CompositeID != nil && edge.Spec.Inverse: toC = edge.To.CompositeID[1].Column default: panic(evalError{fmt.Sprintf("expect id definition for edge %q", name)}) } switch { case e.context.ID != nil: fromC = e.context.ID.Column case e.context.CompositeID != nil && (edge.Spec.Rel == M2O || (edge.Spec.Rel == O2O && edge.Spec.Inverse)): // An edge-schema with a composite id can query // only edges that it owns (holds the foreign-key). default: panic(evalError{fmt.Sprintf("unexpected edge-query from an edge-schema %q", e.context.Type)}) } step := NewStep( From(e.context.Table, fromC), To(edge.To.Table, toC), Edge(edge.Spec.Rel, edge.Spec.Inverse, edge.Spec.Table, edge.Spec.Columns...), ) selector := e.selector.Clone().SetP(nil) selector.SetTotal(e.Total()) if len(exprs) == 0 { HasNeighbors(selector, step) return selector.P() } HasNeighborsWith(selector, step, func(s *sql.Selector) { for _, expr := range exprs { if cx, ok := expr.(*entql.CallExpr); ok && cx.Func == FuncSelector { expect(len(cx.Args) == 1, "invalid number of arguments for %s", FuncSelector) wrapped, ok := cx.Args[0].(wrappedFunc) expect(ok, "invalid argument for %s: %T", FuncSelector, cx.Args[0]) wrapped.Func(s) } else { p, err := evalExpr(edge.To, s, expr) expect(err == nil, "edge evaluation failed for %s->%s: %s", e.context.Type, name, err) s.Where(p) } } }) return selector.P() } func (e *state) field(f *entql.Field) string { _, ok := e.context.Fields[f.Name] expect(ok || e.context.ID.Column == f.Name, "field %q was not found for node %q", f.Name, e.context.Type) return e.selector.C(f.Name) } func args(b *sql.Builder, v *entql.Value) { vs, ok := v.V.([]any) if !ok { b.Arg(v.V) return } b.WriteByte('(').Args(vs...).WriteByte(')') } // expect panics if the condition is false. func expect(cond bool, msg string, args ...any) { if !cond { panic(evalError{fmt.Sprintf("expect "+msg, args...)}) } } type evalError struct { msg string } func (p evalError) Error() string { return fmt.Sprintf("sqlgraph: %s", p.msg) } func catch(err *error) { if e := recover(); e != nil { xerr, ok := e.(evalError) if !ok { panic(e) } *err = xerr } }