diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index 568f7b4b7..0e4d660d0 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -68,7 +68,7 @@ type Step struct { } // Neighbors returns a Selector for evaluating the path-step -// and getting the neighbors of one or more vertices. +// and getting the neighbors of one vertex. func Neighbors(dialect string, s *Step) (q *Selector) { builder := Dialect(dialect) switch r := s.Edge.Rel; { @@ -102,3 +102,43 @@ func Neighbors(dialect string, s *Step) (q *Selector) { } return q } + +// SetNeighbors returns a Selector for evaluating the path-step +// and getting the neighbors of set of vertices. +func SetNeighbors(dialect string, s *Step) (q *Selector) { + set := s.From.V.(*Selector) + builder := Dialect(dialect) + switch r := s.Edge.Rel; { + case r == M2M: + pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] + if s.Edge.Inverse { + pk1, pk2 = pk2, pk1 + } + to := builder.Table(s.To.Table) + set.Select(s.From.Column) + join := builder.Table(s.Edge.Table) + match := builder.Select(join.C(pk1)). + From(join). + Join(set). + On(join.C(pk2), set.C(s.From.Column)) + q = builder.Select(). + From(to). + Join(match). + On(to.C(s.To.Column), match.C(pk1)) + case r == M2O || (r == O2O && s.Edge.Inverse): + t1 := builder.Table(s.To.Table) + set.Select(s.Edge.Columns[0]) + q = builder.Select(). + From(t1). + Join(set). + On(t1.C(s.To.Column), set.C(s.Edge.Columns[0])) + case r == O2M || (r == O2O && !s.Edge.Inverse): + t1 := builder.Table(s.To.Table) + set.Select(s.From.Column) + q = builder.Select(). + From(t1). + Join(set). + On(t1.C(s.Edge.Columns[0]), set.C(s.From.Column)) + } + return q +} diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index 1a028add8..94e74981a 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -5,6 +5,7 @@ package sql import ( + "strings" "testing" "github.com/stretchr/testify/require" @@ -204,3 +205,109 @@ func TestNeighbors(t *testing.T) { }) } } + +func TestSetNeighbors(t *testing.T) { + tests := []struct { + name string + input *Step + wantQuery string + wantArgs []interface{} + }{ + { + name: "O2M/2types", + input: func() *Step { + step := &Step{} + step.From.V = Select().From(Table("users")).Where(EQ("name", "a8m")) + step.From.Table = "users" + step.From.Column = "id" + step.To.Table = "pets" + step.To.Column = "id" + step.Edge.Rel = O2M + step.Edge.Table = "pets" + step.Edge.Columns = []string{"owner_id"} + return step + }(), + wantQuery: `SELECT * FROM "pets" JOIN (SELECT "id" FROM "users" WHERE "name" = $1) AS "t1" ON "pets"."owner_id" = "t1"."id"`, + wantArgs: []interface{}{"a8m"}, + }, + { + name: "M2O/2types", + input: func() *Step { + step := &Step{} + step.From.V = Select().From(Table("pets")).Where(EQ("name", "pedro")) + step.From.Table = "pets" + step.From.Column = "id" + step.To.Table = "users" + step.To.Column = "id" + step.Edge.Rel = M2O + step.Edge.Table = "pets" + step.Edge.Columns = []string{"owner_id"} + return step + }(), + wantQuery: `SELECT * FROM "users" JOIN (SELECT "owner_id" FROM "pets" WHERE "name" = $1) AS "t1" ON "users"."id" = "t1"."owner_id"`, + wantArgs: []interface{}{"pedro"}, + }, + { + name: "M2M/2types", + input: func() *Step { + step := &Step{} + step.From.V = Select().From(Table("users")).Where(EQ("name", "a8m")) + step.From.Table = "users" + step.From.Column = "id" + step.To.Table = "groups" + step.To.Column = "id" + step.Edge.Rel = M2M + step.Edge.Table = "user_groups" + step.Edge.Columns = []string{"user_id", "group_id"} + return step + }(), + wantQuery: ` +SELECT * +FROM "groups" +JOIN + (SELECT "user_groups"."group_id" + FROM "user_groups" + JOIN + (SELECT "id" + FROM "users" + WHERE "name" = $1) AS "t1" ON "user_groups"."user_id" = "t1"."id") AS "t1" ON "groups"."id" = "t1"."group_id"`, + wantArgs: []interface{}{"a8m"}, + }, + { + name: "M2M/2types/inverse", + input: func() *Step { + step := &Step{} + step.From.V = Select().From(Table("groups")).Where(EQ("name", "GitHub")) + step.From.Table = "groups" + step.From.Column = "id" + step.To.Table = "users" + step.To.Column = "id" + step.Edge.Rel = M2M + step.Edge.Inverse = true + step.Edge.Table = "user_groups" + step.Edge.Columns = []string{"user_id", "group_id"} + return step + }(), + wantQuery: ` +SELECT * +FROM "users" +JOIN + (SELECT "user_groups"."user_id" + FROM "user_groups" + JOIN + (SELECT "id" + FROM "groups" + WHERE "name" = $1) AS "t1" ON "user_groups"."group_id" = "t1"."id") AS "t1" ON "users"."id" = "t1"."user_id"`, + wantArgs: []interface{}{"GitHub"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selector := SetNeighbors("postgres", tt.input) + query, args := selector.Query() + tt.wantQuery = strings.Join(strings.Fields(tt.wantQuery), " ") + require.Equal(t, tt.wantQuery, query) + require.Equal(t, tt.wantArgs, args) + }) + } +}