dialect/sql/schema: hello ariga.io/atlas (#2279)

This commit is contained in:
Ariel Mashraki
2022-01-20 17:20:50 +02:00
committed by GitHub
parent 05590433a7
commit 60e03285d0
13 changed files with 1643 additions and 73 deletions

495
dialect/sql/schema/atlas.go Normal file
View File

@@ -0,0 +1,495 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"
)
type (
// Differ is the interface that wraps the Diff method.
Differ interface {
// Diff creates the given tables in the database.
Diff(current, desired *schema.Schema) ([]schema.Change, error)
}
// The DiffFunc type is an adapter to allow the use of ordinary function as Differ.
// If f is a function with the appropriate signature, DiffFunc(f) is a Differ that calls f.
DiffFunc func(current, desired *schema.Schema) ([]schema.Change, error)
// DiffHook defines the "diff middleware". A function that gets a Differ and returns a Differ.
DiffHook func(Differ) Differ
)
// Diff calls f(current, desired).
func (f DiffFunc) Diff(current, desired *schema.Schema) ([]schema.Change, error) {
return f(current, desired)
}
// WithDiffHook adds a list of DiffHook to the schema migration.
//
// schema.WithDiffHook(func(next schema.Differ) schema.Differ {
// return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) {
// // Code before standard diff.
// changes, err := next.Diff(current, desired)
// if err != nil {
// return nil, err
// }
// // After diff, you can filter
// // changes or return new ones.
// return changes, nil
// })
// })
//
func WithDiffHook(hooks ...DiffHook) MigrateOption {
return func(m *Migrate) {
m.atlas.diff = append(m.atlas.diff, hooks...)
}
}
// SkipChanges allows skipping/filtering list of changes
// returned by the differ before executing migration planning.
//
// SkipChanges(schema.DropTable|schema.DropColumn)
//
func WithSkipChanges(skip ChangeKind) MigrateOption {
return func(m *Migrate) {
m.atlas.skip = skip
}
}
// A Change of schema.
type ChangeKind uint
// List of change types.
const (
NoChange ChangeKind = 0
AddSchema ChangeKind = 1 << (iota - 1)
ModifySchema
DropSchema
AddTable
ModifyTable
DropTable
AddColumn
ModifyColumn
DropColumn
AddIndex
ModifyIndex
DropIndex
AddForeignKey
ModifyForeignKey
DropForeignKey
AddCheck
ModifyCheck
DropCheck
)
// Is reports whether c is match the given change king.
func (k ChangeKind) Is(c ChangeKind) bool {
return k == c || k&c != 0
}
// filterChanges is a DiffHook for filtering changes before plan.
func filterChanges(skip ChangeKind) DiffHook {
return func(next Differ) Differ {
return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) {
var f func([]schema.Change) []schema.Change
f = func(changes []schema.Change) (keep []schema.Change) {
var k ChangeKind
for _, c := range changes {
switch c := c.(type) {
case *schema.AddSchema:
k = AddSchema
case *schema.ModifySchema:
k = ModifySchema
if !skip.Is(k) {
c.Changes = f(c.Changes)
}
case *schema.DropSchema:
k = DropSchema
case *schema.AddTable:
k = AddTable
case *schema.ModifyTable:
k = ModifyTable
if !skip.Is(k) {
c.Changes = f(c.Changes)
}
case *schema.DropTable:
k = DropTable
case *schema.AddColumn:
k = AddColumn
case *schema.ModifyColumn:
k = ModifyColumn
case *schema.DropColumn:
k = DropColumn
case *schema.AddIndex:
k = AddIndex
case *schema.ModifyIndex:
k = ModifyIndex
case *schema.DropIndex:
k = DropIndex
case *schema.AddForeignKey:
k = AddIndex
case *schema.ModifyForeignKey:
k = ModifyForeignKey
case *schema.DropForeignKey:
k = DropForeignKey
case *schema.AddCheck:
k = AddCheck
case *schema.ModifyCheck:
k = ModifyCheck
case *schema.DropCheck:
k = DropCheck
}
if !skip.Is(k) {
keep = append(keep, c)
}
}
return
}
changes, err := next.Diff(current, desired)
if err != nil {
return nil, err
}
return f(changes), nil
})
}
}
type (
// Applier is the interface that wraps the Apply method.
Applier interface {
// Diff creates the given tables in the database.
Apply(context.Context, dialect.ExecQuerier, *migrate.Plan) error
}
// The ApplyFunc type is an adapter to allow the use of ordinary function as Applier.
// If f is a function with the appropriate signature, ApplyFunc(f) is a Applier that calls f.
ApplyFunc func(context.Context, dialect.ExecQuerier, *migrate.Plan) error
// ApplyHook defines the "migration applying middleware". A function that gets a Applier and returns a Applier.
ApplyHook func(Applier) Applier
)
// Apply calls f(ctx, tables...).
func (f ApplyFunc) Apply(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
return f(ctx, conn, plan)
}
// func adds a list of ApplyHook to the schema migration.
//
// schema.WithApplyHook(func(next schema.Applier) schema.Applier {
// return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
// // Example to hook into the apply process, or implement
// // a custom applier.
// //
// // for _, c := range plan.Changes {
// // fmt.Printf("%s: %s", c.Comment, c.Cmd)
// // }
// //
// return next.Apply(ctx, conn, plan)
// })
// })
//
func WithApplyHook(hooks ...ApplyHook) MigrateOption {
return func(m *Migrate) {
m.atlas.apply = append(m.atlas.apply, hooks...)
}
}
// WithAtlas is an opt-in option for v0.10 indicates the migration
// should be executed using Atlas engine (i.e. https://atlasgo.io).
// Note, in future versions, this option is going to be replaced
// from opt-in to opt-out and the deprecation of this package.
func WithAtlas(b bool) MigrateOption {
return func(m *Migrate) {
m.atlas.enabled = b
}
}
type (
// atlasOptions describes the options for atlas.
atlasOptions struct {
enabled bool
diff []DiffHook
apply []ApplyHook
skip ChangeKind
}
// atBuilder must be implemented by the different drivers in
// order to convert a dialect/sql/schema to atlas/sql/schema.
atBuilder interface {
atOpen(dialect.ExecQuerier) (migrate.Driver, error)
atTable(*Table, *schema.Table)
atTypeC(*Column, *schema.Column) error
atUniqueC(*Table, *Column, *schema.Table, *schema.Column)
atIncrementC(*schema.Table, *schema.Column)
atIncrementT(*schema.Table, int64)
atIndex(*Index, *schema.Table, *schema.Index) error
}
)
func (m *Migrate) setupAtlas() error {
// Using one of the Atlas options, opt-in to Atlas migration.
if !m.atlas.enabled && (m.atlas.skip != NoChange || len(m.atlas.diff) > 0 || len(m.atlas.apply) > 0) {
m.atlas.enabled = true
}
if !m.atlas.enabled {
return nil
}
if !m.withForeignKeys {
return errors.New("sql/schema: WithForeignKeys(false) does not work in Atlas migration")
}
if m.withFixture {
return errors.New("sql/schema: WithFixture(true) does not work in Atlas migration")
}
k := DropIndex | DropColumn
if m.atlas.skip != NoChange {
k = m.atlas.skip
}
if m.dropIndexes {
k |= ^DropIndex
}
if m.dropColumns {
k |= ^DropColumn
}
if k == NoChange {
m.atlas.diff = append(m.atlas.diff, filterChanges(k))
}
return nil
}
func (m *Migrate) atCreate(ctx context.Context, tables ...*Table) error {
// Open a transaction for backwards compatibility,
// even if the migration is not transactional.
tx, err := m.Tx(ctx)
if err != nil {
return err
}
if err := func() error {
if err := m.init(ctx, tx); err != nil {
return err
}
if m.universalID {
if err := m.types(ctx, tx); err != nil {
return err
}
}
drv, err := m.atOpen(tx)
if err != nil {
return err
}
current, err := drv.InspectSchema(ctx, "", &schema.InspectOptions{
Tables: func() (t []string) {
for i := range tables {
t = append(t, tables[i].Name)
}
return t
}(),
})
if err != nil {
return err
}
tt, err := m.aTables(ctx, m, tx, tables)
if err != nil {
return err
}
// Diff changes.
var differ Differ = DiffFunc(drv.SchemaDiff)
for i := len(m.atlas.diff) - 1; i >= 0; i-- {
differ = m.atlas.diff[i](differ)
}
changes, err := differ.Diff(current, &schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: tt})
if err != nil {
return err
}
// Plan changes.
plan, err := drv.PlanChanges(ctx, "plan", changes)
if err != nil {
return err
}
// Apply plan (changes).
var applier Applier = ApplyFunc(func(ctx context.Context, tx dialect.ExecQuerier, plan *migrate.Plan) error {
for _, c := range plan.Changes {
if err := tx.Exec(ctx, c.Cmd, c.Args, nil); err != nil {
if c.Comment != "" {
err = fmt.Errorf("%s: %w", c.Comment, err)
}
return err
}
}
return nil
})
for i := len(m.atlas.apply) - 1; i >= 0; i-- {
applier = m.atlas.apply[i](applier)
}
return applier.Apply(ctx, tx, plan)
}(); err != nil {
return rollback(tx, err)
}
return tx.Commit()
}
type db struct{ dialect.ExecQuerier }
func (d *db) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
rows := &entsql.Rows{}
if err := d.ExecQuerier.Query(ctx, query, args, rows); err != nil {
return nil, err
}
return rows.ColumnScanner.(*sql.Rows), nil
}
func (d *db) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
var r sql.Result
if err := d.ExecQuerier.Exec(ctx, query, args, &r); err != nil {
return nil, err
}
return r, nil
}
func (m *Migrate) aTables(ctx context.Context, b atBuilder, conn dialect.ExecQuerier, tables1 []*Table) ([]*schema.Table, error) {
tables2 := make([]*schema.Table, len(tables1))
for i, t1 := range tables1 {
t2 := schema.NewTable(t1.Name)
b.atTable(t1, t2)
if m.universalID {
r, err := m.pkRange(ctx, conn, t1)
if err != nil {
return nil, err
}
b.atIncrementT(t2, r)
}
if err := m.aColumns(b, t1, t2); err != nil {
return nil, err
}
if err := m.aIndexes(b, t1, t2); err != nil {
return nil, err
}
tables2[i] = t2
}
for i, t1 := range tables1 {
t2 := tables2[i]
for _, fk1 := range t1.ForeignKeys {
fk2 := schema.NewForeignKey(fk1.Symbol).
SetTable(t2).
SetOnUpdate(schema.ReferenceOption(fk1.OnUpdate)).
SetOnDelete(schema.ReferenceOption(fk1.OnDelete))
for _, c1 := range fk1.Columns {
c2, ok := t2.Column(c1.Name)
if !ok {
return nil, fmt.Errorf("unexpected fk %q column: %q", fk1.Symbol, c1.Name)
}
fk2.AddColumns(c2)
}
var refT *schema.Table
for _, t2 := range tables2 {
if t2.Name == fk1.RefTable.Name {
refT = t2
break
}
}
if refT == nil {
return nil, fmt.Errorf("unexpected fk %q ref-table: %q", fk1.Symbol, fk1.RefTable.Name)
}
fk2.SetRefTable(refT)
for _, c1 := range fk1.RefColumns {
c2, ok := refT.Column(c1.Name)
if !ok {
return nil, fmt.Errorf("unexpected fk %q ref-column: %q", fk1.Symbol, c1.Name)
}
fk2.AddRefColumns(c2)
}
t2.AddForeignKeys(fk2)
}
}
return tables2, nil
}
func (m *Migrate) aColumns(b atBuilder, t1 *Table, t2 *schema.Table) error {
for _, c1 := range t1.Columns {
c2 := schema.NewColumn(c1.Name).
SetNull(c1.Nullable)
if c1.Collation != "" {
c2.SetCollation(c1.Collation)
}
if err := b.atTypeC(c1, c2); err != nil {
return err
}
if c1.Default != nil && c1.supportDefault() {
// Has default and the database supports adding this default.
x := fmt.Sprint(c1.Default)
if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime {
// Escape single quote by replacing each with 2.
x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
}
c2.SetDefault(&schema.RawExpr{X: x})
}
if c1.Unique {
b.atUniqueC(t1, c1, t2, c2)
}
if c1.Increment {
b.atIncrementC(t2, c2)
}
t2.AddColumns(c2)
}
return nil
}
func (m *Migrate) aIndexes(b atBuilder, t1 *Table, t2 *schema.Table) error {
// Primary-key index.
pk := make([]*schema.Column, 0, len(t1.PrimaryKey))
for _, c1 := range t1.PrimaryKey {
c2, ok := t2.Column(c1.Name)
if !ok {
return fmt.Errorf("unexpected primary-key column: %q", c1.Name)
}
pk = append(pk, c2)
}
t2.SetPrimaryKey(schema.NewPrimaryKey(pk...))
// Rest of indexes.
for _, idx1 := range t1.Indexes {
idx2 := schema.NewIndex(idx1.Name).
SetUnique(idx1.Unique)
if err := b.atIndex(idx1, t2, idx2); err != nil {
return err
}
t2.AddIndexes(idx2)
}
return nil
}
func setAtChecks(t1 *Table, t2 *schema.Table) {
if check := t1.Annotation.Check; check != "" {
t2.AddChecks(&schema.Check{
Expr: check,
})
}
if checks := t1.Annotation.Checks; len(t1.Annotation.Checks) > 0 {
names := make([]string, 0, len(checks))
for name := range checks {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
t2.AddChecks(&schema.Check{
Name: name,
Expr: checks[name],
})
}
}
}

