mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
all: move sqlgraph to its own package
This commit is contained in:
@@ -1314,6 +1314,12 @@ func (s *Selector) Distinct() *Selector {
|
||||
return s
|
||||
}
|
||||
|
||||
// SetDistinct sets explicitly if the returned rows are distinct or indistinct.
|
||||
func (s *Selector) SetDistinct(v bool) *Selector {
|
||||
s.distinct = v
|
||||
return s
|
||||
}
|
||||
|
||||
// Limit adds the `LIMIT` clause to the `SELECT` statement.
|
||||
func (s *Selector) Limit(limit int) *Selector {
|
||||
s.limit = &limit
|
||||
@@ -1897,14 +1903,14 @@ type state interface {
|
||||
SetTotal(int)
|
||||
}
|
||||
|
||||
// dialectBuilder prefixes all root builders with the `Dialect` constructor.
|
||||
type dialectBuilder struct {
|
||||
// DialectBuilder prefixes all root builders with the `Dialect` constructor.
|
||||
type DialectBuilder struct {
|
||||
dialect string
|
||||
}
|
||||
|
||||
// Dialect creates a new dialectBuilder with the given dialect name.
|
||||
func Dialect(name string) *dialectBuilder {
|
||||
return &dialectBuilder{name}
|
||||
// Dialect creates a new DialectBuilder with the given dialect name.
|
||||
func Dialect(name string) *DialectBuilder {
|
||||
return &DialectBuilder{name}
|
||||
}
|
||||
|
||||
// Describe creates a DescribeBuilder for the configured dialect.
|
||||
@@ -1912,7 +1918,7 @@ func Dialect(name string) *dialectBuilder {
|
||||
// Dialect(dialect.Postgres).
|
||||
// Describe("users")
|
||||
//
|
||||
func (d *dialectBuilder) Describe(name string) *DescribeBuilder {
|
||||
func (d *DialectBuilder) Describe(name string) *DescribeBuilder {
|
||||
b := Describe(name)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -1928,7 +1934,7 @@ func (d *dialectBuilder) Describe(name string) *DescribeBuilder {
|
||||
// ).
|
||||
// PrimaryKey("id")
|
||||
//
|
||||
func (d *dialectBuilder) CreateTable(name string) *TableBuilder {
|
||||
func (d *DialectBuilder) CreateTable(name string) *TableBuilder {
|
||||
b := CreateTable(name)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -1944,7 +1950,7 @@ func (d *dialectBuilder) CreateTable(name string) *TableBuilder {
|
||||
// OnDelete("CASCADE"),
|
||||
// )
|
||||
//
|
||||
func (d *dialectBuilder) AlterTable(name string) *TableAlter {
|
||||
func (d *DialectBuilder) AlterTable(name string) *TableAlter {
|
||||
b := AlterTable(name)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -1955,7 +1961,7 @@ func (d *dialectBuilder) AlterTable(name string) *TableAlter {
|
||||
// Dialect(dialect.Postgres)..
|
||||
// Column("group_id").Type("int").Attr("UNIQUE")
|
||||
//
|
||||
func (d *dialectBuilder) Column(name string) *ColumnBuilder {
|
||||
func (d *DialectBuilder) Column(name string) *ColumnBuilder {
|
||||
b := Column(name)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -1966,7 +1972,7 @@ func (d *dialectBuilder) Column(name string) *ColumnBuilder {
|
||||
// Dialect(dialect.Postgres).
|
||||
// Insert("users").Columns("age").Values(1)
|
||||
//
|
||||
func (d *dialectBuilder) Insert(table string) *InsertBuilder {
|
||||
func (d *DialectBuilder) Insert(table string) *InsertBuilder {
|
||||
b := Insert(table)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -1977,7 +1983,7 @@ func (d *dialectBuilder) Insert(table string) *InsertBuilder {
|
||||
// Dialect(dialect.Postgres).
|
||||
// Update("users").Set("name", "foo")
|
||||
//
|
||||
func (d *dialectBuilder) Update(table string) *UpdateBuilder {
|
||||
func (d *DialectBuilder) Update(table string) *UpdateBuilder {
|
||||
b := Update(table)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -1988,7 +1994,7 @@ func (d *dialectBuilder) Update(table string) *UpdateBuilder {
|
||||
// Dialect(dialect.Postgres).
|
||||
// Delete().From("users")
|
||||
//
|
||||
func (d *dialectBuilder) Delete(table string) *DeleteBuilder {
|
||||
func (d *DialectBuilder) Delete(table string) *DeleteBuilder {
|
||||
b := Delete(table)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -1999,7 +2005,7 @@ func (d *dialectBuilder) Delete(table string) *DeleteBuilder {
|
||||
// Dialect(dialect.Postgres).
|
||||
// Select().From(Table("users"))
|
||||
//
|
||||
func (d *dialectBuilder) Select(columns ...string) *Selector {
|
||||
func (d *DialectBuilder) Select(columns ...string) *Selector {
|
||||
b := Select(columns...)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -2010,7 +2016,7 @@ func (d *dialectBuilder) Select(columns ...string) *Selector {
|
||||
// Dialect(dialect.Postgres).
|
||||
// Table("users").As("u")
|
||||
//
|
||||
func (d *dialectBuilder) Table(name string) *SelectTable {
|
||||
func (d *DialectBuilder) Table(name string) *SelectTable {
|
||||
b := Table(name)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -2022,7 +2028,7 @@ func (d *dialectBuilder) Table(name string) *SelectTable {
|
||||
// With("users_view").
|
||||
// As(Select().From(Table("users")))
|
||||
//
|
||||
func (d *dialectBuilder) With(name string) *WithBuilder {
|
||||
func (d *DialectBuilder) With(name string) *WithBuilder {
|
||||
b := With(name)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -2036,7 +2042,7 @@ func (d *dialectBuilder) With(name string) *WithBuilder {
|
||||
// Table("users").
|
||||
// Columns("first", "last")
|
||||
//
|
||||
func (d *dialectBuilder) CreateIndex(name string) *IndexBuilder {
|
||||
func (d *DialectBuilder) CreateIndex(name string) *IndexBuilder {
|
||||
b := CreateIndex(name)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
@@ -2047,7 +2053,7 @@ func (d *dialectBuilder) CreateIndex(name string) *IndexBuilder {
|
||||
// Dialect(dialect.Postgres).
|
||||
// DropIndex("name")
|
||||
//
|
||||
func (d *dialectBuilder) DropIndex(name string) *DropIndexBuilder {
|
||||
func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder {
|
||||
b := DropIndex(name)
|
||||
b.SetDialect(d.dialect)
|
||||
return b
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package sql
|
||||
// sqlgraph provides graph abstraction capabilities on top
|
||||
// of sql-based databases for ent codegen.
|
||||
package sqlgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"sort"
|
||||
|
||||
"github.com/facebookincubator/ent/dialect"
|
||||
"github.com/facebookincubator/ent/dialect/sql"
|
||||
"github.com/facebookincubator/ent/schema/field"
|
||||
)
|
||||
|
||||
@@ -22,7 +24,7 @@ type Rel int
|
||||
|
||||
// Relation types.
|
||||
const (
|
||||
Unk Rel = iota // Unknown.
|
||||
_ Rel = iota // Unknown.
|
||||
O2O // One to one / has one.
|
||||
O2M // One to many / has many.
|
||||
M2O // Many to one (inverse perspective for O2M).
|
||||
@@ -130,8 +132,8 @@ func NewStep(opts ...StepOption) *Step {
|
||||
|
||||
// Neighbors returns a Selector for evaluating the path-step
|
||||
// and getting the neighbors of one vertex.
|
||||
func Neighbors(dialect string, s *Step) (q *Selector) {
|
||||
builder := Dialect(dialect)
|
||||
func Neighbors(dialect string, s *Step) (q *sql.Selector) {
|
||||
builder := sql.Dialect(dialect)
|
||||
switch r := s.Edge.Rel; {
|
||||
case r == M2M:
|
||||
pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0]
|
||||
@@ -142,7 +144,7 @@ func Neighbors(dialect string, s *Step) (q *Selector) {
|
||||
join := builder.Table(s.Edge.Table)
|
||||
match := builder.Select(join.C(pk1)).
|
||||
From(join).
|
||||
Where(EQ(join.C(pk2), s.From.V))
|
||||
Where(sql.EQ(join.C(pk2), s.From.V))
|
||||
q = builder.Select().
|
||||
From(to).
|
||||
Join(match).
|
||||
@@ -151,7 +153,7 @@ func Neighbors(dialect string, s *Step) (q *Selector) {
|
||||
t1 := builder.Table(s.To.Table)
|
||||
t2 := builder.Select(s.Edge.Columns[0]).
|
||||
From(builder.Table(s.Edge.Table)).
|
||||
Where(EQ(s.From.Column, s.From.V))
|
||||
Where(sql.EQ(s.From.Column, s.From.V))
|
||||
q = builder.Select().
|
||||
From(t1).
|
||||
Join(t2).
|
||||
@@ -159,16 +161,16 @@ func Neighbors(dialect string, s *Step) (q *Selector) {
|
||||
case r == O2M || (r == O2O && !s.Edge.Inverse):
|
||||
q = builder.Select().
|
||||
From(builder.Table(s.To.Table)).
|
||||
Where(EQ(s.Edge.Columns[0], s.From.V))
|
||||
Where(sql.EQ(s.Edge.Columns[0], s.From.V))
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// SetNeighbors returns a Selector for evaluating the path-step
|
||||
// and getting the neighbors of set of vertices.
|
||||
func SetNeighbors(dialect string, s *Step) (q *Selector) {
|
||||
set := s.From.V.(*Selector)
|
||||
builder := Dialect(dialect)
|
||||
func SetNeighbors(dialect string, s *Step) (q *sql.Selector) {
|
||||
set := s.From.V.(*sql.Selector)
|
||||
builder := sql.Dialect(dialect)
|
||||
switch r := s.Edge.Rel; {
|
||||
case r == M2M:
|
||||
pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0]
|
||||
@@ -205,8 +207,8 @@ func SetNeighbors(dialect string, s *Step) (q *Selector) {
|
||||
}
|
||||
|
||||
// HasNeighbors applies on the given Selector a neighbors check.
|
||||
func HasNeighbors(q *Selector, s *Step) {
|
||||
builder := Dialect(q.dialect)
|
||||
func HasNeighbors(q *sql.Selector, s *Step) {
|
||||
builder := sql.Dialect(q.Dialect())
|
||||
switch r := s.Edge.Rel; {
|
||||
case r == M2M:
|
||||
pk1 := s.Edge.Columns[0]
|
||||
@@ -216,23 +218,23 @@ func HasNeighbors(q *Selector, s *Step) {
|
||||
from := q.Table()
|
||||
join := builder.Table(s.Edge.Table)
|
||||
q.Where(
|
||||
In(
|
||||
sql.In(
|
||||
from.C(s.From.Column),
|
||||
builder.Select(join.C(pk1)).From(join),
|
||||
),
|
||||
)
|
||||
case r == M2O || (r == O2O && s.Edge.Inverse):
|
||||
from := q.Table()
|
||||
q.Where(NotNull(from.C(s.Edge.Columns[0])))
|
||||
q.Where(sql.NotNull(from.C(s.Edge.Columns[0])))
|
||||
case r == O2M || (r == O2O && !s.Edge.Inverse):
|
||||
from := q.Table()
|
||||
to := builder.Table(s.Edge.Table)
|
||||
q.Where(
|
||||
In(
|
||||
sql.In(
|
||||
from.C(s.From.Column),
|
||||
builder.Select(to.C(s.Edge.Columns[0])).
|
||||
From(to).
|
||||
Where(NotNull(to.C(s.Edge.Columns[0]))),
|
||||
Where(sql.NotNull(to.C(s.Edge.Columns[0]))),
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -240,8 +242,8 @@ func HasNeighbors(q *Selector, s *Step) {
|
||||
|
||||
// HasNeighborsWith applies on the given Selector a neighbors check.
|
||||
// The given predicate applies its filtering on the selector.
|
||||
func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) {
|
||||
builder := Dialect(q.dialect)
|
||||
func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
|
||||
builder := sql.Dialect(q.Dialect())
|
||||
switch r := s.Edge.Rel; {
|
||||
case r == M2M:
|
||||
pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0]
|
||||
@@ -258,21 +260,21 @@ func HasNeighborsWith(q *Selector, s *Step, pred func(*Selector)) {
|
||||
matches := builder.Select().From(to)
|
||||
pred(matches)
|
||||
join.FromSelect(matches)
|
||||
q.Where(In(from.C(s.From.Column), join))
|
||||
q.Where(sql.In(from.C(s.From.Column), join))
|
||||
case r == M2O || (r == O2O && s.Edge.Inverse):
|
||||
from := q.Table()
|
||||
to := builder.Table(s.To.Table)
|
||||
matches := builder.Select(to.C(s.To.Column)).
|
||||
From(to)
|
||||
pred(matches)
|
||||
q.Where(In(from.C(s.Edge.Columns[0]), matches))
|
||||
q.Where(sql.In(from.C(s.Edge.Columns[0]), matches))
|
||||
case r == O2M || (r == O2O && !s.Edge.Inverse):
|
||||
from := q.Table()
|
||||
to := builder.Table(s.Edge.Table)
|
||||
matches := builder.Select(to.C(s.Edge.Columns[0])).
|
||||
From(to)
|
||||
pred(matches)
|
||||
q.Where(In(from.C(s.From.Column), matches))
|
||||
q.Where(sql.In(from.C(s.From.Column), matches))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,7 +332,7 @@ func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gr := graph{tx: tx, builder: Dialect(drv.Dialect())}
|
||||
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
|
||||
cr := &creator{CreateSpec: spec, graph: gr}
|
||||
if err := cr.node(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
@@ -358,7 +360,7 @@ type (
|
||||
Node *NodeSpec
|
||||
Edges EdgeMut
|
||||
Fields FieldMut
|
||||
Predicate func(*Selector)
|
||||
Predicate func(*sql.Selector)
|
||||
|
||||
ScanTypes []interface{}
|
||||
Assign func(...interface{}) error
|
||||
@@ -371,7 +373,7 @@ func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gr := graph{tx: tx, builder: Dialect(drv.Dialect())}
|
||||
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
|
||||
cr := &updater{UpdateSpec: spec, graph: gr}
|
||||
if err := cr.node(ctx, tx); err != nil {
|
||||
return rollback(tx, err)
|
||||
@@ -385,7 +387,7 @@ func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
gr := graph{tx: tx, builder: Dialect(drv.Dialect())}
|
||||
gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())}
|
||||
cr := &updater{UpdateSpec: spec, graph: gr}
|
||||
affected, err := cr.nodes(ctx, tx)
|
||||
if err != nil {
|
||||
@@ -398,7 +400,7 @@ func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int
|
||||
// or more nodes in the graph.
|
||||
type DeleteSpec struct {
|
||||
Node *NodeSpec
|
||||
Predicate func(*Selector)
|
||||
Predicate func(*sql.Selector)
|
||||
}
|
||||
|
||||
// DeleteNodes applies the DeleteSpec on the graph.
|
||||
@@ -408,8 +410,8 @@ func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int
|
||||
return 0, err
|
||||
}
|
||||
var (
|
||||
res Result
|
||||
builder = Dialect(drv.Dialect())
|
||||
res sql.Result
|
||||
builder = sql.Dialect(drv.Dialect())
|
||||
)
|
||||
selector := builder.Select().
|
||||
From(builder.Table(spec.Node.Table))
|
||||
@@ -430,14 +432,14 @@ func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int
|
||||
// QuerySpec holds the information for querying
|
||||
// nodes in the graph.
|
||||
type QuerySpec struct {
|
||||
Node *NodeSpec // Nodes info.
|
||||
From *Selector // Optional query source (from path).
|
||||
Node *NodeSpec // Nodes info.
|
||||
From *sql.Selector // Optional query source (from path).
|
||||
|
||||
Limit int
|
||||
Offset int
|
||||
Unique bool
|
||||
Order func(*Selector)
|
||||
Predicate func(*Selector)
|
||||
Order func(*sql.Selector)
|
||||
Predicate func(*sql.Selector)
|
||||
|
||||
ScanValues func() []interface{}
|
||||
Assign func(...interface{}) error
|
||||
@@ -445,14 +447,14 @@ type QuerySpec struct {
|
||||
|
||||
// QueryNodes query the nodes in the graph query and scans them to the given values.
|
||||
func QueryNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) error {
|
||||
builder := Dialect(drv.Dialect())
|
||||
builder := sql.Dialect(drv.Dialect())
|
||||
qr := &query{graph: graph{builder: builder}, QuerySpec: spec}
|
||||
return qr.nodes(ctx, drv)
|
||||
}
|
||||
|
||||
// CountNodes counts the nodes in the given graph query.
|
||||
func CountNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) (int, error) {
|
||||
builder := Dialect(drv.Dialect())
|
||||
builder := sql.Dialect(drv.Dialect())
|
||||
qr := &query{graph: graph{builder: builder}, QuerySpec: spec}
|
||||
return qr.count(ctx, drv)
|
||||
}
|
||||
@@ -463,7 +465,7 @@ type query struct {
|
||||
}
|
||||
|
||||
func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
rows := &Rows{}
|
||||
rows := &sql.Rows{}
|
||||
query, args := q.selector().Query()
|
||||
if err := drv.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
@@ -482,21 +484,21 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
}
|
||||
|
||||
func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
|
||||
rows := &Rows{}
|
||||
rows := &sql.Rows{}
|
||||
selector := q.selector().Count(q.Node.ID.Column)
|
||||
if q.Unique {
|
||||
selector.distinct = false
|
||||
selector.Count(Distinct(q.Node.ID.Column))
|
||||
selector.SetDistinct(false)
|
||||
selector.Count(sql.Distinct(q.Node.ID.Column))
|
||||
}
|
||||
query, args := selector.Query()
|
||||
if err := drv.Query(ctx, query, args, rows); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return ScanInt(rows)
|
||||
return sql.ScanInt(rows)
|
||||
}
|
||||
|
||||
func (q *query) selector() *Selector {
|
||||
func (q *query) selector() *sql.Selector {
|
||||
selector := q.builder.Select().From(q.builder.Table(q.Node.Table))
|
||||
if q.From != nil {
|
||||
selector = q.From
|
||||
@@ -535,12 +537,12 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
addEdges = EdgeSpecs(u.Edges.Add).GroupRel()
|
||||
clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel()
|
||||
)
|
||||
update := u.builder.Update(u.Node.Table).Where(EQ(u.Node.ID.Column, id))
|
||||
update := u.builder.Update(u.Node.Table).Where(sql.EQ(u.Node.ID.Column, id))
|
||||
if err := u.setTableColumns(update, addEdges, clearEdges); err != nil {
|
||||
return err
|
||||
}
|
||||
if !update.Empty() {
|
||||
var res Result
|
||||
var res sql.Result
|
||||
query, args := update.Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return err
|
||||
@@ -551,8 +553,8 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
}
|
||||
selector := u.builder.Select(u.Node.Columns...).
|
||||
From(u.builder.Table(u.Node.Table)).
|
||||
Where(EQ(u.Node.ID.Column, u.Node.ID.Value))
|
||||
rows := &Rows{}
|
||||
Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value))
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := tx.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
@@ -572,12 +574,12 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
|
||||
pred(selector)
|
||||
}
|
||||
query, args := selector.Query()
|
||||
rows := &Rows{}
|
||||
rows := &sql.Rows{}
|
||||
if err := u.tx.Query(ctx, query, args, rows); err != nil {
|
||||
return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := ScanSlice(rows, &ids); err != nil {
|
||||
if err := sql.ScanSlice(rows, &ids); err != nil {
|
||||
return 0, fmt.Errorf("scan node ids: %v", err)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
@@ -591,7 +593,7 @@ func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error
|
||||
return 0, err
|
||||
}
|
||||
if !update.Empty() {
|
||||
var res Result
|
||||
var res sql.Result
|
||||
query, args := update.Query()
|
||||
if err := tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return 0, err
|
||||
@@ -620,7 +622,7 @@ func (u *updater) setExternalEdges(ctx context.Context, ids []driver.Value, addE
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (u *updater) setTableColumns(update *UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error {
|
||||
func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error {
|
||||
for _, fi := range u.Fields.Clear {
|
||||
update.SetNull(fi.Column)
|
||||
}
|
||||
@@ -644,7 +646,7 @@ func (u *updater) setTableColumns(update *UpdateBuilder, addEdges, clearEdges ma
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) scan(rows *Rows) error {
|
||||
func (u *updater) scan(rows *sql.Rows) error {
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return fmt.Errorf("record with id %v not found in table %s", u.Node.ID.Value, u.Node.Table)
|
||||
@@ -682,7 +684,7 @@ func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error {
|
||||
}
|
||||
|
||||
// setTableColumns sets the table columns and foreign_keys used in insert.
|
||||
func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSpec) error {
|
||||
func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error {
|
||||
err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) {
|
||||
insert.Set(column, value)
|
||||
})
|
||||
@@ -690,7 +692,7 @@ func (c *creator) setTableColumns(insert *InsertBuilder, edges map[Rel][]*EdgeSp
|
||||
}
|
||||
|
||||
// insert inserts the node to its table and sets its ID if it wasn't provided by the user.
|
||||
func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) error {
|
||||
func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error {
|
||||
var res sql.Result
|
||||
// If the id field was provided by the user.
|
||||
if c.ID.Value != nil {
|
||||
@@ -733,19 +735,19 @@ func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec {
|
||||
// but use UPDATE queries (fk = ?, fk = NULL).
|
||||
type graph struct {
|
||||
tx dialect.ExecQuerier
|
||||
builder *dialectBuilder
|
||||
builder *sql.DialectBuilder
|
||||
}
|
||||
|
||||
func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
|
||||
var (
|
||||
res Result
|
||||
res sql.Result
|
||||
// Delete all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
)
|
||||
for _, table := range sortedKeys(tables) {
|
||||
edges := tables[table]
|
||||
preds := make([]*Predicate, 0, len(edges))
|
||||
preds := make([]*sql.Predicate, 0, len(edges))
|
||||
for _, edge := range edges {
|
||||
pk1, pk2 := ids, edge.Target.Nodes
|
||||
if edge.Inverse {
|
||||
@@ -756,7 +758,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
|
||||
preds = append(preds, matchIDs(edge.Columns[0], pk2, edge.Columns[1], pk1))
|
||||
}
|
||||
}
|
||||
query, args := g.builder.Delete(table).Where(Or(preds...)).Query()
|
||||
query, args := g.builder.Delete(table).Where(sql.Or(preds...)).Query()
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("remove m2m edge for table %s: %v", table, err)
|
||||
}
|
||||
@@ -766,7 +768,7 @@ func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges Edg
|
||||
|
||||
func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error {
|
||||
var (
|
||||
res Result
|
||||
res sql.Result
|
||||
// Insert all M2M edges from the same type at once.
|
||||
// The EdgeSpec is the same for all members in a group.
|
||||
tables = edges.GroupTable()
|
||||
@@ -809,7 +811,7 @@ func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*E
|
||||
SetNull(edge.Columns[0]).
|
||||
Where(pred).
|
||||
Query()
|
||||
var res Result
|
||||
var res sql.Result
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add %s edge for table %s: %v", edge.Rel, edge.Table, err)
|
||||
}
|
||||
@@ -828,17 +830,17 @@ func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*Edg
|
||||
if edge.Rel == O2O && edge.Inverse {
|
||||
continue
|
||||
}
|
||||
p := EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
|
||||
p := sql.EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0])
|
||||
// Use "IN" predicate instead of list of "OR"
|
||||
// in case of more than on nodes to connect.
|
||||
if len(edge.Target.Nodes) > 1 {
|
||||
p = InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
|
||||
p = sql.InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...)
|
||||
}
|
||||
query, args := g.builder.Update(edge.Table).
|
||||
Set(edge.Columns[0], id).
|
||||
Where(And(p, IsNull(edge.Columns[0]))).
|
||||
Where(sql.And(p, sql.IsNull(edge.Columns[0]))).
|
||||
Query()
|
||||
var res Result
|
||||
var res sql.Result
|
||||
if err := g.tx.Exec(ctx, query, args, &res); err != nil {
|
||||
return fmt.Errorf("add %s edge for table %s: %v", edge.Rel, edge.Table, err)
|
||||
}
|
||||
@@ -876,7 +878,7 @@ func setTableColumns(fields []*FieldSpec, edges map[Rel][]*EdgeSpec, set func(st
|
||||
}
|
||||
|
||||
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
|
||||
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBuilder) (int64, error) {
|
||||
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (int64, error) {
|
||||
query, args := insert.Query()
|
||||
// PostgreSQL does not support the LastInsertId() method of sql.Result
|
||||
// on Exec, and should be extracted manually using the `RETURNING` clause.
|
||||
@@ -886,7 +888,7 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *InsertBui
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return ScanInt64(rows)
|
||||
return sql.ScanInt64(rows)
|
||||
}
|
||||
// MySQL, SQLite, etc.
|
||||
var res sql.Result
|
||||
@@ -913,14 +915,14 @@ func sortedKeys(m map[string][]*EdgeSpec) []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
func matchID(column string, pk []driver.Value) *Predicate {
|
||||
func matchID(column string, pk []driver.Value) *sql.Predicate {
|
||||
if len(pk) > 1 {
|
||||
return InValues(column, pk...)
|
||||
return sql.InValues(column, pk...)
|
||||
}
|
||||
return EQ(column, pk[0])
|
||||
return sql.EQ(column, pk[0])
|
||||
}
|
||||
|
||||
func matchIDs(column1 string, pk1 []driver.Value, column2 string, pk2 []driver.Value) *Predicate {
|
||||
func matchIDs(column1 string, pk1 []driver.Value, column2 string, pk2 []driver.Value) *sql.Predicate {
|
||||
p := matchID(column1, pk1)
|
||||
if len(pk2) > 1 {
|
||||
// Use "IN" predicate instead of list of "OR"
|
||||
@@ -2,7 +2,7 @@
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package sql
|
||||
package sqlgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/facebookincubator/ent/dialect/sql"
|
||||
"github.com/facebookincubator/ent/schema/field"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
@@ -26,7 +27,7 @@ func TestNeighbors(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "O2O/1type",
|
||||
// Since the relation is on the same table,
|
||||
// Since the relation is on the same sql.Table,
|
||||
// V used as a reference value.
|
||||
input: NewStep(
|
||||
From("users", "id", 1),
|
||||
@@ -147,7 +148,7 @@ func TestSetNeighbors(t *testing.T) {
|
||||
{
|
||||
name: "O2M/2types",
|
||||
input: NewStep(
|
||||
From("users", "id", Select().From(Table("users")).Where(EQ("name", "a8m"))),
|
||||
From("users", "id", sql.Select().From(sql.Table("users")).Where(sql.EQ("name", "a8m"))),
|
||||
To("pets", "id"),
|
||||
Edge(O2M, false, "users", "owner_id"),
|
||||
),
|
||||
@@ -157,7 +158,7 @@ func TestSetNeighbors(t *testing.T) {
|
||||
{
|
||||
name: "M2O/2types",
|
||||
input: NewStep(
|
||||
From("pets", "id", Select().From(Table("pets")).Where(EQ("name", "pedro"))),
|
||||
From("pets", "id", sql.Select().From(sql.Table("pets")).Where(sql.EQ("name", "pedro"))),
|
||||
To("users", "id"),
|
||||
Edge(M2O, true, "pets", "owner_id"),
|
||||
),
|
||||
@@ -167,7 +168,7 @@ func TestSetNeighbors(t *testing.T) {
|
||||
{
|
||||
name: "M2M/2types",
|
||||
input: NewStep(
|
||||
From("users", "id", Select().From(Table("users")).Where(EQ("name", "a8m"))),
|
||||
From("users", "id", sql.Select().From(sql.Table("users")).Where(sql.EQ("name", "a8m"))),
|
||||
To("groups", "id"),
|
||||
Edge(M2M, false, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
@@ -186,7 +187,7 @@ JOIN
|
||||
{
|
||||
name: "M2M/2types/inverse",
|
||||
input: NewStep(
|
||||
From("groups", "id", Select().From(Table("groups")).Where(EQ("name", "GitHub"))),
|
||||
From("groups", "id", sql.Select().From(sql.Table("groups")).Where(sql.EQ("name", "GitHub"))),
|
||||
To("users", "id"),
|
||||
Edge(M2M, true, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
@@ -218,12 +219,12 @@ func TestHasNeighbors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
step *Step
|
||||
selector *Selector
|
||||
selector *sql.Selector
|
||||
wantQuery string
|
||||
}{
|
||||
{
|
||||
name: "O2O/1type",
|
||||
// A nodes table; linked-list (next->prev). The "prev"
|
||||
// A nodes sql.Table; linked-list (next->prev). The "prev"
|
||||
// node holds association pointer. The neighbors query
|
||||
// here checks if a node "has-next".
|
||||
step: NewStep(
|
||||
@@ -231,7 +232,7 @@ func TestHasNeighbors(t *testing.T) {
|
||||
To("nodes", "id"),
|
||||
Edge(O2O, false, "nodes", "prev_id"),
|
||||
),
|
||||
selector: Select("*").From(Table("nodes")),
|
||||
selector: sql.Select("*").From(sql.Table("nodes")),
|
||||
wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`id` IN (SELECT `nodes`.`prev_id` FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL)",
|
||||
},
|
||||
{
|
||||
@@ -243,7 +244,7 @@ func TestHasNeighbors(t *testing.T) {
|
||||
To("nodes", "id"),
|
||||
Edge(O2O, true, "nodes", "prev_id"),
|
||||
),
|
||||
selector: Select("*").From(Table("nodes")),
|
||||
selector: sql.Select("*").From(sql.Table("nodes")),
|
||||
wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL",
|
||||
},
|
||||
{
|
||||
@@ -253,7 +254,7 @@ func TestHasNeighbors(t *testing.T) {
|
||||
To("pets", "id"),
|
||||
Edge(O2M, false, "pets", "owner_id"),
|
||||
),
|
||||
selector: Select("*").From(Table("users")),
|
||||
selector: sql.Select("*").From(sql.Table("users")),
|
||||
wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `pets`.`owner_id` FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL)",
|
||||
},
|
||||
{
|
||||
@@ -263,7 +264,7 @@ func TestHasNeighbors(t *testing.T) {
|
||||
To("users", "id"),
|
||||
Edge(M2O, true, "pets", "owner_id"),
|
||||
),
|
||||
selector: Select("*").From(Table("pets")),
|
||||
selector: sql.Select("*").From(sql.Table("pets")),
|
||||
wantQuery: "SELECT * FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL",
|
||||
},
|
||||
{
|
||||
@@ -273,7 +274,7 @@ func TestHasNeighbors(t *testing.T) {
|
||||
To("groups", "id"),
|
||||
Edge(M2M, false, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
selector: Select("*").From(Table("users")),
|
||||
selector: sql.Select("*").From(sql.Table("users")),
|
||||
wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `user_groups`.`user_id` FROM `user_groups`)",
|
||||
},
|
||||
{
|
||||
@@ -283,7 +284,7 @@ func TestHasNeighbors(t *testing.T) {
|
||||
To("groups", "id"),
|
||||
Edge(M2M, true, "group_users", "group_id", "user_id"),
|
||||
),
|
||||
selector: Select("*").From(Table("users")),
|
||||
selector: sql.Select("*").From(sql.Table("users")),
|
||||
wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `group_users`.`user_id` FROM `group_users`)",
|
||||
},
|
||||
}
|
||||
@@ -301,8 +302,8 @@ func TestHasNeighborsWith(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
step *Step
|
||||
selector *Selector
|
||||
predicate func(*Selector)
|
||||
selector *sql.Selector
|
||||
predicate func(*sql.Selector)
|
||||
wantQuery string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
@@ -313,9 +314,9 @@ func TestHasNeighborsWith(t *testing.T) {
|
||||
To("cards", "id"),
|
||||
Edge(O2O, false, "cards", "owner_id"),
|
||||
),
|
||||
selector: Dialect("postgres").Select("*").From(Table("users")),
|
||||
predicate: func(s *Selector) {
|
||||
s.Where(EQ("expired", false))
|
||||
selector: sql.Dialect("postgres").Select("*").From(sql.Table("users")),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("expired", false))
|
||||
},
|
||||
wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE "expired" = $1)`,
|
||||
wantArgs: []interface{}{false},
|
||||
@@ -327,9 +328,9 @@ func TestHasNeighborsWith(t *testing.T) {
|
||||
To("users", "id"),
|
||||
Edge(O2O, true, "cards", "owner_id"),
|
||||
),
|
||||
selector: Dialect("postgres").Select("*").From(Table("cards")),
|
||||
predicate: func(s *Selector) {
|
||||
s.Where(EQ("name", "a8m"))
|
||||
selector: sql.Dialect("postgres").Select("*").From(sql.Table("cards")),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("name", "a8m"))
|
||||
},
|
||||
wantQuery: `SELECT * FROM "cards" WHERE "cards"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "name" = $1)`,
|
||||
wantArgs: []interface{}{"a8m"},
|
||||
@@ -341,11 +342,11 @@ func TestHasNeighborsWith(t *testing.T) {
|
||||
To("pets", "id"),
|
||||
Edge(O2M, false, "pets", "owner_id"),
|
||||
),
|
||||
selector: Dialect("postgres").Select("*").
|
||||
From(Table("users")).
|
||||
Where(EQ("last_name", "mashraki")),
|
||||
predicate: func(s *Selector) {
|
||||
s.Where(EQ("name", "pedro"))
|
||||
selector: sql.Dialect("postgres").Select("*").
|
||||
From(sql.Table("users")).
|
||||
Where(sql.EQ("last_name", "mashraki")),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("name", "pedro"))
|
||||
},
|
||||
wantQuery: `SELECT * FROM "users" WHERE "last_name" = $1 AND "users"."id" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2)`,
|
||||
wantArgs: []interface{}{"mashraki", "pedro"},
|
||||
@@ -357,11 +358,11 @@ func TestHasNeighborsWith(t *testing.T) {
|
||||
To("users", "id"),
|
||||
Edge(M2O, true, "pets", "owner_id"),
|
||||
),
|
||||
selector: Dialect("postgres").Select("*").
|
||||
From(Table("pets")).
|
||||
Where(EQ("name", "pedro")),
|
||||
predicate: func(s *Selector) {
|
||||
s.Where(EQ("last_name", "mashraki"))
|
||||
selector: sql.Dialect("postgres").Select("*").
|
||||
From(sql.Table("pets")).
|
||||
Where(sql.EQ("name", "pedro")),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("last_name", "mashraki"))
|
||||
},
|
||||
wantQuery: `SELECT * FROM "pets" WHERE "name" = $1 AND "pets"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "last_name" = $2)`,
|
||||
wantArgs: []interface{}{"pedro", "mashraki"},
|
||||
@@ -373,9 +374,9 @@ func TestHasNeighborsWith(t *testing.T) {
|
||||
To("groups", "id"),
|
||||
Edge(M2M, false, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
selector: Dialect("postgres").Select("*").From(Table("users")),
|
||||
predicate: func(s *Selector) {
|
||||
s.Where(EQ("name", "GitHub"))
|
||||
selector: sql.Dialect("postgres").Select("*").From(sql.Table("users")),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("name", "GitHub"))
|
||||
},
|
||||
wantQuery: `
|
||||
SELECT *
|
||||
@@ -393,9 +394,9 @@ WHERE "users"."id" IN
|
||||
To("users", "id"),
|
||||
Edge(M2M, true, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
selector: Dialect("postgres").Select("*").From(Table("groups")),
|
||||
predicate: func(s *Selector) {
|
||||
s.Where(EQ("name", "a8m"))
|
||||
selector: sql.Dialect("postgres").Select("*").From(sql.Table("groups")),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("name", "a8m"))
|
||||
},
|
||||
wantQuery: `
|
||||
SELECT *
|
||||
@@ -413,9 +414,9 @@ WHERE "groups"."id" IN
|
||||
To("users", "id"),
|
||||
Edge(M2M, true, "user_groups", "user_id", "group_id"),
|
||||
),
|
||||
selector: Dialect("postgres").Select("*").From(Table("groups")),
|
||||
predicate: func(s *Selector) {
|
||||
s.Where(And(NotNull("name"), EQ("name", "a8m")))
|
||||
selector: sql.Dialect("postgres").Select("*").From(sql.Table("groups")),
|
||||
predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.And(sql.NotNull("name"), sql.EQ("name", "a8m")))
|
||||
},
|
||||
wantQuery: `
|
||||
SELECT *
|
||||
@@ -734,7 +735,7 @@ func TestCreateNode(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.expect(mock)
|
||||
err = CreateNode(context.Background(), OpenDB("", db), tt.spec)
|
||||
err = CreateNode(context.Background(), sql.OpenDB("", db), tt.spec)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
})
|
||||
}
|
||||
@@ -751,17 +752,17 @@ type user struct {
|
||||
}
|
||||
|
||||
func (*user) values() []interface{} {
|
||||
return []interface{}{&NullInt64{}, &NullInt64{}, &NullString{}}
|
||||
return []interface{}{&sql.NullInt64{}, &sql.NullInt64{}, &sql.NullString{}}
|
||||
}
|
||||
|
||||
func (u *user) assign(values ...interface{}) error {
|
||||
u.id = int(values[0].(*NullInt64).Int64)
|
||||
u.age = int(values[1].(*NullInt64).Int64)
|
||||
u.name = values[2].(*NullString).String
|
||||
u.id = int(values[0].(*sql.NullInt64).Int64)
|
||||
u.age = int(values[1].(*sql.NullInt64).Int64)
|
||||
u.name = values[2].(*sql.NullString).String
|
||||
// loaded with foreign-keys.
|
||||
if len(values) > 3 {
|
||||
u.edges.fk1 = int(values[3].(*NullInt64).Int64)
|
||||
u.edges.fk2 = int(values[4].(*NullInt64).Int64)
|
||||
u.edges.fk1 = int(values[3].(*sql.NullInt64).Int64)
|
||||
u.edges.fk2 = int(values[4].(*sql.NullInt64).Int64)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -963,7 +964,7 @@ func TestUpdateNode(t *testing.T) {
|
||||
usr := &user{}
|
||||
tt.spec.Assign = usr.assign
|
||||
tt.spec.ScanTypes = usr.values()
|
||||
err = UpdateNode(context.Background(), OpenDB("", db), tt.spec)
|
||||
err = UpdateNode(context.Background(), sql.OpenDB("", db), tt.spec)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
require.Equal(t, tt.wantUser, usr)
|
||||
})
|
||||
@@ -1020,8 +1021,8 @@ func TestUpdateNodes(t *testing.T) {
|
||||
{Column: "name", Type: field.TypeString},
|
||||
},
|
||||
},
|
||||
Predicate: func(s *Selector) {
|
||||
s.Where(EQ("name", "a8m"))
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.EQ("name", "a8m"))
|
||||
},
|
||||
},
|
||||
prepare: func(mock sqlmock.Sqlmock) {
|
||||
@@ -1126,7 +1127,7 @@ func TestUpdateNodes(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
tt.prepare(mock)
|
||||
affected, err := UpdateNodes(context.Background(), OpenDB("", db), tt.spec)
|
||||
affected, err := UpdateNodes(context.Background(), sql.OpenDB("", db), tt.spec)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
require.Equal(t, tt.wantAffected, affected)
|
||||
})
|
||||
@@ -1140,7 +1141,7 @@ func TestDeleteNodes(t *testing.T) {
|
||||
mock.ExpectExec(escape("DELETE FROM `users`")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 2))
|
||||
mock.ExpectCommit()
|
||||
affected, err := DeleteNodes(context.Background(), OpenDB("", db), &DeleteSpec{
|
||||
affected, err := DeleteNodes(context.Background(), sql.OpenDB("", db), &DeleteSpec{
|
||||
Node: &NodeSpec{
|
||||
Table: "users",
|
||||
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
|
||||
@@ -1175,16 +1176,16 @@ func TestQueryNodes(t *testing.T) {
|
||||
Limit: 3,
|
||||
Offset: 4,
|
||||
Unique: true,
|
||||
Order: func(s *Selector) {
|
||||
Order: func(s *sql.Selector) {
|
||||
s.OrderBy("id")
|
||||
},
|
||||
Predicate: func(s *Selector) {
|
||||
s.Where(LT("age", 40))
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.LT("age", 40))
|
||||
},
|
||||
ScanValues: func() []interface{} {
|
||||
u := &user{}
|
||||
users = append(users, u)
|
||||
return append(u.values(), &NullInt64{}, &NullInt64{}) // extra values for fks.
|
||||
return append(u.values(), &sql.NullInt64{}, &sql.NullInt64{}) // extra values for fks.
|
||||
},
|
||||
Assign: func(values ...interface{}) error {
|
||||
return users[len(users)-1].assign(values...)
|
||||
@@ -1193,14 +1194,14 @@ func TestQueryNodes(t *testing.T) {
|
||||
)
|
||||
|
||||
// Query and scan.
|
||||
err = QueryNodes(context.Background(), OpenDB("", db), spec)
|
||||
err = QueryNodes(context.Background(), sql.OpenDB("", db), spec)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &user{id: 1, age: 10, name: ""}, users[0])
|
||||
require.Equal(t, &user{id: 2, age: 20, name: ""}, users[1])
|
||||
require.Equal(t, &user{id: 3, age: 30, name: "a8m", edges: struct{ fk1, fk2 int }{1, 1}}, users[2])
|
||||
|
||||
// Count nodes.
|
||||
n, err := CountNodes(context.Background(), OpenDB("", db), spec)
|
||||
n, err := CountNodes(context.Background(), sql.OpenDB("", db), spec)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, n)
|
||||
}
|
||||
Reference in New Issue
Block a user