diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index e11d9e6ff..7d4309edc 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -102,7 +102,7 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, for _, idx := range indexes { t.AddIndex(idx.Name, idx.Unique, idx.columns) } - if d.mariadb() { + if _, ok := d.mariadb(); ok { if err := d.normalizeJSON(ctx, tx, t); err != nil { return nil, err } @@ -281,6 +281,15 @@ func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { } c.nullable(b) c.defaultValue(b) + if c.Type == field.TypeJSON { + // Manually add a `CHECK` clause for older versions of MariaDB for validating the + // JSON documents. This constraint is automatically included from version 10.4.3. + if version, ok := d.mariadb(); ok && compareVersions(version, "10.4.3") == -1 { + b.Check(func(b *sql.Builder) { + b.WriteString("JSON_VALID(").Ident(c.Name).WriteByte(')') + }) + } + } return b } @@ -571,9 +580,13 @@ func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) erro return rows.Close() } -// mariadb reports if the migration runs on MariaDB. -func (d *MySQL) mariadb() bool { - return strings.Contains(d.version, "MariaDB") +// mariadb reports if the migration runs on MariaDB and returns the semver string. +func (d *MySQL) mariadb() (string, bool) { + idx := strings.Index(d.version, "MariaDB") + if idx == -1 { + return "", false + } + return d.version[:idx-1], true } // parseColumn returns column parts, size and signed-info from a MySQL type. diff --git a/dialect/sql/schema/mysql_test.go b/dialect/sql/schema/mysql_test.go index e6a6d4c3f..413df63fe 100644 --- a/dialect/sql/schema/mysql_test.go +++ b/dialect/sql/schema/mysql_test.go @@ -1032,7 +1032,51 @@ func TestMySQL_Create(t *testing.T) { }, // MariaDB specific tests. { - name: "mariadb/json columns", + name: "mariadb/10.2.32/create table", + tables: []*Table{ + { + Name: "users", + Columns: []*Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "json", Type: field.TypeJSON, Nullable: true}, + }, + PrimaryKey: []*Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + }, + }, + }, + before: func(mock mysqlMock) { + mock.start("10.2.32-MariaDB") + mock.tableExists("users", false) + mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL CHECK (JSON_VALID(`json`)), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + }, + }, + { + name: "mariadb/10.5.8/create table", + tables: []*Table{ + { + Name: "users", + Columns: []*Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "json", Type: field.TypeJSON, Nullable: true}, + }, + PrimaryKey: []*Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + }, + }, + }, + before: func(mock mysqlMock) { + mock.start("10.5.8-MariaDB") + mock.tableExists("users", false) + mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + }, + }, + { + name: "mariadb/10.5.8/table exists", tables: []*Table{ { Name: "users",