mirror of
https://github.com/ent/ent.git
synced 2026-05-28 09:49:08 +03:00
entc/gen: privatize table columns check
This commit is contained in:
committed by
Ariel Mashraki
parent
07570c5e3f
commit
f12ef91829
@@ -5,41 +5,63 @@ in the LICENSE file in the root directory of this source tree.
|
||||
*/}}
|
||||
|
||||
{{ define "dialect/sql/order/signature" -}}
|
||||
// OrderFunc applies an ordering on the sql selector.
|
||||
type OrderFunc func(*sql.Selector, func(string) bool)
|
||||
// OrderFunc applies an ordering on the sql selector.
|
||||
type OrderFunc func(*sql.Selector)
|
||||
|
||||
// columnChecker returns a function indicates if the column exists in the given column.
|
||||
func columnChecker(table string) func(string) error {
|
||||
checks := map[string]func(string) bool{
|
||||
{{- range $n := $.Nodes }}
|
||||
{{ $n.Package }}.Table: {{ $n.Package }}.ValidColumn,
|
||||
{{- end }}
|
||||
}
|
||||
check, ok := checks[table]
|
||||
if !ok {
|
||||
return func(string) error {
|
||||
return fmt.Errorf("{{ base $.Config.Package }}: unknown table %q", table)
|
||||
}
|
||||
}
|
||||
return func(column string) error {
|
||||
if !check(column) {
|
||||
return fmt.Errorf("{{ base $.Config.Package }}: unknown column %q for table %q", column, table)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/order/func" -}}
|
||||
{{- $f := $.Scope.Func -}}
|
||||
func(s *sql.Selector, check func(string) bool) {
|
||||
func(s *sql.Selector) {
|
||||
check := columnChecker(s.TableName())
|
||||
for _, f := range fields {
|
||||
if check(f) {
|
||||
s.OrderBy(sql.{{ $f }}(f))
|
||||
} else {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)})
|
||||
if err := check(f); err != nil {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)})
|
||||
}
|
||||
s.OrderBy(sql.{{ $f }}(s.C(f)))
|
||||
}
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
{{/* custom signature for group-by function */}}
|
||||
{{ define "dialect/sql/group/signature" -}}
|
||||
type AggregateFunc func(*sql.Selector, func(string) bool) string
|
||||
type AggregateFunc func(*sql.Selector) string
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/group/as" -}}
|
||||
func(s *sql.Selector, check func(string) bool) string {
|
||||
return sql.As(fn(s, check), end)
|
||||
func(s *sql.Selector) string {
|
||||
return sql.As(fn(s), end)
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/group/func" -}}
|
||||
{{- $fn := $.Scope.Func -}}
|
||||
{{- $withField := $.Scope.WithField -}}
|
||||
func(s *sql.Selector, {{ if $withField }}check{{ else }}_{{ end }} func(string) bool) string {
|
||||
func(s *sql.Selector) string {
|
||||
{{- if $withField }}
|
||||
if !check(field) {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)})
|
||||
return ""
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
@@ -33,7 +33,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery() *sql.Selector {
|
||||
columns := make([]string, 0, len({{ $receiver }}.fields) + len({{ $receiver}}.fns))
|
||||
columns = append(columns, {{ $receiver }}.fields...)
|
||||
for _, fn := range {{ $receiver }}.fns {
|
||||
columns = append(columns, fn(selector, {{ $.Package }}.ValidColumn))
|
||||
columns = append(columns, fn(selector))
|
||||
}
|
||||
return selector.Select(columns...).GroupBy({{ $receiver }}.fields...)
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec {
|
||||
if ps := {{ $receiver }}.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector, {{ $.Package }}.ValidColumn)
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -177,7 +177,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range {{ $receiver }}.order {
|
||||
p(selector, {{ $.Package }}.ValidColumn)
|
||||
p(selector)
|
||||
}
|
||||
if offset := {{ $receiver }}.offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
|
||||
Reference in New Issue
Block a user