mirror of
https://github.com/ent/ent.git
synced 2026-05-01 23:20:53 +03:00
dialect/sql/schema: hello ariga.io/atlas (#2279)
This commit is contained in:
@@ -7,11 +7,16 @@ package schema
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/schema/field"
|
||||
|
||||
"ariga.io/atlas/sql/migrate"
|
||||
"ariga.io/atlas/sql/schema"
|
||||
"ariga.io/atlas/sql/sqlite"
|
||||
)
|
||||
|
||||
// SQLite is an SQLite migration driver.
|
||||
@@ -21,7 +26,7 @@ type SQLite struct {
|
||||
}
|
||||
|
||||
// init makes sure that foreign_keys support is enabled.
|
||||
func (d *SQLite) init(ctx context.Context, tx dialect.Tx) error {
|
||||
func (d *SQLite) init(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
on, err := exist(ctx, tx, "PRAGMA foreign_keys")
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err)
|
||||
@@ -34,7 +39,7 @@ func (d *SQLite) init(ctx context.Context, tx dialect.Tx) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
func (d *SQLite) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) {
|
||||
query, args := sql.Select().Count().
|
||||
From(sql.Table("sqlite_master")).
|
||||
Where(sql.And(
|
||||
@@ -42,7 +47,7 @@ func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bo
|
||||
sql.EQ("name", name),
|
||||
)).
|
||||
Query()
|
||||
return exist(ctx, tx, query, args...)
|
||||
return exist(ctx, conn, query, args...)
|
||||
}
|
||||
|
||||
// setRange sets the start value of table PK.
|
||||
@@ -50,12 +55,12 @@ func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bo
|
||||
// whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables)
|
||||
// only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence"
|
||||
// table, we updated it. Otherwise, we insert a new value.
|
||||
func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error {
|
||||
func (d *SQLite) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error {
|
||||
query, args := sql.Select().Count().
|
||||
From(sql.Table("sqlite_sequence")).
|
||||
Where(sql.EQ("name", t.Name)).
|
||||
Query()
|
||||
exists, err := exist(ctx, tx, query, args...)
|
||||
exists, err := exist(ctx, conn, query, args...)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
@@ -64,7 +69,7 @@ func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, t *Table, value in
|
||||
default: // !exists
|
||||
query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(t.Name, value).Query()
|
||||
}
|
||||
return tx.Exec(ctx, query, args, nil)
|
||||
return conn.Exec(ctx, query, args, nil)
|
||||
}
|
||||
|
||||
func (d *SQLite) tBuilder(t *Table) *sql.TableBuilder {
|
||||
@@ -345,3 +350,102 @@ func (d *SQLite) needsConversion(old, new *Column) bool {
|
||||
c1, c2 := d.cType(old), d.cType(new)
|
||||
return c1 != c2 && old.typ != c2
|
||||
}
|
||||
|
||||
// Atlas integration.
|
||||
|
||||
func (d *SQLite) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
|
||||
return sqlite.Open(&db{ExecQuerier: conn})
|
||||
}
|
||||
|
||||
func (d *SQLite) atTable(t1 *Table, t2 *schema.Table) {
|
||||
if t1.Annotation != nil {
|
||||
setAtChecks(t1, t2)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *SQLite) atTypeC(c1 *Column, c2 *schema.Column) error {
|
||||
if c1.SchemaType != nil && c1.SchemaType[dialect.SQLite] != "" {
|
||||
t, err := sqlite.ParseType(strings.ToLower(c1.SchemaType[dialect.SQLite]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c2.Type.Type = t
|
||||
return nil
|
||||
}
|
||||
var t schema.Type
|
||||
switch c1.Type {
|
||||
case field.TypeBool:
|
||||
t = &schema.BoolType{T: "bool"}
|
||||
case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32,
|
||||
field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64:
|
||||
t = &schema.IntegerType{T: sqlite.TypeInteger}
|
||||
case field.TypeBytes:
|
||||
t = &schema.BinaryType{T: sqlite.TypeBlob}
|
||||
case field.TypeString, field.TypeEnum:
|
||||
// SQLite does not impose any length restrictions on
|
||||
// the length of strings, BLOBs or numeric values.
|
||||
t = &schema.StringType{T: sqlite.TypeText}
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = &schema.FloatType{T: sqlite.TypeReal}
|
||||
case field.TypeTime:
|
||||
t = &schema.TimeType{T: "datetime"}
|
||||
case field.TypeJSON:
|
||||
t = &schema.JSONType{T: "json"}
|
||||
case field.TypeUUID:
|
||||
t = &sqlite.UUIDType{T: "uuid"}
|
||||
case field.TypeOther:
|
||||
t = &schema.UnsupportedType{T: c1.typ}
|
||||
default:
|
||||
t, err := sqlite.ParseType(strings.ToLower(c1.typ))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c2.Type.Type = t
|
||||
}
|
||||
c2.Type.Type = t
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *SQLite) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) {
|
||||
// For UNIQUE columns, SQLite create an implicit index named
|
||||
// "sqlite_autoindex_<table>_<i>". Ent uses the MySQL approach
|
||||
// in its migration, and name these indexes as the columns.
|
||||
for _, idx := range t1.Indexes {
|
||||
// Index also defined explicitly, and will be add in atIndexes.
|
||||
if idx.Unique && d.atImplicitIndexName(idx, t1, c1) {
|
||||
return
|
||||
}
|
||||
}
|
||||
t2.AddIndexes(schema.NewUniqueIndex(c1.Name).AddColumns(c2))
|
||||
}
|
||||
|
||||
func (d *SQLite) atImplicitIndexName(idx *Index, t1 *Table, c1 *Column) bool {
|
||||
if idx.Name == c1.Name {
|
||||
return true
|
||||
}
|
||||
p := fmt.Sprintf("sqlite_autoindex_%s_", t1.Name)
|
||||
if !strings.HasPrefix(idx.Name, p) {
|
||||
return false
|
||||
}
|
||||
i, err := strconv.ParseInt(strings.TrimPrefix(idx.Name, p), 10, 64)
|
||||
return err == nil && i > 0
|
||||
}
|
||||
|
||||
func (d *SQLite) atIncrementC(_ *schema.Table, c *schema.Column) {
|
||||
c.AddAttrs(&sqlite.AutoIncrement{})
|
||||
}
|
||||
|
||||
func (d *SQLite) atIncrementT(t *schema.Table, v int64) {
|
||||
t.AddAttrs(&sqlite.AutoIncrement{Seq: v})
|
||||
}
|
||||
|
||||
func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error {
|
||||
for _, c1 := range idx1.Columns {
|
||||
c2, ok := t2.Column(c1.Name)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name)
|
||||
}
|
||||
idx2.AddParts(&schema.IndexPart{C: c2})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user