dialect/sql/schema: disable foreign keys before opening a transaction (#2966)

* dialect/sql/schema: disable foreign keys before opening a transaction

* dialect/sql/schema: disable foreign keys before opening a transaction

* fix tests

* add test for bug

* apply CR
This commit is contained in:
Jannik Clausen
2022-09-28 07:41:49 +02:00
committed by GitHub
parent e02622a064
commit c41d223733
12 changed files with 190 additions and 36 deletions

View File

@@ -164,7 +164,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
a.sqlDialect = nil
a.atDriver = nil
}()
if err := a.sqlDialect.init(ctx, a.sqlDialect); err != nil {
if err := a.sqlDialect.init(ctx); err != nil {
return err
}
if a.universalID {
@@ -656,15 +656,15 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
}
}
defer func() { a.sqlDialect = nil }()
if err := a.sqlDialect.init(ctx); err != nil {
return err
}
// Open a transaction for backwards compatibility,
// even if the migration is not transactional.
tx, err := a.sqlDialect.Tx(ctx)
if err != nil {
return err
}
if err := a.sqlDialect.init(ctx, tx); err != nil {
return err
}
a.atDriver, err = a.sqlDialect.atOpen(tx)
if err != nil {
return err

View File

@@ -139,13 +139,13 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
}
func (m *Migrate) create(ctx context.Context, tables ...*Table) error {
if err := m.init(ctx); err != nil {
return err
}
tx, err := m.Tx(ctx)
if err != nil {
return err
}
if err := m.init(ctx, tx); err != nil {
return rollback(tx, err)
}
if m.universalID {
if err := m.types(ctx, tx); err != nil {
return rollback(tx, err)
@@ -185,7 +185,7 @@ func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table)
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("create table %q: %w", t.Name, err)
}
// If global unique identifier is enabled and it's not
// If global unique identifier is enabled, and it's not
// a relation table, allocate a range for the table pk.
if m.universalID && len(t.PrimaryKey) == 1 {
if err := m.allocPKRange(ctx, tx, t); err != nil {
@@ -606,7 +606,7 @@ func indexOf(a []string, s string) int {
type sqlDialect interface {
atBuilder
dialect.Driver
init(context.Context, dialect.ExecQuerier) error
init(context.Context) error
table(context.Context, dialect.Tx, string) (*Table, error)
tableExist(context.Context, dialect.ExecQuerier, string) (bool, error)
fkExist(context.Context, dialect.Tx, string) (bool, error)

View File

@@ -29,9 +29,9 @@ type MySQL struct {
}
// init loads the MySQL version from the database for later use in the migration process.
func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error {
func (d *MySQL) init(ctx context.Context) error {
rows := &sql.Rows{}
if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
if err := d.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
return fmt.Errorf("mysql: querying mysql version %w", err)
}
defer rows.Close()

View File

@@ -1375,9 +1375,9 @@ type mysqlMock struct {
}
func (m mysqlMock) start(version string) {
m.ExpectBegin()
m.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", version))
m.ExpectBegin()
}
func (m mysqlMock) tableExists(table string, exists bool) {

View File

@@ -29,9 +29,9 @@ type Postgres struct {
// init loads the Postgres version from the database for later use in the migration process.
// It returns an error if the server version is lower than v10.
func (d *Postgres) init(ctx context.Context, tx dialect.ExecQuerier) error {
func (d *Postgres) init(ctx context.Context) error {
rows := &sql.Rows{}
if err := tx.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
if err := d.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
return fmt.Errorf("querying server version %w", err)
}
defer rows.Close()

View File

@@ -1010,9 +1010,9 @@ type pgMock struct {
}
func (m pgMock) start(version string) {
m.ExpectBegin()
m.ExpectQuery(escape("SHOW server_version_num")).
WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow(version))
m.ExpectBegin()
}
func (m pgMock) tableExists(table string, exists bool) {

View File

@@ -6,6 +6,7 @@ package schema
import (
"context"
stdsql "database/sql"
"fmt"
"strconv"
"strings"
@@ -19,15 +20,51 @@ import (
"ariga.io/atlas/sql/sqlite"
)
// SQLite is an SQLite migration driver.
type SQLite struct {
dialect.Driver
WithForeignKeys bool
type (
// SQLite is an SQLite migration driver.
SQLite struct {
dialect.Driver
WithForeignKeys bool
}
// SQLiteTx implements dialect.Tx.
SQLiteTx struct {
dialect.Tx
commit func() error // Override Commit to toggle foreign keys back on after Commit.
rollback func() error // Override Rollback to toggle foreign keys back on after Rollback.
}
)
// Tx implements opens a transaction.
func (d *SQLite) Tx(ctx context.Context) (dialect.Tx, error) {
db := &db{d}
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = off"); err != nil {
return nil, fmt.Errorf("sqlite: set 'foreign_keys = off': %w", err)
}
t, err := d.Driver.Tx(ctx)
if err != nil {
return nil, err
}
tx := &tx{t}
cm, err := sqlite.CommitFunc(ctx, db, tx, true)
if err != nil {
return nil, err
}
return &SQLiteTx{Tx: t, commit: cm, rollback: sqlite.RollbackFunc(ctx, db, tx, true)}, nil
}
// Commit ensures foreign keys are toggled back on after commit.
func (tx *SQLiteTx) Commit() error {
return tx.commit()
}
// Rollback ensures foreign keys are toggled back on after rollback.
func (tx *SQLiteTx) Rollback() error {
return tx.rollback()
}
// init makes sure that foreign_keys support is enabled.
func (d *SQLite) init(ctx context.Context, tx dialect.ExecQuerier) error {
on, err := exist(ctx, tx, "PRAGMA foreign_keys")
func (d *SQLite) init(ctx context.Context) error {
on, err := exist(ctx, d, "PRAGMA foreign_keys")
if err != nil {
return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err)
}
@@ -453,9 +490,29 @@ func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) erro
return nil
}
func (SQLite) atTypeRangeSQL(ts ...string) string {
func (*SQLite) atTypeRangeSQL(ts ...string) string {
for i := range ts {
ts[i] = fmt.Sprintf("('%s')", ts[i])
}
return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", "))
}
type tx struct {
dialect.Tx
}
func (tx *tx) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, err
}
return rows.ColumnScanner.(*stdsql.Rows), nil
}
func (tx *tx) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
var r stdsql.Result
if err := tx.Exec(ctx, query, args, &r); err != nil {
return nil, err
}
return r, nil
}

View File

@@ -47,7 +47,7 @@ func TestSQLite_Create(t *testing.T) {
name: "no tables",
before: func(mock sqliteMock) {
mock.start()
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -73,7 +73,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("users", false)
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -120,7 +120,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("pets", false)
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -170,7 +170,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("pets", false)
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -204,7 +204,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -234,7 +234,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -275,7 +275,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))).
WillReturnResult(sqlmock.NewResult(0, 1))
}
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -306,7 +306,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -347,7 +347,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -389,7 +389,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
WithArgs("groups", 1<<32).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
@@ -428,7 +428,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
WithArgs("groups", 1<<32).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
}
@@ -450,9 +450,21 @@ type sqliteMock struct {
}
func (m sqliteMock) start() {
m.ExpectBegin()
m.ExpectQuery("PRAGMA foreign_keys").
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1))
m.ExpectExec("PRAGMA foreign_keys = off").
WillReturnResult(sqlmock.NewResult(0, 1))
m.ExpectBegin()
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
}
func (m sqliteMock) commit() {
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
m.ExpectCommit()
m.ExpectExec("PRAGMA foreign_keys = on").
WillReturnResult(sqlmock.NewResult(0, 1))
}
func (m sqliteMock) tableExists(table string, exists bool) {