mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
* dialect/sql/schema: add method to create a named versioned migration file * doc/md: documentation for named versioned migrations * entc/gen/template/dialect/sql/feature: add NamedDiff method to create named versioned migration files * all: go generate * doc/md: apply CR
540 lines
14 KiB
Go
540 lines
14 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package schema
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
"ariga.io/atlas/sql/migrate"
|
|
"ariga.io/atlas/sql/schema"
|
|
|
|
"entgo.io/ent/dialect"
|
|
entsql "entgo.io/ent/dialect/sql"
|
|
"entgo.io/ent/schema/field"
|
|
)
|
|
|
|
type (
|
|
// Differ is the interface that wraps the Diff method.
|
|
Differ interface {
|
|
// Diff creates the given tables in the database.
|
|
Diff(current, desired *schema.Schema) ([]schema.Change, error)
|
|
}
|
|
|
|
// The DiffFunc type is an adapter to allow the use of ordinary function as Differ.
|
|
// If f is a function with the appropriate signature, DiffFunc(f) is a Differ that calls f.
|
|
DiffFunc func(current, desired *schema.Schema) ([]schema.Change, error)
|
|
|
|
// DiffHook defines the "diff middleware". A function that gets a Differ and returns a Differ.
|
|
DiffHook func(Differ) Differ
|
|
)
|
|
|
|
// Diff calls f(current, desired).
|
|
func (f DiffFunc) Diff(current, desired *schema.Schema) ([]schema.Change, error) {
|
|
return f(current, desired)
|
|
}
|
|
|
|
// WithDiffHook adds a list of DiffHook to the schema migration.
|
|
//
|
|
// schema.WithDiffHook(func(next schema.Differ) schema.Differ {
|
|
// return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) {
|
|
// // Code before standard diff.
|
|
// changes, err := next.Diff(current, desired)
|
|
// if err != nil {
|
|
// return nil, err
|
|
// }
|
|
// // After diff, you can filter
|
|
// // changes or return new ones.
|
|
// return changes, nil
|
|
// })
|
|
// })
|
|
//
|
|
func WithDiffHook(hooks ...DiffHook) MigrateOption {
|
|
return func(m *Migrate) {
|
|
m.atlas.diff = append(m.atlas.diff, hooks...)
|
|
}
|
|
}
|
|
|
|
// WithSkipChanges allows skipping/filtering list of changes
|
|
// returned by the Differ before executing migration planning.
|
|
//
|
|
// SkipChanges(schema.DropTable|schema.DropColumn)
|
|
//
|
|
func WithSkipChanges(skip ChangeKind) MigrateOption {
|
|
return func(m *Migrate) {
|
|
m.atlas.skip = skip
|
|
}
|
|
}
|
|
|
|
// A ChangeKind denotes the kind of schema change.
|
|
type ChangeKind uint
|
|
|
|
// List of change types.
|
|
const (
|
|
NoChange ChangeKind = 0
|
|
AddSchema ChangeKind = 1 << (iota - 1)
|
|
ModifySchema
|
|
DropSchema
|
|
AddTable
|
|
ModifyTable
|
|
DropTable
|
|
AddColumn
|
|
ModifyColumn
|
|
DropColumn
|
|
AddIndex
|
|
ModifyIndex
|
|
DropIndex
|
|
AddForeignKey
|
|
ModifyForeignKey
|
|
DropForeignKey
|
|
AddCheck
|
|
ModifyCheck
|
|
DropCheck
|
|
)
|
|
|
|
// Is reports whether c is match the given change king.
|
|
func (k ChangeKind) Is(c ChangeKind) bool {
|
|
return k == c || k&c != 0
|
|
}
|
|
|
|
// filterChanges is a DiffHook for filtering changes before plan.
|
|
func filterChanges(skip ChangeKind) DiffHook {
|
|
return func(next Differ) Differ {
|
|
return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) {
|
|
var f func([]schema.Change) []schema.Change
|
|
f = func(changes []schema.Change) (keep []schema.Change) {
|
|
var k ChangeKind
|
|
for _, c := range changes {
|
|
switch c := c.(type) {
|
|
case *schema.AddSchema:
|
|
k = AddSchema
|
|
case *schema.ModifySchema:
|
|
k = ModifySchema
|
|
if !skip.Is(k) {
|
|
c.Changes = f(c.Changes)
|
|
}
|
|
case *schema.DropSchema:
|
|
k = DropSchema
|
|
case *schema.AddTable:
|
|
k = AddTable
|
|
case *schema.ModifyTable:
|
|
k = ModifyTable
|
|
if !skip.Is(k) {
|
|
c.Changes = f(c.Changes)
|
|
}
|
|
case *schema.DropTable:
|
|
k = DropTable
|
|
case *schema.AddColumn:
|
|
k = AddColumn
|
|
case *schema.ModifyColumn:
|
|
k = ModifyColumn
|
|
case *schema.DropColumn:
|
|
k = DropColumn
|
|
case *schema.AddIndex:
|
|
k = AddIndex
|
|
case *schema.ModifyIndex:
|
|
k = ModifyIndex
|
|
case *schema.DropIndex:
|
|
k = DropIndex
|
|
case *schema.AddForeignKey:
|
|
k = AddIndex
|
|
case *schema.ModifyForeignKey:
|
|
k = ModifyForeignKey
|
|
case *schema.DropForeignKey:
|
|
k = DropForeignKey
|
|
case *schema.AddCheck:
|
|
k = AddCheck
|
|
case *schema.ModifyCheck:
|
|
k = ModifyCheck
|
|
case *schema.DropCheck:
|
|
k = DropCheck
|
|
}
|
|
if !skip.Is(k) {
|
|
keep = append(keep, c)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
changes, err := next.Diff(current, desired)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return f(changes), nil
|
|
})
|
|
}
|
|
}
|
|
|
|
type (
|
|
// Applier is the interface that wraps the Apply method.
|
|
Applier interface {
|
|
// Apply applies the given migrate.Plan on the database.
|
|
Apply(context.Context, dialect.ExecQuerier, *migrate.Plan) error
|
|
}
|
|
|
|
// The ApplyFunc type is an adapter to allow the use of ordinary function as Applier.
|
|
// If f is a function with the appropriate signature, ApplyFunc(f) is an Applier that calls f.
|
|
ApplyFunc func(context.Context, dialect.ExecQuerier, *migrate.Plan) error
|
|
|
|
// ApplyHook defines the "migration applying middleware". A function that gets an Applier and returns an Applier.
|
|
ApplyHook func(Applier) Applier
|
|
)
|
|
|
|
// Apply calls f(ctx, tables...).
|
|
func (f ApplyFunc) Apply(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
|
|
return f(ctx, conn, plan)
|
|
}
|
|
|
|
// WithApplyHook adds a list of ApplyHook to the schema migration.
|
|
//
|
|
// schema.WithApplyHook(func(next schema.Applier) schema.Applier {
|
|
// return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
|
|
// // Example to hook into the apply process, or implement
|
|
// // a custom applier.
|
|
// //
|
|
// // for _, c := range plan.Changes {
|
|
// // fmt.Printf("%s: %s", c.Comment, c.Cmd)
|
|
// // }
|
|
// //
|
|
// return next.Apply(ctx, conn, plan)
|
|
// })
|
|
// })
|
|
//
|
|
func WithApplyHook(hooks ...ApplyHook) MigrateOption {
|
|
return func(m *Migrate) {
|
|
m.atlas.apply = append(m.atlas.apply, hooks...)
|
|
}
|
|
}
|
|
|
|
// WithAtlas is an opt-in option for v0.10 indicates the migration
|
|
// should be executed using Atlas engine (i.e. https://atlasgo.io).
|
|
// Note, in future versions, this option is going to be replaced
|
|
// from opt-in to opt-out and the deprecation of this package.
|
|
func WithAtlas(b bool) MigrateOption {
|
|
return func(m *Migrate) {
|
|
m.atlas.enabled = b
|
|
}
|
|
}
|
|
|
|
// WithDir sets the atlas migration directory to use to store migration files.
|
|
func WithDir(dir migrate.Dir) MigrateOption {
|
|
return func(m *Migrate) {
|
|
m.atlas.dir = dir
|
|
}
|
|
}
|
|
|
|
// WithFormatter sets atlas formatter to use to write changes to migration files.
|
|
func WithFormatter(fmt migrate.Formatter) MigrateOption {
|
|
return func(m *Migrate) {
|
|
m.atlas.fmt = fmt
|
|
}
|
|
}
|
|
|
|
type (
|
|
// atlasOptions describes the options for atlas.
|
|
atlasOptions struct {
|
|
enabled bool
|
|
diff []DiffHook
|
|
apply []ApplyHook
|
|
skip ChangeKind
|
|
dir migrate.Dir
|
|
fmt migrate.Formatter
|
|
}
|
|
|
|
// atBuilder must be implemented by the different drivers in
|
|
// order to convert a dialect/sql/schema to atlas/sql/schema.
|
|
atBuilder interface {
|
|
atOpen(dialect.ExecQuerier) (migrate.Driver, error)
|
|
atTable(*Table, *schema.Table)
|
|
atTypeC(*Column, *schema.Column) error
|
|
atUniqueC(*Table, *Column, *schema.Table, *schema.Column)
|
|
atIncrementC(*schema.Table, *schema.Column)
|
|
atIncrementT(*schema.Table, int64)
|
|
atIndex(*Index, *schema.Table, *schema.Index) error
|
|
}
|
|
)
|
|
|
|
func (m *Migrate) setupAtlas() error {
|
|
// Using one of the Atlas options, opt-in to Atlas migration.
|
|
if !m.atlas.enabled && (m.atlas.skip != NoChange || len(m.atlas.diff) > 0 || len(m.atlas.apply) > 0) || m.atlas.dir != nil {
|
|
m.atlas.enabled = true
|
|
}
|
|
if !m.atlas.enabled {
|
|
return nil
|
|
}
|
|
if !m.withForeignKeys {
|
|
return errors.New("sql/schema: WithForeignKeys(false) does not work in Atlas migration")
|
|
}
|
|
if m.withFixture {
|
|
return errors.New("sql/schema: WithFixture(true) does not work in Atlas migration")
|
|
}
|
|
skip := DropIndex | DropColumn
|
|
if m.atlas.skip != NoChange {
|
|
skip = m.atlas.skip
|
|
}
|
|
if m.dropIndexes {
|
|
skip &= ^DropIndex
|
|
}
|
|
if m.dropColumns {
|
|
skip &= ^DropColumn
|
|
}
|
|
if skip != NoChange {
|
|
m.atlas.diff = append(m.atlas.diff, filterChanges(skip))
|
|
}
|
|
if m.atlas.dir != nil && m.atlas.fmt == nil {
|
|
m.atlas.fmt = migrate.DefaultFormatter
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrate) atCreate(ctx context.Context, tables ...*Table) error {
|
|
// Open a transaction for backwards compatibility,
|
|
// even if the migration is not transactional.
|
|
tx, err := m.Tx(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := func() error {
|
|
if err := m.init(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
if m.universalID {
|
|
if err := m.types(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
plan, err := m.atDiff(ctx, tx, "", tables...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Apply plan (changes).
|
|
var applier Applier = ApplyFunc(func(ctx context.Context, tx dialect.ExecQuerier, plan *migrate.Plan) error {
|
|
for _, c := range plan.Changes {
|
|
if err := tx.Exec(ctx, c.Cmd, c.Args, nil); err != nil {
|
|
if c.Comment != "" {
|
|
err = fmt.Errorf("%s: %w", c.Comment, err)
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
for i := len(m.atlas.apply) - 1; i >= 0; i-- {
|
|
applier = m.atlas.apply[i](applier)
|
|
}
|
|
return applier.Apply(ctx, tx, plan)
|
|
}(); err != nil {
|
|
return rollback(tx, err)
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (m *Migrate) atDiff(ctx context.Context, conn dialect.ExecQuerier, name string, tables ...*Table) (*migrate.Plan, error) {
|
|
drv, err := m.atOpen(conn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
current, err := drv.InspectSchema(ctx, "", &schema.InspectOptions{
|
|
Tables: func() (t []string) {
|
|
for i := range tables {
|
|
t = append(t, tables[i].Name)
|
|
}
|
|
return t
|
|
}(),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tt, err := m.aTables(ctx, m, conn, tables)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Diff changes.
|
|
var differ Differ = DiffFunc(drv.SchemaDiff)
|
|
for i := len(m.atlas.diff) - 1; i >= 0; i-- {
|
|
differ = m.atlas.diff[i](differ)
|
|
}
|
|
changes, err := differ.Diff(current, &schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: tt})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Plan changes.
|
|
return drv.PlanChanges(ctx, name, changes)
|
|
}
|
|
|
|
type db struct{ dialect.ExecQuerier }
|
|
|
|
func (d *db) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
rows := &entsql.Rows{}
|
|
if err := d.ExecQuerier.Query(ctx, query, args, rows); err != nil {
|
|
return nil, err
|
|
}
|
|
return rows.ColumnScanner.(*sql.Rows), nil
|
|
}
|
|
|
|
func (d *db) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
var r sql.Result
|
|
if err := d.ExecQuerier.Exec(ctx, query, args, &r); err != nil {
|
|
return nil, err
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
func (m *Migrate) aTables(ctx context.Context, b atBuilder, conn dialect.ExecQuerier, tables1 []*Table) ([]*schema.Table, error) {
|
|
tables2 := make([]*schema.Table, len(tables1))
|
|
for i, t1 := range tables1 {
|
|
t2 := schema.NewTable(t1.Name)
|
|
b.atTable(t1, t2)
|
|
if m.universalID {
|
|
r, err := m.pkRange(ctx, conn, t1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b.atIncrementT(t2, r)
|
|
}
|
|
if err := m.aColumns(b, t1, t2); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := m.aIndexes(b, t1, t2); err != nil {
|
|
return nil, err
|
|
}
|
|
tables2[i] = t2
|
|
}
|
|
for i, t1 := range tables1 {
|
|
t2 := tables2[i]
|
|
for _, fk1 := range t1.ForeignKeys {
|
|
fk2 := schema.NewForeignKey(fk1.Symbol).
|
|
SetTable(t2).
|
|
SetOnUpdate(schema.ReferenceOption(fk1.OnUpdate)).
|
|
SetOnDelete(schema.ReferenceOption(fk1.OnDelete))
|
|
for _, c1 := range fk1.Columns {
|
|
c2, ok := t2.Column(c1.Name)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected fk %q column: %q", fk1.Symbol, c1.Name)
|
|
}
|
|
fk2.AddColumns(c2)
|
|
}
|
|
var refT *schema.Table
|
|
for _, t2 := range tables2 {
|
|
if t2.Name == fk1.RefTable.Name {
|
|
refT = t2
|
|
break
|
|
}
|
|
}
|
|
if refT == nil {
|
|
return nil, fmt.Errorf("unexpected fk %q ref-table: %q", fk1.Symbol, fk1.RefTable.Name)
|
|
}
|
|
fk2.SetRefTable(refT)
|
|
for _, c1 := range fk1.RefColumns {
|
|
c2, ok := refT.Column(c1.Name)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected fk %q ref-column: %q", fk1.Symbol, c1.Name)
|
|
}
|
|
fk2.AddRefColumns(c2)
|
|
}
|
|
t2.AddForeignKeys(fk2)
|
|
}
|
|
}
|
|
return tables2, nil
|
|
}
|
|
|
|
func (m *Migrate) aColumns(b atBuilder, t1 *Table, t2 *schema.Table) error {
|
|
for _, c1 := range t1.Columns {
|
|
c2 := schema.NewColumn(c1.Name).
|
|
SetNull(c1.Nullable)
|
|
if c1.Collation != "" {
|
|
c2.SetCollation(c1.Collation)
|
|
}
|
|
if err := b.atTypeC(c1, c2); err != nil {
|
|
return err
|
|
}
|
|
if c1.Default != nil && c1.supportDefault() {
|
|
// Has default and the database supports adding this default.
|
|
x := fmt.Sprint(c1.Default)
|
|
if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime {
|
|
// Escape single quote by replacing each with 2.
|
|
x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
|
|
}
|
|
c2.SetDefault(&schema.RawExpr{X: x})
|
|
}
|
|
if c1.Unique {
|
|
b.atUniqueC(t1, c1, t2, c2)
|
|
}
|
|
if c1.Increment {
|
|
b.atIncrementC(t2, c2)
|
|
}
|
|
t2.AddColumns(c2)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrate) aIndexes(b atBuilder, t1 *Table, t2 *schema.Table) error {
|
|
// Primary-key index.
|
|
pk := make([]*schema.Column, 0, len(t1.PrimaryKey))
|
|
for _, c1 := range t1.PrimaryKey {
|
|
c2, ok := t2.Column(c1.Name)
|
|
if !ok {
|
|
return fmt.Errorf("unexpected primary-key column: %q", c1.Name)
|
|
}
|
|
pk = append(pk, c2)
|
|
}
|
|
t2.SetPrimaryKey(schema.NewPrimaryKey(pk...))
|
|
// Rest of indexes.
|
|
for _, idx1 := range t1.Indexes {
|
|
idx2 := schema.NewIndex(idx1.Name).
|
|
SetUnique(idx1.Unique)
|
|
if err := b.atIndex(idx1, t2, idx2); err != nil {
|
|
return err
|
|
}
|
|
desc := descIndexes(idx1)
|
|
for _, p := range idx2.Parts {
|
|
p.Desc = desc[p.C.Name]
|
|
}
|
|
t2.AddIndexes(idx2)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func setAtChecks(t1 *Table, t2 *schema.Table) {
|
|
if check := t1.Annotation.Check; check != "" {
|
|
t2.AddChecks(&schema.Check{
|
|
Expr: check,
|
|
})
|
|
}
|
|
if checks := t1.Annotation.Checks; len(t1.Annotation.Checks) > 0 {
|
|
names := make([]string, 0, len(checks))
|
|
for name := range checks {
|
|
names = append(names, name)
|
|
}
|
|
sort.Strings(names)
|
|
for _, name := range names {
|
|
t2.AddChecks(&schema.Check{
|
|
Name: name,
|
|
Expr: checks[name],
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// descIndexes returns a map holding the DESC mapping if exist.
|
|
func descIndexes(idx *Index) map[string]bool {
|
|
descs := make(map[string]bool)
|
|
if idx.Annotation == nil {
|
|
return descs
|
|
}
|
|
// If DESC (without a column) was defined on the
|
|
// annotation, map it to the single column index.
|
|
if idx.Annotation.Desc && len(idx.Columns) == 1 {
|
|
descs[idx.Columns[0].Name] = idx.Annotation.Desc
|
|
}
|
|
for column, desc := range idx.Annotation.DescColumns {
|
|
descs[column] = desc
|
|
}
|
|
return descs
|
|
}
|