mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/entsql: support setting expression as column default value
Fixed https://github.com/ent/ent/issues/3069
This commit is contained in:
committed by
Ariel Mashraki
parent
3b5a535801
commit
1e5f68646f
@@ -35,19 +35,44 @@ type Annotation struct {
|
||||
//
|
||||
Collation string `json:"collation,omitempty"`
|
||||
|
||||
// Default specifies the default value of a column. Note that using this option
|
||||
// will override the default behavior of the code-generation. For example:
|
||||
// Default specifies a literal default value of a column. Note that using
|
||||
// this option overrides the default behavior of the code-generation.
|
||||
//
|
||||
// entsql.Annotation{
|
||||
// Default: "CURRENT_TIMESTAMP",
|
||||
// }
|
||||
//
|
||||
// entsql.Annotation{
|
||||
// Default: "uuid_generate_v4()",
|
||||
// Default: `{"key":"value"}`,
|
||||
// }
|
||||
//
|
||||
Default string `json:"default,omitempty"`
|
||||
|
||||
// DefaultExpr specifies an expression default value of a column. Using this option,
|
||||
// users can define custom expressions to be set as database default values. Note that
|
||||
// using this option overrides the default behavior of the code-generation.
|
||||
//
|
||||
// entsql.Annotation{
|
||||
// DefaultExpr: "CURRENT_TIMESTAMP",
|
||||
// }
|
||||
//
|
||||
// entsql.Annotation{
|
||||
// DefaultExpr: "uuid_generate_v4()",
|
||||
// }
|
||||
//
|
||||
// entsql.Annotation{
|
||||
// DefaultExpr: "(a + b)",
|
||||
// }
|
||||
//
|
||||
DefaultExpr string `json:"default_expr,omitempty"`
|
||||
|
||||
// DefaultExpr specifies an expression default value of a column per dialect.
|
||||
// See, DefaultExpr for full doc.
|
||||
//
|
||||
// entsql.Annotation{
|
||||
// DefaultExprs: map[string]string{
|
||||
// dialect.MySQL: "uuid()",
|
||||
// dialect.Postgres: "uuid_generate_v4",
|
||||
// }
|
||||
//
|
||||
DefaultExprs map[string]string `json:"default_exprs,omitempty"`
|
||||
|
||||
// Options defines the additional table options. For example:
|
||||
//
|
||||
// entsql.Annotation{
|
||||
@@ -111,6 +136,39 @@ func (Annotation) Name() string {
|
||||
return "EntSQL"
|
||||
}
|
||||
|
||||
// DefaultExpr specifies an expression default value for the annotated column.
|
||||
// Using this option, users can define custom expressions to be set as database
|
||||
// default values.Note that using this option overrides the default behavior of
|
||||
// the code-generation.
|
||||
//
|
||||
// field.UUID("id", uuid.Nil).
|
||||
// Default(uuid.New).
|
||||
// Annotations(
|
||||
// entsql.DefaultExpr("uuid_generate_v4()"),
|
||||
// )
|
||||
func DefaultExpr(expr string) *Annotation {
|
||||
return &Annotation{
|
||||
DefaultExpr: expr,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultExprs specifies an expression default value for the annotated
|
||||
// column per dialect. See, DefaultExpr for full doc.
|
||||
//
|
||||
// field.UUID("id", uuid.Nil).
|
||||
// Default(uuid.New).
|
||||
// Annotations(
|
||||
// entsql.DefaultExprs(map[string]string{
|
||||
// dialect.MySQL: "uuid()",
|
||||
// dialect.Postgres: "uuid_generate_v4()",
|
||||
// }),
|
||||
// )
|
||||
func DefaultExprs(exprs map[string]string) *Annotation {
|
||||
return &Annotation{
|
||||
DefaultExprs: exprs,
|
||||
}
|
||||
}
|
||||
|
||||
// Merge implements the schema.Merger interface.
|
||||
func (a Annotation) Merge(other schema.Annotation) schema.Annotation {
|
||||
var ant Annotation
|
||||
@@ -133,6 +191,20 @@ func (a Annotation) Merge(other schema.Annotation) schema.Annotation {
|
||||
if c := ant.Collation; c != "" {
|
||||
a.Collation = c
|
||||
}
|
||||
if d := ant.Default; d != "" {
|
||||
a.Default = d
|
||||
}
|
||||
if d := ant.DefaultExpr; d != "" {
|
||||
a.DefaultExpr = d
|
||||
}
|
||||
if d := ant.DefaultExprs; d != nil {
|
||||
if a.DefaultExprs == nil {
|
||||
a.DefaultExprs = make(map[string]string)
|
||||
}
|
||||
for dialect, x := range d {
|
||||
a.DefaultExprs[dialect] = x
|
||||
}
|
||||
}
|
||||
if o := ant.Options; o != "" {
|
||||
a.Options = o
|
||||
}
|
||||
@@ -514,7 +586,12 @@ func (a IndexAnnotation) Merge(other schema.Annotation) schema.Annotation {
|
||||
a.Type = ant.Type
|
||||
}
|
||||
if ant.Types != nil {
|
||||
a.Types = ant.Types
|
||||
if a.Types == nil {
|
||||
a.Types = make(map[string]string)
|
||||
}
|
||||
for dialect, t := range ant.Types {
|
||||
a.Types[dialect] = t
|
||||
}
|
||||
}
|
||||
if ant.Where != "" {
|
||||
a.Where = ant.Where
|
||||
|
||||
@@ -882,26 +882,41 @@ func (a *Atlas) aColumns(et *Table, at *schema.Table) error {
|
||||
}
|
||||
|
||||
func (a *Atlas) atDefault(c1 *Column, c2 *schema.Column) error {
|
||||
if x, ok := c1.Default.(*schema.RawExpr); ok {
|
||||
c2.SetDefault(x)
|
||||
if c1.Default == nil || !a.sqlDialect.supportsDefault(c1) {
|
||||
return nil
|
||||
}
|
||||
switch {
|
||||
case c1.Default == nil:
|
||||
case c1.Type == field.TypeJSON && a.sqlDialect.supportsDefault(c1):
|
||||
s, ok := c1.Default.(string)
|
||||
switch x := c1.Default.(type) {
|
||||
case Expr:
|
||||
if len(x) > 1 && (x[0] != '(' || x[len(x)-1] != ')') {
|
||||
x = "(" + x + ")"
|
||||
}
|
||||
c2.SetDefault(&schema.RawExpr{X: string(x)})
|
||||
case map[string]Expr:
|
||||
d, ok := x[a.driver.Dialect()]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid default value for JSON column %q: %v", c1.Name, c1.Default)
|
||||
return nil
|
||||
}
|
||||
c2.SetDefault(&schema.Literal{V: strings.ReplaceAll(s, "'", "''")})
|
||||
case c1.supportDefault():
|
||||
// Keep backwards compatibility with the old default value format.
|
||||
x := fmt.Sprint(c1.Default)
|
||||
if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime {
|
||||
// Escape single quote by replacing each with 2.
|
||||
x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
|
||||
if len(d) > 1 && (d[0] != '(' || d[len(d)-1] != ')') {
|
||||
d = "(" + d + ")"
|
||||
}
|
||||
c2.SetDefault(&schema.RawExpr{X: string(d)})
|
||||
default:
|
||||
switch {
|
||||
case c1.Type == field.TypeJSON:
|
||||
s, ok := c1.Default.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid default value for JSON column %q: %v", c1.Name, c1.Default)
|
||||
}
|
||||
c2.SetDefault(&schema.Literal{V: strings.ReplaceAll(s, "'", "''")})
|
||||
default:
|
||||
// Keep backwards compatibility with the old default value format.
|
||||
x := fmt.Sprint(c1.Default)
|
||||
if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime {
|
||||
// Escape single quote by replacing each with 2.
|
||||
x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
|
||||
}
|
||||
c2.SetDefault(&schema.RawExpr{X: x})
|
||||
}
|
||||
c2.SetDefault(&schema.RawExpr{X: x})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -816,7 +816,15 @@ func (d *MySQL) atTable(t1 *Table, t2 *schema.Table) {
|
||||
|
||||
func (d *MySQL) supportsDefault(c *Column) bool {
|
||||
_, maria := d.mariadb()
|
||||
return c.supportDefault() || maria
|
||||
switch c.Default.(type) {
|
||||
case Expr, map[string]Expr:
|
||||
if maria {
|
||||
return compareVersions(d.version, "10.2.0") >= 0
|
||||
}
|
||||
return c.supportDefault() && compareVersions(d.version, "8.0.0") >= 0
|
||||
default:
|
||||
return c.supportDefault() || maria
|
||||
}
|
||||
}
|
||||
|
||||
func (d *MySQL) atTypeC(c1 *Column, c2 *schema.Column) error {
|
||||
|
||||
@@ -297,6 +297,10 @@ type Column struct {
|
||||
foreign *ForeignKey // linked foreign-key.
|
||||
}
|
||||
|
||||
// Expr represents a raw expression. It is used to distinguish between
|
||||
// literal values and raw expressions when defining default values.
|
||||
type Expr string
|
||||
|
||||
// UniqueKey returns boolean indicates if this column is a unique key.
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (c *Column) UniqueKey() bool { return c.Key == UniqueKey }
|
||||
|
||||
Reference in New Issue
Block a user