mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: privatize table columns check
This commit is contained in:
committed by
Ariel Mashraki
parent
07570c5e3f
commit
f12ef91829
@@ -14,6 +14,7 @@ import (
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/examples/privacyadmin/ent/user"
|
||||
)
|
||||
|
||||
// ent aliases to avoid import conflicts in user's code.
|
||||
@@ -29,36 +30,55 @@ type (
|
||||
)
|
||||
|
||||
// OrderFunc applies an ordering on the sql selector.
|
||||
type OrderFunc func(*sql.Selector, func(string) bool)
|
||||
type OrderFunc func(*sql.Selector)
|
||||
|
||||
// columnChecker returns a function indicates if the column exists in the given column.
|
||||
func columnChecker(table string) func(string) error {
|
||||
checks := map[string]func(string) bool{
|
||||
user.Table: user.ValidColumn,
|
||||
}
|
||||
check, ok := checks[table]
|
||||
if !ok {
|
||||
return func(string) error {
|
||||
return fmt.Errorf("ent: unknown table %q", table)
|
||||
}
|
||||
}
|
||||
return func(column string) error {
|
||||
if !check(column) {
|
||||
return fmt.Errorf("ent: unknown column %q for table %q", column, table)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Asc applies the given fields in ASC order.
|
||||
func Asc(fields ...string) OrderFunc {
|
||||
return func(s *sql.Selector, check func(string) bool) {
|
||||
return func(s *sql.Selector) {
|
||||
check := columnChecker(s.TableName())
|
||||
for _, f := range fields {
|
||||
if check(f) {
|
||||
s.OrderBy(sql.Asc(f))
|
||||
} else {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)})
|
||||
if err := check(f); err != nil {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)})
|
||||
}
|
||||
s.OrderBy(sql.Asc(s.C(f)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Desc applies the given fields in DESC order.
|
||||
func Desc(fields ...string) OrderFunc {
|
||||
return func(s *sql.Selector, check func(string) bool) {
|
||||
return func(s *sql.Selector) {
|
||||
check := columnChecker(s.TableName())
|
||||
for _, f := range fields {
|
||||
if check(f) {
|
||||
s.OrderBy(sql.Desc(f))
|
||||
} else {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)})
|
||||
if err := check(f); err != nil {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)})
|
||||
}
|
||||
s.OrderBy(sql.Desc(s.C(f)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AggregateFunc applies an aggregation step on the group-by traversal/selector.
|
||||
type AggregateFunc func(*sql.Selector, func(string) bool) string
|
||||
type AggregateFunc func(*sql.Selector) string
|
||||
|
||||
// As is a pseudo aggregation function for renaming another other functions with custom names. For example:
|
||||
//
|
||||
@@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string
|
||||
// Scan(ctx, &v)
|
||||
//
|
||||
func As(fn AggregateFunc, end string) AggregateFunc {
|
||||
return func(s *sql.Selector, check func(string) bool) string {
|
||||
return sql.As(fn(s, check), end)
|
||||
return func(s *sql.Selector) string {
|
||||
return sql.As(fn(s), end)
|
||||
}
|
||||
}
|
||||
|
||||
// Count applies the "count" aggregation function on each group.
|
||||
func Count() AggregateFunc {
|
||||
return func(s *sql.Selector, _ func(string) bool) string {
|
||||
return func(s *sql.Selector) string {
|
||||
return sql.Count("*")
|
||||
}
|
||||
}
|
||||
|
||||
// Max applies the "max" aggregation function on the given field of each group.
|
||||
func Max(field string) AggregateFunc {
|
||||
return func(s *sql.Selector, check func(string) bool) string {
|
||||
if !check(field) {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
|
||||
return func(s *sql.Selector) string {
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)})
|
||||
return ""
|
||||
}
|
||||
return sql.Max(s.C(field))
|
||||
@@ -92,9 +113,10 @@ func Max(field string) AggregateFunc {
|
||||
|
||||
// Mean applies the "mean" aggregation function on the given field of each group.
|
||||
func Mean(field string) AggregateFunc {
|
||||
return func(s *sql.Selector, check func(string) bool) string {
|
||||
if !check(field) {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
|
||||
return func(s *sql.Selector) string {
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)})
|
||||
return ""
|
||||
}
|
||||
return sql.Avg(s.C(field))
|
||||
@@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc {
|
||||
|
||||
// Min applies the "min" aggregation function on the given field of each group.
|
||||
func Min(field string) AggregateFunc {
|
||||
return func(s *sql.Selector, check func(string) bool) string {
|
||||
if !check(field) {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
|
||||
return func(s *sql.Selector) string {
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)})
|
||||
return ""
|
||||
}
|
||||
return sql.Min(s.C(field))
|
||||
@@ -114,9 +137,10 @@ func Min(field string) AggregateFunc {
|
||||
|
||||
// Sum applies the "sum" aggregation function on the given field of each group.
|
||||
func Sum(field string) AggregateFunc {
|
||||
return func(s *sql.Selector, check func(string) bool) string {
|
||||
if !check(field) {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
|
||||
return func(s *sql.Selector) string {
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)})
|
||||
return ""
|
||||
}
|
||||
return sql.Sum(s.C(field))
|
||||
|
||||
@@ -398,7 +398,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec {
|
||||
if ps := uq.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector, user.ValidColumn)
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -417,7 +417,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range uq.order {
|
||||
p(selector, user.ValidColumn)
|
||||
p(selector)
|
||||
}
|
||||
if offset := uq.offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
@@ -683,7 +683,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector {
|
||||
columns := make([]string, 0, len(ugb.fields)+len(ugb.fns))
|
||||
columns = append(columns, ugb.fields...)
|
||||
for _, fn := range ugb.fns {
|
||||
columns = append(columns, fn(selector, user.ValidColumn))
|
||||
columns = append(columns, fn(selector))
|
||||
}
|
||||
return selector.Select(columns...).GroupBy(ugb.fields...)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user