mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql: fix bad query generation on clone
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/97 Reviewed By: alexsn Differential Revision: D17931134 fbshipit-source-id: e181e253f6e29b757f5ca647195d87d7ac89a0c8
This commit is contained in:
committed by
Facebook Github Bot
parent
3f1d942d3e
commit
a6df15101f
@@ -1205,6 +1205,14 @@ type join struct {
|
||||
table TableView
|
||||
}
|
||||
|
||||
// clone a joiner.
|
||||
func (j join) clone() join {
|
||||
if sel, ok := j.table.(*Selector); ok {
|
||||
j.table = sel.Clone()
|
||||
}
|
||||
return j
|
||||
}
|
||||
|
||||
// Selector is a builder for the `SELECT` statement.
|
||||
type Selector struct {
|
||||
Builder
|
||||
@@ -1386,6 +1394,10 @@ func (s *Selector) Clone() *Selector {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
joins := make([]join, len(s.joins))
|
||||
for i := range s.joins {
|
||||
joins[i] = s.joins[i].clone()
|
||||
}
|
||||
return &Selector{
|
||||
Builder: s.Builder.clone(),
|
||||
as: s.as,
|
||||
@@ -1397,7 +1409,7 @@ func (s *Selector) Clone() *Selector {
|
||||
distinct: s.distinct,
|
||||
where: s.where.clone(),
|
||||
having: s.having.clone(),
|
||||
joins: append([]join{}, s.joins...),
|
||||
joins: append([]join{}, joins...),
|
||||
group: append([]string{}, s.group...),
|
||||
order: append([]string{}, s.order...),
|
||||
columns: append([]string{}, s.columns...),
|
||||
@@ -1437,77 +1449,78 @@ func (s *Selector) Having(p *Predicate) *Selector {
|
||||
}
|
||||
|
||||
// Query returns query representation of a `SELECT` statement.
|
||||
func (s *Selector) Query() (string, []interface{}) {
|
||||
s.WriteString("SELECT ")
|
||||
func (s Selector) Query() (string, []interface{}) {
|
||||
b := s.Builder.clone()
|
||||
b.WriteString("SELECT ")
|
||||
if s.distinct {
|
||||
s.WriteString("DISTINCT ")
|
||||
b.WriteString("DISTINCT ")
|
||||
}
|
||||
if len(s.columns) > 0 {
|
||||
s.IdentComma(s.columns...)
|
||||
b.IdentComma(s.columns...)
|
||||
} else {
|
||||
s.WriteString("*")
|
||||
b.WriteString("*")
|
||||
}
|
||||
s.WriteString(" FROM ")
|
||||
b.WriteString(" FROM ")
|
||||
switch t := s.from.(type) {
|
||||
case *SelectTable:
|
||||
t.SetDialect(s.dialect)
|
||||
s.WriteString(t.ref())
|
||||
b.WriteString(t.ref())
|
||||
case *Selector:
|
||||
t.SetDialect(s.dialect)
|
||||
query, args := t.Query()
|
||||
s.Nested(func(b *Builder) {
|
||||
b.Nested(func(b *Builder) {
|
||||
b.WriteString(query)
|
||||
})
|
||||
s.WriteString(" AS ")
|
||||
s.Ident(t.as)
|
||||
s.args = append(s.args, args...)
|
||||
b.WriteString(" AS ")
|
||||
b.Ident(t.as)
|
||||
b.args = append(b.args, args...)
|
||||
}
|
||||
for _, join := range s.joins {
|
||||
s.WriteString(" " + join.kind + " ")
|
||||
b.WriteString(" " + join.kind + " ")
|
||||
switch view := join.table.(type) {
|
||||
case *SelectTable:
|
||||
view.SetDialect(s.dialect)
|
||||
s.WriteString(view.ref())
|
||||
b.WriteString(view.ref())
|
||||
case *Selector:
|
||||
view.SetDialect(s.dialect)
|
||||
query, args := view.Query()
|
||||
s.Nested(func(b *Builder) {
|
||||
b.Nested(func(b *Builder) {
|
||||
b.WriteString(query)
|
||||
})
|
||||
s.WriteString(" AS ")
|
||||
s.Ident(view.as)
|
||||
s.args = append(s.args, args...)
|
||||
b.WriteString(" AS ")
|
||||
b.Ident(view.as)
|
||||
b.args = append(b.args, args...)
|
||||
}
|
||||
if join.on != "" {
|
||||
s.WriteString(" ON ")
|
||||
s.WriteString(join.on)
|
||||
b.WriteString(" ON ")
|
||||
b.WriteString(join.on)
|
||||
}
|
||||
}
|
||||
if s.where != nil {
|
||||
s.WriteString(" WHERE ")
|
||||
s.Builder.Join(s.where)
|
||||
b.WriteString(" WHERE ")
|
||||
b.Join(s.where)
|
||||
}
|
||||
if len(s.group) > 0 {
|
||||
s.WriteString(" GROUP BY ")
|
||||
s.IdentComma(s.group...)
|
||||
b.WriteString(" GROUP BY ")
|
||||
b.IdentComma(s.group...)
|
||||
}
|
||||
if s.having != nil {
|
||||
s.WriteString(" HAVING ")
|
||||
s.Builder.Join(s.having)
|
||||
b.WriteString(" HAVING ")
|
||||
b.Join(s.having)
|
||||
}
|
||||
if len(s.order) > 0 {
|
||||
s.WriteString(" ORDER BY ")
|
||||
s.IdentComma(s.order...)
|
||||
b.WriteString(" ORDER BY ")
|
||||
b.IdentComma(s.order...)
|
||||
}
|
||||
if s.limit != nil {
|
||||
s.WriteString(" LIMIT ")
|
||||
s.Arg(*s.limit)
|
||||
b.WriteString(" LIMIT ")
|
||||
b.Arg(*s.limit)
|
||||
}
|
||||
if s.offset != nil {
|
||||
s.WriteString(" OFFSET ")
|
||||
s.Arg(*s.offset)
|
||||
b.WriteString(" OFFSET ")
|
||||
b.Arg(*s.offset)
|
||||
}
|
||||
return s.String(), s.args
|
||||
return b.String(), b.args
|
||||
}
|
||||
|
||||
// implement the table view interface.
|
||||
@@ -1796,8 +1809,11 @@ func (b Builder) Query() (string, []interface{}) {
|
||||
|
||||
// clone returns a shallow clone of a builder.
|
||||
func (b Builder) clone() Builder {
|
||||
c := Builder{dialect: b.dialect, args: append([]interface{}{}, b.args...)}
|
||||
c.Buffer.Write(c.Bytes())
|
||||
c := Builder{dialect: b.dialect}
|
||||
if len(b.args) > 0 {
|
||||
c.args = append(c.args, b.args...)
|
||||
}
|
||||
c.Buffer.Write(b.Bytes())
|
||||
return c
|
||||
}
|
||||
|
||||
|
||||
@@ -794,7 +794,7 @@ func templateEntTmpl() (*asset, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "template/ent.tmpl", size: 4173, mode: os.FileMode(420), modTime: time.Unix(1571161760, 0)}
|
||||
info := bindataFileInfo{name: "template/ent.tmpl", size: 4173, mode: os.FileMode(420), modTime: time.Unix(1571167239, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
@@ -854,7 +854,7 @@ func templateImportTmpl() (*asset, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "template/import.tmpl", size: 1048, mode: os.FileMode(420), modTime: time.Unix(1571161760, 0)}
|
||||
info := bindataFileInfo{name: "template/import.tmpl", size: 1048, mode: os.FileMode(420), modTime: time.Unix(1571167239, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
@@ -205,6 +205,12 @@ func Clone(t *testing.T, client *ent.Client) {
|
||||
base := client.File.Query().Where(file.Name("foo"))
|
||||
require.Equal(t, f1.Size, base.Clone().Where(file.Size(f1.Size)).OnlyX(ctx).Size)
|
||||
require.Equal(t, f2.Size, base.Clone().Where(file.Size(f2.Size)).OnlyX(ctx).Size)
|
||||
// ensure clone emits valid code.
|
||||
query := client.Pet.Query().Where(pet.Name("unknown")).QueryTeam()
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := query.Clone().Where(user.Name("unknown")).First(ctx)
|
||||
require.True(t, ent.IsNotFound(err), "should not return syntax error")
|
||||
}
|
||||
}
|
||||
|
||||
func Paging(t *testing.T, client *ent.Client) {
|
||||
|
||||
Reference in New Issue
Block a user