mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
dialect/sqlgraph: add edges in node creation (#216)
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/216 WIP - ignore for now Reviewed By: alexsn Differential Revision: D18795361 fbshipit-source-id: d3a4ef5562be5faf0837cad6364130ec203a9d37
This commit is contained in:
committed by
Facebook Github Bot
parent
a5e4a9cf54
commit
bb051603ac
@@ -283,61 +283,173 @@ type (
|
||||
Value driver.Value // value to be stored.
|
||||
}
|
||||
|
||||
// EdgeTarget holds the information for the target nodes
|
||||
// of an edge.
|
||||
EdgeTarget struct {
|
||||
Nodes []driver.Value
|
||||
IDSpec *FieldSpec
|
||||
}
|
||||
|
||||
// EdgeSpec holds the information for updating a field
|
||||
// column in the database.
|
||||
EdgeSpec struct {
|
||||
Rel Rel
|
||||
Inverse bool
|
||||
Table string
|
||||
Columns []string
|
||||
Inverse bool
|
||||
Value driver.Value
|
||||
Bidi bool // bidirectional edge.
|
||||
Target *EdgeTarget // target nodes.
|
||||
}
|
||||
|
||||
// EdgeSpecs used for perform common operations on list of edges.
|
||||
EdgeSpecs []*EdgeSpec
|
||||
|
||||
// CreateSpec holds the information for creating a node
|
||||
// in the graph.
|
||||
CreateSpec struct {
|
||||
// Type or table name.
|
||||
Table string
|
||||
// ID field.
|
||||
ID *FieldSpec
|
||||
// Fields.
|
||||
Table string
|
||||
ID *FieldSpec
|
||||
Fields []*FieldSpec
|
||||
// Edges.
|
||||
Edges []*EdgeSpec
|
||||
Edges []*EdgeSpec
|
||||
}
|
||||
)
|
||||
|
||||
// CreateNode applies the spec on the graph.
|
||||
// CreateNode applies the CreateSpec on the graph.
|
||||
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
|
||||
tx, err := drv.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
insert := Dialect(drv.Dialect()).Insert(spec.Table).Default()
|
||||
for _, fi := range spec.Fields {
|
||||
cr := &creator{CreateSpec: spec, builder: Dialect(drv.Dialect())}
|
||||
if err := cr.node(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
type creator struct {
|
||||
*CreateSpec
|
||||
builder *dialectBuilder
|
||||
}
|
||||
|
||||
func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
var (
|
||||
res sql.Result
|
||||
edges = EdgeSpecs(c.Edges).GroupRel()
|
||||
insert = c.builder.Insert(c.Table).Default()
|
||||
)
|
||||
// Set and create the node.
|
||||
if err := c.setTableColumns(insert, edges); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.insert(ctx, tx, insert); err != nil {
|
||||
return fmt.Errorf("insert node to table %s: %v", c.Table, err)
|
||||
}
|
||||
// Insert all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables := EdgeSpecs(edges[M2M]).GroupTable()
|
||||
for table, edges := range tables {
|
||||
edge := edges[0]
|
||||
insert = c.builder.Insert(table).Columns(edge.Columns...)
|
||||
for _, edge := range edges {
|
||||
pk1, pk2 := c.ID.Value, edge.Target.Nodes[0]
|
||||
if edge.Inverse {
|
||||
pk1, pk2 = pk2, pk1
|
||||
}
|
||||
insert.Values(pk1, pk2)
|
||||
if edge.Bidi {
|
||||
insert.Values(pk2, pk1)
|
||||
}
|
||||
}
|
||||
query, args := insert.Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %v", table, err)
|
||||
}
|
||||
}
|
||||
// O2M and non-inverse O2O edges also reside in external tables.
|
||||
for _, edge := range append(edges[O2M], edges[O2O]...) {
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
continue
|
||||
}
|
||||
p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
|
||||
// Use "IN" predicate instead of list of "OR"
|
||||
// in case of more than on nodes to connect.
|
||||
if len(edge.Target.Nodes) > 1 {
|
||||
p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
|
||||
}
|
||||
query, args := c.builder.Update(edge.Table).
|
||||
Set(edge.Columns[0], c.ID.Value).
|
||||
Where(And(p, IsNull(edge.Columns[0]))).
|
||||
Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add m2m edge for table %s: %v", edge.Table, err)
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ids := edge.Target.Nodes; int(affected) < len(ids) {
|
||||
return fmt.Errorf("one of %v is already connected to a different %s", ids, edge.Columns[0])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSpec) (err error) {
|
||||
for _, fi := range c.Fields {
|
||||
value := fi.Value
|
||||
if fi.Type == field.TypeJSON {
|
||||
if value, err = json.Marshal(value); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("marshal value for column %s: %v", fi.Column, err)
|
||||
}
|
||||
}
|
||||
insert.Set(fi.Column, value)
|
||||
}
|
||||
// ID was provided by the user.
|
||||
if spec.ID.Value != nil {
|
||||
insert.Set(spec.ID.Column, spec.ID.Value)
|
||||
query, args := insert.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
} else {
|
||||
id, err := insertLastID(ctx, tx, insert.Returning(spec.ID.Column))
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
spec.ID.Value = id
|
||||
for _, e := range edges[M2O] {
|
||||
insert.Set(e.Columns[0], e.Target.Nodes[0])
|
||||
}
|
||||
return tx.Commit()
|
||||
for _, e := range edges[O2O] {
|
||||
if e.Inverse || e.Bidi {
|
||||
insert.Set(e.Columns[0], e.Target.Nodes[0])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
|
||||
func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) error {
|
||||
var res sql.Result
|
||||
// If the id field was provided by the user.
|
||||
if c.ID.Value != nil {
|
||||
insert.Set(c.ID.Column, c.ID.Value)
|
||||
query, args := insert.Query()
|
||||
return tx.Exec(ctx, query, args, &res)
|
||||
}
|
||||
id, err := insertLastID(ctx, tx, insert.Returning(c.ID.Column))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ID.Value = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupRel groups edges by their relation type.
|
||||
func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec {
|
||||
edges := make(map[Rel][]*EdgeSpec)
|
||||
for _, edge := range es {
|
||||
edges[edge.Rel] = append(edges[edge.Rel], edge)
|
||||
}
|
||||
return edges
|
||||
}
|
||||
|
||||
// GroupTable groups edges by their table name.
|
||||
func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec {
|
||||
edges := make(map[string][]*EdgeSpec)
|
||||
for _, edge := range es {
|
||||
edges[edge.Table] = append(edges[edge.Table], edge)
|
||||
}
|
||||
return edges
|
||||
}
|
||||
|
||||
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
|
||||
|
||||
Reference in New Issue
Block a user