dialect/sql/sqlgraph: add query edges function (#274)

This commit is contained in:
Ariel Mashraki
2020-01-07 19:50:33 +02:00
committed by GitHub
parent 9cb0eb7467
commit b93958ebf4
2 changed files with 80 additions and 1 deletions

View File

@@ -452,7 +452,7 @@ type QuerySpec struct {
Assign func(...interface{}) error
}
// QueryNodes query the nodes in the graph query and scans them to the given values.
// QueryNodes queries the nodes in the graph query and scans them to the given values.
func QueryNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) error {
builder := sql.Dialect(drv.Dialect())
qr := &query{graph: graph{builder: builder}, QuerySpec: spec}
@@ -466,6 +466,48 @@ func CountNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) (int,
return qr.count(ctx, drv)
}
// EdgeQuerySpec holds the information for querying
// edges in the graph.
type EdgeQuerySpec struct {
Edge *EdgeSpec
Predicate func(*sql.Selector)
ScanValues func() [2]interface{}
Assign func(out, in interface{}) error
}
// QueryEdges queries the edges in the graph and scans the result with the given dest function.
func QueryEdges(ctx context.Context, drv dialect.Driver, spec *EdgeQuerySpec) error {
if len(spec.Edge.Columns) != 2 {
return fmt.Errorf("sqlgraph: edge query requires 2 columns (out, in)")
}
out, in := spec.Edge.Columns[0], spec.Edge.Columns[1]
if spec.Edge.Inverse {
out, in = in, out
}
selector := sql.Dialect(drv.Dialect()).
Select(out, in).
From(sql.Table(spec.Edge.Table))
if p := spec.Predicate; p != nil {
p(selector)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := drv.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
for rows.Next() {
values := spec.ScanValues()
if err := rows.Scan(values[0], values[1]); err != nil {
return err
}
if err := spec.Assign(values[0], values[1]); err != nil {
return err
}
}
return nil
}
type query struct {
graph
*QuerySpec

View File

@@ -1207,6 +1207,43 @@ func TestQueryNodes(t *testing.T) {
require.Equal(t, 3, n)
}
func TestQueryEdges(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
mock.ExpectQuery(escape("SELECT `group_id`, `user_id` FROM `user_groups` WHERE `user_id` IN (?, ?, ?)")).
WithArgs(1, 2, 3).
WillReturnRows(sqlmock.NewRows([]string{"group_id", "user_id"}).
AddRow(4, 5).
AddRow(4, 6))
var (
edges [][]int64
spec = &EdgeQuerySpec{
Edge: &EdgeSpec{
Inverse: true,
Table: "user_groups",
Columns: []string{"user_id", "group_id"},
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues("user_id", 1, 2, 3))
},
ScanValues: func() [2]interface{} {
return [2]interface{}{&sql.NullInt64{}, &sql.NullInt64{}}
},
Assign: func(out, in interface{}) error {
o, i := out.(*sql.NullInt64), in.(*sql.NullInt64)
edges = append(edges, []int64{o.Int64, i.Int64})
return nil
},
}
)
// Query and scan.
err = QueryEdges(context.Background(), sql.OpenDB("", db), spec)
require.NoError(t, err)
require.Equal(t, [][]int64{{4, 5}, {4, 6}}, edges)
}
func escape(query string) string {
rows := strings.Split(query, "\n")
for i := range rows {