dialect/sql/sqlgraph: schema options for operations (#1136)

* Add schema options for sql graph

* PR Fixes
This commit is contained in:
Marwan Sulaiman
2021-01-12 02:29:01 -05:00
committed by GitHub
parent 5fc907451f
commit 13b61ff455
2 changed files with 549 additions and 24 deletions

View File

@@ -71,6 +71,9 @@ type Step struct {
Edge struct {
// Rel of the edge.
Rel Rel
// Schema is an optional name of the database
// where the table is defined.
Schema string
// Table name of where this edge columns reside.
Table string
// Columns of the edge.
@@ -84,6 +87,9 @@ type Step struct {
To struct {
// Table holds the table name of the neighbors (to).
Table string
// Schema is an optional name of the database
// where the table is defined.
Schema string
// Column to join with. Usually the "id" column.
Column string
}
@@ -147,8 +153,8 @@ func Neighbors(dialect string, s *Step) (q *sql.Selector) {
if s.Edge.Inverse {
pk1, pk2 = pk2, pk1
}
to := builder.Table(s.To.Table)
join := builder.Table(s.Edge.Table)
to := builder.Table(s.To.Table).Schema(s.To.Schema)
join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
match := builder.Select(join.C(pk1)).
From(join).
Where(sql.EQ(join.C(pk2), s.From.V))
@@ -157,9 +163,9 @@ func Neighbors(dialect string, s *Step) (q *sql.Selector) {
Join(match).
On(to.C(s.To.Column), match.C(pk1))
case r == M2O || (r == O2O && s.Edge.Inverse):
t1 := builder.Table(s.To.Table)
t1 := builder.Table(s.To.Table).Schema(s.To.Schema)
t2 := builder.Select(s.Edge.Columns[0]).
From(builder.Table(s.Edge.Table)).
From(builder.Table(s.Edge.Table).Schema(s.Edge.Schema)).
Where(sql.EQ(s.From.Column, s.From.V))
q = builder.Select().
From(t1).
@@ -167,7 +173,7 @@ func Neighbors(dialect string, s *Step) (q *sql.Selector) {
On(t1.C(s.To.Column), t2.C(s.Edge.Columns[0]))
case r == O2M || (r == O2O && !s.Edge.Inverse):
q = builder.Select().
From(builder.Table(s.To.Table)).
From(builder.Table(s.To.Table).Schema(s.To.Schema)).
Where(sql.EQ(s.Edge.Columns[0], s.From.V))
}
return q
@@ -184,9 +190,9 @@ func SetNeighbors(dialect string, s *Step) (q *sql.Selector) {
if s.Edge.Inverse {
pk1, pk2 = pk2, pk1
}
to := builder.Table(s.To.Table)
to := builder.Table(s.To.Table).Schema(s.To.Schema)
set.Select(set.C(s.From.Column))
join := builder.Table(s.Edge.Table)
join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
match := builder.Select(join.C(pk1)).
From(join).
Join(set).
@@ -196,14 +202,14 @@ func SetNeighbors(dialect string, s *Step) (q *sql.Selector) {
Join(match).
On(to.C(s.To.Column), match.C(pk1))
case r == M2O || (r == O2O && s.Edge.Inverse):
t1 := builder.Table(s.To.Table)
t1 := builder.Table(s.To.Table).Schema(s.To.Schema)
set.Select(set.C(s.Edge.Columns[0]))
q = builder.Select().
From(t1).
Join(set).
On(t1.C(s.To.Column), set.C(s.Edge.Columns[0]))
case r == O2M || (r == O2O && !s.Edge.Inverse):
t1 := builder.Table(s.To.Table)
t1 := builder.Table(s.To.Table).Schema(s.To.Schema)
set.Select(set.C(s.From.Column))
q = builder.Select().
From(t1).
@@ -223,7 +229,7 @@ func HasNeighbors(q *sql.Selector, s *Step) {
pk1 = s.Edge.Columns[1]
}
from := q.Table()
join := builder.Table(s.Edge.Table)
join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
q.Where(
sql.In(
from.C(s.From.Column),
@@ -235,7 +241,7 @@ func HasNeighbors(q *sql.Selector, s *Step) {
q.Where(sql.NotNull(from.C(s.Edge.Columns[0])))
case r == O2M || (r == O2O && !s.Edge.Inverse):
from := q.Table()
to := builder.Table(s.Edge.Table)
to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
q.Where(
sql.In(
from.C(s.From.Column),
@@ -258,8 +264,8 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
pk1, pk2 = pk2, pk1
}
from := q.Table()
to := builder.Table(s.To.Table)
edge := builder.Table(s.Edge.Table)
to := builder.Table(s.To.Table).Schema(s.To.Schema)
edge := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
join := builder.Select(edge.C(pk2)).
From(edge).
Join(to).
@@ -270,14 +276,14 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
q.Where(sql.In(from.C(s.From.Column), join))
case r == M2O || (r == O2O && s.Edge.Inverse):
from := q.Table()
to := builder.Table(s.To.Table)
to := builder.Table(s.To.Table).Schema(s.To.Schema)
matches := builder.Select(to.C(s.To.Column)).
From(to)
pred(matches)
q.Where(sql.In(from.C(s.Edge.Columns[0]), matches))
case r == O2M || (r == O2O && !s.Edge.Inverse):
from := q.Table()
to := builder.Table(s.Edge.Table)
to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema)
matches := builder.Select(to.C(s.Edge.Columns[0])).
From(to)
pred(matches)
@@ -307,6 +313,7 @@ type (
Rel Rel
Inverse bool
Table string
Schema string
Columns []string
Bidi bool // bidirectional edge.
Target *EdgeTarget // target nodes.
@@ -319,6 +326,7 @@ type (
// decoding nodes in the graph.
NodeSpec struct {
Table string
Schema string
Columns []string
ID *FieldSpec
}
@@ -329,6 +337,7 @@ type (
// a node in the graph.
CreateSpec struct {
Table string
Schema string
ID *FieldSpec
Fields []*FieldSpec
Edges []*EdgeSpec
@@ -453,11 +462,11 @@ func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int
builder = sql.Dialect(drv.Dialect())
)
selector := builder.Select().
From(builder.Table(spec.Node.Table))
From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema))
if pred := spec.Predicate; pred != nil {
pred(selector)
}
query, args := builder.Delete(spec.Node.Table).FromSelect(selector).Query()
query, args := builder.Delete(spec.Node.Table).Schema(spec.Node.Schema).FromSelect(selector).Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return 0, rollback(tx, err)
}
@@ -518,7 +527,7 @@ func QueryEdges(ctx context.Context, drv dialect.Driver, spec *EdgeQuerySpec) er
}
selector := sql.Dialect(drv.Dialect()).
Select(out, in).
From(sql.Table(spec.Edge.Table))
From(sql.Table(spec.Edge.Table).Schema(spec.Edge.Schema))
if p := spec.Predicate; p != nil {
p(selector)
}
@@ -595,7 +604,7 @@ func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
}
func (q *query) selector() (*sql.Selector, error) {
selector := q.builder.Select().From(q.builder.Table(q.Node.Table))
selector := q.builder.Select().From(q.builder.Table(q.Node.Table).Schema(q.Node.Schema))
if q.From != nil {
selector = q.From
}
@@ -636,7 +645,7 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
)
update := u.builder.Update(u.Node.Table).Where(sql.EQ(u.Node.ID.Column, id))
update := u.builder.Update(u.Node.Table).Schema(u.Node.Schema).Where(sql.EQ(u.Node.ID.Column, id))
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
return err
}
@@ -651,7 +660,7 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
return err
}
selector := u.builder.Select(u.Node.Columns...).
From(u.builder.Table(u.Node.Table)).
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).
Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value))
rows := &sql.Rows{}
query, args := selector.Query()
@@ -667,9 +676,9 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
multiple = u.hasExternalEdges(addEdges, clearEdges)
update = u.builder.Update(u.Node.Table)
update = u.builder.Update(u.Node.Table).Schema(u.Node.Schema)
selector = u.builder.Select(u.Node.ID.Column).
From(u.builder.Table(u.Node.Table))
From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema))
)
if pred := u.Predicate; pred != nil {
pred(selector)
@@ -825,7 +834,7 @@ type creator struct {
func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
var (
edges = EdgeSpecs(c.Edges).GroupRel()
insert = c.builder.Insert(c.Table).Default()
insert = c.builder.Insert(c.Table).Schema(c.Schema).Default()
)
// Set and create the node.
if err := c.setTableColumns(insert, edges); err != nil {

View File

@@ -128,6 +128,154 @@ func TestNeighbors(t *testing.T) {
wantQuery: "SELECT * FROM `groups` JOIN (SELECT `user_groups`.`group_id` FROM `user_groups` WHERE `user_groups`.`user_id` = ?) AS `t1` ON `groups`.`id` = `t1`.`group_id`",
wantArgs: []interface{}{2},
},
{
name: "schema/O2O/1type",
// Since the relation is on the same sql.Table,
// V used as a reference value.
input: func() *Step {
step := NewStep(
From("users", "id", 1),
To("users", "id"),
Edge(O2O, false, "users", "spouse_id"),
)
step.To.Schema = "mydb"
return step
}(),
wantQuery: "SELECT * FROM `mydb`.`users` WHERE `spouse_id` = ?",
wantArgs: []interface{}{1},
},
{
name: "schema/O2O/1type/inverse",
input: func() *Step {
step := NewStep(
From("nodes", "id", 1),
To("nodes", "id"),
Edge(O2O, true, "nodes", "prev_id"),
)
step.To.Schema = "mydb"
step.Edge.Schema = "mydb"
return step
}(),
wantQuery: "SELECT * FROM `mydb`.`nodes` JOIN (SELECT `prev_id` FROM `mydb`.`nodes` WHERE `id` = ?) AS `t1` ON `mydb`.`nodes`.`id` = `t1`.`prev_id`",
wantArgs: []interface{}{1},
},
{
name: "schema/O2M/1type",
input: func() *Step {
step := NewStep(
From("users", "id", 1),
To("users", "id"),
Edge(O2M, false, "users", "parent_id"),
)
step.To.Schema = "mydb"
return step
}(),
wantQuery: "SELECT * FROM `mydb`.`users` WHERE `parent_id` = ?",
wantArgs: []interface{}{1},
},
{
name: "schema/O2O/2types",
input: func() *Step {
step := NewStep(
From("users", "id", 2),
To("card", "id"),
Edge(O2O, false, "cards", "owner_id"),
)
step.To.Schema = "mydb"
return step
}(),
wantQuery: "SELECT * FROM `mydb`.`card` WHERE `owner_id` = ?",
wantArgs: []interface{}{2},
},
{
name: "schema/O2O/2types/inverse",
input: func() *Step {
step := NewStep(
From("cards", "id", 2),
To("users", "id"),
Edge(O2O, true, "cards", "owner_id"),
)
step.To.Schema = "mydb"
step.Edge.Schema = "mydb"
return step
}(),
wantQuery: "SELECT * FROM `mydb`.`users` JOIN (SELECT `owner_id` FROM `mydb`.`cards` WHERE `id` = ?) AS `t1` ON `mydb`.`users`.`id` = `t1`.`owner_id`",
wantArgs: []interface{}{2},
},
{
name: "schema/O2M/2types",
input: func() *Step {
step := NewStep(
From("users", "id", 1),
To("pets", "id"),
Edge(O2M, false, "pets", "owner_id"),
)
step.To.Schema = "mydb"
return step
}(),
wantQuery: "SELECT * FROM `mydb`.`pets` WHERE `owner_id` = ?",
wantArgs: []interface{}{1},
},
{
name: "schema/M2O/2types/inverse",
input: func() *Step {
step := NewStep(
From("pets", "id", 2),
To("users", "id"),
Edge(M2O, true, "pets", "owner_id"),
)
step.To.Schema = "s1"
step.Edge.Schema = "s2"
return step
}(),
wantQuery: "SELECT * FROM `s1`.`users` JOIN (SELECT `owner_id` FROM `s2`.`pets` WHERE `id` = ?) AS `t1` ON `s1`.`users`.`id` = `t1`.`owner_id`",
wantArgs: []interface{}{2},
},
{
name: "schema/M2O/1type/inverse",
input: func() *Step {
step := NewStep(
From("users", "id", 2),
To("users", "id"),
Edge(M2O, true, "users", "parent_id"),
)
step.To.Schema = "s1"
step.Edge.Schema = "s1"
return step
}(),
wantQuery: "SELECT * FROM `s1`.`users` JOIN (SELECT `parent_id` FROM `s1`.`users` WHERE `id` = ?) AS `t1` ON `s1`.`users`.`id` = `t1`.`parent_id`",
wantArgs: []interface{}{2},
},
{
name: "schema/M2M/2type",
input: func() *Step {
step := NewStep(
From("groups", "id", 2),
To("users", "id"),
Edge(M2M, false, "user_groups", "group_id", "user_id"),
)
step.To.Schema = "s1"
step.Edge.Schema = "s2"
return step
}(),
wantQuery: "SELECT * FROM `s1`.`users` JOIN (SELECT `s2`.`user_groups`.`user_id` FROM `s2`.`user_groups` WHERE `s2`.`user_groups`.`group_id` = ?) AS `t1` ON `s1`.`users`.`id` = `t1`.`user_id`",
wantArgs: []interface{}{2},
},
{
name: "schema/M2M/2type/inverse",
input: func() *Step {
step := NewStep(
From("users", "id", 2),
To("groups", "id"),
Edge(M2M, true, "user_groups", "group_id", "user_id"),
)
step.To.Schema = "s1"
step.Edge.Schema = "s2"
return step
}(),
wantQuery: "SELECT * FROM `s1`.`groups` JOIN (SELECT `s2`.`user_groups`.`group_id` FROM `s2`.`user_groups` WHERE `s2`.`user_groups`.`user_id` = ?) AS `t1` ON `s1`.`groups`.`id` = `t1`.`group_id`",
wantArgs: []interface{}{2},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -204,6 +352,82 @@ JOIN
WHERE "name" = $1) AS "t1" ON "user_groups"."group_id" = "t1"."id") AS "t1" ON "users"."id" = "t1"."user_id"`,
wantArgs: []interface{}{"GitHub"},
},
{
name: "schema/O2M/2types",
input: func() *Step {
step := NewStep(
From("users", "id", sql.Select().From(sql.Table("users").Schema("s2")).Where(sql.EQ("name", "a8m"))),
To("pets", "id"),
Edge(O2M, false, "users", "owner_id"),
)
step.To.Schema = "s1"
return step
}(),
wantQuery: `SELECT * FROM "s1"."pets" JOIN (SELECT "s2"."users"."id" FROM "s2"."users" WHERE "name" = $1) AS "t1" ON "s1"."pets"."owner_id" = "t1"."id"`,
wantArgs: []interface{}{"a8m"},
},
{
name: "schema/M2O/2types",
input: func() *Step {
step := NewStep(
From("pets", "id", sql.Select().From(sql.Table("pets").Schema("s2")).Where(sql.EQ("name", "pedro"))),
To("users", "id"),
Edge(M2O, true, "pets", "owner_id"),
)
step.To.Schema = "s1"
return step
}(),
wantQuery: `SELECT * FROM "s1"."users" JOIN (SELECT "s2"."pets"."owner_id" FROM "s2"."pets" WHERE "name" = $1) AS "t1" ON "s1"."users"."id" = "t1"."owner_id"`,
wantArgs: []interface{}{"pedro"},
},
{
name: "schema/M2M/2types",
input: func() *Step {
step := NewStep(
From("users", "id", sql.Select().From(sql.Table("users").Schema("s2")).Where(sql.EQ("name", "a8m"))),
To("groups", "id"),
Edge(M2M, false, "user_groups", "user_id", "group_id"),
)
step.To.Schema = "s1"
step.Edge.Schema = "s3"
return step
}(),
wantQuery: `
SELECT *
FROM "s1"."groups"
JOIN
(SELECT "s3"."user_groups"."group_id"
FROM "s3"."user_groups"
JOIN
(SELECT "s2"."users"."id"
FROM "s2"."users"
WHERE "name" = $1) AS "t1" ON "s3"."user_groups"."user_id" = "t1"."id") AS "t1" ON "s1"."groups"."id" = "t1"."group_id"`,
wantArgs: []interface{}{"a8m"},
},
{
name: "schema/M2M/2types/inverse",
input: func() *Step {
step := NewStep(
From("groups", "id", sql.Select().From(sql.Table("groups").Schema("s2")).Where(sql.EQ("name", "GitHub"))),
To("users", "id"),
Edge(M2M, true, "user_groups", "user_id", "group_id"),
)
step.To.Schema = "s1"
step.Edge.Schema = "s3"
return step
}(),
wantQuery: `
SELECT *
FROM "s1"."users"
JOIN
(SELECT "s3"."user_groups"."user_id"
FROM "s3"."user_groups"
JOIN
(SELECT "s2"."groups"."id"
FROM "s2"."groups"
WHERE "name" = $1) AS "t1" ON "s3"."user_groups"."group_id" = "t1"."id") AS "t1" ON "s1"."users"."id" = "t1"."user_id"`,
wantArgs: []interface{}{"GitHub"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -288,6 +512,84 @@ func TestHasNeighbors(t *testing.T) {
selector: sql.Select("*").From(sql.Table("users")),
wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `group_users`.`user_id` FROM `group_users`)",
},
{
name: "schema/O2O/1type",
step: func() *Step {
step := NewStep(
From("nodes", "id"),
To("nodes", "id"),
Edge(O2O, false, "nodes", "prev_id"),
)
step.Edge.Schema = "s1"
return step
}(),
selector: sql.Select("*").From(sql.Table("nodes").Schema("s1")),
wantQuery: "SELECT * FROM `s1`.`nodes` WHERE `s1`.`nodes`.`id` IN (SELECT `s1`.`nodes`.`prev_id` FROM `s1`.`nodes` WHERE `s1`.`nodes`.`prev_id` IS NOT NULL)",
},
{
name: "schema/O2O/1type/inverse",
// Same example as above, but the neighbors
// query checks if a node "has-previous".
step: NewStep(
From("nodes", "id"),
To("nodes", "id"),
Edge(O2O, true, "nodes", "prev_id"),
),
selector: sql.Select("*").From(sql.Table("nodes").Schema("s1")),
wantQuery: "SELECT * FROM `s1`.`nodes` WHERE `s1`.`nodes`.`prev_id` IS NOT NULL",
},
{
name: "schema/O2M/2type2",
step: func() *Step {
step := NewStep(
From("users", "id"),
To("pets", "id"),
Edge(O2M, false, "pets", "owner_id"),
)
step.Edge.Schema = "s2"
return step
}(),
selector: sql.Select("*").From(sql.Table("users").Schema("s1")),
wantQuery: "SELECT * FROM `s1`.`users` WHERE `s1`.`users`.`id` IN (SELECT `s2`.`pets`.`owner_id` FROM `s2`.`pets` WHERE `s2`.`pets`.`owner_id` IS NOT NULL)",
},
{
name: "schema/M2O/2type2",
step: NewStep(
From("pets", "id"),
To("users", "id"),
Edge(M2O, true, "pets", "owner_id"),
),
selector: sql.Select("*").From(sql.Table("pets").Schema("s1")),
wantQuery: "SELECT * FROM `s1`.`pets` WHERE `s1`.`pets`.`owner_id` IS NOT NULL",
},
{
name: "schema/M2M/2types",
step: func() *Step {
step := NewStep(
From("users", "id"),
To("groups", "id"),
Edge(M2M, false, "user_groups", "user_id", "group_id"),
)
step.Edge.Schema = "s2"
return step
}(),
selector: sql.Select("*").From(sql.Table("users").Schema("s1")),
wantQuery: "SELECT * FROM `s1`.`users` WHERE `s1`.`users`.`id` IN (SELECT `s2`.`user_groups`.`user_id` FROM `s2`.`user_groups`)",
},
{
name: "schema/M2M/2types/inverse",
step: func() *Step {
step := NewStep(
From("users", "id"),
To("groups", "id"),
Edge(M2M, true, "group_users", "group_id", "user_id"),
)
step.Edge.Schema = "s2"
return step
}(),
selector: sql.Select("*").From(sql.Table("users").Schema("s1")),
wantQuery: "SELECT * FROM `s1`.`users` WHERE `s1`.`users`.`id` IN (SELECT `s2`.`group_users`.`user_id` FROM `s2`.`group_users`)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -428,6 +730,69 @@ WHERE "groups"."id" IN
JOIN "users" AS "t0" ON "user_groups"."user_id" = "t0"."id" WHERE "name" IS NOT NULL AND "name" = $1)`,
wantArgs: []interface{}{"a8m"},
},
{
name: "schema/O2O",
step: func() *Step {
step := NewStep(
From("users", "id"),
To("cards", "id"),
Edge(O2O, false, "cards", "owner_id"),
)
step.Edge.Schema = "s2"
return step
}(),
selector: sql.Dialect("postgres").Select("*").From(sql.Table("users").Schema("s1")),
predicate: func(s *sql.Selector) {
s.Where(sql.EQ("expired", false))
},
wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE "expired" = $1)`,
wantArgs: []interface{}{false},
},
{
name: "schema/O2M",
step: func() *Step {
step := NewStep(
From("users", "id"),
To("pets", "id"),
Edge(O2M, false, "pets", "owner_id"),
)
step.Edge.Schema = "s2"
return step
}(),
selector: sql.Dialect("postgres").Select("*").
From(sql.Table("users").Schema("s1")).
Where(sql.EQ("last_name", "mashraki")),
predicate: func(s *sql.Selector) {
s.Where(sql.EQ("name", "pedro"))
},
wantQuery: `SELECT * FROM "s1"."users" WHERE "last_name" = $1 AND "s1"."users"."id" IN (SELECT "s2"."pets"."owner_id" FROM "s2"."pets" WHERE "name" = $2)`,
wantArgs: []interface{}{"mashraki", "pedro"},
},
{
name: "schema/M2M",
step: func() *Step {
step := NewStep(
From("users", "id"),
To("groups", "id"),
Edge(M2M, false, "user_groups", "user_id", "group_id"),
)
step.To.Schema = "s3"
step.Edge.Schema = "s2"
return step
}(),
selector: sql.Dialect("postgres").Select("*").From(sql.Table("users").Schema("s1")),
predicate: func(s *sql.Selector) {
s.Where(sql.EQ("name", "GitHub"))
},
wantQuery: `
SELECT *
FROM "s1"."users"
WHERE "s1"."users"."id" IN
(SELECT "s2"."user_groups"."user_id"
FROM "s2"."user_groups"
JOIN "s3"."groups" AS "t0" ON "s2"."user_groups"."group_id" = "t0"."id" WHERE "name" = $1)`,
wantArgs: []interface{}{"GitHub"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -730,6 +1095,25 @@ func TestCreateNode(t *testing.T) {
m.ExpectCommit()
},
},
{
name: "schema",
spec: &CreateSpec{
Table: "users",
Schema: "mydb",
ID: &FieldSpec{Column: "id"},
Fields: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 30},
{Column: "name", Type: field.TypeString, Value: "a8m"},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `mydb`.`users` (`age`, `name`) VALUES (?, ?)")).
WithArgs(30, "a8m").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -1085,6 +1469,35 @@ func TestUpdateNode(t *testing.T) {
},
wantUser: &user{age: 31, id: 1},
},
{
name: "schema/fields/set",
spec: &UpdateSpec{
Node: &NodeSpec{
Table: "users",
Schema: "mydb",
Columns: []string{"id", "name", "age"},
ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1},
},
Fields: FieldMut{
Set: []*FieldSpec{
{Column: "age", Type: field.TypeInt, Value: 30},
{Column: "name", Type: field.TypeString, Value: "Ariel"},
},
},
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec(escape("UPDATE `mydb`.`users` SET `age` = ?, `name` = ? WHERE `id` = ?")).
WithArgs(30, "Ariel", 1).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `mydb`.`users` WHERE `id` = ?")).
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
AddRow(1, 30, "Ariel"))
mock.ExpectCommit()
},
wantUser: &user{name: "Ariel", age: 30, id: 1},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -1317,6 +1730,24 @@ func TestDeleteNodes(t *testing.T) {
require.Equal(t, 2, affected)
}
func TestDeleteNodesSchema(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
mock.ExpectBegin()
mock.ExpectExec(escape("DELETE FROM `mydb`.`users`")).
WillReturnResult(sqlmock.NewResult(0, 2))
mock.ExpectCommit()
affected, err := DeleteNodes(context.Background(), sql.OpenDB("", db), &DeleteSpec{
Node: &NodeSpec{
Table: "users",
Schema: "mydb",
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
},
})
require.NoError(t, err)
require.Equal(t, 2, affected)
}
func TestQueryNodes(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
@@ -1372,6 +1803,53 @@ func TestQueryNodes(t *testing.T) {
require.Equal(t, 3, n)
}
func TestQueryNodesSchema(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
mock.ExpectQuery(escape("SELECT DISTINCT `mydb`.`users`.`id`, `mydb`.`users`.`age`, `mydb`.`users`.`name`, `mydb`.`users`.`fk1`, `mydb`.`users`.`fk2` FROM `mydb`.`users` WHERE `age` < ? ORDER BY `id` LIMIT 3 OFFSET 4")).
WithArgs(40).
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name", "fk1", "fk2"}).
AddRow(1, 10, nil, nil, nil).
AddRow(2, 20, "", 0, 0).
AddRow(3, 30, "a8m", 1, 1))
var (
users []*user
spec = &QuerySpec{
Node: &NodeSpec{
Table: "users",
Schema: "mydb",
Columns: []string{"id", "age", "name", "fk1", "fk2"},
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
},
Limit: 3,
Offset: 4,
Unique: true,
Order: func(s *sql.Selector) {
s.OrderBy("id")
},
Predicate: func(s *sql.Selector) {
s.Where(sql.LT("age", 40))
},
ScanValues: func(columns []string) ([]interface{}, error) {
u := &user{}
users = append(users, u)
return u.values(columns)
},
Assign: func(columns []string, values []interface{}) error {
return users[len(users)-1].assign(columns, values)
},
}
)
// Query and scan.
err = QueryNodes(context.Background(), sql.OpenDB("", db), spec)
require.NoError(t, err)
require.Equal(t, &user{id: 1, age: 10, name: ""}, users[0])
require.Equal(t, &user{id: 2, age: 20, name: ""}, users[1])
require.Equal(t, &user{id: 3, age: 30, name: "a8m", edges: struct{ fk1, fk2 int }{1, 1}}, users[2])
}
func TestQueryEdges(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
@@ -1409,6 +1887,44 @@ func TestQueryEdges(t *testing.T) {
require.Equal(t, [][]int64{{4, 5}, {4, 6}}, edges)
}
func TestQueryEdgesSchema(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
mock.ExpectQuery(escape("SELECT `group_id`, `user_id` FROM `mydb`.`user_groups` WHERE `user_id` IN (?, ?, ?)")).
WithArgs(1, 2, 3).
WillReturnRows(sqlmock.NewRows([]string{"group_id", "user_id"}).
AddRow(4, 5).
AddRow(4, 6))
var (
edges [][]int64
spec = &EdgeQuerySpec{
Edge: &EdgeSpec{
Inverse: true,
Table: "user_groups",
Schema: "mydb",
Columns: []string{"user_id", "group_id"},
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues("user_id", 1, 2, 3))
},
ScanValues: func() [2]interface{} {
return [2]interface{}{&sql.NullInt64{}, &sql.NullInt64{}}
},
Assign: func(out, in interface{}) error {
o, i := out.(*sql.NullInt64), in.(*sql.NullInt64)
edges = append(edges, []int64{o.Int64, i.Int64})
return nil
},
}
)
// Query and scan.
err = QueryEdges(context.Background(), sql.OpenDB("", db), spec)
require.NoError(t, err)
require.Equal(t, [][]int64{{4, 5}, {4, 6}}, edges)
}
func escape(query string) string {
rows := strings.Split(query, "\n")
for i := range rows {