mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: strip returning from seed (#3367)
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -42,7 +43,9 @@ type (
|
||||
)
|
||||
|
||||
// Write implements the io.Writer interface.
|
||||
func (d *DirWriter) Write(p []byte) (int, error) { return d.b.Write(p) }
|
||||
func (d *DirWriter) Write(p []byte) (int, error) {
|
||||
return d.b.Write(trimReturning(p))
|
||||
}
|
||||
|
||||
// Change converts all written statement so far into a migration
|
||||
// change with the given comment.
|
||||
@@ -229,6 +232,60 @@ func (w *WriteDriver) formatArg(v any) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`)
|
||||
|
||||
// trimReturning trims any RETURNING suffix from INSERT/UPDATE queries.
|
||||
// Note, that the output may be incorrect or unsafe SQL and require manual changes.
|
||||
func trimReturning(query []byte) []byte {
|
||||
var returning = []byte("returning")
|
||||
var (
|
||||
b bytes.Buffer
|
||||
skipQuoted = func(query []byte, idx int) ([]byte, int) {
|
||||
for j := idx + 1; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\\':
|
||||
j++
|
||||
case query[idx]:
|
||||
return query[idx : j+1], j
|
||||
}
|
||||
}
|
||||
// Unexpected EOS.
|
||||
return query, -1
|
||||
}
|
||||
)
|
||||
loop:
|
||||
for i := 0; i < len(query); i++ {
|
||||
switch q := query[i]; {
|
||||
case q == '\'', q == '"', q == '`': // string or identifier
|
||||
s, skip := skipQuoted(query, i)
|
||||
if skip == -1 {
|
||||
return query
|
||||
}
|
||||
b.Write(s)
|
||||
i = skip
|
||||
continue
|
||||
case reReturning.Match(query[i:]):
|
||||
// Forward until next unquoted ';' appears.
|
||||
for j := i + len(returning); j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\'', '"', '`': // string or identifier
|
||||
_, skip := skipQuoted(query, j)
|
||||
if skip == -1 {
|
||||
return query
|
||||
}
|
||||
j = skip
|
||||
case ';':
|
||||
b.WriteString(";")
|
||||
i += j
|
||||
continue loop
|
||||
}
|
||||
}
|
||||
}
|
||||
b.WriteByte(query[i])
|
||||
}
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
// Tx writes the transaction start.
|
||||
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
|
||||
return dialect.NopTx(w), nil
|
||||
|
||||
Reference in New Issue
Block a user