mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: support mariadb json fields on migration (#1011)
This commit is contained in:
@@ -6,6 +6,7 @@ package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
@@ -102,6 +103,11 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
|
||||
for _, idx := range indexes {
|
||||
t.AddIndex(idx.Name, idx.Unique, idx.columns)
|
||||
}
|
||||
if d.mariadb() {
|
||||
if err := d.normalizeJSON(ctx, tx, t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -521,6 +527,57 @@ func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Quer
|
||||
return sql.Queries{b}
|
||||
}
|
||||
|
||||
// normalizeJSON normalize MariaDB longtext columns to type JSON.
|
||||
func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) error {
|
||||
var (
|
||||
names []driver.Value
|
||||
columns = make(map[string]*Column)
|
||||
)
|
||||
for _, c := range t.Columns {
|
||||
if c.typ == "longtext" {
|
||||
columns[c.Name] = c
|
||||
names = append(names, c.Name)
|
||||
}
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Select("CONSTRAINT_NAME", "CHECK_CLAUSE").
|
||||
From(sql.Table("INFORMATION_SCHEMA.CHECK_CONSTRAINTS").Unquote()).
|
||||
Where(sql.And(
|
||||
sql.EQ("CONSTRAINT_SCHEMA", sql.Raw("(SELECT DATABASE())")),
|
||||
sql.EQ("TABLE_NAME", t.Name),
|
||||
sql.InValues("CONSTRAINT_NAME", names...),
|
||||
)).
|
||||
Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return fmt.Errorf("mysql: query table constraints %v", err)
|
||||
}
|
||||
// Call Close in cases of failures (Close is idempotent).
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var name, check string
|
||||
if err := rows.Scan(&name, &check); err != nil {
|
||||
return fmt.Errorf("mysql: scan table constraints")
|
||||
}
|
||||
c, ok := columns[name]
|
||||
if !ok || !strings.HasPrefix(check, "json_valid") {
|
||||
continue
|
||||
}
|
||||
c.Type = field.TypeJSON
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Close()
|
||||
}
|
||||
|
||||
// mariadb reports if the migration runs on MariaDB.
|
||||
func (d *MySQL) mariadb() bool {
|
||||
return strings.Contains(d.version, "MariaDB")
|
||||
}
|
||||
|
||||
// parseColumn returns column parts, size and signed-info from a MySQL type.
|
||||
func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) {
|
||||
switch parts = strings.FieldsFunc(typ, func(r rune) bool {
|
||||
|
||||
@@ -150,7 +150,7 @@ func TestMySQL_Create(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create new table with foreign key diabled",
|
||||
name: "create new table with foreign key disabled",
|
||||
options: []MigrateOption{
|
||||
WithForeignKeys(false),
|
||||
},
|
||||
@@ -1032,6 +1032,44 @@ func TestMySQL_Create(t *testing.T) {
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
// MariaDB specific tests.
|
||||
{
|
||||
name: "mariadb/json columns",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Nullable: true},
|
||||
{Name: "json", Type: field.TypeJSON, Nullable: true},
|
||||
{Name: "longtext", Type: field.TypeString, Nullable: true, Size: math.MaxInt32},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock mysqlMock) {
|
||||
mock.start("10.5.8-MariaDB-1:10.5.8+maria~focal")
|
||||
mock.tableExists("users", true)
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
|
||||
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "").
|
||||
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "").
|
||||
AddRow("json", "longtext", "YES", "YES", "NULL", "", "utf8mb4", "utf8mb4_bin").
|
||||
AddRow("longtext", "longtext", "YES", "YES", "NULL", "", "utf8mb4", "utf8mb4_bin"))
|
||||
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM INFORMATION_SCHEMA.STATISTICS WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}).
|
||||
AddRow("PRIMARY", "id", "0", "1"))
|
||||
mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME`, `CHECK_CLAUSE` FROM INFORMATION_SCHEMA.CHECK_CONSTRAINTS WHERE `CONSTRAINT_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? AND `CONSTRAINT_NAME` IN (?, ?)")).
|
||||
WithArgs("users", "json", "longtext").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME", "CHECK_CLAUSE"}).
|
||||
AddRow("json", "json_valid(`json`)"))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user