mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: support mysql latest numeric type format (#328)
This commit is contained in:
@@ -321,33 +321,31 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
if nullable.Valid {
|
||||
c.Nullable = nullable.String == "YES"
|
||||
}
|
||||
switch parts := strings.FieldsFunc(c.typ, func(r rune) bool {
|
||||
return r == '(' || r == ')' || r == ' ' || r == ','
|
||||
}); parts[0] {
|
||||
parts, size, unsigned, err := parseColumn(c.typ)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch parts[0] {
|
||||
case "int":
|
||||
c.Type = field.TypeInt32
|
||||
if len(parts) == 3 { // int(10) unsigned.
|
||||
if unsigned {
|
||||
c.Type = field.TypeUint32
|
||||
}
|
||||
case "smallint":
|
||||
c.Type = field.TypeInt16
|
||||
if len(parts) == 3 { // smallint(5) unsigned.
|
||||
if unsigned {
|
||||
c.Type = field.TypeUint16
|
||||
}
|
||||
case "bigint":
|
||||
c.Type = field.TypeInt64
|
||||
if len(parts) == 3 { // bigint(20) unsigned.
|
||||
if unsigned {
|
||||
c.Type = field.TypeUint64
|
||||
}
|
||||
case "tinyint":
|
||||
size, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varchar size to int: %v", err)
|
||||
}
|
||||
switch {
|
||||
case size == 1:
|
||||
c.Type = field.TypeBool
|
||||
case len(parts) == 3: // tinyint(3) unsigned.
|
||||
case unsigned:
|
||||
c.Type = field.TypeUint8
|
||||
default:
|
||||
c.Type = field.TypeInt8
|
||||
@@ -370,17 +368,9 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
c.Type = field.TypeBytes
|
||||
case "varbinary":
|
||||
c.Type = field.TypeBytes
|
||||
size, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varbinary size to int: %v", err)
|
||||
}
|
||||
c.Size = size
|
||||
case "varchar":
|
||||
c.Type = field.TypeString
|
||||
size, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting varchar size to int: %v", err)
|
||||
}
|
||||
c.Size = size
|
||||
case "longtext":
|
||||
c.Size = math.MaxInt32
|
||||
@@ -394,10 +384,6 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
c.Enums[i] = strings.Trim(e, "'")
|
||||
}
|
||||
case "char":
|
||||
size, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting char size to int: %v", err)
|
||||
}
|
||||
// UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens).
|
||||
if size != 36 {
|
||||
return fmt.Errorf("unknown char(%d) type (not a uuid)", size)
|
||||
@@ -479,6 +465,30 @@ func (d *MySQL) tableSchema() sql.Querier {
|
||||
return sql.Raw("(SELECT DATABASE())")
|
||||
}
|
||||
|
||||
// parseColumn returns column parts, size and signedness by mysql type
|
||||
func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) {
|
||||
switch parts = strings.FieldsFunc(typ, func(r rune) bool {
|
||||
return r == '(' || r == ')' || r == ' ' || r == ','
|
||||
}); parts[0] {
|
||||
case "int", "smallint", "bigint", "tinyint":
|
||||
switch {
|
||||
case len(parts) == 2 && parts[1] == "unsigned": // int unsigned
|
||||
unsigned = true
|
||||
case len(parts) == 3: // int(10) unsigned
|
||||
unsigned = true
|
||||
fallthrough
|
||||
case len(parts) == 2: // int(10)
|
||||
size, err = strconv.ParseInt(parts[1], 10, 0)
|
||||
}
|
||||
case "varbinary", "varchar", "char":
|
||||
size, err = strconv.ParseInt(parts[1], 10, 64)
|
||||
}
|
||||
if err != nil {
|
||||
return parts, size, unsigned, fmt.Errorf("converting %s size to int: %v", parts[0], err)
|
||||
}
|
||||
return parts, size, unsigned, nil
|
||||
}
|
||||
|
||||
// fkNames returns the foreign-key names of a column.
|
||||
func fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) {
|
||||
query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE").Unquote()).
|
||||
|
||||
@@ -169,6 +169,12 @@ func TestMySQL_Create(t *testing.T) {
|
||||
{Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32},
|
||||
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
|
||||
{Name: "age", Type: field.TypeInt},
|
||||
{Name: "tiny", Type: field.TypeInt8},
|
||||
{Name: "tiny_unsigned", Type: field.TypeUint8},
|
||||
{Name: "small", Type: field.TypeInt16},
|
||||
{Name: "small_unsigned", Type: field.TypeUint16},
|
||||
{Name: "big", Type: field.TypeInt64},
|
||||
{Name: "big_unsigned", Type: field.TypeUint64},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
@@ -188,7 +194,14 @@ func TestMySQL_Create(t *testing.T) {
|
||||
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "").
|
||||
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "").
|
||||
AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "").
|
||||
AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin"))
|
||||
AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin").
|
||||
// 8.0.19: new int column type formats
|
||||
AddRow("tiny", "tinyint", "NO", "YES", "NULL", "", "", "").
|
||||
AddRow("tiny_unsigned", "tinyint unsigned", "NO", "YES", "NULL", "", "", "").
|
||||
AddRow("small", "smallint", "NO", "YES", "NULL", "", "", "").
|
||||
AddRow("small_unsigned", "smallint unsigned", "NO", "YES", "NULL", "", "", "").
|
||||
AddRow("big", "bigint", "NO", "YES", "NULL", "", "", "").
|
||||
AddRow("big_unsigned", "bigint unsigned", "NO", "YES", "NULL", "", "", ""))
|
||||
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM INFORMATION_SCHEMA.STATISTICS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}).
|
||||
|
||||
Reference in New Issue
Block a user