mirror of
https://github.com/ent/ent.git
synced 2026-04-28 21:50:56 +03:00
dialect/sql/sqlgraph: initial work for batch insert (#573)
This is the first part for adding batch insert support for the framework. The second part if the codegen.
This commit is contained in:
@@ -324,14 +324,21 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// CreateSpec holds the information for creating
|
||||
// a node in the graph.
|
||||
type CreateSpec struct {
|
||||
Table string
|
||||
ID *FieldSpec
|
||||
Fields []*FieldSpec
|
||||
Edges []*EdgeSpec
|
||||
}
|
||||
type (
|
||||
// CreateSpec holds the information for creating
|
||||
// a node in the graph.
|
||||
CreateSpec struct {
|
||||
Table string
|
||||
ID *FieldSpec
|
||||
Fields []*FieldSpec
|
||||
Edges []*EdgeSpec
|
||||
}
|
||||
// BatchCreateSpec holds the information for creating
|
||||
// multiple nodes in the graph.
|
||||
BatchCreateSpec struct {
|
||||
Nodes []*CreateSpec
|
||||
}
|
||||
)
|
||||
|
||||
// CreateNode applies the CreateSpec on the graph.
|
||||
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
|
||||
@@ -347,6 +354,20 @@ func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// BatchCreate applies the BatchCreateSpec on the graph.
|
||||
func BatchCreate(ctx context.Context, drv dialect.Driver, spec *BatchCreateSpec) error {
|
||||
tx, err := drv.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
|
||||
cr := &creator{BatchCreateSpec: spec, graph: gr}
|
||||
if err := cr.nodes(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
type (
|
||||
// EdgeMut defines edge mutations.
|
||||
EdgeMut struct {
|
||||
@@ -737,6 +758,7 @@ func (u *updater) scan(rows *sql.Rows) error {
|
||||
type creator struct {
|
||||
graph
|
||||
*CreateSpec
|
||||
*BatchCreateSpec
|
||||
}
|
||||
|
||||
func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
@@ -760,6 +782,70 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
if len(c.Nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
columns := make(map[string]struct{})
|
||||
values := make([]map[string]driver.Value, len(c.Nodes))
|
||||
for i, node := range c.Nodes {
|
||||
if i > 0 && node.Table != c.Nodes[i-1].Table {
|
||||
return fmt.Errorf("more than 1 table for batch insert: %q != %q", node.Table, c.Nodes[i-1].Table)
|
||||
}
|
||||
values[i] = make(map[string]driver.Value)
|
||||
if node.ID.Value != nil {
|
||||
columns[node.ID.Column] = struct{}{}
|
||||
values[i][node.ID.Column] = node.ID.Value
|
||||
}
|
||||
edges := EdgeSpecs(node.Edges).GroupRel()
|
||||
err := setTableColumns(node.Fields, edges, func(column string, value driver.Value) {
|
||||
columns[column] = struct{}{}
|
||||
values[i][column] = value
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for column := range columns {
|
||||
for i := range values {
|
||||
switch _, exists := values[i][column]; {
|
||||
case column == c.Nodes[i].ID.Column && !exists:
|
||||
// If the ID value was provided to one of the nodes, it should be
|
||||
// provided to all others because this affects the way we calculate
|
||||
// their values in MySQL and SQLite dialects.
|
||||
return fmt.Errorf("incosistent id values for batch insert")
|
||||
case !exists:
|
||||
// Assign NULL values for empty placeholders.
|
||||
values[i][column] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
sorted := keys(columns)
|
||||
insert := c.builder.Insert(c.Nodes[0].Table).Default().Columns(sorted...)
|
||||
for i := range values {
|
||||
vs := make([]interface{}, len(sorted))
|
||||
for j, c := range sorted {
|
||||
vs[j] = values[i][c]
|
||||
}
|
||||
insert.Values(vs...)
|
||||
}
|
||||
if err := c.batchInsert(ctx, tx, insert); err != nil {
|
||||
return fmt.Errorf("insert nodes to table %q: %v", c.Nodes[0].Table, err)
|
||||
}
|
||||
if err := c.batchAddM2M(ctx, c.BatchCreateSpec); err != nil {
|
||||
return err
|
||||
}
|
||||
// FKs that exist in different tables can't be updated in batch (using the CASE
|
||||
// statement), because we rely on RowsAffected to check if the FK column is NULL.
|
||||
for _, node := range c.Nodes {
|
||||
edges := EdgeSpecs(node.Edges).GroupRel()
|
||||
if err := c.graph.addFKEdges(ctx, []driver.Value{node.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
|
||||
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
|
||||
@@ -785,6 +871,18 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
|
||||
return nil
|
||||
}
|
||||
|
||||
// batchInsert inserts a batch of nodes to their table and sets their ID if it wasn't provided by the user.
|
||||
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
ids, err := insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, node := range c.Nodes {
|
||||
node.ID.Value = ids[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupRel groups edges by their relation type.
|
||||
func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec {
|
||||
edges := make(map[Rel][]*EdgeSpec)
|
||||
@@ -803,6 +901,17 @@ func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec {
|
||||
return edges
|
||||
}
|
||||
|
||||
// FilterRel returns edges for the given relation type.
|
||||
func (es EdgeSpecs) FilterRel(r Rel) EdgeSpecs {
|
||||
edges := make([]*EdgeSpec, 0, len(es))
|
||||
for _, edge := range es {
|
||||
if edge.Rel == r {
|
||||
edges = append(edges, edge)
|
||||
}
|
||||
}
|
||||
return edges
|
||||
}
|
||||
|
||||
// The common operations shared between the different builders.
|
||||
//
|
||||
// M2M edges reside in join tables and require INSERT and DELETE
|
||||
@@ -822,7 +931,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
for _, table := range sortedKeys(tables) {
|
||||
for _, table := range edgeKeys(tables) {
|
||||
edges := tables[table]
|
||||
preds := make([]*sql.Predicate, 0, len(edges))
|
||||
for _, edge := range edges {
|
||||
@@ -850,7 +959,7 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
for _, table := range sortedKeys(tables) {
|
||||
for _, table := range edgeKeys(tables) {
|
||||
edges := tables[table]
|
||||
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
|
||||
for _, edge := range edges {
|
||||
@@ -873,6 +982,44 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error {
|
||||
tables := make(map[string]*sql.InsertBuilder)
|
||||
for _, node := range spec.Nodes {
|
||||
edges := EdgeSpecs(node.Edges).FilterRel(M2M)
|
||||
for t, edges := range edges.GroupTable() {
|
||||
insert, ok := tables[t]
|
||||
if !ok {
|
||||
insert = g.builder.Insert(t).Columns(edges[0].Columns...)
|
||||
}
|
||||
tables[t] = insert
|
||||
if len(edges) != 1 {
|
||||
return fmt.Errorf("expect exactly 1 edge-spec per table, but got %d", len(edges))
|
||||
}
|
||||
edge := edges[0]
|
||||
pk1, pk2 := []driver.Value{node.ID.Value}, edge.Target.Nodes
|
||||
if edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
for _, pair := range product(pk1, pk2) {
|
||||
insert.Values(pair[0], pair[1])
|
||||
if edge.Bidi {
|
||||
insert.Values(pair[1], pair[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, table := range insertKeys(tables) {
|
||||
var (
|
||||
res sql.Result
|
||||
query, args = tables[table].Query()
|
||||
)
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %v", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error {
|
||||
for _, edge := range edges {
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
@@ -925,6 +1072,8 @@ func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*Edg
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Setting the FK value of the "other" table
|
||||
// without clearing it before, is not allowed.
|
||||
if ids := edge.Target.Nodes; int(affected) < len(ids) {
|
||||
return &ConstraintError{msg: fmt.Sprintf("one of %v is already connected to a different %s", ids, edge.Columns[0])}
|
||||
}
|
||||
@@ -979,6 +1128,45 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.Inser
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities.
|
||||
func insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (ids []int64, err error) {
|
||||
query, args := insert.Query()
|
||||
// PostgreSQL does not support the LastInsertId() method of sql.Result
|
||||
// on Exec, and should be extracted manually using the `RETURNING` clause.
|
||||
if insert.Dialect() == dialect.Postgres {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return ids, sql.ScanSlice(rows, &ids)
|
||||
}
|
||||
// MySQL, SQLite, etc.
|
||||
var res sql.Result
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = make([]int64, 0, affected)
|
||||
switch insert.Dialect() {
|
||||
case dialect.SQLite:
|
||||
id -= affected - 1
|
||||
fallthrough
|
||||
case dialect.MySQL:
|
||||
for i := int64(0); i < affected; i++ {
|
||||
ids = append(ids, id+i)
|
||||
}
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
func rollback(tx dialect.Tx, err error) error {
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
@@ -987,7 +1175,25 @@ func rollback(tx dialect.Tx, err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func sortedKeys(m map[string][]*EdgeSpec) []string {
|
||||
func edgeKeys(m map[string][]*EdgeSpec) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func insertKeys(m map[string]*sql.InsertBuilder) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func keys(m map[string]struct{}) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
|
||||
Reference in New Issue
Block a user