dialect/sqlgraph: initial work for create-node api (#211)

Summary:
Pull Request resolved: https://github.com/facebookincubator/ent/pull/211

Move out the logic from the Go templates to Go code.

Next diff will add the edges of the node.

Reviewed By: alexsn

Differential Revision: D18762049

fbshipit-source-id: c9a93672415a26a6f4a7d466e569b8b0e8b0f9ee
This commit is contained in:
Ariel Mashraki
2019-12-02 08:10:28 -08:00
committed by Facebook Github Bot
parent 0f4fc12cc5
commit a4fac2db3b
2 changed files with 203 additions and 0 deletions

View File

@@ -4,6 +4,17 @@
package sql
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"github.com/facebookincubator/ent/dialect"
"github.com/facebookincubator/ent/schema/field"
)
// Rel is a relation type of an edge.
type Rel int
@@ -262,3 +273,109 @@ func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) {
q.Where(In(from.C(s.From.Column), matches))
}
}
type (
// FieldSpec holds the information for updating a field
// column in the database.
FieldSpec struct {
Column string
Type field.Type
Value driver.Value // value to be stored.
}
// EdgeSpec holds the information for updating a field
// column in the database.
EdgeSpec struct {
Rel Rel
Table string
Columns []string
Inverse bool
Value driver.Value
}
// CreateSpec holds the information for creating a node
// in the graph.
CreateSpec struct {
// Type or table name.
Table string
// ID field.
ID *FieldSpec
// Fields.
Fields []*FieldSpec
// Edges.
Edges []*EdgeSpec
}
)
// CreateNode applies the spec on the graph.
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
tx, err := drv.Tx(ctx)
if err != nil {
return err
}
insert := Dialect(drv.Dialect()).Insert(spec.Table).Default()
for _, fi := range spec.Fields {
value := fi.Value
if fi.Type == field.TypeJSON {
if value, err = json.Marshal(value); err != nil {
return err
}
}
insert.Set(fi.Column, value)
}
// ID was provided by the user.
if spec.ID.Value != nil {
insert.Set(spec.ID.Column, spec.ID.Value)
query, args := insert.Query()
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
return rollback(tx, err)
}
} else {
id, err := insertLastID(ctx, tx, insert.Returning(spec.ID.Column))
if err != nil {
return rollback(tx, err)
}
spec.ID.Value = id
}
return tx.Commit()
}
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) (int64, error) {
query, args := insert.Query()
// PostgreSQL does not support the LastInsertId() method of sql.Result
// on Exec, and should be extracted manually using the `RETURNING` clause.
if insert.Dialect() == dialect.Postgres {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, fmt.Errorf("no rows found for query: %v", query)
}
var id int64
if err := rows.Scan(&id); err != nil {
return 0, err
}
return id, nil
}
// MySQL, SQLite, etc.
var res sql.Result
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, err
}
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
func rollback(tx dialect.Tx, err error) error {
if rerr := tx.Rollback(); rerr != nil {
err = fmt.Errorf("%s: %v", err.Error(), rerr)
}
return err
}