mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: add Aggregate to <T>Select and <T>Query
This commit is contained in:
committed by
Ariel Mashraki
parent
1bc4d48a51
commit
765ec09d31
@@ -336,6 +336,11 @@ func (cq *CarQuery) Select(fields ...string) *CarSelect {
|
||||
return selbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a CarSelect configured with the given aggregations.
|
||||
func (cq *CarQuery) Aggregate(fns ...AggregateFunc) *CarSelect {
|
||||
return cq.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (cq *CarQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, f := range cq.fields {
|
||||
if !car.ValidColumn(f) {
|
||||
@@ -576,8 +581,6 @@ func (cgb *CarGroupBy) sqlQuery() *sql.Selector {
|
||||
for _, fn := range cgb.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
// If no columns were selected in a custom aggregation function, the default
|
||||
// selection is the fields used for "group-by", and the aggregation functions.
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(cgb.fields)+len(cgb.fns))
|
||||
for _, f := range cgb.fields {
|
||||
@@ -597,6 +600,12 @@ type CarSelect struct {
|
||||
sql *sql.Selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (cs *CarSelect) Aggregate(fns ...AggregateFunc) *CarSelect {
|
||||
cs.fns = append(cs.fns, fns...)
|
||||
return cs
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (cs *CarSelect) Scan(ctx context.Context, v any) error {
|
||||
if err := cs.prepareQuery(ctx); err != nil {
|
||||
@@ -607,6 +616,16 @@ func (cs *CarSelect) Scan(ctx context.Context, v any) error {
|
||||
}
|
||||
|
||||
func (cs *CarSelect) sqlScan(ctx context.Context, v any) error {
|
||||
aggregation := make([]string, 0, len(cs.fns))
|
||||
for _, fn := range cs.fns {
|
||||
aggregation = append(aggregation, fn(cs.sql))
|
||||
}
|
||||
switch n := len(*cs.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
cs.sql.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
cs.sql.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := cs.sql.Query()
|
||||
if err := cs.driver.Query(ctx, query, args, rows); err != nil {
|
||||
|
||||
@@ -271,6 +271,7 @@ func IsConstraintError(err error) bool {
|
||||
type selector struct {
|
||||
label string
|
||||
flds *[]string
|
||||
fns []AggregateFunc
|
||||
scan func(context.Context, any) error
|
||||
}
|
||||
|
||||
|
||||
@@ -336,6 +336,11 @@ func (gq *GroupQuery) Select(fields ...string) *GroupSelect {
|
||||
return selbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a GroupSelect configured with the given aggregations.
|
||||
func (gq *GroupQuery) Aggregate(fns ...AggregateFunc) *GroupSelect {
|
||||
return gq.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (gq *GroupQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, f := range gq.fields {
|
||||
if !group.ValidColumn(f) {
|
||||
@@ -599,8 +604,6 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector {
|
||||
for _, fn := range ggb.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
// If no columns were selected in a custom aggregation function, the default
|
||||
// selection is the fields used for "group-by", and the aggregation functions.
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(ggb.fields)+len(ggb.fns))
|
||||
for _, f := range ggb.fields {
|
||||
@@ -620,6 +623,12 @@ type GroupSelect struct {
|
||||
sql *sql.Selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect {
|
||||
gs.fns = append(gs.fns, fns...)
|
||||
return gs
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (gs *GroupSelect) Scan(ctx context.Context, v any) error {
|
||||
if err := gs.prepareQuery(ctx); err != nil {
|
||||
@@ -630,6 +639,16 @@ func (gs *GroupSelect) Scan(ctx context.Context, v any) error {
|
||||
}
|
||||
|
||||
func (gs *GroupSelect) sqlScan(ctx context.Context, v any) error {
|
||||
aggregation := make([]string, 0, len(gs.fns))
|
||||
for _, fn := range gs.fns {
|
||||
aggregation = append(aggregation, fn(gs.sql))
|
||||
}
|
||||
switch n := len(*gs.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
gs.sql.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
gs.sql.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := gs.sql.Query()
|
||||
if err := gs.driver.Query(ctx, query, args, rows); err != nil {
|
||||
|
||||
@@ -372,6 +372,11 @@ func (uq *UserQuery) Select(fields ...string) *UserSelect {
|
||||
return selbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a UserSelect configured with the given aggregations.
|
||||
func (uq *UserQuery) Aggregate(fns ...AggregateFunc) *UserSelect {
|
||||
return uq.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (uq *UserQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, f := range uq.fields {
|
||||
if !user.ValidColumn(f) {
|
||||
@@ -674,8 +679,6 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector {
|
||||
for _, fn := range ugb.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
// If no columns were selected in a custom aggregation function, the default
|
||||
// selection is the fields used for "group-by", and the aggregation functions.
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(ugb.fields)+len(ugb.fns))
|
||||
for _, f := range ugb.fields {
|
||||
@@ -695,6 +698,12 @@ type UserSelect struct {
|
||||
sql *sql.Selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect {
|
||||
us.fns = append(us.fns, fns...)
|
||||
return us
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (us *UserSelect) Scan(ctx context.Context, v any) error {
|
||||
if err := us.prepareQuery(ctx); err != nil {
|
||||
@@ -705,6 +714,16 @@ func (us *UserSelect) Scan(ctx context.Context, v any) error {
|
||||
}
|
||||
|
||||
func (us *UserSelect) sqlScan(ctx context.Context, v any) error {
|
||||
aggregation := make([]string, 0, len(us.fns))
|
||||
for _, fn := range us.fns {
|
||||
aggregation = append(aggregation, fn(us.sql))
|
||||
}
|
||||
switch n := len(*us.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
us.sql.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
us.sql.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := us.sql.Query()
|
||||
if err := us.driver.Query(ctx, query, args, rows); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user