diaelct/sql/schema: postgres read columns

Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/104

Reviewed By: alexsn

Differential Revision: D17980080

fbshipit-source-id: 341092a17798d008b91389263bf3bdc24b2571b2
This commit is contained in:
Ariel Mashraki
2019-10-17 09:09:48 -07:00
committed by Facebook Github Bot
parent 2b6c8eada3
commit 0241a969b4

View File

@@ -7,9 +7,11 @@ package schema
import (
"context"
"fmt"
"strings"
"github.com/facebookincubator/ent/dialect"
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/schema/field"
)
// Postgres is a postgres migration driver.
@@ -63,3 +65,100 @@ func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (boo
func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN id RESTART WITH %d", name, value), []interface{}{}, new(sql.Result))
}
// table loads the current table description from the database.
func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
rows := &sql.Rows{}
query, args := sql.Dialect(dialect.Postgres).
Select("column_name", "data_type", "character_maximum_length", "is_nullable", "column_default").
From(sql.Table("INFORMATION_SCHEMA.COLUMNS").Unquote()).
Where(sql.EQ("table_schema", sql.Raw("CURRENT_SCHEMA()")).And().EQ("table_name", name)).Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("postgres: reading table description %v", err)
}
// call `Close` in cases of failures (`Close` is idempotent).
defer rows.Close()
t := NewTable(name)
for rows.Next() {
c := &Column{}
if err := d.scanColumn(c, rows); err != nil {
return nil, fmt.Errorf("postgres: %v", err)
}
t.AddColumn(c)
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("postgres: closing rows %v", err)
}
// TODO: populate PK/UNI information for columns and tables and scan indexes.
//
// Get PK and UNI columns of a table:
//
// SELECT a.attname AS COLUMN,
// format_type(a.atttypid, a.atttypmod) AS data_type,
// i.indisprimary AS PRIMARY,
// i.indisunique AS unique
// FROM pg_index i
// join pg_attribute a
// ON a.attrelid = i.indrelid
// AND a.attnum = ANY ( i.indkey )
// WHERE i.indrelid = '<TABLE>' :: regclass;
//
// column | data_type | primary | unique
// --------+-----------+---------+--------
// a1 | integer | t | t
// a2 | integer | t | t
// a0 | integer | f | t
//
return t, nil
}
// maxCharSize defines the maximum size of limited character types in Postgres (10 MB).
const maxCharSize = 10 << 20
// scanColumn scans the information a column from column description.
func (d *Postgres) scanColumn(c *Column, rows *sql.Rows) error {
var (
maxlen sql.NullInt64
nullable sql.NullString
defaults sql.NullString
)
if err := rows.Scan(&c.Name, &c.typ, &maxlen, &nullable, &defaults); err != nil {
return fmt.Errorf("scanning column description: %v", err)
}
switch c.typ {
case "boolean":
c.Type = field.TypeBool
case "smallint":
c.Type = field.TypeInt16
case "integer":
c.Type = field.TypeInt32
case "bigint":
c.Type = field.TypeInt64
case "real":
c.Type = field.TypeFloat32
case "double precision":
c.Type = field.TypeFloat64
case "text":
c.Type = field.TypeString
c.Size = maxCharSize + 1
case "character":
c.Type = field.TypeString
c.Size = maxlen.Int64
case "timestamp with time zone":
c.Type = field.TypeTime
case "bytea":
c.Type = field.TypeBytes
case "jsonb":
c.Type = field.TypeJSON
}
switch {
case !defaults.Valid:
return nil
case strings.Contains(defaults.String, "::"):
parts := strings.Split(defaults.String, "::")
defaults.String = strings.Trim(parts[0], "'")
fallthrough
default:
return c.ScanDefault(defaults.String)
}
}