add charset support for fields

Summary: Basically, adding support for Hebrew characters.

Reviewed By: alexsn

Differential Revision: D16068537

fbshipit-source-id: 4e934da5ea97c9e804317f746556ab1d51faebcc
This commit is contained in:
Ariel Mashraki
2019-07-01 08:06:04 -07:00
committed by Facebook Github Bot
parent 37ae2b744e
commit e8e96f014f
24 changed files with 417 additions and 182 deletions

View File

@@ -63,8 +63,13 @@ func (b *Builder) AppendComma(s ...string) *Builder {
// Arg appends an argument to the builder.
func (b *Builder) Arg(a interface{}) *Builder {
b.WriteString("?")
b.args = append(b.args, a)
switch a := a.(type) {
case *raw:
b.WriteString(a.s)
default:
b.WriteString("?")
b.args = append(b.args, a)
}
return b
}
@@ -168,6 +173,8 @@ type TableBuilder struct {
b Builder
name string // table name.
exists bool // check existence.
charset string // table charset.
collation string // table collation.
columns []*ColumnBuilder // table columns.
primary []string // primary key.
constraints []Node // foreign keys and indices.
@@ -230,6 +237,18 @@ func (t *TableBuilder) Constraints(fks ...*ForeignKeyBuilder) *TableBuilder {
return t
}
// Charset appends the `CHARACTER SET` clause to the statement. MySQL only.
func (t *TableBuilder) Charset(s string) *TableBuilder {
t.charset = s
return t
}
// Collate appends the `COLLATE` clause to the statement. MySQL only.
func (t *TableBuilder) Collate(s string) *TableBuilder {
t.collation = s
return t
}
// Query returns query representation of a `CREATE TABLE` statement.
func (t *TableBuilder) Query() (string, []interface{}) {
t.b.WriteString("CREATE TABLE ")
@@ -254,6 +273,12 @@ func (t *TableBuilder) Query() (string, []interface{}) {
b.Comma().JoinComma(t.constraints...)
}
})
if t.charset != "" {
t.b.WriteString(" CHARACTER SET " + t.charset)
}
if t.collation != "" {
t.b.WriteString(" COLLATE " + t.collation)
}
return t.b.String(), t.b.args
}
@@ -297,6 +322,12 @@ func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter {
return t
}
// Modify appends the `MODIFY COLUMN` clause to the given `ALTER TABLE` statement.
func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter {
t.nodes = append(t.nodes, &Wrapper{"MODIFY COLUMN %s", c})
return t
}
// AddForeignKey adds a foreign key constraint to the `ALTER TABLE` statement.
func (t *TableAlter) AddForeignKey(fk *ForeignKeyBuilder) *TableAlter {
t.nodes = append(t.nodes, &Wrapper{"ADD CONSTRAINT %s", fk})
@@ -938,8 +969,9 @@ func Distinct(columns ...string) string {
// SelectTable is a table selector.
type SelectTable struct {
name string
as string
quote bool
name string
as string
}
// Table returns a new table selector.
@@ -948,7 +980,7 @@ type SelectTable struct {
// return Select(t1.C("name"))
//
func Table(name string) *SelectTable {
return &SelectTable{name: name}
return &SelectTable{quote: true, name: name}
}
// As adds the AS clause to the table selector.
@@ -975,12 +1007,24 @@ func (s *SelectTable) Columns(columns ...string) []string {
return names
}
// Unquote makes the table name to be formatted as raw string (unquoted).
// It is useful whe you don't want to query tables under the current database.
// For example: "INFORMATION_SCHEMA.TABLE_CONSTRAINTS" in MySQL.
func (s *SelectTable) Unquote() *SelectTable {
s.quote = false
return s
}
// ref returns the table reference.
func (s *SelectTable) ref() string {
if s.as == "" {
switch {
case !s.quote:
return s.name
case s.as == "":
return fmt.Sprintf("`%s`", s.name)
default:
return fmt.Sprintf("`%s` AS `%s`", s.name, s.as)
}
return fmt.Sprintf("`%s` AS `%s`", s.name, s.as)
}
// implement the table view.
@@ -1202,7 +1246,7 @@ func (s *Selector) Query() (string, []interface{}) {
b.WriteString("*")
}
b.WriteString(" FROM ")
b.Append(s.from.ref())
b.WriteString(s.from.ref())
if len(s.joins) > 0 {
for _, join := range s.joins {
b.WriteString(fmt.Sprintf(" %s ", join.kind))
@@ -1305,6 +1349,13 @@ func (w *Wrapper) Query() (string, []interface{}) {
return fmt.Sprintf(w.format, query), args
}
// Raw returns a raw sql node that is placed as-is in the query.
func Raw(s string) Node { return &raw{s} }
type raw struct{ s string }
func (r *raw) Query() (string, []interface{}) { return r.s, nil }
func isFunc(s string) bool {
return strings.Contains(s, "(") && strings.Contains(s, ")")
}

View File

@@ -26,6 +26,27 @@ func TestBuilder(t *testing.T) {
PrimaryKey("id"),
wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`))",
},
{
input: CreateTable("users").
Columns(
Column("id").Type("int").Attr("auto_increment"),
Column("name").Type("varchar(255)"),
).
PrimaryKey("id").
Charset("utf8mb4"),
wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4",
},
{
input: CreateTable("users").
Columns(
Column("id").Type("int").Attr("auto_increment"),
Column("name").Type("varchar(255)"),
).
PrimaryKey("id").
Charset("utf8mb4").
Collate("utf8mb4_general_ci"),
wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci",
},
{
input: CreateTable("users").
IfNotExists().
@@ -80,6 +101,11 @@ func TestBuilder(t *testing.T) {
),
wantQuery: "ALTER TABLE `users` ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`), ADD CONSTRAINT FOREIGN KEY(`location_id`) REFERENCES `locations`(`id`)",
},
{
input: AlterTable("users").
ModifyColumn(Column("age").Type("int")),
wantQuery: "ALTER TABLE `users` MODIFY COLUMN `age` int",
},
{
input: Insert("users").Columns("age").Values(1),
wantQuery: "INSERT INTO `users` (`age`) VALUES (?)",
@@ -195,6 +221,10 @@ func TestBuilder(t *testing.T) {
input: Select().From(Table("users")),
wantQuery: "SELECT * FROM `users`",
},
{
input: Select().From(Table("users").Unquote()),
wantQuery: "SELECT * FROM users",
},
{
input: Select().From(Table("users").As("u")),
wantQuery: "SELECT * FROM `users` AS `u`",

View File

@@ -29,6 +29,10 @@ func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
}
func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
version, err := d.version(ctx, tx)
if err != nil {
return err
}
for _, t := range tables {
switch exist, err := d.tableExist(ctx, tx, t.Name); {
case err != nil:
@@ -38,25 +42,28 @@ func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) err
if err != nil {
return err
}
changes, err := changeSet(curr, t)
change, err := changeSet(curr, t)
if err != nil {
return err
}
if len(changes.Columns) > 0 {
if len(change.add) != 0 || len(change.modify) != 0 {
b := sql.AlterTable(curr.Name)
for _, c := range changes.Columns {
b.AddColumn(c.DSL())
for _, c := range change.add {
b.AddColumn(c.MySQL(version))
}
for _, c := range change.modify {
b.ModifyColumn(c.MySQL(version))
}
query, args := b.Query()
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
return fmt.Errorf("alter table %q: %v", t.Name, err)
}
}
if len(changes.Indexes) > 0 {
if len(change.indexes) > 0 {
panic("missing implementation")
}
default: // !exist
query, args := t.DSL().Query()
query, args := t.MySQL(version).Query()
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
return fmt.Errorf("create table %q: %v", t.Name, err)
}
@@ -94,22 +101,32 @@ func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) err
return nil
}
func (d *MySQL) version(ctx context.Context, tx dialect.Tx) (string, error) {
rows := &sql.Rows{}
if err := tx.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
return "", fmt.Errorf("dialect/mysql: querying mysql version %v", err)
}
defer rows.Close()
if !rows.Next() {
return "", fmt.Errorf("dialect/mysql: version variable was not found")
}
version := make([]string, 2)
if err := rows.Scan(&version[0], &version[1]); err != nil {
return "", fmt.Errorf("dialect/mysql: scanning mysql version: %v", err)
}
return version[1], nil
}
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
return d.exist(
ctx,
tx,
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = (SELECT DATABASE()) AND TABLE_NAME = ?",
name,
)
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
return d.exist(ctx, tx, query, args...)
}
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
return d.exist(
ctx,
tx,
`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE TABLE_SCHEMA=(SELECT DATABASE()) AND CONSTRAINT_TYPE="FOREIGN KEY" AND CONSTRAINT_NAME = ?`,
name,
)
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("CONSTRAINT_TYPE", "FOREIGN KEY").And().EQ("CONSTRAINT_NAME", name)).Query()
return d.exist(ctx, tx, query, args...)
}
func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
@@ -131,7 +148,9 @@ func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...
// table loads the current table description from the database.
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
rows := &sql.Rows{}
query, args := sql.Describe(name).Query()
query, args := sql.Select("column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name").
From(sql.Table("INFORMATION_SCHEMA.COLUMNS").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("dialect/mysql: reading table description %v", err)
}
@@ -150,10 +169,17 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
return t, nil
}
// changeSet returns a dummy table represents the change set that need
// to be applied on the table. it fails if one of the changes is invalid.
func changeSet(curr, new *Table) (*Table, error) {
changes := &Table{}
// changes to apply on existing table.
type changes struct {
add []*Column
modify []*Column
indexes []*Index
}
// changeSet returns a changes object to be applied on existing table.
// It fails if one of the changes is invalid.
func changeSet(curr, new *Table) (*changes, error) {
change := &changes{}
// pks.
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
@@ -169,23 +195,25 @@ func changeSet(curr, new *Table) (*Table, error) {
for _, c1 := range new.Columns {
switch c2, ok := curr.column(c1.Name); {
case !ok:
changes.Columns = append(changes.Columns, c1)
change.add = append(change.add, c1)
case c1.Type != c2.Type:
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, c1.Type, c2.Type)
case c1.Unique != c2.Unique:
return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name)
case c1.Charset != "" && c1.Charset != c2.Charset || c1.Collation != "" && c1.Charset != c2.Collation:
change.modify = append(change.modify, c1)
}
}
// indexes.
for _, idx1 := range new.Indexes {
switch idx2, ok := curr.index(idx1.Name); {
case !ok:
changes.Indexes = append(changes.Indexes, idx1)
change.indexes = append(change.indexes, idx1)
case idx1.Unique != idx2.Unique:
return nil, fmt.Errorf("changing index %q uniqness is invalid", idx1.Name)
}
}
return changes, nil
return change, nil
}
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).

