dialect/sql/schema: fix sqlite indexes matching (#471)

This commit is contained in:
Ariel Mashraki
2020-05-07 10:05:17 +03:00
committed by GitHub
parent bcb579106a
commit 7a0b530b89
2 changed files with 34 additions and 13 deletions

View File

@@ -7,6 +7,7 @@ package schema
import (
"context"
"fmt"
"strings"
"github.com/facebookincubator/ent/dialect"
"github.com/facebookincubator/ent/dialect/sql"
@@ -176,7 +177,20 @@ func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
}
// Add and link indexes to table columns.
for _, idx := range indexes {
t.AddIndex(idx.Name, idx.Unique, idx.columns)
switch {
case idx.primary:
case idx.Unique && len(idx.columns) == 1:
name := idx.columns[0]
c, ok := t.column(name)
if !ok {
return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = UniqueKey
c.Unique = true
fallthrough
default:
t.addIndex(idx)
}
}
return t, nil
}
@@ -184,7 +198,7 @@ func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table,
// table loads the table indexes from the database.
func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Indexes, error) {
rows := &sql.Rows{}
query, args := sql.Select("name", "unique").
query, args := sql.Select("name", "unique", "origin").
From(sql.Table(fmt.Sprintf("pragma_index_list('%s')", name)).Unquote()).
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
@@ -194,9 +208,11 @@ func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Index
var idx Indexes
for rows.Next() {
i := &Index{}
if err := rows.Scan(&i.Name, &i.Unique); err != nil {
origin := sql.NullString{}
if err := rows.Scan(&i.Name, &i.Unique, &origin); err != nil {
return nil, fmt.Errorf("scanning index description %v", err)
}
i.primary = origin.String == "pk"
idx = append(idx, i)
}
if err := rows.Close(); err != nil {
@@ -208,6 +224,11 @@ func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Index
return nil, err
}
idx[i].columns = columns
// Normalize implicit index names to ent naming convention. See:
// https://github.com/sqlite/sqlite/blob/e937df8/src/build.c#L3583
if len(columns) == 1 && strings.HasPrefix(idx[i].Name, "sqlite_autoindex_"+name) {
idx[i].Name = columns[0]
}
}
return idx, nil
}

View File

@@ -150,9 +150,9 @@ func TestSQLite_Create(t *testing.T) {
AddRow("text", "text", 0, "NULL", 0).
AddRow("uuid", "uuid", 0, "Null", 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
@@ -181,9 +181,9 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("created_at", "datetime", 0, nil, 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
@@ -221,9 +221,9 @@ func TestSQLite_Create(t *testing.T) {
AddRow("old_medium", "blob", 1, nil, 0).
AddRow("old_long", "blob", 1, nil, 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('blobs')")).
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('blobs')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"}))
for _, c := range []string{"tiny", "blob", "medium", "long"} {
mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))).
WillReturnResult(sqlmock.NewResult(0, 1))
@@ -253,9 +253,9 @@ func TestSQLite_Create(t *testing.T) {
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `name` varchar(255) NOT NULL DEFAULT 'unknown'")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")).
@@ -297,9 +297,9 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("name", "varchar(255)", 1, "NULL", 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()