From 8b798d2714bc4ca98850c4fabeec54858bad74f4 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Sun, 17 Jan 2021 16:41:07 +0200 Subject: [PATCH] dialect/sql/schema: add tables inspection capabilities (#1178) --- dialect/sql/schema/inspect.go | 86 +++++++++++++++++++++ dialect/sql/schema/inspect_test.go | 119 +++++++++++++++++++++++++++++ dialect/sql/schema/mysql.go | 7 ++ dialect/sql/schema/mysql_test.go | 3 +- dialect/sql/schema/postgres.go | 8 ++ dialect/sql/schema/sqlite.go | 7 ++ 6 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 dialect/sql/schema/inspect.go create mode 100644 dialect/sql/schema/inspect_test.go diff --git a/dialect/sql/schema/inspect.go b/dialect/sql/schema/inspect.go new file mode 100644 index 000000000..5a5f48501 --- /dev/null +++ b/dialect/sql/schema/inspect.go @@ -0,0 +1,86 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package schema + +import ( + "context" + "fmt" + + "github.com/facebook/ent/dialect" + "github.com/facebook/ent/dialect/sql" +) + +// MigrateOption allows for managing schema configuration using functional options. +type InspectOption func(inspect *Inspector) + +// WithSchema provides a schema (named-database) for reading the tables from. +func WithSchema(schema string) InspectOption { + return func(m *Inspector) { + m.schema = schema + } +} + +// An Inspector provides methods for inspecting database tables. +type Inspector struct { + sqlDialect + schema string +} + +// NewInspect returns an inspector for the given SQL driver. +func NewInspect(d dialect.Driver, opts ...InspectOption) (*Inspector, error) { + i := &Inspector{} + for _, opt := range opts { + opt(i) + } + switch d.Dialect() { + case dialect.MySQL: + i.sqlDialect = &MySQL{Driver: d, schema: i.schema} + case dialect.SQLite: + i.sqlDialect = &SQLite{Driver: d} + case dialect.Postgres: + i.sqlDialect = &Postgres{Driver: d, schema: i.schema} + default: + return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect()) + } + return i, nil +} + +// Tables returns the tables in the schema. +func (i *Inspector) Tables(ctx context.Context) ([]*Table, error) { + names, err := i.tables(ctx) + if err != nil { + return nil, err + } + tx := dialect.NopTx(i.sqlDialect) + tables := make([]*Table, 0, len(names)) + for _, name := range names { + t, err := i.table(ctx, tx, name) + if err != nil { + return nil, err + } + tables = append(tables, t) + } + return tables, nil +} + +func (i *Inspector) tables(ctx context.Context) ([]string, error) { + t, ok := i.sqlDialect.(interface{ tables() sql.Querier }) + if !ok { + return nil, fmt.Errorf("sql/schema: %q driver does not support inspection", i.Dialect()) + } + query, args := t.tables().Query() + var ( + names []string + rows = &sql.Rows{} + ) + if err := i.Query(ctx, query, args, rows); err != nil { + return nil, fmt.Errorf("mysql: reading table names %w", err) + } + defer rows.Close() + if err := sql.ScanSlice(rows, &names); err != nil { + return nil, err + } + return names, nil +} diff --git a/dialect/sql/schema/inspect_test.go b/dialect/sql/schema/inspect_test.go new file mode 100644 index 000000000..875d305ac --- /dev/null +++ b/dialect/sql/schema/inspect_test.go @@ -0,0 +1,119 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package schema + +import ( + "context" + "math" + "path" + "testing" + + "github.com/facebook/ent/dialect" + "github.com/facebook/ent/dialect/sql" + "github.com/facebook/ent/schema/field" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestInspector_Tables(t *testing.T) { + tests := []struct { + name string + options []InspectOption + before map[string]func(mysqlMock) + tables []*Table + wantErr bool + }{ + { + name: "default schema", + before: map[string]func(mysqlMock){ + dialect.MySQL: func(mock mysqlMock) { + mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE())")). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) + }, + dialect.SQLite: func(mock mysqlMock) { + mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")). + WithArgs("table"). + WillReturnRows(sqlmock.NewRows([]string{"name"})) + }, + dialect.Postgres: func(mock mysqlMock) { + mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA()`)). + WillReturnRows(sqlmock.NewRows([]string{"name"})) + }, + }, + }, + { + name: "custom schema", + options: []InspectOption{WithSchema("public")}, + before: map[string]func(mysqlMock){ + dialect.MySQL: func(mock mysqlMock) { + mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = ?")). + WithArgs("public"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}). + AddRow("users")) + mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). + WithArgs("public", "users"). + WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). + AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). + AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", ""). + AddRow("text", "longtext", "YES", "YES", "NULL", "", "", ""). + AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin")) + mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). + WithArgs("public", "users"). + WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). + AddRow("PRIMARY", "id", "0", "1")) + }, + }, + tables: func() []*Table { + var ( + c1 = []*Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "name", Type: field.TypeString, Size: 255, Nullable: true}, + {Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true}, + {Name: "uuid", Type: field.TypeUUID, Nullable: true}, + } + t1 = &Table{ + Name: "users", + Columns: c1, + PrimaryKey: c1[0:1], + } + ) + return []*Table{t1} + }(), + }, + } + for _, tt := range tests { + for drv := range tt.before { + t.Run(path.Join(drv, tt.name), func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + tt.before[drv](mysqlMock{mock}) + inspect, err := NewInspect(sql.OpenDB(drv, db), tt.options...) + require.NoError(t, err) + tables, err := inspect.Tables(context.Background()) + require.Equal(t, tt.wantErr, err != nil, err) + tablesMatch(t, tables, tt.tables) + }) + } + } +} + +func tablesMatch(t *testing.T, got, expected []*Table) { + require.Equal(t, len(expected), len(got)) + for i := range got { + columnsMatch(t, got[i].Columns, got[i].Columns) + columnsMatch(t, got[i].PrimaryKey, got[i].PrimaryKey) + } +} + +func columnsMatch(t *testing.T, got, expected []*Column) { + for i := range got { + c1, c2 := got[i], expected[i] + require.Equal(t, c1.Name, c2.Name) + require.Equal(t, c1.Type, c2.Type) + require.Equal(t, c1.Size, c2.Size) + require.Equal(t, c1.Nullable, c2.Nullable) + } +} diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index 3e695a1be..4131699b3 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -520,6 +520,13 @@ func (d *MySQL) tableSchema() sql.Querier { return sql.Raw("(SELECT DATABASE())") } +// tables returns the query for getting the in the schema. +func (d *MySQL) tables() sql.Querier { + return sql.Select("TABLE_NAME"). + From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")). + Where(sql.EQ("TABLE_SCHEMA", d.tableSchema())) +} + // alterColumns returns the queries for applying the columns change-set. func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries { b := sql.Dialect(dialect.MySQL).AlterTable(table) diff --git a/dialect/sql/schema/mysql_test.go b/dialect/sql/schema/mysql_test.go index 888cf0577..1451f0be6 100644 --- a/dialect/sql/schema/mysql_test.go +++ b/dialect/sql/schema/mysql_test.go @@ -31,7 +31,8 @@ func TestMySQL_Create(t *testing.T) { { name: "tx failed", before: func(mock mysqlMock) { - mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled) + mock.ExpectBegin(). + WillReturnError(sqlmock.ErrCancelled) }, wantErr: true, }, diff --git a/dialect/sql/schema/postgres.go b/dialect/sql/schema/postgres.go index 999c7bb56..738ab8794 100644 --- a/dialect/sql/schema/postgres.go +++ b/dialect/sql/schema/postgres.go @@ -452,6 +452,14 @@ func (d *Postgres) tableSchema() sql.Querier { return sql.Raw("CURRENT_SCHEMA()") } +// tables returns the query for getting the in the schema. +func (d *Postgres) tables() sql.Querier { + return sql.Dialect(dialect.Postgres). + Select("table_name"). + From(sql.Table("tables").Schema("information_schema")). + Where(sql.EQ("table_schema", d.tableSchema())) +} + // alterColumns returns the queries for applying the columns change-set. func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries { b := sql.Dialect(dialect.Postgres).AlterTable(table) diff --git a/dialect/sql/schema/sqlite.go b/dialect/sql/schema/sqlite.go index 464216ea6..f98b82983 100644 --- a/dialect/sql/schema/sqlite.go +++ b/dialect/sql/schema/sqlite.go @@ -321,3 +321,10 @@ func (d *SQLite) alterColumns(table string, add, _, _ []*Column) sql.Queries { // will support https://www.sqlite.org/lang_altertable.html#otheralter return queries } + +// tables returns the query for getting the in the schema. +func (d *SQLite) tables() sql.Querier { + return sql.Select("name"). + From(sql.Table("sqlite_schema")). + Where(sql.EQ("type", "table")) +}