From 1e5f68646f2652ca934831bddaa9d441a37edbee Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Fri, 11 Nov 2022 21:47:10 +0200 Subject: [PATCH] dialect/sql/entsql: support setting expression as column default value Fixed https://github.com/ent/ent/issues/3069 --- dialect/entsql/annotation.go | 93 ++++++- dialect/sql/schema/atlas.go | 45 ++-- dialect/sql/schema/mysql.go | 10 +- dialect/sql/schema/schema.go | 4 + entc/gen/func.go | 1 + entc/gen/template/migrate/schema.tmpl | 13 +- entc/gen/type.go | 16 +- .../migrate/entv2/migrate/schema.go | 4 +- entc/integration/migrate/entv2/mutation.go | 148 ++++++++++- entc/integration/migrate/entv2/runtime.go | 4 +- entc/integration/migrate/entv2/schema/user.go | 10 + entc/integration/migrate/entv2/user.go | 24 +- entc/integration/migrate/entv2/user/user.go | 6 + entc/integration/migrate/entv2/user/where.go | 240 ++++++++++++++++++ entc/integration/migrate/entv2/user_create.go | 36 +++ entc/integration/migrate/entv2/user_update.go | 104 ++++++++ entc/integration/migrate/migrate_test.go | 19 ++ 17 files changed, 744 insertions(+), 33 deletions(-) diff --git a/dialect/entsql/annotation.go b/dialect/entsql/annotation.go index dbf816820..388db46a4 100644 --- a/dialect/entsql/annotation.go +++ b/dialect/entsql/annotation.go @@ -35,19 +35,44 @@ type Annotation struct { // Collation string `json:"collation,omitempty"` - // Default specifies the default value of a column. Note that using this option - // will override the default behavior of the code-generation. For example: + // Default specifies a literal default value of a column. Note that using + // this option overrides the default behavior of the code-generation. // // entsql.Annotation{ - // Default: "CURRENT_TIMESTAMP", - // } - // - // entsql.Annotation{ - // Default: "uuid_generate_v4()", + // Default: `{"key":"value"}`, // } // Default string `json:"default,omitempty"` + // DefaultExpr specifies an expression default value of a column. Using this option, + // users can define custom expressions to be set as database default values. Note that + // using this option overrides the default behavior of the code-generation. + // + // entsql.Annotation{ + // DefaultExpr: "CURRENT_TIMESTAMP", + // } + // + // entsql.Annotation{ + // DefaultExpr: "uuid_generate_v4()", + // } + // + // entsql.Annotation{ + // DefaultExpr: "(a + b)", + // } + // + DefaultExpr string `json:"default_expr,omitempty"` + + // DefaultExpr specifies an expression default value of a column per dialect. + // See, DefaultExpr for full doc. + // + // entsql.Annotation{ + // DefaultExprs: map[string]string{ + // dialect.MySQL: "uuid()", + // dialect.Postgres: "uuid_generate_v4", + // } + // + DefaultExprs map[string]string `json:"default_exprs,omitempty"` + // Options defines the additional table options. For example: // // entsql.Annotation{ @@ -111,6 +136,39 @@ func (Annotation) Name() string { return "EntSQL" } +// DefaultExpr specifies an expression default value for the annotated column. +// Using this option, users can define custom expressions to be set as database +// default values.Note that using this option overrides the default behavior of +// the code-generation. +// +// field.UUID("id", uuid.Nil). +// Default(uuid.New). +// Annotations( +// entsql.DefaultExpr("uuid_generate_v4()"), +// ) +func DefaultExpr(expr string) *Annotation { + return &Annotation{ + DefaultExpr: expr, + } +} + +// DefaultExprs specifies an expression default value for the annotated +// column per dialect. See, DefaultExpr for full doc. +// +// field.UUID("id", uuid.Nil). +// Default(uuid.New). +// Annotations( +// entsql.DefaultExprs(map[string]string{ +// dialect.MySQL: "uuid()", +// dialect.Postgres: "uuid_generate_v4()", +// }), +// ) +func DefaultExprs(exprs map[string]string) *Annotation { + return &Annotation{ + DefaultExprs: exprs, + } +} + // Merge implements the schema.Merger interface. func (a Annotation) Merge(other schema.Annotation) schema.Annotation { var ant Annotation @@ -133,6 +191,20 @@ func (a Annotation) Merge(other schema.Annotation) schema.Annotation { if c := ant.Collation; c != "" { a.Collation = c } + if d := ant.Default; d != "" { + a.Default = d + } + if d := ant.DefaultExpr; d != "" { + a.DefaultExpr = d + } + if d := ant.DefaultExprs; d != nil { + if a.DefaultExprs == nil { + a.DefaultExprs = make(map[string]string) + } + for dialect, x := range d { + a.DefaultExprs[dialect] = x + } + } if o := ant.Options; o != "" { a.Options = o } @@ -514,7 +586,12 @@ func (a IndexAnnotation) Merge(other schema.Annotation) schema.Annotation { a.Type = ant.Type } if ant.Types != nil { - a.Types = ant.Types + if a.Types == nil { + a.Types = make(map[string]string) + } + for dialect, t := range ant.Types { + a.Types[dialect] = t + } } if ant.Where != "" { a.Where = ant.Where diff --git a/dialect/sql/schema/atlas.go b/dialect/sql/schema/atlas.go index d51421f22..717c9f5c2 100644 --- a/dialect/sql/schema/atlas.go +++ b/dialect/sql/schema/atlas.go @@ -882,26 +882,41 @@ func (a *Atlas) aColumns(et *Table, at *schema.Table) error { } func (a *Atlas) atDefault(c1 *Column, c2 *schema.Column) error { - if x, ok := c1.Default.(*schema.RawExpr); ok { - c2.SetDefault(x) + if c1.Default == nil || !a.sqlDialect.supportsDefault(c1) { return nil } - switch { - case c1.Default == nil: - case c1.Type == field.TypeJSON && a.sqlDialect.supportsDefault(c1): - s, ok := c1.Default.(string) + switch x := c1.Default.(type) { + case Expr: + if len(x) > 1 && (x[0] != '(' || x[len(x)-1] != ')') { + x = "(" + x + ")" + } + c2.SetDefault(&schema.RawExpr{X: string(x)}) + case map[string]Expr: + d, ok := x[a.driver.Dialect()] if !ok { - return fmt.Errorf("invalid default value for JSON column %q: %v", c1.Name, c1.Default) + return nil } - c2.SetDefault(&schema.Literal{V: strings.ReplaceAll(s, "'", "''")}) - case c1.supportDefault(): - // Keep backwards compatibility with the old default value format. - x := fmt.Sprint(c1.Default) - if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime { - // Escape single quote by replacing each with 2. - x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) + if len(d) > 1 && (d[0] != '(' || d[len(d)-1] != ')') { + d = "(" + d + ")" + } + c2.SetDefault(&schema.RawExpr{X: string(d)}) + default: + switch { + case c1.Type == field.TypeJSON: + s, ok := c1.Default.(string) + if !ok { + return fmt.Errorf("invalid default value for JSON column %q: %v", c1.Name, c1.Default) + } + c2.SetDefault(&schema.Literal{V: strings.ReplaceAll(s, "'", "''")}) + default: + // Keep backwards compatibility with the old default value format. + x := fmt.Sprint(c1.Default) + if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime { + // Escape single quote by replacing each with 2. + x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) + } + c2.SetDefault(&schema.RawExpr{X: x}) } - c2.SetDefault(&schema.RawExpr{X: x}) } return nil } diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index 586f9dea6..010af4d78 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -816,7 +816,15 @@ func (d *MySQL) atTable(t1 *Table, t2 *schema.Table) { func (d *MySQL) supportsDefault(c *Column) bool { _, maria := d.mariadb() - return c.supportDefault() || maria + switch c.Default.(type) { + case Expr, map[string]Expr: + if maria { + return compareVersions(d.version, "10.2.0") >= 0 + } + return c.supportDefault() && compareVersions(d.version, "8.0.0") >= 0 + default: + return c.supportDefault() || maria + } } func (d *MySQL) atTypeC(c1 *Column, c2 *schema.Column) error { diff --git a/dialect/sql/schema/schema.go b/dialect/sql/schema/schema.go index 3577e322b..4670579d8 100644 --- a/dialect/sql/schema/schema.go +++ b/dialect/sql/schema/schema.go @@ -297,6 +297,10 @@ type Column struct { foreign *ForeignKey // linked foreign-key. } +// Expr represents a raw expression. It is used to distinguish between +// literal values and raw expressions when defining default values. +type Expr string + // UniqueKey returns boolean indicates if this column is a unique key. // Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects. func (c *Column) UniqueKey() bool { return c.Key == UniqueKey } diff --git a/entc/gen/func.go b/entc/gen/func.go index aaed0d8f3..1b48404bf 100644 --- a/entc/gen/func.go +++ b/entc/gen/func.go @@ -320,6 +320,7 @@ func aggregate() map[string]bool { // keys returns the given map keys. func keys(v reflect.Value) ([]string, error) { + v = indirect(v) if k := v.Type().Kind(); k != reflect.Map { return nil, fmt.Errorf("expect map for keys, got: %s", k) } diff --git a/entc/gen/template/migrate/schema.tmpl b/entc/gen/template/migrate/schema.tmpl index a6c40a50f..48e767eaf 100644 --- a/entc/gen/template/migrate/schema.tmpl +++ b/entc/gen/template/migrate/schema.tmpl @@ -38,9 +38,18 @@ var ( {{- with $c.Size }} Size: {{ . }},{{ end }} {{- with $c.Attr }} Attr: "{{ . }}",{{ end }} {{- with $c.Enums }} Enums: []string{ {{ range $e := . }}"{{ $e }}",{{ end }} },{{ end }} - {{- if not (isNil $c.Default) }} Default: {{ quote $c.Default }},{{ end }} + {{- if not (isNil $c.Default) -}} + {{- $t := printf "%T" $c.Default -}} + {{- if eq $t "schema.Expr" -}} + Default: schema.Expr("{{ $c.Default }}"), + {{- else if eq $t "map[string]schema.Expr" -}} + Default: map[string]schema.Expr{ {{ range $k := keys $c.Default }} "{{ $k }}": "{{ index $c.Default $k }}",{{ end }} }, + {{- else -}} + Default: {{ quote $c.Default }}, + {{- end -}} + {{- end }} {{- if $c.Collation }} Collation: "{{ $c.Collation }}",{{ end }} - {{- with $c.SchemaType }} SchemaType: map[string]string{ {{ range $k, $v := . }}"{{ $k }}": "{{ $v }}",{{ end }}}{{ end }}}, + {{- with $c.SchemaType }} SchemaType: map[string]string{ {{ range $k := keys . }}"{{ $k }}": "{{ index $c.SchemaType $k }}",{{ end }}}{{ end }}}, {{- end }} } {{- $table := pascal $t.Name | printf "%sTable" }} diff --git a/entc/gen/type.go b/entc/gen/type.go index 403aaee28..6265f0087 100644 --- a/entc/gen/type.go +++ b/entc/gen/type.go @@ -905,7 +905,7 @@ func ValidSchemaName(name string) error { // checkField checks the schema field. func (t *Type) checkField(tf *Field, f *load.Field) (err error) { - switch { + switch ant := tf.EntSQL(); { case f.Name == "": err = fmt.Errorf("field name cannot be empty") case f.Info == nil || !f.Info.Valid(): @@ -923,6 +923,8 @@ func (t *Type) checkField(tf *Field, f *load.Field) (err error) { } case tf.Validators > 0 && !tf.ConvertedToBasic(): err = fmt.Errorf("GoType %q for field %q must be converted to the basic %q type for validators", tf.Type, f.Name, tf.Type.Type) + case ant != nil && ant.Default != "" && (ant.DefaultExpr != "" || ant.DefaultExprs != nil): + err = fmt.Errorf("field %q cannot have both default value and default expression annotations", f.Name) } return err } @@ -1336,8 +1338,18 @@ func (f Field) Column() *schema.Column { } // Override the default-value defined in the // schema if it was provided by an annotation. - if ant := f.EntSQL(); ant != nil && ant.Default != "" { + switch ant := f.EntSQL(); { + case ant == nil: + case ant.Default != "": c.Default = ant.Default + case ant.DefaultExpr != "": + c.Default = schema.Expr(ant.DefaultExpr) + case ant.DefaultExprs != nil: + x := make(map[string]schema.Expr) + for k, v := range ant.DefaultExprs { + x[k] = schema.Expr(v) + } + c.Default = x } // Override the collation defined in the // schema if it was provided by an annotation. diff --git a/entc/integration/migrate/entv2/migrate/schema.go b/entc/integration/migrate/entv2/migrate/schema.go index cc82590b2..266a190eb 100644 --- a/entc/integration/migrate/entv2/migrate/schema.go +++ b/entc/integration/migrate/entv2/migrate/schema.go @@ -158,6 +158,8 @@ var ( {Name: "status", Type: field.TypeEnum, Nullable: true, Enums: []string{"done", "pending"}}, {Name: "workplace", Type: field.TypeString, Nullable: true}, {Name: "roles", Type: field.TypeJSON, Nullable: true, Default: "[]"}, + {Name: "default_expr", Type: field.TypeString, Nullable: true, Default: schema.Expr("lower('hello')")}, + {Name: "default_exprs", Type: field.TypeString, Nullable: true, Default: map[string]schema.Expr{"mysql": "TO_BASE64('ent')", "postgres": "md5('ent')", "sqlite3": "hex('ent')"}}, {Name: "created_at", Type: field.TypeTime, Default: "CURRENT_TIMESTAMP"}, {Name: "drop_optional", Type: field.TypeString}, {Name: "blog_admins", Type: field.TypeInt, Nullable: true, SchemaType: map[string]string{"postgres": "serial"}}, @@ -170,7 +172,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "users_blogs_admins", - Columns: []*schema.Column{UsersColumns[20]}, + Columns: []*schema.Column{UsersColumns[22]}, RefColumns: []*schema.Column{BlogsColumns[0]}, OnDelete: schema.SetNull, }, diff --git a/entc/integration/migrate/entv2/mutation.go b/entc/integration/migrate/entv2/mutation.go index 6675eeee2..b3e20a51b 100644 --- a/entc/integration/migrate/entv2/mutation.go +++ b/entc/integration/migrate/entv2/mutation.go @@ -3341,6 +3341,8 @@ type UserMutation struct { workplace *string roles *[]string appendroles []string + default_expr *string + default_exprs *string created_at *time.Time drop_optional *string clearedFields map[string]struct{} @@ -4213,6 +4215,104 @@ func (m *UserMutation) ResetRoles() { delete(m.clearedFields, user.FieldRoles) } +// SetDefaultExpr sets the "default_expr" field. +func (m *UserMutation) SetDefaultExpr(s string) { + m.default_expr = &s +} + +// DefaultExpr returns the value of the "default_expr" field in the mutation. +func (m *UserMutation) DefaultExpr() (r string, exists bool) { + v := m.default_expr + if v == nil { + return + } + return *v, true +} + +// OldDefaultExpr returns the old "default_expr" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldDefaultExpr(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultExpr is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultExpr requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultExpr: %w", err) + } + return oldValue.DefaultExpr, nil +} + +// ClearDefaultExpr clears the value of the "default_expr" field. +func (m *UserMutation) ClearDefaultExpr() { + m.default_expr = nil + m.clearedFields[user.FieldDefaultExpr] = struct{}{} +} + +// DefaultExprCleared returns if the "default_expr" field was cleared in this mutation. +func (m *UserMutation) DefaultExprCleared() bool { + _, ok := m.clearedFields[user.FieldDefaultExpr] + return ok +} + +// ResetDefaultExpr resets all changes to the "default_expr" field. +func (m *UserMutation) ResetDefaultExpr() { + m.default_expr = nil + delete(m.clearedFields, user.FieldDefaultExpr) +} + +// SetDefaultExprs sets the "default_exprs" field. +func (m *UserMutation) SetDefaultExprs(s string) { + m.default_exprs = &s +} + +// DefaultExprs returns the value of the "default_exprs" field in the mutation. +func (m *UserMutation) DefaultExprs() (r string, exists bool) { + v := m.default_exprs + if v == nil { + return + } + return *v, true +} + +// OldDefaultExprs returns the old "default_exprs" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldDefaultExprs(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultExprs is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultExprs requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultExprs: %w", err) + } + return oldValue.DefaultExprs, nil +} + +// ClearDefaultExprs clears the value of the "default_exprs" field. +func (m *UserMutation) ClearDefaultExprs() { + m.default_exprs = nil + m.clearedFields[user.FieldDefaultExprs] = struct{}{} +} + +// DefaultExprsCleared returns if the "default_exprs" field was cleared in this mutation. +func (m *UserMutation) DefaultExprsCleared() bool { + _, ok := m.clearedFields[user.FieldDefaultExprs] + return ok +} + +// ResetDefaultExprs resets all changes to the "default_exprs" field. +func (m *UserMutation) ResetDefaultExprs() { + m.default_exprs = nil + delete(m.clearedFields, user.FieldDefaultExprs) +} + // SetCreatedAt sets the "created_at" field. func (m *UserMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -4451,7 +4551,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 21) if m.mixed_string != nil { fields = append(fields, user.FieldMixedString) } @@ -4503,6 +4603,12 @@ func (m *UserMutation) Fields() []string { if m.roles != nil { fields = append(fields, user.FieldRoles) } + if m.default_expr != nil { + fields = append(fields, user.FieldDefaultExpr) + } + if m.default_exprs != nil { + fields = append(fields, user.FieldDefaultExprs) + } if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -4551,6 +4657,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.Workplace() case user.FieldRoles: return m.Roles() + case user.FieldDefaultExpr: + return m.DefaultExpr() + case user.FieldDefaultExprs: + return m.DefaultExprs() case user.FieldCreatedAt: return m.CreatedAt() case user.FieldDropOptional: @@ -4598,6 +4708,10 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldWorkplace(ctx) case user.FieldRoles: return m.OldRoles(ctx) + case user.FieldDefaultExpr: + return m.OldDefaultExpr(ctx) + case user.FieldDefaultExprs: + return m.OldDefaultExprs(ctx) case user.FieldCreatedAt: return m.OldCreatedAt(ctx) case user.FieldDropOptional: @@ -4730,6 +4844,20 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetRoles(v) return nil + case user.FieldDefaultExpr: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultExpr(v) + return nil + case user.FieldDefaultExprs: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultExprs(v) + return nil case user.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -4813,6 +4941,12 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldRoles) { fields = append(fields, user.FieldRoles) } + if m.FieldCleared(user.FieldDefaultExpr) { + fields = append(fields, user.FieldDefaultExpr) + } + if m.FieldCleared(user.FieldDefaultExprs) { + fields = append(fields, user.FieldDefaultExprs) + } return fields } @@ -4851,6 +4985,12 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldRoles: m.ClearRoles() return nil + case user.FieldDefaultExpr: + m.ClearDefaultExpr() + return nil + case user.FieldDefaultExprs: + m.ClearDefaultExprs() + return nil } return fmt.Errorf("unknown User nullable field %s", name) } @@ -4910,6 +5050,12 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldRoles: m.ResetRoles() return nil + case user.FieldDefaultExpr: + m.ResetDefaultExpr() + return nil + case user.FieldDefaultExprs: + m.ResetDefaultExprs() + return nil case user.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/entc/integration/migrate/entv2/runtime.go b/entc/integration/migrate/entv2/runtime.go index 2dbb41346..2d215d6f0 100644 --- a/entc/integration/migrate/entv2/runtime.go +++ b/entc/integration/migrate/entv2/runtime.go @@ -55,11 +55,11 @@ func init() { // user.BlobValidator is a validator for the "blob" field. It is called by the builders before save. user.BlobValidator = userDescBlob.Validators[0].(func([]byte) error) // userDescCreatedAt is the schema descriptor for created_at field. - userDescCreatedAt := userFields[16].Descriptor() + userDescCreatedAt := userFields[18].Descriptor() // user.DefaultCreatedAt holds the default value on creation for the created_at field. user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time) // userDescDropOptional is the schema descriptor for drop_optional field. - userDescDropOptional := userFields[17].Descriptor() + userDescDropOptional := userFields[19].Descriptor() // user.DefaultDropOptional holds the default value on creation for the drop_optional field. user.DefaultDropOptional = userDescDropOptional.Default.(func() string) } diff --git a/entc/integration/migrate/entv2/schema/user.go b/entc/integration/migrate/entv2/schema/user.go index d398191af..63257839c 100644 --- a/entc/integration/migrate/entv2/schema/user.go +++ b/entc/integration/migrate/entv2/schema/user.go @@ -103,6 +103,16 @@ func (User) Fields() []ent.Field { field.Strings("roles"). Optional(). Annotations(entsql.Annotation{Default: `[]`}), + field.String("default_expr"). + Optional(). + Annotations(entsql.DefaultExpr("lower('hello')")), + field.String("default_exprs"). + Optional(). + Annotations(entsql.DefaultExprs(map[string]string{ + dialect.MySQL: "TO_BASE64('ent')", + dialect.SQLite: "hex('ent')", + dialect.Postgres: "md5('ent')", + })), // add a new column with generated values by the database. field.Time("created_at"). Default(time.Now). diff --git a/entc/integration/migrate/entv2/user.go b/entc/integration/migrate/entv2/user.go index bc9d9535c..86c6f90b7 100644 --- a/entc/integration/migrate/entv2/user.go +++ b/entc/integration/migrate/entv2/user.go @@ -56,6 +56,10 @@ type User struct { Workplace string `json:"workplace,omitempty"` // Roles holds the value of the "roles" field. Roles []string `json:"roles,omitempty"` + // DefaultExpr holds the value of the "default_expr" field. + DefaultExpr string `json:"default_expr,omitempty"` + // DefaultExprs holds the value of the "default_exprs" field. + DefaultExprs string `json:"default_exprs,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // DropOptional holds the value of the "drop_optional" field. @@ -121,7 +125,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldID, user.FieldAge: values[i] = new(sql.NullInt64) - case user.FieldMixedString, user.FieldMixedEnum, user.FieldName, user.FieldDescription, user.FieldNickname, user.FieldPhone, user.FieldTitle, user.FieldNewName, user.FieldNewToken, user.FieldState, user.FieldStatus, user.FieldWorkplace, user.FieldDropOptional: + case user.FieldMixedString, user.FieldMixedEnum, user.FieldName, user.FieldDescription, user.FieldNickname, user.FieldPhone, user.FieldTitle, user.FieldNewName, user.FieldNewToken, user.FieldState, user.FieldStatus, user.FieldWorkplace, user.FieldDefaultExpr, user.FieldDefaultExprs, user.FieldDropOptional: values[i] = new(sql.NullString) case user.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -252,6 +256,18 @@ func (u *User) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field roles: %w", err) } } + case user.FieldDefaultExpr: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field default_expr", values[i]) + } else if value.Valid { + u.DefaultExpr = value.String + } + case user.FieldDefaultExprs: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field default_exprs", values[i]) + } else if value.Valid { + u.DefaultExprs = value.String + } case user.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -365,6 +381,12 @@ func (u *User) String() string { builder.WriteString("roles=") builder.WriteString(fmt.Sprintf("%v", u.Roles)) builder.WriteString(", ") + builder.WriteString("default_expr=") + builder.WriteString(u.DefaultExpr) + builder.WriteString(", ") + builder.WriteString("default_exprs=") + builder.WriteString(u.DefaultExprs) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(u.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/entc/integration/migrate/entv2/user/user.go b/entc/integration/migrate/entv2/user/user.go index 3924a3802..b8cb9f49e 100644 --- a/entc/integration/migrate/entv2/user/user.go +++ b/entc/integration/migrate/entv2/user/user.go @@ -50,6 +50,10 @@ const ( FieldWorkplace = "workplace" // FieldRoles holds the string denoting the roles field in the database. FieldRoles = "roles" + // FieldDefaultExpr holds the string denoting the default_expr field in the database. + FieldDefaultExpr = "default_expr" + // FieldDefaultExprs holds the string denoting the default_exprs field in the database. + FieldDefaultExprs = "default_exprs" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldDropOptional holds the string denoting the drop_optional field in the database. @@ -104,6 +108,8 @@ var Columns = []string{ FieldStatus, FieldWorkplace, FieldRoles, + FieldDefaultExpr, + FieldDefaultExprs, FieldCreatedAt, FieldDropOptional, } diff --git a/entc/integration/migrate/entv2/user/where.go b/entc/integration/migrate/entv2/user/where.go index edf88fdf9..5c1ebab87 100644 --- a/entc/integration/migrate/entv2/user/where.go +++ b/entc/integration/migrate/entv2/user/where.go @@ -176,6 +176,20 @@ func Workplace(v string) predicate.User { }) } +// DefaultExpr applies equality check predicate on the "default_expr" field. It's identical to DefaultExprEQ. +func DefaultExpr(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprs applies equality check predicate on the "default_exprs" field. It's identical to DefaultExprsEQ. +func DefaultExprs(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldDefaultExprs), v)) + }) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.User { return predicate.User(func(s *sql.Selector) { @@ -1507,6 +1521,232 @@ func RolesNotNil() predicate.User { }) } +// DefaultExprEQ applies the EQ predicate on the "default_expr" field. +func DefaultExprEQ(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprNEQ applies the NEQ predicate on the "default_expr" field. +func DefaultExprNEQ(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprIn applies the In predicate on the "default_expr" field. +func DefaultExprIn(vs ...string) predicate.User { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + s.Where(sql.In(s.C(FieldDefaultExpr), v...)) + }) +} + +// DefaultExprNotIn applies the NotIn predicate on the "default_expr" field. +func DefaultExprNotIn(vs ...string) predicate.User { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NotIn(s.C(FieldDefaultExpr), v...)) + }) +} + +// DefaultExprGT applies the GT predicate on the "default_expr" field. +func DefaultExprGT(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprGTE applies the GTE predicate on the "default_expr" field. +func DefaultExprGTE(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprLT applies the LT predicate on the "default_expr" field. +func DefaultExprLT(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprLTE applies the LTE predicate on the "default_expr" field. +func DefaultExprLTE(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprContains applies the Contains predicate on the "default_expr" field. +func DefaultExprContains(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprHasPrefix applies the HasPrefix predicate on the "default_expr" field. +func DefaultExprHasPrefix(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprHasSuffix applies the HasSuffix predicate on the "default_expr" field. +func DefaultExprHasSuffix(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprIsNil applies the IsNil predicate on the "default_expr" field. +func DefaultExprIsNil() predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.IsNull(s.C(FieldDefaultExpr))) + }) +} + +// DefaultExprNotNil applies the NotNil predicate on the "default_expr" field. +func DefaultExprNotNil() predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NotNull(s.C(FieldDefaultExpr))) + }) +} + +// DefaultExprEqualFold applies the EqualFold predicate on the "default_expr" field. +func DefaultExprEqualFold(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprContainsFold applies the ContainsFold predicate on the "default_expr" field. +func DefaultExprContainsFold(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldDefaultExpr), v)) + }) +} + +// DefaultExprsEQ applies the EQ predicate on the "default_exprs" field. +func DefaultExprsEQ(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsNEQ applies the NEQ predicate on the "default_exprs" field. +func DefaultExprsNEQ(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsIn applies the In predicate on the "default_exprs" field. +func DefaultExprsIn(vs ...string) predicate.User { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + s.Where(sql.In(s.C(FieldDefaultExprs), v...)) + }) +} + +// DefaultExprsNotIn applies the NotIn predicate on the "default_exprs" field. +func DefaultExprsNotIn(vs ...string) predicate.User { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NotIn(s.C(FieldDefaultExprs), v...)) + }) +} + +// DefaultExprsGT applies the GT predicate on the "default_exprs" field. +func DefaultExprsGT(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsGTE applies the GTE predicate on the "default_exprs" field. +func DefaultExprsGTE(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsLT applies the LT predicate on the "default_exprs" field. +func DefaultExprsLT(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsLTE applies the LTE predicate on the "default_exprs" field. +func DefaultExprsLTE(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsContains applies the Contains predicate on the "default_exprs" field. +func DefaultExprsContains(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsHasPrefix applies the HasPrefix predicate on the "default_exprs" field. +func DefaultExprsHasPrefix(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsHasSuffix applies the HasSuffix predicate on the "default_exprs" field. +func DefaultExprsHasSuffix(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsIsNil applies the IsNil predicate on the "default_exprs" field. +func DefaultExprsIsNil() predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.IsNull(s.C(FieldDefaultExprs))) + }) +} + +// DefaultExprsNotNil applies the NotNil predicate on the "default_exprs" field. +func DefaultExprsNotNil() predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NotNull(s.C(FieldDefaultExprs))) + }) +} + +// DefaultExprsEqualFold applies the EqualFold predicate on the "default_exprs" field. +func DefaultExprsEqualFold(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldDefaultExprs), v)) + }) +} + +// DefaultExprsContainsFold applies the ContainsFold predicate on the "default_exprs" field. +func DefaultExprsContainsFold(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldDefaultExprs), v)) + }) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/entc/integration/migrate/entv2/user_create.go b/entc/integration/migrate/entv2/user_create.go index e9d0062f7..0039ec5e9 100644 --- a/entc/integration/migrate/entv2/user_create.go +++ b/entc/integration/migrate/entv2/user_create.go @@ -216,6 +216,34 @@ func (uc *UserCreate) SetRoles(s []string) *UserCreate { return uc } +// SetDefaultExpr sets the "default_expr" field. +func (uc *UserCreate) SetDefaultExpr(s string) *UserCreate { + uc.mutation.SetDefaultExpr(s) + return uc +} + +// SetNillableDefaultExpr sets the "default_expr" field if the given value is not nil. +func (uc *UserCreate) SetNillableDefaultExpr(s *string) *UserCreate { + if s != nil { + uc.SetDefaultExpr(*s) + } + return uc +} + +// SetDefaultExprs sets the "default_exprs" field. +func (uc *UserCreate) SetDefaultExprs(s string) *UserCreate { + uc.mutation.SetDefaultExprs(s) + return uc +} + +// SetNillableDefaultExprs sets the "default_exprs" field if the given value is not nil. +func (uc *UserCreate) SetNillableDefaultExprs(s *string) *UserCreate { + if s != nil { + uc.SetDefaultExprs(*s) + } + return uc +} + // SetCreatedAt sets the "created_at" field. func (uc *UserCreate) SetCreatedAt(t time.Time) *UserCreate { uc.mutation.SetCreatedAt(t) @@ -579,6 +607,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldRoles, field.TypeJSON, value) _node.Roles = value } + if value, ok := uc.mutation.DefaultExpr(); ok { + _spec.SetField(user.FieldDefaultExpr, field.TypeString, value) + _node.DefaultExpr = value + } + if value, ok := uc.mutation.DefaultExprs(); ok { + _spec.SetField(user.FieldDefaultExprs, field.TypeString, value) + _node.DefaultExprs = value + } if value, ok := uc.mutation.CreatedAt(); ok { _spec.SetField(user.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value diff --git a/entc/integration/migrate/entv2/user_update.go b/entc/integration/migrate/entv2/user_update.go index 4d7fbde6a..60ce10b8b 100644 --- a/entc/integration/migrate/entv2/user_update.go +++ b/entc/integration/migrate/entv2/user_update.go @@ -286,6 +286,46 @@ func (uu *UserUpdate) ClearRoles() *UserUpdate { return uu } +// SetDefaultExpr sets the "default_expr" field. +func (uu *UserUpdate) SetDefaultExpr(s string) *UserUpdate { + uu.mutation.SetDefaultExpr(s) + return uu +} + +// SetNillableDefaultExpr sets the "default_expr" field if the given value is not nil. +func (uu *UserUpdate) SetNillableDefaultExpr(s *string) *UserUpdate { + if s != nil { + uu.SetDefaultExpr(*s) + } + return uu +} + +// ClearDefaultExpr clears the value of the "default_expr" field. +func (uu *UserUpdate) ClearDefaultExpr() *UserUpdate { + uu.mutation.ClearDefaultExpr() + return uu +} + +// SetDefaultExprs sets the "default_exprs" field. +func (uu *UserUpdate) SetDefaultExprs(s string) *UserUpdate { + uu.mutation.SetDefaultExprs(s) + return uu +} + +// SetNillableDefaultExprs sets the "default_exprs" field if the given value is not nil. +func (uu *UserUpdate) SetNillableDefaultExprs(s *string) *UserUpdate { + if s != nil { + uu.SetDefaultExprs(*s) + } + return uu +} + +// ClearDefaultExprs clears the value of the "default_exprs" field. +func (uu *UserUpdate) ClearDefaultExprs() *UserUpdate { + uu.mutation.ClearDefaultExprs() + return uu +} + // SetCreatedAt sets the "created_at" field. func (uu *UserUpdate) SetCreatedAt(t time.Time) *UserUpdate { uu.mutation.SetCreatedAt(t) @@ -607,6 +647,18 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { if uu.mutation.RolesCleared() { _spec.ClearField(user.FieldRoles, field.TypeJSON) } + if value, ok := uu.mutation.DefaultExpr(); ok { + _spec.SetField(user.FieldDefaultExpr, field.TypeString, value) + } + if uu.mutation.DefaultExprCleared() { + _spec.ClearField(user.FieldDefaultExpr, field.TypeString) + } + if value, ok := uu.mutation.DefaultExprs(); ok { + _spec.SetField(user.FieldDefaultExprs, field.TypeString, value) + } + if uu.mutation.DefaultExprsCleared() { + _spec.ClearField(user.FieldDefaultExprs, field.TypeString) + } if value, ok := uu.mutation.CreatedAt(); ok { _spec.SetField(user.FieldCreatedAt, field.TypeTime, value) } @@ -1026,6 +1078,46 @@ func (uuo *UserUpdateOne) ClearRoles() *UserUpdateOne { return uuo } +// SetDefaultExpr sets the "default_expr" field. +func (uuo *UserUpdateOne) SetDefaultExpr(s string) *UserUpdateOne { + uuo.mutation.SetDefaultExpr(s) + return uuo +} + +// SetNillableDefaultExpr sets the "default_expr" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableDefaultExpr(s *string) *UserUpdateOne { + if s != nil { + uuo.SetDefaultExpr(*s) + } + return uuo +} + +// ClearDefaultExpr clears the value of the "default_expr" field. +func (uuo *UserUpdateOne) ClearDefaultExpr() *UserUpdateOne { + uuo.mutation.ClearDefaultExpr() + return uuo +} + +// SetDefaultExprs sets the "default_exprs" field. +func (uuo *UserUpdateOne) SetDefaultExprs(s string) *UserUpdateOne { + uuo.mutation.SetDefaultExprs(s) + return uuo +} + +// SetNillableDefaultExprs sets the "default_exprs" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableDefaultExprs(s *string) *UserUpdateOne { + if s != nil { + uuo.SetDefaultExprs(*s) + } + return uuo +} + +// ClearDefaultExprs clears the value of the "default_exprs" field. +func (uuo *UserUpdateOne) ClearDefaultExprs() *UserUpdateOne { + uuo.mutation.ClearDefaultExprs() + return uuo +} + // SetCreatedAt sets the "created_at" field. func (uuo *UserUpdateOne) SetCreatedAt(t time.Time) *UserUpdateOne { uuo.mutation.SetCreatedAt(t) @@ -1377,6 +1469,18 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) if uuo.mutation.RolesCleared() { _spec.ClearField(user.FieldRoles, field.TypeJSON) } + if value, ok := uuo.mutation.DefaultExpr(); ok { + _spec.SetField(user.FieldDefaultExpr, field.TypeString, value) + } + if uuo.mutation.DefaultExprCleared() { + _spec.ClearField(user.FieldDefaultExpr, field.TypeString) + } + if value, ok := uuo.mutation.DefaultExprs(); ok { + _spec.SetField(user.FieldDefaultExprs, field.TypeString, value) + } + if uuo.mutation.DefaultExprsCleared() { + _spec.ClearField(user.FieldDefaultExprs, field.TypeString) + } if value, ok := uuo.mutation.CreatedAt(); ok { _spec.SetField(user.FieldCreatedAt, field.TypeTime, value) } diff --git a/entc/integration/migrate/migrate_test.go b/entc/integration/migrate/migrate_test.go index b7b578113..1bfba4da9 100644 --- a/entc/integration/migrate/migrate_test.go +++ b/entc/integration/migrate/migrate_test.go @@ -66,6 +66,7 @@ func TestMySQL(t *testing.T) { V1ToV2(t, drv.Dialect(), clientv1, clientv2) if version == "8" { CheckConstraint(t, clientv2) + DefaultExpr(t, drv, "SELECT column_default FROM information_schema.columns WHERE table_name = 'users' AND column_name = ?", "lower(_utf8mb4\\'hello\\')", "to_base64(_utf8mb4\\'ent\\')") } NicknameSearch(t, clientv2) TimePrecision(t, drv, "SELECT datetime_precision FROM information_schema.columns WHERE table_name = ? AND column_name = ?") @@ -134,6 +135,7 @@ func TestPostgres(t *testing.T) { TimePrecision(t, drv, "SELECT datetime_precision FROM information_schema.columns WHERE table_name = $1 AND column_name = $2") PartialIndexes(t, drv, "select indexdef from pg_indexes where indexname=$1", "CREATE INDEX user_phone ON public.users USING btree (phone) WHERE active") JSONDefault(t, drv, `SELECT column_default FROM information_schema.columns WHERE table_name = 'users' AND column_name = $1`) + DefaultExpr(t, drv, `SELECT column_default FROM information_schema.columns WHERE table_name = 'users' AND column_name = $1`, "lower('hello'::text)", "md5('ent'::text)") IndexOpClass(t, drv) if version != "10" { IncludeColumns(t, drv) @@ -213,6 +215,7 @@ func TestSQLite(t *testing.T) { idRange(t, u.ID, 7<<32-1, 8<<32) PartialIndexes(t, drv, "select sql from sqlite_master where name=?", "CREATE INDEX `user_phone` ON `users` (`phone`) WHERE active") JSONDefault(t, drv, "SELECT `dflt_value` FROM `pragma_table_info`('users') WHERE `name` = ?") + DefaultExpr(t, drv, "SELECT `dflt_value` FROM `pragma_table_info`('users') WHERE `name` = ?", "lower('hello')", "hex('ent')") // Override the default behavior of LIKE in SQLite. // https://www.sqlite.org/pragma.html#pragma_case_sensitive_like @@ -673,6 +676,22 @@ func JSONDefault(t *testing.T, drv *sql.Driver, query string) { require.NotEmpty(t, s) } +func DefaultExpr(t *testing.T, drv *sql.Driver, query string, expected1, expected2 string) { + ctx := context.Background() + rows, err := drv.QueryContext(ctx, query, user.FieldDefaultExpr) + require.NoError(t, err) + s, err := sql.ScanString(rows) + require.NoError(t, err) + require.NoError(t, rows.Close()) + require.Equal(t, expected1, s) + rows, err = drv.QueryContext(ctx, query, user.FieldDefaultExprs) + require.NoError(t, err) + s, err = sql.ScanString(rows) + require.NoError(t, err) + require.NoError(t, rows.Close()) + require.Equal(t, expected2, s) +} + func IncludeColumns(t *testing.T, drv *sql.Driver) { rows, err := drv.QueryContext(context.Background(), "select indexdef from pg_indexes where indexname='user_workplace'") require.NoError(t, err)