dialect/sql/sqlgraph: add function to order by edge count (#3422)

This commit is contained in:
Ariel Mashraki
2023-03-30 10:10:59 +03:00
committed by GitHub
parent 651a2a166e
commit e3cee0adc2
6 changed files with 331 additions and 12 deletions

View File

@@ -3330,12 +3330,12 @@ func (b *Builder) Quote(ident string) string {
func (b *Builder) Ident(s string) *Builder {
switch {
case len(s) == 0:
case !strings.HasSuffix(s, "*") && !b.isIdent(s) && !isFunc(s) && !isModifier(s):
case !strings.HasSuffix(s, "*") && !b.isIdent(s) && !isFunc(s) && !isModifier(s) && !isAlias(s):
if b.qualifier != "" {
b.WriteString(b.Quote(b.qualifier)).WriteByte('.')
}
b.WriteString(b.Quote(s))
case (isFunc(s) || isModifier(s)) && b.postgres():
case (isFunc(s) || isModifier(s) || isAlias(s)) && b.postgres():
// Modifiers and aggregation functions that
// were called without dialect information.
b.WriteString(strings.ReplaceAll(s, "`", `"`))
@@ -3432,8 +3432,8 @@ func (b *Builder) Err() error {
// An Op represents an operator.
type Op int
// Predicate and arithmetic operators.
const (
// Predicate operators.
OpEQ Op = iota // =
OpNEQ // <>
OpGT // >
@@ -3445,13 +3445,11 @@ const (
OpLike // LIKE
OpIsNull // IS NULL
OpNotNull // IS NOT NULL
// Arithmetic operators.
OpAdd // +
OpSub // -
OpMul // *
OpDiv // / (Quotient)
OpMod // % (Reminder)
OpAdd // +
OpSub // -
OpMul // *
OpDiv // / (Quotient)
OpMod // % (Reminder)
)
var ops = [...]string{
@@ -3713,6 +3711,19 @@ func Dialect(name string) *DialectBuilder {
return &DialectBuilder{name}
}
// String builds a dialect-aware expression string from the given callback.
func (d *DialectBuilder) String(f func(*Builder)) string {
b := &Builder{}
b.SetDialect(d.dialect)
f(b)
return b.String()
}
// Expr builds a dialect-aware expression from the given callback.
func (d *DialectBuilder) Expr(f func(*Builder)) Querier {
return Expr(d.String(f))
}
// Describe creates a DescribeBuilder for the configured dialect.
//
// Dialect(dialect.Postgres).
@@ -3870,6 +3881,10 @@ func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder {
return b
}
func isAlias(s string) bool {
return strings.Contains(s, " AS ") || strings.Contains(s, " as ")
}
func isFunc(s string) bool {
return strings.Contains(s, "(") && strings.Contains(s, ")")
}

View File

@@ -82,6 +82,8 @@ type Step struct {
Columns []string
// Inverse indicates if the edge is an inverse edge.
Inverse bool
// Name allows giving this edge a name for making queries more readable.
Name string
}
// To is the dest of the path (the neighbors).
To struct {
@@ -287,6 +289,126 @@ func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) {
}
}
type (
// OrderByOptions holds the information needed to order a query by an edge.
OrderByOptions struct {
// Step to get the edge to order by.
Step *Step
// Desc indicates if the ordering should be descending.
// When false, nulls are ordered first. When true, nulls
// are ordered last.
Desc bool
}
// OrderByInfo holds the information done by the OrderBy functions.
OrderByInfo struct {
Terms []OrderByTerm
}
// OrderByTerm holds the terms of an order by clause.
OrderByTerm struct {
Column string // Column name. If empty, an expression is used.
Expr sql.Querier // Expression. If nil, the column is used.
Type field.Type // Term type.
}
)
// countAlias returns the alias to use for the count column.
func countAlias(q *sql.Selector, s *Step) string {
eName := s.Edge.Name
if eName == "" {
eName = s.To.Table
}
selected := make(map[string]struct{})
for _, c := range q.SelectedColumns() {
selected[c] = struct{}{}
}
column := fmt.Sprintf("count_%s", eName)
// If the column was already selected,
// try to find a free alias.
if _, ok := selected[column]; ok {
for i := 1; i <= 5; i++ {
ci := fmt.Sprintf("%s_%d", column, i)
if _, ok := selected[ci]; !ok {
return ci
}
}
}
return column
}
// OrderByCountNeighbors appends ordering based on the number of neighbors.
// For example, order users by their number of posts.
// HasNeighbors applies on the given Selector a neighbors check.
func OrderByCountNeighbors(q *sql.Selector, opts *OrderByOptions) *OrderByInfo {
var (
countC string
build = sql.Dialect(q.Dialect())
)
switch s, r := opts.Step, opts.Step.Edge.Rel; {
case r == M2O || (r == O2O && s.Edge.Inverse):
// For M2O and O2O inverse, the FK resides in the same table.
// Hence, the order by is on the nullability of the column.
x := func(b *sql.Builder) {
b.Ident(s.From.Column)
if opts.Desc {
b.WriteOp(sql.OpNotNull)
} else {
b.WriteOp(sql.OpIsNull)
}
}
q.OrderExpr(build.Expr(x))
return &OrderByInfo{
Terms: []OrderByTerm{
{Expr: build.Expr(x), Type: field.TypeBool},
},
}
case r == M2M:
countC = countAlias(q, s)
pk1 := s.Edge.Columns[0]
if s.Edge.Inverse {
pk1 = s.Edge.Columns[1]
}
joinT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
to := build.Select(
joinT.C(pk1),
build.String(func(b *sql.Builder) {
b.WriteString("COUNT(*) AS ").Ident(countC)
}),
).From(joinT).GroupBy(joinT.C(pk1))
q.LeftJoin(to).
On(
q.C(s.From.Column),
to.C(pk1),
)
case r == O2M || (r == O2O && !s.Edge.Inverse):
countC = countAlias(q, s)
edgeT := build.Table(s.Edge.Table).Schema(s.Edge.Schema)
to := build.Select(
edgeT.C(s.Edge.Columns[0]),
build.String(func(b *sql.Builder) {
b.WriteString("COUNT(*) AS ").Ident(countC)
}),
).From(edgeT).GroupBy(edgeT.C(s.Edge.Columns[0]))
q.LeftJoin(to).
On(
q.C(s.From.Column),
to.C(s.Edge.Columns[0]),
)
}
q.OrderExpr(
build.Expr(func(b *sql.Builder) {
b.WriteString("COALESCE(").Ident(countC).WriteString(", 0)")
if opts.Desc {
b.WriteString(" DESC")
}
}),
)
return &OrderByInfo{
Terms: []OrderByTerm{
{Column: countC, Type: field.TypeInt},
},
}
}
type (
// FieldSpec holds the information for updating a field
// column in the database.

View File

@@ -911,6 +911,66 @@ func TestHasNeighborsWithContext(t *testing.T) {
}
}
func TestOrderByCountNeighbors(t *testing.T) {
build := sql.Dialect(dialect.Postgres)
t1 := build.Table("users")
s := build.Select(t1.C("name")).
From(t1)
t.Run("O2M", func(t *testing.T) {
s := s.Clone()
OrderByCountNeighbors(s, &OrderByOptions{
Step: NewStep(
From("users", "id"),
To("pets", "owner_id"),
Edge(O2M, false, "pets", "owner_id"),
),
Desc: true,
})
query, args := s.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "pets"."owner_id", COUNT(*) AS "count_pets" FROM "pets" GROUP BY "pets"."owner_id") AS "t1" ON "users"."id" = "t1"."owner_id" ORDER BY COALESCE("count_pets", 0) DESC`, query)
})
t.Run("M2M", func(t *testing.T) {
s := s.Clone()
OrderByCountNeighbors(s, &OrderByOptions{
Step: NewStep(
From("users", "id"),
To("groups", "id"),
Edge(M2M, false, "user_groups", "user_id", "group_id"),
),
})
query, args := s.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "user_groups"."user_id", COUNT(*) AS "count_groups" FROM "user_groups" GROUP BY "user_groups"."user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY COALESCE("count_groups", 0)`, query)
})
// Zero or one.
t.Run("M2O", func(t *testing.T) {
s1, s2 := s.Clone(), s.Clone()
OrderByCountNeighbors(s1, &OrderByOptions{
Step: NewStep(
From("pets", "owner_id"),
To("users", "id"),
Edge(M2O, true, "pets", "owner_id"),
),
})
query, args := s1.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" ORDER BY "owner_id" IS NULL`, query)
OrderByCountNeighbors(s2, &OrderByOptions{
Step: NewStep(
From("pets", "owner_id"),
To("users", "id"),
Edge(M2O, true, "pets", "owner_id"),
),
Desc: true,
})
query, args = s2.Query()
require.Empty(t, args)
require.Equal(t, `SELECT "users"."name" FROM "users" ORDER BY "owner_id" IS NOT NULL`, query)
})
}
func TestCreateNode(t *testing.T) {
tests := []struct {
name string

View File

@@ -85,7 +85,7 @@ var drivers = []*Storage{
},
}
// NewStorage returns a the storage driver type from the given string.
// NewStorage returns the storage driver type from the given string.
// It fails if the provided string is not a valid option. this function
// is here in order to remove the validation logic from entc command line.
func NewStorage(s string) (*Storage, error) {

View File

@@ -36,4 +36,4 @@ import (
}
{{- end }}
{{ end }}
{{ end }}
{{ end }}

View File

@@ -166,6 +166,7 @@ var (
ConstraintChecks,
NillableRequired,
ExtValueScan,
OrderByEdgeCount,
}
)
@@ -2361,6 +2362,127 @@ func ExtValueScan(t *testing.T, client *ent.Client) {
require.False(t, client.ExValueScan.Query().Where(exvaluescan.CustomHasSuffix("io")).ExistX(ctx))
}
// Testing the "low-level" behavior of the sqlgraph package.
// This functionality may be extended to the generated fluent API.
func OrderByEdgeCount(t *testing.T, client *ent.Client) {
ctx := context.Background()
users := client.User.CreateBulk(
client.User.Create().SetName("a").SetAge(1),
client.User.Create().SetName("b").SetAge(2),
client.User.Create().SetName("c").SetAge(3),
client.User.Create().SetName("d").SetAge(4),
).SaveX(ctx)
pets := client.Pet.CreateBulk(
client.Pet.Create().SetName("aa").SetOwner(users[0]),
client.Pet.Create().SetName("ab").SetOwner(users[0]),
client.Pet.Create().SetName("ac").SetOwner(users[0]),
client.Pet.Create().SetName("ba").SetOwner(users[1]),
client.Pet.Create().SetName("bb").SetOwner(users[1]),
client.Pet.Create().SetName("ca").SetOwner(users[2]),
client.Pet.Create().SetName("d"),
client.Pet.Create().SetName("e"),
).SaveX(ctx)
// O2M edge.
for _, tt := range []struct {
desc bool
ids []int
}{
{desc: true, ids: []int{users[0].ID, users[1].ID, users[2].ID, users[3].ID}},
{desc: false, ids: []int{users[3].ID, users[2].ID, users[1].ID, users[0].ID}},
} {
ids := client.User.Query().
Order(func(s *sql.Selector) {
sqlgraph.OrderByCountNeighbors(s, &sqlgraph.OrderByOptions{
Desc: tt.desc,
Step: sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID),
sqlgraph.To(pet.Table, pet.OwnerColumn),
sqlgraph.Edge(sqlgraph.O2M, false, pet.Table, pet.OwnerColumn),
),
})
}).
IDsX(ctx)
require.Equal(t, tt.ids, ids)
}
// M2O edge (true or false).
for _, tt := range []struct {
desc bool
ids []int
}{
{desc: true, ids: []int{pets[6].ID, pets[7].ID, pets[0].ID, pets[1].ID, pets[2].ID, pets[3].ID, pets[4].ID, pets[5].ID}},
{desc: false, ids: []int{pets[0].ID, pets[1].ID, pets[2].ID, pets[3].ID, pets[4].ID, pets[5].ID, pets[6].ID, pets[7].ID}},
} {
ids := client.Pet.Query().
Order(
func(s *sql.Selector) {
sqlgraph.OrderByCountNeighbors(s, &sqlgraph.OrderByOptions{
Desc: tt.desc,
Step: sqlgraph.NewStep(
sqlgraph.From(pet.Table, pet.OwnerColumn),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, pet.Table, pet.OwnerColumn),
),
})
},
ent.Asc(pet.FieldID),
).
IDsX(ctx)
require.Equal(t, tt.ids, ids)
}
inf, exp := client.GroupInfo.Create().SetDesc("desc").SaveX(ctx), time.Now()
groups := client.Group.CreateBulk(
client.Group.Create().SetName("Group: 4 users").SetExpire(exp).SetInfo(inf).AddUsers(users...),
client.Group.Create().SetName("Group: 3 users").SetExpire(exp).SetInfo(inf).AddUsers(users[:3]...),
client.Group.Create().SetName("Group: 2 users").SetExpire(exp).SetInfo(inf).AddUsers(users[:2]...),
client.Group.Create().SetName("Group: 1 users").SetExpire(exp).SetInfo(inf).AddUsers(users[:1]...),
client.Group.Create().SetName("Group: 0 users").SetExpire(exp).SetInfo(inf),
).SaveX(ctx)
// M2M edge (inverse).
for _, tt := range []struct {
desc bool
ids []int
}{
{desc: true, ids: []int{groups[0].ID, groups[1].ID, groups[2].ID, groups[3].ID, groups[4].ID}},
{desc: false, ids: []int{groups[4].ID, groups[3].ID, groups[2].ID, groups[1].ID, groups[0].ID}},
} {
ids := client.Group.Query().
Order(func(s *sql.Selector) {
sqlgraph.OrderByCountNeighbors(s, &sqlgraph.OrderByOptions{
Desc: tt.desc,
Step: sqlgraph.NewStep(
sqlgraph.From(group.Table, group.FieldID),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2M, true, group.UsersTable, group.UsersPrimaryKey...),
),
})
}).
IDsX(ctx)
require.Equal(t, tt.ids, ids)
}
// M2M edge (assoc).
for _, tt := range []struct {
desc bool
ids []int
}{
{desc: true, ids: []int{users[0].ID, users[1].ID, users[2].ID, users[3].ID}},
{desc: false, ids: []int{users[3].ID, users[2].ID, users[1].ID, users[0].ID}},
} {
ids := client.User.Query().
Order(func(s *sql.Selector) {
sqlgraph.OrderByCountNeighbors(s, &sqlgraph.OrderByOptions{
Desc: tt.desc,
Step: sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID),
sqlgraph.To(group.Table, group.FieldID),
sqlgraph.Edge(sqlgraph.M2M, false, user.GroupsTable, user.GroupsPrimaryKey...),
),
})
}).
IDsX(ctx)
require.Equal(t, tt.ids, ids)
}
}
func skip(t *testing.T, names ...string) {
for _, n := range names {
if strings.Contains(t.Name(), n) {