mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: skip parsing expression default for postgres
Fixed https://github.com/ent/ent/issues/1962
This commit is contained in:
committed by
Ariel Mashraki
parent
3b41914013
commit
5c7c36bf29
@@ -295,7 +295,7 @@ func (d *Postgres) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
c.SchemaType = map[string]string{dialect.Postgres: udt.String}
|
||||
}
|
||||
switch {
|
||||
case !defaults.Valid || c.Type == field.TypeTime || seqfunc(defaults.String):
|
||||
case !defaults.Valid || c.Type == field.TypeTime || callExpr(defaults.String):
|
||||
return nil
|
||||
case strings.Contains(defaults.String, "::"):
|
||||
parts := strings.Split(defaults.String, "::")
|
||||
@@ -374,13 +374,32 @@ func (d *Postgres) addColumn(c *Column) *sql.ColumnBuilder {
|
||||
b.Attr("GENERATED BY DEFAULT AS IDENTITY")
|
||||
}
|
||||
c.nullable(b)
|
||||
c.defaultValue(b)
|
||||
d.writeDefault(b, c)
|
||||
if c.Collation != "" {
|
||||
b.Attr("COLLATE " + strconv.Quote(c.Collation))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// writeDefault writes the `DEFAULT` clause to column builder
|
||||
// if exists and supported by the driver.
|
||||
func (d *Postgres) writeDefault(b *sql.ColumnBuilder, c *Column) {
|
||||
if c.Default == nil || !c.supportDefault() {
|
||||
return
|
||||
}
|
||||
attr := fmt.Sprint(c.Default)
|
||||
switch v := c.Default.(type) {
|
||||
case bool:
|
||||
attr = strconv.FormatBool(v)
|
||||
case string:
|
||||
if t := c.Type; t != field.TypeUUID && t != field.TypeTime && !t.Numeric() {
|
||||
// Escape single quote by replacing each with 2.
|
||||
attr = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
|
||||
}
|
||||
}
|
||||
b.Attr("DEFAULT " + attr)
|
||||
}
|
||||
|
||||
// alterColumn returns list of ColumnBuilder for applying in order to alter a column.
|
||||
func (d *Postgres) alterColumn(c *Column) (ops []*sql.ColumnBuilder) {
|
||||
b := sql.Dialect(dialect.Postgres)
|
||||
@@ -522,14 +541,25 @@ func (d *Postgres) needsConversion(old, new *Column) bool {
|
||||
return oldT != newT && (oldT != "ARRAY" || !arrayType(newT))
|
||||
}
|
||||
|
||||
// seqfunc reports if the given string is a sequence function.
|
||||
func seqfunc(defaults string) bool {
|
||||
for _, fn := range [...]string{"currval", "lastval", "setval", "nextval"} {
|
||||
if strings.HasPrefix(defaults, fn+"(") && strings.HasSuffix(defaults, ")") {
|
||||
return true
|
||||
// callExpr reports if the given string ~looks like a function call expression.
|
||||
func callExpr(s string) bool {
|
||||
if parts := strings.Split(s, "::"); !strings.HasSuffix(s, ")") && strings.HasSuffix(parts[0], ")") {
|
||||
s = parts[0]
|
||||
}
|
||||
i, j := strings.IndexByte(s, '('), strings.LastIndexByte(s, ')')
|
||||
if i == -1 || i > j || j != len(s)-1 {
|
||||
return false
|
||||
}
|
||||
for i, r := range s[:i] {
|
||||
if !isAlpha(r, i > 0) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
func isAlpha(r rune, digit bool) bool {
|
||||
return 'a' <= r && r <= 'z' || 'A' <= r && r <= 'Z' || r == '_' || digit && '0' <= r && r <= '9'
|
||||
}
|
||||
|
||||
// arrayType reports if the given string is an array type (e.g. int[], text[2]).
|
||||
|
||||
@@ -59,6 +59,7 @@ func TestPostgres_Create(t *testing.T) {
|
||||
},
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeUUID, Default: "uuid_generate_v4()"},
|
||||
{Name: "block_size", Type: field.TypeInt, Default: "current_setting('block_size')::bigint"},
|
||||
{Name: "name", Type: field.TypeString, Nullable: true, Collation: "he_IL"},
|
||||
{Name: "age", Type: field.TypeInt},
|
||||
{Name: "doc", Type: field.TypeJSON, Nullable: true},
|
||||
@@ -77,7 +78,7 @@ func TestPostgres_Create(t *testing.T) {
|
||||
before: func(mock pgMock) {
|
||||
mock.start("120000")
|
||||
mock.tableExists("users", false)
|
||||
mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "name" varchar NULL COLLATE "he_IL", "age" bigint NOT NULL, "doc" jsonb NULL, "enums" varchar NOT NULL DEFAULT 'a', "price" numeric(5,2) NOT NULL, "strings" text[] NULL, PRIMARY KEY("id"), CHECK (price > 0), CONSTRAINT "valid_name" CHECK (name <> ''))`)).
|
||||
mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "block_size" bigint NOT NULL DEFAULT current_setting('block_size')::bigint, "name" varchar NULL COLLATE "he_IL", "age" bigint NOT NULL, "doc" jsonb NULL, "enums" varchar NOT NULL DEFAULT 'a', "price" numeric(5,2) NOT NULL, "strings" text[] NULL, PRIMARY KEY("id"), CHECK (price > 0), CONSTRAINT "valid_name" CHECK (name <> ''))`)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
@@ -184,12 +185,13 @@ func TestPostgres_Create(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "scan table with default set to serial",
|
||||
name: "scan table with default",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "block_size", Type: field.TypeInt, Default: "current_setting('block_size')::bigint"},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
@@ -202,7 +204,8 @@ func TestPostgres_Create(t *testing.T) {
|
||||
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale"}).
|
||||
AddRow("id", "bigint", "NO", "nextval('users_colname_seq'::regclass)", "int4", nil, nil))
|
||||
AddRow("id", "bigint", "NO", "nextval('users_colname_seq'::regclass)", "int4", nil, nil).
|
||||
AddRow("block_size", "bigint", "NO", "current_setting('block_size')::bigint", "int4", nil, nil))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}).
|
||||
AddRow("users_pkey", "id", "t", "t", 0))
|
||||
|
||||
Reference in New Issue
Block a user