From a6df15101fffc656d330962f4605af808f7a3c89 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Tue, 15 Oct 2019 13:37:18 -0700 Subject: [PATCH] dialect/sql: fix bad query generation on clone Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/97 Reviewed By: alexsn Differential Revision: D17931134 fbshipit-source-id: e181e253f6e29b757f5ca647195d87d7ac89a0c8 --- dialect/sql/builder.go | 86 +++++++++++++++++----------- entc/gen/internal/bindata.go | 4 +- entc/integration/integration_test.go | 6 ++ 3 files changed, 59 insertions(+), 37 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index e67c36e90..7f8d70436 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1205,6 +1205,14 @@ type join struct { table TableView } +// clone a joiner. +func (j join) clone() join { + if sel, ok := j.table.(*Selector); ok { + j.table = sel.Clone() + } + return j +} + // Selector is a builder for the `SELECT` statement. type Selector struct { Builder @@ -1386,6 +1394,10 @@ func (s *Selector) Clone() *Selector { if s == nil { return nil } + joins := make([]join, len(s.joins)) + for i := range s.joins { + joins[i] = s.joins[i].clone() + } return &Selector{ Builder: s.Builder.clone(), as: s.as, @@ -1397,7 +1409,7 @@ func (s *Selector) Clone() *Selector { distinct: s.distinct, where: s.where.clone(), having: s.having.clone(), - joins: append([]join{}, s.joins...), + joins: append([]join{}, joins...), group: append([]string{}, s.group...), order: append([]string{}, s.order...), columns: append([]string{}, s.columns...), @@ -1437,77 +1449,78 @@ func (s *Selector) Having(p *Predicate) *Selector { } // Query returns query representation of a `SELECT` statement. -func (s *Selector) Query() (string, []interface{}) { - s.WriteString("SELECT ") +func (s Selector) Query() (string, []interface{}) { + b := s.Builder.clone() + b.WriteString("SELECT ") if s.distinct { - s.WriteString("DISTINCT ") + b.WriteString("DISTINCT ") } if len(s.columns) > 0 { - s.IdentComma(s.columns...) + b.IdentComma(s.columns...) } else { - s.WriteString("*") + b.WriteString("*") } - s.WriteString(" FROM ") + b.WriteString(" FROM ") switch t := s.from.(type) { case *SelectTable: t.SetDialect(s.dialect) - s.WriteString(t.ref()) + b.WriteString(t.ref()) case *Selector: t.SetDialect(s.dialect) query, args := t.Query() - s.Nested(func(b *Builder) { + b.Nested(func(b *Builder) { b.WriteString(query) }) - s.WriteString(" AS ") - s.Ident(t.as) - s.args = append(s.args, args...) + b.WriteString(" AS ") + b.Ident(t.as) + b.args = append(b.args, args...) } for _, join := range s.joins { - s.WriteString(" " + join.kind + " ") + b.WriteString(" " + join.kind + " ") switch view := join.table.(type) { case *SelectTable: view.SetDialect(s.dialect) - s.WriteString(view.ref()) + b.WriteString(view.ref()) case *Selector: view.SetDialect(s.dialect) query, args := view.Query() - s.Nested(func(b *Builder) { + b.Nested(func(b *Builder) { b.WriteString(query) }) - s.WriteString(" AS ") - s.Ident(view.as) - s.args = append(s.args, args...) + b.WriteString(" AS ") + b.Ident(view.as) + b.args = append(b.args, args...) } if join.on != "" { - s.WriteString(" ON ") - s.WriteString(join.on) + b.WriteString(" ON ") + b.WriteString(join.on) } } if s.where != nil { - s.WriteString(" WHERE ") - s.Builder.Join(s.where) + b.WriteString(" WHERE ") + b.Join(s.where) } if len(s.group) > 0 { - s.WriteString(" GROUP BY ") - s.IdentComma(s.group...) + b.WriteString(" GROUP BY ") + b.IdentComma(s.group...) } if s.having != nil { - s.WriteString(" HAVING ") - s.Builder.Join(s.having) + b.WriteString(" HAVING ") + b.Join(s.having) } if len(s.order) > 0 { - s.WriteString(" ORDER BY ") - s.IdentComma(s.order...) + b.WriteString(" ORDER BY ") + b.IdentComma(s.order...) } if s.limit != nil { - s.WriteString(" LIMIT ") - s.Arg(*s.limit) + b.WriteString(" LIMIT ") + b.Arg(*s.limit) } if s.offset != nil { - s.WriteString(" OFFSET ") - s.Arg(*s.offset) + b.WriteString(" OFFSET ") + b.Arg(*s.offset) } - return s.String(), s.args + return b.String(), b.args } // implement the table view interface. @@ -1796,8 +1809,11 @@ func (b Builder) Query() (string, []interface{}) { // clone returns a shallow clone of a builder. func (b Builder) clone() Builder { - c := Builder{dialect: b.dialect, args: append([]interface{}{}, b.args...)} - c.Buffer.Write(c.Bytes()) + c := Builder{dialect: b.dialect} + if len(b.args) > 0 { + c.args = append(c.args, b.args...) + } + c.Buffer.Write(b.Bytes()) return c } diff --git a/entc/gen/internal/bindata.go b/entc/gen/internal/bindata.go index efa8a9e79..4587cf4ef 100644 --- a/entc/gen/internal/bindata.go +++ b/entc/gen/internal/bindata.go @@ -794,7 +794,7 @@ func templateEntTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/ent.tmpl", size: 4173, mode: os.FileMode(420), modTime: time.Unix(1571161760, 0)} + info := bindataFileInfo{name: "template/ent.tmpl", size: 4173, mode: os.FileMode(420), modTime: time.Unix(1571167239, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -854,7 +854,7 @@ func templateImportTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/import.tmpl", size: 1048, mode: os.FileMode(420), modTime: time.Unix(1571161760, 0)} + info := bindataFileInfo{name: "template/import.tmpl", size: 1048, mode: os.FileMode(420), modTime: time.Unix(1571167239, 0)} a := &asset{bytes: bytes, info: info} return a, nil } diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 212b08b4d..f117427df 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -205,6 +205,12 @@ func Clone(t *testing.T, client *ent.Client) { base := client.File.Query().Where(file.Name("foo")) require.Equal(t, f1.Size, base.Clone().Where(file.Size(f1.Size)).OnlyX(ctx).Size) require.Equal(t, f2.Size, base.Clone().Where(file.Size(f2.Size)).OnlyX(ctx).Size) + // ensure clone emits valid code. + query := client.Pet.Query().Where(pet.Name("unknown")).QueryTeam() + for i := 0; i < 10; i++ { + _, err := query.Clone().Where(user.Name("unknown")).First(ctx) + require.True(t, ent.IsNotFound(err), "should not return syntax error") + } } func Paging(t *testing.T, client *ent.Client) {