mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/schema: alter column for postgres
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/117 Reviewed By: alexsn Differential Revision: D18083914 fbshipit-source-id: a5f6993cfe9a260a84c0d4ab868e3e797b3a5776
This commit is contained in:
committed by
Facebook Github Bot
parent
a0c7ee77dc
commit
c414cd9a82
@@ -54,10 +54,10 @@ func (c *ColumnBuilder) Attr(attr string) *ColumnBuilder {
|
||||
// Query returns query representation of a Column.
|
||||
func (c *ColumnBuilder) Query() (string, []interface{}) {
|
||||
c.Ident(c.name)
|
||||
if c.postgres() && c.modify {
|
||||
c.Pad().WriteString("TYPE")
|
||||
}
|
||||
if c.typ != "" {
|
||||
if c.postgres() && c.modify {
|
||||
c.Pad().WriteString("TYPE")
|
||||
}
|
||||
c.Pad().WriteString(c.typ)
|
||||
}
|
||||
if c.attr != "" {
|
||||
@@ -225,7 +225,7 @@ func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter {
|
||||
return t
|
||||
}
|
||||
|
||||
// Modify appends the `MODIFY COLUMN` clause to the given `ALTER TABLE` statement.
|
||||
// Modify appends the `MODIFY/ALTER COLUMN` clause to the given `ALTER TABLE` statement.
|
||||
func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter {
|
||||
switch {
|
||||
case t.postgres():
|
||||
@@ -237,6 +237,14 @@ func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter {
|
||||
return t
|
||||
}
|
||||
|
||||
// ModifyColumns calls ModifyColumn with each of the given builders.
|
||||
func (t *TableAlter) ModifyColumns(cs ...*ColumnBuilder) *TableAlter {
|
||||
for _, c := range cs {
|
||||
t.ModifyColumn(c)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// DropColumn appends the `DROP COLUMN` clause to the given `ALTER TABLE` statement.
|
||||
func (t *TableAlter) DropColumn(c *ColumnBuilder) *TableAlter {
|
||||
t.Queries = append(t.Queries, &Wrapper{"DROP COLUMN %s", c})
|
||||
|
||||
@@ -172,6 +172,13 @@ func TestBuilder(t *testing.T) {
|
||||
DropColumn(Column("name")),
|
||||
wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int, DROP COLUMN "name"`,
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).AlterTable("users").
|
||||
ModifyColumn(Column("age").Type("int")).
|
||||
ModifyColumn(Column("age").Attr("SET NOT NULL")).
|
||||
ModifyColumn(Column("name").Attr("DROP NOT NULL")),
|
||||
wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int, ALTER COLUMN "age" SET NOT NULL, ALTER COLUMN "name" DROP NOT NULL`,
|
||||
},
|
||||
{
|
||||
input: Dialect(dialect.Postgres).AlterTable("users").
|
||||
AddColumn(Column("boring").Type("varchar")).
|
||||
|
||||
@@ -193,10 +193,10 @@ func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change
|
||||
}
|
||||
b := sql.Dialect(m.Dialect()).AlterTable(table)
|
||||
for _, c := range change.column.add {
|
||||
b.AddColumn(m.cBuilder(c))
|
||||
b.AddColumn(m.addColumn(c))
|
||||
}
|
||||
for _, c := range change.column.modify {
|
||||
b.ModifyColumn(m.cBuilder(c))
|
||||
b.ModifyColumns(m.alterColumn(c)...)
|
||||
}
|
||||
if m.dropColumn {
|
||||
for _, c := range change.column.drop {
|
||||
@@ -412,5 +412,6 @@ type sqlDialect interface {
|
||||
// table, column and index builder per dialect.
|
||||
cType(*Column) string
|
||||
tBuilder(*Table) *sql.TableBuilder
|
||||
cBuilder(*Column) *sql.ColumnBuilder
|
||||
addColumn(*Column) *sql.ColumnBuilder
|
||||
alterColumn(*Column) []*sql.ColumnBuilder
|
||||
}
|
||||
|
||||
@@ -105,6 +105,9 @@ func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, new(sql.Result))
|
||||
}
|
||||
|
||||
func (d *MySQL) cType(c *Column) string { return c.MySQLType(d.version) }
|
||||
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { return t.MySQL(d.version) }
|
||||
func (d *MySQL) cBuilder(c *Column) *sql.ColumnBuilder { return c.MySQL(d.version) }
|
||||
func (d *MySQL) cType(c *Column) string { return c.MySQLType(d.version) }
|
||||
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { return t.MySQL(d.version) }
|
||||
func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { return c.MySQL(d.version) }
|
||||
func (d *MySQL) alterColumn(c *Column) []*sql.ColumnBuilder {
|
||||
return []*sql.ColumnBuilder{c.MySQL(d.version)}
|
||||
}
|
||||
|
||||
@@ -524,7 +524,7 @@ func TestMySQL_Create(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modify column",
|
||||
name: "modify column to nullable",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
|
||||
@@ -230,7 +230,7 @@ func (d *Postgres) tBuilder(t *Table) *sql.TableBuilder {
|
||||
b := sql.Dialect(dialect.Postgres).
|
||||
CreateTable(t.Name).IfNotExists()
|
||||
for _, c := range t.Columns {
|
||||
b.Column(d.cBuilder(c))
|
||||
b.Column(d.addColumn(c))
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
@@ -274,8 +274,8 @@ func (d *Postgres) cType(c *Column) (t string) {
|
||||
return t
|
||||
}
|
||||
|
||||
// cBuilder returns the ColumnBuilder for the given column.
|
||||
func (d *Postgres) cBuilder(c *Column) *sql.ColumnBuilder {
|
||||
// addColumn returns the ColumnBuilder for adding the given column to a table.
|
||||
func (d *Postgres) addColumn(c *Column) *sql.ColumnBuilder {
|
||||
b := sql.Dialect(dialect.Postgres).
|
||||
Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
@@ -286,3 +286,15 @@ func (d *Postgres) cBuilder(c *Column) *sql.ColumnBuilder {
|
||||
c.defaultValue(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// alterColumn returns list of ColumnBuilder for applying in order to alter a column.
|
||||
func (d *Postgres) alterColumn(c *Column) (ops []*sql.ColumnBuilder) {
|
||||
b := sql.Dialect(dialect.Postgres)
|
||||
ops = append(ops, b.Column(c.Name).Type(d.cType(c)))
|
||||
if c.Nullable {
|
||||
ops = append(ops, b.Column(c.Name).Attr("DROP NOT NULL"))
|
||||
} else {
|
||||
ops = append(ops, b.Column(c.Name).Attr("SET NOT NULL"))
|
||||
}
|
||||
return ops
|
||||
}
|
||||
|
||||
@@ -379,6 +379,40 @@ func TestPostgres_Create(t *testing.T) {
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modify column to nullable",
|
||||
tables: []*Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Nullable: true},
|
||||
},
|
||||
PrimaryKey: []*Column{
|
||||
{Name: "id", Type: field.TypeInt, Increment: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
before: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(escape("SHOW server_version_num")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow("120000"))
|
||||
mock.ExpectQuery(escape(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default" FROM INFORMATION_SCHEMA.COLUMNS WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)).
|
||||
WithArgs("users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default"}).
|
||||
AddRow("id", "bigint(20)", "NO", "NULL").
|
||||
AddRow("name", "character", "NO", "NULL"))
|
||||
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique"}).
|
||||
AddRow("users_pkey", "id", "t", "t"))
|
||||
mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar, ALTER COLUMN "name" DROP NOT NULL`)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectCommit()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -61,9 +61,10 @@ func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, name string, value
|
||||
return tx.Exec(ctx, query, args, new(sql.Result))
|
||||
}
|
||||
|
||||
func (*SQLite) cType(c *Column) string { return c.SQLiteType() }
|
||||
func (*SQLite) tBuilder(t *Table) *sql.TableBuilder { return t.SQLite() }
|
||||
func (*SQLite) cBuilder(c *Column) *sql.ColumnBuilder { return c.SQLite() }
|
||||
func (*SQLite) cType(c *Column) string { return c.SQLiteType() }
|
||||
func (*SQLite) tBuilder(t *Table) *sql.TableBuilder { return t.SQLite() }
|
||||
func (*SQLite) addColumn(c *Column) *sql.ColumnBuilder { return c.SQLite() }
|
||||
func (*SQLite) alterColumn(c *Column) []*sql.ColumnBuilder { return []*sql.ColumnBuilder{c.SQLite()} }
|
||||
|
||||
// fkExist returns always tru to disable foreign-keys creation after the table was created.
|
||||
func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil }
|
||||
|
||||
Reference in New Issue
Block a user