mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: make WriteDriver friendlier (#3119)
Also, add a guide for writing and executing data migrations files.
This commit is contained in:
@@ -5,44 +5,252 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
|
||||
"ariga.io/atlas/sql/migrate"
|
||||
)
|
||||
|
||||
// WriteDriver is a driver that writes all driver exec operations to its writer.
|
||||
type WriteDriver struct {
|
||||
dialect.Driver // underlying driver.
|
||||
io.Writer // target for exec statements.
|
||||
type (
|
||||
// WriteDriver is a driver that writes all driver exec operations to its writer.
|
||||
// Note that this driver is used only for printing or writing statements to SQL
|
||||
// files, and may require manual changes to the generated SQL statements.
|
||||
WriteDriver struct {
|
||||
dialect.Driver // optional driver for query calls.
|
||||
io.Writer // target for exec statements.
|
||||
FormatFunc func(string) (string, error)
|
||||
}
|
||||
// DirWriter implements the io.Writer interface
|
||||
// for writing to an Atlas managed directory.
|
||||
DirWriter struct {
|
||||
Dir migrate.Dir // target directory.
|
||||
Formatter migrate.Formatter // optional formatter.
|
||||
b bytes.Buffer // working buffer.
|
||||
changes []*migrate.Change // changes to flush.
|
||||
}
|
||||
)
|
||||
|
||||
// Write implements the io.Writer interface.
|
||||
func (d *DirWriter) Write(p []byte) (int, error) { return d.b.Write(p) }
|
||||
|
||||
// Change converts all written statement so far into a migration
|
||||
// change with the given comment.
|
||||
func (d *DirWriter) Change(comment string) {
|
||||
// Trim semicolon and new line, because formatter adds it.
|
||||
d.changes = append(d.changes, &migrate.Change{Comment: comment, Cmd: strings.TrimRight(d.b.String(), ";\n")})
|
||||
d.b.Reset()
|
||||
}
|
||||
|
||||
// Exec writes its query and calls the underlying driver Exec method.
|
||||
func (w *WriteDriver) Exec(_ context.Context, query string, _, _ any) error {
|
||||
// Flush flushes the written statements to the directory.
|
||||
func (d *DirWriter) Flush(name string) error {
|
||||
switch {
|
||||
case d.b.Len() != 0:
|
||||
return fmt.Errorf("writer has undocumented change. Use Change or FlushChange instead")
|
||||
case len(d.changes) == 0:
|
||||
return errors.New("writer has no changes to flush")
|
||||
default:
|
||||
return migrate.NewPlanner(nil, d.Dir, migrate.PlanFormat(d.Formatter)).
|
||||
WritePlan(&migrate.Plan{
|
||||
Name: name,
|
||||
Changes: d.changes,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// FlushChange combines Change and Flush.
|
||||
func (d *DirWriter) FlushChange(name, comment string) error {
|
||||
d.Change(comment)
|
||||
return d.Flush(name)
|
||||
}
|
||||
|
||||
// NewWriteDriver creates a dialect.Driver that writes all driver exec statement to its writer.
|
||||
func NewWriteDriver(dialect string, w io.Writer) *WriteDriver {
|
||||
return &WriteDriver{
|
||||
Writer: w,
|
||||
Driver: nopDriver{dialect: dialect},
|
||||
}
|
||||
}
|
||||
|
||||
// Exec implements the dialect.Driver.Exec method.
|
||||
func (w *WriteDriver) Exec(_ context.Context, query string, args, res any) error {
|
||||
if rr, ok := res.(*sql.Result); ok {
|
||||
*rr = noResult{}
|
||||
}
|
||||
if !strings.HasSuffix(query, ";") {
|
||||
query += ";"
|
||||
}
|
||||
if args != nil {
|
||||
args, ok := args.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected args type: %T", args)
|
||||
}
|
||||
query = w.expandArgs(query, args)
|
||||
}
|
||||
_, err := io.WriteString(w, query+"\n")
|
||||
return err
|
||||
}
|
||||
|
||||
// Query implements the dialect.Driver.Query method.
|
||||
func (w *WriteDriver) Query(ctx context.Context, query string, args, res any) error {
|
||||
if strings.HasPrefix(query, "INSERT") || strings.HasPrefix(query, "UPDATE") {
|
||||
if err := w.Exec(ctx, query, args, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
if rr, ok := res.(*sql.Rows); ok {
|
||||
*rr = sql.Rows{ColumnScanner: noRows{}}
|
||||
}
|
||||
}
|
||||
switch w.Driver.(type) {
|
||||
case nil, nopDriver:
|
||||
return errors.New("query is not supported by the WriteDriver")
|
||||
default:
|
||||
return w.Driver.Query(ctx, query, args, res)
|
||||
}
|
||||
}
|
||||
|
||||
// expandArgs combines to arguments and statement into a single statement to
|
||||
// print or write into a file (before editing).
|
||||
// Note, the output may be incorrect or unsafe SQL and require manual changes.
|
||||
func (w *WriteDriver) expandArgs(query string, args []any) string {
|
||||
var (
|
||||
b strings.Builder
|
||||
p = w.placeholder()
|
||||
scan = w.scanPlaceholder()
|
||||
)
|
||||
for i := 0; i < len(query); i++ {
|
||||
Top:
|
||||
switch query[i] {
|
||||
case p:
|
||||
idx, size := scan(query[i+1:])
|
||||
// Unrecognized placeholder.
|
||||
if idx < 0 || idx >= len(args) {
|
||||
return query
|
||||
}
|
||||
i += size
|
||||
v, err := w.formatArg(args[idx])
|
||||
if err != nil {
|
||||
// Unexpected formatting error.
|
||||
return query
|
||||
}
|
||||
b.WriteString(v)
|
||||
// String or identifier.
|
||||
case '\'', '"', '`':
|
||||
for j := i + 1; j < len(query); j++ {
|
||||
switch query[j] {
|
||||
case '\\':
|
||||
j++
|
||||
case query[i]:
|
||||
b.WriteString(query[i : j+1])
|
||||
i = j
|
||||
break Top
|
||||
}
|
||||
}
|
||||
// Unexpected EOS.
|
||||
return query
|
||||
default:
|
||||
b.WriteByte(query[i])
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (w *WriteDriver) scanPlaceholder() func(string) (int, int) {
|
||||
switch w.Dialect() {
|
||||
case dialect.Postgres:
|
||||
return func(s string) (int, int) {
|
||||
var i int
|
||||
for i < len(s) && unicode.IsDigit(rune(s[i])) {
|
||||
i++
|
||||
}
|
||||
idx, err := strconv.ParseInt(s[:i], 10, 64)
|
||||
if err != nil {
|
||||
return -1, 0
|
||||
}
|
||||
// Placeholders are 1-based.
|
||||
return int(idx) - 1, i
|
||||
}
|
||||
default:
|
||||
idx := -1
|
||||
return func(string) (int, int) {
|
||||
idx++
|
||||
return idx, 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WriteDriver) placeholder() byte {
|
||||
if w.Dialect() == dialect.Postgres {
|
||||
return '$'
|
||||
}
|
||||
return '?'
|
||||
}
|
||||
|
||||
func (w *WriteDriver) formatArg(v any) (string, error) {
|
||||
if w.FormatFunc != nil {
|
||||
return w.FormatFunc(fmt.Sprint(v))
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case nil:
|
||||
return "NULL", nil
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v), nil
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%g", v), nil
|
||||
case bool:
|
||||
if v {
|
||||
return "1", nil
|
||||
} else {
|
||||
return "0", nil
|
||||
}
|
||||
case string:
|
||||
return "'" + strings.ReplaceAll(v, "'", "''") + "'", nil
|
||||
case json.RawMessage:
|
||||
return "'" + strings.ReplaceAll(string(v), "'", "''") + "'", nil
|
||||
case []byte:
|
||||
return "{{ BINARY_VALUE }}", nil
|
||||
case time.Time:
|
||||
return "{{ TIME_VALUE }}", nil
|
||||
default:
|
||||
return "{{ VALUE }}", nil
|
||||
}
|
||||
}
|
||||
|
||||
// Tx writes the transaction start.
|
||||
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
|
||||
if _, err := io.WriteString(w, "BEGIN;\n"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return w, nil
|
||||
return dialect.NopTx(w), nil
|
||||
}
|
||||
|
||||
// Commit writes the transaction commit.
|
||||
func (w *WriteDriver) Commit() error {
|
||||
_, err := io.WriteString(w, "COMMIT;\n")
|
||||
return err
|
||||
// noResult represents a zero result.
|
||||
type noResult struct{}
|
||||
|
||||
func (noResult) LastInsertId() (int64, error) { return 0, nil }
|
||||
func (noResult) RowsAffected() (int64, error) { return 0, nil }
|
||||
|
||||
// noRows represents no rows.
|
||||
type noRows struct{ sql.ColumnScanner }
|
||||
|
||||
func (noRows) Close() error { return nil }
|
||||
func (noRows) Err() error { return nil }
|
||||
func (noRows) Next() bool { return false }
|
||||
|
||||
type nopDriver struct {
|
||||
dialect.Driver
|
||||
dialect string
|
||||
}
|
||||
|
||||
// Rollback writes the transaction rollback.
|
||||
func (w *WriteDriver) Rollback() error {
|
||||
_, err := io.WriteString(w, "ROLLBACK;\n")
|
||||
return err
|
||||
func (d nopDriver) Dialect() string { return d.dialect }
|
||||
|
||||
func (nopDriver) Query(context.Context, string, any, any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,47 +7,82 @@ package schema
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ariga.io/atlas/sql/migrate"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqljson"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteDriver(t *testing.T) {
|
||||
b := &bytes.Buffer{}
|
||||
w := WriteDriver{Driver: nopDriver{}, Writer: b}
|
||||
w := NewWriteDriver(dialect.MySQL, b)
|
||||
ctx := context.Background()
|
||||
tx, err := w.Tx(ctx)
|
||||
require.NoError(t, err)
|
||||
err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil)
|
||||
require.NoError(t, err)
|
||||
err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.EqualError(t, err, "query is not supported by the WriteDriver")
|
||||
err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `age` int", nil, nil)
|
||||
require.NoError(t, err)
|
||||
err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", nil, nil)
|
||||
require.NoError(t, err)
|
||||
err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tx.Commit())
|
||||
lines := strings.Split(b.String(), "\n")
|
||||
require.Equal(t, "BEGIN;", lines[0])
|
||||
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `age` int;", lines[1])
|
||||
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", lines[2])
|
||||
require.Equal(t, "COMMIT;", lines[3])
|
||||
require.Empty(t, lines[4], "file ends with blank line")
|
||||
require.Len(t, lines, 3)
|
||||
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `age` int;", lines[0])
|
||||
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", lines[1])
|
||||
require.Empty(t, lines[2], "file ends with blank line")
|
||||
|
||||
b.Reset()
|
||||
query, args := sql.Update("users").Schema("test").Set("a", 1).Set("b", "a").Set("c", "'c'").Set("d", true).Where(sql.EQ("p", 0.2)).Query()
|
||||
err = w.Exec(ctx, query, args, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "UPDATE `test`.`users` SET `a` = 1, `b` = 'a', `c` = '''c''', `d` = 1 WHERE `p` = 0.2;\n", b.String())
|
||||
|
||||
b.Reset()
|
||||
query, args = sql.Dialect(dialect.MySQL).Update("users").Schema("test").Set("a", "{}").Where(sqljson.ValueIsNull("a")).Query()
|
||||
err = w.Exec(ctx, query, args, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "UPDATE `test`.`users` SET `a` = '{}' WHERE JSON_CONTAINS(`a`, 'null', '$');\n", b.String())
|
||||
|
||||
b.Reset()
|
||||
w = NewWriteDriver(dialect.Postgres, b)
|
||||
query, args = sql.Dialect(dialect.Postgres).Update("users").Set("a", 1).Set("b", time.Now()).Query()
|
||||
err = w.Exec(ctx, query, args, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `UPDATE "users" SET "a" = 1, "b" = {{ TIME_VALUE }};`+"\n", b.String())
|
||||
|
||||
b.Reset()
|
||||
err = w.Exec(ctx, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id`, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id;`+"\n", b.String())
|
||||
}
|
||||
|
||||
type nopDriver struct {
|
||||
dialect.Driver
|
||||
}
|
||||
|
||||
func (nopDriver) Exec(context.Context, string, any, any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nopDriver) Query(context.Context, string, any, any) error {
|
||||
return nil
|
||||
func TestDirWriter(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
dir, err := migrate.NewLocalDir(p)
|
||||
require.NoError(t, err)
|
||||
w := &DirWriter{Dir: dir}
|
||||
drv := NewWriteDriver(dialect.MySQL, w)
|
||||
require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `a` = ?", []any{1}, nil))
|
||||
w.Change("Comment 1.")
|
||||
require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `b` = ?", []any{2}, nil))
|
||||
w.Change("Comment 2.")
|
||||
require.NoError(t, w.Flush("migration_file"))
|
||||
files, err := os.ReadDir(p)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 2)
|
||||
require.Contains(t, files[0].Name(), "_migration_file.sql")
|
||||
buf, err := os.ReadFile(filepath.Join(p, files[0].Name()))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n", string(buf))
|
||||
require.Equal(t, "atlas.sum", files[1].Name())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user