mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sqlgraph: initial work for create-node api (#211)
Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/211 Move out the logic from the Go templates to Go code. Next diff will add the edges of the node. Reviewed By: alexsn Differential Revision: D18762049 fbshipit-source-id: c9a93672415a26a6f4a7d466e569b8b0e8b0f9ee
This commit is contained in:
committed by
Facebook Github Bot
parent
0f4fc12cc5
commit
a4fac2db3b
@@ -4,6 +4,17 @@
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/facebookincubator/ent/dialect"
|
||||
"github.com/facebookincubator/ent/schema/field"
|
||||
)
|
||||
|
||||
// Rel is a relation type of an edge.
|
||||
type Rel int
|
||||
|
||||
@@ -262,3 +273,109 @@ func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) {
|
||||
q.Where(In(from.C(s.From.Column), matches))
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
// FieldSpec holds the information for updating a field
|
||||
// column in the database.
|
||||
FieldSpec struct {
|
||||
Column string
|
||||
Type field.Type
|
||||
Value driver.Value // value to be stored.
|
||||
}
|
||||
|
||||
// EdgeSpec holds the information for updating a field
|
||||
// column in the database.
|
||||
EdgeSpec struct {
|
||||
Rel Rel
|
||||
Table string
|
||||
Columns []string
|
||||
Inverse bool
|
||||
Value driver.Value
|
||||
}
|
||||
|
||||
// CreateSpec holds the information for creating a node
|
||||
// in the graph.
|
||||
CreateSpec struct {
|
||||
// Type or table name.
|
||||
Table string
|
||||
// ID field.
|
||||
ID *FieldSpec
|
||||
// Fields.
|
||||
Fields []*FieldSpec
|
||||
// Edges.
|
||||
Edges []*EdgeSpec
|
||||
}
|
||||
)
|
||||
|
||||
// CreateNode applies the spec on the graph.
|
||||
func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error {
|
||||
tx, err := drv.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
insert := Dialect(drv.Dialect()).Insert(spec.Table).Default()
|
||||
for _, fi := range spec.Fields {
|
||||
value := fi.Value
|
||||
if fi.Type == field.TypeJSON {
|
||||
if value, err = json.Marshal(value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
insert.Set(fi.Column, value)
|
||||
}
|
||||
// ID was provided by the user.
|
||||
if spec.ID.Value != nil {
|
||||
insert.Set(spec.ID.Column, spec.ID.Value)
|
||||
query, args := insert.Query()
|
||||
if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
} else {
|
||||
id, err := insertLastID(ctx, tx, insert.Returning(spec.ID.Column))
|
||||
if err != nil {
|
||||
return rollback(tx, err)
|
||||
}
|
||||
spec.ID.Value = id
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
|
||||
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) (int64, error) {
|
||||
query, args := insert.Query()
|
||||
// PostgreSQL does not support the LastInsertId() method of sql.Result
|
||||
// on Exec, and should be extracted manually using the `RETURNING` clause.
|
||||
if insert.Dialect() == dialect.Postgres {
|
||||
rows := &sql.Rows{}
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, fmt.Errorf("no rows found for query: %v", query)
|
||||
}
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
// MySQL, SQLite, etc.
|
||||
var res sql.Result
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred.
|
||||
func rollback(tx dialect.Tx, err error) error {
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
err = fmt.Errorf("%s: %v", err.Error(), rerr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,9 +5,14 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/facebookincubator/ent/schema/field"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -431,3 +436,84 @@ WHERE "groups"."id" IN
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec *CreateSpec
|
||||
expect func(sqlmock.Sqlmock)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "fields",
|
||||
spec: &CreateSpec{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id"},
|
||||
Fields: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 30},
|
||||
{Column: "name", Type: field.TypeString, Value: "a8m"},
|
||||
},
|
||||
},
|
||||
expect: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`) VALUES (?, ?)")).
|
||||
WithArgs(30, "a8m").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fields/user-defined-id",
|
||||
spec: &CreateSpec{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id", Value: 1},
|
||||
Fields: []*FieldSpec{
|
||||
{Column: "age", Type: field.TypeInt, Value: 30},
|
||||
{Column: "name", Type: field.TypeString, Value: "a8m"},
|
||||
},
|
||||
},
|
||||
expect: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`, `id`) VALUES (?, ?, ?)")).
|
||||
WithArgs(30, "a8m", 1).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.ExpectCommit()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fields/json",
|
||||
spec: &CreateSpec{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id"},
|
||||
Fields: []*FieldSpec{
|
||||
{Column: "json", Type: field.TypeJSON, Value: struct{}{}},
|
||||
},
|
||||
},
|
||||
expect: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
m.ExpectExec(escape("INSERT INTO `users` (`json`) VALUES (?)")).
|
||||
WithArgs([]byte("{}")).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.ExpectCommit()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.expect(mock)
|
||||
err = CreateNode(context.Background(), OpenDB("", db), tt.spec)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func escape(query string) string {
|
||||
rows := strings.Split(query, "\n")
|
||||
for i := range rows {
|
||||
rows[i] = strings.TrimPrefix(rows[i], " ")
|
||||
}
|
||||
query = strings.Join(rows, " ")
|
||||
return regexp.QuoteMeta(query)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user