From 09be472be8f9577998a76471b865484927f6182e Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Thu, 10 Sep 2020 20:48:34 +0300 Subject: [PATCH] dialect/sql: add option for adding and getting error from builders --- dialect/sql/builder.go | 116 ++++++++++++++++++++++++++++-------- dialect/sql/builder_test.go | 10 ++++ 2 files changed, 101 insertions(+), 25 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 45bb8219f..e3b6946ce 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1101,7 +1101,8 @@ func (p *Predicate) EqualFold(col, sub string) *Predicate { return p.Append(func(b *Builder) { f := &Func{} f.SetDialect(b.dialect) - b.Ident(f.Lower(col)) + f.Lower(col) + b.WriteString(f.String()) b.WriteOp(OpEQ) b.Arg(strings.ToLower(sub)) }) @@ -1131,7 +1132,8 @@ func (p *Predicate) ContainsFold(col, sub string) *Predicate { case dialect.Postgres: b.Ident(col).WriteString(" ILIKE ") default: // SQLite. - b.Ident(f.Lower(col)).WriteString(" LIKE ") + f.Lower(col) + b.WriteString(f.String()).WriteString(" LIKE ") } b.Arg("%" + strings.ToLower(sub) + "%") }) @@ -1221,66 +1223,107 @@ func (p *Predicate) mayWrap(preds []*Predicate, b *Builder, op string) { // Func represents an SQL function. type Func struct { Builder + fns []func(*Builder) } // Lower wraps the given column with the LOWER function. // // P().EQ(sql.Lower("name"), "a8m") // -func Lower(name string) string { return Func{}.Lower(name) } +func Lower(ident string) string { + f := &Func{} + f.Lower(ident) + return f.String() +} // Lower wraps the given ident with the LOWER function. -func (f Func) Lower(name string) string { - return f.byName("LOWER", name) +func (f *Func) Lower(ident string) { + f.byName("LOWER", ident) } // Count wraps the ident with the COUNT aggregation function. -func Count(name string) string { return Func{}.Count(name) } +func Count(ident string) string { + f := &Func{} + f.Count(ident) + return f.String() +} // Count wraps the ident with the COUNT aggregation function. -func (f Func) Count(ident string) string { - return f.byName("COUNT", ident) +func (f *Func) Count(ident string) { + f.byName("COUNT", ident) } // Max wraps the ident with the MAX aggregation function. -func Max(name string) string { return Func{}.Max(name) } +func Max(ident string) string { + f := &Func{} + f.Max(ident) + return f.String() +} // Max wraps the ident with the MAX aggregation function. -func (f Func) Max(ident string) string { - return f.byName("MAX", ident) +func (f *Func) Max(ident string) { + f.byName("MAX", ident) } // Min wraps the ident with the MIN aggregation function. -func Min(name string) string { return Func{}.Min(name) } +func Min(ident string) string { + f := &Func{} + f.Min(ident) + return f.String() +} // Min wraps the ident with the MIN aggregation function. -func (f Func) Min(ident string) string { - return f.byName("MIN", ident) +func (f *Func) Min(ident string) { + f.byName("MIN", ident) } // Sum wraps the ident with the SUM aggregation function. -func Sum(name string) string { return Func{}.Sum(name) } +func Sum(ident string) string { + f := &Func{} + f.Sum(ident) + return f.String() +} // Sum wraps the ident with the SUM aggregation function. -func (f Func) Sum(ident string) string { - return f.byName("SUM", ident) +func (f *Func) Sum(ident string) { + f.byName("SUM", ident) } // Avg wraps the ident with the AVG aggregation function. -func Avg(name string) string { return Func{}.Avg(name) } +func Avg(ident string) string { + f := &Func{} + f.Avg(ident) + return f.String() +} // Avg wraps the ident with the AVG aggregation function. -func (f Func) Avg(ident string) string { - return f.byName("AVG", ident) +func (f *Func) Avg(ident string) { + f.byName("AVG", ident) } // byName wraps an identifier with a function name. -func (f Func) byName(fn, ident string) string { - f.WriteString(fn) - f.Nested(func(b *Builder) { - b.Ident(ident) +func (f *Func) byName(fn, ident string) { + f.Append(func(b *Builder) { + f.WriteString(fn) + f.Nested(func(b *Builder) { + b.Ident(ident) + }) }) - return f.String() +} + +// Append appends a new function to the function callbacks. +// The callback list are executed on call to String. +func (f *Func) Append(fn func(*Builder)) *Func { + f.fns = append(f.fns, fn) + return f +} + +// String implements the fmt.Stringer. +func (f *Func) String() string { + for _, fn := range f.fns { + fn(&f.Builder) + } + return f.Builder.String() } // As suffixed the given column with an alias (`a` AS `b`). @@ -1833,6 +1876,7 @@ type Builder struct { dialect string // configured dialect. args []interface{} // query parameters. total int // total number of parameters in query tree. + errs []error // errors that added during the query construction. } // Quote quotes the given identifier with the characters based @@ -1893,6 +1937,28 @@ func (b *Builder) WriteString(s string) *Builder { return b } +// AddError appends an error to the builder errors. +func (b *Builder) AddError(err error) *Builder { + b.errs = append(b.errs, err) + return b +} + +// Err returns a concatenated error of all errors encountered during +// the query-building, or were added manually by calling AddError. +func (b *Builder) Err() error { + if len(b.errs) == 0 { + return nil + } + br := strings.Builder{} + for i := range b.errs { + if i > 0 { + br.WriteString("; ") + } + br.WriteString(b.errs[i].Error()) + } + return fmt.Errorf(br.String()) +} + // An Op represents a predicate operator. type Op int diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index ba149bc28..e0c104192 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -5,6 +5,7 @@ package sql import ( + "fmt" "strconv" "strings" "testing" @@ -1242,3 +1243,12 @@ WHERE }) } } + +func TestBuilder_Err(t *testing.T) { + b := Select("i-") + require.NoError(t, b.Err()) + b.AddError(fmt.Errorf("invalid")) + require.EqualError(t, b.Err(), "invalid") + b.AddError(fmt.Errorf("unexpected")) + require.EqualError(t, b.Err(), "invalid; unexpected") +}