mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
add Clone methods for query builders (#10)
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/10 Closes T46957221 Reviewed By: idoshveki Differential Revision: D16278371 fbshipit-source-id: ca2b038fccb8fca6a7e8261444de27bdd63d0b00
This commit is contained in:
committed by
Facebook Github Bot
parent
b5cdb810b8
commit
dbe2afb946
@@ -9,17 +9,17 @@ import (
|
||||
"fbc/ent/dialect"
|
||||
)
|
||||
|
||||
// Node represents a builder step in the query.
|
||||
type Node interface {
|
||||
// Querier wraps the basic Query method.
|
||||
type Querier interface {
|
||||
// Query returns the query representation of the element and its arguments (if any).
|
||||
Query() (string, []interface{})
|
||||
}
|
||||
|
||||
// Nodes are list of queries join with space between them.
|
||||
type Nodes []Node
|
||||
// Queries are list of queries join with space between them.
|
||||
type Queries []Querier
|
||||
|
||||
// Query returns query representation of Nodes.
|
||||
func (n Nodes) Query() (string, []interface{}) {
|
||||
// Query returns query representation of Queriers.
|
||||
func (n Queries) Query() (string, []interface{}) {
|
||||
b := Builder{}
|
||||
for i := range n {
|
||||
if i > 0 {
|
||||
@@ -96,8 +96,8 @@ func (b *Builder) Pad() *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// Join joins a list of Nodes to the builder.
|
||||
func (b *Builder) Join(n ...Node) *Builder {
|
||||
// Join joins a list of Queriers to the builder.
|
||||
func (b *Builder) Join(n ...Querier) *Builder {
|
||||
for i := range n {
|
||||
query, args := n[i].Query()
|
||||
b.WriteString(query)
|
||||
@@ -106,8 +106,8 @@ func (b *Builder) Join(n ...Node) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// JoinComma joins a list of Nodes and adds comma between them.
|
||||
func (b *Builder) JoinComma(n ...Node) *Builder {
|
||||
// JoinComma joins a list of Queriers and adds comma between them.
|
||||
func (b *Builder) JoinComma(n ...Querier) *Builder {
|
||||
for i := range n {
|
||||
if i > 0 {
|
||||
b.Comma()
|
||||
@@ -130,6 +130,13 @@ func (b *Builder) Nested(f func(*Builder)) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// clone returns a shallow clone of a builder.
|
||||
func (b Builder) clone() Builder {
|
||||
c := Builder{args: append([]interface{}{}, b.args...)}
|
||||
c.Buffer.Write(c.Bytes())
|
||||
return c
|
||||
}
|
||||
|
||||
// ColumnBuilder is a builder for column definition in table creation.
|
||||
type ColumnBuilder struct {
|
||||
b Builder
|
||||
@@ -177,7 +184,7 @@ type TableBuilder struct {
|
||||
collation string // table collation.
|
||||
columns []*ColumnBuilder // table columns.
|
||||
primary []string // primary key.
|
||||
constraints []Node // foreign keys and indices.
|
||||
constraints []Querier // foreign keys and indices.
|
||||
}
|
||||
|
||||
// CreateTable returns a query builder for the `CREATE TABLE` statement.
|
||||
@@ -217,23 +224,23 @@ func (t *TableBuilder) PrimaryKey(column ...string) *TableBuilder {
|
||||
|
||||
// ForeignKeys adds a list of foreign-keys to the statement (without constraints).
|
||||
func (t *TableBuilder) ForeignKeys(fks ...*ForeignKeyBuilder) *TableBuilder {
|
||||
nodes := make([]Node, len(fks))
|
||||
Queriers := make([]Querier, len(fks))
|
||||
for i := range fks {
|
||||
// erase the constraint symbol/name.
|
||||
fks[i].symbol = ""
|
||||
nodes[i] = fks[i]
|
||||
Queriers[i] = fks[i]
|
||||
}
|
||||
t.constraints = append(t.constraints, nodes...)
|
||||
t.constraints = append(t.constraints, Queriers...)
|
||||
return t
|
||||
}
|
||||
|
||||
// Constraints adds a list of foreign-key constraints to the statement.
|
||||
func (t *TableBuilder) Constraints(fks ...*ForeignKeyBuilder) *TableBuilder {
|
||||
nodes := make([]Node, len(fks))
|
||||
Queriers := make([]Querier, len(fks))
|
||||
for i := range fks {
|
||||
nodes[i] = &Wrapper{"CONSTRAINT %s", fks[i]}
|
||||
Queriers[i] = &Wrapper{"CONSTRAINT %s", fks[i]}
|
||||
}
|
||||
t.constraints = append(t.constraints, nodes...)
|
||||
t.constraints = append(t.constraints, Queriers...)
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -303,9 +310,9 @@ func (t *DescribeBuilder) Query() (string, []interface{}) {
|
||||
|
||||
// TableAlter is a query builder for `ALTER TABLE` statement.
|
||||
type TableAlter struct {
|
||||
b Builder
|
||||
name string // table to alter.
|
||||
nodes []Node // columns and foreign-keys to add.
|
||||
b Builder
|
||||
name string // table to alter.
|
||||
Queriers []Querier // columns and foreign-keys to add.
|
||||
}
|
||||
|
||||
// AlterTable returns a query builder for the `ALTER TABLE` statement.
|
||||
@@ -318,19 +325,19 @@ func AlterTable(name string) *TableAlter { return &TableAlter{b: Builder{}, name
|
||||
|
||||
// AddColumn appends the `ADD COLUMN` clause to the given `ALTER TABLE` statement.
|
||||
func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter {
|
||||
t.nodes = append(t.nodes, &Wrapper{"ADD COLUMN %s", c})
|
||||
t.Queriers = append(t.Queriers, &Wrapper{"ADD COLUMN %s", c})
|
||||
return t
|
||||
}
|
||||
|
||||
// Modify appends the `MODIFY COLUMN` clause to the given `ALTER TABLE` statement.
|
||||
func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter {
|
||||
t.nodes = append(t.nodes, &Wrapper{"MODIFY COLUMN %s", c})
|
||||
t.Queriers = append(t.Queriers, &Wrapper{"MODIFY COLUMN %s", c})
|
||||
return t
|
||||
}
|
||||
|
||||
// AddForeignKey adds a foreign key constraint to the `ALTER TABLE` statement.
|
||||
func (t *TableAlter) AddForeignKey(fk *ForeignKeyBuilder) *TableAlter {
|
||||
t.nodes = append(t.nodes, &Wrapper{"ADD CONSTRAINT %s", fk})
|
||||
t.Queriers = append(t.Queriers, &Wrapper{"ADD CONSTRAINT %s", fk})
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -339,7 +346,7 @@ func (t *TableAlter) Query() (string, []interface{}) {
|
||||
t.b.WriteString("ALTER TABLE ")
|
||||
t.b.Append(t.name)
|
||||
t.b.Pad()
|
||||
t.b.JoinComma(t.nodes...)
|
||||
t.b.JoinComma(t.Queriers...)
|
||||
return t.b.String(), t.b.args
|
||||
}
|
||||
|
||||
@@ -921,6 +928,14 @@ func (p *Predicate) merge(pred *Predicate) *Predicate {
|
||||
return p
|
||||
}
|
||||
|
||||
// clone returns a shallow clone of p.
|
||||
func (p *Predicate) clone() *Predicate {
|
||||
if p == nil {
|
||||
return p
|
||||
}
|
||||
return &Predicate{p.b.clone()}
|
||||
}
|
||||
|
||||
// TableView is a view that returns a table view. Can ne a Table, Selector or a View (WITH statement).
|
||||
type TableView interface {
|
||||
view()
|
||||
@@ -1201,6 +1216,29 @@ func (s *Selector) Count(columns ...string) *Selector {
|
||||
return s
|
||||
}
|
||||
|
||||
// Clone returns a duplicate of the selector, including all associated steps. It can be
|
||||
// used to prepare common SELECT statements and use them differently after the clone is made.
|
||||
func (s *Selector) Clone() *Selector {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &Selector{
|
||||
as: s.as,
|
||||
or: s.or,
|
||||
not: s.not,
|
||||
from: s.from,
|
||||
limit: s.limit,
|
||||
offset: s.offset,
|
||||
distinct: s.distinct,
|
||||
where: s.where.clone(),
|
||||
having: s.having.clone(),
|
||||
joins: append([]join{}, s.joins...),
|
||||
group: append([]string{}, s.group...),
|
||||
order: append([]string{}, s.order...),
|
||||
columns: append([]string{}, s.columns...),
|
||||
}
|
||||
}
|
||||
|
||||
// Asc adds the ASC suffix for the given column.
|
||||
func Asc(column string) string {
|
||||
b := Builder{}
|
||||
@@ -1280,7 +1318,7 @@ func (s *Selector) Query() (string, []interface{}) {
|
||||
b.WriteString(query)
|
||||
b.args = append(b.args, args...)
|
||||
}
|
||||
if s.order != nil {
|
||||
if len(s.order) > 0 {
|
||||
b.WriteString(" ORDER BY ")
|
||||
b.AppendComma(s.order...)
|
||||
}
|
||||
@@ -1307,7 +1345,7 @@ type WithBuilder struct {
|
||||
|
||||
// With returns a new builder for the `WITH` statement.
|
||||
//
|
||||
// n := Nodes{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))}
|
||||
// n := Queriers{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))}
|
||||
// return n.Query()
|
||||
//
|
||||
func With(name string) *WithBuilder {
|
||||
@@ -1336,21 +1374,21 @@ func (w *WithBuilder) Query() (string, []interface{}) {
|
||||
// implement the table view interface.
|
||||
func (*WithBuilder) view() {}
|
||||
|
||||
// Wrapper wraps a given node with different format.
|
||||
// Wrapper wraps a given Querier with different format.
|
||||
// Used to prefix/suffix other queries.
|
||||
type Wrapper struct {
|
||||
format string
|
||||
wrapped Node
|
||||
wrapped Querier
|
||||
}
|
||||
|
||||
// Query returns query representation of a wrapped node.
|
||||
// Query returns query representation of a wrapped Querier.
|
||||
func (w *Wrapper) Query() (string, []interface{}) {
|
||||
query, args := w.wrapped.Query()
|
||||
return fmt.Sprintf(w.format, query), args
|
||||
}
|
||||
|
||||
// Raw returns a raw sql node that is placed as-is in the query.
|
||||
func Raw(s string) Node { return &raw{s} }
|
||||
// Raw returns a raw sql Querier that is placed as-is in the query.
|
||||
func Raw(s string) Querier { return &raw{s} }
|
||||
|
||||
type raw struct{ s string }
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
func TestBuilder(t *testing.T) {
|
||||
tests := []struct {
|
||||
input Node
|
||||
input Querier
|
||||
wantQuery string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
@@ -230,7 +230,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantQuery: "SELECT * FROM `users` AS `u`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Table("groups").As("g")
|
||||
return Select(t1.C("id"), t2.C("name")).From(t1).Join(t2)
|
||||
@@ -238,7 +238,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Table("groups").As("g")
|
||||
return Select(t1.C("id"), t2.C("name")).
|
||||
@@ -249,7 +249,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Table("groups").As("g")
|
||||
return Select(t1.C("id"), t2.C("name")).
|
||||
@@ -262,14 +262,14 @@ func TestBuilder(t *testing.T) {
|
||||
wantArgs: []interface{}{"bar"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
return Select(t1.Columns("name", "age")...).From(t1)
|
||||
}(),
|
||||
wantQuery: "SELECT `u`.`name`, `u`.`age` FROM `users` AS `u`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Select().From(Table("groups")).Where(EQ("user_id", 10)).As("g")
|
||||
return Select(t1.C("id"), t2.C("name")).
|
||||
@@ -281,7 +281,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantArgs: []interface{}{10},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
selector := Select().Where(EQ("name", "foo").Or().EQ("name", "bar"))
|
||||
return Delete("users").FromSelect(selector)
|
||||
}(),
|
||||
@@ -289,14 +289,14 @@ func TestBuilder(t *testing.T) {
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
selector := Select().From(Table("users")).As("t")
|
||||
return selector.Select(selector.C("name"))
|
||||
}(),
|
||||
wantQuery: "SELECT `t`.`name` FROM `users`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
selector := Select().From(Table("groups")).Where(EQ("name", "foo"))
|
||||
return Delete("users").FromSelect(selector)
|
||||
}(),
|
||||
@@ -304,7 +304,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantArgs: []interface{}{"foo"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
selector := Select()
|
||||
return Delete("users").FromSelect(selector)
|
||||
}(),
|
||||
@@ -316,7 +316,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users")
|
||||
return Select().
|
||||
From(t1).
|
||||
@@ -326,7 +326,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantArgs: []interface{}{"pedro"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users")
|
||||
return Select().
|
||||
From(t1).
|
||||
@@ -344,7 +344,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantQuery: "SELECT COUNT(DISTINCT `id`) FROM `users`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users")
|
||||
t2 := Select().From(Table("groups"))
|
||||
t3 := Select().Count().From(t1).Join(t1).On(t2.C("id"), t1.C("blocked_id"))
|
||||
@@ -357,7 +357,7 @@ func TestBuilder(t *testing.T) {
|
||||
wantQuery: "SELECT SUM(`age`), MIN(`age`) FROM `users`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
input: func() Querier {
|
||||
t1 := Table("users").As("u")
|
||||
return Select(As(Max(t1.C("age")), "max_age")).From(t1)
|
||||
}(),
|
||||
@@ -402,9 +402,17 @@ func TestBuilder(t *testing.T) {
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
input: Nodes{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))},
|
||||
input: Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))},
|
||||
wantQuery: "WITH users_view AS (SELECT * FROM `users`) SELECT * FROM `users_view`",
|
||||
},
|
||||
{
|
||||
input: func() Querier {
|
||||
base := Select("*").From(Table("groups"))
|
||||
return Queries{With("groups").As(base.Clone().Where(EQ("name", "bar"))), base.Select("age")}
|
||||
}(),
|
||||
wantQuery: "WITH groups AS (SELECT * FROM `groups` WHERE `name` = ?) SELECT `age` FROM `groups`",
|
||||
wantArgs: []interface{}{"bar"},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user