mirror of
https://github.com/ent/ent.git
synced 2026-05-06 01:20:56 +03:00
diaelct/sql/schema: postgres read columns
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/104 Reviewed By: alexsn Differential Revision: D17980080 fbshipit-source-id: 341092a17798d008b91389263bf3bdc24b2571b2
This commit is contained in:
committed by
Facebook Github Bot
parent
2b6c8eada3
commit
0241a969b4
@@ -7,9 +7,11 @@ package schema
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/facebookincubator/ent/dialect"
|
||||
"github.com/facebookincubator/ent/dialect/sql"
|
||||
"github.com/facebookincubator/ent/schema/field"
|
||||
)
|
||||
|
||||
// Postgres is a postgres migration driver.
|
||||
@@ -63,3 +65,100 @@ func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (boo
|
||||
func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
|
||||
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN id RESTART WITH %d", name, value), []interface{}{}, new(sql.Result))
|
||||
}
|
||||
|
||||
// table loads the current table description from the database.
|
||||
func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
|
||||
rows := &sql.Rows{}
|
||||
query, args := sql.Dialect(dialect.Postgres).
|
||||
Select("column_name", "data_type", "character_maximum_length", "is_nullable", "column_default").
|
||||
From(sql.Table("INFORMATION_SCHEMA.COLUMNS").Unquote()).
|
||||
Where(sql.EQ("table_schema", sql.Raw("CURRENT_SCHEMA()")).And().EQ("table_name", name)).Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("postgres: reading table description %v", err)
|
||||
}
|
||||
// call `Close` in cases of failures (`Close` is idempotent).
|
||||
defer rows.Close()
|
||||
t := NewTable(name)
|
||||
for rows.Next() {
|
||||
c := &Column{}
|
||||
if err := d.scanColumn(c, rows); err != nil {
|
||||
return nil, fmt.Errorf("postgres: %v", err)
|
||||
}
|
||||
t.AddColumn(c)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, fmt.Errorf("postgres: closing rows %v", err)
|
||||
}
|
||||
// TODO: populate PK/UNI information for columns and tables and scan indexes.
|
||||
//
|
||||
// Get PK and UNI columns of a table:
|
||||
//
|
||||
// SELECT a.attname AS COLUMN,
|
||||
// format_type(a.atttypid, a.atttypmod) AS data_type,
|
||||
// i.indisprimary AS PRIMARY,
|
||||
// i.indisunique AS unique
|
||||
// FROM pg_index i
|
||||
// join pg_attribute a
|
||||
// ON a.attrelid = i.indrelid
|
||||
// AND a.attnum = ANY ( i.indkey )
|
||||
// WHERE i.indrelid = '<TABLE>' :: regclass;
|
||||
//
|
||||
// column | data_type | primary | unique
|
||||
// --------+-----------+---------+--------
|
||||
// a1 | integer | t | t
|
||||
// a2 | integer | t | t
|
||||
// a0 | integer | f | t
|
||||
//
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// maxCharSize defines the maximum size of limited character types in Postgres (10 MB).
|
||||
const maxCharSize = 10 << 20
|
||||
|
||||
// scanColumn scans the information a column from column description.
|
||||
func (d *Postgres) scanColumn(c *Column, rows *sql.Rows) error {
|
||||
var (
|
||||
maxlen sql.NullInt64
|
||||
nullable sql.NullString
|
||||
defaults sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&c.Name, &c.typ, &maxlen, &nullable, &defaults); err != nil {
|
||||
return fmt.Errorf("scanning column description: %v", err)
|
||||
}
|
||||
switch c.typ {
|
||||
case "boolean":
|
||||
c.Type = field.TypeBool
|
||||
case "smallint":
|
||||
c.Type = field.TypeInt16
|
||||
case "integer":
|
||||
c.Type = field.TypeInt32
|
||||
case "bigint":
|
||||
c.Type = field.TypeInt64
|
||||
case "real":
|
||||
c.Type = field.TypeFloat32
|
||||
case "double precision":
|
||||
c.Type = field.TypeFloat64
|
||||
case "text":
|
||||
c.Type = field.TypeString
|
||||
c.Size = maxCharSize + 1
|
||||
case "character":
|
||||
c.Type = field.TypeString
|
||||
c.Size = maxlen.Int64
|
||||
case "timestamp with time zone":
|
||||
c.Type = field.TypeTime
|
||||
case "bytea":
|
||||
c.Type = field.TypeBytes
|
||||
case "jsonb":
|
||||
c.Type = field.TypeJSON
|
||||
}
|
||||
switch {
|
||||
case !defaults.Valid:
|
||||
return nil
|
||||
case strings.Contains(defaults.String, "::"):
|
||||
parts := strings.Split(defaults.String, "::")
|
||||
defaults.String = strings.Trim(parts[0], "'")
|
||||
fallthrough
|
||||
default:
|
||||
return c.ScanDefault(defaults.String)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user