dialect/sql/schema: atlas engine is now default (#2698)

* atlas engine is default, enabled diff by replay

* Apply suggestions from code review

* docs

* apply CR
This commit is contained in:
Jannik Clausen
2022-07-05 12:29:15 +02:00
committed by GitHub
parent 91b643091f
commit 5b67bdab4f
19 changed files with 828 additions and 1475 deletions

View File

@@ -6,23 +6,227 @@ package schema
import (
"context"
"crypto/md5"
"database/sql"
"errors"
"fmt"
"io/fs"
"io/ioutil"
"net/url"
"sort"
"strings"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"
"ariga.io/atlas/sql/sqlclient"
"ariga.io/atlas/sql/sqltool"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
)
// Atlas atlas migration engine.
type Atlas struct {
atDriver migrate.Driver
sqlDialect sqlDialect
legacy bool // if the legacy migration engine instead of Atlas should be used
withFixture bool // deprecated: with fks rename fixture
sum bool // deprecated: sum file generation will be required
universalID bool // global unique ids
dropColumns bool // drop deleted columns
dropIndexes bool // drop deleted indexes
withForeignKeys bool // with foreign keys
mode Mode
hooks []Hook // hooks to apply before creation
diffHooks []DiffHook // diff hooks to run when diffing current and desired
applyHook []ApplyHook // apply hooks to run when applying the plan
skip ChangeKind // what changes to skip and not apply
dir migrate.Dir // the migration directory to read from
fmt migrate.Formatter // how to format the plan into migration files
driver dialect.Driver // driver passed in when not using an atlas URL
url *url.URL // url of database connection
dialect string // Ent dialect to use when generating migration files
types []string // pre-existing pk range allocation for global unique id
}
// Diff compares the state read from a database connection or migration directory with the state defined by the Ent
// schema. Changes will be written to new migration files.
func Diff(ctx context.Context, u, name string, tables []*Table, opts ...MigrateOption) (err error) {
m, err := NewMigrateURL(u, opts...)
if err != nil {
return err
}
return m.NamedDiff(ctx, name, tables...)
}
// NewMigrate creates a new Atlas form the given dialect.Driver.
func NewMigrate(drv dialect.Driver, opts ...MigrateOption) (*Atlas, error) {
a := &Atlas{driver: drv, withForeignKeys: true, mode: ModeInspect, sum: true}
for _, opt := range opts {
opt(a)
}
a.dialect = a.driver.Dialect()
if err := a.init(); err != nil {
return nil, err
}
return a, nil
}
// NewMigrateURL create a new Atlas from the given url.
func NewMigrateURL(u string, opts ...MigrateOption) (*Atlas, error) {
parsed, err := url.Parse(u)
if err != nil {
return nil, err
}
a := &Atlas{url: parsed, withForeignKeys: true, mode: ModeInspect, sum: true}
for _, opt := range opts {
opt(a)
}
if a.dialect == "" {
a.dialect = parsed.Scheme
}
if err := a.init(); err != nil {
return nil, err
}
return a, nil
}
// Create creates all schema resources in the database. It works in an "append-only"
// mode, which means, it only creates tables, appends columns to tables or modifies column types.
//
// Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not
// resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but
// changing varchar(120) to varchar(255) is valid. For more info, see the convert function below.
func (a *Atlas) Create(ctx context.Context, tables ...*Table) (err error) {
a.setupTables(tables)
var creator Creator = CreateFunc(a.create)
if a.legacy {
m, err := a.legacyMigrate()
if err != nil {
return err
}
creator = CreateFunc(m.create)
}
for i := len(a.hooks) - 1; i >= 0; i-- {
creator = a.hooks[i](creator)
}
return creator.Create(ctx, tables...)
}
// Diff compares the state read from the connected database with the state defined by Ent.
// Changes will be written to migration files by the configured Planner.
func (a *Atlas) Diff(ctx context.Context, tables ...*Table) error {
return a.NamedDiff(ctx, "changes", tables...)
}
// NamedDiff compares the state read from the connected database with the state defined by Ent.
// Changes will be written to migration files by the configured Planner.
func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) error {
if a.dir == nil {
return errors.New("no migration directory given")
}
opts := []migrate.PlannerOption{migrate.WithFormatter(a.fmt)}
if a.sum {
// Validate the migration directory before proceeding.
if err := migrate.Validate(a.dir); err != nil {
return fmt.Errorf("validating migration directory: %w", err)
}
} else {
opts = append(opts, migrate.DisableChecksum())
}
a.setupTables(tables)
// Set up connections.
if a.driver != nil {
var err error
a.sqlDialect, err = a.entDialect(a.driver)
if err != nil {
return err
}
a.atDriver, err = a.sqlDialect.atOpen(a.sqlDialect)
if err != nil {
return err
}
} else {
c, err := sqlclient.OpenURL(ctx, a.url)
if err != nil {
return err
}
defer c.Close()
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
if err != nil {
return err
}
a.atDriver = c.Driver
}
defer func() {
a.sqlDialect = nil
a.atDriver = nil
}()
if err := a.sqlDialect.init(ctx, a.sqlDialect); err != nil {
return err
}
if a.universalID {
tables = append(tables, NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
)
}
switch a.mode {
case ModeInspect:
// Do nothing here, simply inspect later on.
case ModeReplay:
// We consider a database clean if there are no tables in the connected schema.
s, err := a.atDriver.InspectSchema(ctx, "", nil)
if err != nil {
return err
}
if len(s.Tables) > 0 {
return migrate.ErrNotClean
}
// Clean up once done.
defer func() {
// We clean a database by dropping all tables inside the connected schema.
s, err = a.atDriver.InspectSchema(ctx, "", nil)
if err != nil {
return
}
tbls := make([]schema.Change, len(s.Tables))
for i, t := range s.Tables {
tbls[i] = &schema.DropTable{T: t}
}
if err2 := a.atDriver.ApplyChanges(ctx, tbls); err2 != nil {
if err != nil {
err = fmt.Errorf("%v: %w", err2, err)
return
}
err = err2
return
}
}()
// Replay the migration directory on the database.
ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{})
if err != nil {
return err
}
if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
return err
}
default:
return fmt.Errorf("unknown migration mode: %q", a.mode)
}
plan, err := a.plan(ctx, a.sqlDialect, name, tables)
if err != nil {
return err
}
// Skip if the plan has no changes.
if len(plan.Changes) == 0 {
return nil
}
return migrate.NewPlanner(nil, a.dir, opts...).WritePlan(plan)
}
type (
// Differ is the interface that wraps the Diff method.
Differ interface {
@@ -59,8 +263,8 @@ func (f DiffFunc) Diff(current, desired *schema.Schema) ([]schema.Change, error)
// })
//
func WithDiffHook(hooks ...DiffHook) MigrateOption {
return func(m *Migrate) {
m.atlas.diff = append(m.atlas.diff, hooks...)
return func(a *Atlas) {
a.diffHooks = append(a.diffHooks, hooks...)
}
}
@@ -70,8 +274,8 @@ func WithDiffHook(hooks ...DiffHook) MigrateOption {
// SkipChanges(schema.DropTable|schema.DropColumn)
//
func WithSkipChanges(skip ChangeKind) MigrateOption {
return func(m *Migrate) {
m.atlas.skip = skip
return func(a *Atlas) {
a.skip = skip
}
}
@@ -237,146 +441,180 @@ func (f ApplyFunc) Apply(ctx context.Context, conn dialect.ExecQuerier, plan *mi
// })
//
func WithApplyHook(hooks ...ApplyHook) MigrateOption {
return func(m *Migrate) {
m.atlas.apply = append(m.atlas.apply, hooks...)
return func(a *Atlas) {
a.applyHook = append(a.applyHook, hooks...)
}
}
// WithAtlas is an opt-in option for v0.10 indicates the migration
// should be executed using Atlas engine (i.e. https://atlasgo.io).
// Note, in future versions, this option is going to be replaced
// from opt-in to opt-out and the deprecation of this package.
// WithAtlas is an opt-out option for v0.11 indicating the migration
// should be executed using the deprecated legacy engine.
// Note, in future versions, this option is going to be removed
// and the Atlas (https://atlasgo.io) based migration engine should be used.
//
// Deprecated: The legacy engine will be removed.
func WithAtlas(b bool) MigrateOption {
return func(m *Migrate) {
m.atlas.enabled = b
return func(a *Atlas) {
a.legacy = !b
}
}
// WithDir sets the atlas migration directory to use to store migration files.
func WithDir(dir migrate.Dir) MigrateOption {
return func(m *Migrate) {
m.atlas.dir = dir
return func(a *Atlas) {
a.dir = dir
}
}
// WithFormatter sets atlas formatter to use to write changes to migration files.
func WithFormatter(fmt migrate.Formatter) MigrateOption {
return func(m *Migrate) {
m.atlas.fmt = fmt
return func(a *Atlas) {
a.fmt = fmt
}
}
// WithSumFile instructs atlas to generate a migration directory integrity sum file as well.
func WithSumFile() MigrateOption {
return func(m *Migrate) {
m.atlas.genSum = true
// WithDialect configures the Ent dialect to use when migrating for an Atlas supported dialect flavor.
// As an example, Ent can work with TiDB in MySQL dialect and Atlas can handle TiDB migrations.
func WithDialect(d string) MigrateOption {
return func(a *Atlas) {
a.dialect = d
}
}
// WithUniversalID instructs atlas to use a file based type store when
// global unique ids are enabled. For more information see the setupAtlas method on Migrate.
// WithSumFile instructs atlas to generate a migration directory integrity sum file.
//
// ATTENTION:
// The file based PK range store is not backward compatible, since the allocated ranges were computed
// dynamically when computing the diff between a deployed database and the current schema. In cases where there
// exist multiple deployments, the allocated ranges for the same type might be different from each other,
// depending on when the deployment took part.
func WithUniversalID() MigrateOption {
return func(m *Migrate) {
m.universalID = true
m.atlas.typeStoreConsent = true
// Deprecated: generating the sum file is now opt-out. This method will be removed in future versions.
func WithSumFile() MigrateOption {
return func(a *Atlas) {}
}
// DisableChecksum instructs atlas to skip migration directory integrity sum file generation.
//
// Deprecated: generating the sum file will no longer be optional in future versions.
func DisableChecksum() MigrateOption {
return func(a *Atlas) {
a.sum = false
}
}
type (
// atlasOptions describes the options for atlas.
atlasOptions struct {
enabled bool
diff []DiffHook
apply []ApplyHook
skip ChangeKind
dir migrate.Dir
fmt migrate.Formatter
genSum bool
typeStoreConsent bool
// WithMigrationMode instructs atlas how to compute the current state of the schema. This can be done by either
// replaying (ModeReplay) the migration directory on the connected database, or by inspecting (ModeInspect) the
// connection. Currently, ModeReplay is opt-in, and ModeInspect is the default. In future versions, ModeReplay will
// become the default behavior. This option has no effect when using online migrations.
func WithMigrationMode(mode Mode) MigrateOption {
return func(a *Atlas) {
a.mode = mode
}
}
// atBuilder must be implemented by the different drivers in
// order to convert a dialect/sql/schema to atlas/sql/schema.
atBuilder interface {
atOpen(dialect.ExecQuerier) (migrate.Driver, error)
atTable(*Table, *schema.Table)
atTypeC(*Column, *schema.Column) error
atUniqueC(*Table, *Column, *schema.Table, *schema.Column)
atIncrementC(*schema.Table, *schema.Column)
atIncrementT(*schema.Table, int64)
atIndex(*Index, *schema.Table, *schema.Index) error
atTypeRangeSQL(t ...string) string
}
// Mode to compute the current state.
type Mode uint
const (
// ModeReplay computes the current state by replaying the migration directory on the connected database.
ModeReplay = iota
// ModeInspect computes the current state by inspecting the connected database.
ModeInspect
)
var errConsent = errors.New("sql/schema: use WithUniversalID() instead of WithGlobalUniqueID(true) when using WithDir(): https://entgo.io/docs/migrate#universal-ids")
// StateReader returns an atlas migrate.StateReader returning the state as described by the Ent table slice.
func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc {
return func(context.Context) (*schema.Realm, error) {
ts, err := a.tables(tables)
if err != nil {
return nil, err
}
return &schema.Realm{Schemas: []*schema.Schema{{Tables: ts}}}, nil
}
}
func (m *Migrate) setupAtlas() error {
// Using one of the Atlas options, opt-in to Atlas migration.
if !m.atlas.enabled && (m.atlas.skip != NoChange || len(m.atlas.diff) > 0 || len(m.atlas.apply) > 0) || m.atlas.dir != nil {
m.atlas.enabled = true
}
if !m.atlas.enabled {
return nil
}
if m.withFixture {
return errors.New("sql/schema: WithFixture(true) does not work in Atlas migration")
}
// atBuilder must be implemented by the different drivers in
// order to convert a dialect/sql/schema to atlas/sql/schema.
type atBuilder interface {
atOpen(dialect.ExecQuerier) (migrate.Driver, error)
atTable(*Table, *schema.Table)
atTypeC(*Column, *schema.Column) error
atUniqueC(*Table, *Column, *schema.Table, *schema.Column)
atIncrementC(*schema.Table, *schema.Column)
atIncrementT(*schema.Table, int64)
atIndex(*Index, *schema.Table, *schema.Index) error
atTypeRangeSQL(t ...string) string
}
// init initializes the configuration object based on the options passed in.
func (a *Atlas) init() error {
skip := DropIndex | DropColumn
if m.atlas.skip != NoChange {
skip = m.atlas.skip
if a.skip != NoChange {
skip = a.skip
}
if m.dropIndexes {
if a.dropIndexes {
skip &= ^DropIndex
}
if m.dropColumns {
if a.dropColumns {
skip &= ^DropColumn
}
if skip != NoChange {
m.atlas.diff = append(m.atlas.diff, filterChanges(skip))
a.diffHooks = append(a.diffHooks, filterChanges(skip))
}
if !m.withForeignKeys {
m.atlas.diff = append(m.atlas.diff, withoutForeignKeys)
if !a.withForeignKeys {
a.diffHooks = append(a.diffHooks, withoutForeignKeys)
}
if m.atlas.dir != nil && m.atlas.fmt == nil {
m.atlas.fmt = sqltool.GolangMigrateFormatter
if a.dir != nil && a.fmt == nil {
a.fmt = sqltool.GolangMigrateFormatter
}
if m.universalID && m.atlas.dir != nil {
// If global unique ids and a migration directory is given, enable the file based type store for pk ranges.
m.typeStore = &dirTypeStore{dir: m.atlas.dir}
// To guard the user against a possible bug due to backward incompatibility, the file based type store must
// be enabled by an option. For more information see the comment of WithUniversalID function.
if !m.atlas.typeStoreConsent {
return errConsent
if a.mode == ModeReplay {
// ModeReplay requires a migration directory.
if a.dir == nil {
return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires versioned migrations: WithDir()")
}
// ModeReplay requires sum file generation.
if !a.sum {
return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires migration directory integrity file")
}
m.atlas.diff = append(m.atlas.diff, m.ensureTypeTable)
}
return nil
}
func (m *Migrate) atCreate(ctx context.Context, tables ...*Table) error {
// create is the Atlas engine based online migration.
func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
if a.universalID {
tables = append(tables, NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}),
)
}
if a.driver != nil {
a.sqlDialect, err = a.entDialect(a.driver)
if err != nil {
return err
}
} else {
c, err := sqlclient.OpenURL(ctx, a.url)
if err != nil {
return err
}
defer c.Close()
a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB))
if err != nil {
return err
}
}
defer func() { a.sqlDialect = nil }()
// Open a transaction for backwards compatibility,
// even if the migration is not transactional.
tx, err := m.Tx(ctx)
tx, err := a.sqlDialect.Tx(ctx)
if err != nil {
return err
}
if err := a.sqlDialect.init(ctx, tx); err != nil {
return err
}
a.atDriver, err = a.sqlDialect.atOpen(tx)
if err != nil {
return err
}
defer func() { a.atDriver = nil }()
if err := func() error {
if err := m.init(ctx, tx); err != nil {
return err
}
if m.universalID {
if err := m.types(ctx, tx); err != nil {
return err
}
}
plan, err := m.atDiff(ctx, tx, "", tables...)
plan, err := a.plan(ctx, tx, "changes", tables)
if err != nil {
return err
}
@@ -392,22 +630,25 @@ func (m *Migrate) atCreate(ctx context.Context, tables ...*Table) error {
}
return nil
})
for i := len(m.atlas.apply) - 1; i >= 0; i-- {
applier = m.atlas.apply[i](applier)
for i := len(a.applyHook) - 1; i >= 0; i-- {
applier = a.applyHook[i](applier)
}
return applier.Apply(ctx, tx, plan)
}(); err != nil {
return rollback(tx, err)
err = fmt.Errorf("sql/schema: %w", err)
if rerr := tx.Rollback(); rerr != nil {
err = fmt.Errorf("%w: %v", err, rerr)
}
return err
}
return tx.Commit()
}
func (m *Migrate) atDiff(ctx context.Context, conn dialect.ExecQuerier, name string, tables ...*Table) (*migrate.Plan, error) {
drv, err := m.atOpen(conn)
if err != nil {
return nil, err
}
current, err := drv.InspectSchema(ctx, "", &schema.InspectOptions{
// plan creates the current state by inspecting the connected database, computing the current state of the Ent schema
// and proceeds to diff the changes to create a migration plan.
// before diffing.
func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) {
current, err := a.atDriver.InspectSchema(ctx, "", &schema.InspectOptions{
Tables: func() (t []string) {
for i := range tables {
t = append(t, tables[i].Name)
@@ -418,21 +659,55 @@ func (m *Migrate) atDiff(ctx context.Context, conn dialect.ExecQuerier, name str
if err != nil {
return nil, err
}
tt, err := m.aTables(ctx, m, conn, tables)
var types []string
if a.universalID {
// Fetch pre-existing type allocations.
exists, err := a.sqlDialect.tableExist(ctx, conn, TypeTable)
if err != nil {
return nil, err
}
if exists {
rows := &entsql.Rows{}
query, args := entsql.Dialect(a.dialect).
Select("type").From(entsql.Table(TypeTable)).OrderBy(entsql.Asc("id")).Query()
if err := conn.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("query types table: %w", err)
}
defer rows.Close()
a.types = nil
if err := entsql.ScanSlice(rows, &a.types); err != nil {
return nil, err
}
}
types = a.types
}
desired, err := a.StateReader(tables...).ReadState(ctx)
if err != nil {
return nil, err
}
// Diff changes.
var differ Differ = DiffFunc(drv.SchemaDiff)
for i := len(m.atlas.diff) - 1; i >= 0; i-- {
differ = m.atlas.diff[i](differ)
}
changes, err := differ.Diff(current, &schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: tt})
changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, &schema.Schema{
Name: current.Name,
Attrs: current.Attrs,
Tables: desired.Schemas[0].Tables,
})
if err != nil {
return nil, err
}
// Plan changes.
return drv.PlanChanges(ctx, name, changes)
plan, err := a.atDriver.PlanChanges(ctx, name, changes)
if err != nil {
return nil, err
}
// Insert new types.
newTypes := a.types[len(types):]
if len(newTypes) > 0 {
plan.Changes = append(plan.Changes, &migrate.Change{
Cmd: a.sqlDialect.atTypeRangeSQL(newTypes...),
Comment: fmt.Sprintf("add pk ranges for %s tables", strings.Join(newTypes, ",")),
})
}
return plan, nil
}
type db struct{ dialect.ExecQuerier }
@@ -453,28 +728,29 @@ func (d *db) ExecContext(ctx context.Context, query string, args ...interface{})
return r, nil
}
func (m *Migrate) aTables(ctx context.Context, b atBuilder, conn dialect.ExecQuerier, tables1 []*Table) ([]*schema.Table, error) {
tables2 := make([]*schema.Table, len(tables1))
for i, t1 := range tables1 {
t2 := schema.NewTable(t1.Name)
b.atTable(t1, t2)
if m.universalID {
r, err := m.pkRange(ctx, conn, t1)
// tables converts an Ent table slice to an atlas table slice
func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) {
ts := make([]*schema.Table, len(tables))
for i, et := range tables {
at := schema.NewTable(et.Name)
a.sqlDialect.atTable(et, at)
if a.universalID && et.Name != TypeTable {
r, err := a.pkRange(et)
if err != nil {
return nil, err
}
b.atIncrementT(t2, r)
a.sqlDialect.atIncrementT(at, r)
}
if err := m.aColumns(b, t1, t2); err != nil {
if err := a.aColumns(et, at); err != nil {
return nil, err
}
if err := m.aIndexes(b, t1, t2); err != nil {
if err := a.aIndexes(et, at); err != nil {
return nil, err
}
tables2[i] = t2
ts[i] = at
}
for i, t1 := range tables1 {
t2 := tables2[i]
for i, t1 := range tables {
t2 := ts[i]
for _, fk1 := range t1.ForeignKeys {
fk2 := schema.NewForeignKey(fk1.Symbol).
SetTable(t2).
@@ -488,7 +764,7 @@ func (m *Migrate) aTables(ctx context.Context, b atBuilder, conn dialect.ExecQue
fk2.AddColumns(c2)
}
var refT *schema.Table
for _, t2 := range tables2 {
for _, t2 := range ts {
if t2.Name == fk1.RefTable.Name {
refT = t2
break
@@ -508,17 +784,17 @@ func (m *Migrate) aTables(ctx context.Context, b atBuilder, conn dialect.ExecQue
t2.AddForeignKeys(fk2)
}
}
return tables2, nil
return ts, nil
}
func (m *Migrate) aColumns(b atBuilder, t1 *Table, t2 *schema.Table) error {
for _, c1 := range t1.Columns {
func (a *Atlas) aColumns(et *Table, at *schema.Table) error {
for _, c1 := range et.Columns {
c2 := schema.NewColumn(c1.Name).
SetNull(c1.Nullable)
if c1.Collation != "" {
c2.SetCollation(c1.Collation)
}
if err := b.atTypeC(c1, c2); err != nil {
if err := a.sqlDialect.atTypeC(c1, c2); err != nil {
return err
}
if c1.Default != nil && c1.supportDefault() {
@@ -530,106 +806,128 @@ func (m *Migrate) aColumns(b atBuilder, t1 *Table, t2 *schema.Table) error {
}
c2.SetDefault(&schema.RawExpr{X: x})
}
if c1.Unique && (len(t1.PrimaryKey) != 1 || t1.PrimaryKey[0] != c1) {
b.atUniqueC(t1, c1, t2, c2)
if c1.Unique && (len(et.PrimaryKey) != 1 || et.PrimaryKey[0] != c1) {
a.sqlDialect.atUniqueC(et, c1, at, c2)
}
if c1.Increment {
b.atIncrementC(t2, c2)
a.sqlDialect.atIncrementC(at, c2)
}
t2.AddColumns(c2)
at.AddColumns(c2)
}
return nil
}
func (m *Migrate) aIndexes(b atBuilder, t1 *Table, t2 *schema.Table) error {
func (a *Atlas) aIndexes(et *Table, at *schema.Table) error {
// Primary-key index.
pk := make([]*schema.Column, 0, len(t1.PrimaryKey))
for _, c1 := range t1.PrimaryKey {
c2, ok := t2.Column(c1.Name)
pk := make([]*schema.Column, 0, len(et.PrimaryKey))
for _, c1 := range et.PrimaryKey {
c2, ok := at.Column(c1.Name)
if !ok {
return fmt.Errorf("unexpected primary-key column: %q", c1.Name)
}
pk = append(pk, c2)
}
t2.SetPrimaryKey(schema.NewPrimaryKey(pk...))
at.SetPrimaryKey(schema.NewPrimaryKey(pk...))
// Rest of indexes.
for _, idx1 := range t1.Indexes {
for _, idx1 := range et.Indexes {
idx2 := schema.NewIndex(idx1.Name).
SetUnique(idx1.Unique)
if err := b.atIndex(idx1, t2, idx2); err != nil {
if err := a.sqlDialect.atIndex(idx1, at, idx2); err != nil {
return err
}
desc := descIndexes(idx1)
for _, p := range idx2.Parts {
p.Desc = desc[p.C.Name]
}
t2.AddIndexes(idx2)
at.AddIndexes(idx2)
}
return nil
}
func (m *Migrate) ensureTypeTable(next Differ) Differ {
return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) {
// If there is a types table but no types file yet, the user most likely
// switched from online migration to migration files.
if len(m.dbTypeRanges) == 0 {
var (
at = schema.NewTable(TypeTable)
et = NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
)
m.atTable(et, at)
if err := m.aColumns(m, et, at); err != nil {
return nil, err
// setupTables ensures the table is configured properly, like table columns
// are linked to their indexes, and PKs columns are defined.
func (a *Atlas) setupTables(tables []*Table) {
for _, t := range tables {
if t.columns == nil {
t.columns = make(map[string]*Column, len(t.Columns))
}
for _, c := range t.Columns {
t.columns[c.Name] = c
}
for _, idx := range t.Indexes {
idx.Name = a.symbol(idx.Name)
for _, c := range idx.Columns {
c.indexes.append(idx)
}
if err := m.aIndexes(m, et, at); err != nil {
return nil, err
}
for _, pk := range t.PrimaryKey {
c := t.columns[pk.Name]
c.Key = PrimaryKey
pk.Key = PrimaryKey
}
for _, fk := range t.ForeignKeys {
fk.Symbol = a.symbol(fk.Symbol)
for i := range fk.Columns {
fk.Columns[i].foreign = fk
}
desired.Tables = append(desired.Tables, at)
}
// If there is a drift between the types stored in the database and the ones stored in the file,
// stop diffing, as this is potentially destructive. This will most likely happen on the first diffing
// after moving from online-migration to versioned migrations if the "old" ent types are not in sync with
// the deterministic ones computed by the new engine.
if len(m.dbTypeRanges) > 0 && len(m.fileTypeRanges) > 0 && !equal(m.fileTypeRanges, m.dbTypeRanges) {
return nil, fmt.Errorf(
"type allocation range drift detected: %v <> %v: see %s for more information",
m.dbTypeRanges, m.fileTypeRanges,
"https://entgo.io/docs/versioned-migrations#moving-from-auto-migration-to-versioned-migrations",
)
}
changes, err := next.Diff(current, desired)
if err != nil {
return nil, err
}
if len(m.dbTypeRanges) > 0 && len(m.fileTypeRanges) == 0 {
// Override the types file created in the diff process with the "old" allocated types ranges.
if err := m.typeStore.(*dirTypeStore).save(m.dbTypeRanges); err != nil {
return nil, err
}
// Change the type range allocations since they will be added to the migration files when
// writing the migration plan to migration files.
m.typeRanges = m.dbTypeRanges
}
return changes, nil
})
}
}
func setAtChecks(t1 *Table, t2 *schema.Table) {
if check := t1.Annotation.Check; check != "" {
t2.AddChecks(&schema.Check{
// symbol makes sure the symbol length is not longer than the maxlength in the dialect.
func (a *Atlas) symbol(name string) string {
size := 64
if a.dialect == dialect.Postgres {
size = 63
}
if len(name) <= size {
return name
}
return fmt.Sprintf("%s_%x", name[:size-33], md5.Sum([]byte(name)))
}
// entDialect returns the Ent dialect as configured by the dialect option.
func (a *Atlas) entDialect(drv dialect.Driver) (sqlDialect, error) {
switch a.dialect {
case dialect.MySQL:
return &MySQL{Driver: drv}, nil
case dialect.SQLite:
return &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}, nil
case dialect.Postgres:
return &Postgres{Driver: drv}, nil
default:
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect)
}
}
func (a *Atlas) pkRange(et *Table) (int64, error) {
idx := indexOf(a.types, et.Name)
// If the table re-created, re-use its range from
// the past. Otherwise, allocate a new id-range.
if idx == -1 {
if len(a.types) > MaxTypes {
return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes)
}
idx = len(a.types)
a.types = append(a.types, et.Name)
}
return int64(idx << 32), nil
}
func setAtChecks(et *Table, at *schema.Table) {
if check := et.Annotation.Check; check != "" {
at.AddChecks(&schema.Check{
Expr: check,
})
}
if checks := t1.Annotation.Checks; len(t1.Annotation.Checks) > 0 {
if checks := et.Annotation.Checks; len(et.Annotation.Checks) > 0 {
names := make([]string, 0, len(checks))
for name := range checks {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
t2.AddChecks(&schema.Check{
at.AddChecks(&schema.Check{
Name: name,
Expr: checks[name],
})
@@ -654,61 +952,49 @@ func descIndexes(idx *Index) map[string]bool {
return descs
}
const entTypes = ".ent_types"
// dirTypeStore stores and read pk information from a text file stored alongside generated versioned migrations.
// This behaviour is enabled automatically when using versioned migrations.
type dirTypeStore struct {
dir migrate.Dir
// driver decorates the atlas migrate.Driver and adds "diff hooking" and functionality.
type diffDriver struct {
migrate.Driver
hooks []DiffHook // hooks to apply
}
const atlasDirective = "atlas:sum ignore\n"
// load the types from the types file.
func (s *dirTypeStore) load(context.Context, dialect.ExecQuerier) ([]string, error) {
f, err := s.dir.Open(entTypes)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("reading types file: %w", err)
}
if errors.Is(err, fs.ErrNotExist) {
return nil, nil
}
defer f.Close()
c, err := ioutil.ReadAll(f)
if err != nil {
return nil, fmt.Errorf("reading types file: %w", err)
}
return strings.Split(strings.TrimPrefix(string(c), atlasDirective), ","), nil
// RealmDiff creates the diff between two realms. Since Ent does not care about Realms,
// not even schema changes, calling this method raises an error.
func (r *diffDriver) RealmDiff(_, _ *schema.Realm) ([]schema.Change, error) {
return nil, errors.New("sqlDialect does not support working with realms")
}
// add a new type entry to the types file.
func (s *dirTypeStore) add(ctx context.Context, conn dialect.ExecQuerier, t string) error {
ts, err := s.load(ctx, conn)
if err != nil {
return fmt.Errorf("adding type %q: %w", t, err)
// SchemaDiff creates the diff between two schemas, but includes "diff hooks".
func (r *diffDriver) SchemaDiff(from, to *schema.Schema) ([]schema.Change, error) {
var d Differ = DiffFunc(r.Driver.SchemaDiff)
for i := len(r.hooks) - 1; i >= 0; i-- {
d = r.hooks[i](d)
}
return s.save(append(ts, t))
return d.Diff(from, to)
}
// save takes the given allocation range and writes them to the types file.
// The types file will be overridden.
func (s *dirTypeStore) save(ts []string) error {
if err := s.dir.WriteFile(entTypes, []byte(atlasDirective+strings.Join(ts, ","))); err != nil {
return fmt.Errorf("writing types file: %w", err)
// legacyMigrate returns a configured legacy migration engine (before Atlas) to keep backwards compatibility.
//
// Deprecated: Will be removed alongside legacy migration support.
func (a *Atlas) legacyMigrate() (*Migrate, error) {
m := &Migrate{
universalID: a.universalID,
dropColumns: a.dropColumns,
dropIndexes: a.dropIndexes,
withFixture: a.withFixture,
withForeignKeys: a.withForeignKeys,
hooks: a.hooks,
atlas: a,
}
return nil
}
var _ typeStore = (*dirTypeStore)(nil)
func equal(s1, s2 []string) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
switch a.dialect {
case dialect.MySQL:
m.sqlDialect = &MySQL{Driver: a.driver}
case dialect.SQLite:
m.sqlDialect = &SQLite{Driver: a.driver, WithForeignKeys: a.withForeignKeys}
case dialect.Postgres:
m.sqlDialect = &Postgres{Driver: a.driver}
default:
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect)
}
return m, nil
}

View File

@@ -1,33 +0,0 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"context"
"os"
"path/filepath"
"testing"
"ariga.io/atlas/sql/migrate"
"github.com/stretchr/testify/require"
)
func TestDirTypeStore(t *testing.T) {
ex := []string{"a", "b", "c"}
p := t.TempDir()
d, err := migrate.NewLocalDir(p)
require.NoError(t, err)
s := &dirTypeStore{d}
require.NoError(t, s.save(ex))
require.FileExists(t, filepath.Join(p, entTypes))
c, err := os.ReadFile(filepath.Join(p, entTypes))
require.NoError(t, err)
require.Contains(t, string(c), atlasDirective)
ac, err := s.load(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, ex, ac)
}

View File

@@ -6,13 +6,9 @@ package schema
import (
"context"
"crypto/md5"
"errors"
"fmt"
"math"
"strings"
"ariga.io/atlas/sql/migrate"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
@@ -21,57 +17,61 @@ import (
const (
// TypeTable defines the table name holding the type information.
TypeTable = "ent_types"
// MaxTypes defines the max number of types can be created when
// defining universal ids. The left 16-bits are reserved.
MaxTypes = math.MaxUint16
)
// MigrateOption allows for managing schema configuration using functional options.
type MigrateOption func(*Migrate)
// MigrateOption allows configuring Atlas using functional arguments.
type MigrateOption func(*Atlas)
// WithGlobalUniqueID sets the universal ids options to the migration.
// Defaults to false.
func WithGlobalUniqueID(b bool) MigrateOption {
return func(m *Migrate) {
m.universalID = b
return func(a *Atlas) {
a.universalID = b
}
}
// WithDropColumn sets the columns dropping option to the migration.
// Defaults to false.
func WithDropColumn(b bool) MigrateOption {
return func(m *Migrate) {
m.dropColumns = b
return func(a *Atlas) {
a.dropColumns = b
}
}
// WithDropIndex sets the indexes dropping option to the migration.
// Defaults to false.
func WithDropIndex(b bool) MigrateOption {
return func(m *Migrate) {
m.dropIndexes = b
return func(a *Atlas) {
a.dropIndexes = b
}
}
// WithFixture sets the foreign-key renaming option to the migration when upgrading
// ent from v0.1.0 (issue-#285). Defaults to false.
// sqlDialect from v0.1.0 (issue-#285). Defaults to false.
//
// Deprecated: This option is no longer needed with the Atlas based
// migration engine, which now is the default.
func WithFixture(b bool) MigrateOption {
return func(m *Migrate) {
m.withFixture = b
return func(a *Atlas) {
a.withFixture = b
}
}
// WithForeignKeys enables creating foreign-key in ddl. Defaults to true.
func WithForeignKeys(b bool) MigrateOption {
return func(m *Migrate) {
m.withForeignKeys = b
return func(a *Atlas) {
a.withForeignKeys = b
}
}
// WithHooks adds a list of hooks to the schema migration.
func WithHooks(hooks ...Hook) MigrateOption {
return func(m *Migrate) {
m.hooks = append(m.hooks, hooks...)
return func(a *Atlas) {
a.hooks = append(a.hooks, hooks...)
}
}
@@ -105,42 +105,19 @@ func (f CreateFunc) Create(ctx context.Context, tables ...*Table) error {
}
// Migrate runs the migration logic for the SQL dialects.
//
// Deprecated: Use the new Atlas struct instead.
type Migrate struct {
sqlDialect
universalID bool // global unique ids.
dropColumns bool // drop deleted columns.
dropIndexes bool // drop deleted indexes.
withFixture bool // with fks rename fixture.
withForeignKeys bool // with foreign keys
atlas *atlasOptions // migrate with atlas.
typeRanges []string // types order by their range.
hooks []Hook // hooks to apply before creation
typeStore typeStore // the typeStore to read and save type ranges
fileTypeRanges []string // used internally by ensureTypeTable hook
dbTypeRanges []string // used internally by ensureTypeTable hook
}
atlas *Atlas // Atlas this Migrate is based on
// NewMigrate create a migration structure for the given SQL driver.
func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
m := &Migrate{withForeignKeys: true, atlas: &atlasOptions{}}
for _, opt := range opts {
opt(m)
}
switch d.Dialect() {
case dialect.MySQL:
m.sqlDialect = &MySQL{Driver: d}
case dialect.SQLite:
m.sqlDialect = &SQLite{Driver: d, WithForeignKeys: m.withForeignKeys}
case dialect.Postgres:
m.sqlDialect = &Postgres{Driver: d}
default:
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect())
}
m.typeStore = &dbTypeStore{m.sqlDialect}
if err := m.setupAtlas(); err != nil {
return nil, err
}
return m, nil
universalID bool // global unique ids
dropColumns bool // drop deleted columns
dropIndexes bool // drop deleted indexes
withFixture bool // with fks rename fixture
withForeignKeys bool // with foreign keys
typeRanges []string // types order by their range
hooks []Hook // hooks to apply before creation
}
// Create creates all schema resources in the database. It works in an "append-only"
@@ -155,82 +132,12 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
m.setupTables(tables)
var creator Creator = CreateFunc(m.create)
if m.atlas.enabled {
creator = CreateFunc(m.atCreate)
}
for i := len(m.hooks) - 1; i >= 0; i-- {
creator = m.hooks[i](creator)
}
return creator.Create(ctx, tables...)
}
// Diff compares the state read from the connected database with the state defined by Ent.
// Changes will be written to migration files by the configured Planner.
func (m *Migrate) Diff(ctx context.Context, tables ...*Table) error {
return m.NamedDiff(ctx, "changes", tables...)
}
// NamedDiff compares the state read from the connected database with the state defined by Ent.
// Changes will be written to migration files by the configured Planner.
func (m *Migrate) NamedDiff(ctx context.Context, name string, tables ...*Table) error {
if m.atlas.dir == nil {
return errors.New("no migration directory given")
}
opts := []migrate.PlannerOption{
migrate.WithFormatter(m.atlas.fmt),
}
if m.atlas.genSum {
// Validate the migration directory before proceeding.
if err := migrate.Validate(m.atlas.dir); err != nil {
return fmt.Errorf("validating migration directory: %w", err)
}
} else {
opts = append(opts, migrate.DisableChecksum())
}
if err := m.init(ctx, m); err != nil {
return err
}
if m.universalID {
if err := m.types(ctx, m); err != nil {
return err
}
m.fileTypeRanges = m.typeRanges
ex, err := m.tableExist(ctx, m, TypeTable)
if err != nil {
return err
}
if ex {
m.dbTypeRanges, err = (&dbTypeStore{m}).load(ctx, m)
if err != nil {
return err
}
}
defer func() {
m.fileTypeRanges = nil
m.dbTypeRanges = nil
}()
}
m.setupTables(tables)
plan, err := m.atDiff(ctx, m, name, tables...)
if err != nil {
return err
}
if m.universalID {
newTypes := m.typeRanges[len(m.dbTypeRanges):]
if len(newTypes) > 0 {
plan.Changes = append(plan.Changes, &migrate.Change{
Cmd: m.atTypeRangeSQL(newTypes...),
Comment: fmt.Sprintf("add pk ranges for %s tables", strings.Join(newTypes, ",")),
})
}
}
// Skip if the plan has no changes.
if len(plan.Changes) == 0 {
return nil
}
return migrate.NewPlanner(nil, m.atlas.dir, opts...).WritePlan(plan)
}
func (m *Migrate) create(ctx context.Context, tables ...*Table) error {
tx, err := m.Tx(ctx)
if err != nil {
@@ -328,7 +235,7 @@ func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table)
return nil
}
// apply applies changes on the given table.
// apply changes on the given table.
func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change *changes) error {
// Constraints should be dropped before dropping columns, because if a column
// is a part of multi-column constraints (like, unique index), ALTER TABLE
@@ -419,7 +326,7 @@ func (m *Migrate) changeSet(curr, new *Table) (*changes, error) {
case c1.Unique && !c2.Unique:
// Make sure the table does not have unique index for this column
// before adding it to the changeset, because there are 2 ways to
// configure uniqueness on ent.Field (using the Unique modifier or
// configure uniqueness on sqlDialect.Field (using the Unique modifier or
// adding rule on the Indexes option).
if idx, ok := curr.index(c1.Name); !ok || !idx.Unique {
change.index.add.append(&Index{
@@ -580,10 +487,30 @@ func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error {
return vr.verifyRange(ctx, tx, t, int64(id<<32))
}
// types loads the type list from the type store.
func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) (err error) {
m.typeRanges, err = m.typeStore.load(ctx, tx)
return
// types loads the type list from the type store. It will create the types table, if it does not exist yet.
func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) error {
exists, err := m.tableExist(ctx, tx, TypeTable)
if err != nil {
return err
}
if !exists {
t := NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
query, args := m.tBuilder(t).Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("create types table: %w", err)
}
return nil
}
rows := &sql.Rows{}
query, args := sql.Dialect(m.Dialect()).
Select("type").From(sql.Table(TypeTable)).OrderBy(sql.Asc("id")).Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return fmt.Errorf("query types table: %w", err)
}
defer rows.Close()
return sql.ScanSlice(rows, &m.typeRanges)
}
func (m *Migrate) allocPKRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) error {
@@ -602,8 +529,9 @@ func (m *Migrate) pkRange(ctx context.Context, conn dialect.ExecQuerier, t *Tabl
if len(m.typeRanges) > MaxTypes {
return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes)
}
if err := m.typeStore.add(ctx, conn, t.Name); err != nil {
return 0, fmt.Errorf("store type range: %w", err)
query, args := sql.Dialect(m.Dialect()).Insert(TypeTable).Columns("type").Values(t.Name).Query()
if err := conn.Exec(ctx, query, args, nil); err != nil {
return 0, fmt.Errorf("insert into ent_types: %w", err)
}
id = len(m.typeRanges)
m.typeRanges = append(m.typeRanges, t.Name)
@@ -641,45 +569,7 @@ func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (
// setup ensures the table is configured properly, like table columns
// are linked to their indexes, and PKs columns are defined.
func (m *Migrate) setupTables(tables []*Table) {
for _, t := range tables {
if t.columns == nil {
t.columns = make(map[string]*Column, len(t.Columns))
}
for _, c := range t.Columns {
t.columns[c.Name] = c
}
for _, idx := range t.Indexes {
idx.Name = m.symbol(idx.Name)
for _, c := range idx.Columns {
c.indexes.append(idx)
}
}
for _, pk := range t.PrimaryKey {
c := t.columns[pk.Name]
c.Key = PrimaryKey
pk.Key = PrimaryKey
}
for _, fk := range t.ForeignKeys {
fk.Symbol = m.symbol(fk.Symbol)
for i := range fk.Columns {
fk.Columns[i].foreign = fk
}
}
}
}
// symbol makes sure the symbol length is not longer than the maxlength in the dialect.
func (m *Migrate) symbol(name string) string {
size := 64
if m.Dialect() == dialect.Postgres {
size = 63
}
if len(name) <= size {
return name
}
return fmt.Sprintf("%s_%x", name[:size-33], md5.Sum([]byte(name)))
}
func (m *Migrate) setupTables(tables []*Table) { m.atlas.setupTables(tables) }
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
func rollback(tx dialect.Tx, err error) error {
@@ -747,55 +637,3 @@ type fkRenamer interface {
type verifyRanger interface {
verifyRange(context.Context, dialect.Tx, *Table, int64) error
}
// typeStore wraps methods for loading and storing pk range information for types.
type typeStore interface {
load(context.Context, dialect.ExecQuerier) ([]string, error)
add(context.Context, dialect.ExecQuerier, string) error
}
// dbTypeStore stores and read pk information in a database table.
// This is the "old" behaviour before the typeStore interface was added.
type dbTypeStore struct {
drv sqlDialect
}
// load the types from the database. If the table does not exist, it will be created.
func (s *dbTypeStore) load(ctx context.Context, conn dialect.ExecQuerier) ([]string, error) {
exists, err := s.drv.tableExist(ctx, conn, TypeTable)
if err != nil {
return nil, err
}
if !exists {
t := NewTable(TypeTable).
AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}).
AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true})
query, args := s.drv.tBuilder(t).Query()
if err := conn.Exec(ctx, query, args, nil); err != nil {
return nil, fmt.Errorf("create types table: %w", err)
}
return nil, nil
}
rows := &sql.Rows{}
query, args := sql.Dialect(s.drv.Dialect()).
Select("type").From(sql.Table(TypeTable)).OrderBy(sql.Asc("id")).Query()
if err := conn.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("query types table: %w", err)
}
defer rows.Close()
var types []string
return types, sql.ScanSlice(rows, &types)
}
// add a new type entry to the database table. since load is called first,
// there is no need to check for the tables' existence.
func (s *dbTypeStore) add(ctx context.Context, conn dialect.ExecQuerier, t string) error {
query, args := sql.Dialect(s.drv.Dialect()).
Insert(TypeTable).Columns("type").Values(t).Query()
if err := conn.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("insert into ent_types: %w", err)
}
return nil
}
var _ typeStore = (*dbTypeStore)(nil)

View File

@@ -16,10 +16,9 @@ import (
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
@@ -42,7 +41,7 @@ func TestMigrateHookOmitTable(t *testing.T) {
return CreateFunc(func(ctx context.Context, tables ...*Table) error {
return next.Create(ctx, tables[1])
})
}))
}), WithAtlas(false))
require.NoError(t, err)
err = m.Create(context.Background(), tables...)
require.NoError(t, err)
@@ -67,13 +66,15 @@ func TestMigrateHookAddTable(t *testing.T) {
return CreateFunc(func(ctx context.Context, tables ...*Table) error {
return next.Create(ctx, tables[0], &Table{Name: "pets"})
})
}))
}), WithAtlas(false))
require.NoError(t, err)
err = m.Create(context.Background(), tables...)
require.NoError(t, err)
}
func TestMigrate_Diff(t *testing.T) {
ctx := context.Background()
db, err := sql.Open(dialect.SQLite, "file:test?mode=memory&_fk=1")
require.NoError(t, err)
@@ -83,11 +84,11 @@ func TestMigrate_Diff(t *testing.T) {
m, err := NewMigrate(db, WithDir(d))
require.NoError(t, err)
require.NoError(t, m.Diff(context.Background(), &Table{Name: "users"}))
require.NoError(t, m.Diff(ctx, &Table{Name: "users"}))
v := time.Now().UTC().Format("20060102150405")
requireFileEqual(t, filepath.Join(p, v+"_changes.up.sql"), "-- create \"users\" table\nCREATE TABLE `users` (, PRIMARY KEY ());\n")
requireFileEqual(t, filepath.Join(p, v+"_changes.down.sql"), "-- reverse: create \"users\" table\nDROP TABLE `users`;\n")
require.NoFileExists(t, filepath.Join(p, "atlas.sum"))
require.FileExists(t, filepath.Join(p, migrate.HashFileName))
// Test integrity file.
p = t.TempDir()
@@ -95,14 +96,13 @@ func TestMigrate_Diff(t *testing.T) {
require.NoError(t, err)
m, err = NewMigrate(db, WithDir(d), WithSumFile())
require.NoError(t, err)
require.NoError(t, m.Diff(context.Background(), &Table{Name: "users"}))
require.NoError(t, m.Diff(ctx, &Table{Name: "users"}))
requireFileEqual(t, filepath.Join(p, v+"_changes.up.sql"), "-- create \"users\" table\nCREATE TABLE `users` (, PRIMARY KEY ());\n")
requireFileEqual(t, filepath.Join(p, v+"_changes.down.sql"), "-- reverse: create \"users\" table\nDROP TABLE `users`;\n")
require.FileExists(t, filepath.Join(p, "atlas.sum"))
require.FileExists(t, filepath.Join(p, migrate.HashFileName))
require.NoError(t, d.WriteFile("tmp.sql", nil))
require.ErrorIs(t, m.Diff(context.Background(), &Table{Name: "users"}), migrate.ErrChecksumMismatch)
require.ErrorIs(t, m.Diff(ctx, &Table{Name: "users"}), migrate.ErrChecksumMismatch)
// Test type store.
idCol := []*Column{{Name: "id", Type: field.TypeInt, Increment: true}}
p = t.TempDir()
d, err = migrate.NewLocalDir(p)
@@ -115,22 +115,19 @@ func TestMigrate_Diff(t *testing.T) {
)
require.NoError(t, err)
// If using global unique ID and versioned migrations,
// consent for the file based type store has to be given explicitly.
_, err = NewMigrate(db, WithDir(d), WithGlobalUniqueID(true))
require.ErrorIs(t, err, errConsent)
require.Contains(t, err.Error(), "WithUniversalID")
require.Contains(t, err.Error(), "WithGlobalUniqueID")
require.Contains(t, err.Error(), "WithDir")
m, err = NewMigrate(db, WithFormatter(f), WithDir(d), WithUniversalID(), WithSumFile())
m, err = NewMigrate(db, WithFormatter(f), WithDir(d), WithGlobalUniqueID(true))
require.NoError(t, err)
require.IsType(t, &dirTypeStore{}, m.typeStore)
require.NoError(t, m.Diff(context.Background(),
require.NoError(t, m.Diff(ctx,
&Table{Name: "users", Columns: idCol, PrimaryKey: idCol},
&Table{Name: "groups", Columns: idCol, PrimaryKey: idCol, Indexes: []*Index{{Name: "short", Columns: idCol}, {Name: "long_" + strings.Repeat("_", 60), Columns: idCol}}},
&Table{
Name: "groups",
Columns: idCol,
PrimaryKey: idCol,
Indexes: []*Index{
{Name: "short", Columns: idCol},
{Name: "long_" + strings.Repeat("_", 60), Columns: idCol},
}},
))
requireFileEqual(t, filepath.Join(p, ".ent_types"), atlasDirective+"users,groups")
changesSQL := strings.Join([]string{
"CREATE TABLE `users` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);",
"CREATE TABLE `groups` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);",
@@ -143,17 +140,10 @@ func TestMigrate_Diff(t *testing.T) {
}, "\n")
requireFileEqual(t, filepath.Join(p, "changes.sql"), changesSQL)
// types file cannot be part of the sum file.
require.FileExists(t, filepath.Join(p, "atlas.sum"))
sum, err := os.ReadFile(filepath.Join(p, "atlas.sum"))
require.NoError(t, err)
require.NotContains(t, string(sum), ".ent_types")
// Adding another node will result in a new entry to the TypeTable (without actually creating it).
_, err = db.ExecContext(context.Background(), changesSQL, nil, nil)
_, err = db.ExecContext(ctx, changesSQL, nil, nil)
require.NoError(t, err)
require.NoError(t, m.NamedDiff(context.Background(), "changes_2", &Table{Name: "pets", Columns: idCol, PrimaryKey: idCol}))
requireFileEqual(t, filepath.Join(p, ".ent_types"), atlasDirective+"users,groups,pets")
require.NoError(t, m.NamedDiff(ctx, "changes_2", &Table{Name: "pets", Columns: idCol, PrimaryKey: idCol}))
requireFileEqual(t,
filepath.Join(p, "changes_2.sql"), strings.Join([]string{
"CREATE TABLE `pets` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);",
@@ -161,48 +151,8 @@ func TestMigrate_Diff(t *testing.T) {
"INSERT INTO `ent_types` (`type`) VALUES ('pets');", "",
}, "\n"))
// types file cannot be part of the sum file.
require.FileExists(t, filepath.Join(p, "atlas.sum"))
sum, err = os.ReadFile(filepath.Join(p, "atlas.sum"))
require.NoError(t, err)
require.NotContains(t, string(sum), ".ent_types")
// Checksum will be updated as well.
require.NoError(t, migrate.Validate(d))
// Running diff against an existing database without having a types file yet
// will result in the types file respect the "old" order of pk allocations.
switchAllocs := func(one, two string) {
for _, stmt := range []string{
"DELETE FROM `ent_types`;",
fmt.Sprintf("INSERT INTO `ent_types` (`type`) VALUES ('%s'), ('%s');", one, two),
} {
_, err = db.ExecContext(context.Background(), stmt)
require.NoError(t, err)
}
}
switchAllocs("groups", "users")
p = t.TempDir()
d, err = migrate.NewLocalDir(p)
require.NoError(t, err)
m, err = NewMigrate(db, WithFormatter(f), WithDir(d), WithUniversalID())
require.NoError(t, err)
require.NoError(t, m.Diff(context.Background(),
&Table{Name: "users", Columns: idCol, PrimaryKey: idCol},
&Table{Name: "groups", Columns: idCol, PrimaryKey: idCol},
))
requireFileEqual(t, filepath.Join(p, ".ent_types"), atlasDirective+"groups,users")
require.NoFileExists(t, filepath.Join(p, "changes.sql"))
// Drifts in the types file and types database will be detected,
switchAllocs("users", "groups")
require.ErrorContains(t, m.Diff(context.Background()), fmt.Sprintf(
"type allocation range drift detected: %v <> %v: see %s for more information",
[]string{"users", "groups"},
[]string{"groups", "users"},
"https://entgo.io/docs/versioned-migrations#moving-from-auto-migration-to-versioned-migrations",
))
}
func requireFileEqual(t *testing.T, name, contents string) {

View File

@@ -1362,7 +1362,7 @@ func TestMySQL_Create(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tt.before(mysqlMock{mock})
migrate, err := NewMigrate(sql.OpenDB("mysql", db), tt.options...)
migrate, err := NewMigrate(sql.OpenDB("mysql", db), append(tt.options, WithAtlas(false))...)
require.NoError(t, err)
err = migrate.Create(context.Background(), tt.tables...)
require.Equal(t, tt.wantErr, err != nil, err)

View File

@@ -997,7 +997,7 @@ func TestPostgres_Create(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tt.before(pgMock{mock})
migrate, err := NewMigrate(sql.OpenDB("postgres", db), tt.options...)
migrate, err := NewMigrate(sql.OpenDB("postgres", db), append(tt.options, WithAtlas(false))...)
require.NoError(t, err)
err = migrate.Create(context.Background(), tt.tables...)
require.Equal(t, tt.wantErr, err != nil, err)

View File

@@ -437,7 +437,7 @@ func TestSQLite_Create(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tt.before(sqliteMock{mock})
migrate, err := NewMigrate(sql.OpenDB("sqlite3", db), tt.options...)
migrate, err := NewMigrate(sql.OpenDB("sqlite3", db), append(tt.options, WithAtlas(false))...)
require.NoError(t, err)
err = migrate.Create(context.Background(), tt.tables...)
require.Equal(t, tt.wantErr, err != nil, err)