From 5118e824225557bd4bfc245de045664a52d84c58 Mon Sep 17 00:00:00 2001 From: Marwan Sulaiman Date: Tue, 5 Jan 2021 15:10:36 -0500 Subject: [PATCH] dialect/sql: skip schema prefix when dialect is SQLite (#1135) * Skip schema prefix when dialect is SQLite * Update dialect/sql/builder.go Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com> * abstract schema checks Co-authored-by: Ariel Mashraki <7413593+a8m@users.noreply.github.com> --- dialect/sql/builder.go | 26 ++++++++++++-------------- dialect/sql/builder_test.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 1edcb18f2..5dde4bd54 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -682,9 +682,7 @@ func (i *InsertBuilder) Returning(columns ...string) *InsertBuilder { // Query returns query representation of an `INSERT INTO` statement. func (i *InsertBuilder) Query() (string, []interface{}) { i.WriteString("INSERT INTO ") - if i.schema != "" { - i.Ident(i.schema).WriteByte('.') - } + i.writeSchema(i.schema) i.Ident(i.table).Pad() if i.defaults != "" && len(i.columns) == 0 { i.WriteString(i.defaults) @@ -786,9 +784,7 @@ func (u *UpdateBuilder) Empty() bool { // Query returns query representation of an `UPDATE` statement. func (u *UpdateBuilder) Query() (string, []interface{}) { u.WriteString("UPDATE ") - if u.schema != "" { - u.Ident(u.schema).WriteByte('.') - } + u.writeSchema(u.schema) u.Ident(u.table).WriteString(" SET ") for i, c := range u.nulls { if i > 0 { @@ -870,9 +866,7 @@ func (d *DeleteBuilder) FromSelect(s *Selector) *DeleteBuilder { // Query returns query representation of a `DELETE` statement. func (d *DeleteBuilder) Query() (string, []interface{}) { d.WriteString("DELETE FROM ") - if d.schema != "" { - d.Ident(d.schema).WriteByte('.') - } + d.writeSchema(d.schema) d.Ident(d.table) if d.where != nil { d.WriteString(" WHERE ") @@ -1461,8 +1455,8 @@ func (s *SelectTable) C(column string) string { name = s.as } b := &Builder{dialect: s.dialect} - if s.schema != "" && s.as == "" { - b.Ident(s.schema).WriteByte('.') + if s.as == "" { + b.writeSchema(s.schema) } b.Ident(name).WriteByte('.').Ident(column) return b.String() @@ -1491,9 +1485,7 @@ func (s *SelectTable) ref() string { return s.name } b := &Builder{dialect: s.dialect} - if s.schema != "" { - b.Ident(s.schema).WriteByte('.') - } + b.writeSchema(s.schema) b.Ident(s.name) if s.as != "" { b.WriteString(" AS ") @@ -2057,6 +2049,12 @@ func (b *Builder) AddError(err error) *Builder { return b } +func (b *Builder) writeSchema(schema string) { + if schema != "" && b.dialect != dialect.SQLite { + b.Ident(schema).WriteByte('.') + } +} + // Err returns a concatenated error of all errors encountered during // the query-building, or were added manually by calling AddError. func (b *Builder) Err() error { diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index d9d44b9a9..41fd769ca 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -238,6 +238,11 @@ func TestBuilder(t *testing.T) { wantQuery: `INSERT INTO "mydb"."users" ("age") VALUES ($1)`, wantArgs: []interface{}{1}, }, + { + input: Dialect(dialect.SQLite).Insert("users").Columns("age").Values(1).Schema("mydb"), + wantQuery: "INSERT INTO `users` (`age`) VALUES (?)", + wantArgs: []interface{}{1}, + }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1).Returning("id"), wantQuery: `INSERT INTO "users" ("age") VALUES ($1) RETURNING "id"`, @@ -297,6 +302,11 @@ func TestBuilder(t *testing.T) { wantQuery: `UPDATE "mydb"."users" SET "name" = $1`, wantArgs: []interface{}{"foo"}, }, + { + input: Dialect(dialect.SQLite).Update("users").Set("name", "foo").Schema("mydb"), + wantQuery: "UPDATE `users` SET `name` = ?", + wantArgs: []interface{}{"foo"}, + }, { input: Update("users").Set("name", "foo").Set("age", 10), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ?", @@ -515,6 +525,13 @@ func TestBuilder(t *testing.T) { Schema("mydb"), wantQuery: "DELETE FROM `mydb`.`users` WHERE `parent_id` IS NOT NULL", }, + { + input: Dialect(dialect.SQLite). + Delete("users"). + Where(NotNull("parent_id")). + Schema("mydb"), + wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL", + }, { input: Dialect(dialect.Postgres). Delete("users"). @@ -1334,6 +1351,21 @@ WHERE wantQuery: "SELECT * FROM `s1`.`users` JOIN `s2`.`pets` AS `t0` ON `s1`.`users`.`id` = `t0`.`owner_id` WHERE `t0`.`name` = ?", wantArgs: []interface{}{"pedro"}, }, + { + input: func() Querier { + t1, t2 := Table("users").Schema("s1"), Table("pets").Schema("s2") + sel := Select("*"). + From(t1).Join(t2). + OnP(P(func(b *Builder) { + b.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("owner_id")) + })). + Where(EQ(t2.C("name"), "pedro")) + sel.SetDialect(dialect.SQLite) + return sel + }(), + wantQuery: "SELECT * FROM `users` JOIN `pets` AS `t0` ON `users`.`id` = `t0`.`owner_id` WHERE `t0`.`name` = ?", + wantArgs: []interface{}{"pedro"}, + }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) {