From 038dc2899a9bca30cd6a0328dc1c060b95c77a73 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sun, 24 Nov 2019 22:09:28 -0800 Subject: [PATCH] dialect/sqlgraph: test and fix M2M neighbors check Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/190 Reviewed By: alexsn Differential Revision: D18680440 fbshipit-source-id: b2dce0ee3e0a898fab901c4287c6c1c1844e51d6 --- dialect/sql/graph.go | 4 +-- dialect/sql/graph_test.go | 66 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index e5136a5db..1fbea896f 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -148,9 +148,9 @@ func HasNeighbors(q *Selector, s *Step) { builder := Dialect(q.dialect) switch r := s.Edge.Rel; { case r == M2M: - pk1 := s.Edge.Columns[1] + pk1 := s.Edge.Columns[0] if s.Edge.Inverse { - pk1 = s.Edge.Columns[0] + pk1 = s.Edge.Columns[1] } from := q.Table() join := builder.Table(s.Edge.Table) diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index f94be237a..ab8f37bdc 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -357,6 +357,72 @@ func TestHasNeighbors(t *testing.T) { selector: Select("*").From(Table("nodes")), wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL", }, + { + name: "O2M/2type2", + step: func() *Step { + step := &Step{} + step.From.Table = "users" + step.From.Column = "id" + step.To.Table = "pets" + step.To.Column = "id" + step.Edge.Rel = O2M + step.Edge.Table = "pets" + step.Edge.Columns = []string{"owner_id"} + return step + }(), + selector: Select("*").From(Table("users")), + wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `pets`.`owner_id` FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL)", + }, + { + name: "M2O/2type2", + step: func() *Step { + step := &Step{} + step.From.Table = "pets" + step.From.Column = "id" + step.To.Table = "users" + step.To.Column = "id" + step.Edge.Rel = M2O + step.Edge.Inverse = true + step.Edge.Table = "pets" + step.Edge.Columns = []string{"owner_id"} + return step + }(), + selector: Select("*").From(Table("pets")), + wantQuery: "SELECT * FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL", + }, + { + name: "M2M/2types", + step: func() *Step { + step := &Step{} + step.From.Table = "users" + step.From.Column = "id" + step.To.Table = "groups" + step.To.Column = "id" + step.Edge.Rel = M2M + step.Edge.Table = "user_groups" + step.Edge.Columns = []string{"user_id", "group_id"} + return step + }(), + selector: Select("*").From(Table("users")), + wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `user_groups`.`user_id` FROM `user_groups`)", + }, + { + name: "M2M/2types/inverse", + step: func() *Step { + step := &Step{} + step.From.Table = "users" + step.From.Column = "id" + step.To.Table = "groups" + step.To.Column = "id" + step.Edge.Rel = M2M + step.Edge.Inverse = true + step.Edge.Table = "group_users" + step.Edge.Columns = []string{"group_id", "user_id"} + return step + }(), + selector: Select("*").From(Table("users")), + wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `group_users`.`user_id` FROM `group_users`)", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {