mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sql/schema: add tables inspection capabilities (#1178)
This commit is contained in:
86
dialect/sql/schema/inspect.go
Normal file
86
dialect/sql/schema/inspect.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/facebook/ent/dialect"
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// MigrateOption allows for managing schema configuration using functional options.
|
||||
type InspectOption func(inspect *Inspector)
|
||||
|
||||
// WithSchema provides a schema (named-database) for reading the tables from.
|
||||
func WithSchema(schema string) InspectOption {
|
||||
return func(m *Inspector) {
|
||||
m.schema = schema
|
||||
}
|
||||
}
|
||||
|
||||
// An Inspector provides methods for inspecting database tables.
|
||||
type Inspector struct {
|
||||
sqlDialect
|
||||
schema string
|
||||
}
|
||||
|
||||
// NewInspect returns an inspector for the given SQL driver.
|
||||
func NewInspect(d dialect.Driver, opts ...InspectOption) (*Inspector, error) {
|
||||
i := &Inspector{}
|
||||
for _, opt := range opts {
|
||||
opt(i)
|
||||
}
|
||||
switch d.Dialect() {
|
||||
case dialect.MySQL:
|
||||
i.sqlDialect = &MySQL{Driver: d, schema: i.schema}
|
||||
case dialect.SQLite:
|
||||
i.sqlDialect = &SQLite{Driver: d}
|
||||
case dialect.Postgres:
|
||||
i.sqlDialect = &Postgres{Driver: d, schema: i.schema}
|
||||
default:
|
||||
return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect())
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// Tables returns the tables in the schema.
|
||||
func (i *Inspector) Tables(ctx context.Context) ([]*Table, error) {
|
||||
names, err := i.tables(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx := dialect.NopTx(i.sqlDialect)
|
||||
tables := make([]*Table, 0, len(names))
|
||||
for _, name := range names {
|
||||
t, err := i.table(ctx, tx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (i *Inspector) tables(ctx context.Context) ([]string, error) {
|
||||
t, ok := i.sqlDialect.(interface{ tables() sql.Querier })
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sql/schema: %q driver does not support inspection", i.Dialect())
|
||||
}
|
||||
query, args := t.tables().Query()
|
||||
var (
|
||||
names []string
|
||||
rows = &sql.Rows{}
|
||||
)
|
||||
if err := i.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, fmt.Errorf("mysql: reading table names %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := sql.ScanSlice(rows, &names); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
119
dialect/sql/schema/inspect_test.go
Normal file
119
dialect/sql/schema/inspect_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/facebook/ent/dialect"
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInspector_Tables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options []InspectOption
|
||||
before map[string]func(mysqlMock)
|
||||
tables []*Table
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "default schema",
|
||||
before: map[string]func(mysqlMock){
|
||||
dialect.MySQL: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE())")).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
|
||||
},
|
||||
dialect.SQLite: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")).
|
||||
WithArgs("table").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
},
|
||||
dialect.Postgres: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA()`)).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom schema",
|
||||
options: []InspectOption{WithSchema("public")},
|
||||
before: map[string]func(mysqlMock){
|
||||
dialect.MySQL: func(mock mysqlMock) {
|
||||
mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = ?")).
|
||||
WithArgs("public").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).
|
||||
AddRow("users"))
|
||||
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
|
||||
WithArgs("public", "users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}).
|
||||
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "").
|
||||
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "").
|
||||
AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "").
|
||||
AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin"))
|
||||
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
|
||||
WithArgs("public", "users").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}).
|
||||
AddRow("PRIMARY", "id", "0", "1"))
|
||||
},
|
||||
},
|
||||
tables: func() []*Table {
|
||||
var (
|
||||
c1 = []*Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "name", Type: field.TypeString, Size: 255, Nullable: true},
|
||||
{Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true},
|
||||
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
|
||||
}
|
||||
t1 = &Table{
|
||||
Name: "users",
|
||||
Columns: c1,
|
||||
PrimaryKey: c1[0:1],
|
||||
}
|
||||
)
|
||||
return []*Table{t1}
|
||||
}(),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
for drv := range tt.before {
|
||||
t.Run(path.Join(drv, tt.name), func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.before[drv](mysqlMock{mock})
|
||||
inspect, err := NewInspect(sql.OpenDB(drv, db), tt.options...)
|
||||
require.NoError(t, err)
|
||||
tables, err := inspect.Tables(context.Background())
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
tablesMatch(t, tables, tt.tables)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tablesMatch(t *testing.T, got, expected []*Table) {
|
||||
require.Equal(t, len(expected), len(got))
|
||||
for i := range got {
|
||||
columnsMatch(t, got[i].Columns, got[i].Columns)
|
||||
columnsMatch(t, got[i].PrimaryKey, got[i].PrimaryKey)
|
||||
}
|
||||
}
|
||||
|
||||
func columnsMatch(t *testing.T, got, expected []*Column) {
|
||||
for i := range got {
|
||||
c1, c2 := got[i], expected[i]
|
||||
require.Equal(t, c1.Name, c2.Name)
|
||||
require.Equal(t, c1.Type, c2.Type)
|
||||
require.Equal(t, c1.Size, c2.Size)
|
||||
require.Equal(t, c1.Nullable, c2.Nullable)
|
||||
}
|
||||
}
|
||||
@@ -520,6 +520,13 @@ func (d *MySQL) tableSchema() sql.Querier {
|
||||
return sql.Raw("(SELECT DATABASE())")
|
||||
}
|
||||
|
||||
// tables returns the query for getting the in the schema.
|
||||
func (d *MySQL) tables() sql.Querier {
|
||||
return sql.Select("TABLE_NAME").
|
||||
From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")).
|
||||
Where(sql.EQ("TABLE_SCHEMA", d.tableSchema()))
|
||||
}
|
||||
|
||||
// alterColumns returns the queries for applying the columns change-set.
|
||||
func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
|
||||
b := sql.Dialect(dialect.MySQL).AlterTable(table)
|
||||
|
||||
@@ -31,7 +31,8 @@ func TestMySQL_Create(t *testing.T) {
|
||||
{
|
||||
name: "tx failed",
|
||||
before: func(mock mysqlMock) {
|
||||
mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled)
|
||||
mock.ExpectBegin().
|
||||
WillReturnError(sqlmock.ErrCancelled)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
@@ -452,6 +452,14 @@ func (d *Postgres) tableSchema() sql.Querier {
|
||||
return sql.Raw("CURRENT_SCHEMA()")
|
||||
}
|
||||
|
||||
// tables returns the query for getting the in the schema.
|
||||
func (d *Postgres) tables() sql.Querier {
|
||||
return sql.Dialect(dialect.Postgres).
|
||||
Select("table_name").
|
||||
From(sql.Table("tables").Schema("information_schema")).
|
||||
Where(sql.EQ("table_schema", d.tableSchema()))
|
||||
}
|
||||
|
||||
// alterColumns returns the queries for applying the columns change-set.
|
||||
func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
|
||||
b := sql.Dialect(dialect.Postgres).AlterTable(table)
|
||||
|
||||
@@ -321,3 +321,10 @@ func (d *SQLite) alterColumns(table string, add, _, _ []*Column) sql.Queries {
|
||||
// will support https://www.sqlite.org/lang_altertable.html#otheralter
|
||||
return queries
|
||||
}
|
||||
|
||||
// tables returns the query for getting the in the schema.
|
||||
func (d *SQLite) tables() sql.Querier {
|
||||
return sql.Select("name").
|
||||
From(sql.Table("sqlite_schema")).
|
||||
Where(sql.EQ("type", "table"))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user