mirror of
https://github.com/ent/ent.git
synced 2026-04-28 05:30:56 +03:00
dialect/sql/sqlgraph: minor refactor changes
This commit is contained in:
committed by
Ariel Mashraki
parent
4720063afd
commit
8f88f58713
@@ -369,7 +369,8 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// CreateNode applies the CreateSpec on the graph.
|
||||
// CreateNode applies the CreateSpec on the graph. The operation creates a new
|
||||
// record in the database, and connects it to other nodes specified in spec.Edges.
|
||||
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
|
||||
gr := graph{tx: drv, builder: sql.Dialect(drv.Dialect())}
|
||||
cr := &creator{CreateSpec: spec, graph: gr}
|
||||
@@ -383,7 +384,7 @@ func BatchCreate(ctx context.Context, drv dialect.Driver, spec *BatchCreateSpec)
|
||||
return err
|
||||
}
|
||||
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
|
||||
cr := &creator{BatchCreateSpec: spec, graph: gr}
|
||||
cr := &batchCreator{BatchCreateSpec: spec, graph: gr}
|
||||
if err := cr.nodes(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
@@ -674,9 +675,8 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
return err
|
||||
}
|
||||
if !update.Empty() {
|
||||
var res sql.Result
|
||||
query, args := update.Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
if err := tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -865,7 +865,6 @@ func (u *updater) scan(rows *sql.Rows) error {
|
||||
type creator struct {
|
||||
graph
|
||||
*CreateSpec
|
||||
*BatchCreateSpec
|
||||
}
|
||||
|
||||
func (c *creator) node(ctx context.Context, drv dialect.Driver) error {
|
||||
@@ -907,7 +906,55 @@ func (c *creator) mayTx(ctx context.Context, drv dialect.Driver, edges map[Rel][
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
|
||||
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
|
||||
insert.Set(column, value)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// insert inserts the node to its table and sets its ID if it was not provided by the user.
|
||||
func (c *creator) insert(ctx context.Context, insert *sql.InsertBuilder) error {
|
||||
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
|
||||
insert.OnConflict(opts...)
|
||||
c.ensureLastInsertID(insert)
|
||||
}
|
||||
// If the id field was provided by the user.
|
||||
if c.ID.Value != nil {
|
||||
insert.Set(c.ID.Column, c.ID.Value)
|
||||
// In case of "ON CONFLICT", the record may exists in the
|
||||
// database, and we need to get back the database id field.
|
||||
if len(c.CreateSpec.OnConflict) == 0 {
|
||||
query, args := insert.Query()
|
||||
return c.tx.Exec(ctx, query, args, nil)
|
||||
}
|
||||
}
|
||||
return c.insertLastID(ctx, insert.Returning(c.ID.Column))
|
||||
}
|
||||
|
||||
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
|
||||
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
|
||||
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
|
||||
if !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
|
||||
return
|
||||
}
|
||||
insert.OnConflict(sql.ResolveWith(func(s *sql.UpdateSet) {
|
||||
for _, column := range s.UpdateColumns() {
|
||||
if column == c.ID.Column {
|
||||
return
|
||||
}
|
||||
}
|
||||
s.Set(c.ID.Column, sql.Expr(fmt.Sprintf("LAST_INSERT_ID(%s)", s.Table().C(c.ID.Column))))
|
||||
}))
|
||||
}
|
||||
|
||||
type batchCreator struct {
|
||||
graph
|
||||
*BatchCreateSpec
|
||||
}
|
||||
|
||||
func (c *batchCreator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
if len(c.Nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -971,52 +1018,8 @@ func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
|
||||
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
|
||||
insert.Set(column, value)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
|
||||
func (c *creator) insert(ctx context.Context, insert *sql.InsertBuilder) error {
|
||||
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
|
||||
insert.OnConflict(opts...)
|
||||
c.ensureLastInsertID(insert)
|
||||
}
|
||||
var res sql.Result
|
||||
// If the id field was provided by the user.
|
||||
if c.ID.Value != nil {
|
||||
insert.Set(c.ID.Column, c.ID.Value)
|
||||
// In case of "ON CONFLICT", the record may exists in the
|
||||
// database, and we need to get back the database id field.
|
||||
if len(c.CreateSpec.OnConflict) == 0 {
|
||||
query, args := insert.Query()
|
||||
return c.tx.Exec(ctx, query, args, &res)
|
||||
}
|
||||
}
|
||||
return c.insertLastID(ctx, insert.Returning(c.ID.Column))
|
||||
}
|
||||
|
||||
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
|
||||
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
|
||||
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
|
||||
if !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
|
||||
return
|
||||
}
|
||||
insert.OnConflict(sql.ResolveWith(func(s *sql.UpdateSet) {
|
||||
for _, column := range s.UpdateColumns() {
|
||||
if column == c.ID.Column {
|
||||
return
|
||||
}
|
||||
}
|
||||
s.Set(c.ID.Column, sql.Expr(fmt.Sprintf("LAST_INSERT_ID(%s)", s.Table().C(c.ID.Column))))
|
||||
}))
|
||||
}
|
||||
|
||||
// batchInsert inserts a batch of nodes to their table and sets their ID if it was not provided by the user.
|
||||
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
func (c *batchCreator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 {
|
||||
insert.OnConflict(opts...)
|
||||
}
|
||||
@@ -1065,12 +1068,9 @@ type graph struct {
|
||||
}
|
||||
|
||||
func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
|
||||
var (
|
||||
res sql.Result
|
||||
// Remove all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
// Remove all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables := edges.GroupTable()
|
||||
for _, table := range edgeKeys(tables) {
|
||||
edges := tables[table]
|
||||
preds := make([]*sql.Predicate, 0, len(edges))
|
||||
@@ -1101,7 +1101,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
|
||||
deleter.Schema(edges[0].Schema)
|
||||
}
|
||||
query, args := deleter.Query()
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
if err := g.tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("remove m2m edge for table %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
@@ -1109,12 +1109,9 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
|
||||
}
|
||||
|
||||
func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
|
||||
var (
|
||||
res sql.Result
|
||||
// Insert all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
// Insert all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables := edges.GroupTable()
|
||||
for _, table := range edgeKeys(tables) {
|
||||
edges := tables[table]
|
||||
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
|
||||
@@ -1136,7 +1133,7 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
|
||||
}
|
||||
}
|
||||
query, args := insert.Query()
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
if err := g.tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
@@ -1175,11 +1172,8 @@ func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error {
|
||||
}
|
||||
}
|
||||
for _, table := range insertKeys(tables) {
|
||||
var (
|
||||
res sql.Result
|
||||
query, args = tables[table].Query()
|
||||
)
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
query, args := tables[table].Query()
|
||||
if err := g.tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
@@ -1201,8 +1195,7 @@ func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*E
|
||||
SetNull(edge.Columns[0]).
|
||||
Where(pred).
|
||||
Query()
|
||||
var res sql.Result
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
if err := g.tx.Exec(ctx, query, args, nil); err != nil {
|
||||
return fmt.Errorf("add %s edge for table %s: %w", edge.Rel, edge.Table, err)
|
||||
}
|
||||
}
|
||||
@@ -1340,7 +1333,7 @@ func (c *creator) insertLastID(ctx context.Context, insert *sql.InsertBuilder) e
|
||||
}
|
||||
|
||||
// insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities.
|
||||
func (c *creator) insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
func (c *batchCreator) insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
query, args := insert.Query()
|
||||
if err := insert.Err(); err != nil {
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user