diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index e8ffd72fe..50d40712c 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -12,7 +12,6 @@ package sql import ( - "bytes" "context" "database/sql/driver" "fmt" @@ -738,11 +737,9 @@ func (i *InsertBuilder) Query() (string, []interface{}) { i.WriteByte('(').Args(v...).WriteByte(')') } } - if len(i.conflictColumns) > 0 { i.buildConflictHandling() } - if len(i.returning) > 0 && i.postgres() { i.WriteString(" RETURNING ") i.IdentComma(i.returning...) @@ -2232,11 +2229,11 @@ func (n Queries) Query() (string, []interface{}) { // Builder is the base query builder for the sql dsl. type Builder struct { - bytes.Buffer // underlying buffer. - dialect string // configured dialect. - args []interface{} // query parameters. - total int // total number of parameters in query tree. - errs []error // errors that added during the query construction. + sb *strings.Builder // underlying builder. + dialect string // configured dialect. + args []interface{} // query parameters. + total int // total number of parameters in query tree. + errs []error // errors that added during the query construction. } // Quote quotes the given identifier with the characters based @@ -2285,15 +2282,45 @@ func (b *Builder) IdentComma(s ...string) *Builder { return b } +// String returns the accumulated string. +func (b *Builder) String() string { + if b.sb == nil { + return "" + } + return b.sb.String() +} + // WriteByte wraps the Buffer.WriteByte to make it chainable with other methods. func (b *Builder) WriteByte(c byte) *Builder { - b.Buffer.WriteByte(c) + if b.sb == nil { + b.sb = &strings.Builder{} + } + b.sb.WriteByte(c) return b } // WriteString wraps the Buffer.WriteString to make it chainable with other methods. func (b *Builder) WriteString(s string) *Builder { - b.Buffer.WriteString(s) + if b.sb == nil { + b.sb = &strings.Builder{} + } + b.sb.WriteString(s) + return b +} + +// Len returns the number of accumulated bytes. +func (b *Builder) Len() int { + if b.sb == nil { + return 0 + } + return b.sb.Len() +} + +// Reset resets the Builder to be empty. +func (b *Builder) Reset() *Builder { + if b.sb != nil { + b.sb.Reset() + } return b } @@ -2467,11 +2494,11 @@ func (b *Builder) join(qs []Querier, sep string) *Builder { // Nested gets a callback, and wraps its result with parentheses. func (b *Builder) Nested(f func(*Builder)) *Builder { - nb := &Builder{dialect: b.dialect, total: b.total} + nb := &Builder{dialect: b.dialect, total: b.total, sb: &strings.Builder{}} nb.WriteByte('(') f(nb) nb.WriteByte(')') - nb.WriteTo(b) + b.WriteString(nb.String()) b.args = append(b.args, nb.args...) b.total = nb.total return b @@ -2505,11 +2532,13 @@ func (b Builder) Query() (string, []interface{}) { // clone returns a shallow clone of a builder. func (b Builder) clone() Builder { - c := Builder{dialect: b.dialect, total: b.total} + c := Builder{dialect: b.dialect, total: b.total, sb: &strings.Builder{}} if len(b.args) > 0 { c.args = append(c.args, b.args...) } - c.Buffer.Write(b.Bytes()) + if b.sb != nil { + c.sb.WriteString(b.sb.String()) + } return c }