dialect/sql/schema: mock result in write driver if returning clause is present (#3369)

This commit is contained in:
Jannik Clausen
2023-03-07 10:43:01 +01:00
committed by GitHub
parent 809b22be7c
commit 55e98b5b9b
2 changed files with 96 additions and 22 deletions

View File

@@ -111,7 +111,46 @@ func (w *WriteDriver) Query(ctx context.Context, query string, args, res any) er
return err
}
if rr, ok := res.(*sql.Rows); ok {
*rr = sql.Rows{ColumnScanner: noRows{}}
cols := func() []string {
// If the query has a RETURNING clause, mock the result.
var clause string
outer:
for i := 0; i < len(query); i++ {
switch q := query[i]; {
case q == '\'', q == '"', q == '`': // string or identifier
_, skip := skipQuoted(query, i)
if skip == -1 {
return nil // malformed SQL
}
i = skip
continue
case reReturning.MatchString(query[i:]):
var j int
inner:
// Forward until next unquoted ';' appears, or we reach the end of the query.
for j = i; j < len(query); j++ {
switch query[j] {
case '\'', '"', '`': // string or identifier
_, skip := skipQuoted(query, j)
if skip == -1 {
return nil // malformed RETURNING clause
}
j = skip
case ';':
break inner
}
}
clause = query[i:j]
break outer
}
}
cols := strings.Split(reReturning.ReplaceAllString(clause, ""), ",")
for i := range cols {
cols[i] = strings.TrimSpace(cols[i])
}
return cols
}()
*rr = sql.Rows{ColumnScanner: &noRows{cols: cols}}
}
return nil
}
@@ -237,22 +276,7 @@ var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`)
// trimReturning trims any RETURNING suffix from INSERT/UPDATE queries.
// Note, that the output may be incorrect or unsafe SQL and require manual changes.
func trimReturning(query []byte) []byte {
var returning = []byte("returning")
var (
b bytes.Buffer
skipQuoted = func(query []byte, idx int) ([]byte, int) {
for j := idx + 1; j < len(query); j++ {
switch query[j] {
case '\\':
j++
case query[idx]:
return query[idx : j+1], j
}
}
// Unexpected EOS.
return query, -1
}
)
var b bytes.Buffer
loop:
for i := 0; i < len(query); i++ {
switch q := query[i]; {
@@ -266,7 +290,7 @@ loop:
continue
case reReturning.Match(query[i:]):
// Forward until next unquoted ';' appears.
for j := i + len(returning); j < len(query); j++ {
for j := i; j < len(query); j++ { // skip "RETURNING"
switch query[j] {
case '\'', '"', '`': // string or identifier
_, skip := skipQuoted(query, j)
@@ -286,6 +310,19 @@ loop:
return b.Bytes()
}
func skipQuoted[T []byte | string](query T, idx int) (T, int) {
for j := idx + 1; j < len(query); j++ {
switch query[j] {
case '\\':
j++
case query[idx]:
return query[idx : j+1], j
}
}
// Unexpected EOS.
return query, -1
}
// Tx writes the transaction start.
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
return dialect.NopTx(w), nil
@@ -298,11 +335,23 @@ func (noResult) LastInsertId() (int64, error) { return 0, nil }
func (noResult) RowsAffected() (int64, error) { return 0, nil }
// noRows represents no rows.
type noRows struct{ sql.ColumnScanner }
type noRows struct {
sql.ColumnScanner
cols []string
done bool
}
func (noRows) Close() error { return nil }
func (noRows) Err() error { return nil }
func (noRows) Next() bool { return false }
func (*noRows) Close() error { return nil }
func (*noRows) Err() error { return nil }
func (r *noRows) Next() bool {
if !r.done {
r.done = true
return true
}
return false
}
func (r *noRows) Columns() ([]string, error) { return r.cols, nil }
func (*noRows) Scan(...any) error { return nil }
type nopDriver struct {
dialect.Driver

View File

@@ -70,6 +70,24 @@ func TestWriteDriver(t *testing.T) {
err = w.Query(ctx, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id`, nil, nil)
require.NoError(t, err)
require.Equal(t, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id;`+"\n", b.String())
// correct columns are extracted from a returning clause and returned by sql.ColumnScanner.
for q, cols := range map[string][]string{
`INSERT INTO "users" (name) VALUES("a8m") RETURNING id`: {"id"},
`INSERT INTO "users" (name) VALUES("a8m") RETURNING id, "name"`: {"id", `"name"`},
`INSERT INTO "users" (name) VALUES("a8m") RETURNING "id", "name"`: {`"id"`, `"name"`},
`INSERT INTO "users" (name) VALUES("a8m") RETURNING "id", "name"; DROP "groups"`: {`"id"`, `"name"`},
} {
var rows sql.Rows
err = w.Query(ctx, q, nil, &rows)
require.NoError(t, err)
require.True(t, rows.Next())
c, err := rows.Columns()
require.NoError(t, err)
require.Equal(t, cols, c)
require.NoError(t, rows.Scan())
}
b.Reset()
}
func TestDirWriter(t *testing.T) {
@@ -134,6 +152,13 @@ func TestDirWriter(t *testing.T) {
"-- Seed groups table\nINSERT INTO `groups` (`name`) VALUES ('admins');\n",
}, ""),
},
{
dialect.SQLite + " no space",
[]string{"INSERT INTO `users` (`name`) VALUES (?)RETURNING `id`"},
[]string{"Seed users table"},
[][]any{{"masseelch"}},
"-- Seed users table\nINSERT INTO `users` (`name`) VALUES ('masseelch');\n",
},
} {
t.Run(tt.dialect, func(t *testing.T) {
var (