mirror of
https://github.com/ent/ent.git
synced 2026-05-05 00:50:54 +03:00
366 lines
9.0 KiB
Go
366 lines
9.0 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package schema
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"entgo.io/ent/dialect"
|
|
"entgo.io/ent/dialect/sql"
|
|
|
|
"ariga.io/atlas/sql/migrate"
|
|
)
|
|
|
|
type (
|
|
// WriteDriver is a driver that writes all driver exec operations to its writer.
|
|
// Note that this driver is used only for printing or writing statements to SQL
|
|
// files, and may require manual changes to the generated SQL statements.
|
|
WriteDriver struct {
|
|
dialect.Driver // optional driver for query calls.
|
|
io.Writer // target for exec statements.
|
|
FormatFunc func(string) (string, error)
|
|
}
|
|
// DirWriter implements the io.Writer interface
|
|
// for writing to an Atlas managed directory.
|
|
DirWriter struct {
|
|
Dir migrate.Dir // target directory.
|
|
Formatter migrate.Formatter // optional formatter.
|
|
b bytes.Buffer // working buffer.
|
|
changes []*migrate.Change // changes to flush.
|
|
}
|
|
)
|
|
|
|
// Write implements the io.Writer interface.
|
|
func (d *DirWriter) Write(p []byte) (int, error) {
|
|
return d.b.Write(trimReturning(p))
|
|
}
|
|
|
|
// Change converts all written statement so far into a migration
|
|
// change with the given comment.
|
|
func (d *DirWriter) Change(comment string) {
|
|
// Trim semicolon and new line, because formatter adds it.
|
|
d.changes = append(d.changes, &migrate.Change{Comment: comment, Cmd: strings.TrimRight(d.b.String(), ";\n")})
|
|
d.b.Reset()
|
|
}
|
|
|
|
// Flush flushes the written statements to the directory.
|
|
func (d *DirWriter) Flush(name string) error {
|
|
switch {
|
|
case d.b.Len() != 0:
|
|
return fmt.Errorf("writer has undocumented change. Use Change or FlushChange instead")
|
|
case len(d.changes) == 0:
|
|
return errors.New("writer has no changes to flush")
|
|
default:
|
|
return migrate.NewPlanner(nil, d.Dir, migrate.PlanFormat(d.Formatter)).
|
|
WritePlan(&migrate.Plan{
|
|
Name: name,
|
|
Changes: d.changes,
|
|
})
|
|
}
|
|
}
|
|
|
|
// FlushChange combines Change and Flush.
|
|
func (d *DirWriter) FlushChange(name, comment string) error {
|
|
d.Change(comment)
|
|
return d.Flush(name)
|
|
}
|
|
|
|
// NewWriteDriver creates a dialect.Driver that writes all driver exec statement to its writer.
|
|
func NewWriteDriver(dialect string, w io.Writer) *WriteDriver {
|
|
return &WriteDriver{
|
|
Writer: w,
|
|
Driver: nopDriver{dialect: dialect},
|
|
}
|
|
}
|
|
|
|
// Exec implements the dialect.Driver.Exec method.
|
|
func (w *WriteDriver) Exec(_ context.Context, query string, args, res any) error {
|
|
if rr, ok := res.(*sql.Result); ok {
|
|
*rr = noResult{}
|
|
}
|
|
if !strings.HasSuffix(query, ";") {
|
|
query += ";"
|
|
}
|
|
if args != nil {
|
|
args, ok := args.([]any)
|
|
if !ok {
|
|
return fmt.Errorf("unexpected args type: %T", args)
|
|
}
|
|
query = w.expandArgs(query, args)
|
|
}
|
|
_, err := io.WriteString(w, query+"\n")
|
|
return err
|
|
}
|
|
|
|
// Query implements the dialect.Driver.Query method.
|
|
func (w *WriteDriver) Query(ctx context.Context, query string, args, res any) error {
|
|
if strings.HasPrefix(query, "INSERT") || strings.HasPrefix(query, "UPDATE") {
|
|
if err := w.Exec(ctx, query, args, nil); err != nil {
|
|
return err
|
|
}
|
|
if rr, ok := res.(*sql.Rows); ok {
|
|
cols := func() []string {
|
|
// If the query has a RETURNING clause, mock the result.
|
|
var clause string
|
|
outer:
|
|
for i := 0; i < len(query); i++ {
|
|
switch q := query[i]; {
|
|
case q == '\'', q == '"', q == '`': // string or identifier
|
|
_, skip := skipQuoted(query, i)
|
|
if skip == -1 {
|
|
return nil // malformed SQL
|
|
}
|
|
i = skip
|
|
continue
|
|
case reReturning.MatchString(query[i:]):
|
|
var j int
|
|
inner:
|
|
// Forward until next unquoted ';' appears, or we reach the end of the query.
|
|
for j = i; j < len(query); j++ {
|
|
switch query[j] {
|
|
case '\'', '"', '`': // string or identifier
|
|
_, skip := skipQuoted(query, j)
|
|
if skip == -1 {
|
|
return nil // malformed RETURNING clause
|
|
}
|
|
j = skip
|
|
case ';':
|
|
break inner
|
|
}
|
|
}
|
|
clause = query[i:j]
|
|
break outer
|
|
}
|
|
}
|
|
cols := strings.Split(reReturning.ReplaceAllString(clause, ""), ",")
|
|
for i := range cols {
|
|
cols[i] = strings.TrimSpace(cols[i])
|
|
}
|
|
return cols
|
|
}()
|
|
*rr = sql.Rows{ColumnScanner: &noRows{cols: cols}}
|
|
}
|
|
return nil
|
|
}
|
|
switch w.Driver.(type) {
|
|
case nil, nopDriver:
|
|
return errors.New("query is not supported by the WriteDriver")
|
|
default:
|
|
return w.Driver.Query(ctx, query, args, res)
|
|
}
|
|
}
|
|
|
|
// expandArgs combines to arguments and statement into a single statement to
|
|
// print or write into a file (before editing).
|
|
// Note, the output may be incorrect or unsafe SQL and require manual changes.
|
|
func (w *WriteDriver) expandArgs(query string, args []any) string {
|
|
var (
|
|
b strings.Builder
|
|
p = w.placeholder()
|
|
scan = w.scanPlaceholder()
|
|
)
|
|
for i := 0; i < len(query); i++ {
|
|
Top:
|
|
switch query[i] {
|
|
case p:
|
|
idx, size := scan(query[i+1:])
|
|
// Unrecognized placeholder.
|
|
if idx < 0 || idx >= len(args) {
|
|
return query
|
|
}
|
|
i += size
|
|
v, err := w.formatArg(args[idx])
|
|
if err != nil {
|
|
// Unexpected formatting error.
|
|
return query
|
|
}
|
|
b.WriteString(v)
|
|
// String or identifier.
|
|
case '\'', '"', '`':
|
|
for j := i + 1; j < len(query); j++ {
|
|
switch query[j] {
|
|
case '\\':
|
|
j++
|
|
case query[i]:
|
|
b.WriteString(query[i : j+1])
|
|
i = j
|
|
break Top
|
|
}
|
|
}
|
|
// Unexpected EOS.
|
|
return query
|
|
default:
|
|
b.WriteByte(query[i])
|
|
}
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
func (w *WriteDriver) scanPlaceholder() func(string) (int, int) {
|
|
switch w.Dialect() {
|
|
case dialect.Postgres:
|
|
return func(s string) (int, int) {
|
|
var i int
|
|
for i < len(s) && unicode.IsDigit(rune(s[i])) {
|
|
i++
|
|
}
|
|
idx, err := strconv.ParseInt(s[:i], 10, 64)
|
|
if err != nil {
|
|
return -1, 0
|
|
}
|
|
// Placeholders are 1-based.
|
|
return int(idx) - 1, i
|
|
}
|
|
default:
|
|
idx := -1
|
|
return func(string) (int, int) {
|
|
idx++
|
|
return idx, 0
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *WriteDriver) placeholder() byte {
|
|
if w.Dialect() == dialect.Postgres {
|
|
return '$'
|
|
}
|
|
return '?'
|
|
}
|
|
|
|
func (w *WriteDriver) formatArg(v any) (string, error) {
|
|
if w.FormatFunc != nil {
|
|
return w.FormatFunc(fmt.Sprint(v))
|
|
}
|
|
switch v := v.(type) {
|
|
case nil:
|
|
return "NULL", nil
|
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
|
return fmt.Sprintf("%d", v), nil
|
|
case float32, float64:
|
|
return fmt.Sprintf("%g", v), nil
|
|
case bool:
|
|
if v {
|
|
return "1", nil
|
|
} else {
|
|
return "0", nil
|
|
}
|
|
case string:
|
|
return "'" + strings.ReplaceAll(v, "'", "''") + "'", nil
|
|
case json.RawMessage:
|
|
return "'" + strings.ReplaceAll(string(v), "'", "''") + "'", nil
|
|
case []byte:
|
|
return "{{ BINARY_VALUE }}", nil
|
|
case time.Time:
|
|
return "{{ TIME_VALUE }}", nil
|
|
case fmt.Stringer:
|
|
return "'" + strings.ReplaceAll(v.String(), "'", "''") + "'", nil
|
|
default:
|
|
return "{{ VALUE }}", nil
|
|
}
|
|
}
|
|
|
|
var reReturning = regexp.MustCompile(`(?i)^\s?RETURNING`)
|
|
|
|
// trimReturning trims any RETURNING suffix from INSERT/UPDATE queries.
|
|
// Note, that the output may be incorrect or unsafe SQL and require manual changes.
|
|
func trimReturning(query []byte) []byte {
|
|
var b bytes.Buffer
|
|
loop:
|
|
for i := 0; i < len(query); i++ {
|
|
switch q := query[i]; {
|
|
case q == '\'', q == '"', q == '`': // string or identifier
|
|
s, skip := skipQuoted(query, i)
|
|
if skip == -1 {
|
|
return query
|
|
}
|
|
b.Write(s)
|
|
i = skip
|
|
continue
|
|
case reReturning.Match(query[i:]):
|
|
// Forward until next unquoted ';' appears.
|
|
for j := i; j < len(query); j++ { // skip "RETURNING"
|
|
switch query[j] {
|
|
case '\'', '"', '`': // string or identifier
|
|
_, skip := skipQuoted(query, j)
|
|
if skip == -1 {
|
|
return query
|
|
}
|
|
j = skip
|
|
case ';':
|
|
b.WriteString(";")
|
|
i += j
|
|
continue loop
|
|
}
|
|
}
|
|
}
|
|
b.WriteByte(query[i])
|
|
}
|
|
return b.Bytes()
|
|
}
|
|
|
|
func skipQuoted[T []byte | string](query T, idx int) (T, int) {
|
|
for j := idx + 1; j < len(query); j++ {
|
|
switch query[j] {
|
|
case '\\':
|
|
j++
|
|
case query[idx]:
|
|
return query[idx : j+1], j
|
|
}
|
|
}
|
|
// Unexpected EOS.
|
|
return query, -1
|
|
}
|
|
|
|
// Tx writes the transaction start.
|
|
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
|
|
return dialect.NopTx(w), nil
|
|
}
|
|
|
|
// noResult represents a zero result.
|
|
type noResult struct{}
|
|
|
|
func (noResult) LastInsertId() (int64, error) { return 0, nil }
|
|
func (noResult) RowsAffected() (int64, error) { return 0, nil }
|
|
|
|
// noRows represents no rows.
|
|
type noRows struct {
|
|
sql.ColumnScanner
|
|
cols []string
|
|
done bool
|
|
}
|
|
|
|
func (*noRows) Close() error { return nil }
|
|
func (*noRows) Err() error { return nil }
|
|
func (r *noRows) Next() bool {
|
|
if !r.done {
|
|
r.done = true
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
func (r *noRows) Columns() ([]string, error) { return r.cols, nil }
|
|
func (*noRows) Scan(...any) error { return nil }
|
|
|
|
type nopDriver struct {
|
|
dialect.Driver
|
|
dialect string
|
|
}
|
|
|
|
func (d nopDriver) Dialect() string { return d.dialect }
|
|
|
|
func (nopDriver) Query(context.Context, string, any, any) error {
|
|
return nil
|
|
}
|