diff --git a/dialect/sql/schema/writer.go b/dialect/sql/schema/writer.go index 0293f14f6..611caacb1 100644 --- a/dialect/sql/schema/writer.go +++ b/dialect/sql/schema/writer.go @@ -111,7 +111,46 @@ func (w *WriteDriver) Query(ctx context.Context, query string, args, res any) er return err } if rr, ok := res.(*sql.Rows); ok { - *rr = sql.Rows{ColumnScanner: noRows{}} + cols := func() []string { + // If the query has a RETURNING clause, mock the result. + var clause string + outer: + for i := 0; i < len(query); i++ { + switch q := query[i]; { + case q == '\'', q == '"', q == '`': // string or identifier + _, skip := skipQuoted(query, i) + if skip == -1 { + return nil // malformed SQL + } + i = skip + continue + case reReturning.MatchString(query[i:]): + var j int + inner: + // Forward until next unquoted ';' appears, or we reach the end of the query. + for j = i; j < len(query); j++ { + switch query[j] { + case '\'', '"', '`': // string or identifier + _, skip := skipQuoted(query, j) + if skip == -1 { + return nil // malformed RETURNING clause + } + j = skip + case ';': + break inner + } + } + clause = query[i:j] + break outer + } + } + cols := strings.Split(reReturning.ReplaceAllString(clause, ""), ",") + for i := range cols { + cols[i] = strings.TrimSpace(cols[i]) + } + return cols + }() + *rr = sql.Rows{ColumnScanner: &noRows{cols: cols}} } return nil } @@ -237,22 +276,7 @@ var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`) // trimReturning trims any RETURNING suffix from INSERT/UPDATE queries. // Note, that the output may be incorrect or unsafe SQL and require manual changes. func trimReturning(query []byte) []byte { - var returning = []byte("returning") - var ( - b bytes.Buffer - skipQuoted = func(query []byte, idx int) ([]byte, int) { - for j := idx + 1; j < len(query); j++ { - switch query[j] { - case '\\': - j++ - case query[idx]: - return query[idx : j+1], j - } - } - // Unexpected EOS. - return query, -1 - } - ) + var b bytes.Buffer loop: for i := 0; i < len(query); i++ { switch q := query[i]; { @@ -266,7 +290,7 @@ loop: continue case reReturning.Match(query[i:]): // Forward until next unquoted ';' appears. - for j := i + len(returning); j < len(query); j++ { + for j := i; j < len(query); j++ { // skip "RETURNING" switch query[j] { case '\'', '"', '`': // string or identifier _, skip := skipQuoted(query, j) @@ -286,6 +310,19 @@ loop: return b.Bytes() } +func skipQuoted[T []byte | string](query T, idx int) (T, int) { + for j := idx + 1; j < len(query); j++ { + switch query[j] { + case '\\': + j++ + case query[idx]: + return query[idx : j+1], j + } + } + // Unexpected EOS. + return query, -1 +} + // Tx writes the transaction start. func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) { return dialect.NopTx(w), nil @@ -298,11 +335,23 @@ func (noResult) LastInsertId() (int64, error) { return 0, nil } func (noResult) RowsAffected() (int64, error) { return 0, nil } // noRows represents no rows. -type noRows struct{ sql.ColumnScanner } +type noRows struct { + sql.ColumnScanner + cols []string + done bool +} -func (noRows) Close() error { return nil } -func (noRows) Err() error { return nil } -func (noRows) Next() bool { return false } +func (*noRows) Close() error { return nil } +func (*noRows) Err() error { return nil } +func (r *noRows) Next() bool { + if !r.done { + r.done = true + return true + } + return false +} +func (r *noRows) Columns() ([]string, error) { return r.cols, nil } +func (*noRows) Scan(...any) error { return nil } type nopDriver struct { dialect.Driver diff --git a/dialect/sql/schema/writer_test.go b/dialect/sql/schema/writer_test.go index 2d35a3938..35dd00439 100644 --- a/dialect/sql/schema/writer_test.go +++ b/dialect/sql/schema/writer_test.go @@ -70,6 +70,24 @@ func TestWriteDriver(t *testing.T) { err = w.Query(ctx, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id`, nil, nil) require.NoError(t, err) require.Equal(t, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id;`+"\n", b.String()) + + // correct columns are extracted from a returning clause and returned by sql.ColumnScanner. + for q, cols := range map[string][]string{ + `INSERT INTO "users" (name) VALUES("a8m") RETURNING id`: {"id"}, + `INSERT INTO "users" (name) VALUES("a8m") RETURNING id, "name"`: {"id", `"name"`}, + `INSERT INTO "users" (name) VALUES("a8m") RETURNING "id", "name"`: {`"id"`, `"name"`}, + `INSERT INTO "users" (name) VALUES("a8m") RETURNING "id", "name"; DROP "groups"`: {`"id"`, `"name"`}, + } { + var rows sql.Rows + err = w.Query(ctx, q, nil, &rows) + require.NoError(t, err) + require.True(t, rows.Next()) + c, err := rows.Columns() + require.NoError(t, err) + require.Equal(t, cols, c) + require.NoError(t, rows.Scan()) + } + b.Reset() } func TestDirWriter(t *testing.T) { @@ -134,6 +152,13 @@ func TestDirWriter(t *testing.T) { "-- Seed groups table\nINSERT INTO `groups` (`name`) VALUES ('admins');\n", }, ""), }, + { + dialect.SQLite + " no space", + []string{"INSERT INTO `users` (`name`) VALUES (?)RETURNING `id`"}, + []string{"Seed users table"}, + [][]any{{"masseelch"}}, + "-- Seed users table\nINSERT INTO `users` (`name`) VALUES ('masseelch');\n", + }, } { t.Run(tt.dialect, func(t *testing.T) { var (