dialect/sqlgraph: test and fix M2M neighbors check

Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/190

Reviewed By: alexsn

Differential Revision: D18680440

fbshipit-source-id: b2dce0ee3e0a898fab901c4287c6c1c1844e51d6
This commit is contained in:
Ariel Mashraki
2019-11-24 22:09:28 -08:00
committed by Facebook Github Bot
parent 0344904a4e
commit 038dc2899a
2 changed files with 68 additions and 2 deletions

View File

@@ -148,9 +148,9 @@ func HasNeighbors(q *Selector, s *Step) {
builder := Dialect(q.dialect)
switch r := s.Edge.Rel; {
case r == M2M:
pk1 := s.Edge.Columns[1]
pk1 := s.Edge.Columns[0]
if s.Edge.Inverse {
pk1 = s.Edge.Columns[0]
pk1 = s.Edge.Columns[1]
}
from := q.Table()
join := builder.Table(s.Edge.Table)

View File

@@ -357,6 +357,72 @@ func TestHasNeighbors(t *testing.T) {
selector: Select("*").From(Table("nodes")),
wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL",
},
{
name: "O2M/2type2",
step: func() *Step {
step := &Step{}
step.From.Table = "users"
step.From.Column = "id"
step.To.Table = "pets"
step.To.Column = "id"
step.Edge.Rel = O2M
step.Edge.Table = "pets"
step.Edge.Columns = []string{"owner_id"}
return step
}(),
selector: Select("*").From(Table("users")),
wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `pets`.`owner_id` FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL)",
},
{
name: "M2O/2type2",
step: func() *Step {
step := &Step{}
step.From.Table = "pets"
step.From.Column = "id"
step.To.Table = "users"
step.To.Column = "id"
step.Edge.Rel = M2O
step.Edge.Inverse = true
step.Edge.Table = "pets"
step.Edge.Columns = []string{"owner_id"}
return step
}(),
selector: Select("*").From(Table("pets")),
wantQuery: "SELECT * FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL",
},
{
name: "M2M/2types",
step: func() *Step {
step := &Step{}
step.From.Table = "users"
step.From.Column = "id"
step.To.Table = "groups"
step.To.Column = "id"
step.Edge.Rel = M2M
step.Edge.Table = "user_groups"
step.Edge.Columns = []string{"user_id", "group_id"}
return step
}(),
selector: Select("*").From(Table("users")),
wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `user_groups`.`user_id` FROM `user_groups`)",
},
{
name: "M2M/2types/inverse",
step: func() *Step {
step := &Step{}
step.From.Table = "users"
step.From.Column = "id"
step.To.Table = "groups"
step.To.Column = "id"
step.Edge.Rel = M2M
step.Edge.Inverse = true
step.Edge.Table = "group_users"
step.Edge.Columns = []string{"group_id", "user_id"}
return step
}(),
selector: Select("*").From(Table("users")),
wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `group_users`.`user_id` FROM `group_users`)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {