mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: move MySQL logic to its own file
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/128 Reviewed By: alexsn Differential Revision: D18164283 fbshipit-source-id: da6b4d6df89ae4172d8f47a7790c4dac3a8ffe93
This commit is contained in:
committed by
Facebook Github Bot
parent
b2ac0fe2e7
commit
23cbf325c0
@@ -7,9 +7,14 @@ package schema
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/facebookincubator/ent/dialect"
|
||||
"github.com/facebookincubator/ent/dialect/sql"
|
||||
"github.com/facebookincubator/ent/schema/field"
|
||||
)
|
||||
|
||||
// MySQL is a mysql migration driver.
|
||||
@@ -62,7 +67,7 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
t := NewTable(name)
|
||||
for rows.Next() {
|
||||
c := &Column{}
|
||||
if err := c.ScanMySQL(rows); err != nil {
|
||||
if err := d.scanColumn(c, rows); err != nil {
|
||||
return nil, fmt.Errorf("mysql: %v", err)
|
||||
}
|
||||
if c.PrimaryKey() {
|
||||
@@ -94,8 +99,8 @@ func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Ind
|
||||
return nil, fmt.Errorf("mysql: reading index description %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var idx Indexes
|
||||
if err := idx.ScanMySQL(rows); err != nil {
|
||||
idx, err := d.scanIndexes(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: %v", err)
|
||||
}
|
||||
return idx, nil
|
||||
@@ -105,11 +110,108 @@ func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, new(sql.Result))
|
||||
}
|
||||
|
||||
func (d *MySQL) cType(c *Column) string { return c.MySQLType(d.version) }
|
||||
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { return t.MySQL(d.version) }
|
||||
func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { return c.MySQL(d.version) }
|
||||
// tBuilder returns the MySQL DSL query for table creation.
|
||||
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name).IfNotExists()
|
||||
for _, c := range t.Columns {
|
||||
b.Column(d.addColumn(c))
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
// default charset / collation on MySQL table.
|
||||
// columns can be override using the Charset / Collate fields.
|
||||
b.Charset("utf8mb4").Collate("utf8mb4_bin")
|
||||
return b
|
||||
}
|
||||
|
||||
// cType returns the MySQL string type for the given column.
|
||||
func (d *MySQL) cType(c *Column) (t string) {
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "boolean"
|
||||
case field.TypeInt8:
|
||||
t = "tinyint"
|
||||
case field.TypeUint8:
|
||||
t = "tinyint unsigned"
|
||||
case field.TypeInt16:
|
||||
t = "smallint"
|
||||
case field.TypeUint16:
|
||||
t = "smallint unsigned"
|
||||
case field.TypeInt32:
|
||||
t = "int"
|
||||
case field.TypeUint32:
|
||||
t = "int unsigned"
|
||||
case field.TypeInt, field.TypeInt64:
|
||||
t = "bigint"
|
||||
case field.TypeUint, field.TypeUint64:
|
||||
t = "bigint unsigned"
|
||||
case field.TypeBytes:
|
||||
size := int64(math.MaxUint16)
|
||||
if c.Size > 0 {
|
||||
size = c.Size
|
||||
}
|
||||
switch {
|
||||
case size <= math.MaxUint8:
|
||||
t = "tinyblob"
|
||||
case size <= math.MaxUint16:
|
||||
t = "blob"
|
||||
case size < 1<<24:
|
||||
t = "mediumblob"
|
||||
case size <= math.MaxUint32:
|
||||
t = "longblob"
|
||||
}
|
||||
case field.TypeJSON:
|
||||
t = "json"
|
||||
if compareVersions(d.version, "5.7.8") == -1 {
|
||||
t = "longblob"
|
||||
}
|
||||
case field.TypeString:
|
||||
size := c.Size
|
||||
if size == 0 {
|
||||
size = c.defaultSize(d.version)
|
||||
}
|
||||
if size <= math.MaxUint16 {
|
||||
t = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
t = "longtext"
|
||||
}
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = "double"
|
||||
case field.TypeTime:
|
||||
t = "timestamp"
|
||||
// in MySQL timestamp columns are `NOT NULL by default, and assigning NULL
|
||||
// assigns the current_timestamp(). We avoid this if not set otherwise.
|
||||
c.Nullable = true
|
||||
case field.TypeEnum:
|
||||
values := make([]string, len(c.Enums))
|
||||
for i, e := range c.Enums {
|
||||
values[i] = fmt.Sprintf("'%s'", e)
|
||||
}
|
||||
sort.Strings(values)
|
||||
t = fmt.Sprintf("enum(%s)", strings.Join(values, ", "))
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// addColumn returns the DSL query for adding the given column to a table.
|
||||
// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable].
|
||||
func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("AUTO_INCREMENT")
|
||||
}
|
||||
c.nullable(b)
|
||||
c.defaultValue(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// alterColumn returns the DSL query for modifying the given column.
|
||||
func (d *MySQL) alterColumn(c *Column) []*sql.ColumnBuilder {
|
||||
return []*sql.ColumnBuilder{c.MySQL(d.version)}
|
||||
return []*sql.ColumnBuilder{d.addColumn(c)}
|
||||
}
|
||||
|
||||
// addIndex returns the querying for adding an index to MySQL.
|
||||
@@ -121,3 +223,118 @@ func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder {
|
||||
func (d *MySQL) dropIndex(i *Index, table string) *sql.DropIndexBuilder {
|
||||
return i.DropBuilder(table)
|
||||
}
|
||||
|
||||
// scanColumn scans the column information from MySQL column description.
|
||||
func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
var (
|
||||
nullable sql.NullString
|
||||
defaults sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}); err != nil {
|
||||
return fmt.Errorf("scanning column description: %v", err)
|
||||
}
|
||||
c.Unique = c.UniqueKey()
|
||||
if nullable.Valid {
|
||||
c.Nullable = nullable.String == "YES"
|
||||
}
|
||||
switch parts := strings.FieldsFunc(c.typ, func(r rune) bool {
|
||||
return r == '(' || r == ')' || r == ' ' || r == ','
|
||||
}); parts[0] {
|
||||
case "int":
|
||||
c.Type = field.TypeInt32
|
||||
case "smallint":
|
||||
c.Type = field.TypeInt16
|
||||
if len(parts) == 3 { // smallint(5) unsigned.
|
||||
c.Type = field.TypeUint16
|
||||
}
|
||||
case "bigint":
|
||||
c.Type = field.TypeInt64
|
||||
if len(parts) == 3 { // bigint(20) unsigned.
|
||||
c.Type = field.TypeUint64
|
||||
}
|
||||
case "tinyint":
|
||||
size, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varchar size to int: %v", err)
|
||||
}
|
||||
switch {
|
||||
case size == 1:
|
||||
c.Type = field.TypeBool
|
||||
case len(parts) == 3: // tinyint(3) unsigned.
|
||||
c.Type = field.TypeUint8
|
||||
default:
|
||||
c.Type = field.TypeInt8
|
||||
}
|
||||
case "double":
|
||||
c.Type = field.TypeFloat64
|
||||
case "timestamp", "datetime":
|
||||
c.Type = field.TypeTime
|
||||
case "tinyblob":
|
||||
c.Size = math.MaxUint8
|
||||
c.Type = field.TypeBytes
|
||||
case "blob":
|
||||
c.Size = math.MaxUint16
|
||||
c.Type = field.TypeBytes
|
||||
case "mediumblob":
|
||||
c.Size = 1<<24 - 1
|
||||
c.Type = field.TypeBytes
|
||||
case "longblob":
|
||||
c.Size = math.MaxUint32
|
||||
c.Type = field.TypeBytes
|
||||
case "varchar":
|
||||
c.Type = field.TypeString
|
||||
size, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varchar size to int: %v", err)
|
||||
}
|
||||
c.Size = size
|
||||
case "longtext":
|
||||
c.Size = math.MaxInt32
|
||||
c.Type = field.TypeString
|
||||
case "json":
|
||||
c.Type = field.TypeJSON
|
||||
case "enum":
|
||||
c.Type = field.TypeEnum
|
||||
c.Enums = make([]string, len(parts)-1)
|
||||
for i, e := range parts[1:] {
|
||||
c.Enums[i] = strings.Trim(e, "'")
|
||||
}
|
||||
}
|
||||
if defaults.Valid {
|
||||
return c.ScanDefault(defaults.String)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows,
|
||||
// should return the following 4 columns: INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX.
|
||||
// SEQ_IN_INDEX specifies the position of the column in the index columns.
|
||||
func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) {
|
||||
var (
|
||||
i Indexes
|
||||
names = make(map[string]*Index)
|
||||
)
|
||||
for rows.Next() {
|
||||
var (
|
||||
name string
|
||||
column string
|
||||
nonuniq bool
|
||||
seqindex int
|
||||
)
|
||||
if err := rows.Scan(&name, &column, &nonuniq, &seqindex); err != nil {
|
||||
return nil, fmt.Errorf("scanning index description: %v", err)
|
||||
}
|
||||
// ignore primary keys.
|
||||
if name == "PRIMARY" {
|
||||
continue
|
||||
}
|
||||
idx, ok := names[name]
|
||||
if !ok {
|
||||
idx = &Index{Name: name, Unique: !nonuniq}
|
||||
i = append(i, idx)
|
||||
names[name] = idx
|
||||
}
|
||||
idx.columns = append(idx.columns, column)
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@ package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -105,21 +103,6 @@ func (t *Table) setup() {
|
||||
}
|
||||
}
|
||||
|
||||
// MySQL returns the MySQL DSL query for table creation.
|
||||
func (t *Table) MySQL(version string) *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name).IfNotExists()
|
||||
for _, c := range t.Columns {
|
||||
b.Column(c.MySQL(version))
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
// default charset / collation on MySQL table.
|
||||
// columns can be override using the Charset / Collate fields.
|
||||
b.Charset("utf8mb4").Collate("utf8mb4_bin")
|
||||
return b
|
||||
}
|
||||
|
||||
// SQLite returns the SQLite query for table creation.
|
||||
func (t *Table) SQLite() *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name)
|
||||
@@ -201,19 +184,6 @@ func (c *Column) UniqueKey() bool { return c.Key == UniqueKey }
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (c *Column) PrimaryKey() bool { return c.Key == PrimaryKey }
|
||||
|
||||
// MySQL returns the MySQL DSL query for table creation.
|
||||
// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable].
|
||||
func (c *Column) MySQL(version string) *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(c.MySQLType(version)).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("AUTO_INCREMENT")
|
||||
}
|
||||
c.nullable(b)
|
||||
c.defaultValue(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// SQLite returns a SQLite DSL node for this column.
|
||||
func (c *Column) SQLite() *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(c.SQLiteType()).Attr(c.Attr)
|
||||
@@ -226,77 +196,6 @@ func (c *Column) SQLite() *sql.ColumnBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// MySQLType returns the MySQL string type for this column.
|
||||
func (c *Column) MySQLType(version string) (t string) {
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "boolean"
|
||||
case field.TypeInt8:
|
||||
t = "tinyint"
|
||||
case field.TypeUint8:
|
||||
t = "tinyint unsigned"
|
||||
case field.TypeInt16:
|
||||
t = "smallint"
|
||||
case field.TypeUint16:
|
||||
t = "smallint unsigned"
|
||||
case field.TypeInt32:
|
||||
t = "int"
|
||||
case field.TypeUint32:
|
||||
t = "int unsigned"
|
||||
case field.TypeInt, field.TypeInt64:
|
||||
t = "bigint"
|
||||
case field.TypeUint, field.TypeUint64:
|
||||
t = "bigint unsigned"
|
||||
case field.TypeBytes:
|
||||
size := int64(math.MaxUint16)
|
||||
if c.Size > 0 {
|
||||
size = c.Size
|
||||
}
|
||||
switch {
|
||||
case size <= math.MaxUint8:
|
||||
t = "tinyblob"
|
||||
case size <= math.MaxUint16:
|
||||
t = "blob"
|
||||
case size < 1<<24:
|
||||
t = "mediumblob"
|
||||
case size <= math.MaxUint32:
|
||||
t = "longblob"
|
||||
}
|
||||
case field.TypeJSON:
|
||||
t = "json"
|
||||
if compareVersions(version, "5.7.8") == -1 {
|
||||
t = "longblob"
|
||||
}
|
||||
case field.TypeString:
|
||||
size := c.Size
|
||||
if size == 0 {
|
||||
size = c.defaultSize(version)
|
||||
}
|
||||
if size <= math.MaxUint16 {
|
||||
t = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
t = "longtext"
|
||||
}
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = "double"
|
||||
case field.TypeTime:
|
||||
t = "timestamp"
|
||||
// in MySQL timestamp columns are `NOT NULL by default, and assigning NULL
|
||||
// assigns the current_timestamp(). We avoid this if not set otherwise.
|
||||
c.Nullable = true
|
||||
case field.TypeEnum:
|
||||
values := make([]string, len(c.Enums))
|
||||
for i, e := range c.Enums {
|
||||
values[i] = fmt.Sprintf("'%s'", e)
|
||||
}
|
||||
sort.Strings(values)
|
||||
t = fmt.Sprintf("enum(%s)", strings.Join(values, ", "))
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// SQLiteType returns the SQLite string type for this column.
|
||||
func (c *Column) SQLiteType() (t string) {
|
||||
switch c.Type {
|
||||
@@ -327,88 +226,6 @@ func (c *Column) SQLiteType() (t string) {
|
||||
return t
|
||||
}
|
||||
|
||||
// ScanMySQL scans the information from MySQL column description.
|
||||
func (c *Column) ScanMySQL(rows *sql.Rows) error {
|
||||
var (
|
||||
nullable sql.NullString
|
||||
defaults sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}); err != nil {
|
||||
return fmt.Errorf("scanning column description: %v", err)
|
||||
}
|
||||
c.Unique = c.UniqueKey()
|
||||
if nullable.Valid {
|
||||
c.Nullable = nullable.String == "YES"
|
||||
}
|
||||
switch parts := strings.FieldsFunc(c.typ, func(r rune) bool {
|
||||
return r == '(' || r == ')' || r == ' ' || r == ','
|
||||
}); parts[0] {
|
||||
case "int":
|
||||
c.Type = field.TypeInt32
|
||||
case "smallint":
|
||||
c.Type = field.TypeInt16
|
||||
if len(parts) == 3 { // smallint(5) unsigned.
|
||||
c.Type = field.TypeUint16
|
||||
}
|
||||
case "bigint":
|
||||
c.Type = field.TypeInt64
|
||||
if len(parts) == 3 { // bigint(20) unsigned.
|
||||
c.Type = field.TypeUint64
|
||||
}
|
||||
case "tinyint":
|
||||
size, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varchar size to int: %v", err)
|
||||
}
|
||||
switch {
|
||||
case size == 1:
|
||||
c.Type = field.TypeBool
|
||||
case len(parts) == 3: // tinyint(3) unsigned.
|
||||
c.Type = field.TypeUint8
|
||||
default:
|
||||
c.Type = field.TypeInt8
|
||||
}
|
||||
case "double":
|
||||
c.Type = field.TypeFloat64
|
||||
case "timestamp", "datetime":
|
||||
c.Type = field.TypeTime
|
||||
case "tinyblob":
|
||||
c.Size = math.MaxUint8
|
||||
c.Type = field.TypeBytes
|
||||
case "blob":
|
||||
c.Size = math.MaxUint16
|
||||
c.Type = field.TypeBytes
|
||||
case "mediumblob":
|
||||
c.Size = 1<<24 - 1
|
||||
c.Type = field.TypeBytes
|
||||
case "longblob":
|
||||
c.Size = math.MaxUint32
|
||||
c.Type = field.TypeBytes
|
||||
case "varchar":
|
||||
c.Type = field.TypeString
|
||||
size, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varchar size to int: %v", err)
|
||||
}
|
||||
c.Size = size
|
||||
case "longtext":
|
||||
c.Size = math.MaxInt32
|
||||
c.Type = field.TypeString
|
||||
case "json":
|
||||
c.Type = field.TypeJSON
|
||||
case "enum":
|
||||
c.Type = field.TypeEnum
|
||||
c.Enums = make([]string, len(parts)-1)
|
||||
for i, e := range parts[1:] {
|
||||
c.Enums[i] = strings.Trim(e, "'")
|
||||
}
|
||||
}
|
||||
if defaults.Valid {
|
||||
return c.ScanDefault(defaults.String)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertibleTo reports whether a column can be converted to the new column without altering its data.
|
||||
func (c *Column) ConvertibleTo(d *Column) bool {
|
||||
switch {
|
||||
@@ -632,36 +449,6 @@ func (i *Indexes) append(idx1 *Index) {
|
||||
*i = append(*i, idx1)
|
||||
}
|
||||
|
||||
// ScanMySQL scans sql.Rows into an Indexes list. The query for returning the rows,
|
||||
// should return the following 4 columns: INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX.
|
||||
// SEQ_IN_INDEX specifies the position of the column in the index columns.
|
||||
func (i *Indexes) ScanMySQL(rows *sql.Rows) error {
|
||||
names := make(map[string]*Index)
|
||||
for rows.Next() {
|
||||
var (
|
||||
name string
|
||||
column string
|
||||
nonuniq bool
|
||||
seqindex int
|
||||
)
|
||||
if err := rows.Scan(&name, &column, &nonuniq, &seqindex); err != nil {
|
||||
return fmt.Errorf("scanning index description: %v", err)
|
||||
}
|
||||
// ignore primary keys.
|
||||
if name == "PRIMARY" {
|
||||
continue
|
||||
}
|
||||
idx, ok := names[name]
|
||||
if !ok {
|
||||
idx = &Index{Name: name, Unique: !nonuniq}
|
||||
*i = append(*i, idx)
|
||||
names[name] = idx
|
||||
}
|
||||
idx.columns = append(idx.columns, column)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// compareVersions returns an integer comparing the 2 versions.
|
||||
func compareVersions(v1, v2 string) int {
|
||||
pv1, ok1 := parseVersion(v1)
|
||||
|
||||
@@ -93,20 +93,3 @@ func TestColumn_ScanDefault(t *testing.T) {
|
||||
require.Equal(t, false, c1.Default)
|
||||
require.Error(t, c1.ScanDefault("foo"))
|
||||
}
|
||||
|
||||
func TestColumn_MySQLType(t *testing.T) {
|
||||
c1 := &Column{Type: field.TypeString, Unique: true}
|
||||
require.Equal(t, "varchar(191)", c1.MySQLType("5.5"))
|
||||
require.Equal(t, "varchar(191)", c1.MySQLType("5.6.1"))
|
||||
require.Equal(t, "varchar(191)", c1.MySQLType("5.6.8"))
|
||||
require.Equal(t, "varchar(255)", c1.MySQLType("5.7"))
|
||||
require.Equal(t, "varchar(255)", c1.MySQLType("5.7.0"))
|
||||
require.Equal(t, "varchar(255)", c1.MySQLType("5.7.26-log"))
|
||||
require.Equal(t, "varchar(255)", c1.MySQLType("8-log"))
|
||||
|
||||
c1 = &Column{Type: field.TypeJSON}
|
||||
require.Equal(t, "json", c1.MySQLType("5.7.8"))
|
||||
require.Equal(t, "json", c1.MySQLType("5.7.8-log"))
|
||||
require.Equal(t, "longblob", c1.MySQLType("5.5"))
|
||||
require.Equal(t, "longblob", c1.MySQLType("5.7"))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user