dialect/sql/entsql: support setting expression as column default value

Fixed https://github.com/ent/ent/issues/3069
This commit is contained in:
Ariel Mashraki
2022-11-11 21:47:10 +02:00
committed by Ariel Mashraki
parent 3b5a535801
commit 1e5f68646f
17 changed files with 744 additions and 33 deletions

View File

@@ -882,26 +882,41 @@ func (a *Atlas) aColumns(et *Table, at *schema.Table) error {
}
func (a *Atlas) atDefault(c1 *Column, c2 *schema.Column) error {
if x, ok := c1.Default.(*schema.RawExpr); ok {
c2.SetDefault(x)
if c1.Default == nil || !a.sqlDialect.supportsDefault(c1) {
return nil
}
switch {
case c1.Default == nil:
case c1.Type == field.TypeJSON && a.sqlDialect.supportsDefault(c1):
s, ok := c1.Default.(string)
switch x := c1.Default.(type) {
case Expr:
if len(x) > 1 && (x[0] != '(' || x[len(x)-1] != ')') {
x = "(" + x + ")"
}
c2.SetDefault(&schema.RawExpr{X: string(x)})
case map[string]Expr:
d, ok := x[a.driver.Dialect()]
if !ok {
return fmt.Errorf("invalid default value for JSON column %q: %v", c1.Name, c1.Default)
return nil
}
c2.SetDefault(&schema.Literal{V: strings.ReplaceAll(s, "'", "''")})
case c1.supportDefault():
// Keep backwards compatibility with the old default value format.
x := fmt.Sprint(c1.Default)
if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime {
// Escape single quote by replacing each with 2.
x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
if len(d) > 1 && (d[0] != '(' || d[len(d)-1] != ')') {
d = "(" + d + ")"
}
c2.SetDefault(&schema.RawExpr{X: string(d)})
default:
switch {
case c1.Type == field.TypeJSON:
s, ok := c1.Default.(string)
if !ok {
return fmt.Errorf("invalid default value for JSON column %q: %v", c1.Name, c1.Default)
}
c2.SetDefault(&schema.Literal{V: strings.ReplaceAll(s, "'", "''")})
default:
// Keep backwards compatibility with the old default value format.
x := fmt.Sprint(c1.Default)
if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime {
// Escape single quote by replacing each with 2.
x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
}
c2.SetDefault(&schema.RawExpr{X: x})
}
c2.SetDefault(&schema.RawExpr{X: x})
}
return nil
}

View File

@@ -816,7 +816,15 @@ func (d *MySQL) atTable(t1 *Table, t2 *schema.Table) {
func (d *MySQL) supportsDefault(c *Column) bool {
_, maria := d.mariadb()
return c.supportDefault() || maria
switch c.Default.(type) {
case Expr, map[string]Expr:
if maria {
return compareVersions(d.version, "10.2.0") >= 0
}
return c.supportDefault() && compareVersions(d.version, "8.0.0") >= 0
default:
return c.supportDefault() || maria
}
}
func (d *MySQL) atTypeC(c1 *Column, c2 *schema.Column) error {

View File

@@ -297,6 +297,10 @@ type Column struct {
foreign *ForeignKey // linked foreign-key.
}
// Expr represents a raw expression. It is used to distinguish between
// literal values and raw expressions when defining default values.
type Expr string
// UniqueKey returns boolean indicates if this column is a unique key.
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
func (c *Column) UniqueKey() bool { return c.Key == UniqueKey }