dialect/sql/sqlgraph: minor refactor changes

This commit is contained in:
Ariel Mashraki
2021-11-07 12:17:41 +02:00
committed by Ariel Mashraki
parent 4720063afd
commit 8f88f58713

View File

@@ -369,7 +369,8 @@ type (
}
)
// CreateNode applies the CreateSpec on the graph.
// CreateNode applies the CreateSpec on the graph. The operation creates a new
// record in the database, and connects it to other nodes specified in spec.Edges.
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
gr := graph{tx: drv, builder: sql.Dialect(drv.Dialect())}
cr := &creator{CreateSpec: spec, graph: gr}
@@ -383,7 +384,7 @@ func BatchCreate(ctx context.Context, drv dialect.Driver, spec *BatchCreateSpec)
return err
}
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
cr := &creator{BatchCreateSpec: spec, graph: gr}
cr := &batchCreator{BatchCreateSpec: spec, graph: gr}
if err := cr.nodes(ctx, tx); err != nil {
return rollback(tx, err)
}
@@ -674,9 +675,8 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
return err
}
if !update.Empty() {
var res sql.Result
query, args := update.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
if err := tx.Exec(ctx, query, args, nil); err != nil {
return err
}
}
@@ -865,7 +865,6 @@ func (u *updater) scan(rows *sql.Rows) error {
type creator struct {
graph
*CreateSpec
*BatchCreateSpec
}
func (c *creator) node(ctx context.Context, drv dialect.Driver) error {
@@ -907,7 +906,55 @@ func (c *creator) mayTx(ctx context.Context, drv dialect.Driver, edges map[Rel][
return tx, nil
}
func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
// setTableColumns sets the table columns and foreign_keys used in insert.
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
insert.Set(column, value)
})
return err
}
// insert inserts the node to its table and sets its ID if it was not provided by the user.
func (c *creator) insert(ctx context.Context, insert *sql.InsertBuilder) error {
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
c.ensureLastInsertID(insert)
}
// If the id field was provided by the user.
if c.ID.Value != nil {
insert.Set(c.ID.Column, c.ID.Value)
// In case of "ON CONFLICT", the record may exists in the
// database, and we need to get back the database id field.
if len(c.CreateSpec.OnConflict) == 0 {
query, args := insert.Query()
return c.tx.Exec(ctx, query, args, nil)
}
}
return c.insertLastID(ctx, insert.Returning(c.ID.Column))
}
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
if !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
return
}
insert.OnConflict(sql.ResolveWith(func(s *sql.UpdateSet) {
for _, column := range s.UpdateColumns() {
if column == c.ID.Column {
return
}
}
s.Set(c.ID.Column, sql.Expr(fmt.Sprintf("LAST_INSERT_ID(%s)", s.Table().C(c.ID.Column))))
}))
}
type batchCreator struct {
graph
*BatchCreateSpec
}
func (c *batchCreator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
if len(c.Nodes) == 0 {
return nil
}
@@ -971,52 +1018,8 @@ func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
return nil
}
// setTableColumns sets the table columns and foreign_keys used in insert.
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
insert.Set(column, value)
})
return err
}
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
func (c *creator) insert(ctx context.Context, insert *sql.InsertBuilder) error {
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
c.ensureLastInsertID(insert)
}
var res sql.Result
// If the id field was provided by the user.
if c.ID.Value != nil {
insert.Set(c.ID.Column, c.ID.Value)
// In case of "ON CONFLICT", the record may exists in the
// database, and we need to get back the database id field.
if len(c.CreateSpec.OnConflict) == 0 {
query, args := insert.Query()
return c.tx.Exec(ctx, query, args, &res)
}
}
return c.insertLastID(ctx, insert.Returning(c.ID.Column))
}
// ensureLastInsertID ensures the LAST_INSERT_ID was added to the
// 'ON DUPLICATE .. UPDATE' clause in it was not provided.
func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) {
if !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL {
return
}
insert.OnConflict(sql.ResolveWith(func(s *sql.UpdateSet) {
for _, column := range s.UpdateColumns() {
if column == c.ID.Column {
return
}
}
s.Set(c.ID.Column, sql.Expr(fmt.Sprintf("LAST_INSERT_ID(%s)", s.Table().C(c.ID.Column))))
}))
}
// batchInsert inserts a batch of nodes to their table and sets their ID if it was not provided by the user.
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
func (c *batchCreator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
}
@@ -1065,12 +1068,9 @@ type graph struct {
}
func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
var (
res sql.Result
// Remove all M2M edges from the same type at once.
// The EdgeSpec is the same for all members in a group.
tables = edges.GroupTable()
)
// Remove all M2M edges from the same type at once.
// The EdgeSpec is the same for all members in a group.
tables := edges.GroupTable()
for _, table := range edgeKeys(tables) {
edges := tables[table]
preds := make([]*sql.Predicate, 0, len(edges))
@@ -1101,7 +1101,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
deleter.Schema(edges[0].Schema)
}
query, args := deleter.Query()
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
if err := g.tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("remove m2m edge for table %s: %w", table, err)
}
}
@@ -1109,12 +1109,9 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
}
func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
var (
res sql.Result
// Insert all M2M edges from the same type at once.
// The EdgeSpec is the same for all members in a group.
tables = edges.GroupTable()
)
// Insert all M2M edges from the same type at once.
// The EdgeSpec is the same for all members in a group.
tables := edges.GroupTable()
for _, table := range edgeKeys(tables) {
edges := tables[table]
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
@@ -1136,7 +1133,7 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
}
}
query, args := insert.Query()
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
if err := g.tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("add m2m edge for table %s: %w", table, err)
}
}
@@ -1175,11 +1172,8 @@ func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error {
}
}
for _, table := range insertKeys(tables) {
var (
res sql.Result
query, args = tables[table].Query()
)
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
query, args := tables[table].Query()
if err := g.tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("add m2m edge for table %s: %w", table, err)
}
}
@@ -1201,8 +1195,7 @@ func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*E
SetNull(edge.Columns[0]).
Where(pred).
Query()
var res sql.Result
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
if err := g.tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("add %s edge for table %s: %w", edge.Rel, edge.Table, err)
}
}
@@ -1340,7 +1333,7 @@ func (c *creator) insertLastID(ctx context.Context, insert *sql.InsertBuilder) e
}
// insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities.
func (c *creator) insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
func (c *batchCreator) insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
query, args := insert.Query()
if err := insert.Err(); err != nil {
return err