From cfb8f5c4a95956610e751bbc11dec8ef647f703a Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Thu, 7 Jan 2021 20:37:47 +0200 Subject: [PATCH] dialect/sql/schema: minor style changes (#1152) --- dialect/sql/schema/migrate.go | 32 +++++++++--------- dialect/sql/schema/migrate_test.go | 52 ++++++++++++------------------ 2 files changed, 35 insertions(+), 49 deletions(-) diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index f2899f390..3c379b324 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -66,23 +66,29 @@ func WithForeignKeys(b bool) MigrateOption { } } +// WithHooks adds a list of hooks to the schema migration. +func WithHooks(hooks ...Hook) MigrateOption { + return func(m *Migrate) { + m.hooks = append(m.hooks, hooks...) + } +} + type ( // Creator is the interface that wraps the Create method. Creator interface { - // Create creates tables. + // Create creates the given tables in the database. See Migrate.Create for more details. Create(context.Context, ...*Table) error } - // The CreateFunc type is an adapter to allow the use of ordinary - // function as Creator. If f is a function with the appropriate signature, - // CreateFunc(f) is a Creator that calls f. + // The CreateFunc type is an adapter to allow the use of ordinary function as Creator. + // If f is a function with the appropriate signature, CreateFunc(f) is a Creator that calls f. CreateFunc func(context.Context, ...*Table) error - // Hook defines the "create middleware". A function that gets a Creator - // and returns a Creator. For example: + // Hook defines the "create middleware". A function that gets a Creator and returns a Creator. + // For example: // // hook := func(next schema.Creator) schema.Creator { - // return schema.CreateFunc(func(ctx context.Context, tables ...*Table) error { + // return schema.CreateFunc(func(ctx context.Context, tables ...*schema.Table) error { // fmt.Println("Tables:", tables) // return next.Create(ctx, tables...) // }) @@ -96,13 +102,6 @@ func (f CreateFunc) Create(ctx context.Context, tables ...*Table) error { return f(ctx, tables...) } -// WithHook adds a create hook. -func WithHook(hook Hook) MigrateOption { - return func(m *Migrate) { - m.hooks = append(m.hooks, hook) - } -} - // Migrate runs the migrations logic for the SQL dialects. type Migrate struct { sqlDialect @@ -121,7 +120,6 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) { for _, opt := range opts { opt(m) } - switch d.Dialect() { case dialect.MySQL: m.sqlDialect = &MySQL{Driver: d} @@ -166,13 +164,13 @@ func (m *Migrate) create(ctx context.Context, tables ...*Table) error { return rollback(tx, err) } } - if err := m.createInTx(ctx, tx, tables...); err != nil { + if err := m.txCreate(ctx, tx, tables...); err != nil { return rollback(tx, err) } return tx.Commit() } -func (m *Migrate) createInTx(ctx context.Context, tx dialect.Tx, tables ...*Table) error { +func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table) error { for _, t := range tables { m.setupTable(t) switch exist, err := m.tableExist(ctx, tx, t.Name); { diff --git a/dialect/sql/schema/migrate_test.go b/dialect/sql/schema/migrate_test.go index 0f73d121d..5b8710017 100644 --- a/dialect/sql/schema/migrate_test.go +++ b/dialect/sql/schema/migrate_test.go @@ -6,31 +6,26 @@ package schema import ( "context" + "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/facebook/ent/dialect/sql" "github.com/stretchr/testify/require" - "testing" ) func TestMigrateHookOmitTable(t *testing.T) { - db, mock, err := sqlmock.New() + db, mk, err := sqlmock.New() require.NoError(t, err) - tables := []*Table{ - {Name: "users"}, - {Name: "pets"}, - } - - myMock := mysqlMock{mock} - myMock.start("5.7.23") - - myMock.tableExists("pets", false) - myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). + tables := []*Table{{Name: "users"}, {Name: "pets"}} + mock := mysqlMock{mk} + mock.start("5.7.23") + mock.tableExists("pets", false) + mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() - myMock.ExpectCommit() - - migrate, err := NewMigrate(sql.OpenDB("mysql", db), WithHook(func(next Creator) Creator { + migrate, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator { return CreateFunc(func(ctx context.Context, tables ...*Table) error { return next.Create(ctx, tables[1]) }) @@ -41,28 +36,21 @@ func TestMigrateHookOmitTable(t *testing.T) { } func TestMigrateHookAddTable(t *testing.T) { - db, mock, err := sqlmock.New() + db, mk, err := sqlmock.New() require.NoError(t, err) - tables := []*Table{ - {Name: "users"}, - {Name: "pets"}, - } - - myMock := mysqlMock{mock} - myMock.start("5.7.23") - - myMock.tableExists("users", false) - myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). + tables := []*Table{{Name: "users"}} + mock := mysqlMock{mk} + mock.start("5.7.23") + mock.tableExists("users", false) + mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) - - myMock.tableExists("pets", false) - myMock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). + mock.tableExists("pets", false) + mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() - myMock.ExpectCommit() - - migrate, err := NewMigrate(sql.OpenDB("mysql", db), WithHook(func(next Creator) Creator { + migrate, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator { return CreateFunc(func(ctx context.Context, tables ...*Table) error { return next.Create(ctx, tables[0], &Table{Name: "pets"}) })