mirror of
https://github.com/ent/ent.git
synced 2026-04-29 06:00:55 +03:00
dialect/sql/sqlgraph: initial work for batch insert (#573)
This is the first part for adding batch insert support for the framework. The second part if the codegen.
This commit is contained in:
@@ -324,14 +324,21 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// CreateSpec holds the information for creating
|
||||
// a node in the graph.
|
||||
type CreateSpec struct {
|
||||
Table string
|
||||
ID *FieldSpec
|
||||
Fields []*FieldSpec
|
||||
Edges []*EdgeSpec
|
||||
}
|
||||
type (
|
||||
// CreateSpec holds the information for creating
|
||||
// a node in the graph.
|
||||
CreateSpec struct {
|
||||
Table string
|
||||
ID *FieldSpec
|
||||
Fields []*FieldSpec
|
||||
Edges []*EdgeSpec
|
||||
}
|
||||
// BatchCreateSpec holds the information for creating
|
||||
// multiple nodes in the graph.
|
||||
BatchCreateSpec struct {
|
||||
Nodes []*CreateSpec
|
||||
}
|
||||
)
|
||||
|
||||
// CreateNode applies the CreateSpec on the graph.
|
||||
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
|
||||
@@ -347,6 +354,20 @@ func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// BatchCreate applies the BatchCreateSpec on the graph.
|
||||
func BatchCreate(ctx context.Context, drv dialect.Driver, spec *BatchCreateSpec) error {
|
||||
tx, err := drv.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
|
||||
cr := &creator{BatchCreateSpec: spec, graph: gr}
|
||||
if err := cr.nodes(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
type (
|
||||
// EdgeMut defines edge mutations.
|
||||
EdgeMut struct {
|
||||
@@ -737,6 +758,7 @@ func (u *updater) scan(rows *sql.Rows) error {
|
||||
type creator struct {
|
||||
graph
|
||||
*CreateSpec
|
||||
*BatchCreateSpec
|
||||
}
|
||||
|
||||
func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
@@ -760,6 +782,70 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
if len(c.Nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
columns := make(map[string]struct{})
|
||||
values := make([]map[string]driver.Value, len(c.Nodes))
|
||||
for i, node := range c.Nodes {
|
||||
if i > 0 && node.Table != c.Nodes[i-1].Table {
|
||||
return fmt.Errorf("more than 1 table for batch insert: %q != %q", node.Table, c.Nodes[i-1].Table)
|
||||
}
|
||||
values[i] = make(map[string]driver.Value)
|
||||
if node.ID.Value != nil {
|
||||
columns[node.ID.Column] = struct{}{}
|
||||
values[i][node.ID.Column] = node.ID.Value
|
||||
}
|
||||
edges := EdgeSpecs(node.Edges).GroupRel()
|
||||
err := setTableColumns(node.Fields, edges, func(column string, value driver.Value) {
|
||||
columns[column] = struct{}{}
|
||||
values[i][column] = value
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for column := range columns {
|
||||
for i := range values {
|
||||
switch _, exists := values[i][column]; {
|
||||
case column == c.Nodes[i].ID.Column && !exists:
|
||||
// If the ID value was provided to one of the nodes, it should be
|
||||
// provided to all others because this affects the way we calculate
|
||||
// their values in MySQL and SQLite dialects.
|
||||
return fmt.Errorf("incosistent id values for batch insert")
|
||||
case !exists:
|
||||
// Assign NULL values for empty placeholders.
|
||||
values[i][column] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
sorted := keys(columns)
|
||||
insert := c.builder.Insert(c.Nodes[0].Table).Default().Columns(sorted...)
|
||||
for i := range values {
|
||||
vs := make([]interface{}, len(sorted))
|
||||
for j, c := range sorted {
|
||||
vs[j] = values[i][c]
|
||||
}
|
||||
insert.Values(vs...)
|
||||
}
|
||||
if err := c.batchInsert(ctx, tx, insert); err != nil {
|
||||
return fmt.Errorf("insert nodes to table %q: %v", c.Nodes[0].Table, err)
|
||||
}
|
||||
if err := c.batchAddM2M(ctx, c.BatchCreateSpec); err != nil {
|
||||
return err
|
||||
}
|
||||
// FKs that exist in different tables can't be updated in batch (using the CASE
|
||||
// statement), because we rely on RowsAffected to check if the FK column is NULL.
|
||||
for _, node := range c.Nodes {
|
||||
edges := EdgeSpecs(node.Edges).GroupRel()
|
||||
if err := c.graph.addFKEdges(ctx, []driver.Value{node.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
|
||||
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
|
||||
@@ -785,6 +871,18 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
|
||||
return nil
|
||||
}
|
||||
|
||||
// batchInsert inserts a batch of nodes to their table and sets their ID if it wasn't provided by the user.
|
||||
func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
ids, err := insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, node := range c.Nodes {
|
||||
node.ID.Value = ids[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupRel groups edges by their relation type.
|
||||
func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec {
|
||||
edges := make(map[Rel][]*EdgeSpec)
|
||||
@@ -803,6 +901,17 @@ func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec {
|
||||
return edges
|
||||
}
|
||||
|
||||
// FilterRel returns edges for the given relation type.
|
||||
func (es EdgeSpecs) FilterRel(r Rel) EdgeSpecs {
|
||||
edges := make([]*EdgeSpec, 0, len(es))
|
||||
for _, edge := range es {
|
||||
if edge.Rel == r {
|
||||
edges = append(edges, edge)
|
||||
}
|
||||
}
|
||||
return edges
|
||||
}
|
||||
|
||||
// The common operations shared between the different builders.
|
||||
//
|
||||
// M2M edges reside in join tables and require INSERT and DELETE
|
||||
@@ -822,7 +931,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
for _, table := range sortedKeys(tables) {
|
||||
for _, table := range edgeKeys(tables) {
|
||||
edges := tables[table]
|
||||
preds := make([]*sql.Predicate, 0, len(edges))
|
||||
for _, edge := range edges {
|
||||
@@ -850,7 +959,7 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
for _, table := range sortedKeys(tables) {
|
||||
for _, table := range edgeKeys(tables) {
|
||||
edges := tables[table]
|
||||
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
|
||||
for _, edge := range edges {
|
||||
@@ -873,6 +982,44 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error {
|
||||
tables := make(map[string]*sql.InsertBuilder)
|
||||
for _, node := range spec.Nodes {
|
||||
edges := EdgeSpecs(node.Edges).FilterRel(M2M)
|
||||
for t, edges := range edges.GroupTable() {
|
||||
insert, ok := tables[t]
|
||||
if !ok {
|
||||
insert = g.builder.Insert(t).Columns(edges[0].Columns...)
|
||||
}
|
||||
tables[t] = insert
|
||||
if len(edges) != 1 {
|
||||
return fmt.Errorf("expect exactly 1 edge-spec per table, but got %d", len(edges))
|
||||
}
|
||||
edge := edges[0]
|
||||
pk1, pk2 := []driver.Value{node.ID.Value}, edge.Target.Nodes
|
||||
if edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
for _, pair := range product(pk1, pk2) {
|
||||
insert.Values(pair[0], pair[1])
|
||||
if edge.Bidi {
|
||||
insert.Values(pair[1], pair[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, table := range insertKeys(tables) {
|
||||
var (
|
||||
res sql.Result
|
||||
query, args = tables[table].Query()
|
||||
)
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %v", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error {
|
||||
for _, edge := range edges {
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
@@ -925,6 +1072,8 @@ func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*Edg
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Setting the FK value of the "other" table
|
||||
// without clearing it before, is not allowed.
|
||||
if ids := edge.Target.Nodes; int(affected) < len(ids) {
|
||||
return &ConstraintError{msg: fmt.Sprintf("one of %v is already connected to a different %s", ids, edge.Columns[0])}
|
||||
}
|
||||
@@ -979,6 +1128,45 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.Inser
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities.
|
||||
func insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (ids []int64, err error) {
|
||||
query, args := insert.Query()
|
||||
// PostgreSQL does not support the LastInsertId() method of sql.Result
|
||||
// on Exec, and should be extracted manually using the `RETURNING` clause.
|
||||
if insert.Dialect() == dialect.Postgres {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return ids, sql.ScanSlice(rows, &ids)
|
||||
}
|
||||
// MySQL, SQLite, etc.
|
||||
var res sql.Result
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = make([]int64, 0, affected)
|
||||
switch insert.Dialect() {
|
||||
case dialect.SQLite:
|
||||
id -= affected - 1
|
||||
fallthrough
|
||||
case dialect.MySQL:
|
||||
for i := int64(0); i < affected; i++ {
|
||||
ids = append(ids, id+i)
|
||||
}
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
func rollback(tx dialect.Tx, err error) error {
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
@@ -987,7 +1175,25 @@ func rollback(tx dialect.Tx, err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func sortedKeys(m map[string][]*EdgeSpec) []string {
|
||||
func edgeKeys(m map[string][]*EdgeSpec) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func insertKeys(m map[string]*sql.InsertBuilder) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func keys(m map[string]struct{}) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
|
||||
@@ -741,6 +741,94 @@ func TestCreateNode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchCreate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nodes []*CreateSpec
|
||||
expect func(sqlmock.Sqlmock)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
expect: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
m.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
nodes: []*CreateSpec{
|
||||
{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id"},
|
||||
Fields: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 32},
|
||||
{Column: "name", Type: field.TypeString, Value: "a8m"},
|
||||
{Column: "active", Type: field.TypeBool, Value: false},
|
||||
},
|
||||
Edges: []*EdgeSpec{
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
|
||||
{Rel: M2M, Table: "user_products", Columns: []string{"user_id", "product_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
|
||||
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
|
||||
{Rel: M2O, Table: "company", Columns: []string{"workplace_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}}},
|
||||
{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id"},
|
||||
Fields: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 30},
|
||||
{Column: "name", Type: field.TypeString, Value: "nati"},
|
||||
},
|
||||
Edges: []*EdgeSpec{
|
||||
{Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
|
||||
{Rel: M2M, Table: "user_products", Columns: []string{"user_id", "product_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}},
|
||||
{Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}},
|
||||
{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
expect: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
// Insert nodes with FKs.
|
||||
m.ExpectExec(escape("INSERT INTO `users` (`active`, `age`, `name`, `workplace_id`) VALUES (?, ?, ?, ?), (?, ?, ?, ?)")).
|
||||
WithArgs(false, 32, "a8m", 2, nil, 30, "nati", nil).
|
||||
WillReturnResult(sqlmock.NewResult(10, 2))
|
||||
// Insert M2M inverse-edges.
|
||||
m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")).
|
||||
WithArgs(2, 10, 2, 11).
|
||||
WillReturnResult(sqlmock.NewResult(2, 2))
|
||||
// Insert M2M bidirectional edges.
|
||||
m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")).
|
||||
WithArgs(10, 2, 2, 10, 11, 2, 2, 11).
|
||||
WillReturnResult(sqlmock.NewResult(2, 2))
|
||||
// Insert M2M edges.
|
||||
m.ExpectExec(escape("INSERT INTO `user_products` (`user_id`, `product_id`) VALUES (?, ?), (?, ?)")).
|
||||
WithArgs(10, 2, 11, 2).
|
||||
WillReturnResult(sqlmock.NewResult(2, 2))
|
||||
// Update FKs exist in different tables.
|
||||
m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE (`id` = ?) AND (`owner_id` IS NULL")).
|
||||
WithArgs(10 /* id of the 1st new node */, 2 /* pet id */).
|
||||
WillReturnResult(sqlmock.NewResult(2, 2))
|
||||
m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE (`id` = ?) AND (`owner_id` IS NULL")).
|
||||
WithArgs(11 /* id of the 2nd new node */, 3 /* pet id */).
|
||||
WillReturnResult(sqlmock.NewResult(2, 2))
|
||||
m.ExpectCommit()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.expect(mock)
|
||||
err = BatchCreate(context.Background(), sql.OpenDB("mysql", db), &BatchCreateSpec{Nodes: tt.nodes})
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type user struct {
|
||||
id int
|
||||
age int
|
||||
|
||||
Reference in New Issue
Block a user