dialect/sql: add method for finding selection occurrences in queries (#3473)

This commit is contained in:
Ariel Mashraki
2023-04-12 23:22:52 +03:00
committed by GitHub
parent 44b8648720
commit 8cb27bc7cf
3 changed files with 167 additions and 64 deletions

View File

@@ -2143,7 +2143,7 @@ type Selector struct {
// generated code such as alternate table schemas.
ctx context.Context
as string
selection []any
selection []selection
from []TableView
joins []join
where *Predicate
@@ -2196,12 +2196,19 @@ func SelectExpr(exprs ...Querier) *Selector {
return (&Selector{}).SelectExpr(exprs...)
}
// selection represents a column or an expression selection.
type selection struct {
x Querier
c string
as string
}
// Select changes the columns selection of the SELECT statement.
// Empty selection means all columns *.
func (s *Selector) Select(columns ...string) *Selector {
s.selection = make([]any, len(columns))
s.selection = make([]selection, len(columns))
for i := range columns {
s.selection[i] = columns[i]
s.selection[i] = selection{c: columns[i]}
}
return s
}
@@ -2209,31 +2216,23 @@ func (s *Selector) Select(columns ...string) *Selector {
// AppendSelect appends additional columns to the SELECT statement.
func (s *Selector) AppendSelect(columns ...string) *Selector {
for i := range columns {
s.selection = append(s.selection, columns[i])
s.selection = append(s.selection, selection{c: columns[i]})
}
return s
}
// AppendSelectAs appends additional column to the SELECT statement with the given alias.
func (s *Selector) AppendSelectAs(column, as string) *Selector {
s.selection = append(s.selection, ExprFunc(func(b *Builder) {
if b.isIdent(column) || isFunc(column) || isModifier(column) {
b.WriteString(column)
} else {
b.WriteString(s.C(column))
}
b.WriteString(" AS ")
b.Ident(as)
}))
s.selection = append(s.selection, selection{c: column, as: as})
return s
}
// SelectExpr changes the columns selection of the SELECT statement
// with custom list of expressions.
func (s *Selector) SelectExpr(exprs ...Querier) *Selector {
s.selection = make([]any, len(exprs))
s.selection = make([]selection, len(exprs))
for i := range exprs {
s.selection[i] = exprs[i]
s.selection[i] = selection{x: exprs[i]}
}
return s
}
@@ -2241,30 +2240,69 @@ func (s *Selector) SelectExpr(exprs ...Querier) *Selector {
// AppendSelectExpr appends additional expressions to the SELECT statement.
func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector {
for i := range exprs {
s.selection = append(s.selection, exprs[i])
s.selection = append(s.selection, selection{x: exprs[i]})
}
return s
}
// AppendSelectExprAs appends additional expressions to the SELECT statement with the given name.
func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector {
s.selection = append(s.selection, ExprFunc(func(b *Builder) {
switch expr.(type) {
case *raw:
// Raw expressions are not wrapped in parentheses.
b.Join(expr).S(" AS ").Ident(as)
default:
b.S("(").Join(expr).S(") AS ").Ident(as)
}
}))
x := expr
if _, ok := expr.(*raw); !ok {
x = ExprFunc(func(b *Builder) {
b.S("(").Join(expr).S(")")
})
}
s.selection = append(s.selection, selection{
x: x,
as: as,
})
return s
}
// FindSelection returns all occurrences in the selection that match the given column name.
// For example, for column "a" the following match: a, "a", "t"."a", "t"."b" AS "a".
func (s *Selector) FindSelection(name string) (matches []string) {
matchC := func(qualified string) bool {
switch ident, pg := s.isIdent(qualified), s.postgres(); {
case !ident:
if i := strings.IndexRune(qualified, '.'); i > 0 {
return qualified[i+1:] == name
}
case ident && pg:
if i := strings.Index(qualified, `"."`); i > 0 {
return s.unquote(qualified[i+2:]) == name
}
case ident:
if i := strings.Index(qualified, "`.`"); i > 0 {
return s.unquote(qualified[i+2:]) == name
}
}
return false
}
for _, c := range s.selection {
switch {
// Match aliases.
case c.as != "":
if ident := s.isIdent(c.as); !ident && c.as == name || ident && s.unquote(c.as) == name {
matches = append(matches, c.as)
}
// Match qualified columns.
case c.c != "" && s.isQualified(c.c) && matchC(c.c):
matches = append(matches, c.c)
// Match unqualified columns.
case c.c != "" && (c.c == name || s.isIdent(c.c) && s.unquote(c.c) == name):
matches = append(matches, c.c)
}
}
return matches
}
// SelectedColumns returns the selected columns in the Selector.
func (s *Selector) SelectedColumns() []string {
columns := make([]string, 0, len(s.selection))
for i := range s.selection {
if c, ok := s.selection[i].(string); ok {
if c := s.selection[i].c; c != "" {
columns = append(columns, c)
}
}
@@ -2276,8 +2314,8 @@ func (s *Selector) SelectedColumns() []string {
func (s *Selector) UnqualifiedColumns() []string {
columns := make([]string, 0, len(s.selection))
for i := range s.selection {
c, ok := s.selection[i].(string)
if !ok {
c := s.selection[i].c
if c == "" {
continue
}
if s.isIdent(c) {
@@ -2792,7 +2830,7 @@ func (s *Selector) Clone() *Selector {
joins: append([]join{}, joins...),
group: append([]string{}, s.group...),
order: append([]any{}, s.order...),
selection: append([]any{}, s.selection...),
selection: append([]selection{}, s.selection...),
}
}
@@ -3038,15 +3076,19 @@ func joinReturning(columns []string, b *Builder) {
}
func (s *Selector) joinSelect(b *Builder) {
for i := range s.selection {
for i, sc := range s.selection {
if i > 0 {
b.Comma()
}
switch s := s.selection[i].(type) {
case string:
b.Ident(s)
case Querier:
b.Join(s)
switch {
case sc.c != "":
b.Ident(sc.c)
case sc.x != nil:
b.Join(sc.x)
}
if sc.as != "" {
b.WriteString(" AS ")
b.Ident(sc.as)
}
}
}
@@ -3727,6 +3769,18 @@ func (b *Builder) isIdent(s string) bool {
}
}
// unquote database identifiers.
func (b *Builder) unquote(s string) string {
switch pg := b.postgres(); {
case len(s) < 2:
case !pg && s[0] == '`' && s[len(s)-1] == '`', pg && s[0] == '"' && s[len(s)-1] == '"':
if u, err := strconv.Unquote(s); err == nil {
return u
}
}
return s
}
// isIdent reports if the given string is a qualified identifier.
func (b *Builder) isQualified(s string) bool {
ident, pg := b.isIdent(s), b.postgres()