dialect/sql/schema: alter column for postgres

Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/117

Reviewed By: alexsn

Differential Revision: D18083914

fbshipit-source-id: a5f6993cfe9a260a84c0d4ab868e3e797b3a5776
This commit is contained in:
Ariel Mashraki
2019-10-23 05:37:42 -07:00
committed by Facebook Github Bot
parent a0c7ee77dc
commit c414cd9a82
8 changed files with 83 additions and 17 deletions

View File

@@ -54,10 +54,10 @@ func (c *ColumnBuilder) Attr(attr string) *ColumnBuilder {
// Query returns query representation of a Column.
func (c *ColumnBuilder) Query() (string, []interface{}) {
c.Ident(c.name)
if c.postgres() && c.modify {
c.Pad().WriteString("TYPE")
}
if c.typ != "" {
if c.postgres() && c.modify {
c.Pad().WriteString("TYPE")
}
c.Pad().WriteString(c.typ)
}
if c.attr != "" {
@@ -225,7 +225,7 @@ func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter {
return t
}
// Modify appends the `MODIFY COLUMN` clause to the given `ALTER TABLE` statement.
// Modify appends the `MODIFY/ALTER COLUMN` clause to the given `ALTER TABLE` statement.
func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter {
switch {
case t.postgres():
@@ -237,6 +237,14 @@ func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter {
return t
}
// ModifyColumns calls ModifyColumn with each of the given builders.
func (t *TableAlter) ModifyColumns(cs ...*ColumnBuilder) *TableAlter {
for _, c := range cs {
t.ModifyColumn(c)
}
return t
}
// DropColumn appends the `DROP COLUMN` clause to the given `ALTER TABLE` statement.
func (t *TableAlter) DropColumn(c *ColumnBuilder) *TableAlter {
t.Queries = append(t.Queries, &Wrapper{"DROP COLUMN %s", c})

View File

@@ -172,6 +172,13 @@ func TestBuilder(t *testing.T) {
DropColumn(Column("name")),
wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int, DROP COLUMN "name"`,
},
{
input: Dialect(dialect.Postgres).AlterTable("users").
ModifyColumn(Column("age").Type("int")).
ModifyColumn(Column("age").Attr("SET NOT NULL")).
ModifyColumn(Column("name").Attr("DROP NOT NULL")),
wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int, ALTER COLUMN "age" SET NOT NULL, ALTER COLUMN "name" DROP NOT NULL`,
},
{
input: Dialect(dialect.Postgres).AlterTable("users").
AddColumn(Column("boring").Type("varchar")).

View File

@@ -193,10 +193,10 @@ func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change
}
b := sql.Dialect(m.Dialect()).AlterTable(table)
for _, c := range change.column.add {
b.AddColumn(m.cBuilder(c))
b.AddColumn(m.addColumn(c))
}
for _, c := range change.column.modify {
b.ModifyColumn(m.cBuilder(c))
b.ModifyColumns(m.alterColumn(c)...)
}
if m.dropColumn {
for _, c := range change.column.drop {
@@ -412,5 +412,6 @@ type sqlDialect interface {
// table, column and index builder per dialect.
cType(*Column) string
tBuilder(*Table) *sql.TableBuilder
cBuilder(*Column) *sql.ColumnBuilder
addColumn(*Column) *sql.ColumnBuilder
alterColumn(*Column) []*sql.ColumnBuilder
}

View File

@@ -105,6 +105,9 @@ func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, new(sql.Result))
}
func (d *MySQL) cType(c *Column) string { return c.MySQLType(d.version) }
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { return t.MySQL(d.version) }
func (d *MySQL) cBuilder(c *Column) *sql.ColumnBuilder { return c.MySQL(d.version) }
func (d *MySQL) cType(c *Column) string { return c.MySQLType(d.version) }
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { return t.MySQL(d.version) }
func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { return c.MySQL(d.version) }
func (d *MySQL) alterColumn(c *Column) []*sql.ColumnBuilder {
return []*sql.ColumnBuilder{c.MySQL(d.version)}
}

View File

@@ -524,7 +524,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "modify column",
name: "modify column to nullable",
tables: []*Table{
{
Name: "users",

View File

@@ -230,7 +230,7 @@ func (d *Postgres) tBuilder(t *Table) *sql.TableBuilder {
b := sql.Dialect(dialect.Postgres).
CreateTable(t.Name).IfNotExists()
for _, c := range t.Columns {
b.Column(d.cBuilder(c))
b.Column(d.addColumn(c))
}
for _, pk := range t.PrimaryKey {
b.PrimaryKey(pk.Name)
@@ -274,8 +274,8 @@ func (d *Postgres) cType(c *Column) (t string) {
return t
}
// cBuilder returns the ColumnBuilder for the given column.
func (d *Postgres) cBuilder(c *Column) *sql.ColumnBuilder {
// addColumn returns the ColumnBuilder for adding the given column to a table.
func (d *Postgres) addColumn(c *Column) *sql.ColumnBuilder {
b := sql.Dialect(dialect.Postgres).
Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
c.unique(b)
@@ -286,3 +286,15 @@ func (d *Postgres) cBuilder(c *Column) *sql.ColumnBuilder {
c.defaultValue(b)
return b
}
// alterColumn returns list of ColumnBuilder for applying in order to alter a column.
func (d *Postgres) alterColumn(c *Column) (ops []*sql.ColumnBuilder) {
b := sql.Dialect(dialect.Postgres)
ops = append(ops, b.Column(c.Name).Type(d.cType(c)))
if c.Nullable {
ops = append(ops, b.Column(c.Name).Attr("DROP NOT NULL"))
} else {
ops = append(ops, b.Column(c.Name).Attr("SET NOT NULL"))
}
return ops
}

View File

@@ -379,6 +379,40 @@ func TestPostgres_Create(t *testing.T) {
mock.ExpectCommit()
},
},
{
name: "modify column to nullable",
tables: []*Table{
{
Name: "users",
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString, Nullable: true},
},
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
},
},
before: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery(escape("SHOW server_version_num")).
WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow("120000"))
mock.ExpectQuery(escape(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default" FROM INFORMATION_SCHEMA.COLUMNS WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)).
WithArgs("users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default"}).
AddRow("id", "bigint(20)", "NO", "NULL").
AddRow("name", "character", "NO", "NULL"))
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))).
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique"}).
AddRow("users_pkey", "id", "t", "t"))
mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar, ALTER COLUMN "name" DROP NOT NULL`)).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -61,9 +61,10 @@ func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, name string, value
return tx.Exec(ctx, query, args, new(sql.Result))
}
func (*SQLite) cType(c *Column) string { return c.SQLiteType() }
func (*SQLite) tBuilder(t *Table) *sql.TableBuilder { return t.SQLite() }
func (*SQLite) cBuilder(c *Column) *sql.ColumnBuilder { return c.SQLite() }
func (*SQLite) cType(c *Column) string { return c.SQLiteType() }
func (*SQLite) tBuilder(t *Table) *sql.TableBuilder { return t.SQLite() }
func (*SQLite) addColumn(c *Column) *sql.ColumnBuilder { return c.SQLite() }
func (*SQLite) alterColumn(c *Column) []*sql.ColumnBuilder { return []*sql.ColumnBuilder{c.SQLite()} }
// fkExist returns always tru to disable foreign-keys creation after the table was created.
func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil }