From 23cbf325c0d78846f63f3614ffd387237d4adf7c Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sun, 27 Oct 2019 08:57:38 -0700 Subject: [PATCH] dialect/sql/schema: move MySQL logic to its own file Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/128 Reviewed By: alexsn Differential Revision: D18164283 fbshipit-source-id: da6b4d6df89ae4172d8f47a7790c4dac3a8ffe93 --- dialect/sql/schema/mysql.go | 231 +++++++++++++++++++++++++++++- dialect/sql/schema/schema.go | 213 --------------------------- dialect/sql/schema/schema_test.go | 17 --- 3 files changed, 224 insertions(+), 237 deletions(-) diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index 24412cf37..a36c52620 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -7,9 +7,14 @@ package schema import ( "context" "fmt" + "math" + "sort" + "strconv" + "strings" "github.com/facebookincubator/ent/dialect" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/schema/field" ) // MySQL is a mysql migration driver. @@ -62,7 +67,7 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, t := NewTable(name) for rows.Next() { c := &Column{} - if err := c.ScanMySQL(rows); err != nil { + if err := d.scanColumn(c, rows); err != nil { return nil, fmt.Errorf("mysql: %v", err) } if c.PrimaryKey() { @@ -94,8 +99,8 @@ func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Ind return nil, fmt.Errorf("mysql: reading index description %v", err) } defer rows.Close() - var idx Indexes - if err := idx.ScanMySQL(rows); err != nil { + idx, err := d.scanIndexes(rows) + if err != nil { return nil, fmt.Errorf("mysql: %v", err) } return idx, nil @@ -105,11 +110,108 @@ func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, new(sql.Result)) } -func (d *MySQL) cType(c *Column) string { return c.MySQLType(d.version) } -func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { return t.MySQL(d.version) } -func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { return c.MySQL(d.version) } +// tBuilder returns the MySQL DSL query for table creation. +func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { + b := sql.CreateTable(t.Name).IfNotExists() + for _, c := range t.Columns { + b.Column(d.addColumn(c)) + } + for _, pk := range t.PrimaryKey { + b.PrimaryKey(pk.Name) + } + // default charset / collation on MySQL table. + // columns can be override using the Charset / Collate fields. + b.Charset("utf8mb4").Collate("utf8mb4_bin") + return b +} + +// cType returns the MySQL string type for the given column. +func (d *MySQL) cType(c *Column) (t string) { + switch c.Type { + case field.TypeBool: + t = "boolean" + case field.TypeInt8: + t = "tinyint" + case field.TypeUint8: + t = "tinyint unsigned" + case field.TypeInt16: + t = "smallint" + case field.TypeUint16: + t = "smallint unsigned" + case field.TypeInt32: + t = "int" + case field.TypeUint32: + t = "int unsigned" + case field.TypeInt, field.TypeInt64: + t = "bigint" + case field.TypeUint, field.TypeUint64: + t = "bigint unsigned" + case field.TypeBytes: + size := int64(math.MaxUint16) + if c.Size > 0 { + size = c.Size + } + switch { + case size <= math.MaxUint8: + t = "tinyblob" + case size <= math.MaxUint16: + t = "blob" + case size < 1<<24: + t = "mediumblob" + case size <= math.MaxUint32: + t = "longblob" + } + case field.TypeJSON: + t = "json" + if compareVersions(d.version, "5.7.8") == -1 { + t = "longblob" + } + case field.TypeString: + size := c.Size + if size == 0 { + size = c.defaultSize(d.version) + } + if size <= math.MaxUint16 { + t = fmt.Sprintf("varchar(%d)", size) + } else { + t = "longtext" + } + case field.TypeFloat32, field.TypeFloat64: + t = "double" + case field.TypeTime: + t = "timestamp" + // in MySQL timestamp columns are `NOT NULL by default, and assigning NULL + // assigns the current_timestamp(). We avoid this if not set otherwise. + c.Nullable = true + case field.TypeEnum: + values := make([]string, len(c.Enums)) + for i, e := range c.Enums { + values[i] = fmt.Sprintf("'%s'", e) + } + sort.Strings(values) + t = fmt.Sprintf("enum(%s)", strings.Join(values, ", ")) + default: + panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name)) + } + return t +} + +// addColumn returns the DSL query for adding the given column to a table. +// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable]. +func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { + b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr) + c.unique(b) + if c.Increment { + b.Attr("AUTO_INCREMENT") + } + c.nullable(b) + c.defaultValue(b) + return b +} + +// alterColumn returns the DSL query for modifying the given column. func (d *MySQL) alterColumn(c *Column) []*sql.ColumnBuilder { - return []*sql.ColumnBuilder{c.MySQL(d.version)} + return []*sql.ColumnBuilder{d.addColumn(c)} } // addIndex returns the querying for adding an index to MySQL. @@ -121,3 +223,118 @@ func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder { func (d *MySQL) dropIndex(i *Index, table string) *sql.DropIndexBuilder { return i.DropBuilder(table) } + +// scanColumn scans the column information from MySQL column description. +func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { + var ( + nullable sql.NullString + defaults sql.NullString + ) + if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}); err != nil { + return fmt.Errorf("scanning column description: %v", err) + } + c.Unique = c.UniqueKey() + if nullable.Valid { + c.Nullable = nullable.String == "YES" + } + switch parts := strings.FieldsFunc(c.typ, func(r rune) bool { + return r == '(' || r == ')' || r == ' ' || r == ',' + }); parts[0] { + case "int": + c.Type = field.TypeInt32 + case "smallint": + c.Type = field.TypeInt16 + if len(parts) == 3 { // smallint(5) unsigned. + c.Type = field.TypeUint16 + } + case "bigint": + c.Type = field.TypeInt64 + if len(parts) == 3 { // bigint(20) unsigned. + c.Type = field.TypeUint64 + } + case "tinyint": + size, err := strconv.Atoi(parts[1]) + if err != nil { + return fmt.Errorf("converting varchar size to int: %v", err) + } + switch { + case size == 1: + c.Type = field.TypeBool + case len(parts) == 3: // tinyint(3) unsigned. + c.Type = field.TypeUint8 + default: + c.Type = field.TypeInt8 + } + case "double": + c.Type = field.TypeFloat64 + case "timestamp", "datetime": + c.Type = field.TypeTime + case "tinyblob": + c.Size = math.MaxUint8 + c.Type = field.TypeBytes + case "blob": + c.Size = math.MaxUint16 + c.Type = field.TypeBytes + case "mediumblob": + c.Size = 1<<24 - 1 + c.Type = field.TypeBytes + case "longblob": + c.Size = math.MaxUint32 + c.Type = field.TypeBytes + case "varchar": + c.Type = field.TypeString + size, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return fmt.Errorf("converting varchar size to int: %v", err) + } + c.Size = size + case "longtext": + c.Size = math.MaxInt32 + c.Type = field.TypeString + case "json": + c.Type = field.TypeJSON + case "enum": + c.Type = field.TypeEnum + c.Enums = make([]string, len(parts)-1) + for i, e := range parts[1:] { + c.Enums[i] = strings.Trim(e, "'") + } + } + if defaults.Valid { + return c.ScanDefault(defaults.String) + } + return nil +} + +// scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows, +// should return the following 4 columns: INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX. +// SEQ_IN_INDEX specifies the position of the column in the index columns. +func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) { + var ( + i Indexes + names = make(map[string]*Index) + ) + for rows.Next() { + var ( + name string + column string + nonuniq bool + seqindex int + ) + if err := rows.Scan(&name, &column, &nonuniq, &seqindex); err != nil { + return nil, fmt.Errorf("scanning index description: %v", err) + } + // ignore primary keys. + if name == "PRIMARY" { + continue + } + idx, ok := names[name] + if !ok { + idx = &Index{Name: name, Unique: !nonuniq} + i = append(i, idx) + names[name] = idx + } + idx.columns = append(idx.columns, column) + } + return i, nil +} diff --git a/dialect/sql/schema/schema.go b/dialect/sql/schema/schema.go index 39ba093c7..da2575f17 100644 --- a/dialect/sql/schema/schema.go +++ b/dialect/sql/schema/schema.go @@ -7,8 +7,6 @@ package schema import ( "fmt" - "math" - "sort" "strconv" "strings" @@ -105,21 +103,6 @@ func (t *Table) setup() { } } -// MySQL returns the MySQL DSL query for table creation. -func (t *Table) MySQL(version string) *sql.TableBuilder { - b := sql.CreateTable(t.Name).IfNotExists() - for _, c := range t.Columns { - b.Column(c.MySQL(version)) - } - for _, pk := range t.PrimaryKey { - b.PrimaryKey(pk.Name) - } - // default charset / collation on MySQL table. - // columns can be override using the Charset / Collate fields. - b.Charset("utf8mb4").Collate("utf8mb4_bin") - return b -} - // SQLite returns the SQLite query for table creation. func (t *Table) SQLite() *sql.TableBuilder { b := sql.CreateTable(t.Name) @@ -201,19 +184,6 @@ func (c *Column) UniqueKey() bool { return c.Key == UniqueKey } // Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects. func (c *Column) PrimaryKey() bool { return c.Key == PrimaryKey } -// MySQL returns the MySQL DSL query for table creation. -// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable]. -func (c *Column) MySQL(version string) *sql.ColumnBuilder { - b := sql.Column(c.Name).Type(c.MySQLType(version)).Attr(c.Attr) - c.unique(b) - if c.Increment { - b.Attr("AUTO_INCREMENT") - } - c.nullable(b) - c.defaultValue(b) - return b -} - // SQLite returns a SQLite DSL node for this column. func (c *Column) SQLite() *sql.ColumnBuilder { b := sql.Column(c.Name).Type(c.SQLiteType()).Attr(c.Attr) @@ -226,77 +196,6 @@ func (c *Column) SQLite() *sql.ColumnBuilder { return b } -// MySQLType returns the MySQL string type for this column. -func (c *Column) MySQLType(version string) (t string) { - switch c.Type { - case field.TypeBool: - t = "boolean" - case field.TypeInt8: - t = "tinyint" - case field.TypeUint8: - t = "tinyint unsigned" - case field.TypeInt16: - t = "smallint" - case field.TypeUint16: - t = "smallint unsigned" - case field.TypeInt32: - t = "int" - case field.TypeUint32: - t = "int unsigned" - case field.TypeInt, field.TypeInt64: - t = "bigint" - case field.TypeUint, field.TypeUint64: - t = "bigint unsigned" - case field.TypeBytes: - size := int64(math.MaxUint16) - if c.Size > 0 { - size = c.Size - } - switch { - case size <= math.MaxUint8: - t = "tinyblob" - case size <= math.MaxUint16: - t = "blob" - case size < 1<<24: - t = "mediumblob" - case size <= math.MaxUint32: - t = "longblob" - } - case field.TypeJSON: - t = "json" - if compareVersions(version, "5.7.8") == -1 { - t = "longblob" - } - case field.TypeString: - size := c.Size - if size == 0 { - size = c.defaultSize(version) - } - if size <= math.MaxUint16 { - t = fmt.Sprintf("varchar(%d)", size) - } else { - t = "longtext" - } - case field.TypeFloat32, field.TypeFloat64: - t = "double" - case field.TypeTime: - t = "timestamp" - // in MySQL timestamp columns are `NOT NULL by default, and assigning NULL - // assigns the current_timestamp(). We avoid this if not set otherwise. - c.Nullable = true - case field.TypeEnum: - values := make([]string, len(c.Enums)) - for i, e := range c.Enums { - values[i] = fmt.Sprintf("'%s'", e) - } - sort.Strings(values) - t = fmt.Sprintf("enum(%s)", strings.Join(values, ", ")) - default: - panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name)) - } - return t -} - // SQLiteType returns the SQLite string type for this column. func (c *Column) SQLiteType() (t string) { switch c.Type { @@ -327,88 +226,6 @@ func (c *Column) SQLiteType() (t string) { return t } -// ScanMySQL scans the information from MySQL column description. -func (c *Column) ScanMySQL(rows *sql.Rows) error { - var ( - nullable sql.NullString - defaults sql.NullString - ) - if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}); err != nil { - return fmt.Errorf("scanning column description: %v", err) - } - c.Unique = c.UniqueKey() - if nullable.Valid { - c.Nullable = nullable.String == "YES" - } - switch parts := strings.FieldsFunc(c.typ, func(r rune) bool { - return r == '(' || r == ')' || r == ' ' || r == ',' - }); parts[0] { - case "int": - c.Type = field.TypeInt32 - case "smallint": - c.Type = field.TypeInt16 - if len(parts) == 3 { // smallint(5) unsigned. - c.Type = field.TypeUint16 - } - case "bigint": - c.Type = field.TypeInt64 - if len(parts) == 3 { // bigint(20) unsigned. - c.Type = field.TypeUint64 - } - case "tinyint": - size, err := strconv.Atoi(parts[1]) - if err != nil { - return fmt.Errorf("converting varchar size to int: %v", err) - } - switch { - case size == 1: - c.Type = field.TypeBool - case len(parts) == 3: // tinyint(3) unsigned. - c.Type = field.TypeUint8 - default: - c.Type = field.TypeInt8 - } - case "double": - c.Type = field.TypeFloat64 - case "timestamp", "datetime": - c.Type = field.TypeTime - case "tinyblob": - c.Size = math.MaxUint8 - c.Type = field.TypeBytes - case "blob": - c.Size = math.MaxUint16 - c.Type = field.TypeBytes - case "mediumblob": - c.Size = 1<<24 - 1 - c.Type = field.TypeBytes - case "longblob": - c.Size = math.MaxUint32 - c.Type = field.TypeBytes - case "varchar": - c.Type = field.TypeString - size, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return fmt.Errorf("converting varchar size to int: %v", err) - } - c.Size = size - case "longtext": - c.Size = math.MaxInt32 - c.Type = field.TypeString - case "json": - c.Type = field.TypeJSON - case "enum": - c.Type = field.TypeEnum - c.Enums = make([]string, len(parts)-1) - for i, e := range parts[1:] { - c.Enums[i] = strings.Trim(e, "'") - } - } - if defaults.Valid { - return c.ScanDefault(defaults.String) - } - return nil -} - // ConvertibleTo reports whether a column can be converted to the new column without altering its data. func (c *Column) ConvertibleTo(d *Column) bool { switch { @@ -632,36 +449,6 @@ func (i *Indexes) append(idx1 *Index) { *i = append(*i, idx1) } -// ScanMySQL scans sql.Rows into an Indexes list. The query for returning the rows, -// should return the following 4 columns: INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX. -// SEQ_IN_INDEX specifies the position of the column in the index columns. -func (i *Indexes) ScanMySQL(rows *sql.Rows) error { - names := make(map[string]*Index) - for rows.Next() { - var ( - name string - column string - nonuniq bool - seqindex int - ) - if err := rows.Scan(&name, &column, &nonuniq, &seqindex); err != nil { - return fmt.Errorf("scanning index description: %v", err) - } - // ignore primary keys. - if name == "PRIMARY" { - continue - } - idx, ok := names[name] - if !ok { - idx = &Index{Name: name, Unique: !nonuniq} - *i = append(*i, idx) - names[name] = idx - } - idx.columns = append(idx.columns, column) - } - return nil -} - // compareVersions returns an integer comparing the 2 versions. func compareVersions(v1, v2 string) int { pv1, ok1 := parseVersion(v1) diff --git a/dialect/sql/schema/schema_test.go b/dialect/sql/schema/schema_test.go index 9496fac10..2d77a932d 100644 --- a/dialect/sql/schema/schema_test.go +++ b/dialect/sql/schema/schema_test.go @@ -93,20 +93,3 @@ func TestColumn_ScanDefault(t *testing.T) { require.Equal(t, false, c1.Default) require.Error(t, c1.ScanDefault("foo")) } - -func TestColumn_MySQLType(t *testing.T) { - c1 := &Column{Type: field.TypeString, Unique: true} - require.Equal(t, "varchar(191)", c1.MySQLType("5.5")) - require.Equal(t, "varchar(191)", c1.MySQLType("5.6.1")) - require.Equal(t, "varchar(191)", c1.MySQLType("5.6.8")) - require.Equal(t, "varchar(255)", c1.MySQLType("5.7")) - require.Equal(t, "varchar(255)", c1.MySQLType("5.7.0")) - require.Equal(t, "varchar(255)", c1.MySQLType("5.7.26-log")) - require.Equal(t, "varchar(255)", c1.MySQLType("8-log")) - - c1 = &Column{Type: field.TypeJSON} - require.Equal(t, "json", c1.MySQLType("5.7.8")) - require.Equal(t, "json", c1.MySQLType("5.7.8-log")) - require.Equal(t, "longblob", c1.MySQLType("5.5")) - require.Equal(t, "longblob", c1.MySQLType("5.7")) -}