mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
Initial commit
fbshipit-source-id: c79a38536e3c128dce1b2948615b72ec9779ed22
This commit is contained in:
115
dialect/sql/schema/mysql.go
Normal file
115
dialect/sql/schema/mysql.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
|
||||
"fbc/ent/dialect"
|
||||
"fbc/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// MySQL is a mysql migration driver.
|
||||
type MySQL struct {
|
||||
dialect.Driver
|
||||
}
|
||||
|
||||
// Create creates all tables resources in the database.
|
||||
func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
|
||||
tx, err := d.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range tables {
|
||||
exist, err := d.tableExist(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if exist {
|
||||
continue
|
||||
}
|
||||
query, args := t.DSL().Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, fmt.Errorf("sql/mysql: create table %q: %v", t.Name, err))
|
||||
}
|
||||
}
|
||||
// create foreign keys after table was created, because circular foreign-key constraints are possible.
|
||||
for _, t := range tables {
|
||||
if len(t.ForeignKeys) == 0 {
|
||||
continue
|
||||
}
|
||||
fks := make([]*ForeignKey, 0, len(t.ForeignKeys))
|
||||
for _, fk := range t.ForeignKeys {
|
||||
fk.Symbol = symbol(fk.Symbol)
|
||||
exist, err := d.fkExist(ctx, tx, fk.Symbol)
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if !exist {
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
}
|
||||
if len(fks) == 0 {
|
||||
continue
|
||||
}
|
||||
b := sql.AlterTable(t.Name)
|
||||
for _, fk := range fks {
|
||||
b.AddForeignKey(fk.DSL())
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, fmt.Errorf("sql/mysql: create foreign keys for %q: %v", t.Name, err))
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
return d.exist(
|
||||
ctx,
|
||||
tx,
|
||||
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = (SELECT DATABASE()) AND TABLE_NAME = ?",
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
return d.exist(
|
||||
ctx,
|
||||
tx,
|
||||
`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE TABLE_SCHEMA=(SELECT DATABASE()) AND CONSTRAINT_TYPE="FOREIGN KEY" AND CONSTRAINT_NAME = ?`,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return false, fmt.Errorf("dialect/mysql: reading schema information %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, fmt.Errorf("dialect/mysql: no rows returned")
|
||||
}
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return false, fmt.Errorf("dialect/mysql: scanning count")
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).
|
||||
func symbol(name string) string {
|
||||
if len(name) <= 64 {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", name[:31], md5.Sum([]byte(name)))
|
||||
}
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
func rollback(tx dialect.Tx, err error) error {
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
err = fmt.Errorf("%s: %v", err.Error(), rerr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
266
dialect/sql/schema/schema.go
Normal file
266
dialect/sql/schema/schema.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"fbc/ent/dialect/sql"
|
||||
"fbc/ent/field"
|
||||
)
|
||||
|
||||
// Table schema definition for SQL dialects.
|
||||
type Table struct {
|
||||
Name string
|
||||
Columns []*Column
|
||||
Indexes []*Index
|
||||
PrimaryKey []*Column
|
||||
ForeignKeys []*ForeignKey
|
||||
}
|
||||
|
||||
// NewTable returns a new table with the given name.
|
||||
func NewTable(name string) *Table { return &Table{Name: name} }
|
||||
|
||||
// AddPrimary adds a new primary key to the table.
|
||||
func (t *Table) AddPrimary(c *Column) *Table {
|
||||
t.Columns = append(t.Columns, c)
|
||||
t.PrimaryKey = append(t.PrimaryKey, c)
|
||||
return t
|
||||
}
|
||||
|
||||
// AddForeignKey adds a foreign key to the table.
|
||||
func (t *Table) AddForeignKey(fk *ForeignKey) *Table {
|
||||
t.ForeignKeys = append(t.ForeignKeys, fk)
|
||||
return t
|
||||
}
|
||||
|
||||
// DSL returns the default DSL query for table creation.
|
||||
func (t *Table) DSL() *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name).IfNotExists()
|
||||
for _, c := range t.Columns {
|
||||
b.Column(c.DSL())
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// SQLite returns the SQLite query for table creation.
|
||||
func (t *Table) SQLite() *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name)
|
||||
for _, c := range t.Columns {
|
||||
b.Column(c.SQLite())
|
||||
}
|
||||
// Unlike in MySQL, we're not able to add foreign-key constraints to table
|
||||
// after it was created, and adding them to the `CREATE TABLE` statement is
|
||||
// not always valid (because circular foreign-keys situation is possible).
|
||||
// We stay consistent by not using constraints at all, and just defining the
|
||||
// foreign keys in the `CREATE TABLE` statement.
|
||||
for _, fk := range t.ForeignKeys {
|
||||
b.ForeignKeys(fk.DSL())
|
||||
}
|
||||
// if it's an ID based primary key, we add the `PRIMARY KEY`
|
||||
// clause to the column declaration.
|
||||
if len(t.PrimaryKey) == 1 {
|
||||
return b
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Column schema definition for SQL dialects.
|
||||
type Column struct {
|
||||
Name string // column name.
|
||||
Type field.Type // column type.
|
||||
Attr string // extra attributes.
|
||||
Default string // default value.
|
||||
Nullable *bool // null or not null attribute.
|
||||
Size int // max size parameter for string, blob, etc.
|
||||
Key string // key definition (PRI, UNI or MUL).
|
||||
Unique bool // column with unique constraint.
|
||||
Increment bool // auto increment attribute.
|
||||
}
|
||||
|
||||
// UniqueKey returns boolean indicates if this column is a unique key.
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (c *Column) UniqueKey() bool { return c.Key == "UNI" }
|
||||
|
||||
// PrimaryKey returns boolean indicates if this column is on of the primary key columns.
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (c *Column) PrimaryKey() bool { return c.Key == "PRI" }
|
||||
|
||||
// DSL returns the default DSL query for table creation.
|
||||
func (c *Column) DSL() *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(c.MySQLType()).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("AUTO_INCREMENT")
|
||||
}
|
||||
c.nullable(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// SQLite returns a SQLite DSL node for this column.
|
||||
func (c *Column) SQLite() *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(c.SQLiteType()).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("PRIMARY KEY AUTOINCREMENT")
|
||||
}
|
||||
c.nullable(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// MySQLType returns the MySQL string type for this column.
|
||||
func (c *Column) MySQLType() (t string) {
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "boolean"
|
||||
case field.TypeInt8:
|
||||
t = "tinyint"
|
||||
case field.TypeUint8:
|
||||
t = "tinyint unsigned"
|
||||
case field.TypeInt64:
|
||||
t = "bigint"
|
||||
case field.TypeUint64:
|
||||
t = "bigint unsigned"
|
||||
case field.TypeInt, field.TypeInt16, field.TypeInt32:
|
||||
t = "int"
|
||||
case field.TypeUint, field.TypeUint16, field.TypeUint32:
|
||||
t = "int unsigned"
|
||||
case field.TypeString:
|
||||
size := c.Size
|
||||
if size == 0 {
|
||||
size = 255
|
||||
}
|
||||
if size < 1<<16 {
|
||||
t = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
t = "longtext"
|
||||
}
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = "double"
|
||||
case field.TypeTime:
|
||||
t = "timestamp"
|
||||
// in MySQL timestamp columns are `NOT NULL by default, and assigning NULL
|
||||
// assigns the current_timestamp(). We avoid this if not set otherwise.
|
||||
if c.Nullable == nil {
|
||||
nullable := true
|
||||
c.Nullable = &nullable
|
||||
}
|
||||
default:
|
||||
panic("unsupported type " + c.Type.String())
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// SQLiteType returns the SQLite string type for this column.
|
||||
func (c *Column) SQLiteType() (t string) {
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "bool"
|
||||
case field.TypeInt8, field.TypeUint8, field.TypeInt, field.TypeInt16, field.TypeInt32, field.TypeUint, field.TypeUint16, field.TypeUint32:
|
||||
t = "integer"
|
||||
case field.TypeInt64, field.TypeUint64:
|
||||
t = "bigint"
|
||||
case field.TypeString:
|
||||
size := c.Size
|
||||
if size == 0 {
|
||||
size = 255
|
||||
}
|
||||
// sqlite has no size limit on varchar.
|
||||
t = fmt.Sprintf("varchar(%d)", size)
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = "real"
|
||||
case field.TypeTime:
|
||||
t = "datetime"
|
||||
default:
|
||||
panic("unsupported type " + c.Type.String())
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// unique adds the `UNIQUE` attribute if the column is a unique type.
|
||||
// it is exist in a different function to share the common declaration
|
||||
// between the two dialects.
|
||||
func (c *Column) unique(b *sql.ColumnBuilder) {
|
||||
if c.Unique {
|
||||
b.Attr("UNIQUE")
|
||||
}
|
||||
}
|
||||
|
||||
// nullable adds the `NULL`/`NOT NULL` attribute to the column. it is exist in
|
||||
// a different function to share the common declaration between the two dialects.
|
||||
func (c *Column) nullable(b *sql.ColumnBuilder) {
|
||||
if c.Nullable != nil {
|
||||
attr := "NULL"
|
||||
if !*c.Nullable {
|
||||
attr = "NOT " + attr
|
||||
}
|
||||
b.Attr(attr)
|
||||
}
|
||||
}
|
||||
|
||||
// ForeignKey definition for creation.
|
||||
type ForeignKey struct {
|
||||
Symbol string // foreign-key name. Generated if empty.
|
||||
Columns []*Column // table column
|
||||
RefTable *Table // referenced table.
|
||||
RefColumns []*Column // referenced columns.
|
||||
OnUpdate ReferenceOption // action on update.
|
||||
OnDelete ReferenceOption // action on delete.
|
||||
}
|
||||
|
||||
// DSL returns a default DSL query for a foreign-key.
|
||||
func (fk ForeignKey) DSL() *sql.ForeignKeyBuilder {
|
||||
cols := make([]string, len(fk.Columns))
|
||||
refs := make([]string, len(fk.RefColumns))
|
||||
for i, c := range fk.Columns {
|
||||
cols[i] = c.Name
|
||||
}
|
||||
for i, c := range fk.RefColumns {
|
||||
refs[i] = c.Name
|
||||
}
|
||||
dsl := sql.ForeignKey().Symbol(fk.Symbol).
|
||||
Columns(cols...).
|
||||
Reference(sql.Reference().Table(fk.RefTable.Name).Columns(refs...))
|
||||
if action := string(fk.OnDelete); action != "" {
|
||||
dsl.OnDelete(action)
|
||||
}
|
||||
if action := string(fk.OnUpdate); action != "" {
|
||||
dsl.OnUpdate(action)
|
||||
}
|
||||
return dsl
|
||||
}
|
||||
|
||||
// ReferenceOption for constraint actions.
|
||||
type ReferenceOption string
|
||||
|
||||
// Reference options.
|
||||
const (
|
||||
NoAction ReferenceOption = "NO ACTION"
|
||||
Restrict ReferenceOption = "RESTRICT"
|
||||
Cascade ReferenceOption = "CASCADE"
|
||||
SetNull ReferenceOption = "SET NULL"
|
||||
SetDefault ReferenceOption = "SET DEFAULT"
|
||||
)
|
||||
|
||||
// ConstName returns the constant name of a reference option. It's used by entc for printing the constant name in templates.
|
||||
func (r ReferenceOption) ConstName() string {
|
||||
if r == NoAction {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(strings.Title(strings.ToLower(string(r))), " ", "")
|
||||
}
|
||||
|
||||
// Index definition for table index.
|
||||
type Index struct {
|
||||
Key string // key name.
|
||||
Column string // column name.
|
||||
}
|
||||
|
||||
// Primary indicates if this index is a primary key.
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (i *Index) Primary() bool { return i.Key == "PRIMARY" }
|
||||
74
dialect/sql/schema/sqlite.go
Normal file
74
dialect/sql/schema/sqlite.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"fbc/ent/dialect"
|
||||
"fbc/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// SQLite is an SQLite migration driver.
|
||||
type SQLite struct {
|
||||
dialect.Driver
|
||||
}
|
||||
|
||||
// Create creates all tables resources in the database.
|
||||
func (d *SQLite) Create(ctx context.Context, tables ...*Table) error {
|
||||
tx, err := d.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
on, err := d.fkEnabled(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sql/sqlite: check foreign_keys pragma: %v", err)
|
||||
}
|
||||
if !on {
|
||||
// foreign_keys pragma is off, either enable it by execute "PRAGMA foreign_keys=ON"
|
||||
// or add the following parameter in the connection string "_fk=1".
|
||||
return fmt.Errorf("sql/sqlite: foreign_keys pragma is off: missing %q is the connection string", "_fk=1")
|
||||
}
|
||||
for _, t := range tables {
|
||||
exist, err := d.tableExist(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if exist {
|
||||
continue
|
||||
}
|
||||
query, args := t.SQLite().Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
err = fmt.Errorf("sql/sqlite: create table %q: %v", t.Name, err)
|
||||
return rollback(tx, err)
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
query, args := sql.Select().Count().
|
||||
From(sql.Table("sqlite_master")).
|
||||
Where(sql.EQ("type", "table").And().EQ("name", name)).
|
||||
Query()
|
||||
return d.exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
func (d *SQLite) fkEnabled(ctx context.Context, tx dialect.Tx) (bool, error) {
|
||||
return d.exist(ctx, tx, "PRAGMA foreign_keys")
|
||||
}
|
||||
|
||||
func (d *SQLite) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return false, fmt.Errorf("dialect/sqlite: reading schema information %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, fmt.Errorf("dialect/sqlite: no rows returned")
|
||||
}
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return false, fmt.Errorf("dialect/sqlite: scanning count")
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
Reference in New Issue
Block a user