mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
add charset support for fields
Summary: Basically, adding support for Hebrew characters. Reviewed By: alexsn Differential Revision: D16068537 fbshipit-source-id: 4e934da5ea97c9e804317f746556ab1d51faebcc
This commit is contained in:
committed by
Facebook Github Bot
parent
37ae2b744e
commit
e8e96f014f
@@ -63,8 +63,13 @@ func (b *Builder) AppendComma(s ...string) *Builder {
|
||||
|
||||
// Arg appends an argument to the builder.
|
||||
func (b *Builder) Arg(a interface{}) *Builder {
|
||||
b.WriteString("?")
|
||||
b.args = append(b.args, a)
|
||||
switch a := a.(type) {
|
||||
case *raw:
|
||||
b.WriteString(a.s)
|
||||
default:
|
||||
b.WriteString("?")
|
||||
b.args = append(b.args, a)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -168,6 +173,8 @@ type TableBuilder struct {
|
||||
b Builder
|
||||
name string // table name.
|
||||
exists bool // check existence.
|
||||
charset string // table charset.
|
||||
collation string // table collation.
|
||||
columns []*ColumnBuilder // table columns.
|
||||
primary []string // primary key.
|
||||
constraints []Node // foreign keys and indices.
|
||||
@@ -230,6 +237,18 @@ func (t *TableBuilder) Constraints(fks ...*ForeignKeyBuilder) *TableBuilder {
|
||||
return t
|
||||
}
|
||||
|
||||
// Charset appends the `CHARACTER SET` clause to the statement. MySQL only.
|
||||
func (t *TableBuilder) Charset(s string) *TableBuilder {
|
||||
t.charset = s
|
||||
return t
|
||||
}
|
||||
|
||||
// Collate appends the `COLLATE` clause to the statement. MySQL only.
|
||||
func (t *TableBuilder) Collate(s string) *TableBuilder {
|
||||
t.collation = s
|
||||
return t
|
||||
}
|
||||
|
||||
// Query returns query representation of a `CREATE TABLE` statement.
|
||||
func (t *TableBuilder) Query() (string, []interface{}) {
|
||||
t.b.WriteString("CREATE TABLE ")
|
||||
@@ -254,6 +273,12 @@ func (t *TableBuilder) Query() (string, []interface{}) {
|
||||
b.Comma().JoinComma(t.constraints...)
|
||||
}
|
||||
})
|
||||
if t.charset != "" {
|
||||
t.b.WriteString(" CHARACTER SET " + t.charset)
|
||||
}
|
||||
if t.collation != "" {
|
||||
t.b.WriteString(" COLLATE " + t.collation)
|
||||
}
|
||||
return t.b.String(), t.b.args
|
||||
}
|
||||
|
||||
@@ -297,6 +322,12 @@ func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter {
|
||||
return t
|
||||
}
|
||||
|
||||
// Modify appends the `MODIFY COLUMN` clause to the given `ALTER TABLE` statement.
|
||||
func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter {
|
||||
t.nodes = append(t.nodes, &Wrapper{"MODIFY COLUMN %s", c})
|
||||
return t
|
||||
}
|
||||
|
||||
// AddForeignKey adds a foreign key constraint to the `ALTER TABLE` statement.
|
||||
func (t *TableAlter) AddForeignKey(fk *ForeignKeyBuilder) *TableAlter {
|
||||
t.nodes = append(t.nodes, &Wrapper{"ADD CONSTRAINT %s", fk})
|
||||
@@ -938,8 +969,9 @@ func Distinct(columns ...string) string {
|
||||
|
||||
// SelectTable is a table selector.
|
||||
type SelectTable struct {
|
||||
name string
|
||||
as string
|
||||
quote bool
|
||||
name string
|
||||
as string
|
||||
}
|
||||
|
||||
// Table returns a new table selector.
|
||||
@@ -948,7 +980,7 @@ type SelectTable struct {
|
||||
// return Select(t1.C("name"))
|
||||
//
|
||||
func Table(name string) *SelectTable {
|
||||
return &SelectTable{name: name}
|
||||
return &SelectTable{quote: true, name: name}
|
||||
}
|
||||
|
||||
// As adds the AS clause to the table selector.
|
||||
@@ -975,12 +1007,24 @@ func (s *SelectTable) Columns(columns ...string) []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// Unquote makes the table name to be formatted as raw string (unquoted).
|
||||
// It is useful whe you don't want to query tables under the current database.
|
||||
// For example: "INFORMATION_SCHEMA.TABLE_CONSTRAINTS" in MySQL.
|
||||
func (s *SelectTable) Unquote() *SelectTable {
|
||||
s.quote = false
|
||||
return s
|
||||
}
|
||||
|
||||
// ref returns the table reference.
|
||||
func (s *SelectTable) ref() string {
|
||||
if s.as == "" {
|
||||
switch {
|
||||
case !s.quote:
|
||||
return s.name
|
||||
case s.as == "":
|
||||
return fmt.Sprintf("`%s`", s.name)
|
||||
default:
|
||||
return fmt.Sprintf("`%s` AS `%s`", s.name, s.as)
|
||||
}
|
||||
return fmt.Sprintf("`%s` AS `%s`", s.name, s.as)
|
||||
}
|
||||
|
||||
// implement the table view.
|
||||
@@ -1202,7 +1246,7 @@ func (s *Selector) Query() (string, []interface{}) {
|
||||
b.WriteString("*")
|
||||
}
|
||||
b.WriteString(" FROM ")
|
||||
b.Append(s.from.ref())
|
||||
b.WriteString(s.from.ref())
|
||||
if len(s.joins) > 0 {
|
||||
for _, join := range s.joins {
|
||||
b.WriteString(fmt.Sprintf(" %s ", join.kind))
|
||||
@@ -1305,6 +1349,13 @@ func (w *Wrapper) Query() (string, []interface{}) {
|
||||
return fmt.Sprintf(w.format, query), args
|
||||
}
|
||||
|
||||
// Raw returns a raw sql node that is placed as-is in the query.
|
||||
func Raw(s string) Node { return &raw{s} }
|
||||
|
||||
type raw struct{ s string }
|
||||
|
||||
func (r *raw) Query() (string, []interface{}) { return r.s, nil }
|
||||
|
||||
func isFunc(s string) bool {
|
||||
return strings.Contains(s, "(") && strings.Contains(s, ")")
|
||||
}
|
||||
|
||||
@@ -26,6 +26,27 @@ func TestBuilder(t *testing.T) {
|
||||
PrimaryKey("id"),
|
||||
wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`))",
|
||||
},
|
||||
{
|
||||
input: CreateTable("users").
|
||||
Columns(
|
||||
Column("id").Type("int").Attr("auto_increment"),
|
||||
Column("name").Type("varchar(255)"),
|
||||
).
|
||||
PrimaryKey("id").
|
||||
Charset("utf8mb4"),
|
||||
wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4",
|
||||
},
|
||||
{
|
||||
input: CreateTable("users").
|
||||
Columns(
|
||||
Column("id").Type("int").Attr("auto_increment"),
|
||||
Column("name").Type("varchar(255)"),
|
||||
).
|
||||
PrimaryKey("id").
|
||||
Charset("utf8mb4").
|
||||
Collate("utf8mb4_general_ci"),
|
||||
wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci",
|
||||
},
|
||||
{
|
||||
input: CreateTable("users").
|
||||
IfNotExists().
|
||||
@@ -80,6 +101,11 @@ func TestBuilder(t *testing.T) {
|
||||
),
|
||||
wantQuery: "ALTER TABLE `users` ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`), ADD CONSTRAINT FOREIGN KEY(`location_id`) REFERENCES `locations`(`id`)",
|
||||
},
|
||||
{
|
||||
input: AlterTable("users").
|
||||
ModifyColumn(Column("age").Type("int")),
|
||||
wantQuery: "ALTER TABLE `users` MODIFY COLUMN `age` int",
|
||||
},
|
||||
{
|
||||
input: Insert("users").Columns("age").Values(1),
|
||||
wantQuery: "INSERT INTO `users` (`age`) VALUES (?)",
|
||||
@@ -195,6 +221,10 @@ func TestBuilder(t *testing.T) {
|
||||
input: Select().From(Table("users")),
|
||||
wantQuery: "SELECT * FROM `users`",
|
||||
},
|
||||
{
|
||||
input: Select().From(Table("users").Unquote()),
|
||||
wantQuery: "SELECT * FROM users",
|
||||
},
|
||||
{
|
||||
input: Select().From(Table("users").As("u")),
|
||||
wantQuery: "SELECT * FROM `users` AS `u`",
|
||||
|
||||
@@ -29,6 +29,10 @@ func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
|
||||
}
|
||||
|
||||
func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error {
|
||||
version, err := d.version(ctx, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range tables {
|
||||
switch exist, err := d.tableExist(ctx, tx, t.Name); {
|
||||
case err != nil:
|
||||
@@ -38,25 +42,28 @@ func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
changes, err := changeSet(curr, t)
|
||||
change, err := changeSet(curr, t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(changes.Columns) > 0 {
|
||||
if len(change.add) != 0 || len(change.modify) != 0 {
|
||||
b := sql.AlterTable(curr.Name)
|
||||
for _, c := range changes.Columns {
|
||||
b.AddColumn(c.DSL())
|
||||
for _, c := range change.add {
|
||||
b.AddColumn(c.MySQL(version))
|
||||
}
|
||||
for _, c := range change.modify {
|
||||
b.ModifyColumn(c.MySQL(version))
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("alter table %q: %v", t.Name, err)
|
||||
}
|
||||
}
|
||||
if len(changes.Indexes) > 0 {
|
||||
if len(change.indexes) > 0 {
|
||||
panic("missing implementation")
|
||||
}
|
||||
default: // !exist
|
||||
query, args := t.DSL().Query()
|
||||
query, args := t.MySQL(version).Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return fmt.Errorf("create table %q: %v", t.Name, err)
|
||||
}
|
||||
@@ -94,22 +101,32 @@ func (d *MySQL) create(ctx context.Context, tx dialect.Tx, tables ...*Table) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *MySQL) version(ctx context.Context, tx dialect.Tx) (string, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
|
||||
return "", fmt.Errorf("dialect/mysql: querying mysql version %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", fmt.Errorf("dialect/mysql: version variable was not found")
|
||||
}
|
||||
version := make([]string, 2)
|
||||
if err := rows.Scan(&version[0], &version[1]); err != nil {
|
||||
return "", fmt.Errorf("dialect/mysql: scanning mysql version: %v", err)
|
||||
}
|
||||
return version[1], nil
|
||||
}
|
||||
|
||||
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
return d.exist(
|
||||
ctx,
|
||||
tx,
|
||||
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = (SELECT DATABASE()) AND TABLE_NAME = ?",
|
||||
name,
|
||||
)
|
||||
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
|
||||
return d.exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
return d.exist(
|
||||
ctx,
|
||||
tx,
|
||||
`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE TABLE_SCHEMA=(SELECT DATABASE()) AND CONSTRAINT_TYPE="FOREIGN KEY" AND CONSTRAINT_NAME = ?`,
|
||||
name,
|
||||
)
|
||||
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS").Unquote()).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("CONSTRAINT_TYPE", "FOREIGN KEY").And().EQ("CONSTRAINT_NAME", name)).Query()
|
||||
return d.exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
||||
@@ -131,7 +148,9 @@ func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...
|
||||
// table loads the current table description from the database.
|
||||
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Describe(name).Query()
|
||||
query, args := sql.Select("column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name").
|
||||
From(sql.Table("INFORMATION_SCHEMA.COLUMNS").Unquote()).
|
||||
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("dialect/mysql: reading table description %v", err)
|
||||
}
|
||||
@@ -150,10 +169,17 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// changeSet returns a dummy table represents the change set that need
|
||||
// to be applied on the table. it fails if one of the changes is invalid.
|
||||
func changeSet(curr, new *Table) (*Table, error) {
|
||||
changes := &Table{}
|
||||
// changes to apply on existing table.
|
||||
type changes struct {
|
||||
add []*Column
|
||||
modify []*Column
|
||||
indexes []*Index
|
||||
}
|
||||
|
||||
// changeSet returns a changes object to be applied on existing table.
|
||||
// It fails if one of the changes is invalid.
|
||||
func changeSet(curr, new *Table) (*changes, error) {
|
||||
change := &changes{}
|
||||
// pks.
|
||||
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
|
||||
return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name)
|
||||
@@ -169,23 +195,25 @@ func changeSet(curr, new *Table) (*Table, error) {
|
||||
for _, c1 := range new.Columns {
|
||||
switch c2, ok := curr.column(c1.Name); {
|
||||
case !ok:
|
||||
changes.Columns = append(changes.Columns, c1)
|
||||
change.add = append(change.add, c1)
|
||||
case c1.Type != c2.Type:
|
||||
return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, c1.Type, c2.Type)
|
||||
case c1.Unique != c2.Unique:
|
||||
return nil, fmt.Errorf("changing column cardinality for %q is invalid", c1.Name)
|
||||
case c1.Charset != "" && c1.Charset != c2.Charset || c1.Collation != "" && c1.Charset != c2.Collation:
|
||||
change.modify = append(change.modify, c1)
|
||||
}
|
||||
}
|
||||
// indexes.
|
||||
for _, idx1 := range new.Indexes {
|
||||
switch idx2, ok := curr.index(idx1.Name); {
|
||||
case !ok:
|
||||
changes.Indexes = append(changes.Indexes, idx1)
|
||||
change.indexes = append(change.indexes, idx1)
|
||||
case idx1.Unique != idx2.Unique:
|
||||
return nil, fmt.Errorf("changing index %q uniqness is invalid", idx1.Name)
|
||||
}
|
||||
}
|
||||
return changes, nil
|
||||
return change, nil
|
||||
}
|
||||
|
||||
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).
|
||||
|
||||
@@ -32,6 +32,8 @@ func TestMySQL_Create(t *testing.T) {
|
||||
name: "no tables",
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
@@ -45,20 +47,46 @@ func TestMySQL_Create(t *testing.T) {
|
||||
},
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Nullable: &null},
|
||||
{Name: "name", Type: field.TypeString, Nullable: &null, Charset: "utf8"},
|
||||
{Name: "age", Type: field.TypeInt},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape(`SELECT COUNT(*)
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_SCHEMA = (SELECT DATABASE())
|
||||
AND TABLE_NAME = ?`)).
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) NULL, `age` int, PRIMARY KEY(`id`))")).
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) CHARSET utf8 NULL, `age` int, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create new table 5.6",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "age", Type: field.TypeInt},
|
||||
{Name: "name", Type: field.TypeString, Unique: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.6.35"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `age` int, `name` varchar(191) UNIQUE, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
@@ -101,27 +129,20 @@ func TestMySQL_Create(t *testing.T) {
|
||||
}(),
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape(`SELECT COUNT(*)
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_SCHEMA = (SELECT DATABASE())
|
||||
AND TABLE_NAME = ?`)).
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`))")).
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` int AUTO_INCREMENT, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape(`SELECT COUNT(*)
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_SCHEMA = (SELECT DATABASE())
|
||||
AND TABLE_NAME = ?`)).
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("pets").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` int AUTO_INCREMENT, `name` varchar(255), `owner_id` int, PRIMARY KEY(`id`))")).
|
||||
mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` int AUTO_INCREMENT, `name` varchar(255), `owner_id` int, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape(`SELECT COUNT(*)
|
||||
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS
|
||||
WHERE TABLE_SCHEMA=(SELECT DATABASE())
|
||||
AND CONSTRAINT_TYPE="FOREIGN KEY"
|
||||
AND CONSTRAINT_NAME = ?`)).
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
|
||||
WithArgs("FOREIGN KEY", "pets_owner").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec(escape("ALTER TABLE `pets` ADD CONSTRAINT `pets_owner` FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
@@ -145,21 +166,54 @@ func TestMySQL_Create(t *testing.T) {
|
||||
},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape(`SELECT COUNT(*)
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_SCHEMA = (SELECT DATABASE())
|
||||
AND TABLE_NAME = ?`)).
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
mock.ExpectQuery("DESCRIBE `users`").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).
|
||||
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment").
|
||||
AddRow("name", "varchar(255)", "NO", "YES", "NULL", ""))
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
|
||||
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
|
||||
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", ""))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` int")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modify column",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "age", Type: field.TypeInt},
|
||||
{Name: "name", Type: field.TypeString, Nullable: &null, Charset: "utf8"},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
|
||||
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
|
||||
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "").
|
||||
AddRow("age", "int(11)", "NO", "NO", "NULL", "", "", ""))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `name` varchar(255) CHARSET utf8 NULL")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add edge to table",
|
||||
tables: func() []*Table {
|
||||
@@ -188,23 +242,19 @@ func TestMySQL_Create(t *testing.T) {
|
||||
}(),
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape(`SELECT COUNT(*)
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_SCHEMA = (SELECT DATABASE())
|
||||
AND TABLE_NAME = ?`)).
|
||||
mock.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.23"))
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
mock.ExpectQuery("DESCRIBE `users`").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).
|
||||
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment").
|
||||
AddRow("name", "varchar(255)", "NO", "YES", "NULL", ""))
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
|
||||
AddRow("id", "int(11)", "NO", "PRI", "NULL", "auto_increment", "", "").
|
||||
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", ""))
|
||||
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` int")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(escape(`SELECT COUNT(*)
|
||||
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS
|
||||
WHERE TABLE_SCHEMA=(SELECT DATABASE())
|
||||
AND CONSTRAINT_TYPE="FOREIGN KEY"
|
||||
AND CONSTRAINT_NAME = ?`)).
|
||||
mock.ExpectQuery(escape("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec("ALTER TABLE `users` ADD CONSTRAINT `.{64}` FOREIGN KEY\\(`spouse_id`\\) REFERENCES `users`\\(`id`\\) ON DELETE CASCADE").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
@@ -34,15 +34,18 @@ func (t *Table) AddForeignKey(fk *ForeignKey) *Table {
|
||||
return t
|
||||
}
|
||||
|
||||
// DSL returns the default DSL query for table creation.
|
||||
func (t *Table) DSL() *sql.TableBuilder {
|
||||
// MySQL returns the MySQL DSL query for table creation.
|
||||
func (t *Table) MySQL(version string) *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name).IfNotExists()
|
||||
for _, c := range t.Columns {
|
||||
b.Column(c.DSL())
|
||||
b.Column(c.MySQL(version))
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
// default character set to MySQL table.
|
||||
// columns can be override using the "Charset" field.
|
||||
b.Charset("utf8mb4")
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -105,6 +108,8 @@ type Column struct {
|
||||
Increment bool // auto increment attribute.
|
||||
Nullable *bool // null or not null attribute.
|
||||
Default string // default value.
|
||||
Charset string // column character set.
|
||||
Collation string // column collation.
|
||||
}
|
||||
|
||||
// UniqueKey returns boolean indicates if this column is a unique key.
|
||||
@@ -115,13 +120,20 @@ func (c *Column) UniqueKey() bool { return c.Key == "UNI" }
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (c *Column) PrimaryKey() bool { return c.Key == "PRI" }
|
||||
|
||||
// DSL returns the default DSL query for table creation.
|
||||
func (c *Column) DSL() *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(c.MySQLType()).Attr(c.Attr)
|
||||
// MySQL returns the MySQL DSL query for table creation.
|
||||
// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable].
|
||||
func (c *Column) MySQL(version string) *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(c.MySQLType(version)).Attr(c.Attr)
|
||||
if c.Charset != "" {
|
||||
b.Attr("CHARSET " + c.Charset)
|
||||
}
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("AUTO_INCREMENT")
|
||||
}
|
||||
if c.Collation != "" {
|
||||
b.Attr("COLLATE " + c.Collation)
|
||||
}
|
||||
c.nullable(b)
|
||||
return b
|
||||
}
|
||||
@@ -138,7 +150,7 @@ func (c *Column) SQLite() *sql.ColumnBuilder {
|
||||
}
|
||||
|
||||
// MySQLType returns the MySQL string type for this column.
|
||||
func (c *Column) MySQLType() (t string) {
|
||||
func (c *Column) MySQLType(version string) (t string) {
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "boolean"
|
||||
@@ -157,7 +169,7 @@ func (c *Column) MySQLType() (t string) {
|
||||
case field.TypeString:
|
||||
size := c.Size
|
||||
if size == 0 {
|
||||
size = 255
|
||||
size = c.defaultSize(version)
|
||||
}
|
||||
if size < 1<<16 {
|
||||
t = fmt.Sprintf("varchar(%d)", size)
|
||||
@@ -209,14 +221,18 @@ func (c *Column) SQLiteType() (t string) {
|
||||
// ScanMySQL scans the information from MySQL column description.
|
||||
func (c *Column) ScanMySQL(rows *sql.Rows) error {
|
||||
var (
|
||||
charset sql.NullString
|
||||
collate sql.NullString
|
||||
nullable sql.NullString
|
||||
defaults sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr); err != nil {
|
||||
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &charset, &collate); err != nil {
|
||||
return fmt.Errorf("scanning column description: %v", err)
|
||||
}
|
||||
c.Unique = c.UniqueKey()
|
||||
c.Charset = charset.String
|
||||
c.Default = defaults.String
|
||||
c.Collation = collate.String
|
||||
if nullable.Valid {
|
||||
null := nullable.String == "YES"
|
||||
c.Nullable = &null
|
||||
@@ -275,6 +291,21 @@ func (c *Column) nullable(b *sql.ColumnBuilder) {
|
||||
}
|
||||
}
|
||||
|
||||
// defaultSize returns the default size for MySQL varchar
|
||||
// type based on column size, charset and table indexes.
|
||||
func (c *Column) defaultSize(version string) int {
|
||||
size := 255
|
||||
parts := strings.Split(version, ".")
|
||||
// non-unique or invalid version.
|
||||
if !c.Unique || len(parts) == 1 || parts[0] == "" || parts[1] == "" {
|
||||
return size
|
||||
}
|
||||
if major, minor := parts[0], parts[1]; major > "5" || minor > "6" {
|
||||
return size
|
||||
}
|
||||
return 191
|
||||
}
|
||||
|
||||
// ForeignKey definition for creation.
|
||||
type ForeignKey struct {
|
||||
Symbol string // foreign-key name. Generated if empty.
|
||||
|
||||
Reference in New Issue
Block a user