mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql: add method for finding selection occurrences in queries (#3473)
This commit is contained in:
@@ -2143,7 +2143,7 @@ type Selector struct {
|
||||
// generated code such as alternate table schemas.
|
||||
ctx context.Context
|
||||
as string
|
||||
selection []any
|
||||
selection []selection
|
||||
from []TableView
|
||||
joins []join
|
||||
where *Predicate
|
||||
@@ -2196,12 +2196,19 @@ func SelectExpr(exprs ...Querier) *Selector {
|
||||
return (&Selector{}).SelectExpr(exprs...)
|
||||
}
|
||||
|
||||
// selection represents a column or an expression selection.
|
||||
type selection struct {
|
||||
x Querier
|
||||
c string
|
||||
as string
|
||||
}
|
||||
|
||||
// Select changes the columns selection of the SELECT statement.
|
||||
// Empty selection means all columns *.
|
||||
func (s *Selector) Select(columns ...string) *Selector {
|
||||
s.selection = make([]any, len(columns))
|
||||
s.selection = make([]selection, len(columns))
|
||||
for i := range columns {
|
||||
s.selection[i] = columns[i]
|
||||
s.selection[i] = selection{c: columns[i]}
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -2209,31 +2216,23 @@ func (s *Selector) Select(columns ...string) *Selector {
|
||||
// AppendSelect appends additional columns to the SELECT statement.
|
||||
func (s *Selector) AppendSelect(columns ...string) *Selector {
|
||||
for i := range columns {
|
||||
s.selection = append(s.selection, columns[i])
|
||||
s.selection = append(s.selection, selection{c: columns[i]})
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// AppendSelectAs appends additional column to the SELECT statement with the given alias.
|
||||
func (s *Selector) AppendSelectAs(column, as string) *Selector {
|
||||
s.selection = append(s.selection, ExprFunc(func(b *Builder) {
|
||||
if b.isIdent(column) || isFunc(column) || isModifier(column) {
|
||||
b.WriteString(column)
|
||||
} else {
|
||||
b.WriteString(s.C(column))
|
||||
}
|
||||
b.WriteString(" AS ")
|
||||
b.Ident(as)
|
||||
}))
|
||||
s.selection = append(s.selection, selection{c: column, as: as})
|
||||
return s
|
||||
}
|
||||
|
||||
// SelectExpr changes the columns selection of the SELECT statement
|
||||
// with custom list of expressions.
|
||||
func (s *Selector) SelectExpr(exprs ...Querier) *Selector {
|
||||
s.selection = make([]any, len(exprs))
|
||||
s.selection = make([]selection, len(exprs))
|
||||
for i := range exprs {
|
||||
s.selection[i] = exprs[i]
|
||||
s.selection[i] = selection{x: exprs[i]}
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -2241,30 +2240,69 @@ func (s *Selector) SelectExpr(exprs ...Querier) *Selector {
|
||||
// AppendSelectExpr appends additional expressions to the SELECT statement.
|
||||
func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector {
|
||||
for i := range exprs {
|
||||
s.selection = append(s.selection, exprs[i])
|
||||
s.selection = append(s.selection, selection{x: exprs[i]})
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// AppendSelectExprAs appends additional expressions to the SELECT statement with the given name.
|
||||
func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector {
|
||||
s.selection = append(s.selection, ExprFunc(func(b *Builder) {
|
||||
switch expr.(type) {
|
||||
case *raw:
|
||||
// Raw expressions are not wrapped in parentheses.
|
||||
b.Join(expr).S(" AS ").Ident(as)
|
||||
default:
|
||||
b.S("(").Join(expr).S(") AS ").Ident(as)
|
||||
}
|
||||
}))
|
||||
x := expr
|
||||
if _, ok := expr.(*raw); !ok {
|
||||
x = ExprFunc(func(b *Builder) {
|
||||
b.S("(").Join(expr).S(")")
|
||||
})
|
||||
}
|
||||
s.selection = append(s.selection, selection{
|
||||
x: x,
|
||||
as: as,
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// FindSelection returns all occurrences in the selection that match the given column name.
|
||||
// For example, for column "a" the following match: a, "a", "t"."a", "t"."b" AS "a".
|
||||
func (s *Selector) FindSelection(name string) (matches []string) {
|
||||
matchC := func(qualified string) bool {
|
||||
switch ident, pg := s.isIdent(qualified), s.postgres(); {
|
||||
case !ident:
|
||||
if i := strings.IndexRune(qualified, '.'); i > 0 {
|
||||
return qualified[i+1:] == name
|
||||
}
|
||||
case ident && pg:
|
||||
if i := strings.Index(qualified, `"."`); i > 0 {
|
||||
return s.unquote(qualified[i+2:]) == name
|
||||
}
|
||||
case ident:
|
||||
if i := strings.Index(qualified, "`.`"); i > 0 {
|
||||
return s.unquote(qualified[i+2:]) == name
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
for _, c := range s.selection {
|
||||
switch {
|
||||
// Match aliases.
|
||||
case c.as != "":
|
||||
if ident := s.isIdent(c.as); !ident && c.as == name || ident && s.unquote(c.as) == name {
|
||||
matches = append(matches, c.as)
|
||||
}
|
||||
// Match qualified columns.
|
||||
case c.c != "" && s.isQualified(c.c) && matchC(c.c):
|
||||
matches = append(matches, c.c)
|
||||
// Match unqualified columns.
|
||||
case c.c != "" && (c.c == name || s.isIdent(c.c) && s.unquote(c.c) == name):
|
||||
matches = append(matches, c.c)
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
// SelectedColumns returns the selected columns in the Selector.
|
||||
func (s *Selector) SelectedColumns() []string {
|
||||
columns := make([]string, 0, len(s.selection))
|
||||
for i := range s.selection {
|
||||
if c, ok := s.selection[i].(string); ok {
|
||||
if c := s.selection[i].c; c != "" {
|
||||
columns = append(columns, c)
|
||||
}
|
||||
}
|
||||
@@ -2276,8 +2314,8 @@ func (s *Selector) SelectedColumns() []string {
|
||||
func (s *Selector) UnqualifiedColumns() []string {
|
||||
columns := make([]string, 0, len(s.selection))
|
||||
for i := range s.selection {
|
||||
c, ok := s.selection[i].(string)
|
||||
if !ok {
|
||||
c := s.selection[i].c
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
if s.isIdent(c) {
|
||||
@@ -2792,7 +2830,7 @@ func (s *Selector) Clone() *Selector {
|
||||
joins: append([]join{}, joins...),
|
||||
group: append([]string{}, s.group...),
|
||||
order: append([]any{}, s.order...),
|
||||
selection: append([]any{}, s.selection...),
|
||||
selection: append([]selection{}, s.selection...),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3038,15 +3076,19 @@ func joinReturning(columns []string, b *Builder) {
|
||||
}
|
||||
|
||||
func (s *Selector) joinSelect(b *Builder) {
|
||||
for i := range s.selection {
|
||||
for i, sc := range s.selection {
|
||||
if i > 0 {
|
||||
b.Comma()
|
||||
}
|
||||
switch s := s.selection[i].(type) {
|
||||
case string:
|
||||
b.Ident(s)
|
||||
case Querier:
|
||||
b.Join(s)
|
||||
switch {
|
||||
case sc.c != "":
|
||||
b.Ident(sc.c)
|
||||
case sc.x != nil:
|
||||
b.Join(sc.x)
|
||||
}
|
||||
if sc.as != "" {
|
||||
b.WriteString(" AS ")
|
||||
b.Ident(sc.as)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3727,6 +3769,18 @@ func (b *Builder) isIdent(s string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// unquote database identifiers.
|
||||
func (b *Builder) unquote(s string) string {
|
||||
switch pg := b.postgres(); {
|
||||
case len(s) < 2:
|
||||
case !pg && s[0] == '`' && s[len(s)-1] == '`', pg && s[0] == '"' && s[len(s)-1] == '"':
|
||||
if u, err := strconv.Unquote(s); err == nil {
|
||||
return u
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// isIdent reports if the given string is a qualified identifier.
|
||||
func (b *Builder) isQualified(s string) bool {
|
||||
ident, pg := b.isIdent(s), b.postgres()
|
||||
|
||||
Reference in New Issue
Block a user