mirror of
https://github.com/ent/ent.git
synced 2026-05-04 00:20:58 +03:00
* dialect/sql/schema: file based type store This PR adds support for a file based type storage when using versioned migrations. The file called `.ent_types` is written to the migration directory alongside the migration files and will be kept in sync for every migration file generation run. In order to not break existing code, where the type storage might differ for different deployment, global unique ID mut be enabled by using a new option. This will also be raised as an error to the user when attempting to use versioned migrations and global unique ID. Documentation will be added to this PR once feedback on the code is gathered. * apply CR * fix tests * change format of types file to exclude it from atlas.sum file * docs and drift test * apply CR
980 lines
28 KiB
Go
980 lines
28 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package schema
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"entgo.io/ent/dialect"
|
|
"entgo.io/ent/dialect/entsql"
|
|
"entgo.io/ent/dialect/sql"
|
|
"entgo.io/ent/schema/field"
|
|
|
|
"ariga.io/atlas/sql/migrate"
|
|
"ariga.io/atlas/sql/mysql"
|
|
"ariga.io/atlas/sql/schema"
|
|
)
|
|
|
|
// MySQL is a MySQL migration driver.
|
|
type MySQL struct {
|
|
dialect.Driver
|
|
schema string
|
|
version string
|
|
}
|
|
|
|
// init loads the MySQL version from the database for later use in the migration process.
|
|
func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error {
|
|
rows := &sql.Rows{}
|
|
if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
|
|
return fmt.Errorf("mysql: querying mysql version %w", err)
|
|
}
|
|
defer rows.Close()
|
|
if !rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
return fmt.Errorf("mysql: version variable was not found")
|
|
}
|
|
version := make([]string, 2)
|
|
if err := rows.Scan(&version[0], &version[1]); err != nil {
|
|
return fmt.Errorf("mysql: scanning mysql version: %w", err)
|
|
}
|
|
d.version = version[1]
|
|
return nil
|
|
}
|
|
|
|
func (d *MySQL) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) {
|
|
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
|
Where(sql.And(
|
|
d.matchSchema(),
|
|
sql.EQ("TABLE_NAME", name),
|
|
)).Query()
|
|
return exist(ctx, conn, query, args...)
|
|
}
|
|
|
|
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
|
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLE_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
|
|
Where(sql.And(
|
|
d.matchSchema(),
|
|
sql.EQ("CONSTRAINT_TYPE", "FOREIGN KEY"),
|
|
sql.EQ("CONSTRAINT_NAME", name),
|
|
)).Query()
|
|
return exist(ctx, tx, query, args...)
|
|
}
|
|
|
|
// table loads the current table description from the database.
|
|
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
|
|
rows := &sql.Rows{}
|
|
query, args := sql.Select(
|
|
"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name",
|
|
"numeric_precision", "numeric_scale",
|
|
).
|
|
From(sql.Table("COLUMNS").Schema("INFORMATION_SCHEMA")).
|
|
Where(sql.And(
|
|
d.matchSchema(),
|
|
sql.EQ("TABLE_NAME", name)),
|
|
).Query()
|
|
if err := tx.Query(ctx, query, args, rows); err != nil {
|
|
return nil, fmt.Errorf("mysql: reading table description %w", err)
|
|
}
|
|
// Call Close in cases of failures (Close is idempotent).
|
|
defer rows.Close()
|
|
t := NewTable(name)
|
|
for rows.Next() {
|
|
c := &Column{}
|
|
if err := d.scanColumn(c, rows); err != nil {
|
|
return nil, fmt.Errorf("mysql: %w", err)
|
|
}
|
|
t.AddColumn(c)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := rows.Close(); err != nil {
|
|
return nil, fmt.Errorf("mysql: closing rows %w", err)
|
|
}
|
|
indexes, err := d.indexes(ctx, tx, t)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Add and link indexes to table columns.
|
|
for _, idx := range indexes {
|
|
t.addIndex(idx)
|
|
}
|
|
if _, ok := d.mariadb(); ok {
|
|
if err := d.normalizeJSON(ctx, tx, t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
// table loads the table indexes from the database.
|
|
func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, t *Table) ([]*Index, error) {
|
|
rows := &sql.Rows{}
|
|
query, args := sql.Select("index_name", "column_name", "sub_part", "non_unique", "seq_in_index").
|
|
From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")).
|
|
Where(sql.And(
|
|
d.matchSchema(),
|
|
sql.EQ("TABLE_NAME", t.Name),
|
|
)).
|
|
OrderBy("index_name", "seq_in_index").
|
|
Query()
|
|
if err := tx.Query(ctx, query, args, rows); err != nil {
|
|
return nil, fmt.Errorf("mysql: reading index description %w", err)
|
|
}
|
|
defer rows.Close()
|
|
idx, err := d.scanIndexes(rows, t)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("mysql: %w", err)
|
|
}
|
|
return idx, nil
|
|
}
|
|
|
|
func (d *MySQL) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
|
|
return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []interface{}{}, nil)
|
|
}
|
|
|
|
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expected int64) error {
|
|
if expected == 0 {
|
|
return nil
|
|
}
|
|
rows := &sql.Rows{}
|
|
query, args := sql.Select("AUTO_INCREMENT").
|
|
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
|
Where(sql.And(
|
|
d.matchSchema(),
|
|
sql.EQ("TABLE_NAME", t.Name),
|
|
)).
|
|
Query()
|
|
if err := tx.Query(ctx, query, args, rows); err != nil {
|
|
return fmt.Errorf("mysql: query auto_increment %w", err)
|
|
}
|
|
// Call Close in cases of failures (Close is idempotent).
|
|
defer rows.Close()
|
|
actual := &sql.NullInt64{}
|
|
if err := sql.ScanOne(rows, actual); err != nil {
|
|
return fmt.Errorf("mysql: scan auto_increment %w", err)
|
|
}
|
|
if err := rows.Close(); err != nil {
|
|
return err
|
|
}
|
|
// Table is empty and auto-increment is not configured. This can happen
|
|
// because MySQL (< 8.0) stores the auto-increment counter in main memory
|
|
// (not persistent), and the value is reset on restart (if table is empty).
|
|
if actual.Int64 <= 1 {
|
|
return d.setRange(ctx, tx, t, expected)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// tBuilder returns the MySQL DSL query for table creation.
|
|
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder {
|
|
b := sql.CreateTable(t.Name).IfNotExists()
|
|
for _, c := range t.Columns {
|
|
b.Column(d.addColumn(c))
|
|
}
|
|
for _, pk := range t.PrimaryKey {
|
|
b.PrimaryKey(pk.Name)
|
|
}
|
|
// Charset and collation config on MySQL table.
|
|
// These options can be overridden by the entsql annotation.
|
|
b.Charset("utf8mb4").Collate("utf8mb4_bin")
|
|
if t.Annotation != nil {
|
|
if charset := t.Annotation.Charset; charset != "" {
|
|
b.Charset(charset)
|
|
}
|
|
if collate := t.Annotation.Collation; collate != "" {
|
|
b.Collate(collate)
|
|
}
|
|
if opts := t.Annotation.Options; opts != "" {
|
|
b.Options(opts)
|
|
}
|
|
addChecks(b, t.Annotation)
|
|
}
|
|
return b
|
|
}
|
|
|
|
// cType returns the MySQL string type for the given column.
|
|
func (d *MySQL) cType(c *Column) (t string) {
|
|
if c.SchemaType != nil && c.SchemaType[dialect.MySQL] != "" {
|
|
// MySQL returns the column type lower cased.
|
|
return strings.ToLower(c.SchemaType[dialect.MySQL])
|
|
}
|
|
switch c.Type {
|
|
case field.TypeBool:
|
|
t = "boolean"
|
|
case field.TypeInt8:
|
|
t = "tinyint"
|
|
case field.TypeUint8:
|
|
t = "tinyint unsigned"
|
|
case field.TypeInt16:
|
|
t = "smallint"
|
|
case field.TypeUint16:
|
|
t = "smallint unsigned"
|
|
case field.TypeInt32:
|
|
t = "int"
|
|
case field.TypeUint32:
|
|
t = "int unsigned"
|
|
case field.TypeInt, field.TypeInt64:
|
|
t = "bigint"
|
|
case field.TypeUint, field.TypeUint64:
|
|
t = "bigint unsigned"
|
|
case field.TypeBytes:
|
|
size := int64(math.MaxUint16)
|
|
if c.Size > 0 {
|
|
size = c.Size
|
|
}
|
|
switch {
|
|
case size <= math.MaxUint8:
|
|
t = "tinyblob"
|
|
case size <= math.MaxUint16:
|
|
t = "blob"
|
|
case size < 1<<24:
|
|
t = "mediumblob"
|
|
case size <= math.MaxUint32:
|
|
t = "longblob"
|
|
}
|
|
case field.TypeJSON:
|
|
t = "json"
|
|
if compareVersions(d.version, "5.7.8") == -1 {
|
|
t = "longblob"
|
|
}
|
|
case field.TypeString:
|
|
size := c.Size
|
|
if size == 0 {
|
|
size = d.defaultSize(c)
|
|
}
|
|
switch {
|
|
case c.typ == "tinytext", c.typ == "text":
|
|
t = c.typ
|
|
case size <= math.MaxUint16:
|
|
t = fmt.Sprintf("varchar(%d)", size)
|
|
case size == 1<<24-1:
|
|
t = "mediumtext"
|
|
default:
|
|
t = "longtext"
|
|
}
|
|
case field.TypeFloat32, field.TypeFloat64:
|
|
t = c.scanTypeOr("double")
|
|
case field.TypeTime:
|
|
t = c.scanTypeOr("timestamp")
|
|
// In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP`
|
|
// and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is
|
|
// suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute.
|
|
if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c.Default == nil {
|
|
c.Nullable = c.Attr == ""
|
|
}
|
|
case field.TypeEnum:
|
|
values := make([]string, len(c.Enums))
|
|
for i, e := range c.Enums {
|
|
values[i] = fmt.Sprintf("'%s'", e)
|
|
}
|
|
t = fmt.Sprintf("enum(%s)", strings.Join(values, ", "))
|
|
case field.TypeUUID:
|
|
t = "char(36) binary"
|
|
case field.TypeOther:
|
|
t = c.typ
|
|
default:
|
|
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name))
|
|
}
|
|
return t
|
|
}
|
|
|
|
// addColumn returns the DSL query for adding the given column to a table.
|
|
// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable].
|
|
func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder {
|
|
b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
|
|
c.unique(b)
|
|
if c.Increment {
|
|
b.Attr("AUTO_INCREMENT")
|
|
}
|
|
c.nullable(b)
|
|
c.defaultValue(b)
|
|
if c.Collation != "" {
|
|
b.Attr("COLLATE " + c.Collation)
|
|
}
|
|
if c.Type == field.TypeJSON {
|
|
// Manually add a `CHECK` clause for older versions of MariaDB for validating the
|
|
// JSON documents. This constraint is automatically included from version 10.4.3.
|
|
if version, ok := d.mariadb(); ok && compareVersions(version, "10.4.3") == -1 {
|
|
b.Check(func(b *sql.Builder) {
|
|
b.WriteString("JSON_VALID(").Ident(c.Name).WriteByte(')')
|
|
})
|
|
}
|
|
}
|
|
return b
|
|
}
|
|
|
|
// addIndex returns the querying for adding an index to MySQL.
|
|
func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder {
|
|
idx := sql.CreateIndex(i.Name).Table(table)
|
|
if i.Unique {
|
|
idx.Unique()
|
|
}
|
|
parts := indexParts(i)
|
|
for _, c := range i.Columns {
|
|
part, ok := parts[c.Name]
|
|
if !ok || part == 0 {
|
|
idx.Column(c.Name)
|
|
} else {
|
|
idx.Column(fmt.Sprintf("%s(%d)", idx.Builder.Quote(c.Name), part))
|
|
}
|
|
}
|
|
return idx
|
|
}
|
|
|
|
// dropIndex drops a MySQL index.
|
|
func (d *MySQL) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error {
|
|
query, args := idx.DropBuilder(table).Query()
|
|
return tx.Exec(ctx, query, args, nil)
|
|
}
|
|
|
|
// prepare runs preparation work that needs to be done to apply the change-set.
|
|
func (d *MySQL) prepare(ctx context.Context, tx dialect.Tx, change *changes, table string) error {
|
|
for _, idx := range change.index.drop {
|
|
switch n := len(idx.columns); {
|
|
case n == 0:
|
|
return fmt.Errorf("index %q has no columns", idx.Name)
|
|
case n > 1:
|
|
continue // not a foreign-key index.
|
|
}
|
|
var qr sql.Querier
|
|
Switch:
|
|
switch col, ok := change.dropColumn(idx.columns[0]); {
|
|
// If both the index and the column need to be dropped, the foreign-key
|
|
// constraint that is associated with them need to be dropped as well.
|
|
case ok:
|
|
names, err := d.fkNames(ctx, tx, table, col.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(names) == 1 {
|
|
qr = sql.AlterTable(table).DropForeignKey(names[0])
|
|
}
|
|
// If the uniqueness was dropped from a foreign-key column,
|
|
// create a "simple index" if no other index exist for it.
|
|
case !ok && idx.Unique && len(idx.Columns) > 0:
|
|
col := idx.Columns[0]
|
|
for _, idx2 := range col.indexes {
|
|
if idx2 != idx && len(idx2.columns) == 1 {
|
|
break Switch
|
|
}
|
|
}
|
|
names, err := d.fkNames(ctx, tx, table, col.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(names) == 1 {
|
|
qr = sql.CreateIndex(names[0]).Table(table).Columns(col.Name)
|
|
}
|
|
}
|
|
if qr != nil {
|
|
query, args := qr.Query()
|
|
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// scanColumn scans the column information from MySQL column description.
|
|
func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
|
|
var (
|
|
nullable sql.NullString
|
|
defaults sql.NullString
|
|
numericPrecision sql.NullInt64
|
|
numericScale sql.NullInt64
|
|
)
|
|
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}, &numericPrecision, &numericScale); err != nil {
|
|
return fmt.Errorf("scanning column description: %w", err)
|
|
}
|
|
c.Unique = c.UniqueKey()
|
|
if nullable.Valid {
|
|
c.Nullable = nullable.String == "YES"
|
|
}
|
|
if c.typ == "" {
|
|
return fmt.Errorf("missing type information for column %q", c.Name)
|
|
}
|
|
parts, size, unsigned, err := parseColumn(c.typ)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
switch parts[0] {
|
|
case "mediumint", "int":
|
|
c.Type = field.TypeInt32
|
|
if unsigned {
|
|
c.Type = field.TypeUint32
|
|
}
|
|
case "smallint":
|
|
c.Type = field.TypeInt16
|
|
if unsigned {
|
|
c.Type = field.TypeUint16
|
|
}
|
|
case "bigint":
|
|
c.Type = field.TypeInt64
|
|
if unsigned {
|
|
c.Type = field.TypeUint64
|
|
}
|
|
case "tinyint":
|
|
switch {
|
|
case size == 1:
|
|
c.Type = field.TypeBool
|
|
case unsigned:
|
|
c.Type = field.TypeUint8
|
|
default:
|
|
c.Type = field.TypeInt8
|
|
}
|
|
case "double", "float":
|
|
c.Type = field.TypeFloat64
|
|
case "numeric", "decimal":
|
|
c.Type = field.TypeFloat64
|
|
// If precision is specified then we should take that into account.
|
|
if numericPrecision.Valid {
|
|
schemaType := fmt.Sprintf("%s(%d,%d)", parts[0], numericPrecision.Int64, numericScale.Int64)
|
|
c.SchemaType = map[string]string{dialect.MySQL: schemaType}
|
|
}
|
|
case "time", "timestamp", "date", "datetime":
|
|
c.Type = field.TypeTime
|
|
// The mapping from schema defaults to database
|
|
// defaults is not supported for TypeTime fields.
|
|
defaults = sql.NullString{}
|
|
case "tinyblob":
|
|
c.Size = math.MaxUint8
|
|
c.Type = field.TypeBytes
|
|
case "blob":
|
|
c.Size = math.MaxUint16
|
|
c.Type = field.TypeBytes
|
|
case "mediumblob":
|
|
c.Size = 1<<24 - 1
|
|
c.Type = field.TypeBytes
|
|
case "longblob":
|
|
c.Size = math.MaxUint32
|
|
c.Type = field.TypeBytes
|
|
case "binary", "varbinary":
|
|
c.Type = field.TypeBytes
|
|
c.Size = size
|
|
case "varchar":
|
|
c.Type = field.TypeString
|
|
c.Size = size
|
|
case "text":
|
|
c.Size = math.MaxUint16
|
|
c.Type = field.TypeString
|
|
case "mediumtext":
|
|
c.Size = 1<<24 - 1
|
|
c.Type = field.TypeString
|
|
case "longtext":
|
|
c.Size = math.MaxInt32
|
|
c.Type = field.TypeString
|
|
case "json":
|
|
c.Type = field.TypeJSON
|
|
case "enum":
|
|
c.Type = field.TypeEnum
|
|
// Parse the enum values according to the MySQL format.
|
|
// github.com/mysql/mysql-server/blob/8.0/sql/field.cc#Field_enum::sql_type
|
|
values := strings.TrimSuffix(strings.TrimPrefix(c.typ, "enum("), ")")
|
|
if values == "" {
|
|
return fmt.Errorf("mysql: unexpected enum type: %q", c.typ)
|
|
}
|
|
parts := strings.Split(values, "','")
|
|
for i := range parts {
|
|
c.Enums = append(c.Enums, strings.Trim(parts[i], "'"))
|
|
}
|
|
case "char":
|
|
c.Type = field.TypeOther
|
|
// UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens).
|
|
if size == 36 {
|
|
c.Type = field.TypeUUID
|
|
}
|
|
case "point", "geometry", "linestring", "polygon":
|
|
c.Type = field.TypeOther
|
|
default:
|
|
return fmt.Errorf("unknown column type %q for version %q", parts[0], d.version)
|
|
}
|
|
if defaults.Valid {
|
|
return c.ScanDefault(defaults.String)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows,
|
|
// should return the following 5 columns: INDEX_NAME, COLUMN_NAME, SUB_PART, NON_UNIQUE,
|
|
// SEQ_IN_INDEX. SEQ_IN_INDEX specifies the position of the column in the index columns.
|
|
func (d *MySQL) scanIndexes(rows *sql.Rows, t *Table) (Indexes, error) {
|
|
var (
|
|
i Indexes
|
|
names = make(map[string]*Index)
|
|
)
|
|
for rows.Next() {
|
|
var (
|
|
name string
|
|
column string
|
|
nonuniq bool
|
|
seqindex int
|
|
subpart sql.NullInt64
|
|
)
|
|
if err := rows.Scan(&name, &column, &subpart, &nonuniq, &seqindex); err != nil {
|
|
return nil, fmt.Errorf("scanning index description: %w", err)
|
|
}
|
|
// Skip primary keys.
|
|
if name == "PRIMARY" {
|
|
c, ok := t.column(column)
|
|
if !ok {
|
|
return nil, fmt.Errorf("missing primary-key column: %q", column)
|
|
}
|
|
t.PrimaryKey = append(t.PrimaryKey, c)
|
|
continue
|
|
}
|
|
idx, ok := names[name]
|
|
if !ok {
|
|
idx = &Index{Name: name, Unique: !nonuniq, Annotation: &entsql.IndexAnnotation{}}
|
|
i = append(i, idx)
|
|
names[name] = idx
|
|
}
|
|
idx.columns = append(idx.columns, column)
|
|
if subpart.Int64 > 0 {
|
|
if idx.Annotation.PrefixColumns == nil {
|
|
idx.Annotation.PrefixColumns = make(map[string]uint)
|
|
}
|
|
idx.Annotation.PrefixColumns[column] = uint(subpart.Int64)
|
|
}
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return i, nil
|
|
}
|
|
|
|
// isImplicitIndex reports if the index was created implicitly for the unique column.
|
|
func (d *MySQL) isImplicitIndex(idx *Index, col *Column) bool {
|
|
// We execute `CHANGE COLUMN` on older versions of MySQL (<8.0), which
|
|
// auto create the new index. The old one, will be dropped in `changeSet`.
|
|
if compareVersions(d.version, "8.0.0") >= 0 {
|
|
return idx.Name == col.Name && col.Unique
|
|
}
|
|
return false
|
|
}
|
|
|
|
// renameColumn returns the statement for renaming a column in
|
|
// MySQL based on its version.
|
|
func (d *MySQL) renameColumn(t *Table, old, new *Column) sql.Querier {
|
|
q := sql.AlterTable(t.Name)
|
|
if compareVersions(d.version, "8.0.0") >= 0 {
|
|
return q.RenameColumn(old.Name, new.Name)
|
|
}
|
|
return q.ChangeColumn(old.Name, d.addColumn(new))
|
|
}
|
|
|
|
// renameIndex returns the statement for renaming an index.
|
|
func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier {
|
|
q := sql.AlterTable(t.Name)
|
|
if compareVersions(d.version, "5.7.0") >= 0 {
|
|
return q.RenameIndex(old.Name, new.Name)
|
|
}
|
|
return q.DropIndex(old.Name).AddIndex(new.Builder(t.Name))
|
|
}
|
|
|
|
// matchSchema returns the predicate for matching table schema.
|
|
func (d *MySQL) matchSchema(columns ...string) *sql.Predicate {
|
|
column := "TABLE_SCHEMA"
|
|
if len(columns) > 0 {
|
|
column = columns[0]
|
|
}
|
|
if d.schema != "" {
|
|
return sql.EQ(column, d.schema)
|
|
}
|
|
return sql.EQ(column, sql.Raw("(SELECT DATABASE())"))
|
|
}
|
|
|
|
// tables returns the query for getting the in the schema.
|
|
func (d *MySQL) tables() sql.Querier {
|
|
return sql.Select("TABLE_NAME").
|
|
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
|
Where(d.matchSchema())
|
|
}
|
|
|
|
// alterColumns returns the queries for applying the columns change-set.
|
|
func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
|
|
b := sql.Dialect(dialect.MySQL).AlterTable(table)
|
|
for _, c := range add {
|
|
b.AddColumn(d.addColumn(c))
|
|
}
|
|
for _, c := range modify {
|
|
b.ModifyColumn(d.addColumn(c))
|
|
}
|
|
for _, c := range drop {
|
|
b.DropColumn(sql.Dialect(dialect.MySQL).Column(c.Name))
|
|
}
|
|
if len(b.Queries) == 0 {
|
|
return nil
|
|
}
|
|
return sql.Queries{b}
|
|
}
|
|
|
|
// normalizeJSON normalize MariaDB longtext columns to type JSON.
|
|
func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) error {
|
|
columns := make(map[string]*Column)
|
|
for _, c := range t.Columns {
|
|
if c.typ == "longtext" {
|
|
columns[c.Name] = c
|
|
}
|
|
}
|
|
if len(columns) == 0 {
|
|
return nil
|
|
}
|
|
rows := &sql.Rows{}
|
|
query, args := sql.Select("CONSTRAINT_NAME").
|
|
From(sql.Table("CHECK_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
|
|
Where(sql.And(
|
|
d.matchSchema("CONSTRAINT_SCHEMA"),
|
|
sql.EQ("TABLE_NAME", t.Name),
|
|
sql.Like("CHECK_CLAUSE", "json_valid(%)"),
|
|
)).
|
|
Query()
|
|
if err := tx.Query(ctx, query, args, rows); err != nil {
|
|
return fmt.Errorf("mysql: query table constraints %w", err)
|
|
}
|
|
// Call Close in cases of failures (Close is idempotent).
|
|
defer rows.Close()
|
|
names := make([]string, 0, len(columns))
|
|
if err := sql.ScanSlice(rows, &names); err != nil {
|
|
return fmt.Errorf("mysql: scan table constraints: %w", err)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
if err := rows.Close(); err != nil {
|
|
return err
|
|
}
|
|
for _, name := range names {
|
|
c, ok := columns[name]
|
|
if ok {
|
|
c.Type = field.TypeJSON
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// mariadb reports if the migration runs on MariaDB and returns the semver string.
|
|
func (d *MySQL) mariadb() (string, bool) {
|
|
idx := strings.Index(d.version, "MariaDB")
|
|
if idx == -1 {
|
|
return "", false
|
|
}
|
|
return d.version[:idx-1], true
|
|
}
|
|
|
|
// parseColumn returns column parts, size and signed-info from a MySQL type.
|
|
func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) {
|
|
switch parts = strings.FieldsFunc(typ, func(r rune) bool {
|
|
return r == '(' || r == ')' || r == ' ' || r == ','
|
|
}); parts[0] {
|
|
case "tinyint", "smallint", "mediumint", "int", "bigint":
|
|
switch {
|
|
case len(parts) == 2 && parts[1] == "unsigned": // int unsigned
|
|
unsigned = true
|
|
case len(parts) == 3: // int(10) unsigned
|
|
unsigned = true
|
|
fallthrough
|
|
case len(parts) == 2: // int(10)
|
|
size, err = strconv.ParseInt(parts[1], 10, 0)
|
|
}
|
|
case "varbinary", "varchar", "char", "binary":
|
|
if len(parts) > 1 {
|
|
size, err = strconv.ParseInt(parts[1], 10, 64)
|
|
}
|
|
}
|
|
if err != nil {
|
|
return parts, size, unsigned, fmt.Errorf("converting %s size to int: %w", parts[0], err)
|
|
}
|
|
return parts, size, unsigned, nil
|
|
}
|
|
|
|
// fkNames returns the foreign-key names of a column.
|
|
func (d *MySQL) fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) {
|
|
query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("KEY_COLUMN_USAGE").Schema("INFORMATION_SCHEMA")).
|
|
Where(sql.And(
|
|
sql.EQ("TABLE_NAME", table),
|
|
sql.EQ("COLUMN_NAME", column),
|
|
// NULL for unique and primary-key constraints.
|
|
sql.NotNull("POSITION_IN_UNIQUE_CONSTRAINT"),
|
|
d.matchSchema(),
|
|
)).
|
|
Query()
|
|
var (
|
|
names []string
|
|
rows = &sql.Rows{}
|
|
)
|
|
if err := tx.Query(ctx, query, args, rows); err != nil {
|
|
return nil, fmt.Errorf("mysql: reading constraint names %w", err)
|
|
}
|
|
defer rows.Close()
|
|
if err := sql.ScanSlice(rows, &names); err != nil {
|
|
return nil, err
|
|
}
|
|
return names, nil
|
|
}
|
|
|
|
// defaultSize returns the default size for MySQL/MariaDB varchar type
|
|
// based on column size, charset and table indexes, in order to avoid
|
|
// index prefix key limit (767) for older versions of MySQL/MariaDB.
|
|
func (d *MySQL) defaultSize(c *Column) int64 {
|
|
size := DefaultStringLen
|
|
version, checked := d.version, "5.7.0"
|
|
if v, ok := d.mariadb(); ok {
|
|
version, checked = v, "10.2.2"
|
|
}
|
|
switch {
|
|
// Version is >= 5.7 for MySQL, or >= 10.2.2 for MariaDB.
|
|
case compareVersions(version, checked) != -1:
|
|
// Column is non-unique, or not part of any index (reaching
|
|
// the error 1071).
|
|
case !c.Unique && len(c.indexes) == 0 && !c.PrimaryKey():
|
|
default:
|
|
size = 191
|
|
}
|
|
return size
|
|
}
|
|
|
|
// needsConversion reports if column "old" needs to be converted
|
|
// (by table altering) to column "new".
|
|
func (d *MySQL) needsConversion(old, new *Column) bool {
|
|
return d.cType(old) != d.cType(new)
|
|
}
|
|
|
|
// indexModified used by the migration differ to check if the index was modified.
|
|
func (d *MySQL) indexModified(old, new *Index) bool {
|
|
oldParts, newParts := indexParts(old), indexParts(new)
|
|
if len(oldParts) != len(newParts) {
|
|
return true
|
|
}
|
|
for column, oldPart := range oldParts {
|
|
newPart, ok := newParts[column]
|
|
if !ok || oldPart != newPart {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// indexParts returns a map holding the sub_part mapping if exist.
|
|
func indexParts(idx *Index) map[string]uint {
|
|
parts := make(map[string]uint)
|
|
if idx.Annotation == nil {
|
|
return parts
|
|
}
|
|
// If prefix (without a name) was defined on the
|
|
// annotation, map it to the single column index.
|
|
if idx.Annotation.Prefix > 0 && len(idx.Columns) == 1 {
|
|
parts[idx.Columns[0].Name] = idx.Annotation.Prefix
|
|
}
|
|
for column, part := range idx.Annotation.PrefixColumns {
|
|
parts[column] = part
|
|
}
|
|
return parts
|
|
}
|
|
|
|
// Atlas integration.
|
|
|
|
func (d *MySQL) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
|
|
return mysql.Open(&db{ExecQuerier: conn})
|
|
}
|
|
|
|
func (d *MySQL) atTable(t1 *Table, t2 *schema.Table) {
|
|
t2.SetCharset("utf8mb4").SetCollation("utf8mb4_bin")
|
|
if t1.Annotation == nil {
|
|
return
|
|
}
|
|
if charset := t1.Annotation.Charset; charset != "" {
|
|
t2.SetCharset(charset)
|
|
}
|
|
if collate := t1.Annotation.Collation; collate != "" {
|
|
t2.SetCollation(collate)
|
|
}
|
|
if opts := t1.Annotation.Options; opts != "" {
|
|
t2.AddAttrs(&mysql.CreateOptions{
|
|
V: opts,
|
|
})
|
|
}
|
|
// Check if the connected database supports the CHECK clause.
|
|
// For MySQL, is >= "8.0.16" and for MariaDB it is "10.2.1".
|
|
v1, v2 := d.version, "8.0.16"
|
|
if v, ok := d.mariadb(); ok {
|
|
v1, v2 = v, "10.2.1"
|
|
}
|
|
if compareVersions(v1, v2) >= 0 {
|
|
setAtChecks(t1, t2)
|
|
}
|
|
}
|
|
|
|
func (d *MySQL) atTypeC(c1 *Column, c2 *schema.Column) error {
|
|
if c1.SchemaType != nil && c1.SchemaType[dialect.MySQL] != "" {
|
|
t, err := mysql.ParseType(strings.ToLower(c1.SchemaType[dialect.MySQL]))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c2.Type.Type = t
|
|
return nil
|
|
}
|
|
var t schema.Type
|
|
switch c1.Type {
|
|
case field.TypeBool:
|
|
t = &schema.BoolType{T: "boolean"}
|
|
case field.TypeInt8:
|
|
t = &schema.IntegerType{T: mysql.TypeTinyInt}
|
|
case field.TypeUint8:
|
|
t = &schema.IntegerType{T: mysql.TypeTinyInt, Unsigned: true}
|
|
case field.TypeInt16:
|
|
t = &schema.IntegerType{T: mysql.TypeSmallInt}
|
|
case field.TypeUint16:
|
|
t = &schema.IntegerType{T: mysql.TypeSmallInt, Unsigned: true}
|
|
case field.TypeInt32:
|
|
t = &schema.IntegerType{T: mysql.TypeInt}
|
|
case field.TypeUint32:
|
|
t = &schema.IntegerType{T: mysql.TypeInt, Unsigned: true}
|
|
case field.TypeInt, field.TypeInt64:
|
|
t = &schema.IntegerType{T: mysql.TypeBigInt}
|
|
case field.TypeUint, field.TypeUint64:
|
|
t = &schema.IntegerType{T: mysql.TypeBigInt, Unsigned: true}
|
|
case field.TypeBytes:
|
|
size := int64(math.MaxUint16)
|
|
if c1.Size > 0 {
|
|
size = c1.Size
|
|
}
|
|
switch {
|
|
case size <= math.MaxUint8:
|
|
t = &schema.BinaryType{T: mysql.TypeTinyBlob}
|
|
case size <= math.MaxUint16:
|
|
t = &schema.BinaryType{T: mysql.TypeBlob}
|
|
case size < 1<<24:
|
|
t = &schema.BinaryType{T: mysql.TypeMediumBlob}
|
|
case size <= math.MaxUint32:
|
|
t = &schema.BinaryType{T: mysql.TypeLongBlob}
|
|
}
|
|
case field.TypeJSON:
|
|
t = &schema.JSONType{T: mysql.TypeJSON}
|
|
if compareVersions(d.version, "5.7.8") == -1 {
|
|
t = &schema.BinaryType{T: mysql.TypeLongBlob}
|
|
}
|
|
case field.TypeString:
|
|
size := c1.Size
|
|
if size == 0 {
|
|
size = d.defaultSize(c1)
|
|
}
|
|
switch {
|
|
case c1.typ == "tinytext", c1.typ == "text":
|
|
t = &schema.StringType{T: c1.typ}
|
|
case size <= math.MaxUint16:
|
|
t = &schema.StringType{T: mysql.TypeVarchar, Size: int(size)}
|
|
case size == 1<<24-1:
|
|
t = &schema.StringType{T: mysql.TypeMediumText}
|
|
default:
|
|
t = &schema.StringType{T: mysql.TypeLongText}
|
|
}
|
|
case field.TypeFloat32, field.TypeFloat64:
|
|
t = &schema.FloatType{T: c1.scanTypeOr(mysql.TypeDouble)}
|
|
case field.TypeTime:
|
|
t = &schema.TimeType{T: c1.scanTypeOr(mysql.TypeTimestamp)}
|
|
// In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP`
|
|
// and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is
|
|
// suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute.
|
|
if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c1.Default == nil {
|
|
c2.SetNull(c1.Attr == "")
|
|
}
|
|
case field.TypeEnum:
|
|
t = &schema.EnumType{T: mysql.TypeEnum, Values: c1.Enums}
|
|
case field.TypeUUID:
|
|
// "CHAR(X) BINARY" is treated as "CHAR(X) COLLATE latin1_bin", and in MySQL < 8,
|
|
// and "COLLATE utf8mb4_bin" in MySQL >= 8. However we already set the table to
|
|
t = &schema.StringType{T: mysql.TypeChar, Size: 36}
|
|
c2.SetCollation("utf8mb4_bin")
|
|
default:
|
|
t, err := mysql.ParseType(strings.ToLower(c1.typ))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c2.Type.Type = t
|
|
}
|
|
c2.Type.Type = t
|
|
return nil
|
|
}
|
|
|
|
func (d *MySQL) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) {
|
|
// For UNIQUE columns, MySQL create an implicit index
|
|
// named as the column with an extra index in case the
|
|
// name is already taken (<e.g. c>, <c_2>, <c_3>, ...).
|
|
for _, idx := range t1.Indexes {
|
|
// Index also defined explicitly, and will be add in atIndexes.
|
|
if idx.Unique && d.atImplicitIndexName(idx, c1) {
|
|
return
|
|
}
|
|
}
|
|
t2.AddIndexes(schema.NewUniqueIndex(c1.Name).AddColumns(c2))
|
|
}
|
|
|
|
func (d *MySQL) atIncrementC(_ *schema.Table, c *schema.Column) {
|
|
c.AddAttrs(&mysql.AutoIncrement{})
|
|
}
|
|
|
|
func (d *MySQL) atIncrementT(t *schema.Table, v int64) {
|
|
t.AddAttrs(&mysql.AutoIncrement{V: v})
|
|
}
|
|
|
|
func (d *MySQL) atImplicitIndexName(idx *Index, c1 *Column) bool {
|
|
if idx.Name == c1.Name {
|
|
return true
|
|
}
|
|
if !strings.HasPrefix(idx.Name, c1.Name+"_") {
|
|
return false
|
|
}
|
|
i, err := strconv.ParseInt(strings.TrimLeft(idx.Name, c1.Name+"_"), 10, 64)
|
|
return err == nil && i > 1
|
|
}
|
|
|
|
func (d *MySQL) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error {
|
|
prefix := indexParts(idx1)
|
|
for _, c1 := range idx1.Columns {
|
|
c2, ok := t2.Column(c1.Name)
|
|
if !ok {
|
|
return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name)
|
|
}
|
|
part := &schema.IndexPart{C: c2}
|
|
if v, ok := prefix[c1.Name]; ok {
|
|
part.AddAttrs(&mysql.SubPart{Len: int(v)})
|
|
}
|
|
idx2.AddParts(part)
|
|
}
|
|
if t, ok := indexType(idx1, dialect.MySQL); ok {
|
|
idx2.AddAttrs(&mysql.IndexType{T: t})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func indexType(idx *Index, d string) (string, bool) {
|
|
ant := idx.Annotation
|
|
if ant == nil {
|
|
return "", false
|
|
}
|
|
if ant.Types != nil && ant.Types[d] != "" {
|
|
return ant.Types[d], true
|
|
}
|
|
if ant.Type != "" {
|
|
return ant.Type, true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func (MySQL) atTypeRangeSQL(ts ...string) string {
|
|
for i := range ts {
|
|
ts[i] = fmt.Sprintf("('%s')", ts[i])
|
|
}
|
|
return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", "))
|
|
}
|