diff --git a/dialect/sql/driver.go b/dialect/sql/driver.go index 456fa214b..f7c09c638 100644 --- a/dialect/sql/driver.go +++ b/dialect/sql/driver.go @@ -21,22 +21,22 @@ type Driver struct { } // NewDriver creates a new Driver with the given Conn and dialect. -func NewDriver(c Conn, d string) *Driver { - return &Driver{c, d} +func NewDriver(dialect string, c Conn) *Driver { + return &Driver{dialect: dialect, Conn: c} } // Open wraps the database/sql.Open method and returns a dialect.Driver that implements the an ent/dialect.Driver interface. -func Open(driver, source string) (*Driver, error) { - db, err := sql.Open(driver, source) +func Open(dialect, source string) (*Driver, error) { + db, err := sql.Open(dialect, source) if err != nil { return nil, err } - return NewDriver(Conn{db}, driver), nil + return NewDriver(dialect, Conn{db}), nil } // OpenDB wraps the given database/sql.DB method with a Driver. -func OpenDB(driver string, db *sql.DB) *Driver { - return NewDriver(Conn{db}, driver) +func OpenDB(dialect string, db *sql.DB) *Driver { + return NewDriver(dialect, Conn{db}) } // DB returns the underlying *sql.DB instance. @@ -46,7 +46,7 @@ func (d Driver) DB() *sql.DB { // Dialect implements the dialect.Dialect method. func (d Driver) Dialect() string { - // If the underlying driver is wrapped with opencensus driver. + // If the underlying driver is wrapped with a telemetry driver. for _, name := range []string{dialect.MySQL, dialect.SQLite, dialect.Postgres} { if strings.HasPrefix(d.dialect, name) { return name diff --git a/doc/md/dialects.md b/doc/md/dialects.md index feda919cc..43b057eb7 100755 --- a/doc/md/dialects.md +++ b/doc/md/dialects.md @@ -28,7 +28,7 @@ supported by default by SQLite, and will be added in the future using a [tempora Gremlin does not support migration nor indexes, and **it's considered experimental**. -## TiDB - **preview** +## TiDB **(preview)** TiDB support is in preview and requires the [Atlas migration engine](#atlas-integration). TiDB is MySQL compatible and thus any feature that works on MySQL _should_ work on TiDB as well. diff --git a/doc/md/migrate.md b/doc/md/migrate.md index c0c6d2eb6..093543a66 100755 --- a/doc/md/migrate.md +++ b/doc/md/migrate.md @@ -366,6 +366,8 @@ func main() { } ``` +#### `Diff` Hook Example + In case a field was renamed in the `ent/schema`, Ent won't detect this change as renaming and will propose `DropColumn` and `AddColumn` changes in the diff stage. One way to get over this is to use the [StorageKey](schema-fields.md#storage-key) option on the field and keep the old column name in the database table. @@ -415,4 +417,65 @@ func renameColumnHook(next schema.Differ) schema.Differ { return changes, nil }) } +``` + +#### `Apply` Hook Example + +The `Apply` hook allows accessing and mutating the migration plan and its raw changes (SQL statements), but in addition +to that it is also useful for executing custom SQL statements before or after the plan is applied. For example, changing +a nullable column to non-nullable without a default value is not allowed by default. However, we can work around this +using an `Apply` hook that `UPDATE`s all rows that contain `NULL` value in this column: + +```go +func main() { + client, err := ent.Open("mysql", "root:pass@tcp(localhost:3306)/test") + if err != nil { + log.Fatalf("failed connecting to mysql: %v", err) + } + defer client.Close() + // ... + if err := client.Schema.Create(ctx, schema.WithApplyHook(fillNulls)); err != nil { + log.Fatalf("failed creating schema resources: %v", err) + } +} + +func fillNulls(next schema.Applier) schema.Applier { + return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error { + // There are three ways to UPDATE the NULL values to "Unknown" in this stage. + // Append a custom migrate.Change to the plan, execute an SQL statement directly + // on the dialect.ExecQuerier, or use the ent.Client used by the project. + + // Execute a custom SQL statement. + query, args := sql.Dialect(dialect.MySQL). + Update(user.Table). + Set(user.FieldDropOptional, "Unknown"). + Where(sql.IsNull(user.FieldDropOptional)). + Query() + if err := conn.Exec(ctx, query, args, nil); err != nil { + return err + } + + // Append a custom statement to migrate.Plan. + // + // plan.Changes = append([]*migrate.Change{ + // { + // Cmd: fmt.Sprintf("UPDATE users SET %[1]s = '%[2]s' WHERE %[1]s IS NULL", user.FieldDropOptional, "Unknown"), + // }, + // }, plan.Changes...) + + // Use the ent.Client used by the project. + // + // drv := sql.NewDriver(dialect.MySQL, sql.Conn{ExecQuerier: conn.(*sql.Tx)}) + // if err := ent.NewClient(ent.Driver(drv)). + // User. + // Update(). + // SetDropOptional("Unknown"). + // Where(/* Add predicate to filter only rows with NULL values */). + // Exec(ctx); err != nil { + // return fmt.Errorf("fix default values to uppercase: %w", err) + // } + + return next.Apply(ctx, conn, plan) + }) +} ``` \ No newline at end of file diff --git a/entc/integration/migrate/entv1/migrate/schema.go b/entc/integration/migrate/entv1/migrate/schema.go index f151b1174..6f363bddf 100644 --- a/entc/integration/migrate/entv1/migrate/schema.go +++ b/entc/integration/migrate/entv1/migrate/schema.go @@ -76,6 +76,7 @@ var ( {Name: "state", Type: field.TypeEnum, Nullable: true, Enums: []string{"logged_in", "logged_out"}, Default: "logged_in"}, {Name: "status", Type: field.TypeString, Nullable: true}, {Name: "workplace", Type: field.TypeString, Nullable: true, Size: 30}, + {Name: "drop_optional", Type: field.TypeString, Nullable: true}, {Name: "user_children", Type: field.TypeInt, Nullable: true}, {Name: "user_spouse", Type: field.TypeInt, Unique: true, Nullable: true}, } @@ -87,13 +88,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "users_users_children", - Columns: []*schema.Column{UsersColumns[12]}, + Columns: []*schema.Column{UsersColumns[13]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "users_users_spouse", - Columns: []*schema.Column{UsersColumns[13]}, + Columns: []*schema.Column{UsersColumns[14]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.SetNull, }, diff --git a/entc/integration/migrate/entv1/mutation.go b/entc/integration/migrate/entv1/mutation.go index ea6f3a761..9a2cfe521 100644 --- a/entc/integration/migrate/entv1/mutation.go +++ b/entc/integration/migrate/entv1/mutation.go @@ -1897,6 +1897,7 @@ type UserMutation struct { state *user.State status *string workplace *string + drop_optional *string clearedFields map[string]struct{} parent *int clearedparent bool @@ -2523,6 +2524,55 @@ func (m *UserMutation) ResetWorkplace() { delete(m.clearedFields, user.FieldWorkplace) } +// SetDropOptional sets the "drop_optional" field. +func (m *UserMutation) SetDropOptional(s string) { + m.drop_optional = &s +} + +// DropOptional returns the value of the "drop_optional" field in the mutation. +func (m *UserMutation) DropOptional() (r string, exists bool) { + v := m.drop_optional + if v == nil { + return + } + return *v, true +} + +// OldDropOptional returns the old "drop_optional" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldDropOptional(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDropOptional is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDropOptional requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDropOptional: %w", err) + } + return oldValue.DropOptional, nil +} + +// ClearDropOptional clears the value of the "drop_optional" field. +func (m *UserMutation) ClearDropOptional() { + m.drop_optional = nil + m.clearedFields[user.FieldDropOptional] = struct{}{} +} + +// DropOptionalCleared returns if the "drop_optional" field was cleared in this mutation. +func (m *UserMutation) DropOptionalCleared() bool { + _, ok := m.clearedFields[user.FieldDropOptional] + return ok +} + +// ResetDropOptional resets all changes to the "drop_optional" field. +func (m *UserMutation) ResetDropOptional() { + m.drop_optional = nil + delete(m.clearedFields, user.FieldDropOptional) +} + // SetParentID sets the "parent" edge to the User entity by id. func (m *UserMutation) SetParentID(id int) { m.parent = &id @@ -2713,7 +2763,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 12) if m.age != nil { fields = append(fields, user.FieldAge) } @@ -2747,6 +2797,9 @@ func (m *UserMutation) Fields() []string { if m.workplace != nil { fields = append(fields, user.FieldWorkplace) } + if m.drop_optional != nil { + fields = append(fields, user.FieldDropOptional) + } return fields } @@ -2777,6 +2830,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.Status() case user.FieldWorkplace: return m.Workplace() + case user.FieldDropOptional: + return m.DropOptional() } return nil, false } @@ -2808,6 +2863,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldStatus(ctx) case user.FieldWorkplace: return m.OldWorkplace(ctx) + case user.FieldDropOptional: + return m.OldDropOptional(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -2894,6 +2951,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetWorkplace(v) return nil + case user.FieldDropOptional: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDropOptional(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -2960,6 +3024,9 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldWorkplace) { fields = append(fields, user.FieldWorkplace) } + if m.FieldCleared(user.FieldDropOptional) { + fields = append(fields, user.FieldDropOptional) + } return fields } @@ -2995,6 +3062,9 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldWorkplace: m.ClearWorkplace() return nil + case user.FieldDropOptional: + m.ClearDropOptional() + return nil } return fmt.Errorf("unknown User nullable field %s", name) } @@ -3036,6 +3106,9 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldWorkplace: m.ResetWorkplace() return nil + case user.FieldDropOptional: + m.ResetDropOptional() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/entc/integration/migrate/entv1/schema/user.go b/entc/integration/migrate/entv1/schema/user.go index 35aaf28d6..dfadebb07 100644 --- a/entc/integration/migrate/entv1/schema/user.go +++ b/entc/integration/migrate/entv1/schema/user.go @@ -49,6 +49,8 @@ func (User) Fields() []ent.Field { field.String("workplace"). MaxLen(30). Optional(), + field.String("drop_optional"). + Optional(), } } diff --git a/entc/integration/migrate/entv1/user.go b/entc/integration/migrate/entv1/user.go index 08e6cb336..462f03ed8 100644 --- a/entc/integration/migrate/entv1/user.go +++ b/entc/integration/migrate/entv1/user.go @@ -42,6 +42,8 @@ type User struct { Status string `json:"status,omitempty"` // Workplace holds the value of the "workplace" field. Workplace string `json:"workplace,omitempty"` + // DropOptional holds the value of the "drop_optional" field. + DropOptional string `json:"drop_optional,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -124,7 +126,7 @@ func (*User) scanValues(columns []string) ([]interface{}, error) { values[i] = new([]byte) case user.FieldID, user.FieldAge: values[i] = new(sql.NullInt64) - case user.FieldName, user.FieldDescription, user.FieldNickname, user.FieldAddress, user.FieldRenamed, user.FieldOldToken, user.FieldState, user.FieldStatus, user.FieldWorkplace: + case user.FieldName, user.FieldDescription, user.FieldNickname, user.FieldAddress, user.FieldRenamed, user.FieldOldToken, user.FieldState, user.FieldStatus, user.FieldWorkplace, user.FieldDropOptional: values[i] = new(sql.NullString) case user.ForeignKeys[0]: // user_children values[i] = new(sql.NullInt64) @@ -217,6 +219,12 @@ func (u *User) assignValues(columns []string, values []interface{}) error { } else if value.Valid { u.Workplace = value.String } + case user.FieldDropOptional: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field drop_optional", values[i]) + } else if value.Valid { + u.DropOptional = value.String + } case user.ForeignKeys[0]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field user_children", value) @@ -301,6 +309,8 @@ func (u *User) String() string { builder.WriteString(u.Status) builder.WriteString(", workplace=") builder.WriteString(u.Workplace) + builder.WriteString(", drop_optional=") + builder.WriteString(u.DropOptional) builder.WriteByte(')') return builder.String() } diff --git a/entc/integration/migrate/entv1/user/user.go b/entc/integration/migrate/entv1/user/user.go index c4bb94e5b..a37733739 100644 --- a/entc/integration/migrate/entv1/user/user.go +++ b/entc/integration/migrate/entv1/user/user.go @@ -37,6 +37,8 @@ const ( FieldStatus = "status" // FieldWorkplace holds the string denoting the workplace field in the database. FieldWorkplace = "workplace" + // FieldDropOptional holds the string denoting the drop_optional field in the database. + FieldDropOptional = "drop_optional" // EdgeParent holds the string denoting the parent edge name in mutations. EdgeParent = "parent" // EdgeChildren holds the string denoting the children edge name in mutations. @@ -84,6 +86,7 @@ var Columns = []string{ FieldState, FieldStatus, FieldWorkplace, + FieldDropOptional, } // ForeignKeys holds the SQL foreign-keys that are owned by the "users" diff --git a/entc/integration/migrate/entv1/user/where.go b/entc/integration/migrate/entv1/user/where.go index a0b550aa5..f998fdd9c 100644 --- a/entc/integration/migrate/entv1/user/where.go +++ b/entc/integration/migrate/entv1/user/where.go @@ -165,6 +165,13 @@ func Workplace(v string) predicate.User { }) } +// DropOptional applies equality check predicate on the "drop_optional" field. It's identical to DropOptionalEQ. +func DropOptional(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldDropOptional), v)) + }) +} + // AgeEQ applies the EQ predicate on the "age" field. func AgeEQ(v int32) predicate.User { return predicate.User(func(s *sql.Selector) { @@ -1351,6 +1358,131 @@ func WorkplaceContainsFold(v string) predicate.User { }) } +// DropOptionalEQ applies the EQ predicate on the "drop_optional" field. +func DropOptionalEQ(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalNEQ applies the NEQ predicate on the "drop_optional" field. +func DropOptionalNEQ(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalIn applies the In predicate on the "drop_optional" field. +func DropOptionalIn(vs ...string) predicate.User { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldDropOptional), v...)) + }) +} + +// DropOptionalNotIn applies the NotIn predicate on the "drop_optional" field. +func DropOptionalNotIn(vs ...string) predicate.User { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldDropOptional), v...)) + }) +} + +// DropOptionalGT applies the GT predicate on the "drop_optional" field. +func DropOptionalGT(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalGTE applies the GTE predicate on the "drop_optional" field. +func DropOptionalGTE(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalLT applies the LT predicate on the "drop_optional" field. +func DropOptionalLT(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalLTE applies the LTE predicate on the "drop_optional" field. +func DropOptionalLTE(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalContains applies the Contains predicate on the "drop_optional" field. +func DropOptionalContains(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalHasPrefix applies the HasPrefix predicate on the "drop_optional" field. +func DropOptionalHasPrefix(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalHasSuffix applies the HasSuffix predicate on the "drop_optional" field. +func DropOptionalHasSuffix(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalIsNil applies the IsNil predicate on the "drop_optional" field. +func DropOptionalIsNil() predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.IsNull(s.C(FieldDropOptional))) + }) +} + +// DropOptionalNotNil applies the NotNil predicate on the "drop_optional" field. +func DropOptionalNotNil() predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NotNull(s.C(FieldDropOptional))) + }) +} + +// DropOptionalEqualFold applies the EqualFold predicate on the "drop_optional" field. +func DropOptionalEqualFold(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalContainsFold applies the ContainsFold predicate on the "drop_optional" field. +func DropOptionalContainsFold(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldDropOptional), v)) + }) +} + // HasParent applies the HasEdge predicate on the "parent" edge. func HasParent() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/entc/integration/migrate/entv1/user_create.go b/entc/integration/migrate/entv1/user_create.go index 4ce37509f..05e196054 100644 --- a/entc/integration/migrate/entv1/user_create.go +++ b/entc/integration/migrate/entv1/user_create.go @@ -146,6 +146,20 @@ func (uc *UserCreate) SetNillableWorkplace(s *string) *UserCreate { return uc } +// SetDropOptional sets the "drop_optional" field. +func (uc *UserCreate) SetDropOptional(s string) *UserCreate { + uc.mutation.SetDropOptional(s) + return uc +} + +// SetNillableDropOptional sets the "drop_optional" field if the given value is not nil. +func (uc *UserCreate) SetNillableDropOptional(s *string) *UserCreate { + if s != nil { + uc.SetDropOptional(*s) + } + return uc +} + // SetID sets the "id" field. func (uc *UserCreate) SetID(i int) *UserCreate { uc.mutation.SetID(i) @@ -460,6 +474,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { }) _node.Workplace = value } + if value, ok := uc.mutation.DropOptional(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: user.FieldDropOptional, + }) + _node.DropOptional = value + } if nodes := uc.mutation.ParentIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/entc/integration/migrate/entv1/user_update.go b/entc/integration/migrate/entv1/user_update.go index 385b04302..aea634bda 100644 --- a/entc/integration/migrate/entv1/user_update.go +++ b/entc/integration/migrate/entv1/user_update.go @@ -203,6 +203,26 @@ func (uu *UserUpdate) ClearWorkplace() *UserUpdate { return uu } +// SetDropOptional sets the "drop_optional" field. +func (uu *UserUpdate) SetDropOptional(s string) *UserUpdate { + uu.mutation.SetDropOptional(s) + return uu +} + +// SetNillableDropOptional sets the "drop_optional" field if the given value is not nil. +func (uu *UserUpdate) SetNillableDropOptional(s *string) *UserUpdate { + if s != nil { + uu.SetDropOptional(*s) + } + return uu +} + +// ClearDropOptional clears the value of the "drop_optional" field. +func (uu *UserUpdate) ClearDropOptional() *UserUpdate { + uu.mutation.ClearDropOptional() + return uu +} + // SetParentID sets the "parent" edge to the User entity by ID. func (uu *UserUpdate) SetParentID(id int) *UserUpdate { uu.mutation.SetParentID(id) @@ -548,6 +568,19 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: user.FieldWorkplace, }) } + if value, ok := uu.mutation.DropOptional(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: user.FieldDropOptional, + }) + } + if uu.mutation.DropOptionalCleared() { + _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: user.FieldDropOptional, + }) + } if uu.mutation.ParentCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -897,6 +930,26 @@ func (uuo *UserUpdateOne) ClearWorkplace() *UserUpdateOne { return uuo } +// SetDropOptional sets the "drop_optional" field. +func (uuo *UserUpdateOne) SetDropOptional(s string) *UserUpdateOne { + uuo.mutation.SetDropOptional(s) + return uuo +} + +// SetNillableDropOptional sets the "drop_optional" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableDropOptional(s *string) *UserUpdateOne { + if s != nil { + uuo.SetDropOptional(*s) + } + return uuo +} + +// ClearDropOptional clears the value of the "drop_optional" field. +func (uuo *UserUpdateOne) ClearDropOptional() *UserUpdateOne { + uuo.mutation.ClearDropOptional() + return uuo +} + // SetParentID sets the "parent" edge to the User entity by ID. func (uuo *UserUpdateOne) SetParentID(id int) *UserUpdateOne { uuo.mutation.SetParentID(id) @@ -1266,6 +1319,19 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) Column: user.FieldWorkplace, }) } + if value, ok := uuo.mutation.DropOptional(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: user.FieldDropOptional, + }) + } + if uuo.mutation.DropOptionalCleared() { + _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: user.FieldDropOptional, + }) + } if uuo.mutation.ParentCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/entc/integration/migrate/entv2/migrate/schema.go b/entc/integration/migrate/entv2/migrate/schema.go index c88da1d2c..04da9e3bd 100644 --- a/entc/integration/migrate/entv2/migrate/schema.go +++ b/entc/integration/migrate/entv2/migrate/schema.go @@ -147,6 +147,7 @@ var ( {Name: "status", Type: field.TypeEnum, Nullable: true, Enums: []string{"done", "pending"}}, {Name: "workplace", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime, Default: "CURRENT_TIMESTAMP"}, + {Name: "drop_optional", Type: field.TypeString}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/entc/integration/migrate/entv2/mutation.go b/entc/integration/migrate/entv2/mutation.go index 18061d4f6..3364dc2a1 100644 --- a/entc/integration/migrate/entv2/mutation.go +++ b/entc/integration/migrate/entv2/mutation.go @@ -2993,6 +2993,7 @@ type UserMutation struct { status *user.Status workplace *string created_at *time.Time + drop_optional *string clearedFields map[string]struct{} car map[int]struct{} removedcar map[int]struct{} @@ -3798,6 +3799,42 @@ func (m *UserMutation) ResetCreatedAt() { m.created_at = nil } +// SetDropOptional sets the "drop_optional" field. +func (m *UserMutation) SetDropOptional(s string) { + m.drop_optional = &s +} + +// DropOptional returns the value of the "drop_optional" field in the mutation. +func (m *UserMutation) DropOptional() (r string, exists bool) { + v := m.drop_optional + if v == nil { + return + } + return *v, true +} + +// OldDropOptional returns the old "drop_optional" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldDropOptional(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDropOptional is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDropOptional requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDropOptional: %w", err) + } + return oldValue.DropOptional, nil +} + +// ResetDropOptional resets all changes to the "drop_optional" field. +func (m *UserMutation) ResetDropOptional() { + m.drop_optional = nil +} + // AddCarIDs adds the "car" edge to the Car entity by ids. func (m *UserMutation) AddCarIDs(ids ...int) { if m.car == nil { @@ -3964,7 +4001,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 16) + fields := make([]string, 0, 17) if m.mixed_string != nil { fields = append(fields, user.FieldMixedString) } @@ -4013,6 +4050,9 @@ func (m *UserMutation) Fields() []string { if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } + if m.drop_optional != nil { + fields = append(fields, user.FieldDropOptional) + } return fields } @@ -4053,6 +4093,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.Workplace() case user.FieldCreatedAt: return m.CreatedAt() + case user.FieldDropOptional: + return m.DropOptional() } return nil, false } @@ -4094,6 +4136,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldWorkplace(ctx) case user.FieldCreatedAt: return m.OldCreatedAt(ctx) + case user.FieldDropOptional: + return m.OldDropOptional(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -4215,6 +4259,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetCreatedAt(v) return nil + case user.FieldDropOptional: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDropOptional(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -4372,6 +4423,9 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldCreatedAt: m.ResetCreatedAt() return nil + case user.FieldDropOptional: + m.ResetDropOptional() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/entc/integration/migrate/entv2/runtime.go b/entc/integration/migrate/entv2/runtime.go index 1ce242b69..01e8387e8 100644 --- a/entc/integration/migrate/entv2/runtime.go +++ b/entc/integration/migrate/entv2/runtime.go @@ -54,4 +54,8 @@ func init() { userDescCreatedAt := userFields[14].Descriptor() // user.DefaultCreatedAt holds the default value on creation for the created_at field. user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time) + // userDescDropOptional is the schema descriptor for drop_optional field. + userDescDropOptional := userFields[15].Descriptor() + // user.DefaultDropOptional holds the default value on creation for the drop_optional field. + user.DefaultDropOptional = userDescDropOptional.Default.(func() string) } diff --git a/entc/integration/migrate/entv2/schema/user.go b/entc/integration/migrate/entv2/schema/user.go index daedce57b..f91845d44 100644 --- a/entc/integration/migrate/entv2/schema/user.go +++ b/entc/integration/migrate/entv2/schema/user.go @@ -102,6 +102,10 @@ func (User) Fields() []ent.Field { Annotations(&entsql.Annotation{ Default: "CURRENT_TIMESTAMP", }), + // nullable field was changed to non-nullable without a static + // default value, and it requires apply hook to fix this. + field.String("drop_optional"). + DefaultFunc(uuid.NewString), // deleting the `address` column. } } diff --git a/entc/integration/migrate/entv2/user.go b/entc/integration/migrate/entv2/user.go index 37d22b4ff..1f54cd96b 100644 --- a/entc/integration/migrate/entv2/user.go +++ b/entc/integration/migrate/entv2/user.go @@ -53,6 +53,8 @@ type User struct { Workplace string `json:"workplace,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` + // DropOptional holds the value of the "drop_optional" field. + DropOptional string `json:"drop_optional,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -112,7 +114,7 @@ func (*User) scanValues(columns []string) ([]interface{}, error) { values[i] = new([]byte) case user.FieldID, user.FieldAge: values[i] = new(sql.NullInt64) - case user.FieldMixedString, user.FieldMixedEnum, user.FieldName, user.FieldDescription, user.FieldNickname, user.FieldPhone, user.FieldTitle, user.FieldNewName, user.FieldNewToken, user.FieldState, user.FieldStatus, user.FieldWorkplace: + case user.FieldMixedString, user.FieldMixedEnum, user.FieldName, user.FieldDescription, user.FieldNickname, user.FieldPhone, user.FieldTitle, user.FieldNewName, user.FieldNewToken, user.FieldState, user.FieldStatus, user.FieldWorkplace, user.FieldDropOptional: values[i] = new(sql.NullString) case user.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -233,6 +235,12 @@ func (u *User) assignValues(columns []string, values []interface{}) error { } else if value.Valid { u.CreatedAt = value.Time } + case user.FieldDropOptional: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field drop_optional", values[i]) + } else if value.Valid { + u.DropOptional = value.String + } } } return nil @@ -308,6 +316,8 @@ func (u *User) String() string { builder.WriteString(u.Workplace) builder.WriteString(", created_at=") builder.WriteString(u.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", drop_optional=") + builder.WriteString(u.DropOptional) builder.WriteByte(')') return builder.String() } diff --git a/entc/integration/migrate/entv2/user/user.go b/entc/integration/migrate/entv2/user/user.go index 78e27c8db..b9d0832cd 100644 --- a/entc/integration/migrate/entv2/user/user.go +++ b/entc/integration/migrate/entv2/user/user.go @@ -48,6 +48,8 @@ const ( FieldWorkplace = "workplace" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" + // FieldDropOptional holds the string denoting the drop_optional field in the database. + FieldDropOptional = "drop_optional" // EdgeCar holds the string denoting the car edge name in mutations. EdgeCar = "car" // EdgePets holds the string denoting the pets edge name in mutations. @@ -97,6 +99,7 @@ var Columns = []string{ FieldStatus, FieldWorkplace, FieldCreatedAt, + FieldDropOptional, } var ( @@ -132,6 +135,8 @@ var ( BlobValidator func([]byte) error // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time + // DefaultDropOptional holds the default value on creation for the "drop_optional" field. + DefaultDropOptional func() string ) // MixedEnum defines the type for the "mixed_enum" enum field. diff --git a/entc/integration/migrate/entv2/user/where.go b/entc/integration/migrate/entv2/user/where.go index d4ce62625..cf9491206 100644 --- a/entc/integration/migrate/entv2/user/where.go +++ b/entc/integration/migrate/entv2/user/where.go @@ -188,6 +188,13 @@ func CreatedAt(v time.Time) predicate.User { }) } +// DropOptional applies equality check predicate on the "drop_optional" field. It's identical to DropOptionalEQ. +func DropOptional(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldDropOptional), v)) + }) +} + // MixedStringEQ applies the EQ predicate on the "mixed_string" field. func MixedStringEQ(v string) predicate.User { return predicate.User(func(s *sql.Selector) { @@ -1733,6 +1740,117 @@ func CreatedAtLTE(v time.Time) predicate.User { }) } +// DropOptionalEQ applies the EQ predicate on the "drop_optional" field. +func DropOptionalEQ(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalNEQ applies the NEQ predicate on the "drop_optional" field. +func DropOptionalNEQ(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalIn applies the In predicate on the "drop_optional" field. +func DropOptionalIn(vs ...string) predicate.User { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldDropOptional), v...)) + }) +} + +// DropOptionalNotIn applies the NotIn predicate on the "drop_optional" field. +func DropOptionalNotIn(vs ...string) predicate.User { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldDropOptional), v...)) + }) +} + +// DropOptionalGT applies the GT predicate on the "drop_optional" field. +func DropOptionalGT(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalGTE applies the GTE predicate on the "drop_optional" field. +func DropOptionalGTE(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalLT applies the LT predicate on the "drop_optional" field. +func DropOptionalLT(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalLTE applies the LTE predicate on the "drop_optional" field. +func DropOptionalLTE(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalContains applies the Contains predicate on the "drop_optional" field. +func DropOptionalContains(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalHasPrefix applies the HasPrefix predicate on the "drop_optional" field. +func DropOptionalHasPrefix(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalHasSuffix applies the HasSuffix predicate on the "drop_optional" field. +func DropOptionalHasSuffix(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalEqualFold applies the EqualFold predicate on the "drop_optional" field. +func DropOptionalEqualFold(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldDropOptional), v)) + }) +} + +// DropOptionalContainsFold applies the ContainsFold predicate on the "drop_optional" field. +func DropOptionalContainsFold(v string) predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldDropOptional), v)) + }) +} + // HasCar applies the HasEdge predicate on the "car" edge. func HasCar() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/entc/integration/migrate/entv2/user_create.go b/entc/integration/migrate/entv2/user_create.go index da3ea1fc9..db430fc91 100644 --- a/entc/integration/migrate/entv2/user_create.go +++ b/entc/integration/migrate/entv2/user_create.go @@ -210,6 +210,20 @@ func (uc *UserCreate) SetNillableCreatedAt(t *time.Time) *UserCreate { return uc } +// SetDropOptional sets the "drop_optional" field. +func (uc *UserCreate) SetDropOptional(s string) *UserCreate { + uc.mutation.SetDropOptional(s) + return uc +} + +// SetNillableDropOptional sets the "drop_optional" field if the given value is not nil. +func (uc *UserCreate) SetNillableDropOptional(s *string) *UserCreate { + if s != nil { + uc.SetDropOptional(*s) + } + return uc +} + // SetID sets the "id" field. func (uc *UserCreate) SetID(i int) *UserCreate { uc.mutation.SetID(i) @@ -368,6 +382,10 @@ func (uc *UserCreate) defaults() { v := user.DefaultCreatedAt() uc.mutation.SetCreatedAt(v) } + if _, ok := uc.mutation.DropOptional(); !ok { + v := user.DefaultDropOptional() + uc.mutation.SetDropOptional(v) + } } // check runs all checks and user-defined validators on the builder. @@ -424,6 +442,9 @@ func (uc *UserCreate) check() error { if _, ok := uc.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`entv2: missing required field "User.created_at"`)} } + if _, ok := uc.mutation.DropOptional(); !ok { + return &ValidationError{Name: "drop_optional", err: errors.New(`entv2: missing required field "User.drop_optional"`)} + } return nil } @@ -585,6 +606,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { }) _node.CreatedAt = value } + if value, ok := uc.mutation.DropOptional(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: user.FieldDropOptional, + }) + _node.DropOptional = value + } if nodes := uc.mutation.CarIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/entc/integration/migrate/entv2/user_update.go b/entc/integration/migrate/entv2/user_update.go index 3e0edd97e..474b3c663 100644 --- a/entc/integration/migrate/entv2/user_update.go +++ b/entc/integration/migrate/entv2/user_update.go @@ -267,6 +267,20 @@ func (uu *UserUpdate) SetNillableCreatedAt(t *time.Time) *UserUpdate { return uu } +// SetDropOptional sets the "drop_optional" field. +func (uu *UserUpdate) SetDropOptional(s string) *UserUpdate { + uu.mutation.SetDropOptional(s) + return uu +} + +// SetNillableDropOptional sets the "drop_optional" field if the given value is not nil. +func (uu *UserUpdate) SetNillableDropOptional(s *string) *UserUpdate { + if s != nil { + uu.SetDropOptional(*s) + } + return uu +} + // AddCarIDs adds the "car" edge to the Car entity by IDs. func (uu *UserUpdate) AddCarIDs(ids ...int) *UserUpdate { uu.mutation.AddCarIDs(ids...) @@ -638,6 +652,13 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: user.FieldCreatedAt, }) } + if value, ok := uu.mutation.DropOptional(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: user.FieldDropOptional, + }) + } if uu.mutation.CarCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1033,6 +1054,20 @@ func (uuo *UserUpdateOne) SetNillableCreatedAt(t *time.Time) *UserUpdateOne { return uuo } +// SetDropOptional sets the "drop_optional" field. +func (uuo *UserUpdateOne) SetDropOptional(s string) *UserUpdateOne { + uuo.mutation.SetDropOptional(s) + return uuo +} + +// SetNillableDropOptional sets the "drop_optional" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableDropOptional(s *string) *UserUpdateOne { + if s != nil { + uuo.SetDropOptional(*s) + } + return uuo +} + // AddCarIDs adds the "car" edge to the Car entity by IDs. func (uuo *UserUpdateOne) AddCarIDs(ids ...int) *UserUpdateOne { uuo.mutation.AddCarIDs(ids...) @@ -1428,6 +1463,13 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) Column: user.FieldCreatedAt, }) } + if value, ok := uuo.mutation.DropOptional(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: user.FieldDropOptional, + }) + } if uuo.mutation.CarCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/entc/integration/migrate/migrate_test.go b/entc/integration/migrate/migrate_test.go index ca0d2711d..d2ac0a02b 100644 --- a/entc/integration/migrate/migrate_test.go +++ b/entc/integration/migrate/migrate_test.go @@ -26,6 +26,7 @@ import ( "entgo.io/ent/entc/integration/migrate/entv2/conversion" "entgo.io/ent/entc/integration/migrate/entv2/customtype" migratev2 "entgo.io/ent/entc/integration/migrate/entv2/migrate" + "entgo.io/ent/entc/integration/migrate/entv2/predicate" "entgo.io/ent/entc/integration/migrate/entv2/user" "entgo.io/ent/entc/integration/migrate/versioned" @@ -210,7 +211,7 @@ func V1ToV2(t *testing.T, dialect string, clientv1 *entv1.Client, clientv2 *entv SanityV1(t, dialect, clientv1) // Run migration and execute queries on v2. - require.NoError(t, clientv2.Schema.Create(ctx, migratev2.WithGlobalUniqueID(true), migratev2.WithDropIndex(true), migratev2.WithDropColumn(true), schema.WithAtlas(true), schema.WithDiffHook(renameTokenColumn))) + require.NoError(t, clientv2.Schema.Create(ctx, migratev2.WithGlobalUniqueID(true), migratev2.WithDropIndex(true), migratev2.WithDropColumn(true), schema.WithAtlas(true), schema.WithDiffHook(renameTokenColumn), schema.WithApplyHook(fillNulls(dialect)))) require.NoError(t, clientv2.Schema.Create(ctx, migratev2.WithGlobalUniqueID(true), migratev2.WithDropIndex(true), migratev2.WithDropColumn(true), schema.WithAtlas(true)), "should not create additional resources on multiple runs") SanityV2(t, dialect, clientv2) @@ -483,3 +484,23 @@ func renameTokenColumn(next schema.Differ) schema.Differ { return changes, nil }) } + +func fillNulls(dbdialect string) schema.ApplyHook { + return func(next schema.Applier) schema.Applier { + return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error { + // There are three ways to UPDATE the NULL values to "Unknown" in this stage. + // Append a custom migrate.Change to the plan, execute an SQL statement directly + // on the dialect.ExecQuerier, or use the ent.Client used by the project. + drv := sql.NewDriver(dbdialect, sql.Conn{ExecQuerier: conn.(*sql.Tx)}) + client := entv2.NewClient(entv2.Driver(drv)) + if err := client.User. + Update(). + SetDropOptional("Unknown"). + Where(predicate.User(userv1.DropOptionalIsNil())). + Exec(ctx); err != nil { + return fmt.Errorf("fix default values to uppercase: %w", err) + } + return next.Apply(ctx, conn, plan) + }) + } +}