From 5f4a55ea1ebfde70ec1d0ceddc2347b993088904 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sun, 20 Oct 2019 05:18:49 -0700 Subject: [PATCH] sql/dialect/schema: load postgres table Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/107 Reviewed By: alexsn Differential Revision: D18031837 fbshipit-source-id: 863f8db50a0547a7cb5e9ac560066fd6ee4e9c26 --- .golangci.yml | 3 + dialect/sql/schema/postgres.go | 115 +++++++++++++++++++++++++-------- dialect/sql/schema/schema.go | 13 ++-- 3 files changed, 96 insertions(+), 35 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 15432128b..70f11e1ba 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -69,3 +69,6 @@ issues: - path: dialect/sql/schema/postgres.go linters: - unused + - path: dialect/sql/schema/schema.go + linters: + - unused diff --git a/dialect/sql/schema/postgres.go b/dialect/sql/schema/postgres.go index fd07ff70b..a32927f8e 100644 --- a/dialect/sql/schema/postgres.go +++ b/dialect/sql/schema/postgres.go @@ -25,22 +25,22 @@ type Postgres struct { func (d *Postgres) init(ctx context.Context, tx dialect.Tx) error { rows := &sql.Rows{} if err := tx.Query(ctx, "SHOW server_version_num", []interface{}{}, rows); err != nil { - return fmt.Errorf("postgres: querying server version %v", err) + return fmt.Errorf("querying server version %v", err) } defer rows.Close() if !rows.Next() { - return fmt.Errorf("postgres: server_version_num variable was not found") + return fmt.Errorf("server_version_num variable was not found") } var version string if err := rows.Scan(&version); err != nil { - return fmt.Errorf("postgres: scanning version: %v", err) + return fmt.Errorf("scanning version: %v", err) } if len(version) < 6 { - return fmt.Errorf("postgres: malformed version: %s", version) + return fmt.Errorf("malformed version: %s", version) } d.version = fmt.Sprintf("%s.%s.%s", version[:2], version[2:4], version[4:]) if compareVersions(d.version, "10.0.0") == -1 { - return fmt.Errorf("postgres: unsupported version: %s", d.version) + return fmt.Errorf("unsupported postgres version: %s", d.version) } return nil } @@ -82,36 +82,97 @@ func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Tabl for rows.Next() { c := &Column{} if err := d.scanColumn(c, rows); err != nil { - return nil, fmt.Errorf("postgres: %v", err) + return nil, err } t.AddColumn(c) } if err := rows.Close(); err != nil { - return nil, fmt.Errorf("postgres: closing rows %v", err) + return nil, fmt.Errorf("closing rows %v", err) + } + idxs, err := d.indexes(ctx, tx, name) + if err != nil { + return nil, err + } + // Populate the index information to the table and its columns. + // We do it manually, because PK and uniqueness information does + // not exist when querying the INFORMATION_SCHEMA.COLUMNS above. + for _, idx := range idxs { + switch { + case idx.primary: + for _, name := range idx.columns { + c, ok := t.column(name) + if !ok { + return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) + } + c.Key = PrimaryKey + t.PrimaryKey = append(t.PrimaryKey, c) + } + case idx.Unique && len(idx.columns) == 1: + name := idx.columns[0] + c, ok := t.column(name) + if !ok { + return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) + } + c.Key = UniqueKey + c.Unique = true + default: + t.AddIndex(idx.Name, idx.Unique, idx.columns) + } } - // TODO: populate PK/UNI information for columns and tables and scan indexes. - // - // Get PK and UNI columns of a table: - // - // SELECT a.attname AS column, - // format_type(a.atttypid, a.atttypmod) AS data_type, - // i.indisprimary AS primary, - // i.indisunique AS unique - // FROM pg_index i - // join pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY (i.indkey) - // join pg_stat_user_tables t ON t.relid = i.indrelid - // WHERE t.schemaname = CURRENT_SCHEMA() - // AND i.indrelid = '' :: regclass; - // - // column | data_type | primary | unique - // --------+-----------+---------+-------- - // a1 | integer | t | t - // a2 | integer | t | t - // a0 | integer | f | t - // return t, nil } +// indexesQuery holds a query format for retrieving +// table indexes of the current schema. +const indexesQuery = ` +SELECT i.relname AS index_name, + a.attname AS column_name, + idx.indisprimary AS primary, + idx.indisunique AS unique +FROM pg_class t, + pg_class i, + pg_index idx, + pg_attribute a, + pg_namespace n +WHERE t.oid = idx.indrelid + AND i.oid = idx.indexrelid + AND n.oid = t.relnamespace + AND a.attrelid = t.oid + AND a.attnum = ANY(idx.indkey) + AND t.relkind = 'r' + AND n.nspname = CURRENT_SCHEMA() + AND t.relname = '%s'; +` + +func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (Indexes, error) { + rows := &sql.Rows{} + if err := tx.Query(ctx, fmt.Sprintf(indexesQuery, table), []interface{}{}, rows); err != nil { + return nil, fmt.Errorf("querying indexes for table %s", table) + } + defer rows.Close() + var ( + idxs Indexes + names = make(map[string]*Index) + ) + for rows.Next() { + var ( + name, column string + unique, primary bool + ) + if err := rows.Scan(&name, &column, &primary, &unique); err != nil { + return nil, fmt.Errorf("scanning index description: %v", err) + } + idx, ok := names[name] + if !ok { + idx = &Index{Name: name, Unique: unique, primary: primary} + idxs = append(idxs, idx) + names[name] = idx + } + idx.columns = append(idx.columns, column) + } + return idxs, nil +} + // maxCharSize defines the maximum size of limited character types in Postgres (10 MB). const maxCharSize = 10 << 20 diff --git a/dialect/sql/schema/schema.go b/dialect/sql/schema/schema.go index 90792a60e..e241b5075 100644 --- a/dialect/sql/schema/schema.go +++ b/dialect/sql/schema/schema.go @@ -593,12 +593,9 @@ type Index struct { Unique bool // uniqueness. Columns []*Column // actual table columns. columns []string // columns loaded from query scan. + primary bool // primary key index. } -// Primary indicates if this index is a primary key. -// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects. -func (i *Index) Primary() bool { return i.Name == "PRIMARY" } - // Builder returns the query builder for index creation. The DSL is identical in all dialects. func (i *Index) Builder(table string) *sql.IndexBuilder { idx := sql.CreateIndex(i.Name).Table(table) @@ -646,13 +643,13 @@ func (i *Indexes) ScanMySQL(rows *sql.Rows) error { if err := rows.Scan(&name, &column, &nonuniq, &seqindex); err != nil { return fmt.Errorf("scanning index description: %v", err) } + // ignore primary keys. + if name == "PRIMARY" { + continue + } idx, ok := names[name] if !ok { idx = &Index{Name: name, Unique: !nonuniq} - // ignore primary keys. - if idx.Primary() { - continue - } *i = append(*i, idx) names[name] = idx }