View File

@@ -104,18 +104,19 @@ func (f CreateFunc) Create(ctx context.Context, tables ...*Table) error {
// Migrate runs the migrations logic for the SQL dialects.
type Migrate struct {
sqlDialect
universalID bool // global unique ids.
dropColumns bool // drop deleted columns.
dropIndexes bool // drop deleted indexes.
withFixture bool // with fks rename fixture.
withForeignKeys bool // with foreign keys
typeRanges []string // types order by their range.
hooks []Hook // hooks to apply before creation
universalID bool // global unique ids.
dropColumns bool // drop deleted columns.
dropIndexes bool // drop deleted indexes.
withFixture bool // with fks rename fixture.
withForeignKeys bool // with foreign keys
atlas *atlasOptions // migrate with atlas.
typeRanges []string // types order by their range.
hooks []Hook // hooks to apply before creation
}
// NewMigrate create a migration structure for the given SQL driver.
func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
m := &Migrate{withForeignKeys: true}
m := &Migrate{withForeignKeys: true, atlas: &atlasOptions{}}
for _, opt := range opts {
opt(m)
}
@@ -129,6 +130,9 @@ func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) {
default:
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect())
}
if err := m.setupAtlas(); err != nil {
return nil, err
}
return m, nil
}
@@ -146,10 +150,12 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
m.setupTable(t)
}
var creator Creator = CreateFunc(m.create)
if m.atlas.enabled {
creator = CreateFunc(m.atCreate)
}
for i := len(m.hooks) - 1; i >= 0; i-- {
creator = m.hooks[i](creator)
}
return creator.Create(ctx, tables...)
}
@@ -502,12 +508,12 @@ func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error {
if id == -1 {
return nil
}
return vr.verifyRange(ctx, tx, t, id<<32)
return vr.verifyRange(ctx, tx, t, int64(id<<32))
}
// types loads the type list from the database.
// If the table does not create, it will create one.
func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error {
func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) error {
exists, err := m.tableExist(ctx, tx, TypeTable)
if err != nil {
return err
@@ -532,24 +538,31 @@ func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error {
return sql.ScanSlice(rows, &m.typeRanges)
}
func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) error {
func (m *Migrate) allocPKRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) error {
r, err := m.pkRange(ctx, conn, t)
if err != nil {
return err
}
return m.setRange(ctx, conn, t, r)
}
func (m *Migrate) pkRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) (int64, error) {
id := indexOf(m.typeRanges, t.Name)
// If the table re-created, re-use its range from
// the past. otherwise, allocate a new id-range.
// the past. Otherwise, allocate a new id-range.
if id == -1 {
if len(m.typeRanges) > MaxTypes {
return fmt.Errorf("max number of types exceeded: %d", MaxTypes)
return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes)
}
query, args := sql.Dialect(m.Dialect()).
Insert(TypeTable).Columns("type").Values(t.Name).Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("insert into type: %w", err)
if err := conn.Exec(ctx, query, args, nil); err != nil {
return 0, fmt.Errorf("insert into ent_types: %w", err)
}
id = len(m.typeRanges)
m.typeRanges = append(m.typeRanges, t.Name)
}
// Set the id offset for table.
return m.setRange(ctx, tx, t, id<<32)
return int64(id << 32), nil
}
// fkColumn returns the column name of a foreign-key.
@@ -630,9 +643,9 @@ func rollback(tx dialect.Tx, err error) error {
}
// exist checks if the given COUNT query returns a value >= 1.
func exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
func exist(ctx context.Context, conn dialect.ExecQuerier, query string, args ...interface{}) (bool, error) {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
if err := conn.Query(ctx, query, args, rows); err != nil {
return false, fmt.Errorf("reading schema information %w", err)
}
defer rows.Close()
@@ -653,12 +666,13 @@ func indexOf(a []string, s string) int {
}
type sqlDialect interface {
atBuilder
dialect.Driver
init(context.Context, dialect.Tx) error
init(context.Context, dialect.ExecQuerier) error
table(context.Context, dialect.Tx, string) (*Table, error)
tableExist(context.Context, dialect.Tx, string) (bool, error)
tableExist(context.Context, dialect.ExecQuerier, string) (bool, error)
fkExist(context.Context, dialect.Tx, string) (bool, error)
setRange(context.Context, dialect.Tx, *Table, int) error
setRange(context.Context, dialect.ExecQuerier, *Table, int64) error
dropIndex(context.Context, dialect.Tx, *Index, string) error
// table, column and index builder per dialect.
cType(*Column) string
@@ -683,5 +697,5 @@ type fkRenamer interface {
// verifyRanger wraps the method for verifying global-id range correctness.
type verifyRanger interface {
verifyRange(context.Context, dialect.Tx, *Table, int) error
verifyRange(context.Context, dialect.Tx, *Table, int64) error
}

View File

@@ -15,6 +15,10 @@ import (
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/mysql"
"ariga.io/atlas/sql/schema"
)
// MySQL is a MySQL migration driver.
@@ -25,9 +29,9 @@ type MySQL struct {
}
// init loads the MySQL version from the database for later use in the migration process.
func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error {
func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error {
rows := &sql.Rows{}
if err := tx.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
return fmt.Errorf("mysql: querying mysql version %w", err)
}
defer rows.Close()
@@ -45,13 +49,13 @@ func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error {
return nil
}
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
func (d *MySQL) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) {
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
d.matchSchema(),
sql.EQ("TABLE_NAME", name),
)).Query()
return exist(ctx, tx, query, args...)
return exist(ctx, conn, query, args...)
}
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
@@ -133,11 +137,11 @@ func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, t *Table) ([]*Index,
return idx, nil
}
func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []interface{}{}, nil)
func (d *MySQL) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []interface{}{}, nil)
}
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expected int) error {
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expected int64) error {
if expected == 0 {
return nil
}
@@ -776,3 +780,176 @@ func indexParts(idx *Index) map[string]uint {
}
return parts
}
// Atlas integration.
func (d *MySQL) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
return mysql.Open(&db{ExecQuerier: conn})
}
func (d *MySQL) atTable(t1 *Table, t2 *schema.Table) {
t2.SetCharset("utf8mb4").SetCollation("utf8mb4_bin")
if t1.Annotation == nil {
return
}
if charset := t1.Annotation.Charset; charset != "" {
t2.SetCharset(charset)
}
if collate := t1.Annotation.Collation; collate != "" {
t2.SetCollation(collate)
}
if opts := t1.Annotation.Options; opts != "" {
t2.AddAttrs(&mysql.CreateOptions{
V: opts,
})
}
// Check if the connected database supports the CHECK clause.
// For MySQL, is >= "8.0.16" and for MariaDB it is "10.2.1".
v1, v2 := d.version, "8.0.16"
if v, ok := d.mariadb(); ok {
v1, v2 = v, "10.2.1"
}
if compareVersions(v1, v2) >= 0 {
setAtChecks(t1, t2)
}
}
func (d *MySQL) atTypeC(c1 *Column, c2 *schema.Column) error {
if c1.SchemaType != nil && c1.SchemaType[dialect.MySQL] != "" {
t, err := mysql.ParseType(strings.ToLower(c1.SchemaType[dialect.MySQL]))
if err != nil {
return err
}
c2.Type.Type = t
return nil
}
var t schema.Type
switch c1.Type {
case field.TypeBool:
t = &schema.BoolType{T: "boolean"}
case field.TypeInt8:
t = &schema.IntegerType{T: mysql.TypeTinyInt}
case field.TypeUint8:
t = &schema.IntegerType{T: mysql.TypeTinyInt, Unsigned: true}
case field.TypeInt16:
t = &schema.IntegerType{T: mysql.TypeSmallInt}
case field.TypeUint16:
t = &schema.IntegerType{T: mysql.TypeSmallInt, Unsigned: true}
case field.TypeInt32:
t = &schema.IntegerType{T: mysql.TypeInt}
case field.TypeUint32:
t = &schema.IntegerType{T: mysql.TypeInt, Unsigned: true}
case field.TypeInt, field.TypeInt64:
t = &schema.IntegerType{T: mysql.TypeBigInt}
case field.TypeUint, field.TypeUint64:
t = &schema.IntegerType{T: mysql.TypeBigInt, Unsigned: true}
case field.TypeBytes:
size := int64(math.MaxUint16)
if c1.Size > 0 {
size = c1.Size
}
switch {
case size <= math.MaxUint8:
t = &schema.BinaryType{T: mysql.TypeTinyBlob}
case size <= math.MaxUint16:
t = &schema.BinaryType{T: mysql.TypeBlob}
case size < 1<<24:
t = &schema.BinaryType{T: mysql.TypeMediumBlob}
case size <= math.MaxUint32:
t = &schema.BinaryType{T: mysql.TypeLongBlob}
}
case field.TypeJSON:
t = &schema.JSONType{T: mysql.TypeJSON}
if compareVersions(d.version, "5.7.8") == -1 {
t = &schema.BinaryType{T: mysql.TypeLongBlob}
}
case field.TypeString:
size := c1.Size
if size == 0 {
size = d.defaultSize(c1)
}
switch {
case c1.typ == "tinytext", c1.typ == "text":
t = &schema.StringType{T: c1.typ}
case size <= math.MaxUint16:
t = &schema.StringType{T: mysql.TypeVarchar, Size: int(size)}
case size == 1<<24-1:
t = &schema.StringType{T: mysql.TypeMediumText}
default:
t = &schema.StringType{T: mysql.TypeLongText}
}
case field.TypeFloat32, field.TypeFloat64:
t = &schema.FloatType{T: c1.scanTypeOr(mysql.TypeDouble)}
case field.TypeTime:
t = &schema.TimeType{T: c1.scanTypeOr(mysql.TypeTimestamp)}
// In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP`
// and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is
// suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute.
if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c1.Default == nil {
c2.SetNull(c1.Attr == "")
}
case field.TypeEnum:
t = &schema.EnumType{T: mysql.TypeEnum, Values: c1.Enums}
case field.TypeUUID:
// "CHAR(X) BINARY" is treated as "CHAR(X) COLLATE latin1_bin", and in MySQL < 8,
// and "COLLATE utf8mb4_bin" in MySQL >= 8. However we already set the table to
t = &schema.StringType{T: mysql.TypeChar, Size: 36}
c2.SetCollation("utf8mb4_bin")
default:
t, err := mysql.ParseType(strings.ToLower(c1.typ))
if err != nil {
return err
}
c2.Type.Type = t
}
c2.Type.Type = t
return nil
}
func (d *MySQL) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) {
// For UNIQUE columns, MySQL create an implicit index
// named as the column with an extra index in case the
// name is already taken (<e.g. c>, <c_2>, <c_3>, ...).
for _, idx := range t1.Indexes {
// Index also defined explicitly, and will be add in atIndexes.
if idx.Unique && d.atImplicitIndexName(idx, c1) {
return
}
}
t2.AddIndexes(schema.NewUniqueIndex(c1.Name).AddColumns(c2))
}
func (d *MySQL) atIncrementC(_ *schema.Table, c *schema.Column) {
c.AddAttrs(&mysql.AutoIncrement{})
}
func (d *MySQL) atIncrementT(t *schema.Table, v int64) {
t.AddAttrs(&mysql.AutoIncrement{V: v})
}
func (d *MySQL) atImplicitIndexName(idx *Index, c1 *Column) bool {
if idx.Name == c1.Name {
return true
}
if !strings.HasPrefix(idx.Name, c1.Name+"_") {
return false
}
i, err := strconv.ParseInt(strings.TrimLeft(idx.Name, c1.Name+"_"), 10, 64)
return err == nil && i > 1
}
func (d *MySQL) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error {
prefix := indexParts(idx1)
for _, c1 := range idx1.Columns {
c2, ok := t2.Column(c1.Name)
if !ok {
return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name)
}
part := &schema.IndexPart{C: c2}
if v, ok := prefix[c1.Name]; ok {
part.AddAttrs(&mysql.SubPart{Len: int(v)})
}
idx2.AddParts(part)
}
return nil
}

