diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index 51e58394e..0b983496f 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -142,3 +142,37 @@ func SetNeighbors(dialect string, s *Step) (q *Selector) { } return q } + +// HasNeighbors applies on the given Selector a neighbors check. +func HasNeighbors(q *Selector, s *Step) { + builder := Dialect(q.dialect) + switch r := s.Edge.Rel; { + case r == M2M: + pk1 := s.Edge.Columns[1] + if s.Edge.Inverse { + pk1 = s.Edge.Columns[0] + } + from := q.Table() + join := builder.Table(s.Edge.Table) + q.Where( + In( + from.C(s.From.Column), + builder.Select(join.C(pk1)).From(join), + ), + ) + case r == M2O || (r == O2O && s.Edge.Inverse): + from := q.Table() + q.Where(NotNull(from.C(s.Edge.Columns[0]))) + case r == O2M || (r == O2O && !s.Edge.Inverse): + from := q.Table() + to := builder.Table(s.Edge.Table) + q.Where( + In( + from.C(s.From.Column), + builder.Select(to.C(s.Edge.Columns[0])). + From(to). + Where(NotNull(to.C(s.Edge.Columns[0]))), + ), + ) + } +} diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index 8436c4cd2..11183ad6b 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -311,3 +311,61 @@ JOIN }) } } + +func TestHasNeighbors(t *testing.T) { + tests := []struct { + name string + step *Step + selector *Selector + wantQuery string + }{ + { + name: "O2O/1type", + step: func() *Step { + // A nodes table; linked-list (next->prev). The "prev" + // node holds association pointer. The neighbors query + // here checks if a node "has-next". + step := &Step{} + step.From.V = 1 + step.From.Table = "nodes" + step.From.Column = "id" + step.To.Table = "nodes" + step.To.Column = "id" + step.Edge.Rel = O2O + step.Edge.Table = "nodes" + step.Edge.Columns = []string{"prev_id"} + return step + }(), + selector: Select("*").From(Table("nodes")), + wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`id` IN (SELECT `nodes`.`prev_id` FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL)", + }, + { + name: "O2O/1type/inverse", + step: func() *Step { + // Same example as above, but the neighbors + // query checks if a node "has-previous". + step := &Step{} + step.From.V = 1 + step.From.Table = "nodes" + step.From.Column = "id" + step.To.Table = "nodes" + step.To.Column = "id" + step.Edge.Rel = O2O + step.Edge.Inverse = true + step.Edge.Table = "nodes" + step.Edge.Columns = []string{"prev_id"} + return step + }(), + selector: Select("*").From(Table("nodes")), + wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + HasNeighbors(tt.selector, tt.step) + query, args := tt.selector.Query() + require.Equal(t, tt.wantQuery, query) + require.Empty(t, args) + }) + } +}