mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: mock result in write driver if returning clause is present (#3369)
This commit is contained in:
@@ -111,7 +111,46 @@ func (w *WriteDriver) Query(ctx context.Context, query string, args, res any) er
|
||||
return err
|
||||
}
|
||||
if rr, ok := res.(*sql.Rows); ok {
|
||||
*rr = sql.Rows{ColumnScanner: noRows{}}
|
||||
cols := func() []string {
|
||||
// If the query has a RETURNING clause, mock the result.
|
||||
var clause string
|
||||
outer:
|
||||
for i := 0; i < len(query); i++ {
|
||||
switch q := query[i]; {
|
||||
case q == '\'', q == '"', q == '`': // string or identifier
|
||||
_, skip := skipQuoted(query, i)
|
||||
if skip == -1 {
|
||||
return nil // malformed SQL
|
||||
}
|
||||
i = skip
|
||||
continue
|
||||
case reReturning.MatchString(query[i:]):
|
||||
var j int
|
||||
inner:
|
||||
// Forward until next unquoted ';' appears, or we reach the end of the query.
|
||||
for j = i; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\'', '"', '`': // string or identifier
|
||||
_, skip := skipQuoted(query, j)
|
||||
if skip == -1 {
|
||||
return nil // malformed RETURNING clause
|
||||
}
|
||||
j = skip
|
||||
case ';':
|
||||
break inner
|
||||
}
|
||||
}
|
||||
clause = query[i:j]
|
||||
break outer
|
||||
}
|
||||
}
|
||||
cols := strings.Split(reReturning.ReplaceAllString(clause, ""), ",")
|
||||
for i := range cols {
|
||||
cols[i] = strings.TrimSpace(cols[i])
|
||||
}
|
||||
return cols
|
||||
}()
|
||||
*rr = sql.Rows{ColumnScanner: &noRows{cols: cols}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -237,22 +276,7 @@ var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`)
|
||||
// trimReturning trims any RETURNING suffix from INSERT/UPDATE queries.
|
||||
// Note, that the output may be incorrect or unsafe SQL and require manual changes.
|
||||
func trimReturning(query []byte) []byte {
|
||||
var returning = []byte("returning")
|
||||
var (
|
||||
b bytes.Buffer
|
||||
skipQuoted = func(query []byte, idx int) ([]byte, int) {
|
||||
for j := idx + 1; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\\':
|
||||
j++
|
||||
case query[idx]:
|
||||
return query[idx : j+1], j
|
||||
}
|
||||
}
|
||||
// Unexpected EOS.
|
||||
return query, -1
|
||||
}
|
||||
)
|
||||
var b bytes.Buffer
|
||||
loop:
|
||||
for i := 0; i < len(query); i++ {
|
||||
switch q := query[i]; {
|
||||
@@ -266,7 +290,7 @@ loop:
|
||||
continue
|
||||
case reReturning.Match(query[i:]):
|
||||
// Forward until next unquoted ';' appears.
|
||||
for j := i + len(returning); j < len(query); j++ {
|
||||
for j := i; j < len(query); j++ { // skip "RETURNING"
|
||||
switch query[j] {
|
||||
case '\'', '"', '`': // string or identifier
|
||||
_, skip := skipQuoted(query, j)
|
||||
@@ -286,6 +310,19 @@ loop:
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func skipQuoted[T []byte | string](query T, idx int) (T, int) {
|
||||
for j := idx + 1; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\\':
|
||||
j++
|
||||
case query[idx]:
|
||||
return query[idx : j+1], j
|
||||
}
|
||||
}
|
||||
// Unexpected EOS.
|
||||
return query, -1
|
||||
}
|
||||
|
||||
// Tx writes the transaction start.
|
||||
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
|
||||
return dialect.NopTx(w), nil
|
||||
@@ -298,11 +335,23 @@ func (noResult) LastInsertId() (int64, error) { return 0, nil }
|
||||
func (noResult) RowsAffected() (int64, error) { return 0, nil }
|
||||
|
||||
// noRows represents no rows.
|
||||
type noRows struct{ sql.ColumnScanner }
|
||||
type noRows struct {
|
||||
sql.ColumnScanner
|
||||
cols []string
|
||||
done bool
|
||||
}
|
||||
|
||||
func (noRows) Close() error { return nil }
|
||||
func (noRows) Err() error { return nil }
|
||||
func (noRows) Next() bool { return false }
|
||||
func (*noRows) Close() error { return nil }
|
||||
func (*noRows) Err() error { return nil }
|
||||
func (r *noRows) Next() bool {
|
||||
if !r.done {
|
||||
r.done = true
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (r *noRows) Columns() ([]string, error) { return r.cols, nil }
|
||||
func (*noRows) Scan(...any) error { return nil }
|
||||
|
||||
type nopDriver struct {
|
||||
dialect.Driver
|
||||
|
||||
Reference in New Issue
Block a user