dialect/sql/schema: support mysql latest numeric type format (#328)

This commit is contained in:
qystishere
2020-02-08 21:37:35 +08:00
committed by GitHub
parent afc8bd3eab
commit 374b5dd3b8
2 changed files with 47 additions and 24 deletions

View File

@@ -321,33 +321,31 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
if nullable.Valid {
c.Nullable = nullable.String == "YES"
}
switch parts := strings.FieldsFunc(c.typ, func(r rune) bool {
return r == '(' || r == ')' || r == ' ' || r == ','
}); parts[0] {
parts, size, unsigned, err := parseColumn(c.typ)
if err != nil {
return err
}
switch parts[0] {
case "int":
c.Type = field.TypeInt32
if len(parts) == 3 { // int(10) unsigned.
if unsigned {
c.Type = field.TypeUint32
}
case "smallint":
c.Type = field.TypeInt16
if len(parts) == 3 { // smallint(5) unsigned.
if unsigned {
c.Type = field.TypeUint16
}
case "bigint":
c.Type = field.TypeInt64
if len(parts) == 3 { // bigint(20) unsigned.
if unsigned {
c.Type = field.TypeUint64
}
case "tinyint":
size, err := strconv.Atoi(parts[1])
if err != nil {
return fmt.Errorf("converting varchar size to int: %v", err)
}
switch {
case size == 1:
c.Type = field.TypeBool
case len(parts) == 3: // tinyint(3) unsigned.
case unsigned:
c.Type = field.TypeUint8
default:
c.Type = field.TypeInt8
@@ -370,17 +368,9 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
c.Type = field.TypeBytes
case "varbinary":
c.Type = field.TypeBytes
size, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("converting varbinary size to int: %v", err)
}
c.Size = size
case "varchar":
c.Type = field.TypeString
size, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("converting varchar size to int: %v", err)
}
c.Size = size
case "longtext":
c.Size = math.MaxInt32
@@ -394,10 +384,6 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
c.Enums[i] = strings.Trim(e, "'")
}
case "char":
size, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("converting char size to int: %v", err)
}
// UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens).
if size != 36 {
return fmt.Errorf("unknown char(%d) type (not a uuid)", size)
@@ -479,6 +465,30 @@ func (d *MySQL) tableSchema() sql.Querier {
return sql.Raw("(SELECT DATABASE())")
}
// parseColumn returns column parts, size and signedness by mysql type
func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) {
switch parts = strings.FieldsFunc(typ, func(r rune) bool {
return r == '(' || r == ')' || r == ' ' || r == ','
}); parts[0] {
case "int", "smallint", "bigint", "tinyint":
switch {
case len(parts) == 2 && parts[1] == "unsigned": // int unsigned
unsigned = true
case len(parts) == 3: // int(10) unsigned
unsigned = true
fallthrough
case len(parts) == 2: // int(10)
size, err = strconv.ParseInt(parts[1], 10, 0)
}
case "varbinary", "varchar", "char":
size, err = strconv.ParseInt(parts[1], 10, 64)
}
if err != nil {
return parts, size, unsigned, fmt.Errorf("converting %s size to int: %v", parts[0], err)
}
return parts, size, unsigned, nil
}
// fkNames returns the foreign-key names of a column.
func fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) {
query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE").Unquote()).

View File

@@ -169,6 +169,12 @@ func TestMySQL_Create(t *testing.T) {
{Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32},
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
{Name: "age", Type: field.TypeInt},
{Name: "tiny", Type: field.TypeInt8},
{Name: "tiny_unsigned", Type: field.TypeUint8},
{Name: "small", Type: field.TypeInt16},
{Name: "small_unsigned", Type: field.TypeUint16},
{Name: "big", Type: field.TypeInt64},
{Name: "big_unsigned", Type: field.TypeUint64},
},
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
@@ -188,7 +194,14 @@ func TestMySQL_Create(t *testing.T) {
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "").
AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "").
AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin"))
AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin").
// 8.0.19: new int column type formats
AddRow("tiny", "tinyint", "NO", "YES", "NULL", "", "", "").
AddRow("tiny_unsigned", "tinyint unsigned", "NO", "YES", "NULL", "", "", "").
AddRow("small", "smallint", "NO", "YES", "NULL", "", "", "").
AddRow("small_unsigned", "smallint unsigned", "NO", "YES", "NULL", "", "", "").
AddRow("big", "bigint", "NO", "YES", "NULL", "", "", "").
AddRow("big_unsigned", "bigint unsigned", "NO", "YES", "NULL", "", "", ""))
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM INFORMATION_SCHEMA.STATISTICS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}).