View File

@@ -32,6 +32,8 @@ func TestMySQL_Create(t *testing.T) {
name: "no tables",
before: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
mock.ExpectCommit()
},
},
@@ -45,20 +47,46 @@ func TestMySQL_Create(t *testing.T) {
},
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString, Nullable: &null},
{Name: "name", Type: field.TypeString, Nullable: &null, Charset: "utf8"},
{Name: "age", Type: field.TypeInt},
},
},
},
before: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery(escape(`SELECT COUNT(*)
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = (SELECT DATABASE())
AND TABLE_NAME = ?`)).
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) NULL, `age` int, PRIMARY KEY(`id`))")).
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) CHARSET utf8 NULL, `age` int, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
{
name: "create new table 5.6",
tables: []*Table{
{
Name: "users",
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "age", Type: field.TypeInt},
{Name: "name", Type: field.TypeString, Unique: true},
},
},
},
before: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.6.35"))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `age` int, `name` varchar(191) UNIQUE, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
@@ -101,27 +129,20 @@ func TestMySQL_Create(t *testing.T) {
}(),
before: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery(escape(`SELECT COUNT(*)
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = (SELECT DATABASE())
AND TABLE_NAME = ?`)).
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`))")).
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape(`SELECT COUNT(*)
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = (SELECT DATABASE())
AND TABLE_NAME = ?`)).
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("pets").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` int AUTO_INCREMENT, `name` varchar(255), `owner_id` int, PRIMARY KEY(`id`))")).
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` int AUTO_INCREMENT, `name` varchar(255), `owner_id` int, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape(`SELECT COUNT(*)
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS
WHERE TABLE_SCHEMA=(SELECT DATABASE())
AND CONSTRAINT_TYPE="FOREIGN KEY"
AND CONSTRAINT_NAME = ?`)).
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
WithArgs("FOREIGN KEY", "pets_owner").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec(escape("ALTER TABLE `pets` ADD CONSTRAINT `pets_owner` FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE")).
WillReturnResult(sqlmock.NewResult(0, 1))
@@ -145,21 +166,54 @@ func TestMySQL_Create(t *testing.T) {
},
before: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery(escape(`SELECT COUNT(*)
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = (SELECT DATABASE())
AND TABLE_NAME = ?`)).
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery("DESCRIBE `users`").
WillReturnRows(sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment").
AddRow("name", "varchar(255)", "NO", "YES", "NULL", ""))
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", ""))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` int")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
{
name: "modify column",
tables: []*Table{
{
Name: "users",
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "age", Type: field.TypeInt},
{Name: "name", Type: field.TypeString, Nullable: &null, Charset: "utf8"},
},
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
},
},
before: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "").
AddRow("age", "int(11)", "NO", "NO", "NULL", "", "", ""))
mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `name` varchar(255) CHARSET utf8 NULL")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
{
name: "add edge to table",
tables: func() []*Table {
@@ -188,23 +242,19 @@ func TestMySQL_Create(t *testing.T) {
}(),
before: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery(escape(`SELECT COUNT(*)
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = (SELECT DATABASE())
AND TABLE_NAME = ?`)).
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery("DESCRIBE `users`").
WillReturnRows(sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment").
AddRow("name", "varchar(255)", "NO", "YES", "NULL", ""))
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", ""))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` int")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(escape(`SELECT COUNT(*)
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS
WHERE TABLE_SCHEMA=(SELECT DATABASE())
AND CONSTRAINT_TYPE="FOREIGN KEY"
AND CONSTRAINT_NAME = ?`)).
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("ALTER TABLE `users` ADD CONSTRAINT `.{64}` FOREIGN KEY\\(`spouse_id`\\) REFERENCES `users`\\(`id`\\) ON DELETE CASCADE").
WillReturnResult(sqlmock.NewResult(0, 1))

View File

@@ -34,15 +34,18 @@ func (t *Table) AddForeignKey(fk *ForeignKey) *Table {
return t
}
// DSL returns the default DSL query for table creation.
func (t *Table) DSL() *sql.TableBuilder {
// MySQL returns the MySQL DSL query for table creation.
func (t *Table) MySQL(version string) *sql.TableBuilder {
b := sql.CreateTable(t.Name).IfNotExists()
for _, c := range t.Columns {
b.Column(c.DSL())
b.Column(c.MySQL(version))
}
for _, pk := range t.PrimaryKey {
b.PrimaryKey(pk.Name)
}
// default character set to MySQL table.
// columns can be override using the "Charset" field.
b.Charset("utf8mb4")
return b
}
@@ -105,6 +108,8 @@ type Column struct {
Increment bool // auto increment attribute.
Nullable *bool // null or not null attribute.
Default string // default value.
Charset string // column character set.
Collation string // column collation.
}
// UniqueKey returns boolean indicates if this column is a unique key.
@@ -115,13 +120,20 @@ func (c *Column) UniqueKey() bool { return c.Key == "UNI" }
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
func (c *Column) PrimaryKey() bool { return c.Key == "PRI" }
// DSL returns the default DSL query for table creation.
func (c *Column) DSL() *sql.ColumnBuilder {
b := sql.Column(c.Name).Type(c.MySQLType()).Attr(c.Attr)
// MySQL returns the MySQL DSL query for table creation.
// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable].
func (c *Column) MySQL(version string) *sql.ColumnBuilder {
b := sql.Column(c.Name).Type(c.MySQLType(version)).Attr(c.Attr)
if c.Charset != "" {
b.Attr("CHARSET " + c.Charset)
}
c.unique(b)
if c.Increment {
b.Attr("AUTO_INCREMENT")
}
if c.Collation != "" {
b.Attr("COLLATE " + c.Collation)
}
c.nullable(b)
return b
}
@@ -138,7 +150,7 @@ func (c *Column) SQLite() *sql.ColumnBuilder {
}
// MySQLType returns the MySQL string type for this column.
func (c *Column) MySQLType() (t string) {
func (c *Column) MySQLType(version string) (t string) {
switch c.Type {
case field.TypeBool:
t = "boolean"
@@ -157,7 +169,7 @@ func (c *Column) MySQLType() (t string) {
case field.TypeString:
size := c.Size
if size == 0 {
size = 255
size = c.defaultSize(version)
}
if size < 1<<16 {
t = fmt.Sprintf("varchar(%d)", size)
@@ -209,14 +221,18 @@ func (c *Column) SQLiteType() (t string) {
// ScanMySQL scans the information from MySQL column description.
func (c *Column) ScanMySQL(rows *sql.Rows) error {
var (
charset sql.NullString
collate sql.NullString
nullable sql.NullString
defaults sql.NullString
)
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr); err != nil {
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &charset, &collate); err != nil {
return fmt.Errorf("scanning column description: %v", err)
}
c.Unique = c.UniqueKey()
c.Charset = charset.String
c.Default = defaults.String
c.Collation = collate.String
if nullable.Valid {
null := nullable.String == "YES"
c.Nullable = &null
@@ -275,6 +291,21 @@ func (c *Column) nullable(b *sql.ColumnBuilder) {
}
}
// defaultSize returns the default size for MySQL varchar
// type based on column size, charset and table indexes.
func (c *Column) defaultSize(version string) int {
size := 255
parts := strings.Split(version, ".")
// non-unique or invalid version.
if !c.Unique || len(parts) == 1 || parts[0] == "" || parts[1] == "" {
return size
}
if major, minor := parts[0], parts[1]; major > "5" || minor > "6" {
return size
}
return 191
}
// ForeignKey definition for creation.
type ForeignKey struct {
Symbol string // foreign-key name. Generated if empty.