diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index 0b983496f..e5136a5db 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -176,3 +176,39 @@ func HasNeighbors(q *Selector, s *Step) { ) } } + +// HasNeighborsWith applies on the given Selector a neighbors check. +// The given predicate applies its filtering on the selector. +func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) { + builder := Dialect(q.dialect) + switch r := s.Edge.Rel; { + case r == M2M: + pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] + if s.Edge.Inverse { + pk1, pk2 = pk2, pk1 + } + from := q.Table() + to := builder.Table(s.To.Table) + join := builder.Table(s.Edge.Table) + matches := builder.Select(join.C(pk2)). + From(join). + Join(to). + On(join.C(pk1), to.C(s.To.Column)) + pred(matches) + q.Where(In(from.C(s.From.Column), matches)) + case r == M2O || (r == O2O && s.Edge.Inverse): + from := q.Table() + to := builder.Table(s.To.Table) + matches := builder.Select(to.C(s.To.Column)). + From(to) + pred(matches) + q.Where(In(from.C(s.From.Column), matches)) + case r == O2M || (r == O2O && !s.Edge.Inverse): + from := q.Table() + to := builder.Table(s.Edge.Table) + matches := builder.Select(to.C(s.Edge.Columns[0])). + From(to) + pred(matches) + q.Where(In(from.C(s.From.Column), matches)) + } +} diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index 11183ad6b..f94be237a 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -326,7 +326,6 @@ func TestHasNeighbors(t *testing.T) { // node holds association pointer. The neighbors query // here checks if a node "has-next". step := &Step{} - step.From.V = 1 step.From.Table = "nodes" step.From.Column = "id" step.To.Table = "nodes" @@ -345,7 +344,6 @@ func TestHasNeighbors(t *testing.T) { // Same example as above, but the neighbors // query checks if a node "has-previous". step := &Step{} - step.From.V = 1 step.From.Table = "nodes" step.From.Column = "id" step.To.Table = "nodes" @@ -369,3 +367,77 @@ func TestHasNeighbors(t *testing.T) { }) } } + +func TestHasNeighborsWith(t *testing.T) { + tests := []struct { + name string + step *Step + selector *Selector + predicate func(*Selector) + wantQuery string + wantArgs []interface{} + }{ + { + name: "M2M", + step: func() *Step { + step := &Step{} + step.From.Table = "users" + step.From.Column = "id" + step.To.Table = "groups" + step.To.Column = "id" + step.Edge.Rel = M2M + step.Edge.Table = "user_groups" + step.Edge.Columns = []string{"user_id", "group_id"} + return step + }(), + selector: Dialect("postgres").Select("*").From(Table("users")), + predicate: func(s *Selector) { + s.Where(EQ("name", "GitHub")) + }, + wantQuery: ` +SELECT * +FROM "users" +WHERE "users"."id" IN + (SELECT "user_groups"."user_id" + FROM "user_groups" + JOIN "groups" AS "t0" ON "user_groups"."group_id" = "t0"."id" WHERE "name" = $1)`, + wantArgs: []interface{}{"GitHub"}, + }, + { + name: "M2M/inverse", + step: func() *Step { + step := &Step{} + step.From.Table = "groups" + step.From.Column = "id" + step.To.Table = "users" + step.To.Column = "id" + step.Edge.Rel = M2M + step.Edge.Table = "user_groups" + step.Edge.Inverse = true + step.Edge.Columns = []string{"user_id", "group_id"} + return step + }(), + selector: Dialect("postgres").Select("*").From(Table("groups")), + predicate: func(s *Selector) { + s.Where(EQ("name", "a8m")) + }, + wantQuery: ` +SELECT * +FROM "groups" +WHERE "groups"."id" IN + (SELECT "user_groups"."group_id" + FROM "user_groups" + JOIN "users" AS "t0" ON "user_groups"."user_id" = "t0"."id" WHERE "name" = $1)`, + wantArgs: []interface{}{"a8m"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + HasNeighborsWith(tt.selector, tt.step, tt.predicate) + query, args := tt.selector.Query() + tt.wantQuery = strings.Join(strings.Fields(tt.wantQuery), " ") + require.Equal(t, tt.wantQuery, query) + require.Equal(t, tt.wantArgs, args) + }) + } +}