entc/gen: allow scanning dynamic sql values (#3432)

This commit is contained in:
Ariel Mashraki
2023-04-03 17:07:24 +03:00
committed by GitHub
parent 6065db39fa
commit ba7f158a9c
135 changed files with 1740 additions and 325 deletions

View File

@@ -309,9 +309,6 @@ type (
OrderByOptions struct {
// Step to get the edge to order by its count.
Step *Step
// Desc indicates if the ordering should be in descending order. When false, NULL values
// are ordered first (default in MySQL and SQLite). When true, NULLs are ordered last.
Desc bool
// Terms used for non-aggregation ordering.
// See, OrderByNeighborTerms for more info.
Terms []OrderByTerm
@@ -332,11 +329,17 @@ type (
OrderByOption func(*OrderByOptions)
)
// OrderDesc sets the order to be descending order.
// This option is valid only for a single count order.
// OrderDesc sets the latest order by term as descending order,
// or add a new descending order term if no terms are present.
func OrderDesc() OrderByOption {
return func(opts *OrderByOptions) {
opts.Desc = true
if len(opts.Terms) > 0 {
opts.Terms[len(opts.Terms)-1].Desc = true
} else {
opts.Terms = append(opts.Terms, OrderByTerm{
Desc: true,
})
}
}
}
@@ -402,12 +405,15 @@ func NewOrderBy(s *Step, opts ...OrderByOption) *OrderByOptions {
}
// countAlias returns the alias to use for the count column.
func countAlias(q *sql.Selector, s *Step) string {
func countAlias(q *sql.Selector, opts *OrderByOptions) string {
if len(opts.Terms) == 1 && opts.Terms[0].As != "" {
return opts.Terms[0].As
}
selected := make(map[string]struct{})
for _, c := range q.SelectedColumns() {
selected[c] = struct{}{}
}
column := fmt.Sprintf("count_%s", s.To.Table)
column := fmt.Sprintf("count_%s", opts.Step.To.Table)
// If the column was already selected,
// try to find a free alias.
if _, ok := selected[column]; ok {
@@ -423,32 +429,30 @@ func countAlias(q *sql.Selector, s *Step) string {
// OrderByNeighborsCount appends ordering based on the number of neighbors.
// For example, order users by their number of posts.
func OrderByNeighborsCount(q *sql.Selector, opts *OrderByOptions) *OrderByInfo {
func OrderByNeighborsCount(q *sql.Selector, opts *OrderByOptions) {
var (
countC string
join *sql.Selector
build = sql.Dialect(q.Dialect())
desc bool
join *sql.Selector
build = sql.Dialect(q.Dialect())
)
if len(opts.Terms) == 1 {
desc = opts.Terms[0].Desc
}
switch s := opts.Step; {
case s.FromEdgeOwner():
// For M2O and O2O inverse, the FK resides in the same table.
// Hence, the order by is on the nullability of the column.
x := func(b *sql.Builder) {
b.Ident(s.From.Column)
if opts.Desc {
if desc {
b.WriteOp(sql.OpNotNull)
} else {
b.WriteOp(sql.OpIsNull)
}
}
q.OrderExpr(build.Expr(x))
return &OrderByInfo{
Terms: []OrderByTerm{
{Expr: build.Expr(x), Type: field.TypeBool},
},
}
case s.ThroughEdgeTable():
countC = countAlias(q, s)
countC := countAlias(q, opts)
pk1 := s.Edge.Columns[0]
if s.Edge.Inverse {
pk1 = s.Edge.Columns[1]
@@ -465,8 +469,11 @@ func OrderByNeighborsCount(q *sql.Selector, opts *OrderByOptions) *OrderByInfo {
q.C(s.From.Column),
join.C(pk1),
)
orderTerms(q, join, []OrderByTerm{
{Column: countC, Type: field.TypeInt, Desc: desc},
})
case s.ToEdgeOwner():
countC = countAlias(q, s)
countC := countAlias(q, opts)
edgeT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
join = build.Select(
edgeT.C(s.Edge.Columns[0]),
@@ -479,12 +486,10 @@ func OrderByNeighborsCount(q *sql.Selector, opts *OrderByOptions) *OrderByInfo {
q.C(s.From.Column),
join.C(s.Edge.Columns[0]),
)
orderTerms(q, join, []OrderByTerm{
{Column: countC, Type: field.TypeInt, Desc: desc},
})
}
terms := []OrderByTerm{
{Column: countC, Type: field.TypeInt, Desc: opts.Desc},
}
orderTerms(q, join, terms)
return &OrderByInfo{Terms: terms}
}
func orderTerms(q, join *sql.Selector, ts []OrderByTerm) {
@@ -961,6 +966,11 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
if err != nil {
return err
}
for i, v := range values {
if _, ok := v.(*sql.UnknownType); ok {
values[i] = sql.ScanTypeOf(rows, i)
}
}
if err := rows.Scan(values...); err != nil {
return err
}
@@ -1303,6 +1313,11 @@ func (u *updater) scan(rows *sql.Rows) error {
if err != nil {
return err
}
for i, v := range values {
if _, ok := v.(*sql.UnknownType); ok {
values[i] = sql.ScanTypeOf(rows, i)
}
}
if err := rows.Scan(values...); err != nil {
return fmt.Errorf("failed scanning rows: %w", err)
}