dialect/sql/sqlgraph: add tests for nodes count

This commit is contained in:
Ariel Mashraki
2019-12-12 17:22:36 +02:00
parent 26ccecdf6d
commit fb4e1b6234
2 changed files with 54 additions and 27 deletions

View File

@@ -443,13 +443,20 @@ type QuerySpec struct {
Assign func(...interface{}) error
}
// QueryNodes query the nodes in the graph and scan them to the given values.
// QueryNodes query the nodes in the graph query and scans them to the given values.
func QueryNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) error {
builder := Dialect(drv.Dialect())
qr := &query{graph: graph{builder: builder}, QuerySpec: spec}
return qr.nodes(ctx, drv)
}
// CountNodes counts the nodes in the given graph query.
func CountNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) (int, error) {
builder := Dialect(drv.Dialect())
qr := &query{graph: graph{builder: builder}, QuerySpec: spec}
return qr.count(ctx, drv)
}
type query struct {
graph
*QuerySpec
@@ -476,7 +483,12 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
rows := &Rows{}
query, args := q.selector().Count(q.Node.ID.Column).Query()
selector := q.selector().Count(q.Node.ID.Column)
if q.Unique {
selector.distinct = false
selector.Count(Distinct(q.Node.ID.Column))
}
query, args := selector.Query()
if err := drv.Query(ctx, query, args, rows); err != nil {
return 0, err
}

View File

@@ -1159,35 +1159,50 @@ func TestQueryNodes(t *testing.T) {
AddRow(1, 10, nil, nil, nil).
AddRow(2, 20, "", 0, 0).
AddRow(3, 30, "a8m", 1, 1))
var users []*user
err = QueryNodes(context.Background(), OpenDB("", db), &QuerySpec{
Node: &NodeSpec{
Table: "users",
Columns: []string{"id", "age", "name", "fk1", "fk2"},
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
},
Limit: 3,
Offset: 4,
Unique: true,
Order: func(s *Selector) {
s.OrderBy("id")
},
Predicate: func(s *Selector) {
s.Where(LT("age", 40))
},
ScanValues: func() []interface{} {
u := &user{}
users = append(users, u)
return append(u.values(), &NullInt64{}, &NullInt64{}) // extra values for fks.
},
Assign: func(values ...interface{}) error {
return users[len(users)-1].assign(values...)
},
})
mock.ExpectQuery(escape("SELECT COUNT(DISTINCT `id`) FROM `users` WHERE `age` < ? ORDER BY `id` LIMIT ? OFFSET ?")).
WithArgs(40, 3, 4).
WillReturnRows(sqlmock.NewRows([]string{"COUNT"}).
AddRow(3))
var (
users []*user
spec = &QuerySpec{
Node: &NodeSpec{
Table: "users",
Columns: []string{"id", "age", "name", "fk1", "fk2"},
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
},
Limit: 3,
Offset: 4,
Unique: true,
Order: func(s *Selector) {
s.OrderBy("id")
},
Predicate: func(s *Selector) {
s.Where(LT("age", 40))
},
ScanValues: func() []interface{} {
u := &user{}
users = append(users, u)
return append(u.values(), &NullInt64{}, &NullInt64{}) // extra values for fks.
},
Assign: func(values ...interface{}) error {
return users[len(users)-1].assign(values...)
},
}
)
// Query and scan.
err = QueryNodes(context.Background(), OpenDB("", db), spec)
require.NoError(t, err)
require.Equal(t, &user{id: 1, age: 10, name: ""}, users[0])
require.Equal(t, &user{id: 2, age: 20, name: ""}, users[1])
require.Equal(t, &user{id: 3, age: 30, name: "a8m", edges: struct{ fk1, fk2 int }{1, 1}}, users[2])
// Count nodes.
n, err := CountNodes(context.Background(), OpenDB("", db), spec)
require.NoError(t, err)
require.Equal(t, 3, n)
}
func escape(query string) string {