mirror of
https://github.com/ent/ent.git
synced 2026-05-05 00:50:54 +03:00
190 lines
6.1 KiB
Go
190 lines
6.1 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package schema
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"entgo.io/ent/dialect"
|
|
"entgo.io/ent/dialect/sql"
|
|
"entgo.io/ent/dialect/sql/sqljson"
|
|
|
|
"ariga.io/atlas/sql/migrate"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestWriteDriver(t *testing.T) {
|
|
b := &bytes.Buffer{}
|
|
w := NewWriteDriver(dialect.MySQL, b)
|
|
ctx := context.Background()
|
|
tx, err := w.Tx(ctx)
|
|
require.NoError(t, err)
|
|
err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil)
|
|
require.EqualError(t, err, "query is not supported by the WriteDriver")
|
|
err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `age` int", nil, nil)
|
|
require.NoError(t, err)
|
|
err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", nil, nil)
|
|
require.NoError(t, err)
|
|
require.NoError(t, tx.Commit())
|
|
lines := strings.Split(b.String(), "\n")
|
|
require.Len(t, lines, 3)
|
|
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `age` int;", lines[0])
|
|
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", lines[1])
|
|
require.Empty(t, lines[2], "file ends with blank line")
|
|
|
|
b.Reset()
|
|
query, args := sql.Update("users").Schema("test").Set("a", 1).Set("b", "a").Set("c", "'c'").Set("d", true).Where(sql.EQ("p", 0.2)).Query()
|
|
err = w.Exec(ctx, query, args, nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "UPDATE `test`.`users` SET `a` = 1, `b` = 'a', `c` = '''c''', `d` = 1 WHERE `p` = 0.2;\n", b.String())
|
|
|
|
b.Reset()
|
|
query, args = sql.Dialect(dialect.MySQL).Update("users").Schema("test").Set("a", "{}").Where(sqljson.ValueIsNull("a")).Query()
|
|
err = w.Exec(ctx, query, args, nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "UPDATE `test`.`users` SET `a` = '{}' WHERE JSON_CONTAINS(`a`, 'null', '$');\n", b.String())
|
|
|
|
b.Reset()
|
|
w = NewWriteDriver(dialect.Postgres, b)
|
|
query, args = sql.Dialect(dialect.Postgres).Update("users").Set("id", uuid.Nil).Set("a", 1).Set("b", time.Now()).Query()
|
|
err = w.Exec(ctx, query, args, nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, `UPDATE "users" SET "id" = '00000000-0000-0000-0000-000000000000', "a" = 1, "b" = {{ TIME_VALUE }};`+"\n", b.String())
|
|
|
|
b.Reset()
|
|
err = w.Exec(ctx, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id`, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id;`+"\n", b.String())
|
|
|
|
// batchCreator uses tx.Query when doing an insert
|
|
b.Reset()
|
|
err = w.Query(ctx, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id`, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id;`+"\n", b.String())
|
|
|
|
// correct columns are extracted from a returning clause and returned by sql.ColumnScanner.
|
|
for q, cols := range map[string][]string{
|
|
`INSERT INTO "users" (name) VALUES("a8m") RETURNING id`: {"id"},
|
|
`INSERT INTO "users" (name) VALUES("a8m") RETURNING id, "name"`: {"id", `"name"`},
|
|
`INSERT INTO "users" (name) VALUES("a8m") RETURNING "id", "name"`: {`"id"`, `"name"`},
|
|
`INSERT INTO "users" (name) VALUES("a8m") RETURNING "id", "name"; DROP "groups"`: {`"id"`, `"name"`},
|
|
} {
|
|
var rows sql.Rows
|
|
err = w.Query(ctx, q, nil, &rows)
|
|
require.NoError(t, err)
|
|
require.True(t, rows.Next())
|
|
c, err := rows.Columns()
|
|
require.NoError(t, err)
|
|
require.Equal(t, cols, c)
|
|
require.NoError(t, rows.Scan())
|
|
}
|
|
b.Reset()
|
|
}
|
|
|
|
func TestDirWriter(t *testing.T) {
|
|
for _, tt := range []struct {
|
|
dialect string
|
|
exec []string
|
|
comments []string
|
|
args [][]any
|
|
want string
|
|
}{
|
|
{
|
|
dialect.MySQL,
|
|
[]string{
|
|
"UPDATE `test`.`users` SET `a` = ?",
|
|
"UPDATE `test`.`users` SET `b` = ?",
|
|
},
|
|
[]string{
|
|
"Comment 1.",
|
|
"Comment 2.",
|
|
},
|
|
[][]any{
|
|
{1},
|
|
{2},
|
|
},
|
|
"-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n",
|
|
},
|
|
{
|
|
dialect.Postgres,
|
|
[]string{
|
|
"INSERT INTO \"users\" (\"name\", \"email\") VALUES ($1, $2) RETURNING \"id\"",
|
|
"INSERT INTO \"groups\" (\"name\") VALUES ($1) RETURNING \"id\"",
|
|
},
|
|
[]string{
|
|
"Seed users table",
|
|
"Seed groups table",
|
|
},
|
|
[][]any{
|
|
{"masseelch", "j@ariga.io"},
|
|
{"admins"},
|
|
},
|
|
strings.Join([]string{
|
|
"-- Seed users table\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('masseelch', 'j@ariga.io');\n",
|
|
"-- Seed groups table\nINSERT INTO \"groups\" (\"name\") VALUES ('admins');\n",
|
|
}, ""),
|
|
},
|
|
{
|
|
dialect.SQLite,
|
|
[]string{
|
|
"INSERT INTO `users` (`name`, `email`) VALUES (?, ?) RETURNING `id`",
|
|
"INSERT INTO `groups` (`name`) VALUES (?) RETURNING `id`",
|
|
},
|
|
[]string{
|
|
"Seed users table",
|
|
"Seed groups table",
|
|
},
|
|
[][]any{
|
|
{"masseelch", "j@ariga.io"},
|
|
{"admins"},
|
|
},
|
|
strings.Join([]string{
|
|
"-- Seed users table\nINSERT INTO `users` (`name`, `email`) VALUES ('masseelch', 'j@ariga.io');\n",
|
|
"-- Seed groups table\nINSERT INTO `groups` (`name`) VALUES ('admins');\n",
|
|
}, ""),
|
|
},
|
|
{
|
|
dialect.SQLite + " no space",
|
|
[]string{"INSERT INTO `users` (`name`) VALUES (?)RETURNING `id`"},
|
|
[]string{"Seed users table"},
|
|
[][]any{{"masseelch"}},
|
|
"-- Seed users table\nINSERT INTO `users` (`name`) VALUES ('masseelch');\n",
|
|
},
|
|
} {
|
|
t.Run(tt.dialect, func(t *testing.T) {
|
|
var (
|
|
p = t.TempDir()
|
|
dir = func() migrate.Dir {
|
|
d, err := migrate.NewLocalDir(p)
|
|
require.NoError(t, err)
|
|
return d
|
|
}()
|
|
w = &DirWriter{Dir: dir}
|
|
drv = NewWriteDriver(tt.dialect, w)
|
|
)
|
|
for i := range tt.exec {
|
|
require.NoError(t, drv.Exec(context.Background(), tt.exec[i], tt.args[i], nil))
|
|
w.Change(tt.comments[i])
|
|
}
|
|
require.NoError(t, w.Flush("migration_file"))
|
|
files, err := os.ReadDir(p)
|
|
require.NoError(t, err)
|
|
require.Len(t, files, 2)
|
|
require.Contains(t, files[0].Name(), "_migration_file.sql")
|
|
buf, err := os.ReadFile(filepath.Join(p, files[0].Name()))
|
|
require.NoError(t, err)
|
|
require.Equal(t, tt.want, string(buf))
|
|
require.Equal(t, "atlas.sum", files[1].Name())
|
|
})
|
|
}
|
|
}
|