add support for all int types in schema

Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/8

Reviewed By: alexsn

Differential Revision: D16131257

fbshipit-source-id: 7b362740053c684f70ec69188b2fcee898605436
This commit is contained in:
Ariel Mashraki
2019-07-10 08:30:42 -07:00
committed by Facebook Github Bot
parent 1cf3c3a117
commit 4b176495e8
60 changed files with 10170 additions and 101 deletions

View File

@@ -16,7 +16,11 @@ type MySQL struct {
}
// Create creates all schema resources in the database. It works in an "append-only"
// mode, which means, it won't delete or change any existing resource in the database.
// mode, which means, it only create tables, append column to tables or modifying column type.
//
// Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not
// resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but
// changing varchar(120) to varchar(255) is valid. For more info, see the convert function below.
func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
tx, err := d.Tx(ctx)
if err != nil {
@@ -42,7 +46,7 @@ func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) err
if err != nil {
return err
}
change, err := changeSet(curr, t)
change, err := changeSet(curr, t, version)
if err != nil {
return err
}
@@ -178,7 +182,7 @@ type changes struct {
// changeSet returns a changes object to be applied on existing table.
// It fails if one of the changes is invalid.
func changeSet(curr, new *Table) (*changes, error) {
func changeSet(curr, new *Table, version string) (*changes, error) {
change := &changes{}
// pks.
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
@@ -196,10 +200,13 @@ func changeSet(curr, new *Table) (*changes, error) {
switch c2, ok := curr.column(c1.Name); {
case !ok:
change.add = append(change.add, c1)
case c1.Type != c2.Type:
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, c1.Type, c2.Type)
case c1.Unique != c2.Unique:
return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name)
case c1.MySQLType(version) != c2.MySQLType(version):
if !c2.ConvertibleTo(c1) {
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, c1.MySQLType(version), c2.MySQLType(version))
}
fallthrough
case c1.Charset != "" && c1.Charset != c2.Charset || c1.Collation != "" && c1.Charset != c2.Collation:
change.modify = append(change.modify, c1)
}

View File

@@ -59,7 +59,7 @@ func TestMySQL_Create(t *testing.T) {
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) CHARSET utf8 NULL, `age` int, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT, `name` varchar(255) CHARSET utf8 NULL, `age` bigint, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
@@ -86,7 +86,7 @@ func TestMySQL_Create(t *testing.T) {
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `age` int, `name` varchar(191) UNIQUE, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT, `age` bigint, `name` varchar(191) UNIQUE, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
@@ -134,12 +134,12 @@ func TestMySQL_Create(t *testing.T) {
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("pets").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` int AUTO_INCREMENT, `name` varchar(255), `owner_id` int, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` bigint AUTO_INCREMENT, `name` varchar(255), `owner_id` bigint, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
WithArgs("FOREIGN KEY", "pets_owner").
@@ -174,9 +174,9 @@ func TestMySQL_Create(t *testing.T) {
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", ""))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` int")).
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` bigint")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
@@ -206,9 +206,9 @@ func TestMySQL_Create(t *testing.T) {
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "").
AddRow("age", "int(11)", "NO", "NO", "NULL", "", "", ""))
AddRow("age", "bigint(20)", "NO", "NO", "NULL", "", "", ""))
mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `name` varchar(255) CHARSET utf8 NULL")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
@@ -250,9 +250,9 @@ func TestMySQL_Create(t *testing.T) {
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", ""))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` int")).
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` bigint")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))

View File

@@ -158,14 +158,18 @@ func (c *Column) MySQLType(version string) (t string) {
t = "tinyint"
case field.TypeUint8:
t = "tinyint unsigned"
case field.TypeInt64:
t = "bigint"
case field.TypeUint64:
t = "bigint unsigned"
case field.TypeInt, field.TypeInt16, field.TypeInt32:
case field.TypeInt16:
t = "smallint"
case field.TypeUint16:
t = "smallint unsigned"
case field.TypeInt32:
t = "int"
case field.TypeUint, field.TypeUint16, field.TypeUint32:
case field.TypeUint32:
t = "int unsigned"
case field.TypeInt, field.TypeInt64:
t = "bigint"
case field.TypeUint, field.TypeUint64:
t = "bigint unsigned"
case field.TypeString:
size := c.Size
if size == 0 {
@@ -187,7 +191,7 @@ func (c *Column) MySQLType(version string) (t string) {
c.Nullable = &nullable
}
default:
panic("unsupported type " + c.Type.String())
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name))
}
return t
}
@@ -241,11 +245,17 @@ func (c *Column) ScanMySQL(rows *sql.Rows) error {
return r == '(' || r == ')' || r == ' '
}); parts[0] {
case "int":
c.Type = field.TypeInt
case "double":
c.Type = field.TypeFloat64
case "timestamp":
c.Type = field.TypeTime
c.Type = field.TypeInt32
case "smallint":
c.Type = field.TypeInt16
if len(parts) == 3 { // smallint(5) unsigned.
c.Type = field.TypeUint16
}
case "bigint":
c.Type = field.TypeInt64
if len(parts) == 3 { // bigint(20) unsigned.
c.Type = field.TypeUint64
}
case "tinyint":
size, err := strconv.Atoi(parts[1])
if err != nil {
@@ -259,6 +269,10 @@ func (c *Column) ScanMySQL(rows *sql.Rows) error {
default:
c.Type = field.TypeInt8
}
case "double":
c.Type = field.TypeFloat64
case "timestamp":
c.Type = field.TypeTime
case "varchar":
c.Type = field.TypeString
size, err := strconv.Atoi(parts[1])
@@ -270,6 +284,29 @@ func (c *Column) ScanMySQL(rows *sql.Rows) error {
return nil
}
// ConvertibleTo reports whether a column can be converted to the new column without altering its data.
func (c *Column) ConvertibleTo(d *Column) bool {
switch {
case c.Type == d.Type:
return c.Size <= d.Size
case c.IntType() && d.IntType() || c.UintType() && d.UintType():
return c.Type <= d.Type
case c.UintType() && d.IntType():
// uintX can not be converted to intY, when X > Y.
return c.Type-field.TypeUint8 <= d.Type-field.TypeInt8
}
return c.FloatType() && d.FloatType()
}
// IntType reports whether the column is an int type (int8 ... int64).
func (c Column) IntType() bool { return c.Type >= field.TypeInt8 && c.Type <= field.TypeInt64 }
// UintType reports of the given type is a uint type (int8 ... int64).
func (c Column) UintType() bool { return c.Type >= field.TypeUint8 && c.Type <= field.TypeUint64 }
// FloatType reports of the given type is a float type (float32, float64).
func (c Column) FloatType() bool { return c.Type == field.TypeFloat32 || c.Type == field.TypeFloat64 }
// unique adds the `UNIQUE` attribute if the column is a unique type.
// it is exist in a different function to share the common declaration
// between the two dialects.

View File

@@ -0,0 +1,51 @@
package schema
import (
"testing"
"fbc/ent/field"
"github.com/stretchr/testify/require"
)
func TestColumn_ConvertibleTo(t *testing.T) {
c1 := &Column{Type: field.TypeString, Size: 10}
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 10}))
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 255}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 9}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32}))
c1 = &Column{Type: field.TypeFloat32}
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32}))
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat64}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint}))
c1 = &Column{Type: field.TypeFloat64}
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32}))
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat64}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint}))
c1 = &Column{Type: field.TypeUint}
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeUint}))
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt}))
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt64}))
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeUint64}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt8}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint8}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint16}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint32}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString}))
c1 = &Column{Type: field.TypeInt}
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt}))
require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt64}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt8}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt32}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint8}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint16}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint32}))
require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString}))
}