mirror of
https://github.com/ent/ent.git
synced 2026-04-28 21:50:56 +03:00
dialect/sql/sqlgraph: add query edges function (#274)
This commit is contained in:
@@ -452,7 +452,7 @@ type QuerySpec struct {
|
||||
Assign func(...interface{}) error
|
||||
}
|
||||
|
||||
// QueryNodes query the nodes in the graph query and scans them to the given values.
|
||||
// QueryNodes queries the nodes in the graph query and scans them to the given values.
|
||||
func QueryNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) error {
|
||||
builder := sql.Dialect(drv.Dialect())
|
||||
qr := &query{graph: graph{builder: builder}, QuerySpec: spec}
|
||||
@@ -466,6 +466,48 @@ func CountNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) (int,
|
||||
return qr.count(ctx, drv)
|
||||
}
|
||||
|
||||
// EdgeQuerySpec holds the information for querying
|
||||
// edges in the graph.
|
||||
type EdgeQuerySpec struct {
|
||||
Edge *EdgeSpec
|
||||
Predicate func(*sql.Selector)
|
||||
ScanValues func() [2]interface{}
|
||||
Assign func(out, in interface{}) error
|
||||
}
|
||||
|
||||
// QueryEdges queries the edges in the graph and scans the result with the given dest function.
|
||||
func QueryEdges(ctx context.Context, drv dialect.Driver, spec *EdgeQuerySpec) error {
|
||||
if len(spec.Edge.Columns) != 2 {
|
||||
return fmt.Errorf("sqlgraph: edge query requires 2 columns (out, in)")
|
||||
}
|
||||
out, in := spec.Edge.Columns[0], spec.Edge.Columns[1]
|
||||
if spec.Edge.Inverse {
|
||||
out, in = in, out
|
||||
}
|
||||
selector := sql.Dialect(drv.Dialect()).
|
||||
Select(out, in).
|
||||
From(sql.Table(spec.Edge.Table))
|
||||
if p := spec.Predicate; p != nil {
|
||||
p(selector)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := drv.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
values := spec.ScanValues()
|
||||
if err := rows.Scan(values[0], values[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := spec.Assign(values[0], values[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type query struct {
|
||||
graph
|
||||
*QuerySpec
|
||||
|
||||
@@ -1207,6 +1207,43 @@ func TestQueryNodes(t *testing.T) {
|
||||
require.Equal(t, 3, n)
|
||||
}
|
||||
|
||||
func TestQueryEdges(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
mock.ExpectQuery(escape("SELECT `group_id`, `user_id` FROM `user_groups` WHERE `user_id` IN (?, ?, ?)")).
|
||||
WithArgs(1, 2, 3).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"group_id", "user_id"}).
|
||||
AddRow(4, 5).
|
||||
AddRow(4, 6))
|
||||
|
||||
var (
|
||||
edges [][]int64
|
||||
spec = &EdgeQuerySpec{
|
||||
Edge: &EdgeSpec{
|
||||
Inverse: true,
|
||||
Table: "user_groups",
|
||||
Columns: []string{"user_id", "group_id"},
|
||||
},
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.InValues("user_id", 1, 2, 3))
|
||||
},
|
||||
ScanValues: func() [2]interface{} {
|
||||
return [2]interface{}{&sql.NullInt64{}, &sql.NullInt64{}}
|
||||
},
|
||||
Assign: func(out, in interface{}) error {
|
||||
o, i := out.(*sql.NullInt64), in.(*sql.NullInt64)
|
||||
edges = append(edges, []int64{o.Int64, i.Int64})
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Query and scan.
|
||||
err = QueryEdges(context.Background(), sql.OpenDB("", db), spec)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]int64{{4, 5}, {4, 6}}, edges)
|
||||
}
|
||||
|
||||
func escape(query string) string {
|
||||
rows := strings.Split(query, "\n")
|
||||
for i := range rows {
|
||||
|
||||
Reference in New Issue
Block a user