diff --git a/dialect/sql/sqlgraph/entql_test.go b/dialect/sql/sqlgraph/entql_test.go index aba4a19bc..d1693f758 100644 --- a/dialect/sql/sqlgraph/entql_test.go +++ b/dialect/sql/sqlgraph/entql_test.go @@ -136,7 +136,7 @@ func TestGraph_EvalP(t *testing.T) { { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.HasEdgeWith("pets", entql.Or(entql.FieldEQ("name", "pedro"), entql.FieldEQ("name", "xabi"))), - wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $1 OR "pets"."name" = $2)`, + wantQuery: `SELECT * FROM "users" WHERE EXISTS (SELECT "pets"."owner_id" FROM "pets" WHERE "users"."uid" = "pets"."owner_id" AND ("pets"."name" = $1 OR "pets"."name" = $2))`, wantArgs: []any{"pedro", "xabi"}, }, { @@ -155,7 +155,7 @@ func TestGraph_EvalP(t *testing.T) { p: entql.HasEdgeWith("pets", entql.FieldEQ("name", "pedro"), WrapFunc(func(s *sql.Selector) { s.Where(sql.EQ("owner_id", 10)) })), - wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $1 AND "owner_id" = $2)`, + wantQuery: `SELECT * FROM "users" WHERE "active" AND EXISTS (SELECT "pets"."owner_id" FROM "pets" WHERE ("users"."uid" = "pets"."owner_id" AND "pets"."name" = $1) AND "owner_id" = $2)`, wantArgs: []any{"pedro", 10}, }, } diff --git a/dialect/sql/sqlgraph/graph.go b/dialect/sql/sqlgraph/graph.go index c1a3a6579..c41c6b9ca 100644 --- a/dialect/sql/sqlgraph/graph.go +++ b/dialect/sql/sqlgraph/graph.go @@ -299,18 +299,56 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) { q.Where(sql.In(q.C(s.From.Column), join)) case s.FromEdgeOwner(): to := builder.Table(s.To.Table).Schema(s.To.Schema) + // Avoid ambiguity in case both source + // and edge tables are the same. + if s.To.Table == q.TableName() { + to.As(fmt.Sprintf("%s_edge", s.To.Table)) + // Choose the alias name until we do not + // have a collision. Limit to 5 iterations. + for i := 1; i <= 5; i++ { + if to.C("c") != q.C("c") { + break + } + to.As(fmt.Sprintf("%s_edge_%d", s.To.Table, i)) + } + } matches := builder.Select(to.C(s.To.Column)). From(to) matches.WithContext(q.Context()) + matches.Where( + sql.ColumnsEQ( + q.C(s.Edge.Columns[0]), + to.C(s.To.Column), + ), + ) pred(matches) - q.Where(sql.In(q.C(s.Edge.Columns[0]), matches)) + q.Where(sql.Exists(matches)) case s.ToEdgeOwner(): to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) + // Avoid ambiguity in case both source + // and edge tables are the same. + if s.Edge.Table == q.TableName() { + to.As(fmt.Sprintf("%s_edge", s.Edge.Table)) + // Choose the alias name until we do not + // have a collision. Limit to 5 iterations. + for i := 1; i <= 5; i++ { + if to.C("c") != q.C("c") { + break + } + to.As(fmt.Sprintf("%s_edge_%d", s.Edge.Table, i)) + } + } matches := builder.Select(to.C(s.Edge.Columns[0])). From(to) matches.WithContext(q.Context()) + matches.Where( + sql.ColumnsEQ( + q.C(s.From.Column), + to.C(s.Edge.Columns[0]), + ), + ) pred(matches) - q.Where(sql.In(q.C(s.From.Column), matches)) + q.Where(sql.Exists(matches)) } } diff --git a/dialect/sql/sqlgraph/graph_test.go b/dialect/sql/sqlgraph/graph_test.go index 1cb2946d8..eed87d1eb 100644 --- a/dialect/sql/sqlgraph/graph_test.go +++ b/dialect/sql/sqlgraph/graph_test.go @@ -655,7 +655,7 @@ func TestHasNeighborsWith(t *testing.T) { predicate: func(s *sql.Selector) { s.Where(sql.EQ("expired", false)) }, - wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE NOT "expired")`, + wantQuery: `SELECT * FROM "users" WHERE EXISTS (SELECT "cards"."owner_id" FROM "cards" WHERE "users"."id" = "cards"."owner_id" AND NOT "expired")`, }, { name: "O2O/inverse", @@ -668,7 +668,7 @@ func TestHasNeighborsWith(t *testing.T) { predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "a8m")) }, - wantQuery: `SELECT * FROM "cards" WHERE "cards"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "name" = $1)`, + wantQuery: `SELECT * FROM "cards" WHERE EXISTS (SELECT "users"."id" FROM "users" WHERE "cards"."owner_id" = "users"."id" AND "name" = $1)`, wantArgs: []any{"a8m"}, }, { @@ -684,7 +684,7 @@ func TestHasNeighborsWith(t *testing.T) { predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "pedro")) }, - wantQuery: `SELECT * FROM "users" WHERE "last_name" = $1 AND "users"."id" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2)`, + wantQuery: `SELECT * FROM "users" WHERE "last_name" = $1 AND EXISTS (SELECT "pets"."owner_id" FROM "pets" WHERE "users"."id" = "pets"."owner_id" AND "name" = $2)`, wantArgs: []any{"mashraki", "pedro"}, }, { @@ -700,7 +700,7 @@ func TestHasNeighborsWith(t *testing.T) { predicate: func(s *sql.Selector) { s.Where(sql.EQ("last_name", "mashraki")) }, - wantQuery: `SELECT * FROM "pets" WHERE "name" = $1 AND "pets"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "last_name" = $2)`, + wantQuery: `SELECT * FROM "pets" WHERE "name" = $1 AND EXISTS (SELECT "users"."id" FROM "users" WHERE "pets"."owner_id" = "users"."id" AND "last_name" = $2)`, wantArgs: []any{"pedro", "mashraki"}, }, { @@ -778,7 +778,7 @@ WHERE "groups"."id" IN predicate: func(s *sql.Selector) { s.Where(sql.EQ("expired", false)) }, - wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE NOT "expired")`, + wantQuery: `SELECT * FROM "s1"."users" WHERE EXISTS (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE "s1"."users"."id" = "s2"."cards"."owner_id" AND NOT "expired")`, }, { name: "schema/O2M", @@ -797,7 +797,7 @@ WHERE "groups"."id" IN predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "pedro")) }, - wantQuery: `SELECT * FROM "s1"."users" WHERE "last_name" = $1 AND "s1"."users"."id" IN (SELECT "s2"."pets"."owner_id" FROM "s2"."pets" WHERE "name" = $2)`, + wantQuery: `SELECT * FROM "s1"."users" WHERE "last_name" = $1 AND EXISTS (SELECT "s2"."pets"."owner_id" FROM "s2"."pets" WHERE "s1"."users"."id" = "s2"."pets"."owner_id" AND "name" = $2)`, wantArgs: []any{"mashraki", "pedro"}, }, { @@ -838,7 +838,7 @@ WHERE "s1"."users"."id" IN predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "pedro")) }, - wantQuery: `SELECT * FROM (SELECT * FROM "users") AS "users" WHERE "last_name" = $1 AND "users"."id" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2)`, + wantQuery: `SELECT * FROM (SELECT * FROM "users") AS "users" WHERE "last_name" = $1 AND EXISTS (SELECT "pets"."owner_id" FROM "pets" WHERE "users"."id" = "pets"."owner_id" AND "name" = $2)`, wantArgs: []any{"mashraki", "pedro"}, }, { @@ -854,7 +854,7 @@ WHERE "s1"."users"."id" IN predicate: func(s *sql.Selector) { s.Where(sql.EQ("last_name", "mashraki")) }, - wantQuery: `SELECT * FROM (SELECT * FROM "pets") AS "pets" WHERE "name" = $1 AND "pets"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "last_name" = $2)`, + wantQuery: `SELECT * FROM (SELECT * FROM "pets") AS "pets" WHERE "name" = $1 AND EXISTS (SELECT "users"."id" FROM "users" WHERE "pets"."owner_id" = "users"."id" AND "last_name" = $2)`, wantArgs: []any{"pedro", "mashraki"}, }, {