mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: disable foreign keys before opening a transaction (#2966)
* dialect/sql/schema: disable foreign keys before opening a transaction * dialect/sql/schema: disable foreign keys before opening a transaction * fix tests * add test for bug * apply CR
This commit is contained in:
@@ -164,7 +164,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
|
||||
a.sqlDialect = nil
|
||||
a.atDriver = nil
|
||||
}()
|
||||
if err := a.sqlDialect.init(ctx, a.sqlDialect); err != nil {
|
||||
if err := a.sqlDialect.init(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if a.universalID {
|
||||
@@ -656,15 +656,15 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
|
||||
}
|
||||
}
|
||||
defer func() { a.sqlDialect = nil }()
|
||||
if err := a.sqlDialect.init(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
// Open a transaction for backwards compatibility,
|
||||
// even if the migration is not transactional.
|
||||
tx, err := a.sqlDialect.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.sqlDialect.init(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
a.atDriver, err = a.sqlDialect.atOpen(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -139,13 +139,13 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
|
||||
}
|
||||
|
||||
func (m *Migrate) create(ctx context.Context, tables ...*Table) error {
|
||||
if err := m.init(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
tx, err := m.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.init(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if m.universalID {
|
||||
if err := m.types(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
@@ -185,7 +185,7 @@ func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table)
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("create table %q: %w", t.Name, err)
|
||||
}
|
||||
// If global unique identifier is enabled and it's not
|
||||
// If global unique identifier is enabled, and it's not
|
||||
// a relation table, allocate a range for the table pk.
|
||||
if m.universalID && len(t.PrimaryKey) == 1 {
|
||||
if err := m.allocPKRange(ctx, tx, t); err != nil {
|
||||
@@ -606,7 +606,7 @@ func indexOf(a []string, s string) int {
|
||||
type sqlDialect interface {
|
||||
atBuilder
|
||||
dialect.Driver
|
||||
init(context.Context, dialect.ExecQuerier) error
|
||||
init(context.Context) error
|
||||
table(context.Context, dialect.Tx, string) (*Table, error)
|
||||
tableExist(context.Context, dialect.ExecQuerier, string) (bool, error)
|
||||
fkExist(context.Context, dialect.Tx, string) (bool, error)
|
||||
|
||||
@@ -29,9 +29,9 @@ type MySQL struct {
|
||||
}
|
||||
|
||||
// init loads the MySQL version from the database for later use in the migration process.
|
||||
func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error {
|
||||
func (d *MySQL) init(ctx context.Context) error {
|
||||
rows := &sql.Rows{}
|
||||
if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
|
||||
if err := d.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
|
||||
return fmt.Errorf("mysql: querying mysql version %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -1375,9 +1375,9 @@ type mysqlMock struct {
|
||||
}
|
||||
|
||||
func (m mysqlMock) start(version string) {
|
||||
m.ExpectBegin()
|
||||
m.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", version))
|
||||
m.ExpectBegin()
|
||||
}
|
||||
|
||||
func (m mysqlMock) tableExists(table string, exists bool) {
|
||||
|
||||
@@ -29,9 +29,9 @@ type Postgres struct {
|
||||
|
||||
// init loads the Postgres version from the database for later use in the migration process.
|
||||
// It returns an error if the server version is lower than v10.
|
||||
func (d *Postgres) init(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
func (d *Postgres) init(ctx context.Context) error {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
|
||||
if err := d.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
|
||||
return fmt.Errorf("querying server version %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -1010,9 +1010,9 @@ type pgMock struct {
|
||||
}
|
||||
|
||||
func (m pgMock) start(version string) {
|
||||
m.ExpectBegin()
|
||||
m.ExpectQuery(escape("SHOW server_version_num")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow(version))
|
||||
m.ExpectBegin()
|
||||
}
|
||||
|
||||
func (m pgMock) tableExists(table string, exists bool) {
|
||||
|
||||
@@ -6,6 +6,7 @@ package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
stdsql "database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -19,15 +20,51 @@ import (
|
||||
"ariga.io/atlas/sql/sqlite"
|
||||
)
|
||||
|
||||
// SQLite is an SQLite migration driver.
|
||||
type SQLite struct {
|
||||
dialect.Driver
|
||||
WithForeignKeys bool
|
||||
type (
|
||||
// SQLite is an SQLite migration driver.
|
||||
SQLite struct {
|
||||
dialect.Driver
|
||||
WithForeignKeys bool
|
||||
}
|
||||
// SQLiteTx implements dialect.Tx.
|
||||
SQLiteTx struct {
|
||||
dialect.Tx
|
||||
commit func() error // Override Commit to toggle foreign keys back on after Commit.
|
||||
rollback func() error // Override Rollback to toggle foreign keys back on after Rollback.
|
||||
}
|
||||
)
|
||||
|
||||
// Tx implements opens a transaction.
|
||||
func (d *SQLite) Tx(ctx context.Context) (dialect.Tx, error) {
|
||||
db := &db{d}
|
||||
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = off"); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: set 'foreign_keys = off': %w", err)
|
||||
}
|
||||
t, err := d.Driver.Tx(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx := &tx{t}
|
||||
cm, err := sqlite.CommitFunc(ctx, db, tx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SQLiteTx{Tx: t, commit: cm, rollback: sqlite.RollbackFunc(ctx, db, tx, true)}, nil
|
||||
}
|
||||
|
||||
// Commit ensures foreign keys are toggled back on after commit.
|
||||
func (tx *SQLiteTx) Commit() error {
|
||||
return tx.commit()
|
||||
}
|
||||
|
||||
// Rollback ensures foreign keys are toggled back on after rollback.
|
||||
func (tx *SQLiteTx) Rollback() error {
|
||||
return tx.rollback()
|
||||
}
|
||||
|
||||
// init makes sure that foreign_keys support is enabled.
|
||||
func (d *SQLite) init(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
on, err := exist(ctx, tx, "PRAGMA foreign_keys")
|
||||
func (d *SQLite) init(ctx context.Context) error {
|
||||
on, err := exist(ctx, d, "PRAGMA foreign_keys")
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err)
|
||||
}
|
||||
@@ -453,9 +490,29 @@ func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (SQLite) atTypeRangeSQL(ts ...string) string {
|
||||
func (*SQLite) atTypeRangeSQL(ts ...string) string {
|
||||
for i := range ts {
|
||||
ts[i] = fmt.Sprintf("('%s')", ts[i])
|
||||
}
|
||||
return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", "))
|
||||
}
|
||||
|
||||
type tx struct {
|
||||
dialect.Tx
|
||||
}
|
||||
|
||||
func (tx *tx) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rows.ColumnScanner.(*stdsql.Rows), nil
|
||||
}
|
||||
|
||||
func (tx *tx) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
|
||||
var r stdsql.Result
|
||||
if err := tx.Exec(ctx, query, args, &r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
name: "no tables",
|
||||
before: func(mock sqliteMock) {
|
||||
mock.start()
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -73,7 +73,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -120,7 +120,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
mock.tableExists("pets", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -170,7 +170,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
mock.tableExists("pets", false)
|
||||
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -204,7 +204,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -234,7 +234,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -275,7 +275,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
}
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -306,7 +306,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -347,7 +347,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -389,7 +389,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
|
||||
WithArgs("groups", 1<<32).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -428,7 +428,7 @@ func TestSQLite_Create(t *testing.T) {
|
||||
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
|
||||
WithArgs("groups", 1<<32).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.commit()
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -450,9 +450,21 @@ type sqliteMock struct {
|
||||
}
|
||||
|
||||
func (m sqliteMock) start() {
|
||||
m.ExpectBegin()
|
||||
m.ExpectQuery("PRAGMA foreign_keys").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1))
|
||||
m.ExpectExec("PRAGMA foreign_keys = off").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
m.ExpectBegin()
|
||||
m.ExpectQuery("PRAGMA foreign_key_check").
|
||||
WillReturnRows(sqlmock.NewRows([]string{})) // empty
|
||||
}
|
||||
|
||||
func (m sqliteMock) commit() {
|
||||
m.ExpectQuery("PRAGMA foreign_key_check").
|
||||
WillReturnRows(sqlmock.NewRows([]string{})) // empty
|
||||
m.ExpectCommit()
|
||||
m.ExpectExec("PRAGMA foreign_keys = on").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
}
|
||||
|
||||
func (m sqliteMock) tableExists(table string, exists bool) {
|
||||
|
||||
Reference in New Issue
Block a user