View File

@@ -14,6 +14,10 @@ import (
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/postgres"
"ariga.io/atlas/sql/schema"
)
// Postgres is a postgres migration driver.
@@ -25,7 +29,7 @@ type Postgres struct {
// init loads the Postgres version from the database for later use in the migration process.
// It returns an error if the server version is lower than v10.
func (d *Postgres) init(ctx context.Context, tx dialect.Tx) error {
func (d *Postgres) init(ctx context.Context, tx dialect.ExecQuerier) error {
rows := &sql.Rows{}
if err := tx.Query(ctx, "SHOW server_version_num", []interface{}{}, rows); err != nil {
return fmt.Errorf("querying server version %w", err)
@@ -52,14 +56,14 @@ func (d *Postgres) init(ctx context.Context, tx dialect.Tx) error {
}
// tableExist checks if a table exists in the database and current schema.
func (d *Postgres) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
func (d *Postgres) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) {
query, args := sql.Dialect(dialect.Postgres).
Select(sql.Count("*")).From(sql.Table("tables").Schema("information_schema")).
Where(sql.And(
d.matchSchema(),
sql.EQ("table_name", name),
)).Query()
return exist(ctx, tx, query, args...)
return exist(ctx, conn, query, args...)
}
// tableExist checks if a foreign-key exists in the current schema.
@@ -75,7 +79,7 @@ func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (boo
}
// setRange sets restart the identity column to the given offset. Used by the universal-id option.
func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
func (d *Postgres) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
if value == 0 {
value = 1 // RESTART value cannot be < 1.
}
@@ -83,7 +87,7 @@ func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, t *Table, value
if len(t.PrimaryKey) == 1 {
pk = t.PrimaryKey[0].Name
}
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s RESTART WITH %d", t.Name, pk, value), []interface{}{}, nil)
return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s RESTART WITH %d", t.Name, pk, value), []interface{}{}, nil)
}
// table loads the current table description from the database.
@@ -664,3 +668,114 @@ WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name = '%s'
order by constraint_name, kcu.ordinal_position;
`
// Atlas integration.
func (d *Postgres) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
return postgres.Open(&db{ExecQuerier: conn})
}
func (d *Postgres) atTable(t1 *Table, t2 *schema.Table) {
if t1.Annotation != nil {
setAtChecks(t1, t2)
}
}
func (d *Postgres) atTypeC(c1 *Column, c2 *schema.Column) error {
if c1.SchemaType != nil && c1.SchemaType[dialect.Postgres] != "" {
t, err := postgres.ParseType(strings.ToLower(c1.SchemaType[dialect.Postgres]))
if err != nil {
return err
}
c2.Type.Type = t
return nil
}
var t schema.Type
switch c1.Type {
case field.TypeBool:
t = &schema.BoolType{T: postgres.TypeBoolean}
case field.TypeUint8, field.TypeInt8, field.TypeInt16, field.TypeUint16:
t = &schema.IntegerType{T: postgres.TypeSmallInt}
case field.TypeInt32, field.TypeUint32:
t = &schema.IntegerType{T: postgres.TypeInt}
case field.TypeInt, field.TypeUint, field.TypeInt64, field.TypeUint64:
t = &schema.IntegerType{T: postgres.TypeBigInt}
case field.TypeFloat32:
t = &schema.FloatType{T: c1.scanTypeOr(postgres.TypeReal)}
case field.TypeFloat64:
t = &schema.FloatType{T: c1.scanTypeOr(postgres.TypeDouble)}
case field.TypeBytes:
t = &schema.BinaryType{T: postgres.TypeBytea}
case field.TypeUUID:
t = &postgres.UUIDType{T: postgres.TypeUUID}
case field.TypeJSON:
t = &schema.JSONType{T: postgres.TypeJSONB}
case field.TypeString:
t = &schema.StringType{T: postgres.TypeVarChar}
if c1.Size > maxCharSize {
t = &schema.StringType{T: postgres.TypeText}
}
case field.TypeTime:
t = &schema.TimeType{T: c1.scanTypeOr(postgres.TypeTimestampWTZ)}
case field.TypeEnum:
// Although atlas supports enum types, we keep backwards compatibility
// with previous versions of ent and use varchar (see cType).
t = &schema.StringType{T: postgres.TypeVarChar}
case field.TypeOther:
t = &schema.UnsupportedType{T: c1.typ}
default:
t, err := postgres.ParseType(strings.ToLower(c1.typ))
if err != nil {
return err
}
c2.Type.Type = t
}
c2.Type.Type = t
return nil
}
func (d *Postgres) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) {
// For UNIQUE columns, PostgreSQL create an implicit index named
// "<table>_<column>_key<i>".
for _, idx := range t1.Indexes {
// Index also defined explicitly, and will be add in atIndexes.
if idx.Unique && d.atImplicitIndexName(idx, t1, c1) {
return
}
}
t2.AddIndexes(schema.NewUniqueIndex(fmt.Sprintf("%s_%s_key", t1.Name, c1.Name)).AddColumns(c2))
}
func (d *Postgres) atImplicitIndexName(idx *Index, t1 *Table, c1 *Column) bool {
p := fmt.Sprintf("%s_%s_key", t1.Name, c1.Name)
if idx.Name == p {
return true
}
i, err := strconv.ParseInt(strings.TrimPrefix(idx.Name, p), 10, 64)
return err == nil && i > 0
}
func (d *Postgres) atIncrementC(t *schema.Table, c *schema.Column) {
id := &postgres.Identity{}
for _, a := range t.Attrs {
if a, ok := a.(*postgres.Identity); ok {
id = a
}
}
c.AddAttrs(id)
}
func (d *Postgres) atIncrementT(t *schema.Table, v int64) {
t.AddAttrs(&postgres.Identity{Sequence: &postgres.Sequence{Start: v}})
}
func (d *Postgres) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error {
for _, c1 := range idx1.Columns {
c2, ok := t2.Column(c1.Name)
if !ok {
return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name)
}
idx2.AddParts(&schema.IndexPart{C: c2})
}
return nil
}

View File

@@ -7,11 +7,16 @@ package schema
import (
"context"
"fmt"
"strconv"
"strings"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"
"ariga.io/atlas/sql/sqlite"
)
// SQLite is an SQLite migration driver.
@@ -21,7 +26,7 @@ type SQLite struct {
}
// init makes sure that foreign_keys support is enabled.
func (d *SQLite) init(ctx context.Context, tx dialect.Tx) error {
func (d *SQLite) init(ctx context.Context, tx dialect.ExecQuerier) error {
on, err := exist(ctx, tx, "PRAGMA foreign_keys")
if err != nil {
return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err)
@@ -34,7 +39,7 @@ func (d *SQLite) init(ctx context.Context, tx dialect.Tx) error {
return nil
}
func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
func (d *SQLite) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) {
query, args := sql.Select().Count().
From(sql.Table("sqlite_master")).
Where(sql.And(
@@ -42,7 +47,7 @@ func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bo
sql.EQ("name", name),
)).
Query()
return exist(ctx, tx, query, args...)
return exist(ctx, conn, query, args...)
}
// setRange sets the start value of table PK.
@@ -50,12 +55,12 @@ func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bo
// whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables)
// only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence"
// table, we updated it. Otherwise, we insert a new value.
func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
func (d *SQLite) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
query, args := sql.Select().Count().
From(sql.Table("sqlite_sequence")).
Where(sql.EQ("name", t.Name)).
Query()
exists, err := exist(ctx, tx, query, args...)
exists, err := exist(ctx, conn, query, args...)
switch {
case err != nil:
return err
@@ -64,7 +69,7 @@ func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, t *Table, value in
default: // !exists
query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(t.Name, value).Query()
}
return tx.Exec(ctx, query, args, nil)
return conn.Exec(ctx, query, args, nil)
}
func (d *SQLite) tBuilder(t *Table) *sql.TableBuilder {
@@ -345,3 +350,102 @@ func (d *SQLite) needsConversion(old, new *Column) bool {
c1, c2 := d.cType(old), d.cType(new)
return c1 != c2 && old.typ != c2
}
// Atlas integration.
func (d *SQLite) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
return sqlite.Open(&db{ExecQuerier: conn})
}
func (d *SQLite) atTable(t1 *Table, t2 *schema.Table) {
if t1.Annotation != nil {
setAtChecks(t1, t2)
}
}
func (d *SQLite) atTypeC(c1 *Column, c2 *schema.Column) error {
if c1.SchemaType != nil && c1.SchemaType[dialect.SQLite] != "" {
t, err := sqlite.ParseType(strings.ToLower(c1.SchemaType[dialect.SQLite]))
if err != nil {
return err
}
c2.Type.Type = t
return nil
}
var t schema.Type
switch c1.Type {
case field.TypeBool:
t = &schema.BoolType{T: "bool"}
case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32,
field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64:
t = &schema.IntegerType{T: sqlite.TypeInteger}
case field.TypeBytes:
t = &schema.BinaryType{T: sqlite.TypeBlob}
case field.TypeString, field.TypeEnum:
// SQLite does not impose any length restrictions on
// the length of strings, BLOBs or numeric values.
t = &schema.StringType{T: sqlite.TypeText}
case field.TypeFloat32, field.TypeFloat64:
t = &schema.FloatType{T: sqlite.TypeReal}
case field.TypeTime:
t = &schema.TimeType{T: "datetime"}
case field.TypeJSON:
t = &schema.JSONType{T: "json"}
case field.TypeUUID:
t = &sqlite.UUIDType{T: "uuid"}
case field.TypeOther:
t = &schema.UnsupportedType{T: c1.typ}
default:
t, err := sqlite.ParseType(strings.ToLower(c1.typ))
if err != nil {
return err
}
c2.Type.Type = t
}
c2.Type.Type = t
return nil
}
func (d *SQLite) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) {
// For UNIQUE columns, SQLite create an implicit index named
// "sqlite_autoindex_<table>_<i>". Ent uses the MySQL approach
// in its migration, and name these indexes as the columns.
for _, idx := range t1.Indexes {
// Index also defined explicitly, and will be add in atIndexes.
if idx.Unique && d.atImplicitIndexName(idx, t1, c1) {
return
}
}
t2.AddIndexes(schema.NewUniqueIndex(c1.Name).AddColumns(c2))
}
func (d *SQLite) atImplicitIndexName(idx *Index, t1 *Table, c1 *Column) bool {
if idx.Name == c1.Name {
return true
}
p := fmt.Sprintf("sqlite_autoindex_%s_", t1.Name)
if !strings.HasPrefix(idx.Name, p) {
return false
}
i, err := strconv.ParseInt(strings.TrimPrefix(idx.Name, p), 10, 64)
return err == nil && i > 0
}
func (d *SQLite) atIncrementC(_ *schema.Table, c *schema.Column) {
c.AddAttrs(&sqlite.AutoIncrement{})
}
func (d *SQLite) atIncrementT(t *schema.Table, v int64) {
t.AddAttrs(&sqlite.AutoIncrement{Seq: v})
}
func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error {
for _, c1 := range idx1.Columns {
c2, ok := t2.Column(c1.Name)
if !ok {
return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name)
}
idx2.AddParts(&schema.IndexPart{C: c2})
}
return nil
}