mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: mock result in write driver if returning clause is present (#3369)
This commit is contained in:
@@ -111,7 +111,46 @@ func (w *WriteDriver) Query(ctx context.Context, query string, args, res any) er
|
||||
return err
|
||||
}
|
||||
if rr, ok := res.(*sql.Rows); ok {
|
||||
*rr = sql.Rows{ColumnScanner: noRows{}}
|
||||
cols := func() []string {
|
||||
// If the query has a RETURNING clause, mock the result.
|
||||
var clause string
|
||||
outer:
|
||||
for i := 0; i < len(query); i++ {
|
||||
switch q := query[i]; {
|
||||
case q == '\'', q == '"', q == '`': // string or identifier
|
||||
_, skip := skipQuoted(query, i)
|
||||
if skip == -1 {
|
||||
return nil // malformed SQL
|
||||
}
|
||||
i = skip
|
||||
continue
|
||||
case reReturning.MatchString(query[i:]):
|
||||
var j int
|
||||
inner:
|
||||
// Forward until next unquoted ';' appears, or we reach the end of the query.
|
||||
for j = i; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\'', '"', '`': // string or identifier
|
||||
_, skip := skipQuoted(query, j)
|
||||
if skip == -1 {
|
||||
return nil // malformed RETURNING clause
|
||||
}
|
||||
j = skip
|
||||
case ';':
|
||||
break inner
|
||||
}
|
||||
}
|
||||
clause = query[i:j]
|
||||
break outer
|
||||
}
|
||||
}
|
||||
cols := strings.Split(reReturning.ReplaceAllString(clause, ""), ",")
|
||||
for i := range cols {
|
||||
cols[i] = strings.TrimSpace(cols[i])
|
||||
}
|
||||
return cols
|
||||
}()
|
||||
*rr = sql.Rows{ColumnScanner: &noRows{cols: cols}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -237,22 +276,7 @@ var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`)
|
||||
// trimReturning trims any RETURNING suffix from INSERT/UPDATE queries.
|
||||
// Note, that the output may be incorrect or unsafe SQL and require manual changes.
|
||||
func trimReturning(query []byte) []byte {
|
||||
var returning = []byte("returning")
|
||||
var (
|
||||
b bytes.Buffer
|
||||
skipQuoted = func(query []byte, idx int) ([]byte, int) {
|
||||
for j := idx + 1; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\\':
|
||||
j++
|
||||
case query[idx]:
|
||||
return query[idx : j+1], j
|
||||
}
|
||||
}
|
||||
// Unexpected EOS.
|
||||
return query, -1
|
||||
}
|
||||
)
|
||||
var b bytes.Buffer
|
||||
loop:
|
||||
for i := 0; i < len(query); i++ {
|
||||
switch q := query[i]; {
|
||||
@@ -266,7 +290,7 @@ loop:
|
||||
continue
|
||||
case reReturning.Match(query[i:]):
|
||||
// Forward until next unquoted ';' appears.
|
||||
for j := i + len(returning); j < len(query); j++ {
|
||||
for j := i; j < len(query); j++ { // skip "RETURNING"
|
||||
switch query[j] {
|
||||
case '\'', '"', '`': // string or identifier
|
||||
_, skip := skipQuoted(query, j)
|
||||
@@ -286,6 +310,19 @@ loop:
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func skipQuoted[T []byte | string](query T, idx int) (T, int) {
|
||||
for j := idx + 1; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\\':
|
||||
j++
|
||||
case query[idx]:
|
||||
return query[idx : j+1], j
|
||||
}
|
||||
}
|
||||
// Unexpected EOS.
|
||||
return query, -1
|
||||
}
|
||||
|
||||
// Tx writes the transaction start.
|
||||
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
|
||||
return dialect.NopTx(w), nil
|
||||
@@ -298,11 +335,23 @@ func (noResult) LastInsertId() (int64, error) { return 0, nil }
|
||||
func (noResult) RowsAffected() (int64, error) { return 0, nil }
|
||||
|
||||
// noRows represents no rows.
|
||||
type noRows struct{ sql.ColumnScanner }
|
||||
type noRows struct {
|
||||
sql.ColumnScanner
|
||||
cols []string
|
||||
done bool
|
||||
}
|
||||
|
||||
func (noRows) Close() error { return nil }
|
||||
func (noRows) Err() error { return nil }
|
||||
func (noRows) Next() bool { return false }
|
||||
func (*noRows) Close() error { return nil }
|
||||
func (*noRows) Err() error { return nil }
|
||||
func (r *noRows) Next() bool {
|
||||
if !r.done {
|
||||
r.done = true
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (r *noRows) Columns() ([]string, error) { return r.cols, nil }
|
||||
func (*noRows) Scan(...any) error { return nil }
|
||||
|
||||
type nopDriver struct {
|
||||
dialect.Driver
|
||||
|
||||
@@ -70,6 +70,24 @@ func TestWriteDriver(t *testing.T) {
|
||||
err = w.Query(ctx, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id`, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id;`+"\n", b.String())
|
||||
|
||||
// correct columns are extracted from a returning clause and returned by sql.ColumnScanner.
|
||||
for q, cols := range map[string][]string{
|
||||
`INSERT INTO "users" (name) VALUES("a8m") RETURNING id`: {"id"},
|
||||
`INSERT INTO "users" (name) VALUES("a8m") RETURNING id, "name"`: {"id", `"name"`},
|
||||
`INSERT INTO "users" (name) VALUES("a8m") RETURNING "id", "name"`: {`"id"`, `"name"`},
|
||||
`INSERT INTO "users" (name) VALUES("a8m") RETURNING "id", "name"; DROP "groups"`: {`"id"`, `"name"`},
|
||||
} {
|
||||
var rows sql.Rows
|
||||
err = w.Query(ctx, q, nil, &rows)
|
||||
require.NoError(t, err)
|
||||
require.True(t, rows.Next())
|
||||
c, err := rows.Columns()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cols, c)
|
||||
require.NoError(t, rows.Scan())
|
||||
}
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
func TestDirWriter(t *testing.T) {
|
||||
@@ -134,6 +152,13 @@ func TestDirWriter(t *testing.T) {
|
||||
"-- Seed groups table\nINSERT INTO `groups` (`name`) VALUES ('admins');\n",
|
||||
}, ""),
|
||||
},
|
||||
{
|
||||
dialect.SQLite + " no space",
|
||||
[]string{"INSERT INTO `users` (`name`) VALUES (?)RETURNING `id`"},
|
||||
[]string{"Seed users table"},
|
||||
[][]any{{"masseelch"}},
|
||||
"-- Seed users table\nINSERT INTO `users` (`name`) VALUES ('masseelch');\n",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.dialect, func(t *testing.T) {
|
||||
var (
|
||||
|
||||
Reference in New Issue
Block a user