Files
ent/dialect/sql/schema/mysql.go
Jannik Clausen 7017cbc898 dialect/sql/schema: file based type store (#2644)
* dialect/sql/schema: file based type store

This PR adds support for a file based type storage when using versioned migrations. The file called `.ent_types` is written to the migration directory alongside the migration files and will be kept in sync for every migration file generation run.

In order to not break existing code, where the type storage might differ for different deployment, global unique ID mut be enabled by using a new option. This will also be raised as an error to the user when attempting to use versioned migrations and global unique ID.

Documentation will be added to this PR once feedback on the code is gathered.

* apply CR

* fix tests

* change format of types file to exclude it from atlas.sum file

* docs and drift test

* apply CR
2022-06-15 16:10:15 +02:00

980 lines
28 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"context"
"fmt"
"math"
"strconv"
"strings"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/mysql"
"ariga.io/atlas/sql/schema"
)
// MySQL is a MySQL migration driver.
type MySQL struct {
dialect.Driver
schema string
version string
}
// init loads the MySQL version from the database for later use in the migration process.
func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error {
rows := &sql.Rows{}
if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
return fmt.Errorf("mysql: querying mysql version %w", err)
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return fmt.Errorf("mysql: version variable was not found")
}
version := make([]string, 2)
if err := rows.Scan(&version[0], &version[1]); err != nil {
return fmt.Errorf("mysql: scanning mysql version: %w", err)
}
d.version = version[1]
return nil
}
func (d *MySQL) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) {
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
d.matchSchema(),
sql.EQ("TABLE_NAME", name),
)).Query()
return exist(ctx, conn, query, args...)
}
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLE_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
d.matchSchema(),
sql.EQ("CONSTRAINT_TYPE", "FOREIGN KEY"),
sql.EQ("CONSTRAINT_NAME", name),
)).Query()
return exist(ctx, tx, query, args...)
}
// table loads the current table description from the database.
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
rows := &sql.Rows{}
query, args := sql.Select(
"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name",
"numeric_precision", "numeric_scale",
).
From(sql.Table("COLUMNS").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
d.matchSchema(),
sql.EQ("TABLE_NAME", name)),
).Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("mysql: reading table description %w", err)
}
// Call Close in cases of failures (Close is idempotent).
defer rows.Close()
t := NewTable(name)
for rows.Next() {
c := &Column{}
if err := d.scanColumn(c, rows); err != nil {
return nil, fmt.Errorf("mysql: %w", err)
}
t.AddColumn(c)
}
if err := rows.Err(); err != nil {
return nil, err
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("mysql: closing rows %w", err)
}
indexes, err := d.indexes(ctx, tx, t)
if err != nil {
return nil, err
}
// Add and link indexes to table columns.
for _, idx := range indexes {
t.addIndex(idx)
}
if _, ok := d.mariadb(); ok {
if err := d.normalizeJSON(ctx, tx, t); err != nil {
return nil, err
}
}
return t, nil
}
// table loads the table indexes from the database.
func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, t *Table) ([]*Index, error) {
rows := &sql.Rows{}
query, args := sql.Select("index_name", "column_name", "sub_part", "non_unique", "seq_in_index").
From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
d.matchSchema(),
sql.EQ("TABLE_NAME", t.Name),
)).
OrderBy("index_name", "seq_in_index").
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("mysql: reading index description %w", err)
}
defer rows.Close()
idx, err := d.scanIndexes(rows, t)
if err != nil {
return nil, fmt.Errorf("mysql: %w", err)
}
return idx, nil
}
func (d *MySQL) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []interface{}{}, nil)
}
func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expected int64) error {
if expected == 0 {
return nil
}
rows := &sql.Rows{}
query, args := sql.Select("AUTO_INCREMENT").
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
d.matchSchema(),
sql.EQ("TABLE_NAME", t.Name),
)).
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return fmt.Errorf("mysql: query auto_increment %w", err)
}
// Call Close in cases of failures (Close is idempotent).
defer rows.Close()
actual := &sql.NullInt64{}
if err := sql.ScanOne(rows, actual); err != nil {
return fmt.Errorf("mysql: scan auto_increment %w", err)
}
if err := rows.Close(); err != nil {
return err
}
// Table is empty and auto-increment is not configured. This can happen
// because MySQL (< 8.0) stores the auto-increment counter in main memory
// (not persistent), and the value is reset on restart (if table is empty).
if actual.Int64 <= 1 {
return d.setRange(ctx, tx, t, expected)
}
return nil
}
// tBuilder returns the MySQL DSL query for table creation.
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder {
b := sql.CreateTable(t.Name).IfNotExists()
for _, c := range t.Columns {
b.Column(d.addColumn(c))
}
for _, pk := range t.PrimaryKey {
b.PrimaryKey(pk.Name)
}
// Charset and collation config on MySQL table.
// These options can be overridden by the entsql annotation.
b.Charset("utf8mb4").Collate("utf8mb4_bin")
if t.Annotation != nil {
if charset := t.Annotation.Charset; charset != "" {
b.Charset(charset)
}
if collate := t.Annotation.Collation; collate != "" {
b.Collate(collate)
}
if opts := t.Annotation.Options; opts != "" {
b.Options(opts)
}
addChecks(b, t.Annotation)
}
return b
}
// cType returns the MySQL string type for the given column.
func (d *MySQL) cType(c *Column) (t string) {
if c.SchemaType != nil && c.SchemaType[dialect.MySQL] != "" {
// MySQL returns the column type lower cased.
return strings.ToLower(c.SchemaType[dialect.MySQL])
}
switch c.Type {
case field.TypeBool:
t = "boolean"
case field.TypeInt8:
t = "tinyint"
case field.TypeUint8:
t = "tinyint unsigned"
case field.TypeInt16:
t = "smallint"
case field.TypeUint16:
t = "smallint unsigned"
case field.TypeInt32:
t = "int"
case field.TypeUint32:
t = "int unsigned"
case field.TypeInt, field.TypeInt64:
t = "bigint"
case field.TypeUint, field.TypeUint64:
t = "bigint unsigned"
case field.TypeBytes:
size := int64(math.MaxUint16)
if c.Size > 0 {
size = c.Size
}
switch {
case size <= math.MaxUint8:
t = "tinyblob"
case size <= math.MaxUint16:
t = "blob"
case size < 1<<24:
t = "mediumblob"
case size <= math.MaxUint32:
t = "longblob"
}
case field.TypeJSON:
t = "json"
if compareVersions(d.version, "5.7.8") == -1 {
t = "longblob"
}
case field.TypeString:
size := c.Size
if size == 0 {
size = d.defaultSize(c)
}
switch {
case c.typ == "tinytext", c.typ == "text":
t = c.typ
case size <= math.MaxUint16:
t = fmt.Sprintf("varchar(%d)", size)
case size == 1<<24-1:
t = "mediumtext"
default:
t = "longtext"
}
case field.TypeFloat32, field.TypeFloat64:
t = c.scanTypeOr("double")
case field.TypeTime:
t = c.scanTypeOr("timestamp")
// In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP`
// and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is
// suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute.
if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c.Default == nil {
c.Nullable = c.Attr == ""
}
case field.TypeEnum:
values := make([]string, len(c.Enums))
for i, e := range c.Enums {
values[i] = fmt.Sprintf("'%s'", e)
}
t = fmt.Sprintf("enum(%s)", strings.Join(values, ", "))
case field.TypeUUID:
t = "char(36) binary"
case field.TypeOther:
t = c.typ
default:
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name))
}
return t
}
// addColumn returns the DSL query for adding the given column to a table.
// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable].
func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder {
b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
c.unique(b)
if c.Increment {
b.Attr("AUTO_INCREMENT")
}
c.nullable(b)
c.defaultValue(b)
if c.Collation != "" {
b.Attr("COLLATE " + c.Collation)
}
if c.Type == field.TypeJSON {
// Manually add a `CHECK` clause for older versions of MariaDB for validating the
// JSON documents. This constraint is automatically included from version 10.4.3.
if version, ok := d.mariadb(); ok && compareVersions(version, "10.4.3") == -1 {
b.Check(func(b *sql.Builder) {
b.WriteString("JSON_VALID(").Ident(c.Name).WriteByte(')')
})
}
}
return b
}
// addIndex returns the querying for adding an index to MySQL.
func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder {
idx := sql.CreateIndex(i.Name).Table(table)
if i.Unique {
idx.Unique()
}
parts := indexParts(i)
for _, c := range i.Columns {
part, ok := parts[c.Name]
if !ok || part == 0 {
idx.Column(c.Name)
} else {
idx.Column(fmt.Sprintf("%s(%d)", idx.Builder.Quote(c.Name), part))
}
}
return idx
}
// dropIndex drops a MySQL index.
func (d *MySQL) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error {
query, args := idx.DropBuilder(table).Query()
return tx.Exec(ctx, query, args, nil)
}
// prepare runs preparation work that needs to be done to apply the change-set.
func (d *MySQL) prepare(ctx context.Context, tx dialect.Tx, change *changes, table string) error {
for _, idx := range change.index.drop {
switch n := len(idx.columns); {
case n == 0:
return fmt.Errorf("index %q has no columns", idx.Name)
case n > 1:
continue // not a foreign-key index.
}
var qr sql.Querier
Switch:
switch col, ok := change.dropColumn(idx.columns[0]); {
// If both the index and the column need to be dropped, the foreign-key
// constraint that is associated with them need to be dropped as well.
case ok:
names, err := d.fkNames(ctx, tx, table, col.Name)
if err != nil {
return err
}
if len(names) == 1 {
qr = sql.AlterTable(table).DropForeignKey(names[0])
}
// If the uniqueness was dropped from a foreign-key column,
// create a "simple index" if no other index exist for it.
case !ok && idx.Unique && len(idx.Columns) > 0:
col := idx.Columns[0]
for _, idx2 := range col.indexes {
if idx2 != idx && len(idx2.columns) == 1 {
break Switch
}
}
names, err := d.fkNames(ctx, tx, table, col.Name)
if err != nil {
return err
}
if len(names) == 1 {
qr = sql.CreateIndex(names[0]).Table(table).Columns(col.Name)
}
}
if qr != nil {
query, args := qr.Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return err
}
}
}
return nil
}
// scanColumn scans the column information from MySQL column description.
func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
var (
nullable sql.NullString
defaults sql.NullString
numericPrecision sql.NullInt64
numericScale sql.NullInt64
)
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}, &numericPrecision, &numericScale); err != nil {
return fmt.Errorf("scanning column description: %w", err)
}
c.Unique = c.UniqueKey()
if nullable.Valid {
c.Nullable = nullable.String == "YES"
}
if c.typ == "" {
return fmt.Errorf("missing type information for column %q", c.Name)
}
parts, size, unsigned, err := parseColumn(c.typ)
if err != nil {
return err
}
switch parts[0] {
case "mediumint", "int":
c.Type = field.TypeInt32
if unsigned {
c.Type = field.TypeUint32
}
case "smallint":
c.Type = field.TypeInt16
if unsigned {
c.Type = field.TypeUint16
}
case "bigint":
c.Type = field.TypeInt64
if unsigned {
c.Type = field.TypeUint64
}
case "tinyint":
switch {
case size == 1:
c.Type = field.TypeBool
case unsigned:
c.Type = field.TypeUint8
default:
c.Type = field.TypeInt8
}
case "double", "float":
c.Type = field.TypeFloat64
case "numeric", "decimal":
c.Type = field.TypeFloat64
// If precision is specified then we should take that into account.
if numericPrecision.Valid {
schemaType := fmt.Sprintf("%s(%d,%d)", parts[0], numericPrecision.Int64, numericScale.Int64)
c.SchemaType = map[string]string{dialect.MySQL: schemaType}
}
case "time", "timestamp", "date", "datetime":
c.Type = field.TypeTime
// The mapping from schema defaults to database
// defaults is not supported for TypeTime fields.
defaults = sql.NullString{}
case "tinyblob":
c.Size = math.MaxUint8
c.Type = field.TypeBytes
case "blob":
c.Size = math.MaxUint16
c.Type = field.TypeBytes
case "mediumblob":
c.Size = 1<<24 - 1
c.Type = field.TypeBytes
case "longblob":
c.Size = math.MaxUint32
c.Type = field.TypeBytes
case "binary", "varbinary":
c.Type = field.TypeBytes
c.Size = size
case "varchar":
c.Type = field.TypeString
c.Size = size
case "text":
c.Size = math.MaxUint16
c.Type = field.TypeString
case "mediumtext":
c.Size = 1<<24 - 1
c.Type = field.TypeString
case "longtext":
c.Size = math.MaxInt32
c.Type = field.TypeString
case "json":
c.Type = field.TypeJSON
case "enum":
c.Type = field.TypeEnum
// Parse the enum values according to the MySQL format.
// github.com/mysql/mysql-server/blob/8.0/sql/field.cc#Field_enum::sql_type
values := strings.TrimSuffix(strings.TrimPrefix(c.typ, "enum("), ")")
if values == "" {
return fmt.Errorf("mysql: unexpected enum type: %q", c.typ)
}
parts := strings.Split(values, "','")
for i := range parts {
c.Enums = append(c.Enums, strings.Trim(parts[i], "'"))
}
case "char":
c.Type = field.TypeOther
// UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens).
if size == 36 {
c.Type = field.TypeUUID
}
case "point", "geometry", "linestring", "polygon":
c.Type = field.TypeOther
default:
return fmt.Errorf("unknown column type %q for version %q", parts[0], d.version)
}
if defaults.Valid {
return c.ScanDefault(defaults.String)
}
return nil
}
// scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows,
// should return the following 5 columns: INDEX_NAME, COLUMN_NAME, SUB_PART, NON_UNIQUE,
// SEQ_IN_INDEX. SEQ_IN_INDEX specifies the position of the column in the index columns.
func (d *MySQL) scanIndexes(rows *sql.Rows, t *Table) (Indexes, error) {
var (
i Indexes
names = make(map[string]*Index)
)
for rows.Next() {
var (
name string
column string
nonuniq bool
seqindex int
subpart sql.NullInt64
)
if err := rows.Scan(&name, &column, &subpart, &nonuniq, &seqindex); err != nil {
return nil, fmt.Errorf("scanning index description: %w", err)
}
// Skip primary keys.
if name == "PRIMARY" {
c, ok := t.column(column)
if !ok {
return nil, fmt.Errorf("missing primary-key column: %q", column)
}
t.PrimaryKey = append(t.PrimaryKey, c)
continue
}
idx, ok := names[name]
if !ok {
idx = &Index{Name: name, Unique: !nonuniq, Annotation: &entsql.IndexAnnotation{}}
i = append(i, idx)
names[name] = idx
}
idx.columns = append(idx.columns, column)
if subpart.Int64 > 0 {
if idx.Annotation.PrefixColumns == nil {
idx.Annotation.PrefixColumns = make(map[string]uint)
}
idx.Annotation.PrefixColumns[column] = uint(subpart.Int64)
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return i, nil
}
// isImplicitIndex reports if the index was created implicitly for the unique column.
func (d *MySQL) isImplicitIndex(idx *Index, col *Column) bool {
// We execute `CHANGE COLUMN` on older versions of MySQL (<8.0), which
// auto create the new index. The old one, will be dropped in `changeSet`.
if compareVersions(d.version, "8.0.0") >= 0 {
return idx.Name == col.Name && col.Unique
}
return false
}
// renameColumn returns the statement for renaming a column in
// MySQL based on its version.
func (d *MySQL) renameColumn(t *Table, old, new *Column) sql.Querier {
q := sql.AlterTable(t.Name)
if compareVersions(d.version, "8.0.0") >= 0 {
return q.RenameColumn(old.Name, new.Name)
}
return q.ChangeColumn(old.Name, d.addColumn(new))
}
// renameIndex returns the statement for renaming an index.
func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier {
q := sql.AlterTable(t.Name)
if compareVersions(d.version, "5.7.0") >= 0 {
return q.RenameIndex(old.Name, new.Name)
}
return q.DropIndex(old.Name).AddIndex(new.Builder(t.Name))
}
// matchSchema returns the predicate for matching table schema.
func (d *MySQL) matchSchema(columns ...string) *sql.Predicate {
column := "TABLE_SCHEMA"
if len(columns) > 0 {
column = columns[0]
}
if d.schema != "" {
return sql.EQ(column, d.schema)
}
return sql.EQ(column, sql.Raw("(SELECT DATABASE())"))
}
// tables returns the query for getting the in the schema.
func (d *MySQL) tables() sql.Querier {
return sql.Select("TABLE_NAME").
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
Where(d.matchSchema())
}
// alterColumns returns the queries for applying the columns change-set.
func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
b := sql.Dialect(dialect.MySQL).AlterTable(table)
for _, c := range add {
b.AddColumn(d.addColumn(c))
}
for _, c := range modify {
b.ModifyColumn(d.addColumn(c))
}
for _, c := range drop {
b.DropColumn(sql.Dialect(dialect.MySQL).Column(c.Name))
}
if len(b.Queries) == 0 {
return nil
}
return sql.Queries{b}
}
// normalizeJSON normalize MariaDB longtext columns to type JSON.
func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) error {
columns := make(map[string]*Column)
for _, c := range t.Columns {
if c.typ == "longtext" {
columns[c.Name] = c
}
}
if len(columns) == 0 {
return nil
}
rows := &sql.Rows{}
query, args := sql.Select("CONSTRAINT_NAME").
From(sql.Table("CHECK_CONSTRAINTS").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
d.matchSchema("CONSTRAINT_SCHEMA"),
sql.EQ("TABLE_NAME", t.Name),
sql.Like("CHECK_CLAUSE", "json_valid(%)"),
)).
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return fmt.Errorf("mysql: query table constraints %w", err)
}
// Call Close in cases of failures (Close is idempotent).
defer rows.Close()
names := make([]string, 0, len(columns))
if err := sql.ScanSlice(rows, &names); err != nil {
return fmt.Errorf("mysql: scan table constraints: %w", err)
}
if err := rows.Err(); err != nil {
return err
}
if err := rows.Close(); err != nil {
return err
}
for _, name := range names {
c, ok := columns[name]
if ok {
c.Type = field.TypeJSON
}
}
return nil
}
// mariadb reports if the migration runs on MariaDB and returns the semver string.
func (d *MySQL) mariadb() (string, bool) {
idx := strings.Index(d.version, "MariaDB")
if idx == -1 {
return "", false
}
return d.version[:idx-1], true
}
// parseColumn returns column parts, size and signed-info from a MySQL type.
func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) {
switch parts = strings.FieldsFunc(typ, func(r rune) bool {
return r == '(' || r == ')' || r == ' ' || r == ','
}); parts[0] {
case "tinyint", "smallint", "mediumint", "int", "bigint":
switch {
case len(parts) == 2 && parts[1] == "unsigned": // int unsigned
unsigned = true
case len(parts) == 3: // int(10) unsigned
unsigned = true
fallthrough
case len(parts) == 2: // int(10)
size, err = strconv.ParseInt(parts[1], 10, 0)
}
case "varbinary", "varchar", "char", "binary":
if len(parts) > 1 {
size, err = strconv.ParseInt(parts[1], 10, 64)
}
}
if err != nil {
return parts, size, unsigned, fmt.Errorf("converting %s size to int: %w", parts[0], err)
}
return parts, size, unsigned, nil
}
// fkNames returns the foreign-key names of a column.
func (d *MySQL) fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) {
query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("KEY_COLUMN_USAGE").Schema("INFORMATION_SCHEMA")).
Where(sql.And(
sql.EQ("TABLE_NAME", table),
sql.EQ("COLUMN_NAME", column),
// NULL for unique and primary-key constraints.
sql.NotNull("POSITION_IN_UNIQUE_CONSTRAINT"),
d.matchSchema(),
)).
Query()
var (
names []string
rows = &sql.Rows{}
)
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("mysql: reading constraint names %w", err)
}
defer rows.Close()
if err := sql.ScanSlice(rows, &names); err != nil {
return nil, err
}
return names, nil
}
// defaultSize returns the default size for MySQL/MariaDB varchar type
// based on column size, charset and table indexes, in order to avoid
// index prefix key limit (767) for older versions of MySQL/MariaDB.
func (d *MySQL) defaultSize(c *Column) int64 {
size := DefaultStringLen
version, checked := d.version, "5.7.0"
if v, ok := d.mariadb(); ok {
version, checked = v, "10.2.2"
}
switch {
// Version is >= 5.7 for MySQL, or >= 10.2.2 for MariaDB.
case compareVersions(version, checked) != -1:
// Column is non-unique, or not part of any index (reaching
// the error 1071).
case !c.Unique && len(c.indexes) == 0 && !c.PrimaryKey():
default:
size = 191
}
return size
}
// needsConversion reports if column "old" needs to be converted
// (by table altering) to column "new".
func (d *MySQL) needsConversion(old, new *Column) bool {
return d.cType(old) != d.cType(new)
}
// indexModified used by the migration differ to check if the index was modified.
func (d *MySQL) indexModified(old, new *Index) bool {
oldParts, newParts := indexParts(old), indexParts(new)
if len(oldParts) != len(newParts) {
return true
}
for column, oldPart := range oldParts {
newPart, ok := newParts[column]
if !ok || oldPart != newPart {
return true
}
}
return false
}
// indexParts returns a map holding the sub_part mapping if exist.
func indexParts(idx *Index) map[string]uint {
parts := make(map[string]uint)
if idx.Annotation == nil {
return parts
}
// If prefix (without a name) was defined on the
// annotation, map it to the single column index.
if idx.Annotation.Prefix > 0 && len(idx.Columns) == 1 {
parts[idx.Columns[0].Name] = idx.Annotation.Prefix
}
for column, part := range idx.Annotation.PrefixColumns {
parts[column] = part
}
return parts
}
// Atlas integration.
func (d *MySQL) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
return mysql.Open(&db{ExecQuerier: conn})
}
func (d *MySQL) atTable(t1 *Table, t2 *schema.Table) {
t2.SetCharset("utf8mb4").SetCollation("utf8mb4_bin")
if t1.Annotation == nil {
return
}
if charset := t1.Annotation.Charset; charset != "" {
t2.SetCharset(charset)
}
if collate := t1.Annotation.Collation; collate != "" {
t2.SetCollation(collate)
}
if opts := t1.Annotation.Options; opts != "" {
t2.AddAttrs(&mysql.CreateOptions{
V: opts,
})
}
// Check if the connected database supports the CHECK clause.
// For MySQL, is >= "8.0.16" and for MariaDB it is "10.2.1".
v1, v2 := d.version, "8.0.16"
if v, ok := d.mariadb(); ok {
v1, v2 = v, "10.2.1"
}
if compareVersions(v1, v2) >= 0 {
setAtChecks(t1, t2)
}
}
func (d *MySQL) atTypeC(c1 *Column, c2 *schema.Column) error {
if c1.SchemaType != nil && c1.SchemaType[dialect.MySQL] != "" {
t, err := mysql.ParseType(strings.ToLower(c1.SchemaType[dialect.MySQL]))
if err != nil {
return err
}
c2.Type.Type = t
return nil
}
var t schema.Type
switch c1.Type {
case field.TypeBool:
t = &schema.BoolType{T: "boolean"}
case field.TypeInt8:
t = &schema.IntegerType{T: mysql.TypeTinyInt}
case field.TypeUint8:
t = &schema.IntegerType{T: mysql.TypeTinyInt, Unsigned: true}
case field.TypeInt16:
t = &schema.IntegerType{T: mysql.TypeSmallInt}
case field.TypeUint16:
t = &schema.IntegerType{T: mysql.TypeSmallInt, Unsigned: true}
case field.TypeInt32:
t = &schema.IntegerType{T: mysql.TypeInt}
case field.TypeUint32:
t = &schema.IntegerType{T: mysql.TypeInt, Unsigned: true}
case field.TypeInt, field.TypeInt64:
t = &schema.IntegerType{T: mysql.TypeBigInt}
case field.TypeUint, field.TypeUint64:
t = &schema.IntegerType{T: mysql.TypeBigInt, Unsigned: true}
case field.TypeBytes:
size := int64(math.MaxUint16)
if c1.Size > 0 {
size = c1.Size
}
switch {
case size <= math.MaxUint8:
t = &schema.BinaryType{T: mysql.TypeTinyBlob}
case size <= math.MaxUint16:
t = &schema.BinaryType{T: mysql.TypeBlob}
case size < 1<<24:
t = &schema.BinaryType{T: mysql.TypeMediumBlob}
case size <= math.MaxUint32:
t = &schema.BinaryType{T: mysql.TypeLongBlob}
}
case field.TypeJSON:
t = &schema.JSONType{T: mysql.TypeJSON}
if compareVersions(d.version, "5.7.8") == -1 {
t = &schema.BinaryType{T: mysql.TypeLongBlob}
}
case field.TypeString:
size := c1.Size
if size == 0 {
size = d.defaultSize(c1)
}
switch {
case c1.typ == "tinytext", c1.typ == "text":
t = &schema.StringType{T: c1.typ}
case size <= math.MaxUint16:
t = &schema.StringType{T: mysql.TypeVarchar, Size: int(size)}
case size == 1<<24-1:
t = &schema.StringType{T: mysql.TypeMediumText}
default:
t = &schema.StringType{T: mysql.TypeLongText}
}
case field.TypeFloat32, field.TypeFloat64:
t = &schema.FloatType{T: c1.scanTypeOr(mysql.TypeDouble)}
case field.TypeTime:
t = &schema.TimeType{T: c1.scanTypeOr(mysql.TypeTimestamp)}
// In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP`
// and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is
// suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute.
if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c1.Default == nil {
c2.SetNull(c1.Attr == "")
}
case field.TypeEnum:
t = &schema.EnumType{T: mysql.TypeEnum, Values: c1.Enums}
case field.TypeUUID:
// "CHAR(X) BINARY" is treated as "CHAR(X) COLLATE latin1_bin", and in MySQL < 8,
// and "COLLATE utf8mb4_bin" in MySQL >= 8. However we already set the table to
t = &schema.StringType{T: mysql.TypeChar, Size: 36}
c2.SetCollation("utf8mb4_bin")
default:
t, err := mysql.ParseType(strings.ToLower(c1.typ))
if err != nil {
return err
}
c2.Type.Type = t
}
c2.Type.Type = t
return nil
}
func (d *MySQL) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) {
// For UNIQUE columns, MySQL create an implicit index
// named as the column with an extra index in case the
// name is already taken (<e.g. c>, <c_2>, <c_3>, ...).
for _, idx := range t1.Indexes {
// Index also defined explicitly, and will be add in atIndexes.
if idx.Unique && d.atImplicitIndexName(idx, c1) {
return
}
}
t2.AddIndexes(schema.NewUniqueIndex(c1.Name).AddColumns(c2))
}
func (d *MySQL) atIncrementC(_ *schema.Table, c *schema.Column) {
c.AddAttrs(&mysql.AutoIncrement{})
}
func (d *MySQL) atIncrementT(t *schema.Table, v int64) {
t.AddAttrs(&mysql.AutoIncrement{V: v})
}
func (d *MySQL) atImplicitIndexName(idx *Index, c1 *Column) bool {
if idx.Name == c1.Name {
return true
}
if !strings.HasPrefix(idx.Name, c1.Name+"_") {
return false
}
i, err := strconv.ParseInt(strings.TrimLeft(idx.Name, c1.Name+"_"), 10, 64)
return err == nil && i > 1
}
func (d *MySQL) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error {
prefix := indexParts(idx1)
for _, c1 := range idx1.Columns {
c2, ok := t2.Column(c1.Name)
if !ok {
return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name)
}
part := &schema.IndexPart{C: c2}
if v, ok := prefix[c1.Name]; ok {
part.AddAttrs(&mysql.SubPart{Len: int(v)})
}
idx2.AddParts(part)
}
if t, ok := indexType(idx1, dialect.MySQL); ok {
idx2.AddAttrs(&mysql.IndexType{T: t})
}
return nil
}
func indexType(idx *Index, d string) (string, bool) {
ant := idx.Annotation
if ant == nil {
return "", false
}
if ant.Types != nil && ant.Types[d] != "" {
return ant.Types[d], true
}
if ant.Type != "" {
return ant.Type, true
}
return "", false
}
func (MySQL) atTypeRangeSQL(ts ...string) string {
for i := range ts {
ts[i] = fmt.Sprintf("('%s')", ts[i])
}
return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", "))
}