diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index c94013272..f21673fa4 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -443,13 +443,20 @@ type QuerySpec struct { Assign func(...interface{}) error } -// QueryNodes query the nodes in the graph and scan them to the given values. +// QueryNodes query the nodes in the graph query and scans them to the given values. func QueryNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) error { builder := Dialect(drv.Dialect()) qr := &query{graph: graph{builder: builder}, QuerySpec: spec} return qr.nodes(ctx, drv) } +// CountNodes counts the nodes in the given graph query. +func CountNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) (int, error) { + builder := Dialect(drv.Dialect()) + qr := &query{graph: graph{builder: builder}, QuerySpec: spec} + return qr.count(ctx, drv) +} + type query struct { graph *QuerySpec @@ -476,7 +483,12 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error { func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) { rows := &Rows{} - query, args := q.selector().Count(q.Node.ID.Column).Query() + selector := q.selector().Count(q.Node.ID.Column) + if q.Unique { + selector.distinct = false + selector.Count(Distinct(q.Node.ID.Column)) + } + query, args := selector.Query() if err := drv.Query(ctx, query, args, rows); err != nil { return 0, err } diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index 4859b65e2..3297eb9f4 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -1159,35 +1159,50 @@ func TestQueryNodes(t *testing.T) { AddRow(1, 10, nil, nil, nil). AddRow(2, 20, "", 0, 0). AddRow(3, 30, "a8m", 1, 1)) - var users []*user - err = QueryNodes(context.Background(), OpenDB("", db), &QuerySpec{ - Node: &NodeSpec{ - Table: "users", - Columns: []string{"id", "age", "name", "fk1", "fk2"}, - ID: &FieldSpec{Column: "id", Type: field.TypeInt}, - }, - Limit: 3, - Offset: 4, - Unique: true, - Order: func(s *Selector) { - s.OrderBy("id") - }, - Predicate: func(s *Selector) { - s.Where(LT("age", 40)) - }, - ScanValues: func() []interface{} { - u := &user{} - users = append(users, u) - return append(u.values(), &NullInt64{}, &NullInt64{}) // extra values for fks. - }, - Assign: func(values ...interface{}) error { - return users[len(users)-1].assign(values...) - }, - }) + mock.ExpectQuery(escape("SELECT COUNT(DISTINCT `id`) FROM `users` WHERE `age` < ? ORDER BY `id` LIMIT ? OFFSET ?")). + WithArgs(40, 3, 4). + WillReturnRows(sqlmock.NewRows([]string{"COUNT"}). + AddRow(3)) + + var ( + users []*user + spec = &QuerySpec{ + Node: &NodeSpec{ + Table: "users", + Columns: []string{"id", "age", "name", "fk1", "fk2"}, + ID: &FieldSpec{Column: "id", Type: field.TypeInt}, + }, + Limit: 3, + Offset: 4, + Unique: true, + Order: func(s *Selector) { + s.OrderBy("id") + }, + Predicate: func(s *Selector) { + s.Where(LT("age", 40)) + }, + ScanValues: func() []interface{} { + u := &user{} + users = append(users, u) + return append(u.values(), &NullInt64{}, &NullInt64{}) // extra values for fks. + }, + Assign: func(values ...interface{}) error { + return users[len(users)-1].assign(values...) + }, + } + ) + + // Query and scan. + err = QueryNodes(context.Background(), OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, &user{id: 1, age: 10, name: ""}, users[0]) require.Equal(t, &user{id: 2, age: 20, name: ""}, users[1]) require.Equal(t, &user{id: 3, age: 30, name: "a8m", edges: struct{ fk1, fk2 int }{1, 1}}, users[2]) + + // Count nodes. + n, err := CountNodes(context.Background(), OpenDB("", db), spec) + require.NoError(t, err) + require.Equal(t, 3, n) } func escape(query string) string {