mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: move columns check from codegen to sql package (#3431)
This commit is contained in:
@@ -10,34 +10,29 @@ in the LICENSE file in the root directory of this source tree.
|
||||
// OrderFunc applies an ordering on the sql selector.
|
||||
type OrderFunc func(*sql.Selector)
|
||||
|
||||
// columnChecker returns a function indicates if the column exists in the given column.
|
||||
func columnChecker(table string) func(string) error {
|
||||
checks := map[string]func(string) bool{
|
||||
{{- range $n := $.Nodes }}
|
||||
{{ $n.Package }}.Table: {{ $n.Package }}.ValidColumn,
|
||||
{{- end }}
|
||||
}
|
||||
check, ok := checks[table]
|
||||
if !ok {
|
||||
return func(string) error {
|
||||
return fmt.Errorf("unknown table %q", table)
|
||||
}
|
||||
}
|
||||
return func(column string) error {
|
||||
if !check(column) {
|
||||
return fmt.Errorf("unknown column %q for table %q", column, table)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
initCheck sync.Once
|
||||
columnCheck sql.ColumnCheck
|
||||
)
|
||||
|
||||
// columnChecker checks if the column exists in the given table.
|
||||
func checkColumn(table, column string) error {
|
||||
initCheck.Do(func() {
|
||||
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
|
||||
{{- range $n := $.Nodes }}
|
||||
{{ $n.Package }}.Table: {{ $n.Package }}.ValidColumn,
|
||||
{{- end }}
|
||||
})
|
||||
})
|
||||
return columnCheck(table, column)
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/order/func" -}}
|
||||
{{- $f := $.Scope.Func -}}
|
||||
func(s *sql.Selector) {
|
||||
check := columnChecker(s.TableName())
|
||||
for _, f := range fields {
|
||||
if err := check(f); err != nil {
|
||||
if err := checkColumn(s.TableName(), f); err != nil {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("{{ base $.Config.Package }}: %w", err)})
|
||||
}
|
||||
s.OrderBy(sql.{{ $f }}(s.C(f)))
|
||||
@@ -61,8 +56,7 @@ func columnChecker(table string) func(string) error {
|
||||
{{- $withField := $.Scope.WithField -}}
|
||||
func(s *sql.Selector) string {
|
||||
{{- if $withField }}
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
if err := checkColumn(s.TableName(), field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("{{ base $.Config.Package }}: %w", err)})
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user