From 374b5dd3b85df7efbfc40a0b3c6244f987713027 Mon Sep 17 00:00:00 2001 From: qystishere Date: Sat, 8 Feb 2020 21:37:35 +0800 Subject: [PATCH] dialect/sql/schema: support mysql latest numeric type format (#328) --- dialect/sql/schema/mysql.go | 56 +++++++++++++++++++------------- dialect/sql/schema/mysql_test.go | 15 ++++++++- 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index 86898db89..9ac433072 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -321,33 +321,31 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { if nullable.Valid { c.Nullable = nullable.String == "YES" } - switch parts := strings.FieldsFunc(c.typ, func(r rune) bool { - return r == '(' || r == ')' || r == ' ' || r == ',' - }); parts[0] { + parts, size, unsigned, err := parseColumn(c.typ) + if err != nil { + return err + } + switch parts[0] { case "int": c.Type = field.TypeInt32 - if len(parts) == 3 { // int(10) unsigned. + if unsigned { c.Type = field.TypeUint32 } case "smallint": c.Type = field.TypeInt16 - if len(parts) == 3 { // smallint(5) unsigned. + if unsigned { c.Type = field.TypeUint16 } case "bigint": c.Type = field.TypeInt64 - if len(parts) == 3 { // bigint(20) unsigned. + if unsigned { c.Type = field.TypeUint64 } case "tinyint": - size, err := strconv.Atoi(parts[1]) - if err != nil { - return fmt.Errorf("converting varchar size to int: %v", err) - } switch { case size == 1: c.Type = field.TypeBool - case len(parts) == 3: // tinyint(3) unsigned. + case unsigned: c.Type = field.TypeUint8 default: c.Type = field.TypeInt8 @@ -370,17 +368,9 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { c.Type = field.TypeBytes case "varbinary": c.Type = field.TypeBytes - size, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return fmt.Errorf("converting varbinary size to int: %v", err) - } c.Size = size case "varchar": c.Type = field.TypeString - size, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return fmt.Errorf("converting varchar size to int: %v", err) - } c.Size = size case "longtext": c.Size = math.MaxInt32 @@ -394,10 +384,6 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { c.Enums[i] = strings.Trim(e, "'") } case "char": - size, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return fmt.Errorf("converting char size to int: %v", err) - } // UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens). if size != 36 { return fmt.Errorf("unknown char(%d) type (not a uuid)", size) @@ -479,6 +465,30 @@ func (d *MySQL) tableSchema() sql.Querier { return sql.Raw("(SELECT DATABASE())") } +// parseColumn returns column parts, size and signedness by mysql type +func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) { + switch parts = strings.FieldsFunc(typ, func(r rune) bool { + return r == '(' || r == ')' || r == ' ' || r == ',' + }); parts[0] { + case "int", "smallint", "bigint", "tinyint": + switch { + case len(parts) == 2 && parts[1] == "unsigned": // int unsigned + unsigned = true + case len(parts) == 3: // int(10) unsigned + unsigned = true + fallthrough + case len(parts) == 2: // int(10) + size, err = strconv.ParseInt(parts[1], 10, 0) + } + case "varbinary", "varchar", "char": + size, err = strconv.ParseInt(parts[1], 10, 64) + } + if err != nil { + return parts, size, unsigned, fmt.Errorf("converting %s size to int: %v", parts[0], err) + } + return parts, size, unsigned, nil +} + // fkNames returns the foreign-key names of a column. func fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) { query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE").Unquote()). diff --git a/dialect/sql/schema/mysql_test.go b/dialect/sql/schema/mysql_test.go index 886449898..8ac08ecc0 100644 --- a/dialect/sql/schema/mysql_test.go +++ b/dialect/sql/schema/mysql_test.go @@ -169,6 +169,12 @@ func TestMySQL_Create(t *testing.T) { {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "age", Type: field.TypeInt}, + {Name: "tiny", Type: field.TypeInt8}, + {Name: "tiny_unsigned", Type: field.TypeUint8}, + {Name: "small", Type: field.TypeInt16}, + {Name: "small_unsigned", Type: field.TypeUint16}, + {Name: "big", Type: field.TypeInt64}, + {Name: "big_unsigned", Type: field.TypeUint64}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, @@ -188,7 +194,14 @@ func TestMySQL_Create(t *testing.T) { AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", ""). AddRow("text", "longtext", "YES", "YES", "NULL", "", "", ""). - AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin")) + AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin"). + // 8.0.19: new int column type formats + AddRow("tiny", "tinyint", "NO", "YES", "NULL", "", "", ""). + AddRow("tiny_unsigned", "tinyint unsigned", "NO", "YES", "NULL", "", "", ""). + AddRow("small", "smallint", "NO", "YES", "NULL", "", "", ""). + AddRow("small_unsigned", "smallint unsigned", "NO", "YES", "NULL", "", "", ""). + AddRow("big", "bigint", "NO", "YES", "NULL", "", "", ""). + AddRow("big_unsigned", "bigint unsigned", "NO", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM INFORMATION_SCHEMA.STATISTICS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}).