mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: strip returning from seed (#3367)
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -42,7 +43,9 @@ type (
|
||||
)
|
||||
|
||||
// Write implements the io.Writer interface.
|
||||
func (d *DirWriter) Write(p []byte) (int, error) { return d.b.Write(p) }
|
||||
func (d *DirWriter) Write(p []byte) (int, error) {
|
||||
return d.b.Write(trimReturning(p))
|
||||
}
|
||||
|
||||
// Change converts all written statement so far into a migration
|
||||
// change with the given comment.
|
||||
@@ -229,6 +232,60 @@ func (w *WriteDriver) formatArg(v any) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`)
|
||||
|
||||
// trimReturning trims any RETURNING suffix from INSERT/UPDATE queries.
|
||||
// Note, that the output may be incorrect or unsafe SQL and require manual changes.
|
||||
func trimReturning(query []byte) []byte {
|
||||
var returning = []byte("returning")
|
||||
var (
|
||||
b bytes.Buffer
|
||||
skipQuoted = func(query []byte, idx int) ([]byte, int) {
|
||||
for j := idx + 1; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\\':
|
||||
j++
|
||||
case query[idx]:
|
||||
return query[idx : j+1], j
|
||||
}
|
||||
}
|
||||
// Unexpected EOS.
|
||||
return query, -1
|
||||
}
|
||||
)
|
||||
loop:
|
||||
for i := 0; i < len(query); i++ {
|
||||
switch q := query[i]; {
|
||||
case q == '\'', q == '"', q == '`': // string or identifier
|
||||
s, skip := skipQuoted(query, i)
|
||||
if skip == -1 {
|
||||
return query
|
||||
}
|
||||
b.Write(s)
|
||||
i = skip
|
||||
continue
|
||||
case reReturning.Match(query[i:]):
|
||||
// Forward until next unquoted ';' appears.
|
||||
for j := i + len(returning); j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\'', '"', '`': // string or identifier
|
||||
_, skip := skipQuoted(query, j)
|
||||
if skip == -1 {
|
||||
return query
|
||||
}
|
||||
j = skip
|
||||
case ';':
|
||||
b.WriteString(";")
|
||||
i += j
|
||||
continue loop
|
||||
}
|
||||
}
|
||||
}
|
||||
b.WriteByte(query[i])
|
||||
}
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
// Tx writes the transaction start.
|
||||
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
|
||||
return dialect.NopTx(w), nil
|
||||
|
||||
@@ -73,22 +73,92 @@ func TestWriteDriver(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDirWriter(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
dir, err := migrate.NewLocalDir(p)
|
||||
require.NoError(t, err)
|
||||
w := &DirWriter{Dir: dir}
|
||||
drv := NewWriteDriver(dialect.MySQL, w)
|
||||
require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `a` = ?", []any{1}, nil))
|
||||
w.Change("Comment 1.")
|
||||
require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `b` = ?", []any{2}, nil))
|
||||
w.Change("Comment 2.")
|
||||
require.NoError(t, w.Flush("migration_file"))
|
||||
files, err := os.ReadDir(p)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 2)
|
||||
require.Contains(t, files[0].Name(), "_migration_file.sql")
|
||||
buf, err := os.ReadFile(filepath.Join(p, files[0].Name()))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n", string(buf))
|
||||
require.Equal(t, "atlas.sum", files[1].Name())
|
||||
for _, tt := range []struct {
|
||||
dialect string
|
||||
exec []string
|
||||
comments []string
|
||||
args [][]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
dialect.MySQL,
|
||||
[]string{
|
||||
"UPDATE `test`.`users` SET `a` = ?",
|
||||
"UPDATE `test`.`users` SET `b` = ?",
|
||||
},
|
||||
[]string{
|
||||
"Comment 1.",
|
||||
"Comment 2.",
|
||||
},
|
||||
[][]any{
|
||||
{1},
|
||||
{2},
|
||||
},
|
||||
"-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n",
|
||||
},
|
||||
{
|
||||
dialect.Postgres,
|
||||
[]string{
|
||||
"INSERT INTO \"users\" (\"name\", \"email\") VALUES ($1, $2) RETURNING \"id\"",
|
||||
"INSERT INTO \"groups\" (\"name\") VALUES ($1) RETURNING \"id\"",
|
||||
},
|
||||
[]string{
|
||||
"Seed users table",
|
||||
"Seed groups table",
|
||||
},
|
||||
[][]any{
|
||||
{"masseelch", "j@ariga.io"},
|
||||
{"admins"},
|
||||
},
|
||||
strings.Join([]string{
|
||||
"-- Seed users table\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('masseelch', 'j@ariga.io');\n",
|
||||
"-- Seed groups table\nINSERT INTO \"groups\" (\"name\") VALUES ('admins');\n",
|
||||
}, ""),
|
||||
},
|
||||
{
|
||||
dialect.SQLite,
|
||||
[]string{
|
||||
"INSERT INTO `users` (`name`, `email`) VALUES (?, ?) RETURNING `id`",
|
||||
"INSERT INTO `groups` (`name`) VALUES (?) RETURNING `id`",
|
||||
},
|
||||
[]string{
|
||||
"Seed users table",
|
||||
"Seed groups table",
|
||||
},
|
||||
[][]any{
|
||||
{"masseelch", "j@ariga.io"},
|
||||
{"admins"},
|
||||
},
|
||||
strings.Join([]string{
|
||||
"-- Seed users table\nINSERT INTO `users` (`name`, `email`) VALUES ('masseelch', 'j@ariga.io');\n",
|
||||
"-- Seed groups table\nINSERT INTO `groups` (`name`) VALUES ('admins');\n",
|
||||
}, ""),
|
||||
},
|
||||
} {
|
||||
t.Run(tt.dialect, func(t *testing.T) {
|
||||
var (
|
||||
p = t.TempDir()
|
||||
dir = func() migrate.Dir {
|
||||
d, err := migrate.NewLocalDir(p)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}()
|
||||
w = &DirWriter{Dir: dir}
|
||||
drv = NewWriteDriver(tt.dialect, w)
|
||||
)
|
||||
for i := range tt.exec {
|
||||
require.NoError(t, drv.Exec(context.Background(), tt.exec[i], tt.args[i], nil))
|
||||
w.Change(tt.comments[i])
|
||||
}
|
||||
require.NoError(t, w.Flush("migration_file"))
|
||||
files, err := os.ReadDir(p)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 2)
|
||||
require.Contains(t, files[0].Name(), "_migration_file.sql")
|
||||
buf, err := os.ReadFile(filepath.Join(p, files[0].Name()))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, string(buf))
|
||||
require.Equal(t, "atlas.sum", files[1].Name())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user