dialect/sqlgraph: initial work for create-node api (#211)

Summary:
Pull Request resolved: https://github.com/facebookincubator/ent/pull/211

Move out the logic from the Go templates to Go code.

Next diff will add the edges of the node.

Reviewed By: alexsn

Differential Revision: D18762049

fbshipit-source-id: c9a93672415a26a6f4a7d466e569b8b0e8b0f9ee
This commit is contained in:
Ariel Mashraki
2019-12-02 08:10:28 -08:00
committed by Facebook Github Bot
parent 0f4fc12cc5
commit a4fac2db3b
2 changed files with 203 additions and 0 deletions

View File

@@ -4,6 +4,17 @@
package sql
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"github.com/facebookincubator/ent/dialect"
"github.com/facebookincubator/ent/schema/field"
)
// Rel is a relation type of an edge.
type Rel int
@@ -262,3 +273,109 @@ func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) {
q.Where(In(from.C(s.From.Column), matches))
}
}
type (
// FieldSpec holds the information for updating a field
// column in the database.
FieldSpec struct {
Column string
Type field.Type
Value driver.Value // value to be stored.
}
// EdgeSpec holds the information for updating a field
// column in the database.
EdgeSpec struct {
Rel Rel
Table string
Columns []string
Inverse bool
Value driver.Value
}
// CreateSpec holds the information for creating a node
// in the graph.
CreateSpec struct {
// Type or table name.
Table string
// ID field.
ID *FieldSpec
// Fields.
Fields []*FieldSpec
// Edges.
Edges []*EdgeSpec
}
)
// CreateNode applies the spec on the graph.
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
tx, err := drv.Tx(ctx)
if err != nil {
return err
}
insert := Dialect(drv.Dialect()).Insert(spec.Table).Default()
for _, fi := range spec.Fields {
value := fi.Value
if fi.Type == field.TypeJSON {
if value, err = json.Marshal(value); err != nil {
return err
}
}
insert.Set(fi.Column, value)
}
// ID was provided by the user.
if spec.ID.Value != nil {
insert.Set(spec.ID.Column, spec.ID.Value)
query, args := insert.Query()
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
return rollback(tx, err)
}
} else {
id, err := insertLastID(ctx, tx, insert.Returning(spec.ID.Column))
if err != nil {
return rollback(tx, err)
}
spec.ID.Value = id
}
return tx.Commit()
}
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) (int64, error) {
query, args := insert.Query()
// PostgreSQL does not support the LastInsertId() method of sql.Result
// on Exec, and should be extracted manually using the `RETURNING` clause.
if insert.Dialect() == dialect.Postgres {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, fmt.Errorf("no rows found for query: %v", query)
}
var id int64
if err := rows.Scan(&id); err != nil {
return 0, err
}
return id, nil
}
// MySQL, SQLite, etc.
var res sql.Result
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, err
}
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
func rollback(tx dialect.Tx, err error) error {
if rerr := tx.Rollback(); rerr != nil {
err = fmt.Errorf("%s: %v", err.Error(), rerr)
}
return err
}

View File

@@ -5,9 +5,14 @@
package sql
import (
"context"
"regexp"
"strings"
"testing"
"github.com/facebookincubator/ent/schema/field"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
@@ -431,3 +436,84 @@ WHERE "groups"."id" IN
})
}
}
func TestCreateNode(t *testing.T) {
tests := []struct {
name string
spec *CreateSpec
expect func(sqlmock.Sqlmock)
wantErr bool
}{
{
name: "fields",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 30},
{Column: "name", Type: field.TypeString, Value: "a8m"},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`) VALUES (?, ?)")).
WithArgs(30, "a8m").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "fields/user-defined-id",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id", Value: 1},
Fields: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 30},
{Column: "name", Type: field.TypeString, Value: "a8m"},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`, `id`) VALUES (?, ?, ?)")).
WithArgs(30, "a8m", 1).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "fields/json",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "json", Type: field.TypeJSON, Value: struct{}{}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`json`) VALUES (?)")).
WithArgs([]byte("{}")).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tt.expect(mock)
err = CreateNode(context.Background(), OpenDB("", db), tt.spec)
require.Equal(t, tt.wantErr, err != nil, err)
})
}
}
func escape(query string) string {
rows := strings.Split(query, "\n")
for i := range rows {
rows[i] = strings.TrimPrefix(rows[i], " ")
}
query = strings.Join(rows, " ")
return regexp.QuoteMeta(query)
}