mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
89 lines
2.5 KiB
Cheetah
89 lines
2.5 KiB
Cheetah
{{ define "dialect/sql/query" }}
|
|
{{ $pkg := $.Scope.Package }}
|
|
{{ $builder := pascal $.Scope.Builder }}
|
|
{{ $receiver := receiver $builder }}
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name }}, error) {
|
|
rows := &sql.Rows{}
|
|
selector := {{ $receiver }}.sqlQuery()
|
|
if unique := {{ $receiver }}.unique; len(unique) == 0 {
|
|
selector.Distinct()
|
|
}
|
|
query, args := selector.Query()
|
|
if err := {{ $receiver }}.driver.Query(ctx, query, args, rows); err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
{{- $ret := plural $.Receiver }}
|
|
var {{ $ret }} {{ plural $.Name }}
|
|
if err := {{ $ret }}.FromRows(rows); err != nil {
|
|
return nil, err
|
|
}
|
|
{{ $ret }}.config({{ $receiver }}.config)
|
|
return {{ $ret }}, nil
|
|
}
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
|
|
rows := &sql.Rows{}
|
|
selector := {{ $receiver }}.sqlQuery()
|
|
unique := []string{ {{ $.Package }}.{{ $.ID.Constant }} }
|
|
if len({{ $receiver }}.unique) > 0 {
|
|
unique = {{ $receiver }}.unique
|
|
}
|
|
selector.Count(sql.Distinct(selector.Columns(unique...)...))
|
|
query, args := selector.Query()
|
|
if err := {{ $receiver }}.driver.Query(ctx, query, args, rows); err != nil {
|
|
return 0, err
|
|
}
|
|
defer rows.Close()
|
|
if !rows.Next() {
|
|
return 0, errors.New("{{ $pkg }}: no rows found")
|
|
}
|
|
var n int
|
|
if err := rows.Scan(&n); err != nil {
|
|
return 0, fmt.Errorf("{{ $pkg }}: failed reading count: %v", err)
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) sqlExist(ctx context.Context) (bool, error) {
|
|
n, err := {{ $receiver }}.sqlCount(ctx)
|
|
if err != nil {
|
|
return false, fmt.Errorf("{{ $pkg }}: check existence: %v", err)
|
|
}
|
|
return n > 0, nil
|
|
}
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) sqlIDs(ctx context.Context) ([]{{ $.ID.Type }}, error) {
|
|
vs, err := {{ $receiver }}.sqlAll(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var ids []{{ $.ID.Type }}
|
|
for _, v := range vs {
|
|
ids = append(ids, v.ID)
|
|
}
|
|
return ids, nil
|
|
}
|
|
|
|
|
|
func ({{ $receiver }} *{{ $builder }}) sqlQuery() *sql.Selector {
|
|
t1 := sql.Table({{ $.Package }}.Table)
|
|
selector := sql.Select(t1.Columns({{ $.Package }}.Columns...)...).From(t1)
|
|
if {{ $receiver }}.sql != nil {
|
|
selector = {{ $receiver }}.sql
|
|
selector.Select(selector.Columns({{ $.Package }}.Columns...)...)
|
|
}
|
|
for _, p := range {{ $receiver }}.predicates {
|
|
p.SQL(selector)
|
|
}
|
|
for _, p := range {{ $receiver }}.order {
|
|
p.SQL(selector)
|
|
}
|
|
if limit := {{ $receiver }}.limit; limit != nil {
|
|
selector.Limit(*limit)
|
|
}
|
|
return selector
|
|
}
|
|
|
|
{{ end }} |