mirror of
https://github.com/ent/ent.git
synced 2026-04-29 06:00:55 +03:00
dialect/sql/sqlgraph: use selector's column instead of table's column in HasNeighbors (#2060)
Allows for HasNeighbors' selector to be a generic selector instead of *SelectTable. A use case for this is doing HasNeighbors* function on materialized queries instead of on concrete tables only.
This commit is contained in:
@@ -228,23 +228,20 @@ func HasNeighbors(q *sql.Selector, s *Step) {
|
||||
if s.Edge.Inverse {
|
||||
pk1 = s.Edge.Columns[1]
|
||||
}
|
||||
from := q.Table()
|
||||
join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
|
||||
q.Where(
|
||||
sql.In(
|
||||
from.C(s.From.Column),
|
||||
q.C(s.From.Column),
|
||||
builder.Select(join.C(pk1)).From(join),
|
||||
),
|
||||
)
|
||||
case r == M2O || (r == O2O && s.Edge.Inverse):
|
||||
from := q.Table()
|
||||
q.Where(sql.NotNull(from.C(s.Edge.Columns[0])))
|
||||
q.Where(sql.NotNull(q.C(s.Edge.Columns[0])))
|
||||
case r == O2M || (r == O2O && !s.Edge.Inverse):
|
||||
from := q.Table()
|
||||
to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
|
||||
q.Where(
|
||||
sql.In(
|
||||
from.C(s.From.Column),
|
||||
q.C(s.From.Column),
|
||||
builder.Select(to.C(s.Edge.Columns[0])).
|
||||
From(to).
|
||||
Where(sql.NotNull(to.C(s.Edge.Columns[0]))),
|
||||
@@ -263,7 +260,6 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
if s.Edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
from := q.Table()
|
||||
to := builder.Table(s.To.Table).Schema(s.To.Schema)
|
||||
edge := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
|
||||
join := builder.Select(edge.C(pk2)).
|
||||
@@ -274,23 +270,21 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
join.FromSelect(matches)
|
||||
q.Where(sql.In(from.C(s.From.Column), join))
|
||||
q.Where(sql.In(q.C(s.From.Column), join))
|
||||
case r == M2O || (r == O2O && s.Edge.Inverse):
|
||||
from := q.Table()
|
||||
to := builder.Table(s.To.Table).Schema(s.To.Schema)
|
||||
matches := builder.Select(to.C(s.To.Column)).
|
||||
From(to)
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
q.Where(sql.In(from.C(s.Edge.Columns[0]), matches))
|
||||
q.Where(sql.In(q.C(s.Edge.Columns[0]), matches))
|
||||
case r == O2M || (r == O2O && !s.Edge.Inverse):
|
||||
from := q.Table()
|
||||
to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
|
||||
matches := builder.Select(to.C(s.Edge.Columns[0])).
|
||||
From(to)
|
||||
matches.WithContext(q.Context())
|
||||
pred(matches)
|
||||
q.Where(sql.In(from.C(s.From.Column), matches))
|
||||
q.Where(sql.In(q.C(s.From.Column), matches))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -592,6 +592,36 @@ func TestHasNeighbors(t *testing.T) {
|
||||
selector: sql.Select("*").From(sql.Table("users").Schema("s1")),
|
||||
wantQuery: "SELECT * FROM `s1`.`users` WHERE `s1`.`users`.`id` IN (SELECT `s2`.`group_users`.`user_id` FROM `s2`.`group_users`)",
|
||||
},
|
||||
{
|
||||
name: "O2M/2type2/selector",
|
||||
step: NewStep(
|
||||
From("users", "id"),
|
||||
To("pets", "id"),
|
||||
Edge(O2M, false, "pets", "owner_id"),
|
||||
),
|
||||
selector: sql.Select("*").From(sql.Select("*").From(sql.Table("users")).As("users")).As("users"),
|
||||
wantQuery: "SELECT * FROM (SELECT * FROM `users`) AS `users` WHERE `users`.`id` IN (SELECT `pets`.`owner_id` FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL)",
|
||||
},
|
||||
{
|
||||
name: "M2O/2type2/selector",
|
||||
step: NewStep(
|
||||
From("pets", "id"),
|
||||
To("users", "id"),
|
||||
Edge(M2O, true, "pets", "owner_id"),
|
||||
),
|
||||
selector: sql.Select("*").From(sql.Select("*").From(sql.Table("pets")).As("pets")).As("pets"),
|
||||
wantQuery: "SELECT * FROM (SELECT * FROM `pets`) AS `pets` WHERE `pets`.`owner_id` IS NOT NULL",
|
||||
},
|
||||
{
|
||||
name: "M2M/2types/selector",
|
||||
step: NewStep(
|
||||
From("users", "id"),
|
||||
To("groups", "id"),
|
||||
Edge(M2M, false, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
selector: sql.Select("*").From(sql.Select("*").From(sql.Table("users")).As("users")).As("users"),
|
||||
wantQuery: "SELECT * FROM (SELECT * FROM `users`) AS `users` WHERE `users`.`id` IN (SELECT `user_groups`.`user_id` FROM `user_groups`)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -795,6 +825,52 @@ WHERE "s1"."users"."id" IN
|
||||
JOIN "s3"."groups" AS "t1" ON "s2"."user_groups"."group_id" = "t1"."id" WHERE "name" = $1)`,
|
||||
wantArgs: []interface{}{"GitHub"},
|
||||
},
|
||||
{
|
||||
name: "O2M/selector",
|
||||
step: NewStep(
|
||||
From("users", "id"),
|
||||
To("pets", "id"),
|
||||
Edge(O2M, false, "pets", "owner_id"),
|
||||
),
|
||||
selector: sql.Dialect("postgres").Select("*").
|
||||
From(sql.Select("*").From(sql.Table("users")).As("users")).
|
||||
Where(sql.EQ("last_name", "mashraki")).As("users"),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("name", "pedro"))
|
||||
},
|
||||
wantQuery: `SELECT * FROM (SELECT * FROM "users") AS "users" WHERE "last_name" = $1 AND "users"."id" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2)`,
|
||||
wantArgs: []interface{}{"mashraki", "pedro"},
|
||||
},
|
||||
{
|
||||
name: "M2O/selector",
|
||||
step: NewStep(
|
||||
From("pets", "id"),
|
||||
To("users", "id"),
|
||||
Edge(M2O, true, "pets", "owner_id"),
|
||||
),
|
||||
selector: sql.Dialect("postgres").Select("*").
|
||||
From(sql.Select("*").From(sql.Table("pets")).As("pets")).
|
||||
Where(sql.EQ("name", "pedro")).As("pets"),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("last_name", "mashraki"))
|
||||
},
|
||||
wantQuery: `SELECT * FROM (SELECT * FROM "pets") AS "pets" WHERE "name" = $1 AND "pets"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "last_name" = $2)`,
|
||||
wantArgs: []interface{}{"pedro", "mashraki"},
|
||||
},
|
||||
{
|
||||
name: "M2M/selector",
|
||||
step: NewStep(
|
||||
From("users", "id"),
|
||||
To("groups", "id"),
|
||||
Edge(M2M, false, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
selector: sql.Dialect("postgres").Select("*").From(sql.Select("*").From(sql.Table("users")).As("users")).As("users"),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("name", "GitHub"))
|
||||
},
|
||||
wantQuery: `SELECT * FROM (SELECT * FROM "users") AS "users" WHERE "users"."id" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."id" WHERE "name" = $1)`,
|
||||
wantArgs: []interface{}{"GitHub"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user