entc/gen: generate sql builders with dialect option

Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/130

Reviewed By: alexsn

Differential Revision: D18164397

fbshipit-source-id: 2858d69d3ff85c06b51382c01c3d4369ee2c3bdb
This commit is contained in:
Ariel Mashraki
2019-10-27 21:52:31 -07:00
committed by Facebook Github Bot
parent ea479ea527
commit c259aee24b
241 changed files with 3899 additions and 2394 deletions

View File

@@ -375,8 +375,9 @@ func HasOwner() predicate.Card {
func HasOwnerWith(preds ...predicate.User) predicate.Card {
return predicate.Card(
func(s *sql.Selector) {
builder := sql.Dialect(s.Dialect())
t1 := s.Table()
t2 := sql.Select(FieldID).From(sql.Table(OwnerInverseTable))
t2 := builder.Select(FieldID).From(builder.Table(OwnerInverseTable))
for _, p := range preds {
p(t2)
}

View File

@@ -78,36 +78,31 @@ func (cc *CardCreate) SaveX(ctx context.Context) *Card {
func (cc *CardCreate) sqlSave(ctx context.Context) (*Card, error) {
var (
res sql.Result
c = &Card{config: cc.config}
res sql.Result
builder = sql.Dialect(cc.driver.Dialect())
c = &Card{config: cc.config}
)
tx, err := cc.driver.Tx(ctx)
if err != nil {
return nil, err
}
builder := sql.Dialect(cc.driver.Dialect()).
Insert(card.Table).
Default()
insert := builder.Insert(card.Table).Default()
if value := cc.expired; value != nil {
builder.Set(card.FieldExpired, *value)
insert.Set(card.FieldExpired, *value)
c.Expired = *value
}
if value := cc.number; value != nil {
builder.Set(card.FieldNumber, *value)
insert.Set(card.FieldNumber, *value)
c.Number = *value
}
query, args := builder.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return nil, rollback(tx, err)
}
id, err := res.LastInsertId()
id, err := insertLastID(ctx, tx, insert.Returning(card.FieldID))
if err != nil {
return nil, rollback(tx, err)
}
c.ID = int(id)
if len(cc.owner) > 0 {
eid := keys(cc.owner)[0]
query, args := sql.Update(card.OwnerTable).
query, args := builder.Update(card.OwnerTable).
Set(card.OwnerColumn, eid).
Where(sql.EQ(card.FieldID, id).And().IsNull(card.OwnerColumn)).
Query()

View File

@@ -41,12 +41,15 @@ func (cd *CardDelete) ExecX(ctx context.Context) int {
}
func (cd *CardDelete) sqlExec(ctx context.Context) (int, error) {
var res sql.Result
selector := sql.Select().From(sql.Table(card.Table))
var (
res sql.Result
builder = sql.Dialect(cd.driver.Dialect())
)
selector := builder.Select().From(sql.Table(card.Table))
for _, p := range cd.predicates {
p(selector)
}
query, args := sql.Delete(card.Table).FromSelect(selector).Query()
query, args := builder.Delete(card.Table).FromSelect(selector).Query()
if err := cd.driver.Exec(ctx, query, args, &res); err != nil {
return 0, err
}

View File

@@ -57,10 +57,12 @@ func (cq *CardQuery) Order(o ...Order) *CardQuery {
// QueryOwner chains the current query on the owner edge.
func (cq *CardQuery) QueryOwner() *UserQuery {
query := &UserQuery{config: cq.config}
t1 := sql.Table(user.Table)
builder := sql.Dialect(cq.driver.Dialect())
t1 := builder.Table(user.Table)
t2 := cq.sqlQuery()
t2.Select(t2.C(card.OwnerColumn))
query.sql = sql.Select(t1.Columns(user.Columns...)...).
query.sql = builder.Select(t1.Columns(user.Columns...)...).
From(t1).
Join(t2).
On(t1.C(user.FieldID), t2.C(card.OwnerColumn))
@@ -328,8 +330,9 @@ func (cq *CardQuery) sqlExist(ctx context.Context) (bool, error) {
}
func (cq *CardQuery) sqlQuery() *sql.Selector {
t1 := sql.Table(card.Table)
selector := sql.Select(t1.Columns(card.Columns...)...).From(t1)
builder := sql.Dialect(cq.driver.Dialect())
t1 := builder.Table(card.Table)
selector := builder.Select(t1.Columns(card.Columns...)...).From(t1)
if cq.sql != nil {
selector = cq.sql
selector.Select(selector.Columns(card.Columns...)...)
@@ -598,5 +601,6 @@ func (cs *CardSelect) sqlScan(ctx context.Context, v interface{}) error {
func (cs *CardSelect) sqlQuery() sql.Querier {
view := "card_view"
return sql.Select(cs.fields...).From(cs.sql.As(view))
return sql.Dialect(cs.driver.Dialect()).
Select(cs.fields...).From(cs.sql.As(view))
}

View File

@@ -100,7 +100,10 @@ func (cu *CardUpdate) ExecX(ctx context.Context) {
}
func (cu *CardUpdate) sqlSave(ctx context.Context) (n int, err error) {
selector := sql.Select(card.FieldID).From(sql.Table(card.Table))
var (
builder = sql.Dialect(cu.driver.Dialect())
selector = builder.Select(card.FieldID).From(builder.Table(card.Table))
)
for _, p := range cu.predicates {
p(selector)
}
@@ -128,22 +131,22 @@ func (cu *CardUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
var (
res sql.Result
builder = sql.Update(card.Table).Where(sql.InInts(card.FieldID, ids...))
updater = builder.Update(card.Table).Where(sql.InInts(card.FieldID, ids...))
)
if value := cu.expired; value != nil {
builder.Set(card.FieldExpired, *value)
updater.Set(card.FieldExpired, *value)
}
if value := cu.number; value != nil {
builder.Set(card.FieldNumber, *value)
updater.Set(card.FieldNumber, *value)
}
if !builder.Empty() {
query, args := builder.Query()
if !updater.Empty() {
query, args := updater.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, rollback(tx, err)
}
}
if cu.clearedOwner {
query, args := sql.Update(card.OwnerTable).
query, args := builder.Update(card.OwnerTable).
SetNull(card.OwnerColumn).
Where(sql.InInts(user.FieldID, ids...)).
Query()
@@ -154,7 +157,7 @@ func (cu *CardUpdate) sqlSave(ctx context.Context) (n int, err error) {
if len(cu.owner) > 0 {
for _, id := range ids {
eid := keys(cu.owner)[0]
query, args := sql.Update(card.OwnerTable).
query, args := builder.Update(card.OwnerTable).
Set(card.OwnerColumn, eid).
Where(sql.EQ(card.FieldID, id).And().IsNull(card.OwnerColumn)).
Query()
@@ -252,7 +255,10 @@ func (cuo *CardUpdateOne) ExecX(ctx context.Context) {
}
func (cuo *CardUpdateOne) sqlSave(ctx context.Context) (c *Card, err error) {
selector := sql.Select(card.Columns...).From(sql.Table(card.Table))
var (
builder = sql.Dialect(cuo.driver.Dialect())
selector = builder.Select(card.Columns...).From(builder.Table(card.Table))
)
card.ID(cuo.id)(selector)
rows := &sql.Rows{}
query, args := selector.Query()
@@ -283,24 +289,24 @@ func (cuo *CardUpdateOne) sqlSave(ctx context.Context) (c *Card, err error) {
}
var (
res sql.Result
builder = sql.Update(card.Table).Where(sql.InInts(card.FieldID, ids...))
updater = builder.Update(card.Table).Where(sql.InInts(card.FieldID, ids...))
)
if value := cuo.expired; value != nil {
builder.Set(card.FieldExpired, *value)
updater.Set(card.FieldExpired, *value)
c.Expired = *value
}
if value := cuo.number; value != nil {
builder.Set(card.FieldNumber, *value)
updater.Set(card.FieldNumber, *value)
c.Number = *value
}
if !builder.Empty() {
query, args := builder.Query()
if !updater.Empty() {
query, args := updater.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return nil, rollback(tx, err)
}
}
if cuo.clearedOwner {
query, args := sql.Update(card.OwnerTable).
query, args := builder.Update(card.OwnerTable).
SetNull(card.OwnerColumn).
Where(sql.InInts(user.FieldID, ids...)).
Query()
@@ -311,7 +317,7 @@ func (cuo *CardUpdateOne) sqlSave(ctx context.Context) (c *Card, err error) {
if len(cuo.owner) > 0 {
for _, id := range ids {
eid := keys(cuo.owner)[0]
query, args := sql.Update(card.OwnerTable).
query, args := builder.Update(card.OwnerTable).
Set(card.OwnerColumn, eid).
Where(sql.EQ(card.FieldID, id).And().IsNull(card.OwnerColumn)).
Query()

View File

@@ -170,11 +170,12 @@ func (c *CardClient) GetX(ctx context.Context, id int) *Card {
func (c *CardClient) QueryOwner(ca *Card) *UserQuery {
query := &UserQuery{config: c.config}
id := ca.ID
t1 := sql.Table(user.Table)
t2 := sql.Select(card.OwnerColumn).
From(sql.Table(card.OwnerTable)).
builder := sql.Dialect(ca.driver.Dialect())
t1 := builder.Table(user.Table)
t2 := builder.Select(card.OwnerColumn).
From(builder.Table(card.OwnerTable)).
Where(sql.EQ(card.FieldID, id))
query.sql = sql.Select().From(t1).Join(t2).On(t1.C(user.FieldID), t2.C(card.OwnerColumn))
query.sql = builder.Select().From(t1).Join(t2).On(t1.C(user.FieldID), t2.C(card.OwnerColumn))
return query
}
@@ -247,7 +248,8 @@ func (c *UserClient) GetX(ctx context.Context, id int) *User {
func (c *UserClient) QueryCard(u *User) *CardQuery {
query := &CardQuery{config: c.config}
id := u.ID
query.sql = sql.Select().From(sql.Table(card.Table)).
builder := sql.Dialect(u.driver.Dialect())
query.sql = builder.Select().From(builder.Table(card.Table)).
Where(sql.EQ(user.CardColumn, id))
return query

View File

@@ -7,6 +7,7 @@
package ent
import (
"context"
"fmt"
"strings"
@@ -168,9 +169,19 @@ func IsConstraintFailure(err error) bool {
}
func isSQLConstraintError(err error) (*ErrConstraintFailed, bool) {
// Error number 1062 is ER_DUP_ENTRY in mysql, and "UNIQUE constraint failed" is SQLite prefix.
if msg := err.Error(); strings.HasPrefix(msg, "Error 1062") || strings.HasPrefix(msg, "UNIQUE constraint failed") {
return &ErrConstraintFailed{msg, err}, true
var (
msg = err.Error()
// error format per dialect.
errors = [...]string{
"Error 1062", // MySQL 1062 error (ER_DUP_ENTRY).
"UNIQUE constraint failed", // SQLite.
"duplicate key value violates unique constraint", // PostgreSQL.
}
)
for i := range errors {
if strings.Contains(msg, errors[i]) {
return &ErrConstraintFailed{msg, err}, true
}
}
return nil, false
}
@@ -186,6 +197,38 @@ func rollback(tx dialect.Tx, err error) error {
return err
}
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
func insertLastID(ctx context.Context, tx dialect.Tx, insert *sql.InsertBuilder) (int64, error) {
query, args := insert.Query()
// PostgreSQL does not support the LastInsertId() method of sql.Result
// on Exec, and should be extracted manually using the `RETURNING` clause.
if insert.Dialect() == dialect.Postgres {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, fmt.Errorf("no rows found for query: %v", query)
}
var id int64
if err := rows.Scan(&id); err != nil {
return 0, err
}
return id, nil
}
// MySQL, SQLite, etc.
var res sql.Result
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, err
}
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}
// keys returns the keys/ids from the edge map.
func keys(m map[int]struct{}) []int {
s := make([]int, 0, len(m))

View File

@@ -364,11 +364,12 @@ func HasCard() predicate.User {
return predicate.User(
func(s *sql.Selector) {
t1 := s.Table()
builder := sql.Dialect(s.Dialect())
s.Where(
sql.In(
t1.C(FieldID),
sql.Select(CardColumn).
From(sql.Table(CardTable)).
builder.Select(CardColumn).
From(builder.Table(CardTable)).
Where(sql.NotNull(CardColumn)),
),
)
@@ -380,8 +381,9 @@ func HasCard() predicate.User {
func HasCardWith(preds ...predicate.Card) predicate.User {
return predicate.User(
func(s *sql.Selector) {
builder := sql.Dialect(s.Dialect())
t1 := s.Table()
t2 := sql.Select(CardColumn).From(sql.Table(CardTable))
t2 := builder.Select(CardColumn).From(builder.Table(CardTable))
for _, p := range preds {
p(t2)
}

View File

@@ -83,36 +83,31 @@ func (uc *UserCreate) SaveX(ctx context.Context) *User {
func (uc *UserCreate) sqlSave(ctx context.Context) (*User, error) {
var (
res sql.Result
u = &User{config: uc.config}
res sql.Result
builder = sql.Dialect(uc.driver.Dialect())
u = &User{config: uc.config}
)
tx, err := uc.driver.Tx(ctx)
if err != nil {
return nil, err
}
builder := sql.Dialect(uc.driver.Dialect()).
Insert(user.Table).
Default()
insert := builder.Insert(user.Table).Default()
if value := uc.age; value != nil {
builder.Set(user.FieldAge, *value)
insert.Set(user.FieldAge, *value)
u.Age = *value
}
if value := uc.name; value != nil {
builder.Set(user.FieldName, *value)
insert.Set(user.FieldName, *value)
u.Name = *value
}
query, args := builder.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return nil, rollback(tx, err)
}
id, err := res.LastInsertId()
id, err := insertLastID(ctx, tx, insert.Returning(user.FieldID))
if err != nil {
return nil, rollback(tx, err)
}
u.ID = int(id)
if len(uc.card) > 0 {
eid := keys(uc.card)[0]
query, args := sql.Update(user.CardTable).
query, args := builder.Update(user.CardTable).
Set(user.CardColumn, id).
Where(sql.EQ(card.FieldID, eid).And().IsNull(user.CardColumn)).
Query()

View File

@@ -41,12 +41,15 @@ func (ud *UserDelete) ExecX(ctx context.Context) int {
}
func (ud *UserDelete) sqlExec(ctx context.Context) (int, error) {
var res sql.Result
selector := sql.Select().From(sql.Table(user.Table))
var (
res sql.Result
builder = sql.Dialect(ud.driver.Dialect())
)
selector := builder.Select().From(sql.Table(user.Table))
for _, p := range ud.predicates {
p(selector)
}
query, args := sql.Delete(user.Table).FromSelect(selector).Query()
query, args := builder.Delete(user.Table).FromSelect(selector).Query()
if err := ud.driver.Exec(ctx, query, args, &res); err != nil {
return 0, err
}

View File

@@ -57,10 +57,12 @@ func (uq *UserQuery) Order(o ...Order) *UserQuery {
// QueryCard chains the current query on the card edge.
func (uq *UserQuery) QueryCard() *CardQuery {
query := &CardQuery{config: uq.config}
t1 := sql.Table(card.Table)
builder := sql.Dialect(uq.driver.Dialect())
t1 := builder.Table(card.Table)
t2 := uq.sqlQuery()
t2.Select(t2.C(user.FieldID))
query.sql = sql.Select().
query.sql = builder.Select().
From(t1).
Join(t2).
On(t1.C(user.CardColumn), t2.C(user.FieldID))
@@ -328,8 +330,9 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) {
}
func (uq *UserQuery) sqlQuery() *sql.Selector {
t1 := sql.Table(user.Table)
selector := sql.Select(t1.Columns(user.Columns...)...).From(t1)
builder := sql.Dialect(uq.driver.Dialect())
t1 := builder.Table(user.Table)
selector := builder.Select(t1.Columns(user.Columns...)...).From(t1)
if uq.sql != nil {
selector = uq.sql
selector.Select(selector.Columns(user.Columns...)...)
@@ -598,5 +601,6 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error {
func (us *UserSelect) sqlQuery() sql.Querier {
view := "user_view"
return sql.Select(us.fields...).From(us.sql.As(view))
return sql.Dialect(us.driver.Dialect()).
Select(us.fields...).From(us.sql.As(view))
}

View File

@@ -116,7 +116,10 @@ func (uu *UserUpdate) ExecX(ctx context.Context) {
}
func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
selector := sql.Select(user.FieldID).From(sql.Table(user.Table))
var (
builder = sql.Dialect(uu.driver.Dialect())
selector = builder.Select(user.FieldID).From(builder.Table(user.Table))
)
for _, p := range uu.predicates {
p(selector)
}
@@ -144,25 +147,25 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
var (
res sql.Result
builder = sql.Update(user.Table).Where(sql.InInts(user.FieldID, ids...))
updater = builder.Update(user.Table).Where(sql.InInts(user.FieldID, ids...))
)
if value := uu.age; value != nil {
builder.Set(user.FieldAge, *value)
updater.Set(user.FieldAge, *value)
}
if value := uu.addage; value != nil {
builder.Add(user.FieldAge, *value)
updater.Add(user.FieldAge, *value)
}
if value := uu.name; value != nil {
builder.Set(user.FieldName, *value)
updater.Set(user.FieldName, *value)
}
if !builder.Empty() {
query, args := builder.Query()
if !updater.Empty() {
query, args := updater.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, rollback(tx, err)
}
}
if uu.clearedCard {
query, args := sql.Update(user.CardTable).
query, args := builder.Update(user.CardTable).
SetNull(user.CardColumn).
Where(sql.InInts(card.FieldID, ids...)).
Query()
@@ -173,7 +176,7 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
if len(uu.card) > 0 {
for _, id := range ids {
eid := keys(uu.card)[0]
query, args := sql.Update(user.CardTable).
query, args := builder.Update(user.CardTable).
Set(user.CardColumn, id).
Where(sql.EQ(card.FieldID, eid).And().IsNull(user.CardColumn)).
Query()
@@ -288,7 +291,10 @@ func (uuo *UserUpdateOne) ExecX(ctx context.Context) {
}
func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) {
selector := sql.Select(user.Columns...).From(sql.Table(user.Table))
var (
builder = sql.Dialect(uuo.driver.Dialect())
selector = builder.Select(user.Columns...).From(builder.Table(user.Table))
)
user.ID(uuo.id)(selector)
rows := &sql.Rows{}
query, args := selector.Query()
@@ -319,28 +325,28 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) {
}
var (
res sql.Result
builder = sql.Update(user.Table).Where(sql.InInts(user.FieldID, ids...))
updater = builder.Update(user.Table).Where(sql.InInts(user.FieldID, ids...))
)
if value := uuo.age; value != nil {
builder.Set(user.FieldAge, *value)
updater.Set(user.FieldAge, *value)
u.Age = *value
}
if value := uuo.addage; value != nil {
builder.Add(user.FieldAge, *value)
updater.Add(user.FieldAge, *value)
u.Age += *value
}
if value := uuo.name; value != nil {
builder.Set(user.FieldName, *value)
updater.Set(user.FieldName, *value)
u.Name = *value
}
if !builder.Empty() {
query, args := builder.Query()
if !updater.Empty() {
query, args := updater.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return nil, rollback(tx, err)
}
}
if uuo.clearedCard {
query, args := sql.Update(user.CardTable).
query, args := builder.Update(user.CardTable).
SetNull(user.CardColumn).
Where(sql.InInts(card.FieldID, ids...)).
Query()
@@ -351,7 +357,7 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) {
if len(uuo.card) > 0 {
for _, id := range ids {
eid := keys(uuo.card)[0]
query, args := sql.Update(user.CardTable).
query, args := builder.Update(user.CardTable).
Set(user.CardColumn, id).
Where(sql.EQ(card.FieldID, eid).And().IsNull(user.CardColumn)).
Query()