mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
Initial commit
fbshipit-source-id: c79a38536e3c128dce1b2948615b72ec9779ed22
This commit is contained in:
1309
dialect/sql/builder.go
Normal file
1309
dialect/sql/builder.go
Normal file
File diff suppressed because it is too large
Load Diff
382
dialect/sql/builder_test.go
Normal file
382
dialect/sql/builder_test.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuilder(t *testing.T) {
|
||||
tests := []struct {
|
||||
input Node
|
||||
wantQuery string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
input: CreateTable("users").
|
||||
Columns(
|
||||
Column("id").Type("int").Attr("auto_increment"),
|
||||
Column("name").Type("varchar(255)"),
|
||||
).
|
||||
PrimaryKey("id"),
|
||||
wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`))",
|
||||
},
|
||||
{
|
||||
input: CreateTable("users").
|
||||
IfNotExists().
|
||||
Columns(
|
||||
Column("id").Type("int").Attr("auto_increment"),
|
||||
).
|
||||
PrimaryKey("id", "name"),
|
||||
wantQuery: "CREATE TABLE IF NOT EXISTS `users`(`id` int auto_increment, PRIMARY KEY(`id`, `name`))",
|
||||
},
|
||||
{
|
||||
input: CreateTable("users").
|
||||
IfNotExists().
|
||||
Columns(
|
||||
Column("id").Type("int").Attr("auto_increment"),
|
||||
Column("card_id").Type("int"),
|
||||
).
|
||||
PrimaryKey("id", "name").
|
||||
ForeignKeys(ForeignKey().Columns("card_id").
|
||||
Reference(Reference().Table("cards").Columns("id")).OnDelete("SET NULL")),
|
||||
wantQuery: "CREATE TABLE IF NOT EXISTS `users`(`id` int auto_increment, `card_id` int, PRIMARY KEY(`id`, `name`), FOREIGN KEY(`card_id`) REFERENCES `cards`(`id`) ON DELETE SET NULL)",
|
||||
},
|
||||
{
|
||||
input: AlterTable("users").
|
||||
AddColumn(Column("group_id").Type("int").Attr("UNIQUE")).
|
||||
AddForeignKey(ForeignKey().Columns("group_id").
|
||||
Reference(Reference().Table("groups").Columns("id")).
|
||||
OnDelete("CASCADE"),
|
||||
),
|
||||
wantQuery: "ALTER TABLE `users` ADD `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`) ON DELETE CASCADE",
|
||||
},
|
||||
{
|
||||
input: AlterTable("users").
|
||||
AddColumn(Column("group_id").Type("int").Attr("UNIQUE")).
|
||||
AddForeignKey(ForeignKey().Columns("group_id").
|
||||
Reference(Reference().Table("groups").Columns("id")),
|
||||
),
|
||||
wantQuery: "ALTER TABLE `users` ADD `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`)",
|
||||
},
|
||||
{
|
||||
input: AlterTable("users").
|
||||
AddColumn(Column("age").Type("int")).
|
||||
AddColumn(Column("name").Type("varchar(255)")),
|
||||
wantQuery: "ALTER TABLE `users` ADD `age` int, ADD `name` varchar(255)",
|
||||
},
|
||||
{
|
||||
input: AlterTable("users").
|
||||
AddForeignKey(ForeignKey().Columns("group_id").
|
||||
Reference(Reference().Table("groups").Columns("id")),
|
||||
).
|
||||
AddForeignKey(ForeignKey().Columns("location_id").
|
||||
Reference(Reference().Table("locations").Columns("id")),
|
||||
),
|
||||
wantQuery: "ALTER TABLE `users` ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`), ADD CONSTRAINT FOREIGN KEY(`location_id`) REFERENCES `locations`(`id`)",
|
||||
},
|
||||
{
|
||||
input: Insert("users").Columns("age").Values(1),
|
||||
wantQuery: "INSERT INTO `users` (`age`) VALUES (?)",
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
input: Insert("users").Columns("name", "age").Values("a8m", 10),
|
||||
wantQuery: "INSERT INTO `users` (`name`, `age`) VALUES (?, ?)",
|
||||
wantArgs: []interface{}{"a8m", 10},
|
||||
},
|
||||
{
|
||||
input: Insert("users").Columns("name", "age").Values("a8m", 10).Values("foo", 20),
|
||||
wantQuery: "INSERT INTO `users` (`name`, `age`) VALUES (?, ?), (?, ?)",
|
||||
wantArgs: []interface{}{"a8m", 10, "foo", 20},
|
||||
},
|
||||
{
|
||||
input: Update("users").Set("name", "foo"),
|
||||
wantQuery: "UPDATE `users` SET `name` = ?",
|
||||
wantArgs: []interface{}{"foo"},
|
||||
},
|
||||
{
|
||||
input: Update("users").Set("name", "foo").Set("age", 10),
|
||||
wantQuery: "UPDATE `users` SET `name` = ?, `age` = ?",
|
||||
wantArgs: []interface{}{"foo", 10},
|
||||
},
|
||||
{
|
||||
input: Update("users").Set("name", "foo").Where(EQ("name", "bar")),
|
||||
wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` = ?",
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
input: Update("users").Set("name", "foo").SetNull("spouse_id"),
|
||||
wantQuery: "UPDATE `users` SET `spouse_id` = NULL, `name` = ?",
|
||||
wantArgs: []interface{}{"foo"},
|
||||
},
|
||||
{
|
||||
input: Update("users").Set("name", "foo").
|
||||
Where(EQ("name", "bar")).
|
||||
Where(EQ("age", 20)),
|
||||
wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` = ? AND `age` = ?",
|
||||
wantArgs: []interface{}{"foo", "bar", 20},
|
||||
},
|
||||
{
|
||||
input: Update("users").
|
||||
Set("name", "foo").
|
||||
Set("age", 10).
|
||||
Where(EQ("name", "bar").Or().EQ("name", "baz")),
|
||||
wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ? OR `name` = ?",
|
||||
wantArgs: []interface{}{"foo", 10, "bar", "baz"},
|
||||
},
|
||||
{
|
||||
input: Update("users").
|
||||
Set("name", "foo").
|
||||
Set("age", 10).
|
||||
Where(P().EQ("name", "foo")),
|
||||
wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ?",
|
||||
wantArgs: []interface{}{"foo", 10, "foo"},
|
||||
},
|
||||
{
|
||||
input: Update("users").
|
||||
Set("name", "foo").
|
||||
Where(In("name", "bar", "baz").And().NotIn("age", 1, 2)),
|
||||
wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` IN (?, ?) AND `age` NOT IN (?, ?)",
|
||||
wantArgs: []interface{}{"foo", "bar", "baz", 1, 2},
|
||||
},
|
||||
{
|
||||
input: Update("users").
|
||||
Set("name", "foo").
|
||||
Where(HasPrefix("nickname", "a8m").And().Contains("lastname", "mash")),
|
||||
wantQuery: "UPDATE `users` SET `name` = ? WHERE `nickname` LIKE ? AND `lastname` LIKE ?",
|
||||
wantArgs: []interface{}{"foo", "a8m%", "%mash%"},
|
||||
},
|
||||
{
|
||||
input: Update("users").
|
||||
Set("name", "foo").
|
||||
Set("age", 10).
|
||||
Where(P().EQ("name", "foo").And().EQ("age", 20)),
|
||||
wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ? AND `age` = ?",
|
||||
wantArgs: []interface{}{"foo", 10, "foo", 20},
|
||||
},
|
||||
{
|
||||
input: Delete("users").
|
||||
Where(NotNull("parent_id")),
|
||||
wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL",
|
||||
},
|
||||
{
|
||||
input: Delete("users").
|
||||
Where(False().And().False()),
|
||||
wantQuery: "DELETE FROM `users` WHERE FALSE AND FALSE",
|
||||
},
|
||||
{
|
||||
input: Delete("users").
|
||||
Where(NotNull("parent_id").Or().EQ("parent_id", 10)),
|
||||
wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL OR `parent_id` = ?",
|
||||
wantArgs: []interface{}{10},
|
||||
},
|
||||
{
|
||||
input: Delete("users").
|
||||
Where(
|
||||
Or(
|
||||
EQ("name", "foo").And().EQ("age", 10),
|
||||
EQ("name", "bar").And().EQ("age", 20),
|
||||
And(
|
||||
EQ("name", "qux"),
|
||||
EQ("age", 1).Or().EQ("age", 2),
|
||||
),
|
||||
),
|
||||
),
|
||||
wantQuery: "DELETE FROM `users` WHERE (`name` = ? AND `age` = ?) OR (`name` = ? AND `age` = ?) OR ((`name` = ?) AND (`age` = ? OR `age` = ?))",
|
||||
wantArgs: []interface{}{"foo", 10, "bar", 20, "qux", 1, 2},
|
||||
},
|
||||
{
|
||||
input: Select().From(Table("users")),
|
||||
wantQuery: "SELECT * FROM `users`",
|
||||
},
|
||||
{
|
||||
input: Select().From(Table("users").As("u")),
|
||||
wantQuery: "SELECT * FROM `users` AS `u`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Table("groups").As("g")
|
||||
return Select(t1.C("id"), t2.C("name")).From(t1).Join(t2)
|
||||
}(),
|
||||
wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Table("groups").As("g")
|
||||
return Select(t1.C("id"), t2.C("name")).
|
||||
From(t1).
|
||||
Join(t2).
|
||||
On(t1.C("id"), t2.C("user_id"))
|
||||
}(),
|
||||
wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Table("groups").As("g")
|
||||
return Select(t1.C("id"), t2.C("name")).
|
||||
From(t1).
|
||||
Join(t2).
|
||||
On(t1.C("id"), t2.C("user_id")).
|
||||
Where(EQ(t1.C("name"), "bar").And().NotNull(t2.C("name")))
|
||||
}(),
|
||||
wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id` WHERE `u`.`name` = ? AND `g`.`name` IS NOT NULL",
|
||||
wantArgs: []interface{}{"bar"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users").As("u")
|
||||
return Select(t1.Columns("name", "age")...).From(t1)
|
||||
}(),
|
||||
wantQuery: "SELECT `u`.`name`, `u`.`age` FROM `users` AS `u`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users").As("u")
|
||||
t2 := Select().From(Table("groups")).Where(EQ("user_id", 10)).As("g")
|
||||
return Select(t1.C("id"), t2.C("name")).
|
||||
From(t1).
|
||||
Join(t2).
|
||||
On(t1.C("id"), t2.C("user_id"))
|
||||
}(),
|
||||
wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN (SELECT * FROM `groups` WHERE `user_id` = ?) AS `g` ON `u`.`id` = `g`.`user_id`",
|
||||
wantArgs: []interface{}{10},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
selector := Select().Where(EQ("name", "foo").Or().EQ("name", "bar"))
|
||||
return Delete("users").FromSelect(selector)
|
||||
}(),
|
||||
wantQuery: "DELETE FROM `users` WHERE `name` = ? OR `name` = ?",
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
selector := Select().From(Table("users")).As("t")
|
||||
return selector.Select(selector.C("name"))
|
||||
}(),
|
||||
wantQuery: "SELECT `t`.`name` FROM `users`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
selector := Select().From(Table("groups")).Where(EQ("name", "foo"))
|
||||
return Delete("users").FromSelect(selector)
|
||||
}(),
|
||||
wantQuery: "DELETE FROM `groups` WHERE `name` = ?",
|
||||
wantArgs: []interface{}{"foo"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
selector := Select()
|
||||
return Delete("users").FromSelect(selector)
|
||||
}(),
|
||||
wantQuery: "DELETE FROM `users`",
|
||||
},
|
||||
{
|
||||
input: Select().From(Table("users")).Where(Not(EQ("name", "foo").And().EQ("age", "bar"))),
|
||||
wantQuery: "SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)",
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users")
|
||||
return Select().
|
||||
From(t1).
|
||||
Where(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro"))))
|
||||
}(),
|
||||
wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `owner_id` FROM `pets` WHERE `name` = ?)",
|
||||
wantArgs: []interface{}{"pedro"},
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users")
|
||||
return Select().
|
||||
From(t1).
|
||||
Where(Not(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro")))))
|
||||
}(),
|
||||
wantQuery: "SELECT * FROM `users` WHERE NOT (`users`.`id` IN (SELECT `owner_id` FROM `pets` WHERE `name` = ?))",
|
||||
wantArgs: []interface{}{"pedro"},
|
||||
},
|
||||
{
|
||||
input: Select().Count().From(Table("users")),
|
||||
wantQuery: "SELECT COUNT(*) FROM `users`",
|
||||
},
|
||||
{
|
||||
input: Select().Count(Distinct("id")).From(Table("users")),
|
||||
wantQuery: "SELECT COUNT(DISTINCT `id`) FROM `users`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users")
|
||||
t2 := Select().From(Table("groups"))
|
||||
t3 := Select().Count().From(t1).Join(t1).On(t2.C("id"), t1.C("blocked_id"))
|
||||
return t3.Count(Distinct(t3.Columns("id", "name")...))
|
||||
}(),
|
||||
wantQuery: "SELECT COUNT(DISTINCT `t0`.`id`, `t0`.`name`) FROM `users` AS `t0` JOIN `users` AS `t0` ON `groups`.`id` = `t0`.`blocked_id`",
|
||||
},
|
||||
{
|
||||
input: Select(Sum("age"), Min("age")).From(Table("users")),
|
||||
wantQuery: "SELECT SUM(`age`), MIN(`age`) FROM `users`",
|
||||
},
|
||||
{
|
||||
input: func() Node {
|
||||
t1 := Table("users").As("u")
|
||||
return Select(As(Max(t1.C("age")), "max_age")).From(t1)
|
||||
}(),
|
||||
wantQuery: "SELECT MAX(`u`.`age`) AS `max_age` FROM `users` AS `u`",
|
||||
},
|
||||
{
|
||||
input: Select("name", Count("*")).
|
||||
From(Table("users")).
|
||||
GroupBy("name"),
|
||||
wantQuery: "SELECT `name`, COUNT(*) FROM `users` GROUP BY `name`",
|
||||
},
|
||||
{
|
||||
input: Select("name", Count("*")).
|
||||
From(Table("users")).
|
||||
GroupBy("name").
|
||||
OrderBy("name"),
|
||||
wantQuery: "SELECT `name`, COUNT(*) FROM `users` GROUP BY `name` ORDER BY `name`",
|
||||
},
|
||||
{
|
||||
input: Select("name", "age", Count("*")).
|
||||
From(Table("users")).
|
||||
GroupBy("name", "age").
|
||||
OrderBy(Desc("name"), "age"),
|
||||
wantQuery: "SELECT `name`, `age`, COUNT(*) FROM `users` GROUP BY `name`, `age` ORDER BY `name` DESC, `age`",
|
||||
},
|
||||
{
|
||||
input: Select("*").From(Table("users")).Limit(1),
|
||||
wantQuery: "SELECT * FROM `users` LIMIT ?",
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
input: Select("age").Distinct().From(Table("users")),
|
||||
wantQuery: "SELECT DISTINCT `age` FROM `users`",
|
||||
},
|
||||
{
|
||||
input: Select("age", "name").From(Table("users")).Distinct().OrderBy("name"),
|
||||
wantQuery: "SELECT DISTINCT `age`, `name` FROM `users` ORDER BY `name`",
|
||||
},
|
||||
{
|
||||
input: Select("age").From(Table("users")).Where(EQ("name", "foo")).Or().Where(EQ("name", "bar")),
|
||||
wantQuery: "SELECT `age` FROM `users` WHERE (`name` = ?) OR (`name` = ?)",
|
||||
wantArgs: []interface{}{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
input: Nodes{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))},
|
||||
wantQuery: "WITH users_view AS (SELECT * FROM `users`) SELECT * FROM `users_view`",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
query, args := tt.input.Query()
|
||||
require.Equal(t, tt.wantQuery, query)
|
||||
require.Equal(t, tt.wantArgs, args)
|
||||
})
|
||||
}
|
||||
}
|
||||
150
dialect/sql/driver.go
Normal file
150
dialect/sql/driver.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"fbc/ent/dialect"
|
||||
)
|
||||
|
||||
// Driver is a dialect.Driver implementation for SQL based databases.
|
||||
type Driver struct {
|
||||
conn
|
||||
dialect string
|
||||
}
|
||||
|
||||
// Open wraps the database/sql.Open method and returns a dialect.Driver that implements the an ent/dialect.Driver interface.
|
||||
func Open(driver, source string) (*Driver, error) {
|
||||
db, err := sql.Open(driver, source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Driver{conn{db}, driver}, nil
|
||||
}
|
||||
|
||||
// OpenDB wraps the given database/sql.DB method with a Driver.
|
||||
func OpenDB(driver string, db *sql.DB) *Driver {
|
||||
return &Driver{conn{db}, driver}
|
||||
}
|
||||
|
||||
// Dialect implements the dialect.Dialect method.
|
||||
func (d Driver) Dialect() string { return d.dialect }
|
||||
|
||||
// Tx starts and returns a transaction.
|
||||
func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) {
|
||||
tx, err := d.ExecQuerier.(*sql.DB).BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{conn{tx}}, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying connection.
|
||||
func (d *Driver) Close() error { return d.ExecQuerier.(*sql.DB).Close() }
|
||||
|
||||
// Tx wraps the sql.Tx for implementing the dialect.Tx interface.
|
||||
type Tx struct {
|
||||
conn
|
||||
}
|
||||
|
||||
// Commit commits the transaction.
|
||||
func (t *Tx) Commit() error { return t.ExecQuerier.(*sql.Tx).Commit() }
|
||||
|
||||
// Rollback rollback the transaction.
|
||||
func (t *Tx) Rollback() error { return t.ExecQuerier.(*sql.Tx).Rollback() }
|
||||
|
||||
// ExecQuerier wraps the standard Exec and Query methods.
|
||||
type ExecQuerier interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// shared connection ExecQuerier between Gremlin and Tx.
|
||||
type conn struct {
|
||||
ExecQuerier
|
||||
}
|
||||
|
||||
// Exec implements the dialect.Exec method.
|
||||
func (c *conn) Exec(ctx context.Context, query string, args interface{}, v interface{}) error {
|
||||
vr, ok := v.(*sql.Result)
|
||||
if !ok {
|
||||
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Result", v)
|
||||
}
|
||||
argv, ok := args.([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", v)
|
||||
}
|
||||
res, err := c.ExecContext(ctx, query, argv...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*vr = res
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec implements the dialect.Query method.
|
||||
func (c *conn) Query(ctx context.Context, query string, args interface{}, v interface{}) error {
|
||||
vr, ok := v.(*Rows)
|
||||
if !ok {
|
||||
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v)
|
||||
}
|
||||
argv, ok := args.([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", args)
|
||||
}
|
||||
rows, err := c.QueryContext(ctx, query, argv...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*vr = Rows{rows}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ dialect.Driver = (*Driver)(nil)
|
||||
|
||||
type (
|
||||
// Rows wraps the sql.Rows to avoid locks copy.
|
||||
Rows struct{ *sql.Rows }
|
||||
// Result is an alias to sql.Result.
|
||||
Result = sql.Result
|
||||
// NullBool is an alias to sql.NullBool.
|
||||
NullBool = sql.NullBool
|
||||
// NullInt64 is an alias to sql.NullInt64.
|
||||
NullInt64 = sql.NullInt64
|
||||
// NullString is an alias to sql.NullString.
|
||||
NullString = sql.NullString
|
||||
// NullFloat64 is an alias to sql.NullFloat64.
|
||||
NullFloat64 = sql.NullFloat64
|
||||
)
|
||||
|
||||
// Note:
|
||||
// NullTime is a modified copy of database/sql.NullTime from Go 1.13,
|
||||
// It should be replaced with standard library code when Go 1.13 is released.
|
||||
|
||||
// NullTime represents a time.Time that may be null.
|
||||
// NullTime implements the Scanner interface so
|
||||
// it can be used as a scan destination, similar to NullString.
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *NullTime) Scan(v interface{}) error {
|
||||
if v, ok := v.(time.Time); ok {
|
||||
n.Time = v
|
||||
n.Valid = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n NullTime) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Time, nil
|
||||
}
|
||||
123
dialect/sql/scan.go
Normal file
123
dialect/sql/scan.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ColumnScanner is the interface that wraps the
|
||||
// three sql.Rows methods used for scanning.
|
||||
type ColumnScanner interface {
|
||||
Next() bool
|
||||
Scan(...interface{}) error
|
||||
Columns() ([]string, error)
|
||||
}
|
||||
|
||||
// ScanSlice scans the given ColumnScanner (basically, sql.Rows or sql.Rows) into the given slice.
|
||||
func ScanSlice(rows ColumnScanner, v interface{}) error {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return fmt.Errorf("sql/scan: failed getting column names: %v", err)
|
||||
}
|
||||
rv := reflect.Indirect(reflect.ValueOf(v))
|
||||
if k := rv.Kind(); k != reflect.Slice {
|
||||
return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k)
|
||||
}
|
||||
var (
|
||||
scan *rowScan
|
||||
typ = rv.Type().Elem()
|
||||
)
|
||||
switch k := typ.Kind(); {
|
||||
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
|
||||
scan = &rowScan{
|
||||
columns: []reflect.Type{typ},
|
||||
value: func(v ...interface{}) reflect.Value {
|
||||
return reflect.Indirect(reflect.ValueOf(v[0]))
|
||||
},
|
||||
}
|
||||
case k == reflect.Ptr:
|
||||
typ = typ.Elem()
|
||||
if scan, err = scanStruct(typ, columns); err != nil {
|
||||
return err
|
||||
}
|
||||
wrap := scan.value
|
||||
scan.value = func(vs ...interface{}) reflect.Value {
|
||||
v := wrap(vs...)
|
||||
pt := reflect.PtrTo(v.Type())
|
||||
pv := reflect.New(pt.Elem())
|
||||
pv.Elem().Set(v)
|
||||
return pv
|
||||
}
|
||||
case k == reflect.Struct:
|
||||
if scan, err = scanStruct(typ, columns); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("sql/scan: unsupported type ([]%s)", k)
|
||||
}
|
||||
if n, m := len(columns), len(scan.columns); n > m {
|
||||
return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m)
|
||||
}
|
||||
for rows.Next() {
|
||||
values := scan.values()
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
return fmt.Errorf("sql/scan: failed scanning rows: %v", err)
|
||||
}
|
||||
vv := reflect.Append(rv, scan.value(values...))
|
||||
rv.Set(vv)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// rowScan is the configuration for scanning one sql.Row.
|
||||
type rowScan struct {
|
||||
// column types of a row.
|
||||
columns []reflect.Type
|
||||
// value functions that converts the row columns (result) to a reflect.Value.
|
||||
value func(v ...interface{}) reflect.Value
|
||||
}
|
||||
|
||||
// values returns a []interface{} from the configured column types.
|
||||
func (r *rowScan) values() []interface{} {
|
||||
values := make([]interface{}, len(r.columns))
|
||||
for i := range r.columns {
|
||||
values[i] = reflect.New(r.columns[i]).Interface()
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
|
||||
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
|
||||
var (
|
||||
scan = &rowScan{}
|
||||
names = make(map[string]int)
|
||||
idx = make([]int, 0, typ.NumField())
|
||||
)
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
name := strings.ToLower(f.Name)
|
||||
if tag, ok := f.Tag.Lookup("json"); ok {
|
||||
name = strings.Split(tag, ",")[0]
|
||||
}
|
||||
names[name] = i
|
||||
}
|
||||
for _, c := range columns {
|
||||
// normalize columns if necessary, for example: COUNT(*) => count.
|
||||
name := strings.ToLower(strings.Split(c, "(")[0])
|
||||
i, ok := names[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name)
|
||||
}
|
||||
idx = append(idx, i)
|
||||
scan.columns = append(scan.columns, typ.Field(i).Type)
|
||||
}
|
||||
scan.value = func(vs ...interface{}) reflect.Value {
|
||||
st := reflect.New(typ).Elem()
|
||||
for i, v := range vs {
|
||||
st.Field(idx[i]).Set(reflect.Indirect(reflect.ValueOf(v)))
|
||||
}
|
||||
return st
|
||||
}
|
||||
return scan, nil
|
||||
}
|
||||
91
dialect/sql/scan_test.go
Normal file
91
dialect/sql/scan_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestScanSlice(t *testing.T) {
|
||||
rows := &mockRows{
|
||||
columns: []string{"name"},
|
||||
values: [][]interface{}{{"foo"}, {"bar"}},
|
||||
}
|
||||
var v0 []string
|
||||
require.NoError(t, ScanSlice(rows, &v0))
|
||||
require.Equal(t, []string{"foo", "bar"}, v0)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"age"},
|
||||
values: [][]interface{}{{1}, {2}},
|
||||
}
|
||||
var v1 []int
|
||||
require.NoError(t, ScanSlice(rows, &v1))
|
||||
require.Equal(t, []int{1, 2}, v1)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"name", "COUNT(*)"},
|
||||
values: [][]interface{}{{"foo", 1}, {"bar", 2}},
|
||||
}
|
||||
var v2 []struct {
|
||||
Name string
|
||||
Count int
|
||||
}
|
||||
require.NoError(t, ScanSlice(rows, &v2))
|
||||
require.Equal(t, "foo", v2[0].Name)
|
||||
require.Equal(t, "bar", v2[1].Name)
|
||||
require.Equal(t, 1, v2[0].Count)
|
||||
require.Equal(t, 2, v2[1].Count)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"nick_name", "COUNT(*)"},
|
||||
values: [][]interface{}{{"foo", 1}, {"bar", 2}},
|
||||
}
|
||||
var v3 []struct {
|
||||
Count int
|
||||
Name string `json:"nick_name"`
|
||||
}
|
||||
require.NoError(t, ScanSlice(rows, &v3))
|
||||
require.Equal(t, "foo", v3[0].Name)
|
||||
require.Equal(t, "bar", v3[1].Name)
|
||||
require.Equal(t, 1, v3[0].Count)
|
||||
require.Equal(t, 2, v3[1].Count)
|
||||
|
||||
rows = &mockRows{
|
||||
columns: []string{"nick_name", "COUNT(*)"},
|
||||
values: [][]interface{}{{"foo", 1}, {"bar", 2}},
|
||||
}
|
||||
var v4 []*struct {
|
||||
Count int
|
||||
Name string `json:"nick_name"`
|
||||
Ignored string `json:"string"`
|
||||
}
|
||||
require.NoError(t, ScanSlice(rows, &v4))
|
||||
require.Equal(t, "foo", v4[0].Name)
|
||||
require.Equal(t, "bar", v4[1].Name)
|
||||
require.Equal(t, 1, v4[0].Count)
|
||||
require.Equal(t, 2, v4[1].Count)
|
||||
}
|
||||
|
||||
type mockRows struct {
|
||||
columns []string
|
||||
values [][]interface{}
|
||||
}
|
||||
|
||||
func (m mockRows) Columns() ([]string, error) { return m.columns, nil }
|
||||
|
||||
func (m mockRows) Next() bool { return len(m.values) > 0 }
|
||||
|
||||
func (m *mockRows) Scan(vs ...interface{}) error {
|
||||
if len(m.values) == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
row := m.values[0]
|
||||
m.values = m.values[1:]
|
||||
for i := range vs {
|
||||
reflect.Indirect(reflect.ValueOf(vs[i])).Set(reflect.ValueOf(row[i]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
115
dialect/sql/schema/mysql.go
Normal file
115
dialect/sql/schema/mysql.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
|
||||
"fbc/ent/dialect"
|
||||
"fbc/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// MySQL is a mysql migration driver.
|
||||
type MySQL struct {
|
||||
dialect.Driver
|
||||
}
|
||||
|
||||
// Create creates all tables resources in the database.
|
||||
func (d *MySQL) Create(ctx context.Context, tables ...*Table) error {
|
||||
tx, err := d.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range tables {
|
||||
exist, err := d.tableExist(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if exist {
|
||||
continue
|
||||
}
|
||||
query, args := t.DSL().Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, fmt.Errorf("sql/mysql: create table %q: %v", t.Name, err))
|
||||
}
|
||||
}
|
||||
// create foreign keys after table was created, because circular foreign-key constraints are possible.
|
||||
for _, t := range tables {
|
||||
if len(t.ForeignKeys) == 0 {
|
||||
continue
|
||||
}
|
||||
fks := make([]*ForeignKey, 0, len(t.ForeignKeys))
|
||||
for _, fk := range t.ForeignKeys {
|
||||
fk.Symbol = symbol(fk.Symbol)
|
||||
exist, err := d.fkExist(ctx, tx, fk.Symbol)
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if !exist {
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
}
|
||||
if len(fks) == 0 {
|
||||
continue
|
||||
}
|
||||
b := sql.AlterTable(t.Name)
|
||||
for _, fk := range fks {
|
||||
b.AddForeignKey(fk.DSL())
|
||||
}
|
||||
query, args := b.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, fmt.Errorf("sql/mysql: create foreign keys for %q: %v", t.Name, err))
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
return d.exist(
|
||||
ctx,
|
||||
tx,
|
||||
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = (SELECT DATABASE()) AND TABLE_NAME = ?",
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
return d.exist(
|
||||
ctx,
|
||||
tx,
|
||||
`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE TABLE_SCHEMA=(SELECT DATABASE()) AND CONSTRAINT_TYPE="FOREIGN KEY" AND CONSTRAINT_NAME = ?`,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *MySQL) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return false, fmt.Errorf("dialect/mysql: reading schema information %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, fmt.Errorf("dialect/mysql: no rows returned")
|
||||
}
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return false, fmt.Errorf("dialect/mysql: scanning count")
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
// symbol makes sure the symbol length is not longer than the maxlength in MySQL standard (64).
|
||||
func symbol(name string) string {
|
||||
if len(name) <= 64 {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", name[:31], md5.Sum([]byte(name)))
|
||||
}
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
func rollback(tx dialect.Tx, err error) error {
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
err = fmt.Errorf("%s: %v", err.Error(), rerr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
266
dialect/sql/schema/schema.go
Normal file
266
dialect/sql/schema/schema.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"fbc/ent/dialect/sql"
|
||||
"fbc/ent/field"
|
||||
)
|
||||
|
||||
// Table schema definition for SQL dialects.
|
||||
type Table struct {
|
||||
Name string
|
||||
Columns []*Column
|
||||
Indexes []*Index
|
||||
PrimaryKey []*Column
|
||||
ForeignKeys []*ForeignKey
|
||||
}
|
||||
|
||||
// NewTable returns a new table with the given name.
|
||||
func NewTable(name string) *Table { return &Table{Name: name} }
|
||||
|
||||
// AddPrimary adds a new primary key to the table.
|
||||
func (t *Table) AddPrimary(c *Column) *Table {
|
||||
t.Columns = append(t.Columns, c)
|
||||
t.PrimaryKey = append(t.PrimaryKey, c)
|
||||
return t
|
||||
}
|
||||
|
||||
// AddForeignKey adds a foreign key to the table.
|
||||
func (t *Table) AddForeignKey(fk *ForeignKey) *Table {
|
||||
t.ForeignKeys = append(t.ForeignKeys, fk)
|
||||
return t
|
||||
}
|
||||
|
||||
// DSL returns the default DSL query for table creation.
|
||||
func (t *Table) DSL() *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name).IfNotExists()
|
||||
for _, c := range t.Columns {
|
||||
b.Column(c.DSL())
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// SQLite returns the SQLite query for table creation.
|
||||
func (t *Table) SQLite() *sql.TableBuilder {
|
||||
b := sql.CreateTable(t.Name)
|
||||
for _, c := range t.Columns {
|
||||
b.Column(c.SQLite())
|
||||
}
|
||||
// Unlike in MySQL, we're not able to add foreign-key constraints to table
|
||||
// after it was created, and adding them to the `CREATE TABLE` statement is
|
||||
// not always valid (because circular foreign-keys situation is possible).
|
||||
// We stay consistent by not using constraints at all, and just defining the
|
||||
// foreign keys in the `CREATE TABLE` statement.
|
||||
for _, fk := range t.ForeignKeys {
|
||||
b.ForeignKeys(fk.DSL())
|
||||
}
|
||||
// if it's an ID based primary key, we add the `PRIMARY KEY`
|
||||
// clause to the column declaration.
|
||||
if len(t.PrimaryKey) == 1 {
|
||||
return b
|
||||
}
|
||||
for _, pk := range t.PrimaryKey {
|
||||
b.PrimaryKey(pk.Name)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Column schema definition for SQL dialects.
|
||||
type Column struct {
|
||||
Name string // column name.
|
||||
Type field.Type // column type.
|
||||
Attr string // extra attributes.
|
||||
Default string // default value.
|
||||
Nullable *bool // null or not null attribute.
|
||||
Size int // max size parameter for string, blob, etc.
|
||||
Key string // key definition (PRI, UNI or MUL).
|
||||
Unique bool // column with unique constraint.
|
||||
Increment bool // auto increment attribute.
|
||||
}
|
||||
|
||||
// UniqueKey returns boolean indicates if this column is a unique key.
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (c *Column) UniqueKey() bool { return c.Key == "UNI" }
|
||||
|
||||
// PrimaryKey returns boolean indicates if this column is on of the primary key columns.
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (c *Column) PrimaryKey() bool { return c.Key == "PRI" }
|
||||
|
||||
// DSL returns the default DSL query for table creation.
|
||||
func (c *Column) DSL() *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(c.MySQLType()).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("AUTO_INCREMENT")
|
||||
}
|
||||
c.nullable(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// SQLite returns a SQLite DSL node for this column.
|
||||
func (c *Column) SQLite() *sql.ColumnBuilder {
|
||||
b := sql.Column(c.Name).Type(c.SQLiteType()).Attr(c.Attr)
|
||||
c.unique(b)
|
||||
if c.Increment {
|
||||
b.Attr("PRIMARY KEY AUTOINCREMENT")
|
||||
}
|
||||
c.nullable(b)
|
||||
return b
|
||||
}
|
||||
|
||||
// MySQLType returns the MySQL string type for this column.
|
||||
func (c *Column) MySQLType() (t string) {
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "boolean"
|
||||
case field.TypeInt8:
|
||||
t = "tinyint"
|
||||
case field.TypeUint8:
|
||||
t = "tinyint unsigned"
|
||||
case field.TypeInt64:
|
||||
t = "bigint"
|
||||
case field.TypeUint64:
|
||||
t = "bigint unsigned"
|
||||
case field.TypeInt, field.TypeInt16, field.TypeInt32:
|
||||
t = "int"
|
||||
case field.TypeUint, field.TypeUint16, field.TypeUint32:
|
||||
t = "int unsigned"
|
||||
case field.TypeString:
|
||||
size := c.Size
|
||||
if size == 0 {
|
||||
size = 255
|
||||
}
|
||||
if size < 1<<16 {
|
||||
t = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
t = "longtext"
|
||||
}
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = "double"
|
||||
case field.TypeTime:
|
||||
t = "timestamp"
|
||||
// in MySQL timestamp columns are `NOT NULL by default, and assigning NULL
|
||||
// assigns the current_timestamp(). We avoid this if not set otherwise.
|
||||
if c.Nullable == nil {
|
||||
nullable := true
|
||||
c.Nullable = &nullable
|
||||
}
|
||||
default:
|
||||
panic("unsupported type " + c.Type.String())
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// SQLiteType returns the SQLite string type for this column.
|
||||
func (c *Column) SQLiteType() (t string) {
|
||||
switch c.Type {
|
||||
case field.TypeBool:
|
||||
t = "bool"
|
||||
case field.TypeInt8, field.TypeUint8, field.TypeInt, field.TypeInt16, field.TypeInt32, field.TypeUint, field.TypeUint16, field.TypeUint32:
|
||||
t = "integer"
|
||||
case field.TypeInt64, field.TypeUint64:
|
||||
t = "bigint"
|
||||
case field.TypeString:
|
||||
size := c.Size
|
||||
if size == 0 {
|
||||
size = 255
|
||||
}
|
||||
// sqlite has no size limit on varchar.
|
||||
t = fmt.Sprintf("varchar(%d)", size)
|
||||
case field.TypeFloat32, field.TypeFloat64:
|
||||
t = "real"
|
||||
case field.TypeTime:
|
||||
t = "datetime"
|
||||
default:
|
||||
panic("unsupported type " + c.Type.String())
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// unique adds the `UNIQUE` attribute if the column is a unique type.
|
||||
// it is exist in a different function to share the common declaration
|
||||
// between the two dialects.
|
||||
func (c *Column) unique(b *sql.ColumnBuilder) {
|
||||
if c.Unique {
|
||||
b.Attr("UNIQUE")
|
||||
}
|
||||
}
|
||||
|
||||
// nullable adds the `NULL`/`NOT NULL` attribute to the column. it is exist in
|
||||
// a different function to share the common declaration between the two dialects.
|
||||
func (c *Column) nullable(b *sql.ColumnBuilder) {
|
||||
if c.Nullable != nil {
|
||||
attr := "NULL"
|
||||
if !*c.Nullable {
|
||||
attr = "NOT " + attr
|
||||
}
|
||||
b.Attr(attr)
|
||||
}
|
||||
}
|
||||
|
||||
// ForeignKey definition for creation.
|
||||
type ForeignKey struct {
|
||||
Symbol string // foreign-key name. Generated if empty.
|
||||
Columns []*Column // table column
|
||||
RefTable *Table // referenced table.
|
||||
RefColumns []*Column // referenced columns.
|
||||
OnUpdate ReferenceOption // action on update.
|
||||
OnDelete ReferenceOption // action on delete.
|
||||
}
|
||||
|
||||
// DSL returns a default DSL query for a foreign-key.
|
||||
func (fk ForeignKey) DSL() *sql.ForeignKeyBuilder {
|
||||
cols := make([]string, len(fk.Columns))
|
||||
refs := make([]string, len(fk.RefColumns))
|
||||
for i, c := range fk.Columns {
|
||||
cols[i] = c.Name
|
||||
}
|
||||
for i, c := range fk.RefColumns {
|
||||
refs[i] = c.Name
|
||||
}
|
||||
dsl := sql.ForeignKey().Symbol(fk.Symbol).
|
||||
Columns(cols...).
|
||||
Reference(sql.Reference().Table(fk.RefTable.Name).Columns(refs...))
|
||||
if action := string(fk.OnDelete); action != "" {
|
||||
dsl.OnDelete(action)
|
||||
}
|
||||
if action := string(fk.OnUpdate); action != "" {
|
||||
dsl.OnUpdate(action)
|
||||
}
|
||||
return dsl
|
||||
}
|
||||
|
||||
// ReferenceOption for constraint actions.
|
||||
type ReferenceOption string
|
||||
|
||||
// Reference options.
|
||||
const (
|
||||
NoAction ReferenceOption = "NO ACTION"
|
||||
Restrict ReferenceOption = "RESTRICT"
|
||||
Cascade ReferenceOption = "CASCADE"
|
||||
SetNull ReferenceOption = "SET NULL"
|
||||
SetDefault ReferenceOption = "SET DEFAULT"
|
||||
)
|
||||
|
||||
// ConstName returns the constant name of a reference option. It's used by entc for printing the constant name in templates.
|
||||
func (r ReferenceOption) ConstName() string {
|
||||
if r == NoAction {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(strings.Title(strings.ToLower(string(r))), " ", "")
|
||||
}
|
||||
|
||||
// Index definition for table index.
|
||||
type Index struct {
|
||||
Key string // key name.
|
||||
Column string // column name.
|
||||
}
|
||||
|
||||
// Primary indicates if this index is a primary key.
|
||||
// Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects.
|
||||
func (i *Index) Primary() bool { return i.Key == "PRIMARY" }
|
||||
74
dialect/sql/schema/sqlite.go
Normal file
74
dialect/sql/schema/sqlite.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"fbc/ent/dialect"
|
||||
"fbc/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// SQLite is an SQLite migration driver.
|
||||
type SQLite struct {
|
||||
dialect.Driver
|
||||
}
|
||||
|
||||
// Create creates all tables resources in the database.
|
||||
func (d *SQLite) Create(ctx context.Context, tables ...*Table) error {
|
||||
tx, err := d.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
on, err := d.fkEnabled(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sql/sqlite: check foreign_keys pragma: %v", err)
|
||||
}
|
||||
if !on {
|
||||
// foreign_keys pragma is off, either enable it by execute "PRAGMA foreign_keys=ON"
|
||||
// or add the following parameter in the connection string "_fk=1".
|
||||
return fmt.Errorf("sql/sqlite: foreign_keys pragma is off: missing %q is the connection string", "_fk=1")
|
||||
}
|
||||
for _, t := range tables {
|
||||
exist, err := d.tableExist(ctx, tx, t.Name)
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
if exist {
|
||||
continue
|
||||
}
|
||||
query, args := t.SQLite().Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
err = fmt.Errorf("sql/sqlite: create table %q: %v", t.Name, err)
|
||||
return rollback(tx, err)
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
|
||||
query, args := sql.Select().Count().
|
||||
From(sql.Table("sqlite_master")).
|
||||
Where(sql.EQ("type", "table").And().EQ("name", name)).
|
||||
Query()
|
||||
return d.exist(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
func (d *SQLite) fkEnabled(ctx context.Context, tx dialect.Tx) (bool, error) {
|
||||
return d.exist(ctx, tx, "PRAGMA foreign_keys")
|
||||
}
|
||||
|
||||
func (d *SQLite) exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return false, fmt.Errorf("dialect/sqlite: reading schema information %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, fmt.Errorf("dialect/sqlite: no rows returned")
|
||||
}
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
return false, fmt.Errorf("dialect/sqlite: scanning count")
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
Reference in New Issue
Block a user