entc/gen: add Aggregate to <T>Select and <T>Query

This commit is contained in:
Ariel Mashraki
2022-10-22 22:39:13 +03:00
committed by Ariel Mashraki
parent 1bc4d48a51
commit 765ec09d31
165 changed files with 2629 additions and 226 deletions

View File

@@ -336,6 +336,11 @@ func (cq *CityQuery) Select(fields ...string) *CitySelect {
return selbuild
}
// Aggregate returns a CitySelect configured with the given aggregations.
func (cq *CityQuery) Aggregate(fns ...AggregateFunc) *CitySelect {
return cq.Select().Aggregate(fns...)
}
func (cq *CityQuery) prepareQuery(ctx context.Context) error {
for _, f := range cq.fields {
if !city.ValidColumn(f) {
@@ -572,8 +577,6 @@ func (cgb *CityGroupBy) sqlQuery() *sql.Selector {
for _, fn := range cgb.fns {
aggregation = append(aggregation, fn(selector))
}
// If no columns were selected in a custom aggregation function, the default
// selection is the fields used for "group-by", and the aggregation functions.
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(cgb.fields)+len(cgb.fns))
for _, f := range cgb.fields {
@@ -593,6 +596,12 @@ type CitySelect struct {
sql *sql.Selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (cs *CitySelect) Aggregate(fns ...AggregateFunc) *CitySelect {
cs.fns = append(cs.fns, fns...)
return cs
}
// Scan applies the selector query and scans the result into the given value.
func (cs *CitySelect) Scan(ctx context.Context, v any) error {
if err := cs.prepareQuery(ctx); err != nil {
@@ -603,6 +612,16 @@ func (cs *CitySelect) Scan(ctx context.Context, v any) error {
}
func (cs *CitySelect) sqlScan(ctx context.Context, v any) error {
aggregation := make([]string, 0, len(cs.fns))
for _, fn := range cs.fns {
aggregation = append(aggregation, fn(cs.sql))
}
switch n := len(*cs.selector.flds); {
case n == 0 && len(aggregation) > 0:
cs.sql.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
cs.sql.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := cs.sql.Query()
if err := cs.driver.Query(ctx, query, args, rows); err != nil {

View File

@@ -269,6 +269,7 @@ func IsConstraintError(err error) bool {
type selector struct {
label string
flds *[]string
fns []AggregateFunc
scan func(context.Context, any) error
}

View File

@@ -336,6 +336,11 @@ func (sq *StreetQuery) Select(fields ...string) *StreetSelect {
return selbuild
}
// Aggregate returns a StreetSelect configured with the given aggregations.
func (sq *StreetQuery) Aggregate(fns ...AggregateFunc) *StreetSelect {
return sq.Select().Aggregate(fns...)
}
func (sq *StreetQuery) prepareQuery(ctx context.Context) error {
for _, f := range sq.fields {
if !street.ValidColumn(f) {
@@ -576,8 +581,6 @@ func (sgb *StreetGroupBy) sqlQuery() *sql.Selector {
for _, fn := range sgb.fns {
aggregation = append(aggregation, fn(selector))
}
// If no columns were selected in a custom aggregation function, the default
// selection is the fields used for "group-by", and the aggregation functions.
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(sgb.fields)+len(sgb.fns))
for _, f := range sgb.fields {
@@ -597,6 +600,12 @@ type StreetSelect struct {
sql *sql.Selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (ss *StreetSelect) Aggregate(fns ...AggregateFunc) *StreetSelect {
ss.fns = append(ss.fns, fns...)
return ss
}
// Scan applies the selector query and scans the result into the given value.
func (ss *StreetSelect) Scan(ctx context.Context, v any) error {
if err := ss.prepareQuery(ctx); err != nil {
@@ -607,6 +616,16 @@ func (ss *StreetSelect) Scan(ctx context.Context, v any) error {
}
func (ss *StreetSelect) sqlScan(ctx context.Context, v any) error {
aggregation := make([]string, 0, len(ss.fns))
for _, fn := range ss.fns {
aggregation = append(aggregation, fn(ss.sql))
}
switch n := len(*ss.selector.flds); {
case n == 0 && len(aggregation) > 0:
ss.sql.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
ss.sql.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := ss.sql.Query()
if err := ss.driver.Query(ctx, query, args, rows); err != nil {