Files
ent/dialect/sql/schema/inspect_test.go
Naor Matania 8847a1ac55 dialect/sql/schema: fix migrating Postgres fields with max character varchar custom type (#2162)
* Support migration of varying length varchar postgres

* update type

* fix type

* improve type

* remove not needed change

* fix failing test

* add inspect test

* add more tests

* go fmt

* fix pr comment
2021-11-28 07:57:57 +02:00

334 lines
18 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"context"
"fmt"
"math"
"path"
"testing"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestInspector_Tables(t *testing.T) {
tests := []struct {
name string
options []InspectOption
before map[string]func(mysqlMock)
tables func(drv string) []*Table
wantErr bool
}{
{
name: "default schema",
before: map[string]func(mysqlMock){
dialect.MySQL: func(mock mysqlMock) {
mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE())")).
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
},
dialect.SQLite: func(mock mysqlMock) {
mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")).
WithArgs("table").
WillReturnRows(sqlmock.NewRows([]string{"name"}))
},
dialect.Postgres: func(mock mysqlMock) {
mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA()`)).
WillReturnRows(sqlmock.NewRows([]string{"name"}))
},
},
tables: func(drv string) []*Table {
return nil
},
},
{
name: "custom schema",
options: []InspectOption{WithSchema("public")},
before: map[string]func(mysqlMock){
dialect.MySQL: func(mock mysqlMock) {
mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = ?")).
WithArgs("public").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).
AddRow("users").
AddRow("pets").
AddRow("groups").
AddRow("user_groups"))
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
WithArgs("public", "users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}).
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil).
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil).
AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "", nil, nil).
AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin", nil, nil).
AddRow("price", "decimal(6, 4)", "NO", "YES", "NULL", "", "", "", "6", "4").
AddRow("bank_id", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil))
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
WithArgs("public", "users").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}).
AddRow("PRIMARY", "id", nil, "0", "1"))
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
WithArgs("public", "pets").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}).
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil).
AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil).
AddRow("user_pets", "bigint(20)", "YES", "YES", "NULL", "", "", "", nil, nil))
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
WithArgs("public", "pets").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}).
AddRow("PRIMARY", "id", nil, "0", "1"))
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
WithArgs("public", "groups").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}).
AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil).
AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil))
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
WithArgs("public", "groups").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}).
AddRow("PRIMARY", "id", nil, "0", "1"))
mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")).
WithArgs("public", "user_groups").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}).
AddRow("user_id", "bigint(20)", "NO", "YES", "NULL", "", "", "", nil, nil).
AddRow("group_id", "bigint(20)", "NO", "YES", "NULL", "", "", "", nil, nil))
mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")).
WithArgs("public", "user_groups").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}))
},
dialect.SQLite: func(mock mysqlMock) {
mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")).
WithArgs("table").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow("users").
AddRow("pets").
AddRow("groups").
AddRow("user_groups"))
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("id", "integer", 1, "NULL", 1).
AddRow("name", "varchar(255)", 0, "NULL", 0).
AddRow("text", "text", 0, "NULL", 0).
AddRow("uuid", "uuid", 0, "NULL", 0).
AddRow("price", "real", 1, "NULL", 0).
AddRow("bank_id", "varchar(255)", 1, "NULL", 0))
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('pets') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("id", "integer", 1, "NULL", 1).
AddRow("name", "varchar(255)", 0, "NULL", 0).
AddRow("user_pets", "integer", 0, "NULL", 0))
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('pets')")).
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('groups') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("id", "integer", 1, "NULL", 1).
AddRow("name", "varchar(255)", 1, "NULL", 0))
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('groups')")).
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('user_groups') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("user_id", "integer", 1, "NULL", 0).
AddRow("group_id", "integer", 1, "NULL", 0))
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('user_groups')")).
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
},
dialect.Postgres: func(mock mysqlMock) {
mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = $1`)).
WithArgs("public").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow("users").
AddRow("pets").
AddRow("groups").
AddRow("user_groups"))
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)).
WithArgs("public", "users").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}).
AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil).
AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil).
AddRow("text", "text", "YES", "NULL", "text", nil, nil, nil).
AddRow("uuid", "uuid", "YES", "NULL", "uuid", nil, nil, nil).
AddRow("price", "numeric", "NO", "NULL", "numeric", "6", "4", nil).
AddRow("bank_id", "character", "NO", "NULL", "bpchar", nil, nil, 20))
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "users"))).
WithArgs("public").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}).
AddRow("users_pkey", "id", "t", "t", 0))
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)).
WithArgs("public", "pets").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}).
AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil).
AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil).
AddRow("user_pets", "bigint", "YES", "NULL", "int8", nil, nil, nil))
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "pets"))).
WithArgs("public").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}).
AddRow("pets_pkey", "id", "t", "t", 0))
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)).
WithArgs("public", "groups").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}).
AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil).
AddRow("name", "character", "NO", "NULL", "bpchar", nil, nil, nil))
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "groups"))).
WithArgs("public").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}).
AddRow("groups_pkey", "id", "t", "t", 0))
mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)).
WithArgs("public", "user_groups").
WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}).
AddRow("user_id", "bigint", "NO", "NULL", "int8", nil, nil, nil).
AddRow("group_id", "bigint", "NO", "NULL", "int8", nil, nil, nil))
mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "user_groups"))).
WithArgs("public").
WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}))
mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "users"))).
WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}))
mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "pets"))).
WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}).
AddRow("public", "pet_users_pets", "pets", "user_pets", "public", "users", "id"))
mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "groups"))).
WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}))
mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "user_groups"))).
WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}).
AddRow("public", "user_groups_group_id", "user_groups", "group_id", "public", "groups", "id").
AddRow("public", "user_groups_user_id", "user_groups", "user_id", "public", "users", "id"))
},
},
tables: func(drv string) []*Table {
var (
c1 = []*Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "name", Type: field.TypeString, Size: 255, Nullable: true},
{Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true},
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
{Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{
dialect.MySQL: "decimal(6,4)",
dialect.Postgres: "numeric(6,4)",
}},
{Name: "bank_id", Type: field.TypeString, SchemaType: map[string]string{
dialect.Postgres: "varchar(20)",
}},
}
t1 = &Table{
Name: "users",
Columns: c1,
PrimaryKey: c1[0:1],
}
c2 = []*Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "name", Type: field.TypeString, Size: 255, Nullable: true},
{Name: "user_pets", Type: field.TypeInt64, Nullable: true},
}
t2 = &Table{
Name: "pets",
Columns: c2,
PrimaryKey: c2[0:1],
}
c3 = []*Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "name", Type: field.TypeString},
}
t3 = &Table{
Name: "groups",
Columns: c3,
PrimaryKey: c3[0:1],
}
c4 = []*Column{
{Name: "user_id", Type: field.TypeInt64},
{Name: "group_id", Type: field.TypeInt64},
}
t4 = &Table{
Name: "user_groups",
Columns: c4,
}
)
// Only postgres currently supports foreign key inspection
if drv == dialect.Postgres {
t2.ForeignKeys = []*ForeignKey{
{
Symbol: "pet_users_pets",
Columns: []*Column{c2[2]},
RefTable: t1,
RefColumns: []*Column{c1[0]},
},
}
t4.ForeignKeys = []*ForeignKey{
{
Symbol: "user_groups_group_id",
Columns: []*Column{c4[1]},
RefTable: t3,
RefColumns: []*Column{c3[0]},
},
{
Symbol: "user_groups_user_id",
Columns: []*Column{c4[0]},
RefTable: t1,
RefColumns: []*Column{c1[0]},
},
}
}
return []*Table{t1, t2, t3, t4}
},
},
}
for _, tt := range tests {
for drv := range tt.before {
t.Run(path.Join(drv, tt.name), func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tt.before[drv](mysqlMock{mock})
inspect, err := NewInspect(sql.OpenDB(drv, db), tt.options...)
require.NoError(t, err)
tables, err := inspect.Tables(context.Background())
require.Equal(t, tt.wantErr, err != nil, err)
tablesMatch(t, drv, tables, tt.tables(drv))
})
}
}
}
func tablesMatch(t *testing.T, drv string, got, expected []*Table) {
require.Equal(t, len(expected), len(got))
for i := range got {
columnsMatch(t, drv, got[i].Columns, expected[i].Columns)
columnsMatch(t, drv, got[i].PrimaryKey, expected[i].PrimaryKey)
foreignKeysMatch(t, drv, got[i].ForeignKeys, expected[i].ForeignKeys)
}
}
func columnsMatch(t *testing.T, drv string, got, expected []*Column) {
require.Equal(t, len(expected), len(got))
for i := range got {
c1, c2 := got[i], expected[i]
require.Equal(t, c2.Name, c1.Name)
require.Equal(t, c2.Nullable, c1.Nullable)
require.True(t, c1.Type == c2.Type || c1.ConvertibleTo(c2), "mismatched types: %s - %s", c1.Type, c2.Type)
if c2.SchemaType[drv] != "" {
require.Equal(t, c2.SchemaType[drv], c1.SchemaType[drv])
}
}
}
func foreignKeysMatch(t *testing.T, drv string, expected []*ForeignKey, got []*ForeignKey) {
require.Equal(t, len(expected), len(got))
for i := range got {
fk1, fk2 := got[i], expected[i]
require.Equal(t, fk2.Symbol, fk1.Symbol)
require.Equal(t, fk2.RefTable.Name, fk1.RefTable.Name)
columnsMatch(t, drv, fk1.Columns, fk2.Columns)
columnsMatch(t, drv, fk1.RefColumns, fk2.RefColumns)
}
}