diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index 7dc2df2c2..31b2e86f3 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -4,6 +4,17 @@ package sql +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + + "github.com/facebookincubator/ent/dialect" + "github.com/facebookincubator/ent/schema/field" +) + // Rel is a relation type of an edge. type Rel int @@ -262,3 +273,109 @@ func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) { q.Where(In(from.C(s.From.Column), matches)) } } + +type ( + // FieldSpec holds the information for updating a field + // column in the database. + FieldSpec struct { + Column string + Type field.Type + Value driver.Value // value to be stored. + } + + // EdgeSpec holds the information for updating a field + // column in the database. + EdgeSpec struct { + Rel Rel + Table string + Columns []string + Inverse bool + Value driver.Value + } + + // CreateSpec holds the information for creating a node + // in the graph. + CreateSpec struct { + // Type or table name. + Table string + // ID field. + ID *FieldSpec + // Fields. + Fields []*FieldSpec + // Edges. + Edges []*EdgeSpec + } +) + +// CreateNode applies the spec on the graph. +func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error { + tx, err := drv.Tx(ctx) + if err != nil { + return err + } + insert := Dialect(drv.Dialect()).Insert(spec.Table).Default() + for _, fi := range spec.Fields { + value := fi.Value + if fi.Type == field.TypeJSON { + if value, err = json.Marshal(value); err != nil { + return err + } + } + insert.Set(fi.Column, value) + } + // ID was provided by the user. + if spec.ID.Value != nil { + insert.Set(spec.ID.Column, spec.ID.Value) + query, args := insert.Query() + if err := tx.Exec(ctx, query, args, new(sql.Result)); err != nil { + return rollback(tx, err) + } + } else { + id, err := insertLastID(ctx, tx, insert.Returning(spec.ID.Column)) + if err != nil { + return rollback(tx, err) + } + spec.ID.Value = id + } + return tx.Commit() +} + +// insertLastID invokes the insert query on the transaction and returns the LastInsertID. +func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) (int64, error) { + query, args := insert.Query() + // PostgreSQL does not support the LastInsertId() method of sql.Result + // on Exec, and should be extracted manually using the `RETURNING` clause. + if insert.Dialect() == dialect.Postgres { + rows := &sql.Rows{} + if err := tx.Query(ctx, query, args, rows); err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, fmt.Errorf("no rows found for query: %v", query) + } + var id int64 + if err := rows.Scan(&id); err != nil { + return 0, err + } + return id, nil + } + // MySQL, SQLite, etc. + var res sql.Result + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, err + } + id, err := res.LastInsertId() + if err != nil { + return 0, err + } + return id, nil +} + +// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred. +func rollback(tx dialect.Tx, err error) error { + if rerr := tx.Rollback(); rerr != nil { + err = fmt.Errorf("%s: %v", err.Error(), rerr) + } + return err +} diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index f3c619129..aa4895c5b 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -5,9 +5,14 @@ package sql import ( + "context" + "regexp" "strings" "testing" + "github.com/facebookincubator/ent/schema/field" + + "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) @@ -431,3 +436,84 @@ WHERE "groups"."id" IN }) } } + +func TestCreateNode(t *testing.T) { + tests := []struct { + name string + spec *CreateSpec + expect func(sqlmock.Sqlmock) + wantErr bool + }{ + { + name: "fields", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "age", Type: field.TypeInt, Value: 30}, + {Column: "name", Type: field.TypeString, Value: "a8m"}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`) VALUES (?, ?)")). + WithArgs(30, "a8m"). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "fields/user-defined-id", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id", Value: 1}, + Fields: []*FieldSpec{ + {Column: "age", Type: field.TypeInt, Value: 30}, + {Column: "name", Type: field.TypeString, Value: "a8m"}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`, `id`) VALUES (?, ?, ?)")). + WithArgs(30, "a8m", 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + { + name: "fields/json", + spec: &CreateSpec{ + Table: "users", + ID: &FieldSpec{Column: "id"}, + Fields: []*FieldSpec{ + {Column: "json", Type: field.TypeJSON, Value: struct{}{}}, + }, + }, + expect: func(m sqlmock.Sqlmock) { + m.ExpectBegin() + m.ExpectExec(escape("INSERT INTO `users` (`json`) VALUES (?)")). + WithArgs([]byte("{}")). + WillReturnResult(sqlmock.NewResult(1, 1)) + m.ExpectCommit() + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + tt.expect(mock) + err = CreateNode(context.Background(), OpenDB("", db), tt.spec) + require.Equal(t, tt.wantErr, err != nil, err) + }) + } +} + +func escape(query string) string { + rows := strings.Split(query, "\n") + for i := range rows { + rows[i] = strings.TrimPrefix(rows[i], " ") + } + query = strings.Join(rows, " ") + return regexp.QuoteMeta(query) +}