diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index c3d54f8c8..f2899f390 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -66,6 +66,43 @@ func WithForeignKeys(b bool) MigrateOption { } } +type ( + // Creator is the interface that wraps the Create method. + Creator interface { + // Create creates tables. + Create(context.Context, ...*Table) error + } + + // The CreateFunc type is an adapter to allow the use of ordinary + // function as Creator. If f is a function with the appropriate signature, + // CreateFunc(f) is a Creator that calls f. + CreateFunc func(context.Context, ...*Table) error + + // Hook defines the "create middleware". A function that gets a Creator + // and returns a Creator. For example: + // + // hook := func(next schema.Creator) schema.Creator { + // return schema.CreateFunc(func(ctx context.Context, tables ...*Table) error { + // fmt.Println("Tables:", tables) + // return next.Create(ctx, tables...) + // }) + // } + // + Hook func(Creator) Creator +) + +// Create calls f(ctx, tables...). +func (f CreateFunc) Create(ctx context.Context, tables ...*Table) error { + return f(ctx, tables...) +} + +// WithHook adds a create hook. +func WithHook(hook Hook) MigrateOption { + return func(m *Migrate) { + m.hooks = append(m.hooks, hook) + } +} + // Migrate runs the migrations logic for the SQL dialects. type Migrate struct { sqlDialect @@ -75,6 +112,7 @@ type Migrate struct { withFixture bool // with fks rename fixture. withForeignKeys bool // with foreign keys typeRanges []string // types order by their range. + hooks []Hook // hooks to apply before creation } // NewMigrate create a migration structure for the given SQL driver. @@ -83,6 +121,7 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) { for _, opt := range opts { opt(m) } + switch d.Dialect() { case dialect.MySQL: m.sqlDialect = &MySQL{Driver: d} @@ -106,6 +145,15 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) { // Note that SQLite dialect does not support (this moment) the "append-only" mode describe above, // since it's used only for testing. func (m *Migrate) Create(ctx context.Context, tables ...*Table) error { + var creator Creator = CreateFunc(m.create) + for i := len(m.hooks) - 1; i >= 0; i-- { + creator = m.hooks[i](creator) + } + + return creator.Create(ctx, tables...) +} + +func (m *Migrate) create(ctx context.Context, tables ...*Table) error { tx, err := m.Tx(ctx) if err != nil { return err @@ -118,13 +166,13 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error { return rollback(tx, err) } } - if err := m.create(ctx, tx, tables...); err != nil { + if err := m.createInTx(ctx, tx, tables...); err != nil { return rollback(tx, err) } return tx.Commit() } -func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error { +func (m *Migrate) createInTx(ctx context.Context, tx dialect.Tx, tables ...*Table) error { for _, t := range tables { m.setupTable(t) switch exist, err := m.tableExist(ctx, tx, t.Name); { diff --git a/dialect/sql/schema/migrate_test.go b/dialect/sql/schema/migrate_test.go new file mode 100644 index 000000000..0f73d121d --- /dev/null +++ b/dialect/sql/schema/migrate_test.go @@ -0,0 +1,73 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package schema + +import ( + "context" + "github.com/DATA-DOG/go-sqlmock" + "github.com/facebook/ent/dialect/sql" + "github.com/stretchr/testify/require" + "testing" +) + +func TestMigrateHookOmitTable(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + + tables := []*Table{ + {Name: "users"}, + {Name: "pets"}, + } + + myMock := mysqlMock{mock} + myMock.start("5.7.23") + + myMock.tableExists("pets", false) + myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). + WillReturnResult(sqlmock.NewResult(0, 1)) + + myMock.ExpectCommit() + + migrate, err := NewMigrate(sql.OpenDB("mysql", db), WithHook(func(next Creator) Creator { + return CreateFunc(func(ctx context.Context, tables ...*Table) error { + return next.Create(ctx, tables[1]) + }) + })) + require.NoError(t, err) + err = migrate.Create(context.Background(), tables...) + require.NoError(t, err) +} + +func TestMigrateHookAddTable(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + + tables := []*Table{ + {Name: "users"}, + {Name: "pets"}, + } + + myMock := mysqlMock{mock} + myMock.start("5.7.23") + + myMock.tableExists("users", false) + myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). + WillReturnResult(sqlmock.NewResult(0, 1)) + + myMock.tableExists("pets", false) + myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). + WillReturnResult(sqlmock.NewResult(0, 1)) + + myMock.ExpectCommit() + + migrate, err := NewMigrate(sql.OpenDB("mysql", db), WithHook(func(next Creator) Creator { + return CreateFunc(func(ctx context.Context, tables ...*Table) error { + return next.Create(ctx, tables[0], &Table{Name: "pets"}) + }) + })) + require.NoError(t, err) + err = migrate.Create(context.Background(), tables...) + require.NoError(t, err) +}