mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: support setting schema name for migration (#4327)
This commit is contained in:
@@ -29,6 +29,7 @@ type Atlas struct {
|
||||
atDriver migrate.Driver
|
||||
sqlDialect sqlDialect
|
||||
|
||||
schema string // schema to use
|
||||
indent string // plan indentation
|
||||
errNoPlan bool // no plan error enabled
|
||||
universalID bool // global unique ids
|
||||
@@ -662,13 +663,15 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
|
||||
// planInspect creates the current state by inspecting the connected database, computing the current state of the Ent schema
|
||||
// and proceeds to diff the changes to create a migration plan.
|
||||
func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) {
|
||||
current, err := a.atDriver.InspectSchema(ctx, "", &schema.InspectOptions{
|
||||
current, err := a.atDriver.InspectSchema(ctx, a.schema, &schema.InspectOptions{
|
||||
Tables: func() (t []string) {
|
||||
for i := range tables {
|
||||
t = append(t, tables[i].Name)
|
||||
}
|
||||
return t
|
||||
}(),
|
||||
// Ent supports table-level inspection only.
|
||||
Mode: schema.InspectSchemas | schema.InspectTables,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -692,7 +695,7 @@ func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name
|
||||
|
||||
func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*migrate.Plan, error) {
|
||||
// We consider a database clean if there are no tables in the connected schema.
|
||||
s, err := a.atDriver.InspectSchema(ctx, "", nil)
|
||||
s, err := a.atDriver.InspectSchema(ctx, a.schema, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -705,21 +708,21 @@ func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*
|
||||
return nil, err
|
||||
}
|
||||
if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
|
||||
return nil, a.cleanSchema(ctx, "", err)
|
||||
return nil, a.cleanSchema(ctx, a.schema, err)
|
||||
}
|
||||
// Inspect the current schema (migration directory).
|
||||
current, err := a.atDriver.InspectSchema(ctx, "", nil)
|
||||
current, err := a.atDriver.InspectSchema(ctx, a.schema, nil)
|
||||
if err != nil {
|
||||
return nil, a.cleanSchema(ctx, "", err)
|
||||
return nil, a.cleanSchema(ctx, a.schema, err)
|
||||
}
|
||||
var types []string
|
||||
if a.universalID {
|
||||
if types, err = a.loadTypes(ctx, a.sqlDialect); err != nil && !errors.Is(err, errTypeTableNotFound) {
|
||||
return nil, a.cleanSchema(ctx, "", err)
|
||||
return nil, a.cleanSchema(ctx, a.schema, err)
|
||||
}
|
||||
a.types = types
|
||||
}
|
||||
if err := a.cleanSchema(ctx, "", nil); err != nil {
|
||||
if err := a.cleanSchema(ctx, a.schema, nil); err != nil {
|
||||
return nil, fmt.Errorf("clean schemas after migration replaying: %w", err)
|
||||
}
|
||||
desired, err := a.tables(tables)
|
||||
|
||||
@@ -57,6 +57,14 @@ func WithErrNoPlan(b bool) MigrateOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSchemaName sets the database schema for the migration.
|
||||
// If not set, the CURRENT_SCHEMA() is used.
|
||||
func WithSchemaName(ns string) MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
a.schema = ns
|
||||
}
|
||||
}
|
||||
|
||||
// WithDropColumn sets the columns dropping option to the migration.
|
||||
// Defaults to false.
|
||||
func WithDropColumn(b bool) MigrateOption {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
@@ -28,6 +29,54 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrate_SchemaName(t *testing.T) {
|
||||
db, mk, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
mk.ExpectQuery(escape("SHOW server_version_num")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow("130000"))
|
||||
mk.ExpectQuery(escape("SELECT current_setting('server_version_num'), current_setting('default_table_access_method', true), current_setting('crdb_version', true)")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"current_setting", "current_setting", "current_setting"}).AddRow("130000", "heap", ""))
|
||||
mk.ExpectQuery("SELECT nspname AS schema_name,.+").
|
||||
WithArgs("public"). // Schema "public" param is used.
|
||||
WillReturnRows(sqlmock.NewRows([]string{"schema_name", "comment"}).AddRow("public", "default schema"))
|
||||
mk.ExpectQuery("SELECT t3.oid, t1.table_schema,.+").
|
||||
WillReturnRows(sqlmock.NewRows([]string{}))
|
||||
m, err := NewMigrate(sql.OpenDB("postgres", db), WithSchemaName("public"), WithDiffHook(func(next Differ) Differ {
|
||||
return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) {
|
||||
return nil, nil // Noop.
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, m.Create(context.Background()))
|
||||
require.NoError(t, mk.ExpectationsWereMet())
|
||||
|
||||
// Without schema name the CURRENT_SCHEMA is used.
|
||||
mk.ExpectQuery(escape("SHOW server_version_num")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow("130000"))
|
||||
mk.ExpectQuery(escape("SELECT current_setting('server_version_num'), current_setting('default_table_access_method', true), current_setting('crdb_version', true)")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"current_setting", "current_setting", "current_setting"}).AddRow("130000", "heap", ""))
|
||||
mk.ExpectQuery("SELECT nspname AS schema_name,.+CURRENT_SCHEMA().+").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"schema_name", "comment"}).AddRow("public", "default schema"))
|
||||
mk.ExpectQuery("SELECT t3.oid, t1.table_schema,.+").
|
||||
WillReturnRows(sqlmock.NewRows([]string{}))
|
||||
m, err = NewMigrate(sql.OpenDB("postgres", db), WithDiffHook(func(next Differ) Differ {
|
||||
return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) {
|
||||
return nil, nil // Noop.
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, m.Create(context.Background()))
|
||||
}
|
||||
|
||||
func escape(query string) string {
|
||||
rows := strings.Split(query, "\n")
|
||||
for i := range rows {
|
||||
rows[i] = strings.TrimPrefix(rows[i], " ")
|
||||
}
|
||||
query = strings.Join(rows, " ")
|
||||
return strings.TrimSpace(regexp.QuoteMeta(query)) + "$"
|
||||
}
|
||||
|
||||
func TestMigrate_Formatter(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -30,6 +30,9 @@ type MySQL struct {
|
||||
|
||||
// init loads the MySQL version from the database for later use in the migration process.
|
||||
func (d *MySQL) init(ctx context.Context) error {
|
||||
if d.version != "" {
|
||||
return nil // already initialized.
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
if err := d.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
|
||||
return fmt.Errorf("mysql: querying mysql version %w", err)
|
||||
|
||||
@@ -30,6 +30,9 @@ type Postgres struct {
|
||||
// init loads the Postgres version from the database for later use in the migration process.
|
||||
// It returns an error if the server version is lower than v10.
|
||||
func (d *Postgres) init(ctx context.Context) error {
|
||||
if d.version != "" {
|
||||
return nil // already initialized.
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
if err := d.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
|
||||
return fmt.Errorf("querying server version %w", err)
|
||||
|
||||
Reference in New Issue
Block a user