From 809b22be7c35160e6c6d1fd3124159c7afad6718 Mon Sep 17 00:00:00 2001 From: Jannik Clausen <12862103+masseelch@users.noreply.github.com> Date: Tue, 7 Mar 2023 07:46:34 +0100 Subject: [PATCH] dialect/sql/schema: strip returning from seed (#3367) --- dialect/sql/schema/writer.go | 59 ++++++++++++++++- dialect/sql/schema/writer_test.go | 106 +++++++++++++++++++++++++----- 2 files changed, 146 insertions(+), 19 deletions(-) diff --git a/dialect/sql/schema/writer.go b/dialect/sql/schema/writer.go index 163c930c3..0293f14f6 100644 --- a/dialect/sql/schema/writer.go +++ b/dialect/sql/schema/writer.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "regexp" "strconv" "strings" "time" @@ -42,7 +43,9 @@ type ( ) // Write implements the io.Writer interface. -func (d *DirWriter) Write(p []byte) (int, error) { return d.b.Write(p) } +func (d *DirWriter) Write(p []byte) (int, error) { + return d.b.Write(trimReturning(p)) +} // Change converts all written statement so far into a migration // change with the given comment. @@ -229,6 +232,60 @@ func (w *WriteDriver) formatArg(v any) (string, error) { } } +var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`) + +// trimReturning trims any RETURNING suffix from INSERT/UPDATE queries. +// Note, that the output may be incorrect or unsafe SQL and require manual changes. +func trimReturning(query []byte) []byte { + var returning = []byte("returning") + var ( + b bytes.Buffer + skipQuoted = func(query []byte, idx int) ([]byte, int) { + for j := idx + 1; j < len(query); j++ { + switch query[j] { + case '\\': + j++ + case query[idx]: + return query[idx : j+1], j + } + } + // Unexpected EOS. + return query, -1 + } + ) +loop: + for i := 0; i < len(query); i++ { + switch q := query[i]; { + case q == '\'', q == '"', q == '`': // string or identifier + s, skip := skipQuoted(query, i) + if skip == -1 { + return query + } + b.Write(s) + i = skip + continue + case reReturning.Match(query[i:]): + // Forward until next unquoted ';' appears. + for j := i + len(returning); j < len(query); j++ { + switch query[j] { + case '\'', '"', '`': // string or identifier + _, skip := skipQuoted(query, j) + if skip == -1 { + return query + } + j = skip + case ';': + b.WriteString(";") + i += j + continue loop + } + } + } + b.WriteByte(query[i]) + } + return b.Bytes() +} + // Tx writes the transaction start. func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) { return dialect.NopTx(w), nil diff --git a/dialect/sql/schema/writer_test.go b/dialect/sql/schema/writer_test.go index 62aad6d2c..2d35a3938 100644 --- a/dialect/sql/schema/writer_test.go +++ b/dialect/sql/schema/writer_test.go @@ -73,22 +73,92 @@ func TestWriteDriver(t *testing.T) { } func TestDirWriter(t *testing.T) { - p := t.TempDir() - dir, err := migrate.NewLocalDir(p) - require.NoError(t, err) - w := &DirWriter{Dir: dir} - drv := NewWriteDriver(dialect.MySQL, w) - require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `a` = ?", []any{1}, nil)) - w.Change("Comment 1.") - require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `b` = ?", []any{2}, nil)) - w.Change("Comment 2.") - require.NoError(t, w.Flush("migration_file")) - files, err := os.ReadDir(p) - require.NoError(t, err) - require.Len(t, files, 2) - require.Contains(t, files[0].Name(), "_migration_file.sql") - buf, err := os.ReadFile(filepath.Join(p, files[0].Name())) - require.NoError(t, err) - require.Equal(t, "-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n", string(buf)) - require.Equal(t, "atlas.sum", files[1].Name()) + for _, tt := range []struct { + dialect string + exec []string + comments []string + args [][]any + want string + }{ + { + dialect.MySQL, + []string{ + "UPDATE `test`.`users` SET `a` = ?", + "UPDATE `test`.`users` SET `b` = ?", + }, + []string{ + "Comment 1.", + "Comment 2.", + }, + [][]any{ + {1}, + {2}, + }, + "-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n", + }, + { + dialect.Postgres, + []string{ + "INSERT INTO \"users\" (\"name\", \"email\") VALUES ($1, $2) RETURNING \"id\"", + "INSERT INTO \"groups\" (\"name\") VALUES ($1) RETURNING \"id\"", + }, + []string{ + "Seed users table", + "Seed groups table", + }, + [][]any{ + {"masseelch", "j@ariga.io"}, + {"admins"}, + }, + strings.Join([]string{ + "-- Seed users table\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('masseelch', 'j@ariga.io');\n", + "-- Seed groups table\nINSERT INTO \"groups\" (\"name\") VALUES ('admins');\n", + }, ""), + }, + { + dialect.SQLite, + []string{ + "INSERT INTO `users` (`name`, `email`) VALUES (?, ?) RETURNING `id`", + "INSERT INTO `groups` (`name`) VALUES (?) RETURNING `id`", + }, + []string{ + "Seed users table", + "Seed groups table", + }, + [][]any{ + {"masseelch", "j@ariga.io"}, + {"admins"}, + }, + strings.Join([]string{ + "-- Seed users table\nINSERT INTO `users` (`name`, `email`) VALUES ('masseelch', 'j@ariga.io');\n", + "-- Seed groups table\nINSERT INTO `groups` (`name`) VALUES ('admins');\n", + }, ""), + }, + } { + t.Run(tt.dialect, func(t *testing.T) { + var ( + p = t.TempDir() + dir = func() migrate.Dir { + d, err := migrate.NewLocalDir(p) + require.NoError(t, err) + return d + }() + w = &DirWriter{Dir: dir} + drv = NewWriteDriver(tt.dialect, w) + ) + for i := range tt.exec { + require.NoError(t, drv.Exec(context.Background(), tt.exec[i], tt.args[i], nil)) + w.Change(tt.comments[i]) + } + require.NoError(t, w.Flush("migration_file")) + files, err := os.ReadDir(p) + require.NoError(t, err) + require.Len(t, files, 2) + require.Contains(t, files[0].Name(), "_migration_file.sql") + buf, err := os.ReadFile(filepath.Join(p, files[0].Name())) + require.NoError(t, err) + require.Equal(t, tt.want, string(buf)) + require.Equal(t, "atlas.sum", files[1].Name()) + }) + } }