dialect/sql/sqlgraph: add support for edge-schema in create-bulk

This commit is contained in:
Ariel Mashraki
2022-11-20 16:10:32 +02:00
committed by Ariel Mashraki
parent aa3d21f01a
commit f46dd7ace8
3 changed files with 74 additions and 17 deletions

View File

@@ -332,6 +332,15 @@ type (
}
)
// FieldValues returns the values of additional fields that were set on the join-table.
func (e *EdgeTarget) FieldValues() []any {
vs := make([]any, len(e.Fields))
for i, f := range e.Fields {
vs[i] = f.Value
}
return vs
}
type (
// CreateSpec holds the information for creating
// a node in the graph.
@@ -372,7 +381,7 @@ type (
}
)
// SetField appends a new field setter to the create spec.
// SetField appends a new field setter to the creation spec.
func (u *CreateSpec) SetField(column string, t field.Type, value driver.Value) {
u.Fields = append(u.Fields, &FieldSpec{
Column: column,
@@ -1328,29 +1337,39 @@ func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error {
tables := make(map[string]*sql.InsertBuilder)
for _, node := range spec.Nodes {
edges := EdgeSpecs(node.Edges).FilterRel(M2M)
for t, edges := range edges.GroupTable() {
insert, ok := tables[t]
if !ok {
insert = g.builder.Insert(t).Columns(edges[0].Columns...)
if edges[0].Schema != "" {
// If the Schema field was provided to the EdgeSpec (by the
// generated code), it should be the same for all EdgeSpecs.
insert.Schema(edges[0].Schema)
}
}
tables[t] = insert
for name, edges := range edges.GroupTable() {
if len(edges) != 1 {
return fmt.Errorf("expect exactly 1 edge-spec per table, but got %d", len(edges))
}
edge := edges[0]
insert, ok := tables[name]
if !ok {
columns := edge.Columns
// Additional fields, such as edge-schema fields.
for _, f := range edge.Target.Fields {
columns = append(columns, f.Column)
}
insert = g.builder.Insert(name).Columns(columns...)
if edge.Schema != "" {
// If the Schema field was provided to the EdgeSpec (by the
// generated code), it should be the same for all EdgeSpecs.
insert.Schema(edge.Schema)
}
// Ignore conflicts only if edges do not contain extra fields, because these fields
// can hold different values on different insertions (e.g. time.Now() or uuid.New()).
if len(edge.Target.Fields) == 0 {
insert.OnConflict(sql.DoNothing())
}
}
tables[name] = insert
pk1, pk2 := []driver.Value{node.ID.Value}, edge.Target.Nodes
if edge.Inverse {
pk1, pk2 = pk2, pk1
}
for _, pair := range product(pk1, pk2) {
insert.Values(pair[0], pair[1])
insert.Values(append([]any{pair[0], pair[1]}, edge.Target.FieldValues()...)...)
if edge.Bidi {
insert.Values(pair[1], pair[0])
insert.Values(append([]any{pair[1], pair[0]}, edge.Target.FieldValues()...)...)
}
}
}