dialect/sql/schema: add tables inspection capabilities (#1178)

This commit is contained in:
Ariel Mashraki
2021-01-17 16:41:07 +02:00
committed by GitHub
parent a692086309
commit 8b798d2714
6 changed files with 229 additions and 1 deletions

View File

@@ -0,0 +1,86 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"context"
"fmt"
"github.com/facebook/ent/dialect"
"github.com/facebook/ent/dialect/sql"
)
// MigrateOption allows for managing schema configuration using functional options.
type InspectOption func(inspect *Inspector)
// WithSchema provides a schema (named-database) for reading the tables from.
func WithSchema(schema string) InspectOption {
return func(m *Inspector) {
m.schema = schema
}
}
// An Inspector provides methods for inspecting database tables.
type Inspector struct {
sqlDialect
schema string
}
// NewInspect returns an inspector for the given SQL driver.
func NewInspect(d dialect.Driver, opts ...InspectOption) (*Inspector, error) {
i := &Inspector{}
for _, opt := range opts {
opt(i)
}
switch d.Dialect() {
case dialect.MySQL:
i.sqlDialect = &MySQL{Driver: d, schema: i.schema}
case dialect.SQLite:
i.sqlDialect = &SQLite{Driver: d}
case dialect.Postgres:
i.sqlDialect = &Postgres{Driver: d, schema: i.schema}
default:
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect())
}
return i, nil
}
// Tables returns the tables in the schema.
func (i *Inspector) Tables(ctx context.Context) ([]*Table, error) {
names, err := i.tables(ctx)
if err != nil {
return nil, err
}
tx := dialect.NopTx(i.sqlDialect)
tables := make([]*Table, 0, len(names))
for _, name := range names {
t, err := i.table(ctx, tx, name)
if err != nil {
return nil, err
}
tables = append(tables, t)
}
return tables, nil
}
func (i *Inspector) tables(ctx context.Context) ([]string, error) {
t, ok := i.sqlDialect.(interface{ tables() sql.Querier })
if !ok {
return nil, fmt.Errorf("sql/schema: %q driver does not support inspection", i.Dialect())
}
query, args := t.tables().Query()
var (
names []string
rows = &sql.Rows{}
)
if err := i.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("mysql: reading table names %w", err)
}
defer rows.Close()
if err := sql.ScanSlice(rows, &names); err != nil {
return nil, err
}
return names, nil
}

View File

@@ -0,0 +1,119 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"context"
"math"
"path"
"testing"
"github.com/facebook/ent/dialect"
"github.com/facebook/ent/dialect/sql"
"github.com/facebook/ent/schema/field"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestInspector_Tables(t *testing.T) {
tests := []struct {
name string
options []InspectOption
before map[string]func(mysqlMock)
tables []*Table
wantErr bool
}{
{
name: "default schema",
before: map[string]func(mysqlMock){
dialect.MySQL: func(mock mysqlMock) {
mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE())")).
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
},
dialect.SQLite: func(mock mysqlMock) {
mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")).
WithArgs("table").
WillReturnRows(sqlmock.NewRows([]string{"name"}))
},
dialect.Postgres: func(mock mysqlMock) {
mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA()`)).
WillReturnRows(sqlmock.NewRows([]string{"name"}))
},
},
},
{
name: "custom schema",
options: []InspectOption{WithSchema("public")},
before: map[string]func(mysqlMock){
dialect.MySQL: func(mock mysqlMock) {
mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = ?")).
WithArgs("public").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).
AddRow("users"))
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
WithArgs("public", "users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "").
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "").
AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "").
AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin"))
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
WithArgs("public", "users").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}).
AddRow("PRIMARY", "id", "0", "1"))
},
},
tables: func() []*Table {
var (
c1 = []*Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "name", Type: field.TypeString, Size: 255, Nullable: true},
{Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true},
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
}
t1 = &Table{
Name: "users",
Columns: c1,
PrimaryKey: c1[0:1],
}
)
return []*Table{t1}
}(),
},
}
for _, tt := range tests {
for drv := range tt.before {
t.Run(path.Join(drv, tt.name), func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tt.before[drv](mysqlMock{mock})
inspect, err := NewInspect(sql.OpenDB(drv, db), tt.options...)
require.NoError(t, err)
tables, err := inspect.Tables(context.Background())
require.Equal(t, tt.wantErr, err != nil, err)
tablesMatch(t, tables, tt.tables)
})
}
}
}
func tablesMatch(t *testing.T, got, expected []*Table) {
require.Equal(t, len(expected), len(got))
for i := range got {
columnsMatch(t, got[i].Columns, got[i].Columns)
columnsMatch(t, got[i].PrimaryKey, got[i].PrimaryKey)
}
}
func columnsMatch(t *testing.T, got, expected []*Column) {
for i := range got {
c1, c2 := got[i], expected[i]
require.Equal(t, c1.Name, c2.Name)
require.Equal(t, c1.Type, c2.Type)
require.Equal(t, c1.Size, c2.Size)
require.Equal(t, c1.Nullable, c2.Nullable)
}
}

View File

@@ -520,6 +520,13 @@ func (d *MySQL) tableSchema() sql.Querier {
return sql.Raw("(SELECT DATABASE())")
}
// tables returns the query for getting the in the schema.
func (d *MySQL) tables() sql.Querier {
return sql.Select("TABLE_NAME").
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
Where(sql.EQ("TABLE_SCHEMA", d.tableSchema()))
}
// alterColumns returns the queries for applying the columns change-set.
func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
b := sql.Dialect(dialect.MySQL).AlterTable(table)

View File

@@ -31,7 +31,8 @@ func TestMySQL_Create(t *testing.T) {
{
name: "tx failed",
before: func(mock mysqlMock) {
mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled)
mock.ExpectBegin().
WillReturnError(sqlmock.ErrCancelled)
},
wantErr: true,
},

View File

@@ -452,6 +452,14 @@ func (d *Postgres) tableSchema() sql.Querier {
return sql.Raw("CURRENT_SCHEMA()")
}
// tables returns the query for getting the in the schema.
func (d *Postgres) tables() sql.Querier {
return sql.Dialect(dialect.Postgres).
Select("table_name").
From(sql.Table("tables").Schema("information_schema")).
Where(sql.EQ("table_schema", d.tableSchema()))
}
// alterColumns returns the queries for applying the columns change-set.
func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
b := sql.Dialect(dialect.Postgres).AlterTable(table)

View File

@@ -321,3 +321,10 @@ func (d *SQLite) alterColumns(table string, add, _, _ []*Column) sql.Queries {
// will support https://www.sqlite.org/lang_altertable.html#otheralter
return queries
}
// tables returns the query for getting the in the schema.
func (d *SQLite) tables() sql.Querier {
return sql.Select("name").
From(sql.Table("sqlite_schema")).
Where(sql.EQ("type", "table"))
}