From 4ab244cf934a30c8272b388921ba5cdb12385cb1 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Thu, 12 Dec 2019 12:56:31 +0200 Subject: [PATCH] dialect/sql/sqlgraph: first work on nodes querying --- dialect/sql/graph.go | 59 +++++++++++++++++++++++++++++++++++++++ dialect/sql/graph_test.go | 55 ++++++++++++++++++++++++++++++++++-- 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/dialect/sql/graph.go b/dialect/sql/graph.go index a6ea35ce3..af4d0f39d 100644 --- a/dialect/sql/graph.go +++ b/dialect/sql/graph.go @@ -10,6 +10,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "math" "sort" "github.com/facebookincubator/ent/dialect" @@ -426,6 +427,64 @@ func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int return int(affected), tx.Commit() } +// QuerySpec holds the information for querying +// nodes in the graph. +type QuerySpec struct { + Node *NodeSpec // Nodes info. + From *Selector // Optional query source (from path). + + Limit int + Offset int + Unique bool + Order func(*Selector) + Predicate func(*Selector) + + ScanValues func() []interface{} + Assign func(...interface{}) error +} + +// QueryNodes query the nodes in the graph and scan them to the given values. +func QueryNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) error { + builder := Dialect(drv.Dialect()) + selector := builder.Select().From(builder.Table(spec.Node.Table)) + if spec.From != nil { + selector = spec.From + } + selector.Select(spec.Node.Columns...) + if pred := spec.Predicate; pred != nil { + pred(selector) + } + if order := spec.Order; order != nil { + order(selector) + } + if spec.Offset != 0 { + // Limit is mandatory for the offset clause. We start + // with default value, and override it below if needed. + selector.Offset(spec.Offset).Limit(math.MaxInt32) + } + if spec.Limit != 0 { + selector.Limit(spec.Limit) + } + if spec.Unique { + selector.Distinct() + } + rows := &Rows{} + query, args := selector.Query() + if err := drv.Query(ctx, query, args, rows); err != nil { + return err + } + for rows.Next() { + values := spec.ScanValues() + if err := rows.Scan(values...); err != nil { + return err + } + if err := spec.Assign(values...); err != nil { + return err + } + } + return nil +} + type updater struct { graph *UpdateSpec diff --git a/dialect/sql/graph_test.go b/dialect/sql/graph_test.go index d6ecdb3d2..4859b65e2 100644 --- a/dialect/sql/graph_test.go +++ b/dialect/sql/graph_test.go @@ -741,9 +741,13 @@ func TestCreateNode(t *testing.T) { } type user struct { - id int - age int - name string + id int + age int + name string + edges struct { + fk1 int + fk2 int + } } func (*user) values() []interface{} { @@ -754,6 +758,11 @@ func (u *user) assign(values ...interface{}) error { u.id = int(values[0].(*NullInt64).Int64) u.age = int(values[1].(*NullInt64).Int64) u.name = values[2].(*NullString).String + // loaded with foreign-keys. + if len(values) > 3 { + u.edges.fk1 = int(values[3].(*NullInt64).Int64) + u.edges.fk2 = int(values[4].(*NullInt64).Int64) + } return nil } @@ -1141,6 +1150,46 @@ func TestDeleteNodes(t *testing.T) { require.Equal(t, 2, affected) } +func TestQueryNodes(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + mock.ExpectQuery(escape("SELECT DISTINCT `id`, `age`, `name`, `fk1`, `fk2` FROM `users` WHERE `age` < ? ORDER BY `id` LIMIT ? OFFSET ?")). + WithArgs(40, 3, 4). + WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name", "fk1", "fk2"}). + AddRow(1, 10, nil, nil, nil). + AddRow(2, 20, "", 0, 0). + AddRow(3, 30, "a8m", 1, 1)) + var users []*user + err = QueryNodes(context.Background(), OpenDB("", db), &QuerySpec{ + Node: &NodeSpec{ + Table: "users", + Columns: []string{"id", "age", "name", "fk1", "fk2"}, + ID: &FieldSpec{Column: "id", Type: field.TypeInt}, + }, + Limit: 3, + Offset: 4, + Unique: true, + Order: func(s *Selector) { + s.OrderBy("id") + }, + Predicate: func(s *Selector) { + s.Where(LT("age", 40)) + }, + ScanValues: func() []interface{} { + u := &user{} + users = append(users, u) + return append(u.values(), &NullInt64{}, &NullInt64{}) // extra values for fks. + }, + Assign: func(values ...interface{}) error { + return users[len(users)-1].assign(values...) + }, + }) + require.NoError(t, err) + require.Equal(t, &user{id: 1, age: 10, name: ""}, users[0]) + require.Equal(t, &user{id: 2, age: 20, name: ""}, users[1]) + require.Equal(t, &user{id: 3, age: 30, name: "a8m", edges: struct{ fk1, fk2 int }{1, 1}}, users[2]) +} + func escape(query string) string { rows := strings.Split(query, "\n") for i := range rows {