mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: accept default values for enum fields (#646)
Fixed #644
This commit is contained in:
@@ -61,7 +61,7 @@ func TestPostgres_Create(t *testing.T) {
|
||||
{Name: "name", Type: field.TypeString, Nullable: true},
|
||||
{Name: "age", Type: field.TypeInt},
|
||||
{Name: "doc", Type: field.TypeJSON, Nullable: true},
|
||||
{Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}},
|
||||
{Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}, Default: "a"},
|
||||
{Name: "uuid", Type: field.TypeUUID},
|
||||
{Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.Postgres: "numeric(5,2)"}},
|
||||
},
|
||||
@@ -70,7 +70,7 @@ func TestPostgres_Create(t *testing.T) {
|
||||
before: func(mock pgMock) {
|
||||
mock.start("120000")
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "age" bigint NOT NULL, "doc" jsonb NULL, "enums" varchar NOT NULL, "uuid" uuid NOT NULL, "price" numeric(5,2) NOT NULL, PRIMARY KEY("id"))`)).
|
||||
mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "age" bigint NOT NULL, "doc" jsonb NULL, "enums" varchar NOT NULL DEFAULT 'a', "uuid" uuid NOT NULL, "price" numeric(5,2) NOT NULL, PRIMARY KEY("id"))`)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
|
||||
@@ -215,7 +215,7 @@ func (c *Column) ScanDefault(value string) error {
|
||||
return fmt.Errorf("scanning bool value for column %q: %v", c.Name, err)
|
||||
}
|
||||
c.Default = v.Bool
|
||||
case c.Type == field.TypeString:
|
||||
case c.Type == field.TypeString || c.Type == field.TypeEnum:
|
||||
v := &sql.NullString{}
|
||||
if err := v.Scan(value); err != nil {
|
||||
return fmt.Errorf("scanning string value for column %q: %v", c.Name, err)
|
||||
@@ -228,7 +228,7 @@ func (c *Column) ScanDefault(value string) error {
|
||||
}
|
||||
c.Default = v.String
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %v", c.Type)
|
||||
return fmt.Errorf("unsupported default type: %v", c.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -256,7 +256,7 @@ func (c *Column) defaultValue(b *sql.ColumnBuilder) {
|
||||
// supportDefault reports if the column type supports default value.
|
||||
func (c Column) supportDefault() bool {
|
||||
switch {
|
||||
case c.Type == field.TypeString:
|
||||
case c.Type == field.TypeString || c.Type == field.TypeEnum:
|
||||
return c.Size < 1<<16 // not a text.
|
||||
case c.Type.Numeric(), c.Type == field.TypeBool:
|
||||
return true
|
||||
|
||||
Reference in New Issue
Block a user