From 413bbad8d876a759f98800ef06685d3981136405 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Tue, 26 Nov 2019 12:22:08 -0800 Subject: [PATCH] dialect/sqlgraph: fix M2O relation in neighbors-with check Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/198 Reviewed By: alexsn Differential Revision: D18707741 fbshipit-source-id: 69dd010e27ee07ffe44acc12003b9772220aaa2a --- dialect/sql/graph.go | 2 +- dialect/sql/graph_test.go | 86 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index 1fbea896f..75f2b6b0d 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -202,7 +202,7 @@ func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) { matches := builder.Select(to.C(s.To.Column)). From(to) pred(matches) - q.Where(In(from.C(s.From.Column), matches)) + q.Where(In(from.C(s.Edge.Columns[0]), matches)) case r == O2M || (r == O2O && !s.Edge.Inverse): from := q.Table() to := builder.Table(s.Edge.Table) diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index ab8f37bdc..84027baa6 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -443,6 +443,92 @@ func TestHasNeighborsWith(t *testing.T) { wantQuery string wantArgs []interface{} }{ + { + name: "O2O", + step: func() *Step { + step := &Step{} + step.From.Table = "users" + step.From.Column = "id" + step.To.Table = "cards" + step.To.Column = "id" + step.Edge.Rel = O2O + step.Edge.Table = "cards" + step.Edge.Columns = []string{"owner_id"} + return step + }(), + selector: Dialect("postgres").Select("*").From(Table("users")), + predicate: func(s *Selector) { + s.Where(EQ("expired", false)) + }, + wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE "expired" = $1)`, + wantArgs: []interface{}{false}, + }, + { + name: "O2O/inverse", + step: func() *Step { + step := &Step{} + step.From.Table = "cards" + step.From.Column = "id" + step.To.Table = "users" + step.To.Column = "id" + step.Edge.Rel = O2O + step.Edge.Table = "cards" + step.Edge.Inverse = true + step.Edge.Columns = []string{"owner_id"} + return step + }(), + selector: Dialect("postgres").Select("*").From(Table("cards")), + predicate: func(s *Selector) { + s.Where(EQ("name", "a8m")) + }, + wantQuery: `SELECT * FROM "cards" WHERE "cards"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "name" = $1)`, + wantArgs: []interface{}{"a8m"}, + }, + { + name: "O2M", + step: func() *Step { + step := &Step{} + step.From.Table = "users" + step.From.Column = "id" + step.To.Table = "pets" + step.To.Column = "id" + step.Edge.Rel = O2M + step.Edge.Table = "pets" + step.Edge.Columns = []string{"owner_id"} + return step + }(), + selector: Dialect("postgres").Select("*"). + From(Table("users")). + Where(EQ("last_name", "mashraki")), + predicate: func(s *Selector) { + s.Where(EQ("name", "pedro")) + }, + wantQuery: `SELECT * FROM "users" WHERE "last_name" = $1 AND "users"."id" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2)`, + wantArgs: []interface{}{"mashraki", "pedro"}, + }, + { + name: "M2O", + step: func() *Step { + step := &Step{} + step.From.Table = "pets" + step.From.Column = "id" + step.To.Table = "users" + step.To.Column = "id" + step.Edge.Rel = M2O + step.Edge.Table = "pets" + step.Edge.Inverse = true + step.Edge.Columns = []string{"owner_id"} + return step + }(), + selector: Dialect("postgres").Select("*"). + From(Table("pets")). + Where(EQ("name", "pedro")), + predicate: func(s *Selector) { + s.Where(EQ("last_name", "mashraki")) + }, + wantQuery: `SELECT * FROM "pets" WHERE "name" = $1 AND "pets"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "last_name" = $2)`, + wantArgs: []interface{}{"pedro", "mashraki"}, + }, { name: "M2M", step: func() *Step {