dialect/sql/sqlgraph: allow setting stmt modifiers for create-one

This commit is contained in:
Ariel Mashraki
2021-07-29 15:07:06 +03:00
committed by Ariel Mashraki
parent 19b418d1a0
commit b9fcbff724
2 changed files with 40 additions and 1 deletions

View File

@@ -344,7 +344,20 @@ type (
ID *FieldSpec
Fields []*FieldSpec
Edges []*EdgeSpec
// The Modifiers option allows providing custom functions to
// modify the INSERT statement of the node before it is executed.
//
// sqlgraph.CreateSpec{
// Modifiers: []func(*sql.InsertBuilder) {
// OnConflictOptions(...),
// ReturningOptions(...),
// },
// }
//
Modifiers []func(*sql.InsertBuilder)
}
// BatchCreateSpec holds the information for creating
// multiple nodes in the graph.
BatchCreateSpec struct {
@@ -958,6 +971,9 @@ func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sq
query, args := insert.Query()
return tx.Exec(ctx, query, args, &res)
}
for _, m := range c.CreateSpec.Modifiers {
m(insert)
}
id, err := insertLastID(ctx, tx, insert.Returning(c.ID.Column))
if err != nil {
return err

View File

@@ -858,6 +858,29 @@ func TestCreateNode(t *testing.T) {
m.ExpectCommit()
},
},
{
name: "modifiers",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 30},
{Column: "name", Type: field.TypeString, Value: "a8m"},
},
Modifiers: []func(i *sql.InsertBuilder){
func(i *sql.InsertBuilder) {
i.OnConflict(sql.ResolveWithNewValues())
},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `age` = VALUES(`age`), `name` = VALUES(`name`)")).
WithArgs(30, "a8m").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "fields/user-defined-id",
spec: &CreateSpec{
@@ -1148,7 +1171,7 @@ func TestCreateNode(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
tt.expect(mock)
err = CreateNode(context.Background(), sql.OpenDB("", db), tt.spec)
err = CreateNode(context.Background(), sql.OpenDB(dialect.MySQL, db), tt.spec)
require.Equal(t, tt.wantErr, err != nil, err)
})
}