mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: atlas engine is now default (#2698)
* atlas engine is default, enabled diff by replay * Apply suggestions from code review * docs * apply CR
This commit is contained in:
@@ -6,23 +6,227 @@ package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"ariga.io/atlas/sql/migrate"
|
||||
"ariga.io/atlas/sql/schema"
|
||||
"ariga.io/atlas/sql/sqlclient"
|
||||
"ariga.io/atlas/sql/sqltool"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/schema/field"
|
||||
)
|
||||
|
||||
// Atlas atlas migration engine.
|
||||
type Atlas struct {
|
||||
atDriver migrate.Driver
|
||||
sqlDialect sqlDialect
|
||||
|
||||
legacy bool // if the legacy migration engine instead of Atlas should be used
|
||||
withFixture bool // deprecated: with fks rename fixture
|
||||
sum bool // deprecated: sum file generation will be required
|
||||
|
||||
universalID bool // global unique ids
|
||||
dropColumns bool // drop deleted columns
|
||||
dropIndexes bool // drop deleted indexes
|
||||
withForeignKeys bool // with foreign keys
|
||||
mode Mode
|
||||
hooks []Hook // hooks to apply before creation
|
||||
diffHooks []DiffHook // diff hooks to run when diffing current and desired
|
||||
applyHook []ApplyHook // apply hooks to run when applying the plan
|
||||
skip ChangeKind // what changes to skip and not apply
|
||||
dir migrate.Dir // the migration directory to read from
|
||||
fmt migrate.Formatter // how to format the plan into migration files
|
||||
|
||||
driver dialect.Driver // driver passed in when not using an atlas URL
|
||||
url *url.URL // url of database connection
|
||||
dialect string // Ent dialect to use when generating migration files
|
||||
|
||||
types []string // pre-existing pk range allocation for global unique id
|
||||
}
|
||||
|
||||
// Diff compares the state read from a database connection or migration directory with the state defined by the Ent
|
||||
// schema. Changes will be written to new migration files.
|
||||
func Diff(ctx context.Context, u, name string, tables []*Table, opts ...MigrateOption) (err error) {
|
||||
m, err := NewMigrateURL(u, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.NamedDiff(ctx, name, tables...)
|
||||
}
|
||||
|
||||
// NewMigrate creates a new Atlas form the given dialect.Driver.
|
||||
func NewMigrate(drv dialect.Driver, opts ...MigrateOption) (*Atlas, error) {
|
||||
a := &Atlas{driver: drv, withForeignKeys: true, mode: ModeInspect, sum: true}
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
a.dialect = a.driver.Dialect()
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// NewMigrateURL create a new Atlas from the given url.
|
||||
func NewMigrateURL(u string, opts ...MigrateOption) (*Atlas, error) {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a := &Atlas{url: parsed, withForeignKeys: true, mode: ModeInspect, sum: true}
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
if a.dialect == "" {
|
||||
a.dialect = parsed.Scheme
|
||||
}
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Create creates all schema resources in the database. It works in an "append-only"
|
||||
// mode, which means, it only creates tables, appends columns to tables or modifies column types.
|
||||
//
|
||||
// Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not
|
||||
// resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but
|
||||
// changing varchar(120) to varchar(255) is valid. For more info, see the convert function below.
|
||||
func (a *Atlas) Create(ctx context.Context, tables ...*Table) (err error) {
|
||||
a.setupTables(tables)
|
||||
var creator Creator = CreateFunc(a.create)
|
||||
if a.legacy {
|
||||
m, err := a.legacyMigrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
creator = CreateFunc(m.create)
|
||||
}
|
||||
for i := len(a.hooks) - 1; i >= 0; i-- {
|
||||
creator = a.hooks[i](creator)
|
||||
}
|
||||
return creator.Create(ctx, tables...)
|
||||
}
|
||||
|
||||
// Diff compares the state read from the connected database with the state defined by Ent.
|
||||
// Changes will be written to migration files by the configured Planner.
|
||||
func (a *Atlas) Diff(ctx context.Context, tables ...*Table) error {
|
||||
return a.NamedDiff(ctx, "changes", tables...)
|
||||
}
|
||||
|
||||
// NamedDiff compares the state read from the connected database with the state defined by Ent.
|
||||
// Changes will be written to migration files by the configured Planner.
|
||||
func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) error {
|
||||
if a.dir == nil {
|
||||
return errors.New("no migration directory given")
|
||||
}
|
||||
opts := []migrate.PlannerOption{migrate.WithFormatter(a.fmt)}
|
||||
if a.sum {
|
||||
// Validate the migration directory before proceeding.
|
||||
if err := migrate.Validate(a.dir); err != nil {
|
||||
return fmt.Errorf("validating migration directory: %w", err)
|
||||
}
|
||||
} else {
|
||||
opts = append(opts, migrate.DisableChecksum())
|
||||
}
|
||||
a.setupTables(tables)
|
||||
// Set up connections.
|
||||
if a.driver != nil {
|
||||
var err error
|
||||
a.sqlDialect, err = a.entDialect(a.driver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.atDriver, err = a.sqlDialect.atOpen(a.sqlDialect)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
c, err := sqlclient.OpenURL(ctx, a.url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.atDriver = c.Driver
|
||||
}
|
||||
defer func() {
|
||||
a.sqlDialect = nil
|
||||
a.atDriver = nil
|
||||
}()
|
||||
if err := a.sqlDialect.init(ctx, a.sqlDialect); err != nil {
|
||||
return err
|
||||
}
|
||||
if a.universalID {
|
||||
tables = append(tables, NewTable(TypeTable).
|
||||
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
|
||||
)
|
||||
}
|
||||
switch a.mode {
|
||||
case ModeInspect:
|
||||
// Do nothing here, simply inspect later on.
|
||||
case ModeReplay:
|
||||
// We consider a database clean if there are no tables in the connected schema.
|
||||
s, err := a.atDriver.InspectSchema(ctx, "", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(s.Tables) > 0 {
|
||||
return migrate.ErrNotClean
|
||||
}
|
||||
// Clean up once done.
|
||||
defer func() {
|
||||
// We clean a database by dropping all tables inside the connected schema.
|
||||
s, err = a.atDriver.InspectSchema(ctx, "", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tbls := make([]schema.Change, len(s.Tables))
|
||||
for i, t := range s.Tables {
|
||||
tbls[i] = &schema.DropTable{T: t}
|
||||
}
|
||||
if err2 := a.atDriver.ApplyChanges(ctx, tbls); err2 != nil {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %w", err2, err)
|
||||
return
|
||||
}
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
}()
|
||||
// Replay the migration directory on the database.
|
||||
ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown migration mode: %q", a.mode)
|
||||
}
|
||||
plan, err := a.plan(ctx, a.sqlDialect, name, tables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Skip if the plan has no changes.
|
||||
if len(plan.Changes) == 0 {
|
||||
return nil
|
||||
}
|
||||
return migrate.NewPlanner(nil, a.dir, opts...).WritePlan(plan)
|
||||
}
|
||||
|
||||
type (
|
||||
// Differ is the interface that wraps the Diff method.
|
||||
Differ interface {
|
||||
@@ -59,8 +263,8 @@ func (f DiffFunc) Diff(current, desired *schema.Schema) ([]schema.Change, error)
|
||||
// })
|
||||
//
|
||||
func WithDiffHook(hooks ...DiffHook) MigrateOption {
|
||||
return func(m *Migrate) {
|
||||
m.atlas.diff = append(m.atlas.diff, hooks...)
|
||||
return func(a *Atlas) {
|
||||
a.diffHooks = append(a.diffHooks, hooks...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,8 +274,8 @@ func WithDiffHook(hooks ...DiffHook) MigrateOption {
|
||||
// SkipChanges(schema.DropTable|schema.DropColumn)
|
||||
//
|
||||
func WithSkipChanges(skip ChangeKind) MigrateOption {
|
||||
return func(m *Migrate) {
|
||||
m.atlas.skip = skip
|
||||
return func(a *Atlas) {
|
||||
a.skip = skip
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,146 +441,180 @@ func (f ApplyFunc) Apply(ctx context.Context, conn dialect.ExecQuerier, plan *mi
|
||||
// })
|
||||
//
|
||||
func WithApplyHook(hooks ...ApplyHook) MigrateOption {
|
||||
return func(m *Migrate) {
|
||||
m.atlas.apply = append(m.atlas.apply, hooks...)
|
||||
return func(a *Atlas) {
|
||||
a.applyHook = append(a.applyHook, hooks...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAtlas is an opt-in option for v0.10 indicates the migration
|
||||
// should be executed using Atlas engine (i.e. https://atlasgo.io).
|
||||
// Note, in future versions, this option is going to be replaced
|
||||
// from opt-in to opt-out and the deprecation of this package.
|
||||
// WithAtlas is an opt-out option for v0.11 indicating the migration
|
||||
// should be executed using the deprecated legacy engine.
|
||||
// Note, in future versions, this option is going to be removed
|
||||
// and the Atlas (https://atlasgo.io) based migration engine should be used.
|
||||
//
|
||||
// Deprecated: The legacy engine will be removed.
|
||||
func WithAtlas(b bool) MigrateOption {
|
||||
return func(m *Migrate) {
|
||||
m.atlas.enabled = b
|
||||
return func(a *Atlas) {
|
||||
a.legacy = !b
|
||||
}
|
||||
}
|
||||
|
||||
// WithDir sets the atlas migration directory to use to store migration files.
|
||||
func WithDir(dir migrate.Dir) MigrateOption {
|
||||
return func(m *Migrate) {
|
||||
m.atlas.dir = dir
|
||||
return func(a *Atlas) {
|
||||
a.dir = dir
|
||||
}
|
||||
}
|
||||
|
||||
// WithFormatter sets atlas formatter to use to write changes to migration files.
|
||||
func WithFormatter(fmt migrate.Formatter) MigrateOption {
|
||||
return func(m *Migrate) {
|
||||
m.atlas.fmt = fmt
|
||||
return func(a *Atlas) {
|
||||
a.fmt = fmt
|
||||
}
|
||||
}
|
||||
|
||||
// WithSumFile instructs atlas to generate a migration directory integrity sum file as well.
|
||||
func WithSumFile() MigrateOption {
|
||||
return func(m *Migrate) {
|
||||
m.atlas.genSum = true
|
||||
// WithDialect configures the Ent dialect to use when migrating for an Atlas supported dialect flavor.
|
||||
// As an example, Ent can work with TiDB in MySQL dialect and Atlas can handle TiDB migrations.
|
||||
func WithDialect(d string) MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
a.dialect = d
|
||||
}
|
||||
}
|
||||
|
||||
// WithUniversalID instructs atlas to use a file based type store when
|
||||
// global unique ids are enabled. For more information see the setupAtlas method on Migrate.
|
||||
// WithSumFile instructs atlas to generate a migration directory integrity sum file.
|
||||
//
|
||||
// ATTENTION:
|
||||
// The file based PK range store is not backward compatible, since the allocated ranges were computed
|
||||
// dynamically when computing the diff between a deployed database and the current schema. In cases where there
|
||||
// exist multiple deployments, the allocated ranges for the same type might be different from each other,
|
||||
// depending on when the deployment took part.
|
||||
func WithUniversalID() MigrateOption {
|
||||
return func(m *Migrate) {
|
||||
m.universalID = true
|
||||
m.atlas.typeStoreConsent = true
|
||||
// Deprecated: generating the sum file is now opt-out. This method will be removed in future versions.
|
||||
func WithSumFile() MigrateOption {
|
||||
return func(a *Atlas) {}
|
||||
}
|
||||
|
||||
// DisableChecksum instructs atlas to skip migration directory integrity sum file generation.
|
||||
//
|
||||
// Deprecated: generating the sum file will no longer be optional in future versions.
|
||||
func DisableChecksum() MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
a.sum = false
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
// atlasOptions describes the options for atlas.
|
||||
atlasOptions struct {
|
||||
enabled bool
|
||||
diff []DiffHook
|
||||
apply []ApplyHook
|
||||
skip ChangeKind
|
||||
dir migrate.Dir
|
||||
fmt migrate.Formatter
|
||||
genSum bool
|
||||
typeStoreConsent bool
|
||||
// WithMigrationMode instructs atlas how to compute the current state of the schema. This can be done by either
|
||||
// replaying (ModeReplay) the migration directory on the connected database, or by inspecting (ModeInspect) the
|
||||
// connection. Currently, ModeReplay is opt-in, and ModeInspect is the default. In future versions, ModeReplay will
|
||||
// become the default behavior. This option has no effect when using online migrations.
|
||||
func WithMigrationMode(mode Mode) MigrateOption {
|
||||
return func(a *Atlas) {
|
||||
a.mode = mode
|
||||
}
|
||||
}
|
||||
|
||||
// atBuilder must be implemented by the different drivers in
|
||||
// order to convert a dialect/sql/schema to atlas/sql/schema.
|
||||
atBuilder interface {
|
||||
atOpen(dialect.ExecQuerier) (migrate.Driver, error)
|
||||
atTable(*Table, *schema.Table)
|
||||
atTypeC(*Column, *schema.Column) error
|
||||
atUniqueC(*Table, *Column, *schema.Table, *schema.Column)
|
||||
atIncrementC(*schema.Table, *schema.Column)
|
||||
atIncrementT(*schema.Table, int64)
|
||||
atIndex(*Index, *schema.Table, *schema.Index) error
|
||||
atTypeRangeSQL(t ...string) string
|
||||
}
|
||||
// Mode to compute the current state.
|
||||
type Mode uint
|
||||
|
||||
const (
|
||||
// ModeReplay computes the current state by replaying the migration directory on the connected database.
|
||||
ModeReplay = iota
|
||||
// ModeInspect computes the current state by inspecting the connected database.
|
||||
ModeInspect
|
||||
)
|
||||
|
||||
var errConsent = errors.New("sql/schema: use WithUniversalID() instead of WithGlobalUniqueID(true) when using WithDir(): https://entgo.io/docs/migrate#universal-ids")
|
||||
// StateReader returns an atlas migrate.StateReader returning the state as described by the Ent table slice.
|
||||
func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc {
|
||||
return func(context.Context) (*schema.Realm, error) {
|
||||
ts, err := a.tables(tables)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &schema.Realm{Schemas: []*schema.Schema{{Tables: ts}}}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migrate) setupAtlas() error {
|
||||
// Using one of the Atlas options, opt-in to Atlas migration.
|
||||
if !m.atlas.enabled && (m.atlas.skip != NoChange || len(m.atlas.diff) > 0 || len(m.atlas.apply) > 0) || m.atlas.dir != nil {
|
||||
m.atlas.enabled = true
|
||||
}
|
||||
if !m.atlas.enabled {
|
||||
return nil
|
||||
}
|
||||
if m.withFixture {
|
||||
return errors.New("sql/schema: WithFixture(true) does not work in Atlas migration")
|
||||
}
|
||||
// atBuilder must be implemented by the different drivers in
|
||||
// order to convert a dialect/sql/schema to atlas/sql/schema.
|
||||
type atBuilder interface {
|
||||
atOpen(dialect.ExecQuerier) (migrate.Driver, error)
|
||||
atTable(*Table, *schema.Table)
|
||||
atTypeC(*Column, *schema.Column) error
|
||||
atUniqueC(*Table, *Column, *schema.Table, *schema.Column)
|
||||
atIncrementC(*schema.Table, *schema.Column)
|
||||
atIncrementT(*schema.Table, int64)
|
||||
atIndex(*Index, *schema.Table, *schema.Index) error
|
||||
atTypeRangeSQL(t ...string) string
|
||||
}
|
||||
|
||||
// init initializes the configuration object based on the options passed in.
|
||||
func (a *Atlas) init() error {
|
||||
skip := DropIndex | DropColumn
|
||||
if m.atlas.skip != NoChange {
|
||||
skip = m.atlas.skip
|
||||
if a.skip != NoChange {
|
||||
skip = a.skip
|
||||
}
|
||||
if m.dropIndexes {
|
||||
if a.dropIndexes {
|
||||
skip &= ^DropIndex
|
||||
}
|
||||
if m.dropColumns {
|
||||
if a.dropColumns {
|
||||
skip &= ^DropColumn
|
||||
}
|
||||
if skip != NoChange {
|
||||
m.atlas.diff = append(m.atlas.diff, filterChanges(skip))
|
||||
a.diffHooks = append(a.diffHooks, filterChanges(skip))
|
||||
}
|
||||
if !m.withForeignKeys {
|
||||
m.atlas.diff = append(m.atlas.diff, withoutForeignKeys)
|
||||
if !a.withForeignKeys {
|
||||
a.diffHooks = append(a.diffHooks, withoutForeignKeys)
|
||||
}
|
||||
if m.atlas.dir != nil && m.atlas.fmt == nil {
|
||||
m.atlas.fmt = sqltool.GolangMigrateFormatter
|
||||
if a.dir != nil && a.fmt == nil {
|
||||
a.fmt = sqltool.GolangMigrateFormatter
|
||||
}
|
||||
if m.universalID && m.atlas.dir != nil {
|
||||
// If global unique ids and a migration directory is given, enable the file based type store for pk ranges.
|
||||
m.typeStore = &dirTypeStore{dir: m.atlas.dir}
|
||||
// To guard the user against a possible bug due to backward incompatibility, the file based type store must
|
||||
// be enabled by an option. For more information see the comment of WithUniversalID function.
|
||||
if !m.atlas.typeStoreConsent {
|
||||
return errConsent
|
||||
if a.mode == ModeReplay {
|
||||
// ModeReplay requires a migration directory.
|
||||
if a.dir == nil {
|
||||
return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires versioned migrations: WithDir()")
|
||||
}
|
||||
// ModeReplay requires sum file generation.
|
||||
if !a.sum {
|
||||
return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires migration directory integrity file")
|
||||
}
|
||||
m.atlas.diff = append(m.atlas.diff, m.ensureTypeTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrate) atCreate(ctx context.Context, tables ...*Table) error {
|
||||
// create is the Atlas engine based online migration.
|
||||
func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
|
||||
if a.universalID {
|
||||
tables = append(tables, NewTable(TypeTable).
|
||||
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
|
||||
)
|
||||
}
|
||||
if a.driver != nil {
|
||||
a.sqlDialect, err = a.entDialect(a.driver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
c, err := sqlclient.OpenURL(ctx, a.url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
defer func() { a.sqlDialect = nil }()
|
||||
// Open a transaction for backwards compatibility,
|
||||
// even if the migration is not transactional.
|
||||
tx, err := m.Tx(ctx)
|
||||
tx, err := a.sqlDialect.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.sqlDialect.init(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
a.atDriver, err = a.sqlDialect.atOpen(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { a.atDriver = nil }()
|
||||
if err := func() error {
|
||||
if err := m.init(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.universalID {
|
||||
if err := m.types(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
plan, err := m.atDiff(ctx, tx, "", tables...)
|
||||
plan, err := a.plan(ctx, tx, "changes", tables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -392,22 +630,25 @@ func (m *Migrate) atCreate(ctx context.Context, tables ...*Table) error {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
for i := len(m.atlas.apply) - 1; i >= 0; i-- {
|
||||
applier = m.atlas.apply[i](applier)
|
||||
for i := len(a.applyHook) - 1; i >= 0; i-- {
|
||||
applier = a.applyHook[i](applier)
|
||||
}
|
||||
return applier.Apply(ctx, tx, plan)
|
||||
}(); err != nil {
|
||||
return rollback(tx, err)
|
||||
err = fmt.Errorf("sql/schema: %w", err)
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
err = fmt.Errorf("%w: %v", err, rerr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (m *Migrate) atDiff(ctx context.Context, conn dialect.ExecQuerier, name string, tables ...*Table) (*migrate.Plan, error) {
|
||||
drv, err := m.atOpen(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
current, err := drv.InspectSchema(ctx, "", &schema.InspectOptions{
|
||||
// plan creates the current state by inspecting the connected database, computing the current state of the Ent schema
|
||||
// and proceeds to diff the changes to create a migration plan.
|
||||
// before diffing.
|
||||
func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) {
|
||||
current, err := a.atDriver.InspectSchema(ctx, "", &schema.InspectOptions{
|
||||
Tables: func() (t []string) {
|
||||
for i := range tables {
|
||||
t = append(t, tables[i].Name)
|
||||
@@ -418,21 +659,55 @@ func (m *Migrate) atDiff(ctx context.Context, conn dialect.ExecQuerier, name str
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt, err := m.aTables(ctx, m, conn, tables)
|
||||
var types []string
|
||||
if a.universalID {
|
||||
// Fetch pre-existing type allocations.
|
||||
exists, err := a.sqlDialect.tableExist(ctx, conn, TypeTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
rows := &entsql.Rows{}
|
||||
query, args := entsql.Dialect(a.dialect).
|
||||
Select("type").From(entsql.Table(TypeTable)).OrderBy(entsql.Asc("id")).Query()
|
||||
if err := conn.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("query types table: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
a.types = nil
|
||||
if err := entsql.ScanSlice(rows, &a.types); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
types = a.types
|
||||
}
|
||||
desired, err := a.StateReader(tables...).ReadState(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Diff changes.
|
||||
var differ Differ = DiffFunc(drv.SchemaDiff)
|
||||
for i := len(m.atlas.diff) - 1; i >= 0; i-- {
|
||||
differ = m.atlas.diff[i](differ)
|
||||
}
|
||||
changes, err := differ.Diff(current, &schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: tt})
|
||||
changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, &schema.Schema{
|
||||
Name: current.Name,
|
||||
Attrs: current.Attrs,
|
||||
Tables: desired.Schemas[0].Tables,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Plan changes.
|
||||
return drv.PlanChanges(ctx, name, changes)
|
||||
plan, err := a.atDriver.PlanChanges(ctx, name, changes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Insert new types.
|
||||
newTypes := a.types[len(types):]
|
||||
if len(newTypes) > 0 {
|
||||
plan.Changes = append(plan.Changes, &migrate.Change{
|
||||
Cmd: a.sqlDialect.atTypeRangeSQL(newTypes...),
|
||||
Comment: fmt.Sprintf("add pk ranges for %s tables", strings.Join(newTypes, ",")),
|
||||
})
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
type db struct{ dialect.ExecQuerier }
|
||||
@@ -453,28 +728,29 @@ func (d *db) ExecContext(ctx context.Context, query string, args ...interface{})
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (m *Migrate) aTables(ctx context.Context, b atBuilder, conn dialect.ExecQuerier, tables1 []*Table) ([]*schema.Table, error) {
|
||||
tables2 := make([]*schema.Table, len(tables1))
|
||||
for i, t1 := range tables1 {
|
||||
t2 := schema.NewTable(t1.Name)
|
||||
b.atTable(t1, t2)
|
||||
if m.universalID {
|
||||
r, err := m.pkRange(ctx, conn, t1)
|
||||
// tables converts an Ent table slice to an atlas table slice
|
||||
func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) {
|
||||
ts := make([]*schema.Table, len(tables))
|
||||
for i, et := range tables {
|
||||
at := schema.NewTable(et.Name)
|
||||
a.sqlDialect.atTable(et, at)
|
||||
if a.universalID && et.Name != TypeTable {
|
||||
r, err := a.pkRange(et)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.atIncrementT(t2, r)
|
||||
a.sqlDialect.atIncrementT(at, r)
|
||||
}
|
||||
if err := m.aColumns(b, t1, t2); err != nil {
|
||||
if err := a.aColumns(et, at); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.aIndexes(b, t1, t2); err != nil {
|
||||
if err := a.aIndexes(et, at); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tables2[i] = t2
|
||||
ts[i] = at
|
||||
}
|
||||
for i, t1 := range tables1 {
|
||||
t2 := tables2[i]
|
||||
for i, t1 := range tables {
|
||||
t2 := ts[i]
|
||||
for _, fk1 := range t1.ForeignKeys {
|
||||
fk2 := schema.NewForeignKey(fk1.Symbol).
|
||||
SetTable(t2).
|
||||
@@ -488,7 +764,7 @@ func (m *Migrate) aTables(ctx context.Context, b atBuilder, conn dialect.ExecQue
|
||||
fk2.AddColumns(c2)
|
||||
}
|
||||
var refT *schema.Table
|
||||
for _, t2 := range tables2 {
|
||||
for _, t2 := range ts {
|
||||
if t2.Name == fk1.RefTable.Name {
|
||||
refT = t2
|
||||
break
|
||||
@@ -508,17 +784,17 @@ func (m *Migrate) aTables(ctx context.Context, b atBuilder, conn dialect.ExecQue
|
||||
t2.AddForeignKeys(fk2)
|
||||
}
|
||||
}
|
||||
return tables2, nil
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func (m *Migrate) aColumns(b atBuilder, t1 *Table, t2 *schema.Table) error {
|
||||
for _, c1 := range t1.Columns {
|
||||
func (a *Atlas) aColumns(et *Table, at *schema.Table) error {
|
||||
for _, c1 := range et.Columns {
|
||||
c2 := schema.NewColumn(c1.Name).
|
||||
SetNull(c1.Nullable)
|
||||
if c1.Collation != "" {
|
||||
c2.SetCollation(c1.Collation)
|
||||
}
|
||||
if err := b.atTypeC(c1, c2); err != nil {
|
||||
if err := a.sqlDialect.atTypeC(c1, c2); err != nil {
|
||||
return err
|
||||
}
|
||||
if c1.Default != nil && c1.supportDefault() {
|
||||
@@ -530,106 +806,128 @@ func (m *Migrate) aColumns(b atBuilder, t1 *Table, t2 *schema.Table) error {
|
||||
}
|
||||
c2.SetDefault(&schema.RawExpr{X: x})
|
||||
}
|
||||
if c1.Unique && (len(t1.PrimaryKey) != 1 || t1.PrimaryKey[0] != c1) {
|
||||
b.atUniqueC(t1, c1, t2, c2)
|
||||
if c1.Unique && (len(et.PrimaryKey) != 1 || et.PrimaryKey[0] != c1) {
|
||||
a.sqlDialect.atUniqueC(et, c1, at, c2)
|
||||
}
|
||||
if c1.Increment {
|
||||
b.atIncrementC(t2, c2)
|
||||
a.sqlDialect.atIncrementC(at, c2)
|
||||
}
|
||||
t2.AddColumns(c2)
|
||||
at.AddColumns(c2)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrate) aIndexes(b atBuilder, t1 *Table, t2 *schema.Table) error {
|
||||
func (a *Atlas) aIndexes(et *Table, at *schema.Table) error {
|
||||
// Primary-key index.
|
||||
pk := make([]*schema.Column, 0, len(t1.PrimaryKey))
|
||||
for _, c1 := range t1.PrimaryKey {
|
||||
c2, ok := t2.Column(c1.Name)
|
||||
pk := make([]*schema.Column, 0, len(et.PrimaryKey))
|
||||
for _, c1 := range et.PrimaryKey {
|
||||
c2, ok := at.Column(c1.Name)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected primary-key column: %q", c1.Name)
|
||||
}
|
||||
pk = append(pk, c2)
|
||||
}
|
||||
t2.SetPrimaryKey(schema.NewPrimaryKey(pk...))
|
||||
at.SetPrimaryKey(schema.NewPrimaryKey(pk...))
|
||||
// Rest of indexes.
|
||||
for _, idx1 := range t1.Indexes {
|
||||
for _, idx1 := range et.Indexes {
|
||||
idx2 := schema.NewIndex(idx1.Name).
|
||||
SetUnique(idx1.Unique)
|
||||
if err := b.atIndex(idx1, t2, idx2); err != nil {
|
||||
if err := a.sqlDialect.atIndex(idx1, at, idx2); err != nil {
|
||||
return err
|
||||
}
|
||||
desc := descIndexes(idx1)
|
||||
for _, p := range idx2.Parts {
|
||||
p.Desc = desc[p.C.Name]
|
||||
}
|
||||
t2.AddIndexes(idx2)
|
||||
at.AddIndexes(idx2)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrate) ensureTypeTable(next Differ) Differ {
|
||||
return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) {
|
||||
// If there is a types table but no types file yet, the user most likely
|
||||
// switched from online migration to migration files.
|
||||
if len(m.dbTypeRanges) == 0 {
|
||||
var (
|
||||
at = schema.NewTable(TypeTable)
|
||||
et = NewTable(TypeTable).
|
||||
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
|
||||
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
|
||||
)
|
||||
m.atTable(et, at)
|
||||
if err := m.aColumns(m, et, at); err != nil {
|
||||
return nil, err
|
||||
// setupTables ensures the table is configured properly, like table columns
|
||||
// are linked to their indexes, and PKs columns are defined.
|
||||
func (a *Atlas) setupTables(tables []*Table) {
|
||||
for _, t := range tables {
|
||||
if t.columns == nil {
|
||||
t.columns = make(map[string]*Column, len(t.Columns))
|
||||
}
|
||||
for _, c := range t.Columns {
|
||||
t.columns[c.Name] = c
|
||||
}
|
||||
for _, idx := range t.Indexes {
|
||||
idx.Name = a.symbol(idx.Name)
|
||||
for _, c := range idx.Columns {
|
||||
c.indexes.append(idx)
|
||||
}
|
||||
if err := m.aIndexes(m, et, at); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
c := t.columns[pk.Name]
|
||||
c.Key = PrimaryKey
|
||||
pk.Key = PrimaryKey
|
||||
}
|
||||
for _, fk := range t.ForeignKeys {
|
||||
fk.Symbol = a.symbol(fk.Symbol)
|
||||
for i := range fk.Columns {
|
||||
fk.Columns[i].foreign = fk
|
||||
}
|
||||
desired.Tables = append(desired.Tables, at)
|
||||
}
|
||||
// If there is a drift between the types stored in the database and the ones stored in the file,
|
||||
// stop diffing, as this is potentially destructive. This will most likely happen on the first diffing
|
||||
// after moving from online-migration to versioned migrations if the "old" ent types are not in sync with
|
||||
// the deterministic ones computed by the new engine.
|
||||
if len(m.dbTypeRanges) > 0 && len(m.fileTypeRanges) > 0 && !equal(m.fileTypeRanges, m.dbTypeRanges) {
|
||||
return nil, fmt.Errorf(
|
||||
"type allocation range drift detected: %v <> %v: see %s for more information",
|
||||
m.dbTypeRanges, m.fileTypeRanges,
|
||||
"https://entgo.io/docs/versioned-migrations#moving-from-auto-migration-to-versioned-migrations",
|
||||
)
|
||||
}
|
||||
changes, err := next.Diff(current, desired)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(m.dbTypeRanges) > 0 && len(m.fileTypeRanges) == 0 {
|
||||
// Override the types file created in the diff process with the "old" allocated types ranges.
|
||||
if err := m.typeStore.(*dirTypeStore).save(m.dbTypeRanges); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Change the type range allocations since they will be added to the migration files when
|
||||
// writing the migration plan to migration files.
|
||||
m.typeRanges = m.dbTypeRanges
|
||||
}
|
||||
return changes, nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setAtChecks(t1 *Table, t2 *schema.Table) {
|
||||
if check := t1.Annotation.Check; check != "" {
|
||||
t2.AddChecks(&schema.Check{
|
||||
// symbol makes sure the symbol length is not longer than the maxlength in the dialect.
|
||||
func (a *Atlas) symbol(name string) string {
|
||||
size := 64
|
||||
if a.dialect == dialect.Postgres {
|
||||
size = 63
|
||||
}
|
||||
if len(name) <= size {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", name[:size-33], md5.Sum([]byte(name)))
|
||||
}
|
||||
|
||||
// entDialect returns the Ent dialect as configured by the dialect option.
|
||||
func (a *Atlas) entDialect(drv dialect.Driver) (sqlDialect, error) {
|
||||
switch a.dialect {
|
||||
case dialect.MySQL:
|
||||
return &MySQL{Driver: drv}, nil
|
||||
case dialect.SQLite:
|
||||
return &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}, nil
|
||||
case dialect.Postgres:
|
||||
return &Postgres{Driver: drv}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Atlas) pkRange(et *Table) (int64, error) {
|
||||
idx := indexOf(a.types, et.Name)
|
||||
// If the table re-created, re-use its range from
|
||||
// the past. Otherwise, allocate a new id-range.
|
||||
if idx == -1 {
|
||||
if len(a.types) > MaxTypes {
|
||||
return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes)
|
||||
}
|
||||
idx = len(a.types)
|
||||
a.types = append(a.types, et.Name)
|
||||
}
|
||||
return int64(idx << 32), nil
|
||||
}
|
||||
|
||||
func setAtChecks(et *Table, at *schema.Table) {
|
||||
if check := et.Annotation.Check; check != "" {
|
||||
at.AddChecks(&schema.Check{
|
||||
Expr: check,
|
||||
})
|
||||
}
|
||||
if checks := t1.Annotation.Checks; len(t1.Annotation.Checks) > 0 {
|
||||
if checks := et.Annotation.Checks; len(et.Annotation.Checks) > 0 {
|
||||
names := make([]string, 0, len(checks))
|
||||
for name := range checks {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
t2.AddChecks(&schema.Check{
|
||||
at.AddChecks(&schema.Check{
|
||||
Name: name,
|
||||
Expr: checks[name],
|
||||
})
|
||||
@@ -654,61 +952,49 @@ func descIndexes(idx *Index) map[string]bool {
|
||||
return descs
|
||||
}
|
||||
|
||||
const entTypes = ".ent_types"
|
||||
|
||||
// dirTypeStore stores and read pk information from a text file stored alongside generated versioned migrations.
|
||||
// This behaviour is enabled automatically when using versioned migrations.
|
||||
type dirTypeStore struct {
|
||||
dir migrate.Dir
|
||||
// driver decorates the atlas migrate.Driver and adds "diff hooking" and functionality.
|
||||
type diffDriver struct {
|
||||
migrate.Driver
|
||||
hooks []DiffHook // hooks to apply
|
||||
}
|
||||
|
||||
const atlasDirective = "atlas:sum ignore\n"
|
||||
|
||||
// load the types from the types file.
|
||||
func (s *dirTypeStore) load(context.Context, dialect.ExecQuerier) ([]string, error) {
|
||||
f, err := s.dir.Open(entTypes)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, fmt.Errorf("reading types file: %w", err)
|
||||
}
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
defer f.Close()
|
||||
c, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading types file: %w", err)
|
||||
}
|
||||
return strings.Split(strings.TrimPrefix(string(c), atlasDirective), ","), nil
|
||||
// RealmDiff creates the diff between two realms. Since Ent does not care about Realms,
|
||||
// not even schema changes, calling this method raises an error.
|
||||
func (r *diffDriver) RealmDiff(_, _ *schema.Realm) ([]schema.Change, error) {
|
||||
return nil, errors.New("sqlDialect does not support working with realms")
|
||||
}
|
||||
|
||||
// add a new type entry to the types file.
|
||||
func (s *dirTypeStore) add(ctx context.Context, conn dialect.ExecQuerier, t string) error {
|
||||
ts, err := s.load(ctx, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding type %q: %w", t, err)
|
||||
// SchemaDiff creates the diff between two schemas, but includes "diff hooks".
|
||||
func (r *diffDriver) SchemaDiff(from, to *schema.Schema) ([]schema.Change, error) {
|
||||
var d Differ = DiffFunc(r.Driver.SchemaDiff)
|
||||
for i := len(r.hooks) - 1; i >= 0; i-- {
|
||||
d = r.hooks[i](d)
|
||||
}
|
||||
return s.save(append(ts, t))
|
||||
return d.Diff(from, to)
|
||||
}
|
||||
|
||||
// save takes the given allocation range and writes them to the types file.
|
||||
// The types file will be overridden.
|
||||
func (s *dirTypeStore) save(ts []string) error {
|
||||
if err := s.dir.WriteFile(entTypes, []byte(atlasDirective+strings.Join(ts, ","))); err != nil {
|
||||
return fmt.Errorf("writing types file: %w", err)
|
||||
// legacyMigrate returns a configured legacy migration engine (before Atlas) to keep backwards compatibility.
|
||||
//
|
||||
// Deprecated: Will be removed alongside legacy migration support.
|
||||
func (a *Atlas) legacyMigrate() (*Migrate, error) {
|
||||
m := &Migrate{
|
||||
universalID: a.universalID,
|
||||
dropColumns: a.dropColumns,
|
||||
dropIndexes: a.dropIndexes,
|
||||
withFixture: a.withFixture,
|
||||
withForeignKeys: a.withForeignKeys,
|
||||
hooks: a.hooks,
|
||||
atlas: a,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ typeStore = (*dirTypeStore)(nil)
|
||||
|
||||
func equal(s1, s2 []string) bool {
|
||||
if len(s1) != len(s2) {
|
||||
return false
|
||||
}
|
||||
for i := range s1 {
|
||||
if s1[i] != s2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
switch a.dialect {
|
||||
case dialect.MySQL:
|
||||
m.sqlDialect = &MySQL{Driver: a.driver}
|
||||
case dialect.SQLite:
|
||||
m.sqlDialect = &SQLite{Driver: a.driver, WithForeignKeys: a.withForeignKeys}
|
||||
case dialect.Postgres:
|
||||
m.sqlDialect = &Postgres{Driver: a.driver}
|
||||
default:
|
||||
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user