dialect/sql/schema: strip returning from seed (#3367)

This commit is contained in:
Jannik Clausen
2023-03-07 07:46:34 +01:00
committed by GitHub
parent 7e2da46e09
commit 809b22be7c
2 changed files with 146 additions and 19 deletions

View File

@@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"io"
"regexp"
"strconv"
"strings"
"time"
@@ -42,7 +43,9 @@ type (
)
// Write implements the io.Writer interface.
func (d *DirWriter) Write(p []byte) (int, error) { return d.b.Write(p) }
func (d *DirWriter) Write(p []byte) (int, error) {
return d.b.Write(trimReturning(p))
}
// Change converts all written statement so far into a migration
// change with the given comment.
@@ -229,6 +232,60 @@ func (w *WriteDriver) formatArg(v any) (string, error) {
}
}
var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`)
// trimReturning trims any RETURNING suffix from INSERT/UPDATE queries.
// Note, that the output may be incorrect or unsafe SQL and require manual changes.
func trimReturning(query []byte) []byte {
var returning = []byte("returning")
var (
b bytes.Buffer
skipQuoted = func(query []byte, idx int) ([]byte, int) {
for j := idx + 1; j < len(query); j++ {
switch query[j] {
case '\\':
j++
case query[idx]:
return query[idx : j+1], j
}
}
// Unexpected EOS.
return query, -1
}
)
loop:
for i := 0; i < len(query); i++ {
switch q := query[i]; {
case q == '\'', q == '"', q == '`': // string or identifier
s, skip := skipQuoted(query, i)
if skip == -1 {
return query
}
b.Write(s)
i = skip
continue
case reReturning.Match(query[i:]):
// Forward until next unquoted ';' appears.
for j := i + len(returning); j < len(query); j++ {
switch query[j] {
case '\'', '"', '`': // string or identifier
_, skip := skipQuoted(query, j)
if skip == -1 {
return query
}
j = skip
case ';':
b.WriteString(";")
i += j
continue loop
}
}
}
b.WriteByte(query[i])
}
return b.Bytes()
}
// Tx writes the transaction start.
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
return dialect.NopTx(w), nil

View File

@@ -73,22 +73,92 @@ func TestWriteDriver(t *testing.T) {
}
func TestDirWriter(t *testing.T) {
p := t.TempDir()
dir, err := migrate.NewLocalDir(p)
require.NoError(t, err)
w := &DirWriter{Dir: dir}
drv := NewWriteDriver(dialect.MySQL, w)
require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `a` = ?", []any{1}, nil))
w.Change("Comment 1.")
require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `b` = ?", []any{2}, nil))
w.Change("Comment 2.")
require.NoError(t, w.Flush("migration_file"))
files, err := os.ReadDir(p)
require.NoError(t, err)
require.Len(t, files, 2)
require.Contains(t, files[0].Name(), "_migration_file.sql")
buf, err := os.ReadFile(filepath.Join(p, files[0].Name()))
require.NoError(t, err)
require.Equal(t, "-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n", string(buf))
require.Equal(t, "atlas.sum", files[1].Name())
for _, tt := range []struct {
dialect string
exec []string
comments []string
args [][]any
want string
}{
{
dialect.MySQL,
[]string{
"UPDATE `test`.`users` SET `a` = ?",
"UPDATE `test`.`users` SET `b` = ?",
},
[]string{
"Comment 1.",
"Comment 2.",
},
[][]any{
{1},
{2},
},
"-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n",
},
{
dialect.Postgres,
[]string{
"INSERT INTO \"users\" (\"name\", \"email\") VALUES ($1, $2) RETURNING \"id\"",
"INSERT INTO \"groups\" (\"name\") VALUES ($1) RETURNING \"id\"",
},
[]string{
"Seed users table",
"Seed groups table",
},
[][]any{
{"masseelch", "j@ariga.io"},
{"admins"},
},
strings.Join([]string{
"-- Seed users table\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('masseelch', 'j@ariga.io');\n",
"-- Seed groups table\nINSERT INTO \"groups\" (\"name\") VALUES ('admins');\n",
}, ""),
},
{
dialect.SQLite,
[]string{
"INSERT INTO `users` (`name`, `email`) VALUES (?, ?) RETURNING `id`",
"INSERT INTO `groups` (`name`) VALUES (?) RETURNING `id`",
},
[]string{
"Seed users table",
"Seed groups table",
},
[][]any{
{"masseelch", "j@ariga.io"},
{"admins"},
},
strings.Join([]string{
"-- Seed users table\nINSERT INTO `users` (`name`, `email`) VALUES ('masseelch', 'j@ariga.io');\n",
"-- Seed groups table\nINSERT INTO `groups` (`name`) VALUES ('admins');\n",
}, ""),
},
} {
t.Run(tt.dialect, func(t *testing.T) {
var (
p = t.TempDir()
dir = func() migrate.Dir {
d, err := migrate.NewLocalDir(p)
require.NoError(t, err)
return d
}()
w = &DirWriter{Dir: dir}
drv = NewWriteDriver(tt.dialect, w)
)
for i := range tt.exec {
require.NoError(t, drv.Exec(context.Background(), tt.exec[i], tt.args[i], nil))
w.Change(tt.comments[i])
}
require.NoError(t, w.Flush("migration_file"))
files, err := os.ReadDir(p)
require.NoError(t, err)
require.Len(t, files, 2)
require.Contains(t, files[0].Name(), "_migration_file.sql")
buf, err := os.ReadFile(filepath.Join(p, files[0].Name()))
require.NoError(t, err)
require.Equal(t, tt.want, string(buf))
require.Equal(t, "atlas.sum", files[1].Name())
})
}
}