diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index 36a39bfae..ec89c6da2 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -9,7 +9,6 @@ import ( "crypto/md5" "fmt" "math" - "sort" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" @@ -322,8 +321,6 @@ func (m *Migrate) changeSet(curr, new *Table) (*changes, error) { if len(curr.PrimaryKey) != len(new.PrimaryKey) { return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) } - sort.Slice(new.PrimaryKey, func(i, j int) bool { return new.PrimaryKey[i].Name < new.PrimaryKey[j].Name }) - sort.Slice(curr.PrimaryKey, func(i, j int) bool { return curr.PrimaryKey[i].Name < curr.PrimaryKey[j].Name }) for i := range curr.PrimaryKey { if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name { return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index b5f1c83c3..e23edee10 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -87,9 +87,6 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, if err := d.scanColumn(c, rows); err != nil { return nil, fmt.Errorf("mysql: %w", err) } - if c.PrimaryKey() { - t.PrimaryKey = append(t.PrimaryKey, c) - } t.AddColumn(c) } if err := rows.Err(); err != nil { @@ -98,7 +95,7 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, if err := rows.Close(); err != nil { return nil, fmt.Errorf("mysql: closing rows %w", err) } - indexes, err := d.indexes(ctx, tx, name) + indexes, err := d.indexes(ctx, tx, t) if err != nil { return nil, err } @@ -115,13 +112,13 @@ func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, } // table loads the table indexes from the database. -func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Index, error) { +func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, t *Table) ([]*Index, error) { rows := &sql.Rows{} query, args := sql.Select("index_name", "column_name", "sub_part", "non_unique", "seq_in_index"). From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")). Where(sql.And( d.matchSchema(), - sql.EQ("TABLE_NAME", name), + sql.EQ("TABLE_NAME", t.Name), )). OrderBy("index_name", "seq_in_index"). Query() @@ -129,7 +126,7 @@ func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Ind return nil, fmt.Errorf("mysql: reading index description %w", err) } defer rows.Close() - idx, err := d.scanIndexes(rows) + idx, err := d.scanIndexes(rows, t) if err != nil { return nil, fmt.Errorf("mysql: %w", err) } @@ -506,7 +503,7 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { // scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows, // should return the following 5 columns: INDEX_NAME, COLUMN_NAME, SUB_PART, NON_UNIQUE, // SEQ_IN_INDEX. SEQ_IN_INDEX specifies the position of the column in the index columns. -func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) { +func (d *MySQL) scanIndexes(rows *sql.Rows, t *Table) (Indexes, error) { var ( i Indexes names = make(map[string]*Index) @@ -522,8 +519,13 @@ func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) { if err := rows.Scan(&name, &column, &subpart, &nonuniq, &seqindex); err != nil { return nil, fmt.Errorf("scanning index description: %w", err) } - // Ignore primary keys. + // Skip primary keys. if name == "PRIMARY" { + c, ok := t.column(column) + if !ok { + return nil, fmt.Errorf("missing primary-key column: %q", column) + } + t.PrimaryKey = append(t.PrimaryKey, c) continue } idx, ok := names[name]