dialect/sql/schema: add migration hooks (#1150)

This commit is contained in:
Ruben de Vries
2021-01-07 19:23:01 +01:00
committed by GitHub
parent a7f899339b
commit a9c39bb952
2 changed files with 123 additions and 2 deletions

View File

@@ -66,6 +66,43 @@ func WithForeignKeys(b bool) MigrateOption {
}
}
type (
// Creator is the interface that wraps the Create method.
Creator interface {
// Create creates tables.
Create(context.Context, ...*Table) error
}
// The CreateFunc type is an adapter to allow the use of ordinary
// function as Creator. If f is a function with the appropriate signature,
// CreateFunc(f) is a Creator that calls f.
CreateFunc func(context.Context, ...*Table) error
// Hook defines the "create middleware". A function that gets a Creator
// and returns a Creator. For example:
//
// hook := func(next schema.Creator) schema.Creator {
// return schema.CreateFunc(func(ctx context.Context, tables ...*Table) error {
// fmt.Println("Tables:", tables)
// return next.Create(ctx, tables...)
// })
// }
//
Hook func(Creator) Creator
)
// Create calls f(ctx, tables...).
func (f CreateFunc) Create(ctx context.Context, tables ...*Table) error {
return f(ctx, tables...)
}
// WithHook adds a create hook.
func WithHook(hook Hook) MigrateOption {
return func(m *Migrate) {
m.hooks = append(m.hooks, hook)
}
}
// Migrate runs the migrations logic for the SQL dialects.
type Migrate struct {
sqlDialect
@@ -75,6 +112,7 @@ type Migrate struct {
withFixture bool // with fks rename fixture.
withForeignKeys bool // with foreign keys
typeRanges []string // types order by their range.
hooks []Hook // hooks to apply before creation
}
// NewMigrate create a migration structure for the given SQL driver.
@@ -83,6 +121,7 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
for _, opt := range opts {
opt(m)
}
switch d.Dialect() {
case dialect.MySQL:
m.sqlDialect = &MySQL{Driver: d}
@@ -106,6 +145,15 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
// Note that SQLite dialect does not support (this moment) the "append-only" mode describe above,
// since it's used only for testing.
func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
var creator Creator = CreateFunc(m.create)
for i := len(m.hooks) - 1; i >= 0; i-- {
creator = m.hooks[i](creator)
}
return creator.Create(ctx, tables...)
}
func (m *Migrate) create(ctx context.Context, tables ...*Table) error {
tx, err := m.Tx(ctx)
if err != nil {
return err
@@ -118,13 +166,13 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
return rollback(tx, err)
}
}
if err := m.create(ctx, tx, tables...); err != nil {
if err := m.createInTx(ctx, tx, tables...); err != nil {
return rollback(tx, err)
}
return tx.Commit()
}
func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
func (m *Migrate) createInTx(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
for _, t := range tables {
m.setupTable(t)
switch exist, err := m.tableExist(ctx, tx, t.Name); {

View File

@@ -0,0 +1,73 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"context"
"github.com/DATA-DOG/go-sqlmock"
"github.com/facebook/ent/dialect/sql"
"github.com/stretchr/testify/require"
"testing"
)
func TestMigrateHookOmitTable(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tables := []*Table{
{Name: "users"},
{Name: "pets"},
}
myMock := mysqlMock{mock}
myMock.start("5.7.23")
myMock.tableExists("pets", false)
myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
WillReturnResult(sqlmock.NewResult(0, 1))
myMock.ExpectCommit()
migrate, err := NewMigrate(sql.OpenDB("mysql", db), WithHook(func(next Creator) Creator {
return CreateFunc(func(ctx context.Context, tables ...*Table) error {
return next.Create(ctx, tables[1])
})
}))
require.NoError(t, err)
err = migrate.Create(context.Background(), tables...)
require.NoError(t, err)
}
func TestMigrateHookAddTable(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tables := []*Table{
{Name: "users"},
{Name: "pets"},
}
myMock := mysqlMock{mock}
myMock.start("5.7.23")
myMock.tableExists("users", false)
myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
WillReturnResult(sqlmock.NewResult(0, 1))
myMock.tableExists("pets", false)
myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")).
WillReturnResult(sqlmock.NewResult(0, 1))
myMock.ExpectCommit()
migrate, err := NewMigrate(sql.OpenDB("mysql", db), WithHook(func(next Creator) Creator {
return CreateFunc(func(ctx context.Context, tables ...*Table) error {
return next.Create(ctx, tables[0], &Table{Name: "pets"})
})
}))
require.NoError(t, err)
err = migrate.Create(context.Background(), tables...)
require.NoError(t, err)
}