diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index fdb2cb2a5..7dc2df2c2 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -237,13 +237,15 @@ func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) { } from := q.Table() to := builder.Table(s.To.Table) - join := builder.Table(s.Edge.Table) - matches := builder.Select(join.C(pk2)). - From(join). + edge := builder.Table(s.Edge.Table) + join := builder.Select(edge.C(pk2)). + From(edge). Join(to). - On(join.C(pk1), to.C(s.To.Column)) + On(edge.C(pk1), to.C(s.To.Column)) + matches := builder.Select().From(to) pred(matches) - q.Where(In(from.C(s.From.Column), matches)) + join.FromSelect(matches) + q.Where(In(from.C(s.From.Column), join)) case r == M2O || (r == O2O && s.Edge.Inverse): from := q.Table() to := builder.Table(s.To.Table) diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index 7193554c2..f3c619129 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -400,6 +400,26 @@ WHERE "groups"."id" IN JOIN "users" AS "t0" ON "user_groups"."user_id" = "t0"."id" WHERE "name" = $1)`, wantArgs: []interface{}{"a8m"}, }, + { + name: "M2M/inverse", + step: NewStep( + From("groups", "id"), + To("users", "id"), + Edge(M2M, true, "user_groups", "user_id", "group_id"), + ), + selector: Dialect("postgres").Select("*").From(Table("groups")), + predicate: func(s *Selector) { + s.Where(And(NotNull("name"), EQ("name", "a8m"))) + }, + wantQuery: ` +SELECT * +FROM "groups" +WHERE "groups"."id" IN + (SELECT "user_groups"."group_id" + FROM "user_groups" + JOIN "users" AS "t0" ON "user_groups"."user_id" = "t0"."id" WHERE ("name" IS NOT NULL) AND ("name" = $1))`, + wantArgs: []interface{}{"a8m"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {