diff --git a/ent.go b/ent.go index f9f939377..a0eb0b28f 100644 --- a/ent.go +++ b/ent.go @@ -473,8 +473,21 @@ type ( // QueryContext contains additional information about // the context in which the query is executed. QueryContext struct { - Op string // operation name - Type string // type name + // Op defines the operation name. e.g., First, All, Count, etc. + Op string + // Type defines the query type as defined in the generated code. + Type string + // Unique indicates if the Unique modifier was set on the query and + // its value. Calling Unique(false) sets the value of Unique to false. + Unique *bool + // Limit indicates if the Limit modifier was set on the query and + // its value. Calling Limit(10) sets the value of Limit to 10. + Limit *int + // Offset indicates if the Offset modifier was set on the query and + // its value. Calling Offset(10) sets the value of Offset to 10. + Offset *int + // Fields specifies the fields that were selected in the query. + Fields []string } queryCtxKey struct{} ) @@ -489,3 +502,25 @@ func QueryFromContext(ctx context.Context) *QueryContext { c, _ := ctx.Value(queryCtxKey{}).(*QueryContext) return c } + +// Clone returns a deep copy of the query context. +func (q *QueryContext) Clone() *QueryContext { + c := &QueryContext{ + Op: q.Op, + Type: q.Type, + Fields: append([]string(nil), q.Fields...), + } + if q.Unique != nil { + v := *q.Unique + c.Unique = &v + } + if q.Limit != nil { + v := *q.Limit + c.Limit = &v + } + if q.Offset != nil { + v := *q.Offset + c.Offset = &v + } + return c +} diff --git a/entc/gen/template/base.tmpl b/entc/gen/template/base.tmpl index aa36f24c7..61085ad0b 100644 --- a/entc/gen/template/base.tmpl +++ b/entc/gen/template/base.tmpl @@ -26,6 +26,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -305,10 +306,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type:typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/gen/template/builder/query.tmpl b/entc/gen/template/builder/query.tmpl index 06bdcf8fa..01f7062c6 100644 --- a/entc/gen/template/builder/query.tmpl +++ b/entc/gen/template/builder/query.tmpl @@ -25,11 +25,8 @@ import ( // {{ $builder }} is the builder for querying {{ $.Name }} entities. type {{ $builder }} struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.{{ $.Name }} {{- /* Eager loading fields. */}} @@ -54,20 +51,20 @@ func ({{ $receiver }} *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $ // Limit the number of records to be returned by this query. func ({{ $receiver }} *{{ $builder }}) Limit(limit int) *{{ $builder }} { - {{ $receiver }}.limit = &limit + {{ $receiver }}.ctx.Limit = &limit return {{ $receiver }} } // Offset to start from. func ({{ $receiver }} *{{ $builder }}) Offset(offset int) *{{ $builder }} { - {{ $receiver }}.offset = &offset + {{ $receiver }}.ctx.Offset = &offset return {{ $receiver }} } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func ({{ $receiver }} *{{ $builder }}) Unique(unique bool) *{{ $builder }} { - {{ $receiver }}.unique = &unique + {{ $receiver }}.ctx.Unique = &unique return {{ $receiver }} } @@ -100,7 +97,7 @@ func ({{ $receiver }} *{{ $builder }}) Order(o ...OrderFunc) *{{ $builder }} { // First returns the first {{ $.Name }} entity from the query. // Returns a *NotFoundError when no {{ $.Name }} was found. func ({{ $receiver }} *{{ $builder }}) First(ctx context.Context) (*{{ $.Name }}, error) { - nodes, err := {{ $receiver }}.Limit(1).All(newQueryContext(ctx, {{ $.TypeName }}, "First")) + nodes, err := {{ $receiver }}.Limit(1).All(setContextOp(ctx, {{ $receiver }}.ctx, "First")) if err != nil { return nil, err } @@ -124,7 +121,7 @@ func ({{ $receiver }} *{{ $builder }}) FirstX(ctx context.Context) *{{ $.Name }} // Returns a *NotFoundError when no {{ $.Name }} ID was found. func ({{ $receiver }} *{{ $builder }}) FirstID(ctx context.Context) (id {{ $.ID.Type }}, err error) { var ids []{{ $.ID.Type }} - if ids, err = {{ $receiver }}.Limit(1).IDs(newQueryContext(ctx, {{ $.TypeName }}, "FirstID")); err != nil { + if ids, err = {{ $receiver }}.Limit(1).IDs(setContextOp(ctx, {{ $receiver }}.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -133,7 +130,7 @@ func ({{ $receiver }} *{{ $builder }}) FirstX(ctx context.Context) *{{ $.Name }} } return ids[0], nil } - + // FirstIDX is like FirstID, but panics if an error occurs. func ({{ $receiver }} *{{ $builder }}) FirstIDX(ctx context.Context) {{ $.ID.Type }} { id, err := {{ $receiver }}.FirstID(ctx) @@ -148,7 +145,7 @@ func ({{ $receiver }} *{{ $builder }}) FirstX(ctx context.Context) *{{ $.Name }} // Returns a *NotSingularError when more than one {{ $.Name }} entity is found. // Returns a *NotFoundError when no {{ $.Name }} entities are found. func ({{ $receiver }} *{{ $builder }}) Only(ctx context.Context) (*{{ $.Name }}, error) { - nodes, err := {{ $receiver }}.Limit(2).All(newQueryContext(ctx, {{ $.TypeName }}, "Only")) + nodes, err := {{ $receiver }}.Limit(2).All(setContextOp(ctx, {{ $receiver }}.ctx, "Only")) if err != nil { return nil, err } @@ -177,7 +174,7 @@ func ({{ $receiver }} *{{ $builder }}) OnlyX(ctx context.Context) *{{ $.Name }} // Returns a *NotFoundError when no entities are found. func ({{ $receiver }} *{{ $builder }}) OnlyID(ctx context.Context) (id {{ $.ID.Type }}, err error) { var ids []{{ $.ID.Type }} - if ids, err = {{ $receiver }}.Limit(2).IDs(newQueryContext(ctx, {{ $.TypeName }}, "OnlyID")); err != nil { + if ids, err = {{ $receiver }}.Limit(2).IDs(setContextOp(ctx, {{ $receiver }}.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -190,7 +187,7 @@ func ({{ $receiver }} *{{ $builder }}) OnlyX(ctx context.Context) *{{ $.Name }} } return } - + // OnlyIDX is like OnlyID, but panics if an error occurs. func ({{ $receiver }} *{{ $builder }}) OnlyIDX(ctx context.Context) {{ $.ID.Type }} { id, err := {{ $receiver }}.OnlyID(ctx) @@ -203,7 +200,7 @@ func ({{ $receiver }} *{{ $builder }}) OnlyX(ctx context.Context) *{{ $.Name }} // All executes the query and returns a list of {{ plural $.Name }}. func ({{ $receiver }} *{{ $builder }}) All(ctx context.Context) ([]*{{ $.Name }}, error) { - ctx = newQueryContext(ctx, {{ $.TypeName }}, "All") + ctx = setContextOp(ctx, {{ $receiver }}.ctx, "All") if err := {{ $receiver }}.prepareQuery(ctx); err != nil { return nil, err } @@ -224,13 +221,13 @@ func ({{ $receiver }} *{{ $builder }}) AllX(ctx context.Context) []*{{ $.Name }} // IDs executes the query and returns a list of {{ $.Name }} IDs. func ({{ $receiver }} *{{ $builder }}) IDs(ctx context.Context) ([]{{ $.ID.Type }}, error) { var ids []{{ $.ID.Type }} - ctx = newQueryContext(ctx, {{ $.TypeName }}, "IDs") + ctx = setContextOp(ctx, {{ $receiver }}.ctx, "IDs") if err := {{ $receiver }}.Select({{ $.Package }}.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil } - + // IDsX is like IDs, but panics if an error occurs. func ({{ $receiver }} *{{ $builder }}) IDsX(ctx context.Context) []{{ $.ID.Type }} { ids, err := {{ $receiver }}.IDs(ctx) @@ -243,7 +240,7 @@ func ({{ $receiver }} *{{ $builder }}) AllX(ctx context.Context) []*{{ $.Name }} // Count returns the count of the given query. func ({{ $receiver }} *{{ $builder }}) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, {{ $.TypeName }}, "Count") + ctx = setContextOp(ctx, {{ $receiver }}.ctx, "Count") if err := {{ $receiver }}.prepareQuery(ctx); err != nil { return 0, err } @@ -261,7 +258,7 @@ func ({{ $receiver }} *{{ $builder }}) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func ({{ $receiver }} *{{ $builder }}) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, {{ $.TypeName }}, "Exist") + ctx = setContextOp(ctx, {{ $receiver }}.ctx, "Exist") switch _, err := {{ $receiver }}.First{{ if $.HasOneFieldID }}ID{{ end }}(ctx);{ case IsNotFound(err): return false, nil @@ -289,8 +286,7 @@ func ({{ $receiver }} *{{ $builder }}) Clone() *{{ $builder }} { } return &{{ $builder }}{ config: {{ $receiver }}.config, - limit: {{ $receiver }}.limit, - offset: {{ $receiver }}.offset, + ctx: {{ $receiver }}.ctx.Clone(), order: append([]OrderFunc{}, {{ $receiver }}.order...), inters: append([]Interceptor{}, {{ $receiver }}.inters...), predicates: append([]predicate.{{ $.Name }}{}, {{ $receiver }}.predicates...), @@ -300,7 +296,6 @@ func ({{ $receiver }} *{{ $builder }}) Clone() *{{ $builder }} { // clone intermediate query. {{ $.Storage }}: {{ $receiver }}.{{ $.Storage }}.Clone(), path: {{ $receiver }}.path, - unique: {{ $receiver }}.unique, } } @@ -340,9 +335,9 @@ func ({{ $receiver }} *{{ $builder }}) Clone() *{{ $builder }} { // {{- end }} func ({{ $receiver }} *{{ $builder }}) GroupBy(field string, fields ...string) *{{ $groupBuilder }} { - {{ $receiver }}.fields = append([]string{field}, fields...) + {{ $receiver }}.ctx.Fields = append([]string{field}, fields...) grbuild := &{{ $groupBuilder }}{build: {{ $receiver }}} - grbuild.flds = &{{ $receiver }}.fields + grbuild.flds = &{{ $receiver }}.ctx.Fields grbuild.label = {{ $.Package }}.Label grbuild.scan = grbuild.Scan return grbuild @@ -367,10 +362,10 @@ func ({{ $receiver }} *{{ $builder }}) GroupBy(field string, fields ...string) * // {{- end }} func ({{ $receiver }} *{{ $builder }}) Select(fields ...string) *{{ $selectBuilder }} { - {{ $receiver }}.fields = append({{ $receiver }}.fields, fields...) + {{ $receiver }}.ctx.Fields = append({{ $receiver }}.ctx.Fields, fields...) sbuild := &{{ $selectBuilder }}{ {{ $builder }}: {{ $receiver }} } sbuild.label = {{ $.Package }}.Label - sbuild.flds, sbuild.scan = &{{ $receiver }}.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &{{ $receiver }}.ctx.Fields, sbuild.Scan return sbuild } @@ -447,7 +442,7 @@ func ({{ $groupReceiver }} *{{ $groupBuilder }}) Aggregate(fns ...AggregateFunc) // Scan applies the selector query and scans the result into the given value. func ({{ $groupReceiver }} *{{ $groupBuilder }}) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, {{ $.TypeName }}, "GroupBy") + ctx = setContextOp(ctx, {{ $groupReceiver }}.build.ctx, "GroupBy") if err := {{ $groupReceiver }}.build.prepareQuery(ctx); err != nil { return err } @@ -477,7 +472,7 @@ func ({{ $selectReceiver }} *{{ $selectBuilder }}) Aggregate(fns ...AggregateFun // Scan applies the selector query and scans the result into the given value. func ({{ $selectReceiver }} *{{ $selectBuilder }}) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, {{ $.TypeName }}, "Select") + ctx = setContextOp(ctx, {{ $selectReceiver }}.ctx, "Select") if err := {{ $selectReceiver }}.prepareQuery(ctx); err != nil { return err } diff --git a/entc/gen/template/client.tmpl b/entc/gen/template/client.tmpl index 1809879c3..56f2f1670 100644 --- a/entc/gen/template/client.tmpl +++ b/entc/gen/template/client.tmpl @@ -253,6 +253,7 @@ func (c *{{ $client }}) Delete() *{{ $n.DeleteName }} { func (c *{{ $client }}) Query() *{{ $n.QueryName }} { return &{{ $n.QueryName }}{ config: c.config, + ctx: &QueryContext{Type: {{ $n.TypeName }} }, inters: c.Interceptors(), {{- with $tmpls := matchTemplate (printf "dialect/%s/query/fields/init/*" $.Storage) }} {{- range $tmpl := $tmpls }} diff --git a/entc/gen/template/dialect/gremlin/query.tmpl b/entc/gen/template/dialect/gremlin/query.tmpl index 2dc7e1d69..9bf82b9c2 100644 --- a/entc/gen/template/dialect/gremlin/query.tmpl +++ b/entc/gen/template/dialect/gremlin/query.tmpl @@ -13,9 +13,9 @@ in the LICENSE file in the root directory of this source tree. func ({{ $receiver }} *{{ $builder }}) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) { res := &gremlin.Response{} traversal := {{ $receiver }}.gremlinQuery(ctx) - if len({{ $receiver }}.fields) > 0 { - fields := make([]any, len({{ $receiver }}.fields)) - for i, f := range {{ $receiver }}.fields { + if len({{ $receiver }}.ctx.Fields) > 0 { + fields := make([]any, len({{ $receiver }}.ctx.Fields)) + for i, f := range {{ $receiver }}.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -57,7 +57,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlinQuery(context.Context) *dsl.Traver p(v) } } - switch limit, offset := {{ $receiver }}.limit, {{ $receiver }}.offset; { + switch limit, offset := {{ $receiver }}.ctx.Limit, {{ $receiver }}.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset + *limit) case offset != nil: @@ -65,7 +65,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlinQuery(context.Context) *dsl.Traver case limit != nil: v.Limit(*limit) } - if unique := {{ $receiver }}.unique; unique == nil || *unique { + if unique := {{ $receiver }}.ctx.Unique; unique == nil || *unique { v.Dedup() } return v diff --git a/entc/gen/template/dialect/gremlin/select.tmpl b/entc/gen/template/dialect/gremlin/select.tmpl index da2e3965d..6087b0f97 100644 --- a/entc/gen/template/dialect/gremlin/select.tmpl +++ b/entc/gen/template/dialect/gremlin/select.tmpl @@ -15,15 +15,15 @@ func ({{ $receiver }} *{{ $builder }}) gremlinScan(ctx context.Context, root *{{ res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len({{ $receiver }}.fields) == 1 { - if {{ $receiver }}.fields[0] != {{ $.Package }}.FieldID { - traversal = traversal.Values({{ $receiver }}.fields...) + if fields := {{ $receiver }}.ctx.Fields; len(fields) == 1 { + if fields[0] != {{ $.Package }}.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len({{ $receiver }}.fields)) - for i, f := range {{ $receiver }}.fields { + fields := make([]any, len({{ $receiver }}.ctx.Fields)) + for i, f := range {{ $receiver }}.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -32,7 +32,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlinScan(ctx context.Context, root *{{ if err := {{ $receiver }}.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/gen/template/dialect/sql/query.tmpl b/entc/gen/template/dialect/sql/query.tmpl index 209053b8c..c2562306a 100644 --- a/entc/gen/template/dialect/sql/query.tmpl +++ b/entc/gen/template/dialect/sql/query.tmpl @@ -249,10 +249,10 @@ func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error _spec.Unique = false _spec.Node.Columns = nil {{- else }} - _spec.Node.Columns = {{ $receiver }}.fields - if len({{ $receiver }}.fields) > 0 { + _spec.Node.Columns = {{ $receiver }}.ctx.Fields + if len({{ $receiver }}.ctx.Fields) > 0 { {{- /* In case of field selection, configure query to unique only if was explicitly set to true. */}} - _spec.Unique = {{ $receiver }}.unique != nil && *{{ $receiver }}.unique + _spec.Unique = {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique } {{- end }} return sqlgraph.CountNodes(ctx, {{ $receiver }}.driver, _spec) @@ -273,10 +273,10 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec { From: {{ $receiver }}.sql, Unique: true, } - if unique := {{ $receiver }}.unique; unique != nil { + if unique := {{ $receiver }}.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := {{ $receiver }}.fields; len(fields) > 0 { + if fields := {{ $receiver }}.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) {{- if $.HasOneFieldID }} _spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }}) @@ -298,10 +298,10 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec { } } } - if limit := {{ $receiver }}.limit; limit != nil { + if limit := {{ $receiver }}.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := {{ $receiver }}.offset; offset != nil { + if offset := {{ $receiver }}.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := {{ $receiver }}.order; len(ps) > 0 { @@ -333,7 +333,7 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec { func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect({{ $receiver }}.driver.Dialect()) t1 := builder.Table({{ $.Package }}.Table) - columns := {{ $receiver }}.fields + columns := {{ $receiver }}.ctx.Fields if len(columns) == 0 { columns = {{ $.Package }}.Columns } @@ -342,7 +342,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select selector = {{ $receiver }}.sql selector.Select(selector.Columns(columns...)...) } - if {{ $receiver }}.unique != nil && *{{ $receiver }}.unique { + if {{ $receiver }}.ctx.Unique != nil && *{{ $receiver }}.ctx.Unique { selector.Distinct() } {{- /* Allow mutating the sql.Selector by ent extensions or user templates.*/}} @@ -357,12 +357,12 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select for _, p := range {{ $receiver }}.order { p(selector) } - if offset := {{ $receiver }}.offset; offset != nil { + if offset := {{ $receiver }}.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := {{ $receiver }}.limit; limit != nil { + if limit := {{ $receiver }}.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -440,7 +440,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select {{ define "dialect/sql/query/preparecheck" }} {{- $pkg := $.Scope.Package }} {{- $receiver := $.Scope.Receiver }} - for _, f := range {{ $receiver }}.fields { + for _, f := range {{ $receiver }}.ctx.Fields { if !{{ $.Package }}.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("{{ $pkg }}: invalid field %q for query", f)} } diff --git a/entc/gen/type.go b/entc/gen/type.go index 217f5b772..fb67bb4e1 100644 --- a/entc/gen/type.go +++ b/entc/gen/type.go @@ -2086,6 +2086,7 @@ var ( // private fields used by the different builders. privateField = names( "config", + "ctx", "done", "hooks", "inters", diff --git a/entc/integration/cascadelete/ent/client.go b/entc/integration/cascadelete/ent/client.go index fc6f66916..5b76495c5 100644 --- a/entc/integration/cascadelete/ent/client.go +++ b/entc/integration/cascadelete/ent/client.go @@ -237,6 +237,7 @@ func (c *CommentClient) DeleteOneID(id int) *CommentDeleteOne { func (c *CommentClient) Query() *CommentQuery { return &CommentQuery{ config: c.config, + ctx: &QueryContext{Type: TypeComment}, inters: c.Interceptors(), } } @@ -370,6 +371,7 @@ func (c *PostClient) DeleteOneID(id int) *PostDeleteOne { func (c *PostClient) Query() *PostQuery { return &PostQuery{ config: c.config, + ctx: &QueryContext{Type: TypePost}, inters: c.Interceptors(), } } @@ -519,6 +521,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/cascadelete/ent/comment_query.go b/entc/integration/cascadelete/ent/comment_query.go index b0d4164fc..c90fd4c48 100644 --- a/entc/integration/cascadelete/ent/comment_query.go +++ b/entc/integration/cascadelete/ent/comment_query.go @@ -22,11 +22,8 @@ import ( // CommentQuery is the builder for querying Comment entities. type CommentQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Comment withPost *PostQuery @@ -43,20 +40,20 @@ func (cq *CommentQuery) Where(ps ...predicate.Comment) *CommentQuery { // Limit the number of records to be returned by this query. func (cq *CommentQuery) Limit(limit int) *CommentQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CommentQuery) Offset(offset int) *CommentQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CommentQuery) Unique(unique bool) *CommentQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -91,7 +88,7 @@ func (cq *CommentQuery) QueryPost() *PostQuery { // First returns the first Comment entity from the query. // Returns a *NotFoundError when no Comment was found. func (cq *CommentQuery) First(ctx context.Context) (*Comment, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeComment, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -114,7 +111,7 @@ func (cq *CommentQuery) FirstX(ctx context.Context) *Comment { // Returns a *NotFoundError when no Comment ID was found. func (cq *CommentQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeComment, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -137,7 +134,7 @@ func (cq *CommentQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Comment entity is found. // Returns a *NotFoundError when no Comment entities are found. func (cq *CommentQuery) Only(ctx context.Context) (*Comment, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeComment, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -165,7 +162,7 @@ func (cq *CommentQuery) OnlyX(ctx context.Context) *Comment { // Returns a *NotFoundError when no entities are found. func (cq *CommentQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeComment, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -190,7 +187,7 @@ func (cq *CommentQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Comments. func (cq *CommentQuery) All(ctx context.Context) ([]*Comment, error) { - ctx = newQueryContext(ctx, TypeComment, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -210,7 +207,7 @@ func (cq *CommentQuery) AllX(ctx context.Context) []*Comment { // IDs executes the query and returns a list of Comment IDs. func (cq *CommentQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeComment, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(comment.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -228,7 +225,7 @@ func (cq *CommentQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *CommentQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeComment, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -246,7 +243,7 @@ func (cq *CommentQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CommentQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeComment, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -274,16 +271,14 @@ func (cq *CommentQuery) Clone() *CommentQuery { } return &CommentQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Comment{}, cq.predicates...), withPost: cq.withPost.Clone(), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -313,9 +308,9 @@ func (cq *CommentQuery) WithPost(opts ...func(*PostQuery)) *CommentQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CommentQuery) GroupBy(field string, fields ...string) *CommentGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CommentGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = comment.Label grbuild.scan = grbuild.Scan return grbuild @@ -334,10 +329,10 @@ func (cq *CommentQuery) GroupBy(field string, fields ...string) *CommentGroupBy // Select(comment.FieldText). // Scan(ctx, &v) func (cq *CommentQuery) Select(fields ...string) *CommentSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CommentSelect{CommentQuery: cq} sbuild.label = comment.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -357,7 +352,7 @@ func (cq *CommentQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !comment.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -439,9 +434,9 @@ func (cq *CommentQuery) loadPost(ctx context.Context, query *PostQuery, nodes [] func (cq *CommentQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -459,10 +454,10 @@ func (cq *CommentQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, comment.FieldID) for i := range fields { @@ -478,10 +473,10 @@ func (cq *CommentQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -497,7 +492,7 @@ func (cq *CommentQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CommentQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(comment.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = comment.Columns } @@ -506,7 +501,7 @@ func (cq *CommentQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -515,12 +510,12 @@ func (cq *CommentQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -540,7 +535,7 @@ func (cgb *CommentGroupBy) Aggregate(fns ...AggregateFunc) *CommentGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CommentGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeComment, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -588,7 +583,7 @@ func (cs *CommentSelect) Aggregate(fns ...AggregateFunc) *CommentSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CommentSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeComment, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/cascadelete/ent/ent.go b/entc/integration/cascadelete/ent/ent.go index 27aa4671f..c99218100 100644 --- a/entc/integration/cascadelete/ent/ent.go +++ b/entc/integration/cascadelete/ent/ent.go @@ -26,6 +26,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -511,10 +512,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/cascadelete/ent/post_query.go b/entc/integration/cascadelete/ent/post_query.go index 506937611..ea119aece 100644 --- a/entc/integration/cascadelete/ent/post_query.go +++ b/entc/integration/cascadelete/ent/post_query.go @@ -24,11 +24,8 @@ import ( // PostQuery is the builder for querying Post entities. type PostQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Post withAuthor *UserQuery @@ -46,20 +43,20 @@ func (pq *PostQuery) Where(ps ...predicate.Post) *PostQuery { // Limit the number of records to be returned by this query. func (pq *PostQuery) Limit(limit int) *PostQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PostQuery) Offset(offset int) *PostQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PostQuery) Unique(unique bool) *PostQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -116,7 +113,7 @@ func (pq *PostQuery) QueryComments() *CommentQuery { // First returns the first Post entity from the query. // Returns a *NotFoundError when no Post was found. func (pq *PostQuery) First(ctx context.Context) (*Post, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePost, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -139,7 +136,7 @@ func (pq *PostQuery) FirstX(ctx context.Context) *Post { // Returns a *NotFoundError when no Post ID was found. func (pq *PostQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePost, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -162,7 +159,7 @@ func (pq *PostQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Post entity is found. // Returns a *NotFoundError when no Post entities are found. func (pq *PostQuery) Only(ctx context.Context) (*Post, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePost, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -190,7 +187,7 @@ func (pq *PostQuery) OnlyX(ctx context.Context) *Post { // Returns a *NotFoundError when no entities are found. func (pq *PostQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePost, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -215,7 +212,7 @@ func (pq *PostQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Posts. func (pq *PostQuery) All(ctx context.Context) ([]*Post, error) { - ctx = newQueryContext(ctx, TypePost, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -235,7 +232,7 @@ func (pq *PostQuery) AllX(ctx context.Context) []*Post { // IDs executes the query and returns a list of Post IDs. func (pq *PostQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypePost, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(post.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -253,7 +250,7 @@ func (pq *PostQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pq *PostQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePost, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -271,7 +268,7 @@ func (pq *PostQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PostQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePost, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -299,17 +296,15 @@ func (pq *PostQuery) Clone() *PostQuery { } return &PostQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Post{}, pq.predicates...), withAuthor: pq.withAuthor.Clone(), withComments: pq.withComments.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -350,9 +345,9 @@ func (pq *PostQuery) WithComments(opts ...func(*CommentQuery)) *PostQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (pq *PostQuery) GroupBy(field string, fields ...string) *PostGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PostGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = post.Label grbuild.scan = grbuild.Scan return grbuild @@ -371,10 +366,10 @@ func (pq *PostQuery) GroupBy(field string, fields ...string) *PostGroupBy { // Select(post.FieldText). // Scan(ctx, &v) func (pq *PostQuery) Select(fields ...string) *PostSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PostSelect{PostQuery: pq} sbuild.label = post.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -394,7 +389,7 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !post.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -511,9 +506,9 @@ func (pq *PostQuery) loadComments(ctx context.Context, query *CommentQuery, node func (pq *PostQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -531,10 +526,10 @@ func (pq *PostQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, post.FieldID) for i := range fields { @@ -550,10 +545,10 @@ func (pq *PostQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -569,7 +564,7 @@ func (pq *PostQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PostQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(post.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = post.Columns } @@ -578,7 +573,7 @@ func (pq *PostQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } for _, p := range pq.predicates { @@ -587,12 +582,12 @@ func (pq *PostQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -612,7 +607,7 @@ func (pgb *PostGroupBy) Aggregate(fns ...AggregateFunc) *PostGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PostGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePost, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -660,7 +655,7 @@ func (ps *PostSelect) Aggregate(fns ...AggregateFunc) *PostSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PostSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePost, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/cascadelete/ent/user_query.go b/entc/integration/cascadelete/ent/user_query.go index c4b047d67..22a55d938 100644 --- a/entc/integration/cascadelete/ent/user_query.go +++ b/entc/integration/cascadelete/ent/user_query.go @@ -23,11 +23,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withPosts *PostQuery @@ -44,20 +41,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -92,7 +89,7 @@ func (uq *UserQuery) QueryPosts() *PostQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), withPosts: uq.withPosts.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -314,9 +309,9 @@ func (uq *UserQuery) WithPosts(opts ...func(*PostQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -335,10 +330,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldName). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -358,7 +353,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -439,9 +434,9 @@ func (uq *UserQuery) loadPosts(ctx context.Context, query *PostQuery, nodes []*U func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -459,10 +454,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -478,10 +473,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -497,7 +492,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -506,7 +501,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -515,12 +510,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -540,7 +535,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -588,7 +583,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/config/ent/client.go b/entc/integration/config/ent/client.go index 58ab340d0..a1f1fefc0 100644 --- a/entc/integration/config/ent/client.go +++ b/entc/integration/config/ent/client.go @@ -216,6 +216,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/config/ent/ent.go b/entc/integration/config/ent/ent.go index 0b1ff16f8..3a2349a62 100644 --- a/entc/integration/config/ent/ent.go +++ b/entc/integration/config/ent/ent.go @@ -24,6 +24,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -507,10 +508,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/config/ent/user_query.go b/entc/integration/config/ent/user_query.go index f7867f841..4ab865448 100644 --- a/entc/integration/config/ent/user_query.go +++ b/entc/integration/config/ent/user_query.go @@ -21,11 +21,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -67,7 +64,7 @@ func (uq *UserQuery) Order(o ...OrderFunc) *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -277,9 +272,9 @@ func (uq *UserQuery) Clone() *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldName). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -383,10 +378,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -430,7 +425,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -439,12 +434,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/account_query.go b/entc/integration/customid/ent/account_query.go index 76ae2d304..9c6087435 100644 --- a/entc/integration/customid/ent/account_query.go +++ b/entc/integration/customid/ent/account_query.go @@ -24,11 +24,8 @@ import ( // AccountQuery is the builder for querying Account entities. type AccountQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Account withToken *TokenQuery @@ -45,20 +42,20 @@ func (aq *AccountQuery) Where(ps ...predicate.Account) *AccountQuery { // Limit the number of records to be returned by this query. func (aq *AccountQuery) Limit(limit int) *AccountQuery { - aq.limit = &limit + aq.ctx.Limit = &limit return aq } // Offset to start from. func (aq *AccountQuery) Offset(offset int) *AccountQuery { - aq.offset = &offset + aq.ctx.Offset = &offset return aq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (aq *AccountQuery) Unique(unique bool) *AccountQuery { - aq.unique = &unique + aq.ctx.Unique = &unique return aq } @@ -93,7 +90,7 @@ func (aq *AccountQuery) QueryToken() *TokenQuery { // First returns the first Account entity from the query. // Returns a *NotFoundError when no Account was found. func (aq *AccountQuery) First(ctx context.Context) (*Account, error) { - nodes, err := aq.Limit(1).All(newQueryContext(ctx, TypeAccount, "First")) + nodes, err := aq.Limit(1).All(setContextOp(ctx, aq.ctx, "First")) if err != nil { return nil, err } @@ -116,7 +113,7 @@ func (aq *AccountQuery) FirstX(ctx context.Context) *Account { // Returns a *NotFoundError when no Account ID was found. func (aq *AccountQuery) FirstID(ctx context.Context) (id sid.ID, err error) { var ids []sid.ID - if ids, err = aq.Limit(1).IDs(newQueryContext(ctx, TypeAccount, "FirstID")); err != nil { + if ids, err = aq.Limit(1).IDs(setContextOp(ctx, aq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -139,7 +136,7 @@ func (aq *AccountQuery) FirstIDX(ctx context.Context) sid.ID { // Returns a *NotSingularError when more than one Account entity is found. // Returns a *NotFoundError when no Account entities are found. func (aq *AccountQuery) Only(ctx context.Context) (*Account, error) { - nodes, err := aq.Limit(2).All(newQueryContext(ctx, TypeAccount, "Only")) + nodes, err := aq.Limit(2).All(setContextOp(ctx, aq.ctx, "Only")) if err != nil { return nil, err } @@ -167,7 +164,7 @@ func (aq *AccountQuery) OnlyX(ctx context.Context) *Account { // Returns a *NotFoundError when no entities are found. func (aq *AccountQuery) OnlyID(ctx context.Context) (id sid.ID, err error) { var ids []sid.ID - if ids, err = aq.Limit(2).IDs(newQueryContext(ctx, TypeAccount, "OnlyID")); err != nil { + if ids, err = aq.Limit(2).IDs(setContextOp(ctx, aq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -192,7 +189,7 @@ func (aq *AccountQuery) OnlyIDX(ctx context.Context) sid.ID { // All executes the query and returns a list of Accounts. func (aq *AccountQuery) All(ctx context.Context) ([]*Account, error) { - ctx = newQueryContext(ctx, TypeAccount, "All") + ctx = setContextOp(ctx, aq.ctx, "All") if err := aq.prepareQuery(ctx); err != nil { return nil, err } @@ -212,7 +209,7 @@ func (aq *AccountQuery) AllX(ctx context.Context) []*Account { // IDs executes the query and returns a list of Account IDs. func (aq *AccountQuery) IDs(ctx context.Context) ([]sid.ID, error) { var ids []sid.ID - ctx = newQueryContext(ctx, TypeAccount, "IDs") + ctx = setContextOp(ctx, aq.ctx, "IDs") if err := aq.Select(account.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -230,7 +227,7 @@ func (aq *AccountQuery) IDsX(ctx context.Context) []sid.ID { // Count returns the count of the given query. func (aq *AccountQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeAccount, "Count") + ctx = setContextOp(ctx, aq.ctx, "Count") if err := aq.prepareQuery(ctx); err != nil { return 0, err } @@ -248,7 +245,7 @@ func (aq *AccountQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (aq *AccountQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeAccount, "Exist") + ctx = setContextOp(ctx, aq.ctx, "Exist") switch _, err := aq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -276,16 +273,14 @@ func (aq *AccountQuery) Clone() *AccountQuery { } return &AccountQuery{ config: aq.config, - limit: aq.limit, - offset: aq.offset, + ctx: aq.ctx.Clone(), order: append([]OrderFunc{}, aq.order...), inters: append([]Interceptor{}, aq.inters...), predicates: append([]predicate.Account{}, aq.predicates...), withToken: aq.withToken.Clone(), // clone intermediate query. - sql: aq.sql.Clone(), - path: aq.path, - unique: aq.unique, + sql: aq.sql.Clone(), + path: aq.path, } } @@ -315,9 +310,9 @@ func (aq *AccountQuery) WithToken(opts ...func(*TokenQuery)) *AccountQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (aq *AccountQuery) GroupBy(field string, fields ...string) *AccountGroupBy { - aq.fields = append([]string{field}, fields...) + aq.ctx.Fields = append([]string{field}, fields...) grbuild := &AccountGroupBy{build: aq} - grbuild.flds = &aq.fields + grbuild.flds = &aq.ctx.Fields grbuild.label = account.Label grbuild.scan = grbuild.Scan return grbuild @@ -336,10 +331,10 @@ func (aq *AccountQuery) GroupBy(field string, fields ...string) *AccountGroupBy // Select(account.FieldEmail). // Scan(ctx, &v) func (aq *AccountQuery) Select(fields ...string) *AccountSelect { - aq.fields = append(aq.fields, fields...) + aq.ctx.Fields = append(aq.ctx.Fields, fields...) sbuild := &AccountSelect{AccountQuery: aq} sbuild.label = account.Label - sbuild.flds, sbuild.scan = &aq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &aq.ctx.Fields, sbuild.Scan return sbuild } @@ -359,7 +354,7 @@ func (aq *AccountQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range aq.fields { + for _, f := range aq.ctx.Fields { if !account.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -444,9 +439,9 @@ func (aq *AccountQuery) loadToken(ctx context.Context, query *TokenQuery, nodes func (aq *AccountQuery) sqlCount(ctx context.Context) (int, error) { _spec := aq.querySpec() - _spec.Node.Columns = aq.fields - if len(aq.fields) > 0 { - _spec.Unique = aq.unique != nil && *aq.unique + _spec.Node.Columns = aq.ctx.Fields + if len(aq.ctx.Fields) > 0 { + _spec.Unique = aq.ctx.Unique != nil && *aq.ctx.Unique } return sqlgraph.CountNodes(ctx, aq.driver, _spec) } @@ -464,10 +459,10 @@ func (aq *AccountQuery) querySpec() *sqlgraph.QuerySpec { From: aq.sql, Unique: true, } - if unique := aq.unique; unique != nil { + if unique := aq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := aq.fields; len(fields) > 0 { + if fields := aq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, account.FieldID) for i := range fields { @@ -483,10 +478,10 @@ func (aq *AccountQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := aq.order; len(ps) > 0 { @@ -502,7 +497,7 @@ func (aq *AccountQuery) querySpec() *sqlgraph.QuerySpec { func (aq *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(aq.driver.Dialect()) t1 := builder.Table(account.Table) - columns := aq.fields + columns := aq.ctx.Fields if len(columns) == 0 { columns = account.Columns } @@ -511,7 +506,7 @@ func (aq *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = aq.sql selector.Select(selector.Columns(columns...)...) } - if aq.unique != nil && *aq.unique { + if aq.ctx.Unique != nil && *aq.ctx.Unique { selector.Distinct() } for _, p := range aq.predicates { @@ -520,12 +515,12 @@ func (aq *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range aq.order { p(selector) } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -545,7 +540,7 @@ func (agb *AccountGroupBy) Aggregate(fns ...AggregateFunc) *AccountGroupBy { // Scan applies the selector query and scans the result into the given value. func (agb *AccountGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeAccount, "GroupBy") + ctx = setContextOp(ctx, agb.build.ctx, "GroupBy") if err := agb.build.prepareQuery(ctx); err != nil { return err } @@ -593,7 +588,7 @@ func (as *AccountSelect) Aggregate(fns ...AggregateFunc) *AccountSelect { // Scan applies the selector query and scans the result into the given value. func (as *AccountSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeAccount, "Select") + ctx = setContextOp(ctx, as.ctx, "Select") if err := as.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/blob_query.go b/entc/integration/customid/ent/blob_query.go index 01aba2c27..633410fa4 100644 --- a/entc/integration/customid/ent/blob_query.go +++ b/entc/integration/customid/ent/blob_query.go @@ -24,11 +24,8 @@ import ( // BlobQuery is the builder for querying Blob entities. type BlobQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Blob withParent *BlobQuery @@ -48,20 +45,20 @@ func (bq *BlobQuery) Where(ps ...predicate.Blob) *BlobQuery { // Limit the number of records to be returned by this query. func (bq *BlobQuery) Limit(limit int) *BlobQuery { - bq.limit = &limit + bq.ctx.Limit = &limit return bq } // Offset to start from. func (bq *BlobQuery) Offset(offset int) *BlobQuery { - bq.offset = &offset + bq.ctx.Offset = &offset return bq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (bq *BlobQuery) Unique(unique bool) *BlobQuery { - bq.unique = &unique + bq.ctx.Unique = &unique return bq } @@ -140,7 +137,7 @@ func (bq *BlobQuery) QueryBlobLinks() *BlobLinkQuery { // First returns the first Blob entity from the query. // Returns a *NotFoundError when no Blob was found. func (bq *BlobQuery) First(ctx context.Context) (*Blob, error) { - nodes, err := bq.Limit(1).All(newQueryContext(ctx, TypeBlob, "First")) + nodes, err := bq.Limit(1).All(setContextOp(ctx, bq.ctx, "First")) if err != nil { return nil, err } @@ -163,7 +160,7 @@ func (bq *BlobQuery) FirstX(ctx context.Context) *Blob { // Returns a *NotFoundError when no Blob ID was found. func (bq *BlobQuery) FirstID(ctx context.Context) (id uuid.UUID, err error) { var ids []uuid.UUID - if ids, err = bq.Limit(1).IDs(newQueryContext(ctx, TypeBlob, "FirstID")); err != nil { + if ids, err = bq.Limit(1).IDs(setContextOp(ctx, bq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -186,7 +183,7 @@ func (bq *BlobQuery) FirstIDX(ctx context.Context) uuid.UUID { // Returns a *NotSingularError when more than one Blob entity is found. // Returns a *NotFoundError when no Blob entities are found. func (bq *BlobQuery) Only(ctx context.Context) (*Blob, error) { - nodes, err := bq.Limit(2).All(newQueryContext(ctx, TypeBlob, "Only")) + nodes, err := bq.Limit(2).All(setContextOp(ctx, bq.ctx, "Only")) if err != nil { return nil, err } @@ -214,7 +211,7 @@ func (bq *BlobQuery) OnlyX(ctx context.Context) *Blob { // Returns a *NotFoundError when no entities are found. func (bq *BlobQuery) OnlyID(ctx context.Context) (id uuid.UUID, err error) { var ids []uuid.UUID - if ids, err = bq.Limit(2).IDs(newQueryContext(ctx, TypeBlob, "OnlyID")); err != nil { + if ids, err = bq.Limit(2).IDs(setContextOp(ctx, bq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -239,7 +236,7 @@ func (bq *BlobQuery) OnlyIDX(ctx context.Context) uuid.UUID { // All executes the query and returns a list of Blobs. func (bq *BlobQuery) All(ctx context.Context) ([]*Blob, error) { - ctx = newQueryContext(ctx, TypeBlob, "All") + ctx = setContextOp(ctx, bq.ctx, "All") if err := bq.prepareQuery(ctx); err != nil { return nil, err } @@ -259,7 +256,7 @@ func (bq *BlobQuery) AllX(ctx context.Context) []*Blob { // IDs executes the query and returns a list of Blob IDs. func (bq *BlobQuery) IDs(ctx context.Context) ([]uuid.UUID, error) { var ids []uuid.UUID - ctx = newQueryContext(ctx, TypeBlob, "IDs") + ctx = setContextOp(ctx, bq.ctx, "IDs") if err := bq.Select(blob.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -277,7 +274,7 @@ func (bq *BlobQuery) IDsX(ctx context.Context) []uuid.UUID { // Count returns the count of the given query. func (bq *BlobQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeBlob, "Count") + ctx = setContextOp(ctx, bq.ctx, "Count") if err := bq.prepareQuery(ctx); err != nil { return 0, err } @@ -295,7 +292,7 @@ func (bq *BlobQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (bq *BlobQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeBlob, "Exist") + ctx = setContextOp(ctx, bq.ctx, "Exist") switch _, err := bq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -323,8 +320,7 @@ func (bq *BlobQuery) Clone() *BlobQuery { } return &BlobQuery{ config: bq.config, - limit: bq.limit, - offset: bq.offset, + ctx: bq.ctx.Clone(), order: append([]OrderFunc{}, bq.order...), inters: append([]Interceptor{}, bq.inters...), predicates: append([]predicate.Blob{}, bq.predicates...), @@ -332,9 +328,8 @@ func (bq *BlobQuery) Clone() *BlobQuery { withLinks: bq.withLinks.Clone(), withBlobLinks: bq.withBlobLinks.Clone(), // clone intermediate query. - sql: bq.sql.Clone(), - path: bq.path, - unique: bq.unique, + sql: bq.sql.Clone(), + path: bq.path, } } @@ -386,9 +381,9 @@ func (bq *BlobQuery) WithBlobLinks(opts ...func(*BlobLinkQuery)) *BlobQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (bq *BlobQuery) GroupBy(field string, fields ...string) *BlobGroupBy { - bq.fields = append([]string{field}, fields...) + bq.ctx.Fields = append([]string{field}, fields...) grbuild := &BlobGroupBy{build: bq} - grbuild.flds = &bq.fields + grbuild.flds = &bq.ctx.Fields grbuild.label = blob.Label grbuild.scan = grbuild.Scan return grbuild @@ -407,10 +402,10 @@ func (bq *BlobQuery) GroupBy(field string, fields ...string) *BlobGroupBy { // Select(blob.FieldUUID). // Scan(ctx, &v) func (bq *BlobQuery) Select(fields ...string) *BlobSelect { - bq.fields = append(bq.fields, fields...) + bq.ctx.Fields = append(bq.ctx.Fields, fields...) sbuild := &BlobSelect{BlobQuery: bq} sbuild.label = blob.Label - sbuild.flds, sbuild.scan = &bq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &bq.ctx.Fields, sbuild.Scan return sbuild } @@ -430,7 +425,7 @@ func (bq *BlobQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range bq.fields { + for _, f := range bq.ctx.Fields { if !blob.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -623,9 +618,9 @@ func (bq *BlobQuery) loadBlobLinks(ctx context.Context, query *BlobLinkQuery, no func (bq *BlobQuery) sqlCount(ctx context.Context) (int, error) { _spec := bq.querySpec() - _spec.Node.Columns = bq.fields - if len(bq.fields) > 0 { - _spec.Unique = bq.unique != nil && *bq.unique + _spec.Node.Columns = bq.ctx.Fields + if len(bq.ctx.Fields) > 0 { + _spec.Unique = bq.ctx.Unique != nil && *bq.ctx.Unique } return sqlgraph.CountNodes(ctx, bq.driver, _spec) } @@ -643,10 +638,10 @@ func (bq *BlobQuery) querySpec() *sqlgraph.QuerySpec { From: bq.sql, Unique: true, } - if unique := bq.unique; unique != nil { + if unique := bq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := bq.fields; len(fields) > 0 { + if fields := bq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, blob.FieldID) for i := range fields { @@ -662,10 +657,10 @@ func (bq *BlobQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := bq.order; len(ps) > 0 { @@ -681,7 +676,7 @@ func (bq *BlobQuery) querySpec() *sqlgraph.QuerySpec { func (bq *BlobQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(bq.driver.Dialect()) t1 := builder.Table(blob.Table) - columns := bq.fields + columns := bq.ctx.Fields if len(columns) == 0 { columns = blob.Columns } @@ -690,7 +685,7 @@ func (bq *BlobQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = bq.sql selector.Select(selector.Columns(columns...)...) } - if bq.unique != nil && *bq.unique { + if bq.ctx.Unique != nil && *bq.ctx.Unique { selector.Distinct() } for _, p := range bq.predicates { @@ -699,12 +694,12 @@ func (bq *BlobQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range bq.order { p(selector) } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -724,7 +719,7 @@ func (bgb *BlobGroupBy) Aggregate(fns ...AggregateFunc) *BlobGroupBy { // Scan applies the selector query and scans the result into the given value. func (bgb *BlobGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeBlob, "GroupBy") + ctx = setContextOp(ctx, bgb.build.ctx, "GroupBy") if err := bgb.build.prepareQuery(ctx); err != nil { return err } @@ -772,7 +767,7 @@ func (bs *BlobSelect) Aggregate(fns ...AggregateFunc) *BlobSelect { // Scan applies the selector query and scans the result into the given value. func (bs *BlobSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeBlob, "Select") + ctx = setContextOp(ctx, bs.ctx, "Select") if err := bs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/bloblink_query.go b/entc/integration/customid/ent/bloblink_query.go index 3e80b17cb..f98e8b894 100644 --- a/entc/integration/customid/ent/bloblink_query.go +++ b/entc/integration/customid/ent/bloblink_query.go @@ -22,11 +22,8 @@ import ( // BlobLinkQuery is the builder for querying BlobLink entities. type BlobLinkQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.BlobLink withBlob *BlobQuery @@ -44,20 +41,20 @@ func (blq *BlobLinkQuery) Where(ps ...predicate.BlobLink) *BlobLinkQuery { // Limit the number of records to be returned by this query. func (blq *BlobLinkQuery) Limit(limit int) *BlobLinkQuery { - blq.limit = &limit + blq.ctx.Limit = &limit return blq } // Offset to start from. func (blq *BlobLinkQuery) Offset(offset int) *BlobLinkQuery { - blq.offset = &offset + blq.ctx.Offset = &offset return blq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (blq *BlobLinkQuery) Unique(unique bool) *BlobLinkQuery { - blq.unique = &unique + blq.ctx.Unique = &unique return blq } @@ -114,7 +111,7 @@ func (blq *BlobLinkQuery) QueryLink() *BlobQuery { // First returns the first BlobLink entity from the query. // Returns a *NotFoundError when no BlobLink was found. func (blq *BlobLinkQuery) First(ctx context.Context) (*BlobLink, error) { - nodes, err := blq.Limit(1).All(newQueryContext(ctx, TypeBlobLink, "First")) + nodes, err := blq.Limit(1).All(setContextOp(ctx, blq.ctx, "First")) if err != nil { return nil, err } @@ -137,7 +134,7 @@ func (blq *BlobLinkQuery) FirstX(ctx context.Context) *BlobLink { // Returns a *NotSingularError when more than one BlobLink entity is found. // Returns a *NotFoundError when no BlobLink entities are found. func (blq *BlobLinkQuery) Only(ctx context.Context) (*BlobLink, error) { - nodes, err := blq.Limit(2).All(newQueryContext(ctx, TypeBlobLink, "Only")) + nodes, err := blq.Limit(2).All(setContextOp(ctx, blq.ctx, "Only")) if err != nil { return nil, err } @@ -162,7 +159,7 @@ func (blq *BlobLinkQuery) OnlyX(ctx context.Context) *BlobLink { // All executes the query and returns a list of BlobLinks. func (blq *BlobLinkQuery) All(ctx context.Context) ([]*BlobLink, error) { - ctx = newQueryContext(ctx, TypeBlobLink, "All") + ctx = setContextOp(ctx, blq.ctx, "All") if err := blq.prepareQuery(ctx); err != nil { return nil, err } @@ -181,7 +178,7 @@ func (blq *BlobLinkQuery) AllX(ctx context.Context) []*BlobLink { // Count returns the count of the given query. func (blq *BlobLinkQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeBlobLink, "Count") + ctx = setContextOp(ctx, blq.ctx, "Count") if err := blq.prepareQuery(ctx); err != nil { return 0, err } @@ -199,7 +196,7 @@ func (blq *BlobLinkQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (blq *BlobLinkQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeBlobLink, "Exist") + ctx = setContextOp(ctx, blq.ctx, "Exist") switch _, err := blq.First(ctx); { case IsNotFound(err): return false, nil @@ -227,17 +224,15 @@ func (blq *BlobLinkQuery) Clone() *BlobLinkQuery { } return &BlobLinkQuery{ config: blq.config, - limit: blq.limit, - offset: blq.offset, + ctx: blq.ctx.Clone(), order: append([]OrderFunc{}, blq.order...), inters: append([]Interceptor{}, blq.inters...), predicates: append([]predicate.BlobLink{}, blq.predicates...), withBlob: blq.withBlob.Clone(), withLink: blq.withLink.Clone(), // clone intermediate query. - sql: blq.sql.Clone(), - path: blq.path, - unique: blq.unique, + sql: blq.sql.Clone(), + path: blq.path, } } @@ -278,9 +273,9 @@ func (blq *BlobLinkQuery) WithLink(opts ...func(*BlobQuery)) *BlobLinkQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (blq *BlobLinkQuery) GroupBy(field string, fields ...string) *BlobLinkGroupBy { - blq.fields = append([]string{field}, fields...) + blq.ctx.Fields = append([]string{field}, fields...) grbuild := &BlobLinkGroupBy{build: blq} - grbuild.flds = &blq.fields + grbuild.flds = &blq.ctx.Fields grbuild.label = bloblink.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (blq *BlobLinkQuery) GroupBy(field string, fields ...string) *BlobLinkGroup // Select(bloblink.FieldCreatedAt). // Scan(ctx, &v) func (blq *BlobLinkQuery) Select(fields ...string) *BlobLinkSelect { - blq.fields = append(blq.fields, fields...) + blq.ctx.Fields = append(blq.ctx.Fields, fields...) sbuild := &BlobLinkSelect{BlobLinkQuery: blq} sbuild.label = bloblink.Label - sbuild.flds, sbuild.scan = &blq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &blq.ctx.Fields, sbuild.Scan return sbuild } @@ -322,7 +317,7 @@ func (blq *BlobLinkQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range blq.fields { + for _, f := range blq.ctx.Fields { if !bloblink.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -454,10 +449,10 @@ func (blq *BlobLinkQuery) querySpec() *sqlgraph.QuerySpec { From: blq.sql, Unique: true, } - if unique := blq.unique; unique != nil { + if unique := blq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := blq.fields; len(fields) > 0 { + if fields := blq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) for i := range fields { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) @@ -470,10 +465,10 @@ func (blq *BlobLinkQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := blq.limit; limit != nil { + if limit := blq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := blq.offset; offset != nil { + if offset := blq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := blq.order; len(ps) > 0 { @@ -489,7 +484,7 @@ func (blq *BlobLinkQuery) querySpec() *sqlgraph.QuerySpec { func (blq *BlobLinkQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(blq.driver.Dialect()) t1 := builder.Table(bloblink.Table) - columns := blq.fields + columns := blq.ctx.Fields if len(columns) == 0 { columns = bloblink.Columns } @@ -498,7 +493,7 @@ func (blq *BlobLinkQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = blq.sql selector.Select(selector.Columns(columns...)...) } - if blq.unique != nil && *blq.unique { + if blq.ctx.Unique != nil && *blq.ctx.Unique { selector.Distinct() } for _, p := range blq.predicates { @@ -507,12 +502,12 @@ func (blq *BlobLinkQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range blq.order { p(selector) } - if offset := blq.offset; offset != nil { + if offset := blq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := blq.limit; limit != nil { + if limit := blq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -532,7 +527,7 @@ func (blgb *BlobLinkGroupBy) Aggregate(fns ...AggregateFunc) *BlobLinkGroupBy { // Scan applies the selector query and scans the result into the given value. func (blgb *BlobLinkGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeBlobLink, "GroupBy") + ctx = setContextOp(ctx, blgb.build.ctx, "GroupBy") if err := blgb.build.prepareQuery(ctx); err != nil { return err } @@ -580,7 +575,7 @@ func (bls *BlobLinkSelect) Aggregate(fns ...AggregateFunc) *BlobLinkSelect { // Scan applies the selector query and scans the result into the given value. func (bls *BlobLinkSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeBlobLink, "Select") + ctx = setContextOp(ctx, bls.ctx, "Select") if err := bls.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/car_query.go b/entc/integration/customid/ent/car_query.go index 113bab886..491f037a6 100644 --- a/entc/integration/customid/ent/car_query.go +++ b/entc/integration/customid/ent/car_query.go @@ -22,11 +22,8 @@ import ( // CarQuery is the builder for querying Car entities. type CarQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Car withOwner *PetQuery @@ -44,20 +41,20 @@ func (cq *CarQuery) Where(ps ...predicate.Car) *CarQuery { // Limit the number of records to be returned by this query. func (cq *CarQuery) Limit(limit int) *CarQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CarQuery) Offset(offset int) *CarQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CarQuery) Unique(unique bool) *CarQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -92,7 +89,7 @@ func (cq *CarQuery) QueryOwner() *PetQuery { // First returns the first Car entity from the query. // Returns a *NotFoundError when no Car was found. func (cq *CarQuery) First(ctx context.Context) (*Car, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeCar, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (cq *CarQuery) FirstX(ctx context.Context) *Car { // Returns a *NotFoundError when no Car ID was found. func (cq *CarQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeCar, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (cq *CarQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Car entity is found. // Returns a *NotFoundError when no Car entities are found. func (cq *CarQuery) Only(ctx context.Context) (*Car, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeCar, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (cq *CarQuery) OnlyX(ctx context.Context) *Car { // Returns a *NotFoundError when no entities are found. func (cq *CarQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeCar, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (cq *CarQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Cars. func (cq *CarQuery) All(ctx context.Context) ([]*Car, error) { - ctx = newQueryContext(ctx, TypeCar, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (cq *CarQuery) AllX(ctx context.Context) []*Car { // IDs executes the query and returns a list of Car IDs. func (cq *CarQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeCar, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(car.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (cq *CarQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *CarQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCar, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (cq *CarQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CarQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCar, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (cq *CarQuery) Clone() *CarQuery { } return &CarQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Car{}, cq.predicates...), withOwner: cq.withOwner.Clone(), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -314,9 +309,9 @@ func (cq *CarQuery) WithOwner(opts ...func(*PetQuery)) *CarQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CarQuery) GroupBy(field string, fields ...string) *CarGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CarGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = car.Label grbuild.scan = grbuild.Scan return grbuild @@ -335,10 +330,10 @@ func (cq *CarQuery) GroupBy(field string, fields ...string) *CarGroupBy { // Select(car.FieldBeforeID). // Scan(ctx, &v) func (cq *CarQuery) Select(fields ...string) *CarSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CarSelect{CarQuery: cq} sbuild.label = car.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -358,7 +353,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !car.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -450,9 +445,9 @@ func (cq *CarQuery) loadOwner(ctx context.Context, query *PetQuery, nodes []*Car func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -470,10 +465,10 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, car.FieldID) for i := range fields { @@ -489,10 +484,10 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -508,7 +503,7 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(car.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = car.Columns } @@ -517,7 +512,7 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -526,12 +521,12 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -551,7 +546,7 @@ func (cgb *CarGroupBy) Aggregate(fns ...AggregateFunc) *CarGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CarGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCar, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -599,7 +594,7 @@ func (cs *CarSelect) Aggregate(fns ...AggregateFunc) *CarSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CarSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCar, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/client.go b/entc/integration/customid/ent/client.go index 8d12d74cb..d0d3e4826 100644 --- a/entc/integration/customid/ent/client.go +++ b/entc/integration/customid/ent/client.go @@ -381,6 +381,7 @@ func (c *AccountClient) DeleteOneID(id sid.ID) *AccountDeleteOne { func (c *AccountClient) Query() *AccountQuery { return &AccountQuery{ config: c.config, + ctx: &QueryContext{Type: TypeAccount}, inters: c.Interceptors(), } } @@ -514,6 +515,7 @@ func (c *BlobClient) DeleteOneID(id uuid.UUID) *BlobDeleteOne { func (c *BlobClient) Query() *BlobQuery { return &BlobQuery{ config: c.config, + ctx: &QueryContext{Type: TypeBlob}, inters: c.Interceptors(), } } @@ -662,6 +664,7 @@ func (c *BlobLinkClient) Delete() *BlobLinkDelete { func (c *BlobLinkClient) Query() *BlobLinkQuery { return &BlobLinkQuery{ config: c.config, + ctx: &QueryContext{Type: TypeBlobLink}, inters: c.Interceptors(), } } @@ -779,6 +782,7 @@ func (c *CarClient) DeleteOneID(id int) *CarDeleteOne { func (c *CarClient) Query() *CarQuery { return &CarQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCar}, inters: c.Interceptors(), } } @@ -912,6 +916,7 @@ func (c *DeviceClient) DeleteOneID(id schema.ID) *DeviceDeleteOne { func (c *DeviceClient) Query() *DeviceQuery { return &DeviceQuery{ config: c.config, + ctx: &QueryContext{Type: TypeDevice}, inters: c.Interceptors(), } } @@ -1061,6 +1066,7 @@ func (c *DocClient) DeleteOneID(id schema.DocID) *DocDeleteOne { func (c *DocClient) Query() *DocQuery { return &DocQuery{ config: c.config, + ctx: &QueryContext{Type: TypeDoc}, inters: c.Interceptors(), } } @@ -1226,6 +1232,7 @@ func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { func (c *GroupClient) Query() *GroupQuery { return &GroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroup}, inters: c.Interceptors(), } } @@ -1359,6 +1366,7 @@ func (c *IntSIDClient) DeleteOneID(id sid.ID) *IntSIDDeleteOne { func (c *IntSIDClient) Query() *IntSIDQuery { return &IntSIDQuery{ config: c.config, + ctx: &QueryContext{Type: TypeIntSID}, inters: c.Interceptors(), } } @@ -1508,6 +1516,7 @@ func (c *LinkClient) DeleteOneID(id uuidc.UUIDC) *LinkDeleteOne { func (c *LinkClient) Query() *LinkQuery { return &LinkQuery{ config: c.config, + ctx: &QueryContext{Type: TypeLink}, inters: c.Interceptors(), } } @@ -1625,6 +1634,7 @@ func (c *MixinIDClient) DeleteOneID(id uuid.UUID) *MixinIDDeleteOne { func (c *MixinIDClient) Query() *MixinIDQuery { return &MixinIDQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMixinID}, inters: c.Interceptors(), } } @@ -1742,6 +1752,7 @@ func (c *NoteClient) DeleteOneID(id schema.NoteID) *NoteDeleteOne { func (c *NoteClient) Query() *NoteQuery { return &NoteQuery{ config: c.config, + ctx: &QueryContext{Type: TypeNote}, inters: c.Interceptors(), } } @@ -1891,6 +1902,7 @@ func (c *OtherClient) DeleteOneID(id sid.ID) *OtherDeleteOne { func (c *OtherClient) Query() *OtherQuery { return &OtherQuery{ config: c.config, + ctx: &QueryContext{Type: TypeOther}, inters: c.Interceptors(), } } @@ -2008,6 +2020,7 @@ func (c *PetClient) DeleteOneID(id string) *PetDeleteOne { func (c *PetClient) Query() *PetQuery { return &PetQuery{ config: c.config, + ctx: &QueryContext{Type: TypePet}, inters: c.Interceptors(), } } @@ -2189,6 +2202,7 @@ func (c *RevisionClient) DeleteOneID(id string) *RevisionDeleteOne { func (c *RevisionClient) Query() *RevisionQuery { return &RevisionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeRevision}, inters: c.Interceptors(), } } @@ -2306,6 +2320,7 @@ func (c *SessionClient) DeleteOneID(id schema.ID) *SessionDeleteOne { func (c *SessionClient) Query() *SessionQuery { return &SessionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeSession}, inters: c.Interceptors(), } } @@ -2439,6 +2454,7 @@ func (c *TokenClient) DeleteOneID(id sid.ID) *TokenDeleteOne { func (c *TokenClient) Query() *TokenQuery { return &TokenQuery{ config: c.config, + ctx: &QueryContext{Type: TypeToken}, inters: c.Interceptors(), } } @@ -2572,6 +2588,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/customid/ent/device_query.go b/entc/integration/customid/ent/device_query.go index 7e3670593..e91e945ce 100644 --- a/entc/integration/customid/ent/device_query.go +++ b/entc/integration/customid/ent/device_query.go @@ -24,11 +24,8 @@ import ( // DeviceQuery is the builder for querying Device entities. type DeviceQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Device withActiveSession *SessionQuery @@ -47,20 +44,20 @@ func (dq *DeviceQuery) Where(ps ...predicate.Device) *DeviceQuery { // Limit the number of records to be returned by this query. func (dq *DeviceQuery) Limit(limit int) *DeviceQuery { - dq.limit = &limit + dq.ctx.Limit = &limit return dq } // Offset to start from. func (dq *DeviceQuery) Offset(offset int) *DeviceQuery { - dq.offset = &offset + dq.ctx.Offset = &offset return dq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (dq *DeviceQuery) Unique(unique bool) *DeviceQuery { - dq.unique = &unique + dq.ctx.Unique = &unique return dq } @@ -117,7 +114,7 @@ func (dq *DeviceQuery) QuerySessions() *SessionQuery { // First returns the first Device entity from the query. // Returns a *NotFoundError when no Device was found. func (dq *DeviceQuery) First(ctx context.Context) (*Device, error) { - nodes, err := dq.Limit(1).All(newQueryContext(ctx, TypeDevice, "First")) + nodes, err := dq.Limit(1).All(setContextOp(ctx, dq.ctx, "First")) if err != nil { return nil, err } @@ -140,7 +137,7 @@ func (dq *DeviceQuery) FirstX(ctx context.Context) *Device { // Returns a *NotFoundError when no Device ID was found. func (dq *DeviceQuery) FirstID(ctx context.Context) (id schema.ID, err error) { var ids []schema.ID - if ids, err = dq.Limit(1).IDs(newQueryContext(ctx, TypeDevice, "FirstID")); err != nil { + if ids, err = dq.Limit(1).IDs(setContextOp(ctx, dq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -163,7 +160,7 @@ func (dq *DeviceQuery) FirstIDX(ctx context.Context) schema.ID { // Returns a *NotSingularError when more than one Device entity is found. // Returns a *NotFoundError when no Device entities are found. func (dq *DeviceQuery) Only(ctx context.Context) (*Device, error) { - nodes, err := dq.Limit(2).All(newQueryContext(ctx, TypeDevice, "Only")) + nodes, err := dq.Limit(2).All(setContextOp(ctx, dq.ctx, "Only")) if err != nil { return nil, err } @@ -191,7 +188,7 @@ func (dq *DeviceQuery) OnlyX(ctx context.Context) *Device { // Returns a *NotFoundError when no entities are found. func (dq *DeviceQuery) OnlyID(ctx context.Context) (id schema.ID, err error) { var ids []schema.ID - if ids, err = dq.Limit(2).IDs(newQueryContext(ctx, TypeDevice, "OnlyID")); err != nil { + if ids, err = dq.Limit(2).IDs(setContextOp(ctx, dq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -216,7 +213,7 @@ func (dq *DeviceQuery) OnlyIDX(ctx context.Context) schema.ID { // All executes the query and returns a list of Devices. func (dq *DeviceQuery) All(ctx context.Context) ([]*Device, error) { - ctx = newQueryContext(ctx, TypeDevice, "All") + ctx = setContextOp(ctx, dq.ctx, "All") if err := dq.prepareQuery(ctx); err != nil { return nil, err } @@ -236,7 +233,7 @@ func (dq *DeviceQuery) AllX(ctx context.Context) []*Device { // IDs executes the query and returns a list of Device IDs. func (dq *DeviceQuery) IDs(ctx context.Context) ([]schema.ID, error) { var ids []schema.ID - ctx = newQueryContext(ctx, TypeDevice, "IDs") + ctx = setContextOp(ctx, dq.ctx, "IDs") if err := dq.Select(device.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -254,7 +251,7 @@ func (dq *DeviceQuery) IDsX(ctx context.Context) []schema.ID { // Count returns the count of the given query. func (dq *DeviceQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeDevice, "Count") + ctx = setContextOp(ctx, dq.ctx, "Count") if err := dq.prepareQuery(ctx); err != nil { return 0, err } @@ -272,7 +269,7 @@ func (dq *DeviceQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (dq *DeviceQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeDevice, "Exist") + ctx = setContextOp(ctx, dq.ctx, "Exist") switch _, err := dq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -300,17 +297,15 @@ func (dq *DeviceQuery) Clone() *DeviceQuery { } return &DeviceQuery{ config: dq.config, - limit: dq.limit, - offset: dq.offset, + ctx: dq.ctx.Clone(), order: append([]OrderFunc{}, dq.order...), inters: append([]Interceptor{}, dq.inters...), predicates: append([]predicate.Device{}, dq.predicates...), withActiveSession: dq.withActiveSession.Clone(), withSessions: dq.withSessions.Clone(), // clone intermediate query. - sql: dq.sql.Clone(), - path: dq.path, - unique: dq.unique, + sql: dq.sql.Clone(), + path: dq.path, } } @@ -339,9 +334,9 @@ func (dq *DeviceQuery) WithSessions(opts ...func(*SessionQuery)) *DeviceQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (dq *DeviceQuery) GroupBy(field string, fields ...string) *DeviceGroupBy { - dq.fields = append([]string{field}, fields...) + dq.ctx.Fields = append([]string{field}, fields...) grbuild := &DeviceGroupBy{build: dq} - grbuild.flds = &dq.fields + grbuild.flds = &dq.ctx.Fields grbuild.label = device.Label grbuild.scan = grbuild.Scan return grbuild @@ -350,10 +345,10 @@ func (dq *DeviceQuery) GroupBy(field string, fields ...string) *DeviceGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (dq *DeviceQuery) Select(fields ...string) *DeviceSelect { - dq.fields = append(dq.fields, fields...) + dq.ctx.Fields = append(dq.ctx.Fields, fields...) sbuild := &DeviceSelect{DeviceQuery: dq} sbuild.label = device.Label - sbuild.flds, sbuild.scan = &dq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &dq.ctx.Fields, sbuild.Scan return sbuild } @@ -373,7 +368,7 @@ func (dq *DeviceQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range dq.fields { + for _, f := range dq.ctx.Fields { if !device.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -504,9 +499,9 @@ func (dq *DeviceQuery) loadSessions(ctx context.Context, query *SessionQuery, no func (dq *DeviceQuery) sqlCount(ctx context.Context) (int, error) { _spec := dq.querySpec() - _spec.Node.Columns = dq.fields - if len(dq.fields) > 0 { - _spec.Unique = dq.unique != nil && *dq.unique + _spec.Node.Columns = dq.ctx.Fields + if len(dq.ctx.Fields) > 0 { + _spec.Unique = dq.ctx.Unique != nil && *dq.ctx.Unique } return sqlgraph.CountNodes(ctx, dq.driver, _spec) } @@ -524,10 +519,10 @@ func (dq *DeviceQuery) querySpec() *sqlgraph.QuerySpec { From: dq.sql, Unique: true, } - if unique := dq.unique; unique != nil { + if unique := dq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := dq.fields; len(fields) > 0 { + if fields := dq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, device.FieldID) for i := range fields { @@ -543,10 +538,10 @@ func (dq *DeviceQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := dq.order; len(ps) > 0 { @@ -562,7 +557,7 @@ func (dq *DeviceQuery) querySpec() *sqlgraph.QuerySpec { func (dq *DeviceQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dq.driver.Dialect()) t1 := builder.Table(device.Table) - columns := dq.fields + columns := dq.ctx.Fields if len(columns) == 0 { columns = device.Columns } @@ -571,7 +566,7 @@ func (dq *DeviceQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = dq.sql selector.Select(selector.Columns(columns...)...) } - if dq.unique != nil && *dq.unique { + if dq.ctx.Unique != nil && *dq.ctx.Unique { selector.Distinct() } for _, p := range dq.predicates { @@ -580,12 +575,12 @@ func (dq *DeviceQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range dq.order { p(selector) } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -605,7 +600,7 @@ func (dgb *DeviceGroupBy) Aggregate(fns ...AggregateFunc) *DeviceGroupBy { // Scan applies the selector query and scans the result into the given value. func (dgb *DeviceGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeDevice, "GroupBy") + ctx = setContextOp(ctx, dgb.build.ctx, "GroupBy") if err := dgb.build.prepareQuery(ctx); err != nil { return err } @@ -653,7 +648,7 @@ func (ds *DeviceSelect) Aggregate(fns ...AggregateFunc) *DeviceSelect { // Scan applies the selector query and scans the result into the given value. func (ds *DeviceSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeDevice, "Select") + ctx = setContextOp(ctx, ds.ctx, "Select") if err := ds.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/doc_query.go b/entc/integration/customid/ent/doc_query.go index 95e78bd72..e018d332c 100644 --- a/entc/integration/customid/ent/doc_query.go +++ b/entc/integration/customid/ent/doc_query.go @@ -23,11 +23,8 @@ import ( // DocQuery is the builder for querying Doc entities. type DocQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Doc withParent *DocQuery @@ -47,20 +44,20 @@ func (dq *DocQuery) Where(ps ...predicate.Doc) *DocQuery { // Limit the number of records to be returned by this query. func (dq *DocQuery) Limit(limit int) *DocQuery { - dq.limit = &limit + dq.ctx.Limit = &limit return dq } // Offset to start from. func (dq *DocQuery) Offset(offset int) *DocQuery { - dq.offset = &offset + dq.ctx.Offset = &offset return dq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (dq *DocQuery) Unique(unique bool) *DocQuery { - dq.unique = &unique + dq.ctx.Unique = &unique return dq } @@ -139,7 +136,7 @@ func (dq *DocQuery) QueryRelated() *DocQuery { // First returns the first Doc entity from the query. // Returns a *NotFoundError when no Doc was found. func (dq *DocQuery) First(ctx context.Context) (*Doc, error) { - nodes, err := dq.Limit(1).All(newQueryContext(ctx, TypeDoc, "First")) + nodes, err := dq.Limit(1).All(setContextOp(ctx, dq.ctx, "First")) if err != nil { return nil, err } @@ -162,7 +159,7 @@ func (dq *DocQuery) FirstX(ctx context.Context) *Doc { // Returns a *NotFoundError when no Doc ID was found. func (dq *DocQuery) FirstID(ctx context.Context) (id schema.DocID, err error) { var ids []schema.DocID - if ids, err = dq.Limit(1).IDs(newQueryContext(ctx, TypeDoc, "FirstID")); err != nil { + if ids, err = dq.Limit(1).IDs(setContextOp(ctx, dq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -185,7 +182,7 @@ func (dq *DocQuery) FirstIDX(ctx context.Context) schema.DocID { // Returns a *NotSingularError when more than one Doc entity is found. // Returns a *NotFoundError when no Doc entities are found. func (dq *DocQuery) Only(ctx context.Context) (*Doc, error) { - nodes, err := dq.Limit(2).All(newQueryContext(ctx, TypeDoc, "Only")) + nodes, err := dq.Limit(2).All(setContextOp(ctx, dq.ctx, "Only")) if err != nil { return nil, err } @@ -213,7 +210,7 @@ func (dq *DocQuery) OnlyX(ctx context.Context) *Doc { // Returns a *NotFoundError when no entities are found. func (dq *DocQuery) OnlyID(ctx context.Context) (id schema.DocID, err error) { var ids []schema.DocID - if ids, err = dq.Limit(2).IDs(newQueryContext(ctx, TypeDoc, "OnlyID")); err != nil { + if ids, err = dq.Limit(2).IDs(setContextOp(ctx, dq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -238,7 +235,7 @@ func (dq *DocQuery) OnlyIDX(ctx context.Context) schema.DocID { // All executes the query and returns a list of Docs. func (dq *DocQuery) All(ctx context.Context) ([]*Doc, error) { - ctx = newQueryContext(ctx, TypeDoc, "All") + ctx = setContextOp(ctx, dq.ctx, "All") if err := dq.prepareQuery(ctx); err != nil { return nil, err } @@ -258,7 +255,7 @@ func (dq *DocQuery) AllX(ctx context.Context) []*Doc { // IDs executes the query and returns a list of Doc IDs. func (dq *DocQuery) IDs(ctx context.Context) ([]schema.DocID, error) { var ids []schema.DocID - ctx = newQueryContext(ctx, TypeDoc, "IDs") + ctx = setContextOp(ctx, dq.ctx, "IDs") if err := dq.Select(doc.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -276,7 +273,7 @@ func (dq *DocQuery) IDsX(ctx context.Context) []schema.DocID { // Count returns the count of the given query. func (dq *DocQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeDoc, "Count") + ctx = setContextOp(ctx, dq.ctx, "Count") if err := dq.prepareQuery(ctx); err != nil { return 0, err } @@ -294,7 +291,7 @@ func (dq *DocQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (dq *DocQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeDoc, "Exist") + ctx = setContextOp(ctx, dq.ctx, "Exist") switch _, err := dq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -322,8 +319,7 @@ func (dq *DocQuery) Clone() *DocQuery { } return &DocQuery{ config: dq.config, - limit: dq.limit, - offset: dq.offset, + ctx: dq.ctx.Clone(), order: append([]OrderFunc{}, dq.order...), inters: append([]Interceptor{}, dq.inters...), predicates: append([]predicate.Doc{}, dq.predicates...), @@ -331,9 +327,8 @@ func (dq *DocQuery) Clone() *DocQuery { withChildren: dq.withChildren.Clone(), withRelated: dq.withRelated.Clone(), // clone intermediate query. - sql: dq.sql.Clone(), - path: dq.path, - unique: dq.unique, + sql: dq.sql.Clone(), + path: dq.path, } } @@ -385,9 +380,9 @@ func (dq *DocQuery) WithRelated(opts ...func(*DocQuery)) *DocQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (dq *DocQuery) GroupBy(field string, fields ...string) *DocGroupBy { - dq.fields = append([]string{field}, fields...) + dq.ctx.Fields = append([]string{field}, fields...) grbuild := &DocGroupBy{build: dq} - grbuild.flds = &dq.fields + grbuild.flds = &dq.ctx.Fields grbuild.label = doc.Label grbuild.scan = grbuild.Scan return grbuild @@ -406,10 +401,10 @@ func (dq *DocQuery) GroupBy(field string, fields ...string) *DocGroupBy { // Select(doc.FieldText). // Scan(ctx, &v) func (dq *DocQuery) Select(fields ...string) *DocSelect { - dq.fields = append(dq.fields, fields...) + dq.ctx.Fields = append(dq.ctx.Fields, fields...) sbuild := &DocSelect{DocQuery: dq} sbuild.label = doc.Label - sbuild.flds, sbuild.scan = &dq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &dq.ctx.Fields, sbuild.Scan return sbuild } @@ -429,7 +424,7 @@ func (dq *DocQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range dq.fields { + for _, f := range dq.ctx.Fields { if !doc.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -626,9 +621,9 @@ func (dq *DocQuery) loadRelated(ctx context.Context, query *DocQuery, nodes []*D func (dq *DocQuery) sqlCount(ctx context.Context) (int, error) { _spec := dq.querySpec() - _spec.Node.Columns = dq.fields - if len(dq.fields) > 0 { - _spec.Unique = dq.unique != nil && *dq.unique + _spec.Node.Columns = dq.ctx.Fields + if len(dq.ctx.Fields) > 0 { + _spec.Unique = dq.ctx.Unique != nil && *dq.ctx.Unique } return sqlgraph.CountNodes(ctx, dq.driver, _spec) } @@ -646,10 +641,10 @@ func (dq *DocQuery) querySpec() *sqlgraph.QuerySpec { From: dq.sql, Unique: true, } - if unique := dq.unique; unique != nil { + if unique := dq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := dq.fields; len(fields) > 0 { + if fields := dq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, doc.FieldID) for i := range fields { @@ -665,10 +660,10 @@ func (dq *DocQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := dq.order; len(ps) > 0 { @@ -684,7 +679,7 @@ func (dq *DocQuery) querySpec() *sqlgraph.QuerySpec { func (dq *DocQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dq.driver.Dialect()) t1 := builder.Table(doc.Table) - columns := dq.fields + columns := dq.ctx.Fields if len(columns) == 0 { columns = doc.Columns } @@ -693,7 +688,7 @@ func (dq *DocQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = dq.sql selector.Select(selector.Columns(columns...)...) } - if dq.unique != nil && *dq.unique { + if dq.ctx.Unique != nil && *dq.ctx.Unique { selector.Distinct() } for _, p := range dq.predicates { @@ -702,12 +697,12 @@ func (dq *DocQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range dq.order { p(selector) } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -727,7 +722,7 @@ func (dgb *DocGroupBy) Aggregate(fns ...AggregateFunc) *DocGroupBy { // Scan applies the selector query and scans the result into the given value. func (dgb *DocGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeDoc, "GroupBy") + ctx = setContextOp(ctx, dgb.build.ctx, "GroupBy") if err := dgb.build.prepareQuery(ctx); err != nil { return err } @@ -775,7 +770,7 @@ func (ds *DocSelect) Aggregate(fns ...AggregateFunc) *DocSelect { // Scan applies the selector query and scans the result into the given value. func (ds *DocSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeDoc, "Select") + ctx = setContextOp(ctx, ds.ctx, "Select") if err := ds.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/ent.go b/entc/integration/customid/ent/ent.go index a5c81e197..5469c7375 100644 --- a/entc/integration/customid/ent/ent.go +++ b/entc/integration/customid/ent/ent.go @@ -40,6 +40,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -539,10 +540,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/customid/ent/group_query.go b/entc/integration/customid/ent/group_query.go index 346562d37..92cf7f64a 100644 --- a/entc/integration/customid/ent/group_query.go +++ b/entc/integration/customid/ent/group_query.go @@ -23,11 +23,8 @@ import ( // GroupQuery is the builder for querying Group entities. type GroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Group withUsers *UserQuery @@ -44,20 +41,20 @@ func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { // Limit the number of records to be returned by this query. func (gq *GroupQuery) Limit(limit int) *GroupQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GroupQuery) Offset(offset int) *GroupQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GroupQuery) Unique(unique bool) *GroupQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -92,7 +89,7 @@ func (gq *GroupQuery) QueryUsers() *UserQuery { // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGroup, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group { // Returns a *NotFoundError when no Group ID was found. func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGroup, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (gq *GroupQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Group entity is found. // Returns a *NotFoundError when no Group entities are found. func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGroup, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { // Returns a *NotFoundError when no entities are found. func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGroup, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (gq *GroupQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Groups. func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { - ctx = newQueryContext(ctx, TypeGroup, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group { // IDs executes the query and returns a list of Group IDs. func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroup, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (gq *GroupQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gq *GroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroup, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (gq *GroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroup, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (gq *GroupQuery) Clone() *GroupQuery { } return &GroupQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Group{}, gq.predicates...), withUsers: gq.withUsers.Clone(), // clone intermediate query. - sql: gq.sql.Clone(), - path: gq.path, - unique: gq.unique, + sql: gq.sql.Clone(), + path: gq.path, } } @@ -302,9 +297,9 @@ func (gq *GroupQuery) WithUsers(opts ...func(*UserQuery)) *GroupQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = group.Label grbuild.scan = grbuild.Scan return grbuild @@ -313,10 +308,10 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (gq *GroupQuery) Select(fields ...string) *GroupSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GroupSelect{GroupQuery: gq} sbuild.label = group.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -336,7 +331,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gq.fields { + for _, f := range gq.ctx.Fields { if !group.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -448,9 +443,9 @@ func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []* func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { _spec := gq.querySpec() - _spec.Node.Columns = gq.fields - if len(gq.fields) > 0 { - _spec.Unique = gq.unique != nil && *gq.unique + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique } return sqlgraph.CountNodes(ctx, gq.driver, _spec) } @@ -468,10 +463,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { From: gq.sql, Unique: true, } - if unique := gq.unique; unique != nil { + if unique := gq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gq.fields; len(fields) > 0 { + if fields := gq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) for i := range fields { @@ -487,10 +482,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gq.order; len(ps) > 0 { @@ -506,7 +501,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) - columns := gq.fields + columns := gq.ctx.Fields if len(columns) == 0 { columns = group.Columns } @@ -515,7 +510,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gq.sql selector.Select(selector.Columns(columns...)...) } - if gq.unique != nil && *gq.unique { + if gq.ctx.Unique != nil && *gq.ctx.Unique { selector.Distinct() } for _, p := range gq.predicates { @@ -524,12 +519,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gq.order { p(selector) } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -549,7 +544,7 @@ func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -597,7 +592,7 @@ func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/intsid_query.go b/entc/integration/customid/ent/intsid_query.go index e8c77f7df..446982e5a 100644 --- a/entc/integration/customid/ent/intsid_query.go +++ b/entc/integration/customid/ent/intsid_query.go @@ -23,11 +23,8 @@ import ( // IntSIDQuery is the builder for querying IntSID entities. type IntSIDQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.IntSID withParent *IntSIDQuery @@ -46,20 +43,20 @@ func (isq *IntSIDQuery) Where(ps ...predicate.IntSID) *IntSIDQuery { // Limit the number of records to be returned by this query. func (isq *IntSIDQuery) Limit(limit int) *IntSIDQuery { - isq.limit = &limit + isq.ctx.Limit = &limit return isq } // Offset to start from. func (isq *IntSIDQuery) Offset(offset int) *IntSIDQuery { - isq.offset = &offset + isq.ctx.Offset = &offset return isq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (isq *IntSIDQuery) Unique(unique bool) *IntSIDQuery { - isq.unique = &unique + isq.ctx.Unique = &unique return isq } @@ -116,7 +113,7 @@ func (isq *IntSIDQuery) QueryChildren() *IntSIDQuery { // First returns the first IntSID entity from the query. // Returns a *NotFoundError when no IntSID was found. func (isq *IntSIDQuery) First(ctx context.Context) (*IntSID, error) { - nodes, err := isq.Limit(1).All(newQueryContext(ctx, TypeIntSID, "First")) + nodes, err := isq.Limit(1).All(setContextOp(ctx, isq.ctx, "First")) if err != nil { return nil, err } @@ -139,7 +136,7 @@ func (isq *IntSIDQuery) FirstX(ctx context.Context) *IntSID { // Returns a *NotFoundError when no IntSID ID was found. func (isq *IntSIDQuery) FirstID(ctx context.Context) (id sid.ID, err error) { var ids []sid.ID - if ids, err = isq.Limit(1).IDs(newQueryContext(ctx, TypeIntSID, "FirstID")); err != nil { + if ids, err = isq.Limit(1).IDs(setContextOp(ctx, isq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -162,7 +159,7 @@ func (isq *IntSIDQuery) FirstIDX(ctx context.Context) sid.ID { // Returns a *NotSingularError when more than one IntSID entity is found. // Returns a *NotFoundError when no IntSID entities are found. func (isq *IntSIDQuery) Only(ctx context.Context) (*IntSID, error) { - nodes, err := isq.Limit(2).All(newQueryContext(ctx, TypeIntSID, "Only")) + nodes, err := isq.Limit(2).All(setContextOp(ctx, isq.ctx, "Only")) if err != nil { return nil, err } @@ -190,7 +187,7 @@ func (isq *IntSIDQuery) OnlyX(ctx context.Context) *IntSID { // Returns a *NotFoundError when no entities are found. func (isq *IntSIDQuery) OnlyID(ctx context.Context) (id sid.ID, err error) { var ids []sid.ID - if ids, err = isq.Limit(2).IDs(newQueryContext(ctx, TypeIntSID, "OnlyID")); err != nil { + if ids, err = isq.Limit(2).IDs(setContextOp(ctx, isq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -215,7 +212,7 @@ func (isq *IntSIDQuery) OnlyIDX(ctx context.Context) sid.ID { // All executes the query and returns a list of IntSIDs. func (isq *IntSIDQuery) All(ctx context.Context) ([]*IntSID, error) { - ctx = newQueryContext(ctx, TypeIntSID, "All") + ctx = setContextOp(ctx, isq.ctx, "All") if err := isq.prepareQuery(ctx); err != nil { return nil, err } @@ -235,7 +232,7 @@ func (isq *IntSIDQuery) AllX(ctx context.Context) []*IntSID { // IDs executes the query and returns a list of IntSID IDs. func (isq *IntSIDQuery) IDs(ctx context.Context) ([]sid.ID, error) { var ids []sid.ID - ctx = newQueryContext(ctx, TypeIntSID, "IDs") + ctx = setContextOp(ctx, isq.ctx, "IDs") if err := isq.Select(intsid.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -253,7 +250,7 @@ func (isq *IntSIDQuery) IDsX(ctx context.Context) []sid.ID { // Count returns the count of the given query. func (isq *IntSIDQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeIntSID, "Count") + ctx = setContextOp(ctx, isq.ctx, "Count") if err := isq.prepareQuery(ctx); err != nil { return 0, err } @@ -271,7 +268,7 @@ func (isq *IntSIDQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (isq *IntSIDQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeIntSID, "Exist") + ctx = setContextOp(ctx, isq.ctx, "Exist") switch _, err := isq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -299,17 +296,15 @@ func (isq *IntSIDQuery) Clone() *IntSIDQuery { } return &IntSIDQuery{ config: isq.config, - limit: isq.limit, - offset: isq.offset, + ctx: isq.ctx.Clone(), order: append([]OrderFunc{}, isq.order...), inters: append([]Interceptor{}, isq.inters...), predicates: append([]predicate.IntSID{}, isq.predicates...), withParent: isq.withParent.Clone(), withChildren: isq.withChildren.Clone(), // clone intermediate query. - sql: isq.sql.Clone(), - path: isq.path, - unique: isq.unique, + sql: isq.sql.Clone(), + path: isq.path, } } @@ -338,9 +333,9 @@ func (isq *IntSIDQuery) WithChildren(opts ...func(*IntSIDQuery)) *IntSIDQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (isq *IntSIDQuery) GroupBy(field string, fields ...string) *IntSIDGroupBy { - isq.fields = append([]string{field}, fields...) + isq.ctx.Fields = append([]string{field}, fields...) grbuild := &IntSIDGroupBy{build: isq} - grbuild.flds = &isq.fields + grbuild.flds = &isq.ctx.Fields grbuild.label = intsid.Label grbuild.scan = grbuild.Scan return grbuild @@ -349,10 +344,10 @@ func (isq *IntSIDQuery) GroupBy(field string, fields ...string) *IntSIDGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (isq *IntSIDQuery) Select(fields ...string) *IntSIDSelect { - isq.fields = append(isq.fields, fields...) + isq.ctx.Fields = append(isq.ctx.Fields, fields...) sbuild := &IntSIDSelect{IntSIDQuery: isq} sbuild.label = intsid.Label - sbuild.flds, sbuild.scan = &isq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &isq.ctx.Fields, sbuild.Scan return sbuild } @@ -372,7 +367,7 @@ func (isq *IntSIDQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range isq.fields { + for _, f := range isq.ctx.Fields { if !intsid.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -503,9 +498,9 @@ func (isq *IntSIDQuery) loadChildren(ctx context.Context, query *IntSIDQuery, no func (isq *IntSIDQuery) sqlCount(ctx context.Context) (int, error) { _spec := isq.querySpec() - _spec.Node.Columns = isq.fields - if len(isq.fields) > 0 { - _spec.Unique = isq.unique != nil && *isq.unique + _spec.Node.Columns = isq.ctx.Fields + if len(isq.ctx.Fields) > 0 { + _spec.Unique = isq.ctx.Unique != nil && *isq.ctx.Unique } return sqlgraph.CountNodes(ctx, isq.driver, _spec) } @@ -523,10 +518,10 @@ func (isq *IntSIDQuery) querySpec() *sqlgraph.QuerySpec { From: isq.sql, Unique: true, } - if unique := isq.unique; unique != nil { + if unique := isq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := isq.fields; len(fields) > 0 { + if fields := isq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, intsid.FieldID) for i := range fields { @@ -542,10 +537,10 @@ func (isq *IntSIDQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := isq.limit; limit != nil { + if limit := isq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := isq.offset; offset != nil { + if offset := isq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := isq.order; len(ps) > 0 { @@ -561,7 +556,7 @@ func (isq *IntSIDQuery) querySpec() *sqlgraph.QuerySpec { func (isq *IntSIDQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(isq.driver.Dialect()) t1 := builder.Table(intsid.Table) - columns := isq.fields + columns := isq.ctx.Fields if len(columns) == 0 { columns = intsid.Columns } @@ -570,7 +565,7 @@ func (isq *IntSIDQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = isq.sql selector.Select(selector.Columns(columns...)...) } - if isq.unique != nil && *isq.unique { + if isq.ctx.Unique != nil && *isq.ctx.Unique { selector.Distinct() } for _, p := range isq.predicates { @@ -579,12 +574,12 @@ func (isq *IntSIDQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range isq.order { p(selector) } - if offset := isq.offset; offset != nil { + if offset := isq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := isq.limit; limit != nil { + if limit := isq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -604,7 +599,7 @@ func (isgb *IntSIDGroupBy) Aggregate(fns ...AggregateFunc) *IntSIDGroupBy { // Scan applies the selector query and scans the result into the given value. func (isgb *IntSIDGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeIntSID, "GroupBy") + ctx = setContextOp(ctx, isgb.build.ctx, "GroupBy") if err := isgb.build.prepareQuery(ctx); err != nil { return err } @@ -652,7 +647,7 @@ func (iss *IntSIDSelect) Aggregate(fns ...AggregateFunc) *IntSIDSelect { // Scan applies the selector query and scans the result into the given value. func (iss *IntSIDSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeIntSID, "Select") + ctx = setContextOp(ctx, iss.ctx, "Select") if err := iss.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/link_query.go b/entc/integration/customid/ent/link_query.go index 2628de94a..6975c1838 100644 --- a/entc/integration/customid/ent/link_query.go +++ b/entc/integration/customid/ent/link_query.go @@ -22,11 +22,8 @@ import ( // LinkQuery is the builder for querying Link entities. type LinkQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Link // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (lq *LinkQuery) Where(ps ...predicate.Link) *LinkQuery { // Limit the number of records to be returned by this query. func (lq *LinkQuery) Limit(limit int) *LinkQuery { - lq.limit = &limit + lq.ctx.Limit = &limit return lq } // Offset to start from. func (lq *LinkQuery) Offset(offset int) *LinkQuery { - lq.offset = &offset + lq.ctx.Offset = &offset return lq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (lq *LinkQuery) Unique(unique bool) *LinkQuery { - lq.unique = &unique + lq.ctx.Unique = &unique return lq } @@ -68,7 +65,7 @@ func (lq *LinkQuery) Order(o ...OrderFunc) *LinkQuery { // First returns the first Link entity from the query. // Returns a *NotFoundError when no Link was found. func (lq *LinkQuery) First(ctx context.Context) (*Link, error) { - nodes, err := lq.Limit(1).All(newQueryContext(ctx, TypeLink, "First")) + nodes, err := lq.Limit(1).All(setContextOp(ctx, lq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (lq *LinkQuery) FirstX(ctx context.Context) *Link { // Returns a *NotFoundError when no Link ID was found. func (lq *LinkQuery) FirstID(ctx context.Context) (id uuidc.UUIDC, err error) { var ids []uuidc.UUIDC - if ids, err = lq.Limit(1).IDs(newQueryContext(ctx, TypeLink, "FirstID")); err != nil { + if ids, err = lq.Limit(1).IDs(setContextOp(ctx, lq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (lq *LinkQuery) FirstIDX(ctx context.Context) uuidc.UUIDC { // Returns a *NotSingularError when more than one Link entity is found. // Returns a *NotFoundError when no Link entities are found. func (lq *LinkQuery) Only(ctx context.Context) (*Link, error) { - nodes, err := lq.Limit(2).All(newQueryContext(ctx, TypeLink, "Only")) + nodes, err := lq.Limit(2).All(setContextOp(ctx, lq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (lq *LinkQuery) OnlyX(ctx context.Context) *Link { // Returns a *NotFoundError when no entities are found. func (lq *LinkQuery) OnlyID(ctx context.Context) (id uuidc.UUIDC, err error) { var ids []uuidc.UUIDC - if ids, err = lq.Limit(2).IDs(newQueryContext(ctx, TypeLink, "OnlyID")); err != nil { + if ids, err = lq.Limit(2).IDs(setContextOp(ctx, lq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (lq *LinkQuery) OnlyIDX(ctx context.Context) uuidc.UUIDC { // All executes the query and returns a list of Links. func (lq *LinkQuery) All(ctx context.Context) ([]*Link, error) { - ctx = newQueryContext(ctx, TypeLink, "All") + ctx = setContextOp(ctx, lq.ctx, "All") if err := lq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (lq *LinkQuery) AllX(ctx context.Context) []*Link { // IDs executes the query and returns a list of Link IDs. func (lq *LinkQuery) IDs(ctx context.Context) ([]uuidc.UUIDC, error) { var ids []uuidc.UUIDC - ctx = newQueryContext(ctx, TypeLink, "IDs") + ctx = setContextOp(ctx, lq.ctx, "IDs") if err := lq.Select(link.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (lq *LinkQuery) IDsX(ctx context.Context) []uuidc.UUIDC { // Count returns the count of the given query. func (lq *LinkQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeLink, "Count") + ctx = setContextOp(ctx, lq.ctx, "Count") if err := lq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (lq *LinkQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (lq *LinkQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeLink, "Exist") + ctx = setContextOp(ctx, lq.ctx, "Exist") switch _, err := lq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,15 +248,13 @@ func (lq *LinkQuery) Clone() *LinkQuery { } return &LinkQuery{ config: lq.config, - limit: lq.limit, - offset: lq.offset, + ctx: lq.ctx.Clone(), order: append([]OrderFunc{}, lq.order...), inters: append([]Interceptor{}, lq.inters...), predicates: append([]predicate.Link{}, lq.predicates...), // clone intermediate query. - sql: lq.sql.Clone(), - path: lq.path, - unique: lq.unique, + sql: lq.sql.Clone(), + path: lq.path, } } @@ -278,9 +273,9 @@ func (lq *LinkQuery) Clone() *LinkQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (lq *LinkQuery) GroupBy(field string, fields ...string) *LinkGroupBy { - lq.fields = append([]string{field}, fields...) + lq.ctx.Fields = append([]string{field}, fields...) grbuild := &LinkGroupBy{build: lq} - grbuild.flds = &lq.fields + grbuild.flds = &lq.ctx.Fields grbuild.label = link.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (lq *LinkQuery) GroupBy(field string, fields ...string) *LinkGroupBy { // Select(link.FieldLinkInformation). // Scan(ctx, &v) func (lq *LinkQuery) Select(fields ...string) *LinkSelect { - lq.fields = append(lq.fields, fields...) + lq.ctx.Fields = append(lq.ctx.Fields, fields...) sbuild := &LinkSelect{LinkQuery: lq} sbuild.label = link.Label - sbuild.flds, sbuild.scan = &lq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &lq.ctx.Fields, sbuild.Scan return sbuild } @@ -322,7 +317,7 @@ func (lq *LinkQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range lq.fields { + for _, f := range lq.ctx.Fields { if !link.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -364,9 +359,9 @@ func (lq *LinkQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Link, e func (lq *LinkQuery) sqlCount(ctx context.Context) (int, error) { _spec := lq.querySpec() - _spec.Node.Columns = lq.fields - if len(lq.fields) > 0 { - _spec.Unique = lq.unique != nil && *lq.unique + _spec.Node.Columns = lq.ctx.Fields + if len(lq.ctx.Fields) > 0 { + _spec.Unique = lq.ctx.Unique != nil && *lq.ctx.Unique } return sqlgraph.CountNodes(ctx, lq.driver, _spec) } @@ -384,10 +379,10 @@ func (lq *LinkQuery) querySpec() *sqlgraph.QuerySpec { From: lq.sql, Unique: true, } - if unique := lq.unique; unique != nil { + if unique := lq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := lq.fields; len(fields) > 0 { + if fields := lq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, link.FieldID) for i := range fields { @@ -403,10 +398,10 @@ func (lq *LinkQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := lq.limit; limit != nil { + if limit := lq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := lq.offset; offset != nil { + if offset := lq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := lq.order; len(ps) > 0 { @@ -422,7 +417,7 @@ func (lq *LinkQuery) querySpec() *sqlgraph.QuerySpec { func (lq *LinkQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(lq.driver.Dialect()) t1 := builder.Table(link.Table) - columns := lq.fields + columns := lq.ctx.Fields if len(columns) == 0 { columns = link.Columns } @@ -431,7 +426,7 @@ func (lq *LinkQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = lq.sql selector.Select(selector.Columns(columns...)...) } - if lq.unique != nil && *lq.unique { + if lq.ctx.Unique != nil && *lq.ctx.Unique { selector.Distinct() } for _, p := range lq.predicates { @@ -440,12 +435,12 @@ func (lq *LinkQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range lq.order { p(selector) } - if offset := lq.offset; offset != nil { + if offset := lq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := lq.limit; limit != nil { + if limit := lq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -465,7 +460,7 @@ func (lgb *LinkGroupBy) Aggregate(fns ...AggregateFunc) *LinkGroupBy { // Scan applies the selector query and scans the result into the given value. func (lgb *LinkGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeLink, "GroupBy") + ctx = setContextOp(ctx, lgb.build.ctx, "GroupBy") if err := lgb.build.prepareQuery(ctx); err != nil { return err } @@ -513,7 +508,7 @@ func (ls *LinkSelect) Aggregate(fns ...AggregateFunc) *LinkSelect { // Scan applies the selector query and scans the result into the given value. func (ls *LinkSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeLink, "Select") + ctx = setContextOp(ctx, ls.ctx, "Select") if err := ls.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/mixinid_query.go b/entc/integration/customid/ent/mixinid_query.go index 89c42eb64..c88cb44d7 100644 --- a/entc/integration/customid/ent/mixinid_query.go +++ b/entc/integration/customid/ent/mixinid_query.go @@ -22,11 +22,8 @@ import ( // MixinIDQuery is the builder for querying MixinID entities. type MixinIDQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.MixinID // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (miq *MixinIDQuery) Where(ps ...predicate.MixinID) *MixinIDQuery { // Limit the number of records to be returned by this query. func (miq *MixinIDQuery) Limit(limit int) *MixinIDQuery { - miq.limit = &limit + miq.ctx.Limit = &limit return miq } // Offset to start from. func (miq *MixinIDQuery) Offset(offset int) *MixinIDQuery { - miq.offset = &offset + miq.ctx.Offset = &offset return miq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (miq *MixinIDQuery) Unique(unique bool) *MixinIDQuery { - miq.unique = &unique + miq.ctx.Unique = &unique return miq } @@ -68,7 +65,7 @@ func (miq *MixinIDQuery) Order(o ...OrderFunc) *MixinIDQuery { // First returns the first MixinID entity from the query. // Returns a *NotFoundError when no MixinID was found. func (miq *MixinIDQuery) First(ctx context.Context) (*MixinID, error) { - nodes, err := miq.Limit(1).All(newQueryContext(ctx, TypeMixinID, "First")) + nodes, err := miq.Limit(1).All(setContextOp(ctx, miq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (miq *MixinIDQuery) FirstX(ctx context.Context) *MixinID { // Returns a *NotFoundError when no MixinID ID was found. func (miq *MixinIDQuery) FirstID(ctx context.Context) (id uuid.UUID, err error) { var ids []uuid.UUID - if ids, err = miq.Limit(1).IDs(newQueryContext(ctx, TypeMixinID, "FirstID")); err != nil { + if ids, err = miq.Limit(1).IDs(setContextOp(ctx, miq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (miq *MixinIDQuery) FirstIDX(ctx context.Context) uuid.UUID { // Returns a *NotSingularError when more than one MixinID entity is found. // Returns a *NotFoundError when no MixinID entities are found. func (miq *MixinIDQuery) Only(ctx context.Context) (*MixinID, error) { - nodes, err := miq.Limit(2).All(newQueryContext(ctx, TypeMixinID, "Only")) + nodes, err := miq.Limit(2).All(setContextOp(ctx, miq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (miq *MixinIDQuery) OnlyX(ctx context.Context) *MixinID { // Returns a *NotFoundError when no entities are found. func (miq *MixinIDQuery) OnlyID(ctx context.Context) (id uuid.UUID, err error) { var ids []uuid.UUID - if ids, err = miq.Limit(2).IDs(newQueryContext(ctx, TypeMixinID, "OnlyID")); err != nil { + if ids, err = miq.Limit(2).IDs(setContextOp(ctx, miq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (miq *MixinIDQuery) OnlyIDX(ctx context.Context) uuid.UUID { // All executes the query and returns a list of MixinIDs. func (miq *MixinIDQuery) All(ctx context.Context) ([]*MixinID, error) { - ctx = newQueryContext(ctx, TypeMixinID, "All") + ctx = setContextOp(ctx, miq.ctx, "All") if err := miq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (miq *MixinIDQuery) AllX(ctx context.Context) []*MixinID { // IDs executes the query and returns a list of MixinID IDs. func (miq *MixinIDQuery) IDs(ctx context.Context) ([]uuid.UUID, error) { var ids []uuid.UUID - ctx = newQueryContext(ctx, TypeMixinID, "IDs") + ctx = setContextOp(ctx, miq.ctx, "IDs") if err := miq.Select(mixinid.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (miq *MixinIDQuery) IDsX(ctx context.Context) []uuid.UUID { // Count returns the count of the given query. func (miq *MixinIDQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeMixinID, "Count") + ctx = setContextOp(ctx, miq.ctx, "Count") if err := miq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (miq *MixinIDQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (miq *MixinIDQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeMixinID, "Exist") + ctx = setContextOp(ctx, miq.ctx, "Exist") switch _, err := miq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,15 +248,13 @@ func (miq *MixinIDQuery) Clone() *MixinIDQuery { } return &MixinIDQuery{ config: miq.config, - limit: miq.limit, - offset: miq.offset, + ctx: miq.ctx.Clone(), order: append([]OrderFunc{}, miq.order...), inters: append([]Interceptor{}, miq.inters...), predicates: append([]predicate.MixinID{}, miq.predicates...), // clone intermediate query. - sql: miq.sql.Clone(), - path: miq.path, - unique: miq.unique, + sql: miq.sql.Clone(), + path: miq.path, } } @@ -278,9 +273,9 @@ func (miq *MixinIDQuery) Clone() *MixinIDQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (miq *MixinIDQuery) GroupBy(field string, fields ...string) *MixinIDGroupBy { - miq.fields = append([]string{field}, fields...) + miq.ctx.Fields = append([]string{field}, fields...) grbuild := &MixinIDGroupBy{build: miq} - grbuild.flds = &miq.fields + grbuild.flds = &miq.ctx.Fields grbuild.label = mixinid.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (miq *MixinIDQuery) GroupBy(field string, fields ...string) *MixinIDGroupBy // Select(mixinid.FieldSomeField). // Scan(ctx, &v) func (miq *MixinIDQuery) Select(fields ...string) *MixinIDSelect { - miq.fields = append(miq.fields, fields...) + miq.ctx.Fields = append(miq.ctx.Fields, fields...) sbuild := &MixinIDSelect{MixinIDQuery: miq} sbuild.label = mixinid.Label - sbuild.flds, sbuild.scan = &miq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &miq.ctx.Fields, sbuild.Scan return sbuild } @@ -322,7 +317,7 @@ func (miq *MixinIDQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range miq.fields { + for _, f := range miq.ctx.Fields { if !mixinid.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -364,9 +359,9 @@ func (miq *MixinIDQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Mix func (miq *MixinIDQuery) sqlCount(ctx context.Context) (int, error) { _spec := miq.querySpec() - _spec.Node.Columns = miq.fields - if len(miq.fields) > 0 { - _spec.Unique = miq.unique != nil && *miq.unique + _spec.Node.Columns = miq.ctx.Fields + if len(miq.ctx.Fields) > 0 { + _spec.Unique = miq.ctx.Unique != nil && *miq.ctx.Unique } return sqlgraph.CountNodes(ctx, miq.driver, _spec) } @@ -384,10 +379,10 @@ func (miq *MixinIDQuery) querySpec() *sqlgraph.QuerySpec { From: miq.sql, Unique: true, } - if unique := miq.unique; unique != nil { + if unique := miq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := miq.fields; len(fields) > 0 { + if fields := miq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mixinid.FieldID) for i := range fields { @@ -403,10 +398,10 @@ func (miq *MixinIDQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := miq.limit; limit != nil { + if limit := miq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := miq.offset; offset != nil { + if offset := miq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := miq.order; len(ps) > 0 { @@ -422,7 +417,7 @@ func (miq *MixinIDQuery) querySpec() *sqlgraph.QuerySpec { func (miq *MixinIDQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(miq.driver.Dialect()) t1 := builder.Table(mixinid.Table) - columns := miq.fields + columns := miq.ctx.Fields if len(columns) == 0 { columns = mixinid.Columns } @@ -431,7 +426,7 @@ func (miq *MixinIDQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = miq.sql selector.Select(selector.Columns(columns...)...) } - if miq.unique != nil && *miq.unique { + if miq.ctx.Unique != nil && *miq.ctx.Unique { selector.Distinct() } for _, p := range miq.predicates { @@ -440,12 +435,12 @@ func (miq *MixinIDQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range miq.order { p(selector) } - if offset := miq.offset; offset != nil { + if offset := miq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := miq.limit; limit != nil { + if limit := miq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -465,7 +460,7 @@ func (migb *MixinIDGroupBy) Aggregate(fns ...AggregateFunc) *MixinIDGroupBy { // Scan applies the selector query and scans the result into the given value. func (migb *MixinIDGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeMixinID, "GroupBy") + ctx = setContextOp(ctx, migb.build.ctx, "GroupBy") if err := migb.build.prepareQuery(ctx); err != nil { return err } @@ -513,7 +508,7 @@ func (mis *MixinIDSelect) Aggregate(fns ...AggregateFunc) *MixinIDSelect { // Scan applies the selector query and scans the result into the given value. func (mis *MixinIDSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeMixinID, "Select") + ctx = setContextOp(ctx, mis.ctx, "Select") if err := mis.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/note_query.go b/entc/integration/customid/ent/note_query.go index fac501793..8fc519dbe 100644 --- a/entc/integration/customid/ent/note_query.go +++ b/entc/integration/customid/ent/note_query.go @@ -23,11 +23,8 @@ import ( // NoteQuery is the builder for querying Note entities. type NoteQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Note withParent *NoteQuery @@ -46,20 +43,20 @@ func (nq *NoteQuery) Where(ps ...predicate.Note) *NoteQuery { // Limit the number of records to be returned by this query. func (nq *NoteQuery) Limit(limit int) *NoteQuery { - nq.limit = &limit + nq.ctx.Limit = &limit return nq } // Offset to start from. func (nq *NoteQuery) Offset(offset int) *NoteQuery { - nq.offset = &offset + nq.ctx.Offset = &offset return nq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (nq *NoteQuery) Unique(unique bool) *NoteQuery { - nq.unique = &unique + nq.ctx.Unique = &unique return nq } @@ -116,7 +113,7 @@ func (nq *NoteQuery) QueryChildren() *NoteQuery { // First returns the first Note entity from the query. // Returns a *NotFoundError when no Note was found. func (nq *NoteQuery) First(ctx context.Context) (*Note, error) { - nodes, err := nq.Limit(1).All(newQueryContext(ctx, TypeNote, "First")) + nodes, err := nq.Limit(1).All(setContextOp(ctx, nq.ctx, "First")) if err != nil { return nil, err } @@ -139,7 +136,7 @@ func (nq *NoteQuery) FirstX(ctx context.Context) *Note { // Returns a *NotFoundError when no Note ID was found. func (nq *NoteQuery) FirstID(ctx context.Context) (id schema.NoteID, err error) { var ids []schema.NoteID - if ids, err = nq.Limit(1).IDs(newQueryContext(ctx, TypeNote, "FirstID")); err != nil { + if ids, err = nq.Limit(1).IDs(setContextOp(ctx, nq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -162,7 +159,7 @@ func (nq *NoteQuery) FirstIDX(ctx context.Context) schema.NoteID { // Returns a *NotSingularError when more than one Note entity is found. // Returns a *NotFoundError when no Note entities are found. func (nq *NoteQuery) Only(ctx context.Context) (*Note, error) { - nodes, err := nq.Limit(2).All(newQueryContext(ctx, TypeNote, "Only")) + nodes, err := nq.Limit(2).All(setContextOp(ctx, nq.ctx, "Only")) if err != nil { return nil, err } @@ -190,7 +187,7 @@ func (nq *NoteQuery) OnlyX(ctx context.Context) *Note { // Returns a *NotFoundError when no entities are found. func (nq *NoteQuery) OnlyID(ctx context.Context) (id schema.NoteID, err error) { var ids []schema.NoteID - if ids, err = nq.Limit(2).IDs(newQueryContext(ctx, TypeNote, "OnlyID")); err != nil { + if ids, err = nq.Limit(2).IDs(setContextOp(ctx, nq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -215,7 +212,7 @@ func (nq *NoteQuery) OnlyIDX(ctx context.Context) schema.NoteID { // All executes the query and returns a list of Notes. func (nq *NoteQuery) All(ctx context.Context) ([]*Note, error) { - ctx = newQueryContext(ctx, TypeNote, "All") + ctx = setContextOp(ctx, nq.ctx, "All") if err := nq.prepareQuery(ctx); err != nil { return nil, err } @@ -235,7 +232,7 @@ func (nq *NoteQuery) AllX(ctx context.Context) []*Note { // IDs executes the query and returns a list of Note IDs. func (nq *NoteQuery) IDs(ctx context.Context) ([]schema.NoteID, error) { var ids []schema.NoteID - ctx = newQueryContext(ctx, TypeNote, "IDs") + ctx = setContextOp(ctx, nq.ctx, "IDs") if err := nq.Select(note.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -253,7 +250,7 @@ func (nq *NoteQuery) IDsX(ctx context.Context) []schema.NoteID { // Count returns the count of the given query. func (nq *NoteQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeNote, "Count") + ctx = setContextOp(ctx, nq.ctx, "Count") if err := nq.prepareQuery(ctx); err != nil { return 0, err } @@ -271,7 +268,7 @@ func (nq *NoteQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (nq *NoteQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeNote, "Exist") + ctx = setContextOp(ctx, nq.ctx, "Exist") switch _, err := nq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -299,17 +296,15 @@ func (nq *NoteQuery) Clone() *NoteQuery { } return &NoteQuery{ config: nq.config, - limit: nq.limit, - offset: nq.offset, + ctx: nq.ctx.Clone(), order: append([]OrderFunc{}, nq.order...), inters: append([]Interceptor{}, nq.inters...), predicates: append([]predicate.Note{}, nq.predicates...), withParent: nq.withParent.Clone(), withChildren: nq.withChildren.Clone(), // clone intermediate query. - sql: nq.sql.Clone(), - path: nq.path, - unique: nq.unique, + sql: nq.sql.Clone(), + path: nq.path, } } @@ -350,9 +345,9 @@ func (nq *NoteQuery) WithChildren(opts ...func(*NoteQuery)) *NoteQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (nq *NoteQuery) GroupBy(field string, fields ...string) *NoteGroupBy { - nq.fields = append([]string{field}, fields...) + nq.ctx.Fields = append([]string{field}, fields...) grbuild := &NoteGroupBy{build: nq} - grbuild.flds = &nq.fields + grbuild.flds = &nq.ctx.Fields grbuild.label = note.Label grbuild.scan = grbuild.Scan return grbuild @@ -371,10 +366,10 @@ func (nq *NoteQuery) GroupBy(field string, fields ...string) *NoteGroupBy { // Select(note.FieldText). // Scan(ctx, &v) func (nq *NoteQuery) Select(fields ...string) *NoteSelect { - nq.fields = append(nq.fields, fields...) + nq.ctx.Fields = append(nq.ctx.Fields, fields...) sbuild := &NoteSelect{NoteQuery: nq} sbuild.label = note.Label - sbuild.flds, sbuild.scan = &nq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &nq.ctx.Fields, sbuild.Scan return sbuild } @@ -394,7 +389,7 @@ func (nq *NoteQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range nq.fields { + for _, f := range nq.ctx.Fields { if !note.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -525,9 +520,9 @@ func (nq *NoteQuery) loadChildren(ctx context.Context, query *NoteQuery, nodes [ func (nq *NoteQuery) sqlCount(ctx context.Context) (int, error) { _spec := nq.querySpec() - _spec.Node.Columns = nq.fields - if len(nq.fields) > 0 { - _spec.Unique = nq.unique != nil && *nq.unique + _spec.Node.Columns = nq.ctx.Fields + if len(nq.ctx.Fields) > 0 { + _spec.Unique = nq.ctx.Unique != nil && *nq.ctx.Unique } return sqlgraph.CountNodes(ctx, nq.driver, _spec) } @@ -545,10 +540,10 @@ func (nq *NoteQuery) querySpec() *sqlgraph.QuerySpec { From: nq.sql, Unique: true, } - if unique := nq.unique; unique != nil { + if unique := nq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := nq.fields; len(fields) > 0 { + if fields := nq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, note.FieldID) for i := range fields { @@ -564,10 +559,10 @@ func (nq *NoteQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := nq.limit; limit != nil { + if limit := nq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := nq.offset; offset != nil { + if offset := nq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := nq.order; len(ps) > 0 { @@ -583,7 +578,7 @@ func (nq *NoteQuery) querySpec() *sqlgraph.QuerySpec { func (nq *NoteQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(nq.driver.Dialect()) t1 := builder.Table(note.Table) - columns := nq.fields + columns := nq.ctx.Fields if len(columns) == 0 { columns = note.Columns } @@ -592,7 +587,7 @@ func (nq *NoteQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = nq.sql selector.Select(selector.Columns(columns...)...) } - if nq.unique != nil && *nq.unique { + if nq.ctx.Unique != nil && *nq.ctx.Unique { selector.Distinct() } for _, p := range nq.predicates { @@ -601,12 +596,12 @@ func (nq *NoteQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range nq.order { p(selector) } - if offset := nq.offset; offset != nil { + if offset := nq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := nq.limit; limit != nil { + if limit := nq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -626,7 +621,7 @@ func (ngb *NoteGroupBy) Aggregate(fns ...AggregateFunc) *NoteGroupBy { // Scan applies the selector query and scans the result into the given value. func (ngb *NoteGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeNote, "GroupBy") + ctx = setContextOp(ctx, ngb.build.ctx, "GroupBy") if err := ngb.build.prepareQuery(ctx); err != nil { return err } @@ -674,7 +669,7 @@ func (ns *NoteSelect) Aggregate(fns ...AggregateFunc) *NoteSelect { // Scan applies the selector query and scans the result into the given value. func (ns *NoteSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeNote, "Select") + ctx = setContextOp(ctx, ns.ctx, "Select") if err := ns.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/other_query.go b/entc/integration/customid/ent/other_query.go index 7f54292cf..83f3b7938 100644 --- a/entc/integration/customid/ent/other_query.go +++ b/entc/integration/customid/ent/other_query.go @@ -22,11 +22,8 @@ import ( // OtherQuery is the builder for querying Other entities. type OtherQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Other // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (oq *OtherQuery) Where(ps ...predicate.Other) *OtherQuery { // Limit the number of records to be returned by this query. func (oq *OtherQuery) Limit(limit int) *OtherQuery { - oq.limit = &limit + oq.ctx.Limit = &limit return oq } // Offset to start from. func (oq *OtherQuery) Offset(offset int) *OtherQuery { - oq.offset = &offset + oq.ctx.Offset = &offset return oq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (oq *OtherQuery) Unique(unique bool) *OtherQuery { - oq.unique = &unique + oq.ctx.Unique = &unique return oq } @@ -68,7 +65,7 @@ func (oq *OtherQuery) Order(o ...OrderFunc) *OtherQuery { // First returns the first Other entity from the query. // Returns a *NotFoundError when no Other was found. func (oq *OtherQuery) First(ctx context.Context) (*Other, error) { - nodes, err := oq.Limit(1).All(newQueryContext(ctx, TypeOther, "First")) + nodes, err := oq.Limit(1).All(setContextOp(ctx, oq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (oq *OtherQuery) FirstX(ctx context.Context) *Other { // Returns a *NotFoundError when no Other ID was found. func (oq *OtherQuery) FirstID(ctx context.Context) (id sid.ID, err error) { var ids []sid.ID - if ids, err = oq.Limit(1).IDs(newQueryContext(ctx, TypeOther, "FirstID")); err != nil { + if ids, err = oq.Limit(1).IDs(setContextOp(ctx, oq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (oq *OtherQuery) FirstIDX(ctx context.Context) sid.ID { // Returns a *NotSingularError when more than one Other entity is found. // Returns a *NotFoundError when no Other entities are found. func (oq *OtherQuery) Only(ctx context.Context) (*Other, error) { - nodes, err := oq.Limit(2).All(newQueryContext(ctx, TypeOther, "Only")) + nodes, err := oq.Limit(2).All(setContextOp(ctx, oq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (oq *OtherQuery) OnlyX(ctx context.Context) *Other { // Returns a *NotFoundError when no entities are found. func (oq *OtherQuery) OnlyID(ctx context.Context) (id sid.ID, err error) { var ids []sid.ID - if ids, err = oq.Limit(2).IDs(newQueryContext(ctx, TypeOther, "OnlyID")); err != nil { + if ids, err = oq.Limit(2).IDs(setContextOp(ctx, oq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (oq *OtherQuery) OnlyIDX(ctx context.Context) sid.ID { // All executes the query and returns a list of Others. func (oq *OtherQuery) All(ctx context.Context) ([]*Other, error) { - ctx = newQueryContext(ctx, TypeOther, "All") + ctx = setContextOp(ctx, oq.ctx, "All") if err := oq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (oq *OtherQuery) AllX(ctx context.Context) []*Other { // IDs executes the query and returns a list of Other IDs. func (oq *OtherQuery) IDs(ctx context.Context) ([]sid.ID, error) { var ids []sid.ID - ctx = newQueryContext(ctx, TypeOther, "IDs") + ctx = setContextOp(ctx, oq.ctx, "IDs") if err := oq.Select(other.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (oq *OtherQuery) IDsX(ctx context.Context) []sid.ID { // Count returns the count of the given query. func (oq *OtherQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeOther, "Count") + ctx = setContextOp(ctx, oq.ctx, "Count") if err := oq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (oq *OtherQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (oq *OtherQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeOther, "Exist") + ctx = setContextOp(ctx, oq.ctx, "Exist") switch _, err := oq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,24 +248,22 @@ func (oq *OtherQuery) Clone() *OtherQuery { } return &OtherQuery{ config: oq.config, - limit: oq.limit, - offset: oq.offset, + ctx: oq.ctx.Clone(), order: append([]OrderFunc{}, oq.order...), inters: append([]Interceptor{}, oq.inters...), predicates: append([]predicate.Other{}, oq.predicates...), // clone intermediate query. - sql: oq.sql.Clone(), - path: oq.path, - unique: oq.unique, + sql: oq.sql.Clone(), + path: oq.path, } } // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (oq *OtherQuery) GroupBy(field string, fields ...string) *OtherGroupBy { - oq.fields = append([]string{field}, fields...) + oq.ctx.Fields = append([]string{field}, fields...) grbuild := &OtherGroupBy{build: oq} - grbuild.flds = &oq.fields + grbuild.flds = &oq.ctx.Fields grbuild.label = other.Label grbuild.scan = grbuild.Scan return grbuild @@ -277,10 +272,10 @@ func (oq *OtherQuery) GroupBy(field string, fields ...string) *OtherGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (oq *OtherQuery) Select(fields ...string) *OtherSelect { - oq.fields = append(oq.fields, fields...) + oq.ctx.Fields = append(oq.ctx.Fields, fields...) sbuild := &OtherSelect{OtherQuery: oq} sbuild.label = other.Label - sbuild.flds, sbuild.scan = &oq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &oq.ctx.Fields, sbuild.Scan return sbuild } @@ -300,7 +295,7 @@ func (oq *OtherQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range oq.fields { + for _, f := range oq.ctx.Fields { if !other.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -342,9 +337,9 @@ func (oq *OtherQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Other, func (oq *OtherQuery) sqlCount(ctx context.Context) (int, error) { _spec := oq.querySpec() - _spec.Node.Columns = oq.fields - if len(oq.fields) > 0 { - _spec.Unique = oq.unique != nil && *oq.unique + _spec.Node.Columns = oq.ctx.Fields + if len(oq.ctx.Fields) > 0 { + _spec.Unique = oq.ctx.Unique != nil && *oq.ctx.Unique } return sqlgraph.CountNodes(ctx, oq.driver, _spec) } @@ -362,10 +357,10 @@ func (oq *OtherQuery) querySpec() *sqlgraph.QuerySpec { From: oq.sql, Unique: true, } - if unique := oq.unique; unique != nil { + if unique := oq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := oq.fields; len(fields) > 0 { + if fields := oq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, other.FieldID) for i := range fields { @@ -381,10 +376,10 @@ func (oq *OtherQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := oq.limit; limit != nil { + if limit := oq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := oq.offset; offset != nil { + if offset := oq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := oq.order; len(ps) > 0 { @@ -400,7 +395,7 @@ func (oq *OtherQuery) querySpec() *sqlgraph.QuerySpec { func (oq *OtherQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(oq.driver.Dialect()) t1 := builder.Table(other.Table) - columns := oq.fields + columns := oq.ctx.Fields if len(columns) == 0 { columns = other.Columns } @@ -409,7 +404,7 @@ func (oq *OtherQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = oq.sql selector.Select(selector.Columns(columns...)...) } - if oq.unique != nil && *oq.unique { + if oq.ctx.Unique != nil && *oq.ctx.Unique { selector.Distinct() } for _, p := range oq.predicates { @@ -418,12 +413,12 @@ func (oq *OtherQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range oq.order { p(selector) } - if offset := oq.offset; offset != nil { + if offset := oq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := oq.limit; limit != nil { + if limit := oq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -443,7 +438,7 @@ func (ogb *OtherGroupBy) Aggregate(fns ...AggregateFunc) *OtherGroupBy { // Scan applies the selector query and scans the result into the given value. func (ogb *OtherGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeOther, "GroupBy") + ctx = setContextOp(ctx, ogb.build.ctx, "GroupBy") if err := ogb.build.prepareQuery(ctx); err != nil { return err } @@ -491,7 +486,7 @@ func (os *OtherSelect) Aggregate(fns ...AggregateFunc) *OtherSelect { // Scan applies the selector query and scans the result into the given value. func (os *OtherSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeOther, "Select") + ctx = setContextOp(ctx, os.ctx, "Select") if err := os.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/pet_query.go b/entc/integration/customid/ent/pet_query.go index bdf98df71..f21c3189c 100644 --- a/entc/integration/customid/ent/pet_query.go +++ b/entc/integration/customid/ent/pet_query.go @@ -24,11 +24,8 @@ import ( // PetQuery is the builder for querying Pet entities. type PetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Pet withOwner *UserQuery @@ -49,20 +46,20 @@ func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { // Limit the number of records to be returned by this query. func (pq *PetQuery) Limit(limit int) *PetQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PetQuery) Offset(offset int) *PetQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PetQuery) Unique(unique bool) *PetQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -163,7 +160,7 @@ func (pq *PetQuery) QueryBestFriend() *PetQuery { // First returns the first Pet entity from the query. // Returns a *NotFoundError when no Pet was found. func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePet, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -186,7 +183,7 @@ func (pq *PetQuery) FirstX(ctx context.Context) *Pet { // Returns a *NotFoundError when no Pet ID was found. func (pq *PetQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePet, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -209,7 +206,7 @@ func (pq *PetQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Pet entity is found. // Returns a *NotFoundError when no Pet entities are found. func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePet, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -237,7 +234,7 @@ func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { // Returns a *NotFoundError when no entities are found. func (pq *PetQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePet, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -262,7 +259,7 @@ func (pq *PetQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Pets. func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { - ctx = newQueryContext(ctx, TypePet, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -282,7 +279,7 @@ func (pq *PetQuery) AllX(ctx context.Context) []*Pet { // IDs executes the query and returns a list of Pet IDs. func (pq *PetQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypePet, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(pet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -300,7 +297,7 @@ func (pq *PetQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (pq *PetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePet, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -318,7 +315,7 @@ func (pq *PetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePet, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -346,8 +343,7 @@ func (pq *PetQuery) Clone() *PetQuery { } return &PetQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Pet{}, pq.predicates...), @@ -356,9 +352,8 @@ func (pq *PetQuery) Clone() *PetQuery { withFriends: pq.withFriends.Clone(), withBestFriend: pq.withBestFriend.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -409,9 +404,9 @@ func (pq *PetQuery) WithBestFriend(opts ...func(*PetQuery)) *PetQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PetGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = pet.Label grbuild.scan = grbuild.Scan return grbuild @@ -420,10 +415,10 @@ func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (pq *PetQuery) Select(fields ...string) *PetSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PetSelect{PetQuery: pq} sbuild.label = pet.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -443,7 +438,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !pet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -679,9 +674,9 @@ func (pq *PetQuery) loadBestFriend(ctx context.Context, query *PetQuery, nodes [ func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -699,10 +694,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, pet.FieldID) for i := range fields { @@ -718,10 +713,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -737,7 +732,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = pet.Columns } @@ -746,7 +741,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } for _, p := range pq.predicates { @@ -755,12 +750,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -780,7 +775,7 @@ func (pgb *PetGroupBy) Aggregate(fns ...AggregateFunc) *PetGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -828,7 +823,7 @@ func (ps *PetSelect) Aggregate(fns ...AggregateFunc) *PetSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/revision_query.go b/entc/integration/customid/ent/revision_query.go index c54b5c9e3..5913796dd 100644 --- a/entc/integration/customid/ent/revision_query.go +++ b/entc/integration/customid/ent/revision_query.go @@ -21,11 +21,8 @@ import ( // RevisionQuery is the builder for querying Revision entities. type RevisionQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Revision // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (rq *RevisionQuery) Where(ps ...predicate.Revision) *RevisionQuery { // Limit the number of records to be returned by this query. func (rq *RevisionQuery) Limit(limit int) *RevisionQuery { - rq.limit = &limit + rq.ctx.Limit = &limit return rq } // Offset to start from. func (rq *RevisionQuery) Offset(offset int) *RevisionQuery { - rq.offset = &offset + rq.ctx.Offset = &offset return rq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (rq *RevisionQuery) Unique(unique bool) *RevisionQuery { - rq.unique = &unique + rq.ctx.Unique = &unique return rq } @@ -67,7 +64,7 @@ func (rq *RevisionQuery) Order(o ...OrderFunc) *RevisionQuery { // First returns the first Revision entity from the query. // Returns a *NotFoundError when no Revision was found. func (rq *RevisionQuery) First(ctx context.Context) (*Revision, error) { - nodes, err := rq.Limit(1).All(newQueryContext(ctx, TypeRevision, "First")) + nodes, err := rq.Limit(1).All(setContextOp(ctx, rq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (rq *RevisionQuery) FirstX(ctx context.Context) *Revision { // Returns a *NotFoundError when no Revision ID was found. func (rq *RevisionQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = rq.Limit(1).IDs(newQueryContext(ctx, TypeRevision, "FirstID")); err != nil { + if ids, err = rq.Limit(1).IDs(setContextOp(ctx, rq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (rq *RevisionQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Revision entity is found. // Returns a *NotFoundError when no Revision entities are found. func (rq *RevisionQuery) Only(ctx context.Context) (*Revision, error) { - nodes, err := rq.Limit(2).All(newQueryContext(ctx, TypeRevision, "Only")) + nodes, err := rq.Limit(2).All(setContextOp(ctx, rq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (rq *RevisionQuery) OnlyX(ctx context.Context) *Revision { // Returns a *NotFoundError when no entities are found. func (rq *RevisionQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = rq.Limit(2).IDs(newQueryContext(ctx, TypeRevision, "OnlyID")); err != nil { + if ids, err = rq.Limit(2).IDs(setContextOp(ctx, rq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (rq *RevisionQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Revisions. func (rq *RevisionQuery) All(ctx context.Context) ([]*Revision, error) { - ctx = newQueryContext(ctx, TypeRevision, "All") + ctx = setContextOp(ctx, rq.ctx, "All") if err := rq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (rq *RevisionQuery) AllX(ctx context.Context) []*Revision { // IDs executes the query and returns a list of Revision IDs. func (rq *RevisionQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeRevision, "IDs") + ctx = setContextOp(ctx, rq.ctx, "IDs") if err := rq.Select(revision.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (rq *RevisionQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (rq *RevisionQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeRevision, "Count") + ctx = setContextOp(ctx, rq.ctx, "Count") if err := rq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (rq *RevisionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (rq *RevisionQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeRevision, "Exist") + ctx = setContextOp(ctx, rq.ctx, "Exist") switch _, err := rq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,24 +247,22 @@ func (rq *RevisionQuery) Clone() *RevisionQuery { } return &RevisionQuery{ config: rq.config, - limit: rq.limit, - offset: rq.offset, + ctx: rq.ctx.Clone(), order: append([]OrderFunc{}, rq.order...), inters: append([]Interceptor{}, rq.inters...), predicates: append([]predicate.Revision{}, rq.predicates...), // clone intermediate query. - sql: rq.sql.Clone(), - path: rq.path, - unique: rq.unique, + sql: rq.sql.Clone(), + path: rq.path, } } // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (rq *RevisionQuery) GroupBy(field string, fields ...string) *RevisionGroupBy { - rq.fields = append([]string{field}, fields...) + rq.ctx.Fields = append([]string{field}, fields...) grbuild := &RevisionGroupBy{build: rq} - grbuild.flds = &rq.fields + grbuild.flds = &rq.ctx.Fields grbuild.label = revision.Label grbuild.scan = grbuild.Scan return grbuild @@ -276,10 +271,10 @@ func (rq *RevisionQuery) GroupBy(field string, fields ...string) *RevisionGroupB // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (rq *RevisionQuery) Select(fields ...string) *RevisionSelect { - rq.fields = append(rq.fields, fields...) + rq.ctx.Fields = append(rq.ctx.Fields, fields...) sbuild := &RevisionSelect{RevisionQuery: rq} sbuild.label = revision.Label - sbuild.flds, sbuild.scan = &rq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &rq.ctx.Fields, sbuild.Scan return sbuild } @@ -299,7 +294,7 @@ func (rq *RevisionQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range rq.fields { + for _, f := range rq.ctx.Fields { if !revision.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -341,9 +336,9 @@ func (rq *RevisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Rev func (rq *RevisionQuery) sqlCount(ctx context.Context) (int, error) { _spec := rq.querySpec() - _spec.Node.Columns = rq.fields - if len(rq.fields) > 0 { - _spec.Unique = rq.unique != nil && *rq.unique + _spec.Node.Columns = rq.ctx.Fields + if len(rq.ctx.Fields) > 0 { + _spec.Unique = rq.ctx.Unique != nil && *rq.ctx.Unique } return sqlgraph.CountNodes(ctx, rq.driver, _spec) } @@ -361,10 +356,10 @@ func (rq *RevisionQuery) querySpec() *sqlgraph.QuerySpec { From: rq.sql, Unique: true, } - if unique := rq.unique; unique != nil { + if unique := rq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := rq.fields; len(fields) > 0 { + if fields := rq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, revision.FieldID) for i := range fields { @@ -380,10 +375,10 @@ func (rq *RevisionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := rq.limit; limit != nil { + if limit := rq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := rq.offset; offset != nil { + if offset := rq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := rq.order; len(ps) > 0 { @@ -399,7 +394,7 @@ func (rq *RevisionQuery) querySpec() *sqlgraph.QuerySpec { func (rq *RevisionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(rq.driver.Dialect()) t1 := builder.Table(revision.Table) - columns := rq.fields + columns := rq.ctx.Fields if len(columns) == 0 { columns = revision.Columns } @@ -408,7 +403,7 @@ func (rq *RevisionQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = rq.sql selector.Select(selector.Columns(columns...)...) } - if rq.unique != nil && *rq.unique { + if rq.ctx.Unique != nil && *rq.ctx.Unique { selector.Distinct() } for _, p := range rq.predicates { @@ -417,12 +412,12 @@ func (rq *RevisionQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range rq.order { p(selector) } - if offset := rq.offset; offset != nil { + if offset := rq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := rq.limit; limit != nil { + if limit := rq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -442,7 +437,7 @@ func (rgb *RevisionGroupBy) Aggregate(fns ...AggregateFunc) *RevisionGroupBy { // Scan applies the selector query and scans the result into the given value. func (rgb *RevisionGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRevision, "GroupBy") + ctx = setContextOp(ctx, rgb.build.ctx, "GroupBy") if err := rgb.build.prepareQuery(ctx); err != nil { return err } @@ -490,7 +485,7 @@ func (rs *RevisionSelect) Aggregate(fns ...AggregateFunc) *RevisionSelect { // Scan applies the selector query and scans the result into the given value. func (rs *RevisionSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRevision, "Select") + ctx = setContextOp(ctx, rs.ctx, "Select") if err := rs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/session_query.go b/entc/integration/customid/ent/session_query.go index c03d305af..5ba2e0598 100644 --- a/entc/integration/customid/ent/session_query.go +++ b/entc/integration/customid/ent/session_query.go @@ -23,11 +23,8 @@ import ( // SessionQuery is the builder for querying Session entities. type SessionQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Session withDevice *DeviceQuery @@ -45,20 +42,20 @@ func (sq *SessionQuery) Where(ps ...predicate.Session) *SessionQuery { // Limit the number of records to be returned by this query. func (sq *SessionQuery) Limit(limit int) *SessionQuery { - sq.limit = &limit + sq.ctx.Limit = &limit return sq } // Offset to start from. func (sq *SessionQuery) Offset(offset int) *SessionQuery { - sq.offset = &offset + sq.ctx.Offset = &offset return sq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (sq *SessionQuery) Unique(unique bool) *SessionQuery { - sq.unique = &unique + sq.ctx.Unique = &unique return sq } @@ -93,7 +90,7 @@ func (sq *SessionQuery) QueryDevice() *DeviceQuery { // First returns the first Session entity from the query. // Returns a *NotFoundError when no Session was found. func (sq *SessionQuery) First(ctx context.Context) (*Session, error) { - nodes, err := sq.Limit(1).All(newQueryContext(ctx, TypeSession, "First")) + nodes, err := sq.Limit(1).All(setContextOp(ctx, sq.ctx, "First")) if err != nil { return nil, err } @@ -116,7 +113,7 @@ func (sq *SessionQuery) FirstX(ctx context.Context) *Session { // Returns a *NotFoundError when no Session ID was found. func (sq *SessionQuery) FirstID(ctx context.Context) (id schema.ID, err error) { var ids []schema.ID - if ids, err = sq.Limit(1).IDs(newQueryContext(ctx, TypeSession, "FirstID")); err != nil { + if ids, err = sq.Limit(1).IDs(setContextOp(ctx, sq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -139,7 +136,7 @@ func (sq *SessionQuery) FirstIDX(ctx context.Context) schema.ID { // Returns a *NotSingularError when more than one Session entity is found. // Returns a *NotFoundError when no Session entities are found. func (sq *SessionQuery) Only(ctx context.Context) (*Session, error) { - nodes, err := sq.Limit(2).All(newQueryContext(ctx, TypeSession, "Only")) + nodes, err := sq.Limit(2).All(setContextOp(ctx, sq.ctx, "Only")) if err != nil { return nil, err } @@ -167,7 +164,7 @@ func (sq *SessionQuery) OnlyX(ctx context.Context) *Session { // Returns a *NotFoundError when no entities are found. func (sq *SessionQuery) OnlyID(ctx context.Context) (id schema.ID, err error) { var ids []schema.ID - if ids, err = sq.Limit(2).IDs(newQueryContext(ctx, TypeSession, "OnlyID")); err != nil { + if ids, err = sq.Limit(2).IDs(setContextOp(ctx, sq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -192,7 +189,7 @@ func (sq *SessionQuery) OnlyIDX(ctx context.Context) schema.ID { // All executes the query and returns a list of Sessions. func (sq *SessionQuery) All(ctx context.Context) ([]*Session, error) { - ctx = newQueryContext(ctx, TypeSession, "All") + ctx = setContextOp(ctx, sq.ctx, "All") if err := sq.prepareQuery(ctx); err != nil { return nil, err } @@ -212,7 +209,7 @@ func (sq *SessionQuery) AllX(ctx context.Context) []*Session { // IDs executes the query and returns a list of Session IDs. func (sq *SessionQuery) IDs(ctx context.Context) ([]schema.ID, error) { var ids []schema.ID - ctx = newQueryContext(ctx, TypeSession, "IDs") + ctx = setContextOp(ctx, sq.ctx, "IDs") if err := sq.Select(session.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -230,7 +227,7 @@ func (sq *SessionQuery) IDsX(ctx context.Context) []schema.ID { // Count returns the count of the given query. func (sq *SessionQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeSession, "Count") + ctx = setContextOp(ctx, sq.ctx, "Count") if err := sq.prepareQuery(ctx); err != nil { return 0, err } @@ -248,7 +245,7 @@ func (sq *SessionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (sq *SessionQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeSession, "Exist") + ctx = setContextOp(ctx, sq.ctx, "Exist") switch _, err := sq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -276,16 +273,14 @@ func (sq *SessionQuery) Clone() *SessionQuery { } return &SessionQuery{ config: sq.config, - limit: sq.limit, - offset: sq.offset, + ctx: sq.ctx.Clone(), order: append([]OrderFunc{}, sq.order...), inters: append([]Interceptor{}, sq.inters...), predicates: append([]predicate.Session{}, sq.predicates...), withDevice: sq.withDevice.Clone(), // clone intermediate query. - sql: sq.sql.Clone(), - path: sq.path, - unique: sq.unique, + sql: sq.sql.Clone(), + path: sq.path, } } @@ -303,9 +298,9 @@ func (sq *SessionQuery) WithDevice(opts ...func(*DeviceQuery)) *SessionQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (sq *SessionQuery) GroupBy(field string, fields ...string) *SessionGroupBy { - sq.fields = append([]string{field}, fields...) + sq.ctx.Fields = append([]string{field}, fields...) grbuild := &SessionGroupBy{build: sq} - grbuild.flds = &sq.fields + grbuild.flds = &sq.ctx.Fields grbuild.label = session.Label grbuild.scan = grbuild.Scan return grbuild @@ -314,10 +309,10 @@ func (sq *SessionQuery) GroupBy(field string, fields ...string) *SessionGroupBy // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (sq *SessionQuery) Select(fields ...string) *SessionSelect { - sq.fields = append(sq.fields, fields...) + sq.ctx.Fields = append(sq.ctx.Fields, fields...) sbuild := &SessionSelect{SessionQuery: sq} sbuild.label = session.Label - sbuild.flds, sbuild.scan = &sq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &sq.ctx.Fields, sbuild.Scan return sbuild } @@ -337,7 +332,7 @@ func (sq *SessionQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range sq.fields { + for _, f := range sq.ctx.Fields { if !session.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -429,9 +424,9 @@ func (sq *SessionQuery) loadDevice(ctx context.Context, query *DeviceQuery, node func (sq *SessionQuery) sqlCount(ctx context.Context) (int, error) { _spec := sq.querySpec() - _spec.Node.Columns = sq.fields - if len(sq.fields) > 0 { - _spec.Unique = sq.unique != nil && *sq.unique + _spec.Node.Columns = sq.ctx.Fields + if len(sq.ctx.Fields) > 0 { + _spec.Unique = sq.ctx.Unique != nil && *sq.ctx.Unique } return sqlgraph.CountNodes(ctx, sq.driver, _spec) } @@ -449,10 +444,10 @@ func (sq *SessionQuery) querySpec() *sqlgraph.QuerySpec { From: sq.sql, Unique: true, } - if unique := sq.unique; unique != nil { + if unique := sq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := sq.fields; len(fields) > 0 { + if fields := sq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, session.FieldID) for i := range fields { @@ -468,10 +463,10 @@ func (sq *SessionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := sq.order; len(ps) > 0 { @@ -487,7 +482,7 @@ func (sq *SessionQuery) querySpec() *sqlgraph.QuerySpec { func (sq *SessionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(sq.driver.Dialect()) t1 := builder.Table(session.Table) - columns := sq.fields + columns := sq.ctx.Fields if len(columns) == 0 { columns = session.Columns } @@ -496,7 +491,7 @@ func (sq *SessionQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = sq.sql selector.Select(selector.Columns(columns...)...) } - if sq.unique != nil && *sq.unique { + if sq.ctx.Unique != nil && *sq.ctx.Unique { selector.Distinct() } for _, p := range sq.predicates { @@ -505,12 +500,12 @@ func (sq *SessionQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range sq.order { p(selector) } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -530,7 +525,7 @@ func (sgb *SessionGroupBy) Aggregate(fns ...AggregateFunc) *SessionGroupBy { // Scan applies the selector query and scans the result into the given value. func (sgb *SessionGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeSession, "GroupBy") + ctx = setContextOp(ctx, sgb.build.ctx, "GroupBy") if err := sgb.build.prepareQuery(ctx); err != nil { return err } @@ -578,7 +573,7 @@ func (ss *SessionSelect) Aggregate(fns ...AggregateFunc) *SessionSelect { // Scan applies the selector query and scans the result into the given value. func (ss *SessionSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeSession, "Select") + ctx = setContextOp(ctx, ss.ctx, "Select") if err := ss.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/token_query.go b/entc/integration/customid/ent/token_query.go index d49578dc0..ce005e39e 100644 --- a/entc/integration/customid/ent/token_query.go +++ b/entc/integration/customid/ent/token_query.go @@ -23,11 +23,8 @@ import ( // TokenQuery is the builder for querying Token entities. type TokenQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Token withAccount *AccountQuery @@ -45,20 +42,20 @@ func (tq *TokenQuery) Where(ps ...predicate.Token) *TokenQuery { // Limit the number of records to be returned by this query. func (tq *TokenQuery) Limit(limit int) *TokenQuery { - tq.limit = &limit + tq.ctx.Limit = &limit return tq } // Offset to start from. func (tq *TokenQuery) Offset(offset int) *TokenQuery { - tq.offset = &offset + tq.ctx.Offset = &offset return tq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tq *TokenQuery) Unique(unique bool) *TokenQuery { - tq.unique = &unique + tq.ctx.Unique = &unique return tq } @@ -93,7 +90,7 @@ func (tq *TokenQuery) QueryAccount() *AccountQuery { // First returns the first Token entity from the query. // Returns a *NotFoundError when no Token was found. func (tq *TokenQuery) First(ctx context.Context) (*Token, error) { - nodes, err := tq.Limit(1).All(newQueryContext(ctx, TypeToken, "First")) + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) if err != nil { return nil, err } @@ -116,7 +113,7 @@ func (tq *TokenQuery) FirstX(ctx context.Context) *Token { // Returns a *NotFoundError when no Token ID was found. func (tq *TokenQuery) FirstID(ctx context.Context) (id sid.ID, err error) { var ids []sid.ID - if ids, err = tq.Limit(1).IDs(newQueryContext(ctx, TypeToken, "FirstID")); err != nil { + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -139,7 +136,7 @@ func (tq *TokenQuery) FirstIDX(ctx context.Context) sid.ID { // Returns a *NotSingularError when more than one Token entity is found. // Returns a *NotFoundError when no Token entities are found. func (tq *TokenQuery) Only(ctx context.Context) (*Token, error) { - nodes, err := tq.Limit(2).All(newQueryContext(ctx, TypeToken, "Only")) + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) if err != nil { return nil, err } @@ -167,7 +164,7 @@ func (tq *TokenQuery) OnlyX(ctx context.Context) *Token { // Returns a *NotFoundError when no entities are found. func (tq *TokenQuery) OnlyID(ctx context.Context) (id sid.ID, err error) { var ids []sid.ID - if ids, err = tq.Limit(2).IDs(newQueryContext(ctx, TypeToken, "OnlyID")); err != nil { + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -192,7 +189,7 @@ func (tq *TokenQuery) OnlyIDX(ctx context.Context) sid.ID { // All executes the query and returns a list of Tokens. func (tq *TokenQuery) All(ctx context.Context) ([]*Token, error) { - ctx = newQueryContext(ctx, TypeToken, "All") + ctx = setContextOp(ctx, tq.ctx, "All") if err := tq.prepareQuery(ctx); err != nil { return nil, err } @@ -212,7 +209,7 @@ func (tq *TokenQuery) AllX(ctx context.Context) []*Token { // IDs executes the query and returns a list of Token IDs. func (tq *TokenQuery) IDs(ctx context.Context) ([]sid.ID, error) { var ids []sid.ID - ctx = newQueryContext(ctx, TypeToken, "IDs") + ctx = setContextOp(ctx, tq.ctx, "IDs") if err := tq.Select(token.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -230,7 +227,7 @@ func (tq *TokenQuery) IDsX(ctx context.Context) []sid.ID { // Count returns the count of the given query. func (tq *TokenQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeToken, "Count") + ctx = setContextOp(ctx, tq.ctx, "Count") if err := tq.prepareQuery(ctx); err != nil { return 0, err } @@ -248,7 +245,7 @@ func (tq *TokenQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tq *TokenQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeToken, "Exist") + ctx = setContextOp(ctx, tq.ctx, "Exist") switch _, err := tq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -276,16 +273,14 @@ func (tq *TokenQuery) Clone() *TokenQuery { } return &TokenQuery{ config: tq.config, - limit: tq.limit, - offset: tq.offset, + ctx: tq.ctx.Clone(), order: append([]OrderFunc{}, tq.order...), inters: append([]Interceptor{}, tq.inters...), predicates: append([]predicate.Token{}, tq.predicates...), withAccount: tq.withAccount.Clone(), // clone intermediate query. - sql: tq.sql.Clone(), - path: tq.path, - unique: tq.unique, + sql: tq.sql.Clone(), + path: tq.path, } } @@ -315,9 +310,9 @@ func (tq *TokenQuery) WithAccount(opts ...func(*AccountQuery)) *TokenQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (tq *TokenQuery) GroupBy(field string, fields ...string) *TokenGroupBy { - tq.fields = append([]string{field}, fields...) + tq.ctx.Fields = append([]string{field}, fields...) grbuild := &TokenGroupBy{build: tq} - grbuild.flds = &tq.fields + grbuild.flds = &tq.ctx.Fields grbuild.label = token.Label grbuild.scan = grbuild.Scan return grbuild @@ -336,10 +331,10 @@ func (tq *TokenQuery) GroupBy(field string, fields ...string) *TokenGroupBy { // Select(token.FieldBody). // Scan(ctx, &v) func (tq *TokenQuery) Select(fields ...string) *TokenSelect { - tq.fields = append(tq.fields, fields...) + tq.ctx.Fields = append(tq.ctx.Fields, fields...) sbuild := &TokenSelect{TokenQuery: tq} sbuild.label = token.Label - sbuild.flds, sbuild.scan = &tq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan return sbuild } @@ -359,7 +354,7 @@ func (tq *TokenQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range tq.fields { + for _, f := range tq.ctx.Fields { if !token.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -451,9 +446,9 @@ func (tq *TokenQuery) loadAccount(ctx context.Context, query *AccountQuery, node func (tq *TokenQuery) sqlCount(ctx context.Context) (int, error) { _spec := tq.querySpec() - _spec.Node.Columns = tq.fields - if len(tq.fields) > 0 { - _spec.Unique = tq.unique != nil && *tq.unique + _spec.Node.Columns = tq.ctx.Fields + if len(tq.ctx.Fields) > 0 { + _spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique } return sqlgraph.CountNodes(ctx, tq.driver, _spec) } @@ -471,10 +466,10 @@ func (tq *TokenQuery) querySpec() *sqlgraph.QuerySpec { From: tq.sql, Unique: true, } - if unique := tq.unique; unique != nil { + if unique := tq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := tq.fields; len(fields) > 0 { + if fields := tq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, token.FieldID) for i := range fields { @@ -490,10 +485,10 @@ func (tq *TokenQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := tq.order; len(ps) > 0 { @@ -509,7 +504,7 @@ func (tq *TokenQuery) querySpec() *sqlgraph.QuerySpec { func (tq *TokenQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(tq.driver.Dialect()) t1 := builder.Table(token.Table) - columns := tq.fields + columns := tq.ctx.Fields if len(columns) == 0 { columns = token.Columns } @@ -518,7 +513,7 @@ func (tq *TokenQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = tq.sql selector.Select(selector.Columns(columns...)...) } - if tq.unique != nil && *tq.unique { + if tq.ctx.Unique != nil && *tq.ctx.Unique { selector.Distinct() } for _, p := range tq.predicates { @@ -527,12 +522,12 @@ func (tq *TokenQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range tq.order { p(selector) } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -552,7 +547,7 @@ func (tgb *TokenGroupBy) Aggregate(fns ...AggregateFunc) *TokenGroupBy { // Scan applies the selector query and scans the result into the given value. func (tgb *TokenGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeToken, "GroupBy") + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") if err := tgb.build.prepareQuery(ctx); err != nil { return err } @@ -600,7 +595,7 @@ func (ts *TokenSelect) Aggregate(fns ...AggregateFunc) *TokenSelect { // Scan applies the selector query and scans the result into the given value. func (ts *TokenSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeToken, "Select") + ctx = setContextOp(ctx, ts.ctx, "Select") if err := ts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/customid/ent/user_query.go b/entc/integration/customid/ent/user_query.go index a805a4577..eae5c896f 100644 --- a/entc/integration/customid/ent/user_query.go +++ b/entc/integration/customid/ent/user_query.go @@ -24,11 +24,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withGroups *GroupQuery @@ -49,20 +46,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -163,7 +160,7 @@ func (uq *UserQuery) QueryPets() *PetQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -186,7 +183,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -209,7 +206,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -237,7 +234,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -262,7 +259,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -282,7 +279,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -300,7 +297,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -318,7 +315,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -346,8 +343,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -356,9 +352,8 @@ func (uq *UserQuery) Clone() *UserQuery { withChildren: uq.withChildren.Clone(), withPets: uq.withPets.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -409,9 +404,9 @@ func (uq *UserQuery) WithPets(opts ...func(*PetQuery)) *UserQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -420,10 +415,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -443,7 +438,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -679,9 +674,9 @@ func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*Use func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -699,10 +694,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -718,10 +713,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -737,7 +732,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -746,7 +741,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -755,12 +750,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -780,7 +775,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -828,7 +823,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/car_query.go b/entc/integration/edgefield/ent/car_query.go index 1067cb3cf..b97c85876 100644 --- a/entc/integration/edgefield/ent/car_query.go +++ b/entc/integration/edgefield/ent/car_query.go @@ -24,11 +24,8 @@ import ( // CarQuery is the builder for querying Car entities. type CarQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Car withRentals *RentalQuery @@ -45,20 +42,20 @@ func (cq *CarQuery) Where(ps ...predicate.Car) *CarQuery { // Limit the number of records to be returned by this query. func (cq *CarQuery) Limit(limit int) *CarQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CarQuery) Offset(offset int) *CarQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CarQuery) Unique(unique bool) *CarQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -93,7 +90,7 @@ func (cq *CarQuery) QueryRentals() *RentalQuery { // First returns the first Car entity from the query. // Returns a *NotFoundError when no Car was found. func (cq *CarQuery) First(ctx context.Context) (*Car, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeCar, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -116,7 +113,7 @@ func (cq *CarQuery) FirstX(ctx context.Context) *Car { // Returns a *NotFoundError when no Car ID was found. func (cq *CarQuery) FirstID(ctx context.Context) (id uuid.UUID, err error) { var ids []uuid.UUID - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeCar, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -139,7 +136,7 @@ func (cq *CarQuery) FirstIDX(ctx context.Context) uuid.UUID { // Returns a *NotSingularError when more than one Car entity is found. // Returns a *NotFoundError when no Car entities are found. func (cq *CarQuery) Only(ctx context.Context) (*Car, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeCar, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -167,7 +164,7 @@ func (cq *CarQuery) OnlyX(ctx context.Context) *Car { // Returns a *NotFoundError when no entities are found. func (cq *CarQuery) OnlyID(ctx context.Context) (id uuid.UUID, err error) { var ids []uuid.UUID - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeCar, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -192,7 +189,7 @@ func (cq *CarQuery) OnlyIDX(ctx context.Context) uuid.UUID { // All executes the query and returns a list of Cars. func (cq *CarQuery) All(ctx context.Context) ([]*Car, error) { - ctx = newQueryContext(ctx, TypeCar, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -212,7 +209,7 @@ func (cq *CarQuery) AllX(ctx context.Context) []*Car { // IDs executes the query and returns a list of Car IDs. func (cq *CarQuery) IDs(ctx context.Context) ([]uuid.UUID, error) { var ids []uuid.UUID - ctx = newQueryContext(ctx, TypeCar, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(car.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -230,7 +227,7 @@ func (cq *CarQuery) IDsX(ctx context.Context) []uuid.UUID { // Count returns the count of the given query. func (cq *CarQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCar, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -248,7 +245,7 @@ func (cq *CarQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CarQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCar, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -276,16 +273,14 @@ func (cq *CarQuery) Clone() *CarQuery { } return &CarQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Car{}, cq.predicates...), withRentals: cq.withRentals.Clone(), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -315,9 +310,9 @@ func (cq *CarQuery) WithRentals(opts ...func(*RentalQuery)) *CarQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CarQuery) GroupBy(field string, fields ...string) *CarGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CarGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = car.Label grbuild.scan = grbuild.Scan return grbuild @@ -336,10 +331,10 @@ func (cq *CarQuery) GroupBy(field string, fields ...string) *CarGroupBy { // Select(car.FieldNumber). // Scan(ctx, &v) func (cq *CarQuery) Select(fields ...string) *CarSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CarSelect{CarQuery: cq} sbuild.label = car.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -359,7 +354,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !car.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -440,9 +435,9 @@ func (cq *CarQuery) loadRentals(ctx context.Context, query *RentalQuery, nodes [ func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -460,10 +455,10 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, car.FieldID) for i := range fields { @@ -479,10 +474,10 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -498,7 +493,7 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(car.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = car.Columns } @@ -507,7 +502,7 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -516,12 +511,12 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -541,7 +536,7 @@ func (cgb *CarGroupBy) Aggregate(fns ...AggregateFunc) *CarGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CarGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCar, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -589,7 +584,7 @@ func (cs *CarSelect) Aggregate(fns ...AggregateFunc) *CarSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CarSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCar, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/card_query.go b/entc/integration/edgefield/ent/card_query.go index d5cfaa153..63241bc66 100644 --- a/entc/integration/edgefield/ent/card_query.go +++ b/entc/integration/edgefield/ent/card_query.go @@ -22,11 +22,8 @@ import ( // CardQuery is the builder for querying Card entities. type CardQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Card withOwner *UserQuery @@ -43,20 +40,20 @@ func (cq *CardQuery) Where(ps ...predicate.Card) *CardQuery { // Limit the number of records to be returned by this query. func (cq *CardQuery) Limit(limit int) *CardQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CardQuery) Offset(offset int) *CardQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CardQuery) Unique(unique bool) *CardQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -91,7 +88,7 @@ func (cq *CardQuery) QueryOwner() *UserQuery { // First returns the first Card entity from the query. // Returns a *NotFoundError when no Card was found. func (cq *CardQuery) First(ctx context.Context) (*Card, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeCard, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -114,7 +111,7 @@ func (cq *CardQuery) FirstX(ctx context.Context) *Card { // Returns a *NotFoundError when no Card ID was found. func (cq *CardQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeCard, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -137,7 +134,7 @@ func (cq *CardQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Card entity is found. // Returns a *NotFoundError when no Card entities are found. func (cq *CardQuery) Only(ctx context.Context) (*Card, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeCard, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -165,7 +162,7 @@ func (cq *CardQuery) OnlyX(ctx context.Context) *Card { // Returns a *NotFoundError when no entities are found. func (cq *CardQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeCard, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -190,7 +187,7 @@ func (cq *CardQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Cards. func (cq *CardQuery) All(ctx context.Context) ([]*Card, error) { - ctx = newQueryContext(ctx, TypeCard, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -210,7 +207,7 @@ func (cq *CardQuery) AllX(ctx context.Context) []*Card { // IDs executes the query and returns a list of Card IDs. func (cq *CardQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeCard, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(card.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -228,7 +225,7 @@ func (cq *CardQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *CardQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCard, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -246,7 +243,7 @@ func (cq *CardQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CardQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCard, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -274,16 +271,14 @@ func (cq *CardQuery) Clone() *CardQuery { } return &CardQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Card{}, cq.predicates...), withOwner: cq.withOwner.Clone(), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -313,9 +308,9 @@ func (cq *CardQuery) WithOwner(opts ...func(*UserQuery)) *CardQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CardQuery) GroupBy(field string, fields ...string) *CardGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CardGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = card.Label grbuild.scan = grbuild.Scan return grbuild @@ -334,10 +329,10 @@ func (cq *CardQuery) GroupBy(field string, fields ...string) *CardGroupBy { // Select(card.FieldNumber). // Scan(ctx, &v) func (cq *CardQuery) Select(fields ...string) *CardSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CardSelect{CardQuery: cq} sbuild.label = card.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -357,7 +352,7 @@ func (cq *CardQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !card.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -439,9 +434,9 @@ func (cq *CardQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*C func (cq *CardQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -459,10 +454,10 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, card.FieldID) for i := range fields { @@ -478,10 +473,10 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -497,7 +492,7 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(card.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = card.Columns } @@ -506,7 +501,7 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -515,12 +510,12 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -540,7 +535,7 @@ func (cgb *CardGroupBy) Aggregate(fns ...AggregateFunc) *CardGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CardGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCard, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -588,7 +583,7 @@ func (cs *CardSelect) Aggregate(fns ...AggregateFunc) *CardSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CardSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCard, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/client.go b/entc/integration/edgefield/ent/client.go index 8fca9d113..81d0ed7e9 100644 --- a/entc/integration/edgefield/ent/client.go +++ b/entc/integration/edgefield/ent/client.go @@ -298,6 +298,7 @@ func (c *CarClient) DeleteOneID(id uuid.UUID) *CarDeleteOne { func (c *CarClient) Query() *CarQuery { return &CarQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCar}, inters: c.Interceptors(), } } @@ -431,6 +432,7 @@ func (c *CardClient) DeleteOneID(id int) *CardDeleteOne { func (c *CardClient) Query() *CardQuery { return &CardQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCard}, inters: c.Interceptors(), } } @@ -564,6 +566,7 @@ func (c *InfoClient) DeleteOneID(id int) *InfoDeleteOne { func (c *InfoClient) Query() *InfoQuery { return &InfoQuery{ config: c.config, + ctx: &QueryContext{Type: TypeInfo}, inters: c.Interceptors(), } } @@ -697,6 +700,7 @@ func (c *MetadataClient) DeleteOneID(id int) *MetadataDeleteOne { func (c *MetadataClient) Query() *MetadataQuery { return &MetadataQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMetadata}, inters: c.Interceptors(), } } @@ -862,6 +866,7 @@ func (c *NodeClient) DeleteOneID(id int) *NodeDeleteOne { func (c *NodeClient) Query() *NodeQuery { return &NodeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeNode}, inters: c.Interceptors(), } } @@ -1011,6 +1016,7 @@ func (c *PetClient) DeleteOneID(id int) *PetDeleteOne { func (c *PetClient) Query() *PetQuery { return &PetQuery{ config: c.config, + ctx: &QueryContext{Type: TypePet}, inters: c.Interceptors(), } } @@ -1144,6 +1150,7 @@ func (c *PostClient) DeleteOneID(id int) *PostDeleteOne { func (c *PostClient) Query() *PostQuery { return &PostQuery{ config: c.config, + ctx: &QueryContext{Type: TypePost}, inters: c.Interceptors(), } } @@ -1277,6 +1284,7 @@ func (c *RentalClient) DeleteOneID(id int) *RentalDeleteOne { func (c *RentalClient) Query() *RentalQuery { return &RentalQuery{ config: c.config, + ctx: &QueryContext{Type: TypeRental}, inters: c.Interceptors(), } } @@ -1426,6 +1434,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/edgefield/ent/ent.go b/entc/integration/edgefield/ent/ent.go index 2dd9faee9..239f06f48 100644 --- a/entc/integration/edgefield/ent/ent.go +++ b/entc/integration/edgefield/ent/ent.go @@ -32,6 +32,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -523,10 +524,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/edgefield/ent/info_query.go b/entc/integration/edgefield/ent/info_query.go index 618e75c88..9c727b7b8 100644 --- a/entc/integration/edgefield/ent/info_query.go +++ b/entc/integration/edgefield/ent/info_query.go @@ -22,11 +22,8 @@ import ( // InfoQuery is the builder for querying Info entities. type InfoQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Info withUser *UserQuery @@ -43,20 +40,20 @@ func (iq *InfoQuery) Where(ps ...predicate.Info) *InfoQuery { // Limit the number of records to be returned by this query. func (iq *InfoQuery) Limit(limit int) *InfoQuery { - iq.limit = &limit + iq.ctx.Limit = &limit return iq } // Offset to start from. func (iq *InfoQuery) Offset(offset int) *InfoQuery { - iq.offset = &offset + iq.ctx.Offset = &offset return iq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (iq *InfoQuery) Unique(unique bool) *InfoQuery { - iq.unique = &unique + iq.ctx.Unique = &unique return iq } @@ -91,7 +88,7 @@ func (iq *InfoQuery) QueryUser() *UserQuery { // First returns the first Info entity from the query. // Returns a *NotFoundError when no Info was found. func (iq *InfoQuery) First(ctx context.Context) (*Info, error) { - nodes, err := iq.Limit(1).All(newQueryContext(ctx, TypeInfo, "First")) + nodes, err := iq.Limit(1).All(setContextOp(ctx, iq.ctx, "First")) if err != nil { return nil, err } @@ -114,7 +111,7 @@ func (iq *InfoQuery) FirstX(ctx context.Context) *Info { // Returns a *NotFoundError when no Info ID was found. func (iq *InfoQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = iq.Limit(1).IDs(newQueryContext(ctx, TypeInfo, "FirstID")); err != nil { + if ids, err = iq.Limit(1).IDs(setContextOp(ctx, iq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -137,7 +134,7 @@ func (iq *InfoQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Info entity is found. // Returns a *NotFoundError when no Info entities are found. func (iq *InfoQuery) Only(ctx context.Context) (*Info, error) { - nodes, err := iq.Limit(2).All(newQueryContext(ctx, TypeInfo, "Only")) + nodes, err := iq.Limit(2).All(setContextOp(ctx, iq.ctx, "Only")) if err != nil { return nil, err } @@ -165,7 +162,7 @@ func (iq *InfoQuery) OnlyX(ctx context.Context) *Info { // Returns a *NotFoundError when no entities are found. func (iq *InfoQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = iq.Limit(2).IDs(newQueryContext(ctx, TypeInfo, "OnlyID")); err != nil { + if ids, err = iq.Limit(2).IDs(setContextOp(ctx, iq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -190,7 +187,7 @@ func (iq *InfoQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Infos. func (iq *InfoQuery) All(ctx context.Context) ([]*Info, error) { - ctx = newQueryContext(ctx, TypeInfo, "All") + ctx = setContextOp(ctx, iq.ctx, "All") if err := iq.prepareQuery(ctx); err != nil { return nil, err } @@ -210,7 +207,7 @@ func (iq *InfoQuery) AllX(ctx context.Context) []*Info { // IDs executes the query and returns a list of Info IDs. func (iq *InfoQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeInfo, "IDs") + ctx = setContextOp(ctx, iq.ctx, "IDs") if err := iq.Select(info.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -228,7 +225,7 @@ func (iq *InfoQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (iq *InfoQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeInfo, "Count") + ctx = setContextOp(ctx, iq.ctx, "Count") if err := iq.prepareQuery(ctx); err != nil { return 0, err } @@ -246,7 +243,7 @@ func (iq *InfoQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (iq *InfoQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeInfo, "Exist") + ctx = setContextOp(ctx, iq.ctx, "Exist") switch _, err := iq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -274,16 +271,14 @@ func (iq *InfoQuery) Clone() *InfoQuery { } return &InfoQuery{ config: iq.config, - limit: iq.limit, - offset: iq.offset, + ctx: iq.ctx.Clone(), order: append([]OrderFunc{}, iq.order...), inters: append([]Interceptor{}, iq.inters...), predicates: append([]predicate.Info{}, iq.predicates...), withUser: iq.withUser.Clone(), // clone intermediate query. - sql: iq.sql.Clone(), - path: iq.path, - unique: iq.unique, + sql: iq.sql.Clone(), + path: iq.path, } } @@ -313,9 +308,9 @@ func (iq *InfoQuery) WithUser(opts ...func(*UserQuery)) *InfoQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (iq *InfoQuery) GroupBy(field string, fields ...string) *InfoGroupBy { - iq.fields = append([]string{field}, fields...) + iq.ctx.Fields = append([]string{field}, fields...) grbuild := &InfoGroupBy{build: iq} - grbuild.flds = &iq.fields + grbuild.flds = &iq.ctx.Fields grbuild.label = info.Label grbuild.scan = grbuild.Scan return grbuild @@ -334,10 +329,10 @@ func (iq *InfoQuery) GroupBy(field string, fields ...string) *InfoGroupBy { // Select(info.FieldContent). // Scan(ctx, &v) func (iq *InfoQuery) Select(fields ...string) *InfoSelect { - iq.fields = append(iq.fields, fields...) + iq.ctx.Fields = append(iq.ctx.Fields, fields...) sbuild := &InfoSelect{InfoQuery: iq} sbuild.label = info.Label - sbuild.flds, sbuild.scan = &iq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &iq.ctx.Fields, sbuild.Scan return sbuild } @@ -357,7 +352,7 @@ func (iq *InfoQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range iq.fields { + for _, f := range iq.ctx.Fields { if !info.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -439,9 +434,9 @@ func (iq *InfoQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*In func (iq *InfoQuery) sqlCount(ctx context.Context) (int, error) { _spec := iq.querySpec() - _spec.Node.Columns = iq.fields - if len(iq.fields) > 0 { - _spec.Unique = iq.unique != nil && *iq.unique + _spec.Node.Columns = iq.ctx.Fields + if len(iq.ctx.Fields) > 0 { + _spec.Unique = iq.ctx.Unique != nil && *iq.ctx.Unique } return sqlgraph.CountNodes(ctx, iq.driver, _spec) } @@ -459,10 +454,10 @@ func (iq *InfoQuery) querySpec() *sqlgraph.QuerySpec { From: iq.sql, Unique: true, } - if unique := iq.unique; unique != nil { + if unique := iq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := iq.fields; len(fields) > 0 { + if fields := iq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, info.FieldID) for i := range fields { @@ -478,10 +473,10 @@ func (iq *InfoQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := iq.limit; limit != nil { + if limit := iq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := iq.offset; offset != nil { + if offset := iq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := iq.order; len(ps) > 0 { @@ -497,7 +492,7 @@ func (iq *InfoQuery) querySpec() *sqlgraph.QuerySpec { func (iq *InfoQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(iq.driver.Dialect()) t1 := builder.Table(info.Table) - columns := iq.fields + columns := iq.ctx.Fields if len(columns) == 0 { columns = info.Columns } @@ -506,7 +501,7 @@ func (iq *InfoQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = iq.sql selector.Select(selector.Columns(columns...)...) } - if iq.unique != nil && *iq.unique { + if iq.ctx.Unique != nil && *iq.ctx.Unique { selector.Distinct() } for _, p := range iq.predicates { @@ -515,12 +510,12 @@ func (iq *InfoQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range iq.order { p(selector) } - if offset := iq.offset; offset != nil { + if offset := iq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := iq.limit; limit != nil { + if limit := iq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -540,7 +535,7 @@ func (igb *InfoGroupBy) Aggregate(fns ...AggregateFunc) *InfoGroupBy { // Scan applies the selector query and scans the result into the given value. func (igb *InfoGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeInfo, "GroupBy") + ctx = setContextOp(ctx, igb.build.ctx, "GroupBy") if err := igb.build.prepareQuery(ctx); err != nil { return err } @@ -588,7 +583,7 @@ func (is *InfoSelect) Aggregate(fns ...AggregateFunc) *InfoSelect { // Scan applies the selector query and scans the result into the given value. func (is *InfoSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeInfo, "Select") + ctx = setContextOp(ctx, is.ctx, "Select") if err := is.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/metadata_query.go b/entc/integration/edgefield/ent/metadata_query.go index 238492e7f..f622660e7 100644 --- a/entc/integration/edgefield/ent/metadata_query.go +++ b/entc/integration/edgefield/ent/metadata_query.go @@ -23,11 +23,8 @@ import ( // MetadataQuery is the builder for querying Metadata entities. type MetadataQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Metadata withUser *UserQuery @@ -46,20 +43,20 @@ func (mq *MetadataQuery) Where(ps ...predicate.Metadata) *MetadataQuery { // Limit the number of records to be returned by this query. func (mq *MetadataQuery) Limit(limit int) *MetadataQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } // Offset to start from. func (mq *MetadataQuery) Offset(offset int) *MetadataQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MetadataQuery) Unique(unique bool) *MetadataQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } @@ -138,7 +135,7 @@ func (mq *MetadataQuery) QueryParent() *MetadataQuery { // First returns the first Metadata entity from the query. // Returns a *NotFoundError when no Metadata was found. func (mq *MetadataQuery) First(ctx context.Context) (*Metadata, error) { - nodes, err := mq.Limit(1).All(newQueryContext(ctx, TypeMetadata, "First")) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -161,7 +158,7 @@ func (mq *MetadataQuery) FirstX(ctx context.Context) *Metadata { // Returns a *NotFoundError when no Metadata ID was found. func (mq *MetadataQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(1).IDs(newQueryContext(ctx, TypeMetadata, "FirstID")); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -184,7 +181,7 @@ func (mq *MetadataQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Metadata entity is found. // Returns a *NotFoundError when no Metadata entities are found. func (mq *MetadataQuery) Only(ctx context.Context) (*Metadata, error) { - nodes, err := mq.Limit(2).All(newQueryContext(ctx, TypeMetadata, "Only")) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -212,7 +209,7 @@ func (mq *MetadataQuery) OnlyX(ctx context.Context) *Metadata { // Returns a *NotFoundError when no entities are found. func (mq *MetadataQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(2).IDs(newQueryContext(ctx, TypeMetadata, "OnlyID")); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -237,7 +234,7 @@ func (mq *MetadataQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MetadataSlice. func (mq *MetadataQuery) All(ctx context.Context) ([]*Metadata, error) { - ctx = newQueryContext(ctx, TypeMetadata, "All") + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } @@ -257,7 +254,7 @@ func (mq *MetadataQuery) AllX(ctx context.Context) []*Metadata { // IDs executes the query and returns a list of Metadata IDs. func (mq *MetadataQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeMetadata, "IDs") + ctx = setContextOp(ctx, mq.ctx, "IDs") if err := mq.Select(metadata.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -275,7 +272,7 @@ func (mq *MetadataQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mq *MetadataQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeMetadata, "Count") + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } @@ -293,7 +290,7 @@ func (mq *MetadataQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MetadataQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeMetadata, "Exist") + ctx = setContextOp(ctx, mq.ctx, "Exist") switch _, err := mq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -321,8 +318,7 @@ func (mq *MetadataQuery) Clone() *MetadataQuery { } return &MetadataQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, + ctx: mq.ctx.Clone(), order: append([]OrderFunc{}, mq.order...), inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Metadata{}, mq.predicates...), @@ -330,9 +326,8 @@ func (mq *MetadataQuery) Clone() *MetadataQuery { withChildren: mq.withChildren.Clone(), withParent: mq.withParent.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } @@ -384,9 +379,9 @@ func (mq *MetadataQuery) WithParent(opts ...func(*MetadataQuery)) *MetadataQuery // Aggregate(ent.Count()). // Scan(ctx, &v) func (mq *MetadataQuery) GroupBy(field string, fields ...string) *MetadataGroupBy { - mq.fields = append([]string{field}, fields...) + mq.ctx.Fields = append([]string{field}, fields...) grbuild := &MetadataGroupBy{build: mq} - grbuild.flds = &mq.fields + grbuild.flds = &mq.ctx.Fields grbuild.label = metadata.Label grbuild.scan = grbuild.Scan return grbuild @@ -405,10 +400,10 @@ func (mq *MetadataQuery) GroupBy(field string, fields ...string) *MetadataGroupB // Select(metadata.FieldAge). // Scan(ctx, &v) func (mq *MetadataQuery) Select(fields ...string) *MetadataSelect { - mq.fields = append(mq.fields, fields...) + mq.ctx.Fields = append(mq.ctx.Fields, fields...) sbuild := &MetadataSelect{MetadataQuery: mq} sbuild.label = metadata.Label - sbuild.flds, sbuild.scan = &mq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan return sbuild } @@ -428,7 +423,7 @@ func (mq *MetadataQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range mq.fields { + for _, f := range mq.ctx.Fields { if !metadata.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -581,9 +576,9 @@ func (mq *MetadataQuery) loadParent(ctx context.Context, query *MetadataQuery, n func (mq *MetadataQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } @@ -601,10 +596,10 @@ func (mq *MetadataQuery) querySpec() *sqlgraph.QuerySpec { From: mq.sql, Unique: true, } - if unique := mq.unique; unique != nil { + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, metadata.FieldID) for i := range fields { @@ -620,10 +615,10 @@ func (mq *MetadataQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -639,7 +634,7 @@ func (mq *MetadataQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MetadataQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(metadata.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = metadata.Columns } @@ -648,7 +643,7 @@ func (mq *MetadataQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -657,12 +652,12 @@ func (mq *MetadataQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -682,7 +677,7 @@ func (mgb *MetadataGroupBy) Aggregate(fns ...AggregateFunc) *MetadataGroupBy { // Scan applies the selector query and scans the result into the given value. func (mgb *MetadataGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeMetadata, "GroupBy") + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") if err := mgb.build.prepareQuery(ctx); err != nil { return err } @@ -730,7 +725,7 @@ func (ms *MetadataSelect) Aggregate(fns ...AggregateFunc) *MetadataSelect { // Scan applies the selector query and scans the result into the given value. func (ms *MetadataSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeMetadata, "Select") + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/node_query.go b/entc/integration/edgefield/ent/node_query.go index fec365a5a..08778edf0 100644 --- a/entc/integration/edgefield/ent/node_query.go +++ b/entc/integration/edgefield/ent/node_query.go @@ -22,11 +22,8 @@ import ( // NodeQuery is the builder for querying Node entities. type NodeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Node withPrev *NodeQuery @@ -44,20 +41,20 @@ func (nq *NodeQuery) Where(ps ...predicate.Node) *NodeQuery { // Limit the number of records to be returned by this query. func (nq *NodeQuery) Limit(limit int) *NodeQuery { - nq.limit = &limit + nq.ctx.Limit = &limit return nq } // Offset to start from. func (nq *NodeQuery) Offset(offset int) *NodeQuery { - nq.offset = &offset + nq.ctx.Offset = &offset return nq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (nq *NodeQuery) Unique(unique bool) *NodeQuery { - nq.unique = &unique + nq.ctx.Unique = &unique return nq } @@ -114,7 +111,7 @@ func (nq *NodeQuery) QueryNext() *NodeQuery { // First returns the first Node entity from the query. // Returns a *NotFoundError when no Node was found. func (nq *NodeQuery) First(ctx context.Context) (*Node, error) { - nodes, err := nq.Limit(1).All(newQueryContext(ctx, TypeNode, "First")) + nodes, err := nq.Limit(1).All(setContextOp(ctx, nq.ctx, "First")) if err != nil { return nil, err } @@ -137,7 +134,7 @@ func (nq *NodeQuery) FirstX(ctx context.Context) *Node { // Returns a *NotFoundError when no Node ID was found. func (nq *NodeQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = nq.Limit(1).IDs(newQueryContext(ctx, TypeNode, "FirstID")); err != nil { + if ids, err = nq.Limit(1).IDs(setContextOp(ctx, nq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -160,7 +157,7 @@ func (nq *NodeQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Node entity is found. // Returns a *NotFoundError when no Node entities are found. func (nq *NodeQuery) Only(ctx context.Context) (*Node, error) { - nodes, err := nq.Limit(2).All(newQueryContext(ctx, TypeNode, "Only")) + nodes, err := nq.Limit(2).All(setContextOp(ctx, nq.ctx, "Only")) if err != nil { return nil, err } @@ -188,7 +185,7 @@ func (nq *NodeQuery) OnlyX(ctx context.Context) *Node { // Returns a *NotFoundError when no entities are found. func (nq *NodeQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = nq.Limit(2).IDs(newQueryContext(ctx, TypeNode, "OnlyID")); err != nil { + if ids, err = nq.Limit(2).IDs(setContextOp(ctx, nq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -213,7 +210,7 @@ func (nq *NodeQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Nodes. func (nq *NodeQuery) All(ctx context.Context) ([]*Node, error) { - ctx = newQueryContext(ctx, TypeNode, "All") + ctx = setContextOp(ctx, nq.ctx, "All") if err := nq.prepareQuery(ctx); err != nil { return nil, err } @@ -233,7 +230,7 @@ func (nq *NodeQuery) AllX(ctx context.Context) []*Node { // IDs executes the query and returns a list of Node IDs. func (nq *NodeQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeNode, "IDs") + ctx = setContextOp(ctx, nq.ctx, "IDs") if err := nq.Select(node.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -251,7 +248,7 @@ func (nq *NodeQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (nq *NodeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeNode, "Count") + ctx = setContextOp(ctx, nq.ctx, "Count") if err := nq.prepareQuery(ctx); err != nil { return 0, err } @@ -269,7 +266,7 @@ func (nq *NodeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (nq *NodeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeNode, "Exist") + ctx = setContextOp(ctx, nq.ctx, "Exist") switch _, err := nq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -297,17 +294,15 @@ func (nq *NodeQuery) Clone() *NodeQuery { } return &NodeQuery{ config: nq.config, - limit: nq.limit, - offset: nq.offset, + ctx: nq.ctx.Clone(), order: append([]OrderFunc{}, nq.order...), inters: append([]Interceptor{}, nq.inters...), predicates: append([]predicate.Node{}, nq.predicates...), withPrev: nq.withPrev.Clone(), withNext: nq.withNext.Clone(), // clone intermediate query. - sql: nq.sql.Clone(), - path: nq.path, - unique: nq.unique, + sql: nq.sql.Clone(), + path: nq.path, } } @@ -348,9 +343,9 @@ func (nq *NodeQuery) WithNext(opts ...func(*NodeQuery)) *NodeQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (nq *NodeQuery) GroupBy(field string, fields ...string) *NodeGroupBy { - nq.fields = append([]string{field}, fields...) + nq.ctx.Fields = append([]string{field}, fields...) grbuild := &NodeGroupBy{build: nq} - grbuild.flds = &nq.fields + grbuild.flds = &nq.ctx.Fields grbuild.label = node.Label grbuild.scan = grbuild.Scan return grbuild @@ -369,10 +364,10 @@ func (nq *NodeQuery) GroupBy(field string, fields ...string) *NodeGroupBy { // Select(node.FieldValue). // Scan(ctx, &v) func (nq *NodeQuery) Select(fields ...string) *NodeSelect { - nq.fields = append(nq.fields, fields...) + nq.ctx.Fields = append(nq.ctx.Fields, fields...) sbuild := &NodeSelect{NodeQuery: nq} sbuild.label = node.Label - sbuild.flds, sbuild.scan = &nq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &nq.ctx.Fields, sbuild.Scan return sbuild } @@ -392,7 +387,7 @@ func (nq *NodeQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range nq.fields { + for _, f := range nq.ctx.Fields { if !node.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -505,9 +500,9 @@ func (nq *NodeQuery) loadNext(ctx context.Context, query *NodeQuery, nodes []*No func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { _spec := nq.querySpec() - _spec.Node.Columns = nq.fields - if len(nq.fields) > 0 { - _spec.Unique = nq.unique != nil && *nq.unique + _spec.Node.Columns = nq.ctx.Fields + if len(nq.ctx.Fields) > 0 { + _spec.Unique = nq.ctx.Unique != nil && *nq.ctx.Unique } return sqlgraph.CountNodes(ctx, nq.driver, _spec) } @@ -525,10 +520,10 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { From: nq.sql, Unique: true, } - if unique := nq.unique; unique != nil { + if unique := nq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := nq.fields; len(fields) > 0 { + if fields := nq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, node.FieldID) for i := range fields { @@ -544,10 +539,10 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := nq.limit; limit != nil { + if limit := nq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := nq.offset; offset != nil { + if offset := nq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := nq.order; len(ps) > 0 { @@ -563,7 +558,7 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(nq.driver.Dialect()) t1 := builder.Table(node.Table) - columns := nq.fields + columns := nq.ctx.Fields if len(columns) == 0 { columns = node.Columns } @@ -572,7 +567,7 @@ func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = nq.sql selector.Select(selector.Columns(columns...)...) } - if nq.unique != nil && *nq.unique { + if nq.ctx.Unique != nil && *nq.ctx.Unique { selector.Distinct() } for _, p := range nq.predicates { @@ -581,12 +576,12 @@ func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range nq.order { p(selector) } - if offset := nq.offset; offset != nil { + if offset := nq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := nq.limit; limit != nil { + if limit := nq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -606,7 +601,7 @@ func (ngb *NodeGroupBy) Aggregate(fns ...AggregateFunc) *NodeGroupBy { // Scan applies the selector query and scans the result into the given value. func (ngb *NodeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeNode, "GroupBy") + ctx = setContextOp(ctx, ngb.build.ctx, "GroupBy") if err := ngb.build.prepareQuery(ctx); err != nil { return err } @@ -654,7 +649,7 @@ func (ns *NodeSelect) Aggregate(fns ...AggregateFunc) *NodeSelect { // Scan applies the selector query and scans the result into the given value. func (ns *NodeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeNode, "Select") + ctx = setContextOp(ctx, ns.ctx, "Select") if err := ns.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/pet_query.go b/entc/integration/edgefield/ent/pet_query.go index 510cb36ef..9b13a4242 100644 --- a/entc/integration/edgefield/ent/pet_query.go +++ b/entc/integration/edgefield/ent/pet_query.go @@ -22,11 +22,8 @@ import ( // PetQuery is the builder for querying Pet entities. type PetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Pet withOwner *UserQuery @@ -43,20 +40,20 @@ func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { // Limit the number of records to be returned by this query. func (pq *PetQuery) Limit(limit int) *PetQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PetQuery) Offset(offset int) *PetQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PetQuery) Unique(unique bool) *PetQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -91,7 +88,7 @@ func (pq *PetQuery) QueryOwner() *UserQuery { // First returns the first Pet entity from the query. // Returns a *NotFoundError when no Pet was found. func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePet, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -114,7 +111,7 @@ func (pq *PetQuery) FirstX(ctx context.Context) *Pet { // Returns a *NotFoundError when no Pet ID was found. func (pq *PetQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePet, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -137,7 +134,7 @@ func (pq *PetQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Pet entity is found. // Returns a *NotFoundError when no Pet entities are found. func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePet, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -165,7 +162,7 @@ func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { // Returns a *NotFoundError when no entities are found. func (pq *PetQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePet, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -190,7 +187,7 @@ func (pq *PetQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Pets. func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { - ctx = newQueryContext(ctx, TypePet, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -210,7 +207,7 @@ func (pq *PetQuery) AllX(ctx context.Context) []*Pet { // IDs executes the query and returns a list of Pet IDs. func (pq *PetQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypePet, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(pet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -228,7 +225,7 @@ func (pq *PetQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pq *PetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePet, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -246,7 +243,7 @@ func (pq *PetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePet, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -274,16 +271,14 @@ func (pq *PetQuery) Clone() *PetQuery { } return &PetQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Pet{}, pq.predicates...), withOwner: pq.withOwner.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -313,9 +308,9 @@ func (pq *PetQuery) WithOwner(opts ...func(*UserQuery)) *PetQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PetGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = pet.Label grbuild.scan = grbuild.Scan return grbuild @@ -334,10 +329,10 @@ func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { // Select(pet.FieldOwnerID). // Scan(ctx, &v) func (pq *PetQuery) Select(fields ...string) *PetSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PetSelect{PetQuery: pq} sbuild.label = pet.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -357,7 +352,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !pet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -439,9 +434,9 @@ func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pe func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -459,10 +454,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, pet.FieldID) for i := range fields { @@ -478,10 +473,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -497,7 +492,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = pet.Columns } @@ -506,7 +501,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } for _, p := range pq.predicates { @@ -515,12 +510,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -540,7 +535,7 @@ func (pgb *PetGroupBy) Aggregate(fns ...AggregateFunc) *PetGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -588,7 +583,7 @@ func (ps *PetSelect) Aggregate(fns ...AggregateFunc) *PetSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/post_query.go b/entc/integration/edgefield/ent/post_query.go index 10e5d1a54..31ef6680e 100644 --- a/entc/integration/edgefield/ent/post_query.go +++ b/entc/integration/edgefield/ent/post_query.go @@ -22,11 +22,8 @@ import ( // PostQuery is the builder for querying Post entities. type PostQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Post withAuthor *UserQuery @@ -43,20 +40,20 @@ func (pq *PostQuery) Where(ps ...predicate.Post) *PostQuery { // Limit the number of records to be returned by this query. func (pq *PostQuery) Limit(limit int) *PostQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PostQuery) Offset(offset int) *PostQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PostQuery) Unique(unique bool) *PostQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -91,7 +88,7 @@ func (pq *PostQuery) QueryAuthor() *UserQuery { // First returns the first Post entity from the query. // Returns a *NotFoundError when no Post was found. func (pq *PostQuery) First(ctx context.Context) (*Post, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePost, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -114,7 +111,7 @@ func (pq *PostQuery) FirstX(ctx context.Context) *Post { // Returns a *NotFoundError when no Post ID was found. func (pq *PostQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePost, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -137,7 +134,7 @@ func (pq *PostQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Post entity is found. // Returns a *NotFoundError when no Post entities are found. func (pq *PostQuery) Only(ctx context.Context) (*Post, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePost, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -165,7 +162,7 @@ func (pq *PostQuery) OnlyX(ctx context.Context) *Post { // Returns a *NotFoundError when no entities are found. func (pq *PostQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePost, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -190,7 +187,7 @@ func (pq *PostQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Posts. func (pq *PostQuery) All(ctx context.Context) ([]*Post, error) { - ctx = newQueryContext(ctx, TypePost, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -210,7 +207,7 @@ func (pq *PostQuery) AllX(ctx context.Context) []*Post { // IDs executes the query and returns a list of Post IDs. func (pq *PostQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypePost, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(post.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -228,7 +225,7 @@ func (pq *PostQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pq *PostQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePost, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -246,7 +243,7 @@ func (pq *PostQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PostQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePost, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -274,16 +271,14 @@ func (pq *PostQuery) Clone() *PostQuery { } return &PostQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Post{}, pq.predicates...), withAuthor: pq.withAuthor.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -313,9 +308,9 @@ func (pq *PostQuery) WithAuthor(opts ...func(*UserQuery)) *PostQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (pq *PostQuery) GroupBy(field string, fields ...string) *PostGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PostGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = post.Label grbuild.scan = grbuild.Scan return grbuild @@ -334,10 +329,10 @@ func (pq *PostQuery) GroupBy(field string, fields ...string) *PostGroupBy { // Select(post.FieldText). // Scan(ctx, &v) func (pq *PostQuery) Select(fields ...string) *PostSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PostSelect{PostQuery: pq} sbuild.label = post.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -357,7 +352,7 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !post.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -442,9 +437,9 @@ func (pq *PostQuery) loadAuthor(ctx context.Context, query *UserQuery, nodes []* func (pq *PostQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -462,10 +457,10 @@ func (pq *PostQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, post.FieldID) for i := range fields { @@ -481,10 +476,10 @@ func (pq *PostQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -500,7 +495,7 @@ func (pq *PostQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PostQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(post.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = post.Columns } @@ -509,7 +504,7 @@ func (pq *PostQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } for _, p := range pq.predicates { @@ -518,12 +513,12 @@ func (pq *PostQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -543,7 +538,7 @@ func (pgb *PostGroupBy) Aggregate(fns ...AggregateFunc) *PostGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PostGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePost, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -591,7 +586,7 @@ func (ps *PostSelect) Aggregate(fns ...AggregateFunc) *PostSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PostSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePost, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/rental_query.go b/entc/integration/edgefield/ent/rental_query.go index 5eeb4bbae..01215f7d4 100644 --- a/entc/integration/edgefield/ent/rental_query.go +++ b/entc/integration/edgefield/ent/rental_query.go @@ -24,11 +24,8 @@ import ( // RentalQuery is the builder for querying Rental entities. type RentalQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Rental withUser *UserQuery @@ -46,20 +43,20 @@ func (rq *RentalQuery) Where(ps ...predicate.Rental) *RentalQuery { // Limit the number of records to be returned by this query. func (rq *RentalQuery) Limit(limit int) *RentalQuery { - rq.limit = &limit + rq.ctx.Limit = &limit return rq } // Offset to start from. func (rq *RentalQuery) Offset(offset int) *RentalQuery { - rq.offset = &offset + rq.ctx.Offset = &offset return rq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (rq *RentalQuery) Unique(unique bool) *RentalQuery { - rq.unique = &unique + rq.ctx.Unique = &unique return rq } @@ -116,7 +113,7 @@ func (rq *RentalQuery) QueryCar() *CarQuery { // First returns the first Rental entity from the query. // Returns a *NotFoundError when no Rental was found. func (rq *RentalQuery) First(ctx context.Context) (*Rental, error) { - nodes, err := rq.Limit(1).All(newQueryContext(ctx, TypeRental, "First")) + nodes, err := rq.Limit(1).All(setContextOp(ctx, rq.ctx, "First")) if err != nil { return nil, err } @@ -139,7 +136,7 @@ func (rq *RentalQuery) FirstX(ctx context.Context) *Rental { // Returns a *NotFoundError when no Rental ID was found. func (rq *RentalQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = rq.Limit(1).IDs(newQueryContext(ctx, TypeRental, "FirstID")); err != nil { + if ids, err = rq.Limit(1).IDs(setContextOp(ctx, rq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -162,7 +159,7 @@ func (rq *RentalQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Rental entity is found. // Returns a *NotFoundError when no Rental entities are found. func (rq *RentalQuery) Only(ctx context.Context) (*Rental, error) { - nodes, err := rq.Limit(2).All(newQueryContext(ctx, TypeRental, "Only")) + nodes, err := rq.Limit(2).All(setContextOp(ctx, rq.ctx, "Only")) if err != nil { return nil, err } @@ -190,7 +187,7 @@ func (rq *RentalQuery) OnlyX(ctx context.Context) *Rental { // Returns a *NotFoundError when no entities are found. func (rq *RentalQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = rq.Limit(2).IDs(newQueryContext(ctx, TypeRental, "OnlyID")); err != nil { + if ids, err = rq.Limit(2).IDs(setContextOp(ctx, rq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -215,7 +212,7 @@ func (rq *RentalQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Rentals. func (rq *RentalQuery) All(ctx context.Context) ([]*Rental, error) { - ctx = newQueryContext(ctx, TypeRental, "All") + ctx = setContextOp(ctx, rq.ctx, "All") if err := rq.prepareQuery(ctx); err != nil { return nil, err } @@ -235,7 +232,7 @@ func (rq *RentalQuery) AllX(ctx context.Context) []*Rental { // IDs executes the query and returns a list of Rental IDs. func (rq *RentalQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeRental, "IDs") + ctx = setContextOp(ctx, rq.ctx, "IDs") if err := rq.Select(rental.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -253,7 +250,7 @@ func (rq *RentalQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (rq *RentalQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeRental, "Count") + ctx = setContextOp(ctx, rq.ctx, "Count") if err := rq.prepareQuery(ctx); err != nil { return 0, err } @@ -271,7 +268,7 @@ func (rq *RentalQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (rq *RentalQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeRental, "Exist") + ctx = setContextOp(ctx, rq.ctx, "Exist") switch _, err := rq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -299,17 +296,15 @@ func (rq *RentalQuery) Clone() *RentalQuery { } return &RentalQuery{ config: rq.config, - limit: rq.limit, - offset: rq.offset, + ctx: rq.ctx.Clone(), order: append([]OrderFunc{}, rq.order...), inters: append([]Interceptor{}, rq.inters...), predicates: append([]predicate.Rental{}, rq.predicates...), withUser: rq.withUser.Clone(), withCar: rq.withCar.Clone(), // clone intermediate query. - sql: rq.sql.Clone(), - path: rq.path, - unique: rq.unique, + sql: rq.sql.Clone(), + path: rq.path, } } @@ -350,9 +345,9 @@ func (rq *RentalQuery) WithCar(opts ...func(*CarQuery)) *RentalQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (rq *RentalQuery) GroupBy(field string, fields ...string) *RentalGroupBy { - rq.fields = append([]string{field}, fields...) + rq.ctx.Fields = append([]string{field}, fields...) grbuild := &RentalGroupBy{build: rq} - grbuild.flds = &rq.fields + grbuild.flds = &rq.ctx.Fields grbuild.label = rental.Label grbuild.scan = grbuild.Scan return grbuild @@ -371,10 +366,10 @@ func (rq *RentalQuery) GroupBy(field string, fields ...string) *RentalGroupBy { // Select(rental.FieldDate). // Scan(ctx, &v) func (rq *RentalQuery) Select(fields ...string) *RentalSelect { - rq.fields = append(rq.fields, fields...) + rq.ctx.Fields = append(rq.ctx.Fields, fields...) sbuild := &RentalSelect{RentalQuery: rq} sbuild.label = rental.Label - sbuild.flds, sbuild.scan = &rq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &rq.ctx.Fields, sbuild.Scan return sbuild } @@ -394,7 +389,7 @@ func (rq *RentalQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range rq.fields { + for _, f := range rq.ctx.Fields { if !rental.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -512,9 +507,9 @@ func (rq *RentalQuery) loadCar(ctx context.Context, query *CarQuery, nodes []*Re func (rq *RentalQuery) sqlCount(ctx context.Context) (int, error) { _spec := rq.querySpec() - _spec.Node.Columns = rq.fields - if len(rq.fields) > 0 { - _spec.Unique = rq.unique != nil && *rq.unique + _spec.Node.Columns = rq.ctx.Fields + if len(rq.ctx.Fields) > 0 { + _spec.Unique = rq.ctx.Unique != nil && *rq.ctx.Unique } return sqlgraph.CountNodes(ctx, rq.driver, _spec) } @@ -532,10 +527,10 @@ func (rq *RentalQuery) querySpec() *sqlgraph.QuerySpec { From: rq.sql, Unique: true, } - if unique := rq.unique; unique != nil { + if unique := rq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := rq.fields; len(fields) > 0 { + if fields := rq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, rental.FieldID) for i := range fields { @@ -551,10 +546,10 @@ func (rq *RentalQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := rq.limit; limit != nil { + if limit := rq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := rq.offset; offset != nil { + if offset := rq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := rq.order; len(ps) > 0 { @@ -570,7 +565,7 @@ func (rq *RentalQuery) querySpec() *sqlgraph.QuerySpec { func (rq *RentalQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(rq.driver.Dialect()) t1 := builder.Table(rental.Table) - columns := rq.fields + columns := rq.ctx.Fields if len(columns) == 0 { columns = rental.Columns } @@ -579,7 +574,7 @@ func (rq *RentalQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = rq.sql selector.Select(selector.Columns(columns...)...) } - if rq.unique != nil && *rq.unique { + if rq.ctx.Unique != nil && *rq.ctx.Unique { selector.Distinct() } for _, p := range rq.predicates { @@ -588,12 +583,12 @@ func (rq *RentalQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range rq.order { p(selector) } - if offset := rq.offset; offset != nil { + if offset := rq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := rq.limit; limit != nil { + if limit := rq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -613,7 +608,7 @@ func (rgb *RentalGroupBy) Aggregate(fns ...AggregateFunc) *RentalGroupBy { // Scan applies the selector query and scans the result into the given value. func (rgb *RentalGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRental, "GroupBy") + ctx = setContextOp(ctx, rgb.build.ctx, "GroupBy") if err := rgb.build.prepareQuery(ctx); err != nil { return err } @@ -661,7 +656,7 @@ func (rs *RentalSelect) Aggregate(fns ...AggregateFunc) *RentalSelect { // Scan applies the selector query and scans the result into the given value. func (rs *RentalSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRental, "Select") + ctx = setContextOp(ctx, rs.ctx, "Select") if err := rs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgefield/ent/user_query.go b/entc/integration/edgefield/ent/user_query.go index 09fa36f12..fc6f8bbb0 100644 --- a/entc/integration/edgefield/ent/user_query.go +++ b/entc/integration/edgefield/ent/user_query.go @@ -27,11 +27,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withPets *PetQuery @@ -55,20 +52,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -257,7 +254,7 @@ func (uq *UserQuery) QueryRentals() *RentalQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -280,7 +277,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -303,7 +300,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -331,7 +328,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -356,7 +353,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -376,7 +373,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -394,7 +391,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -412,7 +409,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -440,8 +437,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -454,9 +450,8 @@ func (uq *UserQuery) Clone() *UserQuery { withInfo: uq.withInfo.Clone(), withRentals: uq.withRentals.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -563,9 +558,9 @@ func (uq *UserQuery) WithRentals(opts ...func(*RentalQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -584,10 +579,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldParentID). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -607,7 +602,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -927,9 +922,9 @@ func (uq *UserQuery) loadRentals(ctx context.Context, query *RentalQuery, nodes func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -947,10 +942,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -966,10 +961,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -985,7 +980,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -994,7 +989,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -1003,12 +998,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -1028,7 +1023,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -1076,7 +1071,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/client.go b/entc/integration/edgeschema/ent/client.go index 3be016245..89b97c857 100644 --- a/entc/integration/edgeschema/ent/client.go +++ b/entc/integration/edgeschema/ent/client.go @@ -348,6 +348,7 @@ func (c *FriendshipClient) DeleteOneID(id int) *FriendshipDeleteOne { func (c *FriendshipClient) Query() *FriendshipQuery { return &FriendshipQuery{ config: c.config, + ctx: &QueryContext{Type: TypeFriendship}, inters: c.Interceptors(), } } @@ -497,6 +498,7 @@ func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { func (c *GroupClient) Query() *GroupQuery { return &GroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroup}, inters: c.Interceptors(), } } @@ -678,6 +680,7 @@ func (c *GroupTagClient) DeleteOneID(id int) *GroupTagDeleteOne { func (c *GroupTagClient) Query() *GroupTagQuery { return &GroupTagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroupTag}, inters: c.Interceptors(), } } @@ -810,6 +813,7 @@ func (c *RelationshipClient) Delete() *RelationshipDelete { func (c *RelationshipClient) Query() *RelationshipQuery { return &RelationshipQuery{ config: c.config, + ctx: &QueryContext{Type: TypeRelationship}, inters: c.Interceptors(), } } @@ -935,6 +939,7 @@ func (c *RelationshipInfoClient) DeleteOneID(id int) *RelationshipInfoDeleteOne func (c *RelationshipInfoClient) Query() *RelationshipInfoQuery { return &RelationshipInfoQuery{ config: c.config, + ctx: &QueryContext{Type: TypeRelationshipInfo}, inters: c.Interceptors(), } } @@ -1052,6 +1057,7 @@ func (c *RoleClient) DeleteOneID(id int) *RoleDeleteOne { func (c *RoleClient) Query() *RoleQuery { return &RoleQuery{ config: c.config, + ctx: &QueryContext{Type: TypeRole}, inters: c.Interceptors(), } } @@ -1184,6 +1190,7 @@ func (c *RoleUserClient) Delete() *RoleUserDelete { func (c *RoleUserClient) Query() *RoleUserQuery { return &RoleUserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeRoleUser}, inters: c.Interceptors(), } } @@ -1301,6 +1308,7 @@ func (c *TagClient) DeleteOneID(id int) *TagDeleteOne { func (c *TagClient) Query() *TagQuery { return &TagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTag}, inters: c.Interceptors(), } } @@ -1482,6 +1490,7 @@ func (c *TweetClient) DeleteOneID(id int) *TweetDeleteOne { func (c *TweetClient) Query() *TweetQuery { return &TweetQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTweet}, inters: c.Interceptors(), } } @@ -1678,6 +1687,7 @@ func (c *TweetLikeClient) Delete() *TweetLikeDelete { func (c *TweetLikeClient) Query() *TweetLikeQuery { return &TweetLikeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTweetLike}, inters: c.Interceptors(), } } @@ -1796,6 +1806,7 @@ func (c *TweetTagClient) DeleteOneID(id uuid.UUID) *TweetTagDeleteOne { func (c *TweetTagClient) Query() *TweetTagQuery { return &TweetTagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTweetTag}, inters: c.Interceptors(), } } @@ -1945,6 +1956,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } @@ -2255,6 +2267,7 @@ func (c *UserGroupClient) DeleteOneID(id int) *UserGroupDeleteOne { func (c *UserGroupClient) Query() *UserGroupQuery { return &UserGroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUserGroup}, inters: c.Interceptors(), } } @@ -2404,6 +2417,7 @@ func (c *UserTweetClient) DeleteOneID(id int) *UserTweetDeleteOne { func (c *UserTweetClient) Query() *UserTweetQuery { return &UserTweetQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUserTweet}, inters: c.Interceptors(), } } diff --git a/entc/integration/edgeschema/ent/ent.go b/entc/integration/edgeschema/ent/ent.go index 73d49bcdc..1e3a699df 100644 --- a/entc/integration/edgeschema/ent/ent.go +++ b/entc/integration/edgeschema/ent/ent.go @@ -37,6 +37,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -533,10 +534,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/edgeschema/ent/friendship_query.go b/entc/integration/edgeschema/ent/friendship_query.go index 6a7728eb8..18bbc9dee 100644 --- a/entc/integration/edgeschema/ent/friendship_query.go +++ b/entc/integration/edgeschema/ent/friendship_query.go @@ -22,11 +22,8 @@ import ( // FriendshipQuery is the builder for querying Friendship entities. type FriendshipQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Friendship withUser *UserQuery @@ -44,20 +41,20 @@ func (fq *FriendshipQuery) Where(ps ...predicate.Friendship) *FriendshipQuery { // Limit the number of records to be returned by this query. func (fq *FriendshipQuery) Limit(limit int) *FriendshipQuery { - fq.limit = &limit + fq.ctx.Limit = &limit return fq } // Offset to start from. func (fq *FriendshipQuery) Offset(offset int) *FriendshipQuery { - fq.offset = &offset + fq.ctx.Offset = &offset return fq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (fq *FriendshipQuery) Unique(unique bool) *FriendshipQuery { - fq.unique = &unique + fq.ctx.Unique = &unique return fq } @@ -114,7 +111,7 @@ func (fq *FriendshipQuery) QueryFriend() *UserQuery { // First returns the first Friendship entity from the query. // Returns a *NotFoundError when no Friendship was found. func (fq *FriendshipQuery) First(ctx context.Context) (*Friendship, error) { - nodes, err := fq.Limit(1).All(newQueryContext(ctx, TypeFriendship, "First")) + nodes, err := fq.Limit(1).All(setContextOp(ctx, fq.ctx, "First")) if err != nil { return nil, err } @@ -137,7 +134,7 @@ func (fq *FriendshipQuery) FirstX(ctx context.Context) *Friendship { // Returns a *NotFoundError when no Friendship ID was found. func (fq *FriendshipQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = fq.Limit(1).IDs(newQueryContext(ctx, TypeFriendship, "FirstID")); err != nil { + if ids, err = fq.Limit(1).IDs(setContextOp(ctx, fq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -160,7 +157,7 @@ func (fq *FriendshipQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Friendship entity is found. // Returns a *NotFoundError when no Friendship entities are found. func (fq *FriendshipQuery) Only(ctx context.Context) (*Friendship, error) { - nodes, err := fq.Limit(2).All(newQueryContext(ctx, TypeFriendship, "Only")) + nodes, err := fq.Limit(2).All(setContextOp(ctx, fq.ctx, "Only")) if err != nil { return nil, err } @@ -188,7 +185,7 @@ func (fq *FriendshipQuery) OnlyX(ctx context.Context) *Friendship { // Returns a *NotFoundError when no entities are found. func (fq *FriendshipQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = fq.Limit(2).IDs(newQueryContext(ctx, TypeFriendship, "OnlyID")); err != nil { + if ids, err = fq.Limit(2).IDs(setContextOp(ctx, fq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -213,7 +210,7 @@ func (fq *FriendshipQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Friendships. func (fq *FriendshipQuery) All(ctx context.Context) ([]*Friendship, error) { - ctx = newQueryContext(ctx, TypeFriendship, "All") + ctx = setContextOp(ctx, fq.ctx, "All") if err := fq.prepareQuery(ctx); err != nil { return nil, err } @@ -233,7 +230,7 @@ func (fq *FriendshipQuery) AllX(ctx context.Context) []*Friendship { // IDs executes the query and returns a list of Friendship IDs. func (fq *FriendshipQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeFriendship, "IDs") + ctx = setContextOp(ctx, fq.ctx, "IDs") if err := fq.Select(friendship.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -251,7 +248,7 @@ func (fq *FriendshipQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (fq *FriendshipQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeFriendship, "Count") + ctx = setContextOp(ctx, fq.ctx, "Count") if err := fq.prepareQuery(ctx); err != nil { return 0, err } @@ -269,7 +266,7 @@ func (fq *FriendshipQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (fq *FriendshipQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeFriendship, "Exist") + ctx = setContextOp(ctx, fq.ctx, "Exist") switch _, err := fq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -297,17 +294,15 @@ func (fq *FriendshipQuery) Clone() *FriendshipQuery { } return &FriendshipQuery{ config: fq.config, - limit: fq.limit, - offset: fq.offset, + ctx: fq.ctx.Clone(), order: append([]OrderFunc{}, fq.order...), inters: append([]Interceptor{}, fq.inters...), predicates: append([]predicate.Friendship{}, fq.predicates...), withUser: fq.withUser.Clone(), withFriend: fq.withFriend.Clone(), // clone intermediate query. - sql: fq.sql.Clone(), - path: fq.path, - unique: fq.unique, + sql: fq.sql.Clone(), + path: fq.path, } } @@ -348,9 +343,9 @@ func (fq *FriendshipQuery) WithFriend(opts ...func(*UserQuery)) *FriendshipQuery // Aggregate(ent.Count()). // Scan(ctx, &v) func (fq *FriendshipQuery) GroupBy(field string, fields ...string) *FriendshipGroupBy { - fq.fields = append([]string{field}, fields...) + fq.ctx.Fields = append([]string{field}, fields...) grbuild := &FriendshipGroupBy{build: fq} - grbuild.flds = &fq.fields + grbuild.flds = &fq.ctx.Fields grbuild.label = friendship.Label grbuild.scan = grbuild.Scan return grbuild @@ -369,10 +364,10 @@ func (fq *FriendshipQuery) GroupBy(field string, fields ...string) *FriendshipGr // Select(friendship.FieldWeight). // Scan(ctx, &v) func (fq *FriendshipQuery) Select(fields ...string) *FriendshipSelect { - fq.fields = append(fq.fields, fields...) + fq.ctx.Fields = append(fq.ctx.Fields, fields...) sbuild := &FriendshipSelect{FriendshipQuery: fq} sbuild.label = friendship.Label - sbuild.flds, sbuild.scan = &fq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &fq.ctx.Fields, sbuild.Scan return sbuild } @@ -392,7 +387,7 @@ func (fq *FriendshipQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range fq.fields { + for _, f := range fq.ctx.Fields { if !friendship.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -510,9 +505,9 @@ func (fq *FriendshipQuery) loadFriend(ctx context.Context, query *UserQuery, nod func (fq *FriendshipQuery) sqlCount(ctx context.Context) (int, error) { _spec := fq.querySpec() - _spec.Node.Columns = fq.fields - if len(fq.fields) > 0 { - _spec.Unique = fq.unique != nil && *fq.unique + _spec.Node.Columns = fq.ctx.Fields + if len(fq.ctx.Fields) > 0 { + _spec.Unique = fq.ctx.Unique != nil && *fq.ctx.Unique } return sqlgraph.CountNodes(ctx, fq.driver, _spec) } @@ -530,10 +525,10 @@ func (fq *FriendshipQuery) querySpec() *sqlgraph.QuerySpec { From: fq.sql, Unique: true, } - if unique := fq.unique; unique != nil { + if unique := fq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := fq.fields; len(fields) > 0 { + if fields := fq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, friendship.FieldID) for i := range fields { @@ -549,10 +544,10 @@ func (fq *FriendshipQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := fq.limit; limit != nil { + if limit := fq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := fq.offset; offset != nil { + if offset := fq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := fq.order; len(ps) > 0 { @@ -568,7 +563,7 @@ func (fq *FriendshipQuery) querySpec() *sqlgraph.QuerySpec { func (fq *FriendshipQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(fq.driver.Dialect()) t1 := builder.Table(friendship.Table) - columns := fq.fields + columns := fq.ctx.Fields if len(columns) == 0 { columns = friendship.Columns } @@ -577,7 +572,7 @@ func (fq *FriendshipQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = fq.sql selector.Select(selector.Columns(columns...)...) } - if fq.unique != nil && *fq.unique { + if fq.ctx.Unique != nil && *fq.ctx.Unique { selector.Distinct() } for _, p := range fq.predicates { @@ -586,12 +581,12 @@ func (fq *FriendshipQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range fq.order { p(selector) } - if offset := fq.offset; offset != nil { + if offset := fq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := fq.limit; limit != nil { + if limit := fq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -611,7 +606,7 @@ func (fgb *FriendshipGroupBy) Aggregate(fns ...AggregateFunc) *FriendshipGroupBy // Scan applies the selector query and scans the result into the given value. func (fgb *FriendshipGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFriendship, "GroupBy") + ctx = setContextOp(ctx, fgb.build.ctx, "GroupBy") if err := fgb.build.prepareQuery(ctx); err != nil { return err } @@ -659,7 +654,7 @@ func (fs *FriendshipSelect) Aggregate(fns ...AggregateFunc) *FriendshipSelect { // Scan applies the selector query and scans the result into the given value. func (fs *FriendshipSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFriendship, "Select") + ctx = setContextOp(ctx, fs.ctx, "Select") if err := fs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/group_query.go b/entc/integration/edgeschema/ent/group_query.go index 658d66fa2..269f18151 100644 --- a/entc/integration/edgeschema/ent/group_query.go +++ b/entc/integration/edgeschema/ent/group_query.go @@ -26,11 +26,8 @@ import ( // GroupQuery is the builder for querying Group entities. type GroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Group withUsers *UserQuery @@ -50,20 +47,20 @@ func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { // Limit the number of records to be returned by this query. func (gq *GroupQuery) Limit(limit int) *GroupQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GroupQuery) Offset(offset int) *GroupQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GroupQuery) Unique(unique bool) *GroupQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -164,7 +161,7 @@ func (gq *GroupQuery) QueryGroupTags() *GroupTagQuery { // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGroup, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -187,7 +184,7 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group { // Returns a *NotFoundError when no Group ID was found. func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGroup, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -210,7 +207,7 @@ func (gq *GroupQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Group entity is found. // Returns a *NotFoundError when no Group entities are found. func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGroup, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -238,7 +235,7 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { // Returns a *NotFoundError when no entities are found. func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGroup, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -263,7 +260,7 @@ func (gq *GroupQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Groups. func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { - ctx = newQueryContext(ctx, TypeGroup, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -283,7 +280,7 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group { // IDs executes the query and returns a list of Group IDs. func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroup, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -301,7 +298,7 @@ func (gq *GroupQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gq *GroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroup, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -319,7 +316,7 @@ func (gq *GroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroup, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -347,8 +344,7 @@ func (gq *GroupQuery) Clone() *GroupQuery { } return &GroupQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Group{}, gq.predicates...), @@ -357,9 +353,8 @@ func (gq *GroupQuery) Clone() *GroupQuery { withJoinedUsers: gq.withJoinedUsers.Clone(), withGroupTags: gq.withGroupTags.Clone(), // clone intermediate query. - sql: gq.sql.Clone(), - path: gq.path, - unique: gq.unique, + sql: gq.sql.Clone(), + path: gq.path, } } @@ -422,9 +417,9 @@ func (gq *GroupQuery) WithGroupTags(opts ...func(*GroupTagQuery)) *GroupQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = group.Label grbuild.scan = grbuild.Scan return grbuild @@ -443,10 +438,10 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Select(group.FieldName). // Scan(ctx, &v) func (gq *GroupQuery) Select(fields ...string) *GroupSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GroupSelect{GroupQuery: gq} sbuild.label = group.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -466,7 +461,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gq.fields { + for _, f := range gq.ctx.Fields { if !group.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -714,9 +709,9 @@ func (gq *GroupQuery) loadGroupTags(ctx context.Context, query *GroupTagQuery, n func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { _spec := gq.querySpec() - _spec.Node.Columns = gq.fields - if len(gq.fields) > 0 { - _spec.Unique = gq.unique != nil && *gq.unique + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique } return sqlgraph.CountNodes(ctx, gq.driver, _spec) } @@ -734,10 +729,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { From: gq.sql, Unique: true, } - if unique := gq.unique; unique != nil { + if unique := gq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gq.fields; len(fields) > 0 { + if fields := gq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) for i := range fields { @@ -753,10 +748,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gq.order; len(ps) > 0 { @@ -772,7 +767,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) - columns := gq.fields + columns := gq.ctx.Fields if len(columns) == 0 { columns = group.Columns } @@ -781,7 +776,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gq.sql selector.Select(selector.Columns(columns...)...) } - if gq.unique != nil && *gq.unique { + if gq.ctx.Unique != nil && *gq.ctx.Unique { selector.Distinct() } for _, p := range gq.predicates { @@ -790,12 +785,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gq.order { p(selector) } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -815,7 +810,7 @@ func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -863,7 +858,7 @@ func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/grouptag_query.go b/entc/integration/edgeschema/ent/grouptag_query.go index 5b4b15e05..8e5545e06 100644 --- a/entc/integration/edgeschema/ent/grouptag_query.go +++ b/entc/integration/edgeschema/ent/grouptag_query.go @@ -23,11 +23,8 @@ import ( // GroupTagQuery is the builder for querying GroupTag entities. type GroupTagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.GroupTag withTag *TagQuery @@ -45,20 +42,20 @@ func (gtq *GroupTagQuery) Where(ps ...predicate.GroupTag) *GroupTagQuery { // Limit the number of records to be returned by this query. func (gtq *GroupTagQuery) Limit(limit int) *GroupTagQuery { - gtq.limit = &limit + gtq.ctx.Limit = &limit return gtq } // Offset to start from. func (gtq *GroupTagQuery) Offset(offset int) *GroupTagQuery { - gtq.offset = &offset + gtq.ctx.Offset = &offset return gtq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gtq *GroupTagQuery) Unique(unique bool) *GroupTagQuery { - gtq.unique = &unique + gtq.ctx.Unique = &unique return gtq } @@ -115,7 +112,7 @@ func (gtq *GroupTagQuery) QueryGroup() *GroupQuery { // First returns the first GroupTag entity from the query. // Returns a *NotFoundError when no GroupTag was found. func (gtq *GroupTagQuery) First(ctx context.Context) (*GroupTag, error) { - nodes, err := gtq.Limit(1).All(newQueryContext(ctx, TypeGroupTag, "First")) + nodes, err := gtq.Limit(1).All(setContextOp(ctx, gtq.ctx, "First")) if err != nil { return nil, err } @@ -138,7 +135,7 @@ func (gtq *GroupTagQuery) FirstX(ctx context.Context) *GroupTag { // Returns a *NotFoundError when no GroupTag ID was found. func (gtq *GroupTagQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gtq.Limit(1).IDs(newQueryContext(ctx, TypeGroupTag, "FirstID")); err != nil { + if ids, err = gtq.Limit(1).IDs(setContextOp(ctx, gtq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -161,7 +158,7 @@ func (gtq *GroupTagQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one GroupTag entity is found. // Returns a *NotFoundError when no GroupTag entities are found. func (gtq *GroupTagQuery) Only(ctx context.Context) (*GroupTag, error) { - nodes, err := gtq.Limit(2).All(newQueryContext(ctx, TypeGroupTag, "Only")) + nodes, err := gtq.Limit(2).All(setContextOp(ctx, gtq.ctx, "Only")) if err != nil { return nil, err } @@ -189,7 +186,7 @@ func (gtq *GroupTagQuery) OnlyX(ctx context.Context) *GroupTag { // Returns a *NotFoundError when no entities are found. func (gtq *GroupTagQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gtq.Limit(2).IDs(newQueryContext(ctx, TypeGroupTag, "OnlyID")); err != nil { + if ids, err = gtq.Limit(2).IDs(setContextOp(ctx, gtq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -214,7 +211,7 @@ func (gtq *GroupTagQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of GroupTags. func (gtq *GroupTagQuery) All(ctx context.Context) ([]*GroupTag, error) { - ctx = newQueryContext(ctx, TypeGroupTag, "All") + ctx = setContextOp(ctx, gtq.ctx, "All") if err := gtq.prepareQuery(ctx); err != nil { return nil, err } @@ -234,7 +231,7 @@ func (gtq *GroupTagQuery) AllX(ctx context.Context) []*GroupTag { // IDs executes the query and returns a list of GroupTag IDs. func (gtq *GroupTagQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroupTag, "IDs") + ctx = setContextOp(ctx, gtq.ctx, "IDs") if err := gtq.Select(grouptag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -252,7 +249,7 @@ func (gtq *GroupTagQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gtq *GroupTagQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroupTag, "Count") + ctx = setContextOp(ctx, gtq.ctx, "Count") if err := gtq.prepareQuery(ctx); err != nil { return 0, err } @@ -270,7 +267,7 @@ func (gtq *GroupTagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gtq *GroupTagQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroupTag, "Exist") + ctx = setContextOp(ctx, gtq.ctx, "Exist") switch _, err := gtq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -298,17 +295,15 @@ func (gtq *GroupTagQuery) Clone() *GroupTagQuery { } return &GroupTagQuery{ config: gtq.config, - limit: gtq.limit, - offset: gtq.offset, + ctx: gtq.ctx.Clone(), order: append([]OrderFunc{}, gtq.order...), inters: append([]Interceptor{}, gtq.inters...), predicates: append([]predicate.GroupTag{}, gtq.predicates...), withTag: gtq.withTag.Clone(), withGroup: gtq.withGroup.Clone(), // clone intermediate query. - sql: gtq.sql.Clone(), - path: gtq.path, - unique: gtq.unique, + sql: gtq.sql.Clone(), + path: gtq.path, } } @@ -349,9 +344,9 @@ func (gtq *GroupTagQuery) WithGroup(opts ...func(*GroupQuery)) *GroupTagQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (gtq *GroupTagQuery) GroupBy(field string, fields ...string) *GroupTagGroupBy { - gtq.fields = append([]string{field}, fields...) + gtq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupTagGroupBy{build: gtq} - grbuild.flds = >q.fields + grbuild.flds = >q.ctx.Fields grbuild.label = grouptag.Label grbuild.scan = grbuild.Scan return grbuild @@ -370,10 +365,10 @@ func (gtq *GroupTagQuery) GroupBy(field string, fields ...string) *GroupTagGroup // Select(grouptag.FieldTagID). // Scan(ctx, &v) func (gtq *GroupTagQuery) Select(fields ...string) *GroupTagSelect { - gtq.fields = append(gtq.fields, fields...) + gtq.ctx.Fields = append(gtq.ctx.Fields, fields...) sbuild := &GroupTagSelect{GroupTagQuery: gtq} sbuild.label = grouptag.Label - sbuild.flds, sbuild.scan = >q.fields, sbuild.Scan + sbuild.flds, sbuild.scan = >q.ctx.Fields, sbuild.Scan return sbuild } @@ -393,7 +388,7 @@ func (gtq *GroupTagQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gtq.fields { + for _, f := range gtq.ctx.Fields { if !grouptag.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -511,9 +506,9 @@ func (gtq *GroupTagQuery) loadGroup(ctx context.Context, query *GroupQuery, node func (gtq *GroupTagQuery) sqlCount(ctx context.Context) (int, error) { _spec := gtq.querySpec() - _spec.Node.Columns = gtq.fields - if len(gtq.fields) > 0 { - _spec.Unique = gtq.unique != nil && *gtq.unique + _spec.Node.Columns = gtq.ctx.Fields + if len(gtq.ctx.Fields) > 0 { + _spec.Unique = gtq.ctx.Unique != nil && *gtq.ctx.Unique } return sqlgraph.CountNodes(ctx, gtq.driver, _spec) } @@ -531,10 +526,10 @@ func (gtq *GroupTagQuery) querySpec() *sqlgraph.QuerySpec { From: gtq.sql, Unique: true, } - if unique := gtq.unique; unique != nil { + if unique := gtq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gtq.fields; len(fields) > 0 { + if fields := gtq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, grouptag.FieldID) for i := range fields { @@ -550,10 +545,10 @@ func (gtq *GroupTagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gtq.limit; limit != nil { + if limit := gtq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gtq.offset; offset != nil { + if offset := gtq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gtq.order; len(ps) > 0 { @@ -569,7 +564,7 @@ func (gtq *GroupTagQuery) querySpec() *sqlgraph.QuerySpec { func (gtq *GroupTagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gtq.driver.Dialect()) t1 := builder.Table(grouptag.Table) - columns := gtq.fields + columns := gtq.ctx.Fields if len(columns) == 0 { columns = grouptag.Columns } @@ -578,7 +573,7 @@ func (gtq *GroupTagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gtq.sql selector.Select(selector.Columns(columns...)...) } - if gtq.unique != nil && *gtq.unique { + if gtq.ctx.Unique != nil && *gtq.ctx.Unique { selector.Distinct() } for _, p := range gtq.predicates { @@ -587,12 +582,12 @@ func (gtq *GroupTagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gtq.order { p(selector) } - if offset := gtq.offset; offset != nil { + if offset := gtq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gtq.limit; limit != nil { + if limit := gtq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -612,7 +607,7 @@ func (gtgb *GroupTagGroupBy) Aggregate(fns ...AggregateFunc) *GroupTagGroupBy { // Scan applies the selector query and scans the result into the given value. func (gtgb *GroupTagGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroupTag, "GroupBy") + ctx = setContextOp(ctx, gtgb.build.ctx, "GroupBy") if err := gtgb.build.prepareQuery(ctx); err != nil { return err } @@ -660,7 +655,7 @@ func (gts *GroupTagSelect) Aggregate(fns ...AggregateFunc) *GroupTagSelect { // Scan applies the selector query and scans the result into the given value. func (gts *GroupTagSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroupTag, "Select") + ctx = setContextOp(ctx, gts.ctx, "Select") if err := gts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/relationship_query.go b/entc/integration/edgeschema/ent/relationship_query.go index cacb1248d..13c47fbd7 100644 --- a/entc/integration/edgeschema/ent/relationship_query.go +++ b/entc/integration/edgeschema/ent/relationship_query.go @@ -23,11 +23,8 @@ import ( // RelationshipQuery is the builder for querying Relationship entities. type RelationshipQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Relationship withUser *UserQuery @@ -46,20 +43,20 @@ func (rq *RelationshipQuery) Where(ps ...predicate.Relationship) *RelationshipQu // Limit the number of records to be returned by this query. func (rq *RelationshipQuery) Limit(limit int) *RelationshipQuery { - rq.limit = &limit + rq.ctx.Limit = &limit return rq } // Offset to start from. func (rq *RelationshipQuery) Offset(offset int) *RelationshipQuery { - rq.offset = &offset + rq.ctx.Offset = &offset return rq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (rq *RelationshipQuery) Unique(unique bool) *RelationshipQuery { - rq.unique = &unique + rq.ctx.Unique = &unique return rq } @@ -138,7 +135,7 @@ func (rq *RelationshipQuery) QueryInfo() *RelationshipInfoQuery { // First returns the first Relationship entity from the query. // Returns a *NotFoundError when no Relationship was found. func (rq *RelationshipQuery) First(ctx context.Context) (*Relationship, error) { - nodes, err := rq.Limit(1).All(newQueryContext(ctx, TypeRelationship, "First")) + nodes, err := rq.Limit(1).All(setContextOp(ctx, rq.ctx, "First")) if err != nil { return nil, err } @@ -161,7 +158,7 @@ func (rq *RelationshipQuery) FirstX(ctx context.Context) *Relationship { // Returns a *NotSingularError when more than one Relationship entity is found. // Returns a *NotFoundError when no Relationship entities are found. func (rq *RelationshipQuery) Only(ctx context.Context) (*Relationship, error) { - nodes, err := rq.Limit(2).All(newQueryContext(ctx, TypeRelationship, "Only")) + nodes, err := rq.Limit(2).All(setContextOp(ctx, rq.ctx, "Only")) if err != nil { return nil, err } @@ -186,7 +183,7 @@ func (rq *RelationshipQuery) OnlyX(ctx context.Context) *Relationship { // All executes the query and returns a list of Relationships. func (rq *RelationshipQuery) All(ctx context.Context) ([]*Relationship, error) { - ctx = newQueryContext(ctx, TypeRelationship, "All") + ctx = setContextOp(ctx, rq.ctx, "All") if err := rq.prepareQuery(ctx); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (rq *RelationshipQuery) AllX(ctx context.Context) []*Relationship { // Count returns the count of the given query. func (rq *RelationshipQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeRelationship, "Count") + ctx = setContextOp(ctx, rq.ctx, "Count") if err := rq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (rq *RelationshipQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (rq *RelationshipQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeRelationship, "Exist") + ctx = setContextOp(ctx, rq.ctx, "Exist") switch _, err := rq.First(ctx); { case IsNotFound(err): return false, nil @@ -251,8 +248,7 @@ func (rq *RelationshipQuery) Clone() *RelationshipQuery { } return &RelationshipQuery{ config: rq.config, - limit: rq.limit, - offset: rq.offset, + ctx: rq.ctx.Clone(), order: append([]OrderFunc{}, rq.order...), inters: append([]Interceptor{}, rq.inters...), predicates: append([]predicate.Relationship{}, rq.predicates...), @@ -260,9 +256,8 @@ func (rq *RelationshipQuery) Clone() *RelationshipQuery { withRelative: rq.withRelative.Clone(), withInfo: rq.withInfo.Clone(), // clone intermediate query. - sql: rq.sql.Clone(), - path: rq.path, - unique: rq.unique, + sql: rq.sql.Clone(), + path: rq.path, } } @@ -314,9 +309,9 @@ func (rq *RelationshipQuery) WithInfo(opts ...func(*RelationshipInfoQuery)) *Rel // Aggregate(ent.Count()). // Scan(ctx, &v) func (rq *RelationshipQuery) GroupBy(field string, fields ...string) *RelationshipGroupBy { - rq.fields = append([]string{field}, fields...) + rq.ctx.Fields = append([]string{field}, fields...) grbuild := &RelationshipGroupBy{build: rq} - grbuild.flds = &rq.fields + grbuild.flds = &rq.ctx.Fields grbuild.label = relationship.Label grbuild.scan = grbuild.Scan return grbuild @@ -335,10 +330,10 @@ func (rq *RelationshipQuery) GroupBy(field string, fields ...string) *Relationsh // Select(relationship.FieldWeight). // Scan(ctx, &v) func (rq *RelationshipQuery) Select(fields ...string) *RelationshipSelect { - rq.fields = append(rq.fields, fields...) + rq.ctx.Fields = append(rq.ctx.Fields, fields...) sbuild := &RelationshipSelect{RelationshipQuery: rq} sbuild.label = relationship.Label - sbuild.flds, sbuild.scan = &rq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &rq.ctx.Fields, sbuild.Scan return sbuild } @@ -358,7 +353,7 @@ func (rq *RelationshipQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range rq.fields { + for _, f := range rq.ctx.Fields { if !relationship.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -532,10 +527,10 @@ func (rq *RelationshipQuery) querySpec() *sqlgraph.QuerySpec { From: rq.sql, Unique: true, } - if unique := rq.unique; unique != nil { + if unique := rq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := rq.fields; len(fields) > 0 { + if fields := rq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) for i := range fields { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) @@ -548,10 +543,10 @@ func (rq *RelationshipQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := rq.limit; limit != nil { + if limit := rq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := rq.offset; offset != nil { + if offset := rq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := rq.order; len(ps) > 0 { @@ -567,7 +562,7 @@ func (rq *RelationshipQuery) querySpec() *sqlgraph.QuerySpec { func (rq *RelationshipQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(rq.driver.Dialect()) t1 := builder.Table(relationship.Table) - columns := rq.fields + columns := rq.ctx.Fields if len(columns) == 0 { columns = relationship.Columns } @@ -576,7 +571,7 @@ func (rq *RelationshipQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = rq.sql selector.Select(selector.Columns(columns...)...) } - if rq.unique != nil && *rq.unique { + if rq.ctx.Unique != nil && *rq.ctx.Unique { selector.Distinct() } for _, p := range rq.predicates { @@ -585,12 +580,12 @@ func (rq *RelationshipQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range rq.order { p(selector) } - if offset := rq.offset; offset != nil { + if offset := rq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := rq.limit; limit != nil { + if limit := rq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -610,7 +605,7 @@ func (rgb *RelationshipGroupBy) Aggregate(fns ...AggregateFunc) *RelationshipGro // Scan applies the selector query and scans the result into the given value. func (rgb *RelationshipGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRelationship, "GroupBy") + ctx = setContextOp(ctx, rgb.build.ctx, "GroupBy") if err := rgb.build.prepareQuery(ctx); err != nil { return err } @@ -658,7 +653,7 @@ func (rs *RelationshipSelect) Aggregate(fns ...AggregateFunc) *RelationshipSelec // Scan applies the selector query and scans the result into the given value. func (rs *RelationshipSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRelationship, "Select") + ctx = setContextOp(ctx, rs.ctx, "Select") if err := rs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/relationshipinfo_query.go b/entc/integration/edgeschema/ent/relationshipinfo_query.go index c206e83f9..152233ecf 100644 --- a/entc/integration/edgeschema/ent/relationshipinfo_query.go +++ b/entc/integration/edgeschema/ent/relationshipinfo_query.go @@ -21,11 +21,8 @@ import ( // RelationshipInfoQuery is the builder for querying RelationshipInfo entities. type RelationshipInfoQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.RelationshipInfo // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (riq *RelationshipInfoQuery) Where(ps ...predicate.RelationshipInfo) *Relat // Limit the number of records to be returned by this query. func (riq *RelationshipInfoQuery) Limit(limit int) *RelationshipInfoQuery { - riq.limit = &limit + riq.ctx.Limit = &limit return riq } // Offset to start from. func (riq *RelationshipInfoQuery) Offset(offset int) *RelationshipInfoQuery { - riq.offset = &offset + riq.ctx.Offset = &offset return riq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (riq *RelationshipInfoQuery) Unique(unique bool) *RelationshipInfoQuery { - riq.unique = &unique + riq.ctx.Unique = &unique return riq } @@ -67,7 +64,7 @@ func (riq *RelationshipInfoQuery) Order(o ...OrderFunc) *RelationshipInfoQuery { // First returns the first RelationshipInfo entity from the query. // Returns a *NotFoundError when no RelationshipInfo was found. func (riq *RelationshipInfoQuery) First(ctx context.Context) (*RelationshipInfo, error) { - nodes, err := riq.Limit(1).All(newQueryContext(ctx, TypeRelationshipInfo, "First")) + nodes, err := riq.Limit(1).All(setContextOp(ctx, riq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (riq *RelationshipInfoQuery) FirstX(ctx context.Context) *RelationshipInfo // Returns a *NotFoundError when no RelationshipInfo ID was found. func (riq *RelationshipInfoQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = riq.Limit(1).IDs(newQueryContext(ctx, TypeRelationshipInfo, "FirstID")); err != nil { + if ids, err = riq.Limit(1).IDs(setContextOp(ctx, riq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (riq *RelationshipInfoQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one RelationshipInfo entity is found. // Returns a *NotFoundError when no RelationshipInfo entities are found. func (riq *RelationshipInfoQuery) Only(ctx context.Context) (*RelationshipInfo, error) { - nodes, err := riq.Limit(2).All(newQueryContext(ctx, TypeRelationshipInfo, "Only")) + nodes, err := riq.Limit(2).All(setContextOp(ctx, riq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (riq *RelationshipInfoQuery) OnlyX(ctx context.Context) *RelationshipInfo { // Returns a *NotFoundError when no entities are found. func (riq *RelationshipInfoQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = riq.Limit(2).IDs(newQueryContext(ctx, TypeRelationshipInfo, "OnlyID")); err != nil { + if ids, err = riq.Limit(2).IDs(setContextOp(ctx, riq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (riq *RelationshipInfoQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of RelationshipInfos. func (riq *RelationshipInfoQuery) All(ctx context.Context) ([]*RelationshipInfo, error) { - ctx = newQueryContext(ctx, TypeRelationshipInfo, "All") + ctx = setContextOp(ctx, riq.ctx, "All") if err := riq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (riq *RelationshipInfoQuery) AllX(ctx context.Context) []*RelationshipInfo // IDs executes the query and returns a list of RelationshipInfo IDs. func (riq *RelationshipInfoQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeRelationshipInfo, "IDs") + ctx = setContextOp(ctx, riq.ctx, "IDs") if err := riq.Select(relationshipinfo.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (riq *RelationshipInfoQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (riq *RelationshipInfoQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeRelationshipInfo, "Count") + ctx = setContextOp(ctx, riq.ctx, "Count") if err := riq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (riq *RelationshipInfoQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (riq *RelationshipInfoQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeRelationshipInfo, "Exist") + ctx = setContextOp(ctx, riq.ctx, "Exist") switch _, err := riq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (riq *RelationshipInfoQuery) Clone() *RelationshipInfoQuery { } return &RelationshipInfoQuery{ config: riq.config, - limit: riq.limit, - offset: riq.offset, + ctx: riq.ctx.Clone(), order: append([]OrderFunc{}, riq.order...), inters: append([]Interceptor{}, riq.inters...), predicates: append([]predicate.RelationshipInfo{}, riq.predicates...), // clone intermediate query. - sql: riq.sql.Clone(), - path: riq.path, - unique: riq.unique, + sql: riq.sql.Clone(), + path: riq.path, } } @@ -277,9 +272,9 @@ func (riq *RelationshipInfoQuery) Clone() *RelationshipInfoQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (riq *RelationshipInfoQuery) GroupBy(field string, fields ...string) *RelationshipInfoGroupBy { - riq.fields = append([]string{field}, fields...) + riq.ctx.Fields = append([]string{field}, fields...) grbuild := &RelationshipInfoGroupBy{build: riq} - grbuild.flds = &riq.fields + grbuild.flds = &riq.ctx.Fields grbuild.label = relationshipinfo.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (riq *RelationshipInfoQuery) GroupBy(field string, fields ...string) *Relat // Select(relationshipinfo.FieldText). // Scan(ctx, &v) func (riq *RelationshipInfoQuery) Select(fields ...string) *RelationshipInfoSelect { - riq.fields = append(riq.fields, fields...) + riq.ctx.Fields = append(riq.ctx.Fields, fields...) sbuild := &RelationshipInfoSelect{RelationshipInfoQuery: riq} sbuild.label = relationshipinfo.Label - sbuild.flds, sbuild.scan = &riq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &riq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (riq *RelationshipInfoQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range riq.fields { + for _, f := range riq.ctx.Fields { if !relationshipinfo.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (riq *RelationshipInfoQuery) sqlAll(ctx context.Context, hooks ...queryHook func (riq *RelationshipInfoQuery) sqlCount(ctx context.Context) (int, error) { _spec := riq.querySpec() - _spec.Node.Columns = riq.fields - if len(riq.fields) > 0 { - _spec.Unique = riq.unique != nil && *riq.unique + _spec.Node.Columns = riq.ctx.Fields + if len(riq.ctx.Fields) > 0 { + _spec.Unique = riq.ctx.Unique != nil && *riq.ctx.Unique } return sqlgraph.CountNodes(ctx, riq.driver, _spec) } @@ -383,10 +378,10 @@ func (riq *RelationshipInfoQuery) querySpec() *sqlgraph.QuerySpec { From: riq.sql, Unique: true, } - if unique := riq.unique; unique != nil { + if unique := riq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := riq.fields; len(fields) > 0 { + if fields := riq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, relationshipinfo.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (riq *RelationshipInfoQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := riq.limit; limit != nil { + if limit := riq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := riq.offset; offset != nil { + if offset := riq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := riq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (riq *RelationshipInfoQuery) querySpec() *sqlgraph.QuerySpec { func (riq *RelationshipInfoQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(riq.driver.Dialect()) t1 := builder.Table(relationshipinfo.Table) - columns := riq.fields + columns := riq.ctx.Fields if len(columns) == 0 { columns = relationshipinfo.Columns } @@ -430,7 +425,7 @@ func (riq *RelationshipInfoQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = riq.sql selector.Select(selector.Columns(columns...)...) } - if riq.unique != nil && *riq.unique { + if riq.ctx.Unique != nil && *riq.ctx.Unique { selector.Distinct() } for _, p := range riq.predicates { @@ -439,12 +434,12 @@ func (riq *RelationshipInfoQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range riq.order { p(selector) } - if offset := riq.offset; offset != nil { + if offset := riq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := riq.limit; limit != nil { + if limit := riq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (rigb *RelationshipInfoGroupBy) Aggregate(fns ...AggregateFunc) *Relationsh // Scan applies the selector query and scans the result into the given value. func (rigb *RelationshipInfoGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRelationshipInfo, "GroupBy") + ctx = setContextOp(ctx, rigb.build.ctx, "GroupBy") if err := rigb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (ris *RelationshipInfoSelect) Aggregate(fns ...AggregateFunc) *Relationship // Scan applies the selector query and scans the result into the given value. func (ris *RelationshipInfoSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRelationshipInfo, "Select") + ctx = setContextOp(ctx, ris.ctx, "Select") if err := ris.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/role_query.go b/entc/integration/edgeschema/ent/role_query.go index 251d9c428..f28ab7492 100644 --- a/entc/integration/edgeschema/ent/role_query.go +++ b/entc/integration/edgeschema/ent/role_query.go @@ -24,11 +24,8 @@ import ( // RoleQuery is the builder for querying Role entities. type RoleQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Role withUser *UserQuery @@ -46,20 +43,20 @@ func (rq *RoleQuery) Where(ps ...predicate.Role) *RoleQuery { // Limit the number of records to be returned by this query. func (rq *RoleQuery) Limit(limit int) *RoleQuery { - rq.limit = &limit + rq.ctx.Limit = &limit return rq } // Offset to start from. func (rq *RoleQuery) Offset(offset int) *RoleQuery { - rq.offset = &offset + rq.ctx.Offset = &offset return rq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (rq *RoleQuery) Unique(unique bool) *RoleQuery { - rq.unique = &unique + rq.ctx.Unique = &unique return rq } @@ -116,7 +113,7 @@ func (rq *RoleQuery) QueryRolesUsers() *RoleUserQuery { // First returns the first Role entity from the query. // Returns a *NotFoundError when no Role was found. func (rq *RoleQuery) First(ctx context.Context) (*Role, error) { - nodes, err := rq.Limit(1).All(newQueryContext(ctx, TypeRole, "First")) + nodes, err := rq.Limit(1).All(setContextOp(ctx, rq.ctx, "First")) if err != nil { return nil, err } @@ -139,7 +136,7 @@ func (rq *RoleQuery) FirstX(ctx context.Context) *Role { // Returns a *NotFoundError when no Role ID was found. func (rq *RoleQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = rq.Limit(1).IDs(newQueryContext(ctx, TypeRole, "FirstID")); err != nil { + if ids, err = rq.Limit(1).IDs(setContextOp(ctx, rq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -162,7 +159,7 @@ func (rq *RoleQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Role entity is found. // Returns a *NotFoundError when no Role entities are found. func (rq *RoleQuery) Only(ctx context.Context) (*Role, error) { - nodes, err := rq.Limit(2).All(newQueryContext(ctx, TypeRole, "Only")) + nodes, err := rq.Limit(2).All(setContextOp(ctx, rq.ctx, "Only")) if err != nil { return nil, err } @@ -190,7 +187,7 @@ func (rq *RoleQuery) OnlyX(ctx context.Context) *Role { // Returns a *NotFoundError when no entities are found. func (rq *RoleQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = rq.Limit(2).IDs(newQueryContext(ctx, TypeRole, "OnlyID")); err != nil { + if ids, err = rq.Limit(2).IDs(setContextOp(ctx, rq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -215,7 +212,7 @@ func (rq *RoleQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Roles. func (rq *RoleQuery) All(ctx context.Context) ([]*Role, error) { - ctx = newQueryContext(ctx, TypeRole, "All") + ctx = setContextOp(ctx, rq.ctx, "All") if err := rq.prepareQuery(ctx); err != nil { return nil, err } @@ -235,7 +232,7 @@ func (rq *RoleQuery) AllX(ctx context.Context) []*Role { // IDs executes the query and returns a list of Role IDs. func (rq *RoleQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeRole, "IDs") + ctx = setContextOp(ctx, rq.ctx, "IDs") if err := rq.Select(role.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -253,7 +250,7 @@ func (rq *RoleQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (rq *RoleQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeRole, "Count") + ctx = setContextOp(ctx, rq.ctx, "Count") if err := rq.prepareQuery(ctx); err != nil { return 0, err } @@ -271,7 +268,7 @@ func (rq *RoleQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (rq *RoleQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeRole, "Exist") + ctx = setContextOp(ctx, rq.ctx, "Exist") switch _, err := rq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -299,17 +296,15 @@ func (rq *RoleQuery) Clone() *RoleQuery { } return &RoleQuery{ config: rq.config, - limit: rq.limit, - offset: rq.offset, + ctx: rq.ctx.Clone(), order: append([]OrderFunc{}, rq.order...), inters: append([]Interceptor{}, rq.inters...), predicates: append([]predicate.Role{}, rq.predicates...), withUser: rq.withUser.Clone(), withRolesUsers: rq.withRolesUsers.Clone(), // clone intermediate query. - sql: rq.sql.Clone(), - path: rq.path, - unique: rq.unique, + sql: rq.sql.Clone(), + path: rq.path, } } @@ -350,9 +345,9 @@ func (rq *RoleQuery) WithRolesUsers(opts ...func(*RoleUserQuery)) *RoleQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (rq *RoleQuery) GroupBy(field string, fields ...string) *RoleGroupBy { - rq.fields = append([]string{field}, fields...) + rq.ctx.Fields = append([]string{field}, fields...) grbuild := &RoleGroupBy{build: rq} - grbuild.flds = &rq.fields + grbuild.flds = &rq.ctx.Fields grbuild.label = role.Label grbuild.scan = grbuild.Scan return grbuild @@ -371,10 +366,10 @@ func (rq *RoleQuery) GroupBy(field string, fields ...string) *RoleGroupBy { // Select(role.FieldName). // Scan(ctx, &v) func (rq *RoleQuery) Select(fields ...string) *RoleSelect { - rq.fields = append(rq.fields, fields...) + rq.ctx.Fields = append(rq.ctx.Fields, fields...) sbuild := &RoleSelect{RoleQuery: rq} sbuild.label = role.Label - sbuild.flds, sbuild.scan = &rq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &rq.ctx.Fields, sbuild.Scan return sbuild } @@ -394,7 +389,7 @@ func (rq *RoleQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range rq.fields { + for _, f := range rq.ctx.Fields { if !role.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -541,9 +536,9 @@ func (rq *RoleQuery) loadRolesUsers(ctx context.Context, query *RoleUserQuery, n func (rq *RoleQuery) sqlCount(ctx context.Context) (int, error) { _spec := rq.querySpec() - _spec.Node.Columns = rq.fields - if len(rq.fields) > 0 { - _spec.Unique = rq.unique != nil && *rq.unique + _spec.Node.Columns = rq.ctx.Fields + if len(rq.ctx.Fields) > 0 { + _spec.Unique = rq.ctx.Unique != nil && *rq.ctx.Unique } return sqlgraph.CountNodes(ctx, rq.driver, _spec) } @@ -561,10 +556,10 @@ func (rq *RoleQuery) querySpec() *sqlgraph.QuerySpec { From: rq.sql, Unique: true, } - if unique := rq.unique; unique != nil { + if unique := rq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := rq.fields; len(fields) > 0 { + if fields := rq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, role.FieldID) for i := range fields { @@ -580,10 +575,10 @@ func (rq *RoleQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := rq.limit; limit != nil { + if limit := rq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := rq.offset; offset != nil { + if offset := rq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := rq.order; len(ps) > 0 { @@ -599,7 +594,7 @@ func (rq *RoleQuery) querySpec() *sqlgraph.QuerySpec { func (rq *RoleQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(rq.driver.Dialect()) t1 := builder.Table(role.Table) - columns := rq.fields + columns := rq.ctx.Fields if len(columns) == 0 { columns = role.Columns } @@ -608,7 +603,7 @@ func (rq *RoleQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = rq.sql selector.Select(selector.Columns(columns...)...) } - if rq.unique != nil && *rq.unique { + if rq.ctx.Unique != nil && *rq.ctx.Unique { selector.Distinct() } for _, p := range rq.predicates { @@ -617,12 +612,12 @@ func (rq *RoleQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range rq.order { p(selector) } - if offset := rq.offset; offset != nil { + if offset := rq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := rq.limit; limit != nil { + if limit := rq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -642,7 +637,7 @@ func (rgb *RoleGroupBy) Aggregate(fns ...AggregateFunc) *RoleGroupBy { // Scan applies the selector query and scans the result into the given value. func (rgb *RoleGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRole, "GroupBy") + ctx = setContextOp(ctx, rgb.build.ctx, "GroupBy") if err := rgb.build.prepareQuery(ctx); err != nil { return err } @@ -690,7 +685,7 @@ func (rs *RoleSelect) Aggregate(fns ...AggregateFunc) *RoleSelect { // Scan applies the selector query and scans the result into the given value. func (rs *RoleSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRole, "Select") + ctx = setContextOp(ctx, rs.ctx, "Select") if err := rs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/roleuser_query.go b/entc/integration/edgeschema/ent/roleuser_query.go index 4612a7fa2..1b6546b0e 100644 --- a/entc/integration/edgeschema/ent/roleuser_query.go +++ b/entc/integration/edgeschema/ent/roleuser_query.go @@ -22,11 +22,8 @@ import ( // RoleUserQuery is the builder for querying RoleUser entities. type RoleUserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.RoleUser withRole *RoleQuery @@ -44,20 +41,20 @@ func (ruq *RoleUserQuery) Where(ps ...predicate.RoleUser) *RoleUserQuery { // Limit the number of records to be returned by this query. func (ruq *RoleUserQuery) Limit(limit int) *RoleUserQuery { - ruq.limit = &limit + ruq.ctx.Limit = &limit return ruq } // Offset to start from. func (ruq *RoleUserQuery) Offset(offset int) *RoleUserQuery { - ruq.offset = &offset + ruq.ctx.Offset = &offset return ruq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ruq *RoleUserQuery) Unique(unique bool) *RoleUserQuery { - ruq.unique = &unique + ruq.ctx.Unique = &unique return ruq } @@ -114,7 +111,7 @@ func (ruq *RoleUserQuery) QueryUser() *UserQuery { // First returns the first RoleUser entity from the query. // Returns a *NotFoundError when no RoleUser was found. func (ruq *RoleUserQuery) First(ctx context.Context) (*RoleUser, error) { - nodes, err := ruq.Limit(1).All(newQueryContext(ctx, TypeRoleUser, "First")) + nodes, err := ruq.Limit(1).All(setContextOp(ctx, ruq.ctx, "First")) if err != nil { return nil, err } @@ -137,7 +134,7 @@ func (ruq *RoleUserQuery) FirstX(ctx context.Context) *RoleUser { // Returns a *NotSingularError when more than one RoleUser entity is found. // Returns a *NotFoundError when no RoleUser entities are found. func (ruq *RoleUserQuery) Only(ctx context.Context) (*RoleUser, error) { - nodes, err := ruq.Limit(2).All(newQueryContext(ctx, TypeRoleUser, "Only")) + nodes, err := ruq.Limit(2).All(setContextOp(ctx, ruq.ctx, "Only")) if err != nil { return nil, err } @@ -162,7 +159,7 @@ func (ruq *RoleUserQuery) OnlyX(ctx context.Context) *RoleUser { // All executes the query and returns a list of RoleUsers. func (ruq *RoleUserQuery) All(ctx context.Context) ([]*RoleUser, error) { - ctx = newQueryContext(ctx, TypeRoleUser, "All") + ctx = setContextOp(ctx, ruq.ctx, "All") if err := ruq.prepareQuery(ctx); err != nil { return nil, err } @@ -181,7 +178,7 @@ func (ruq *RoleUserQuery) AllX(ctx context.Context) []*RoleUser { // Count returns the count of the given query. func (ruq *RoleUserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeRoleUser, "Count") + ctx = setContextOp(ctx, ruq.ctx, "Count") if err := ruq.prepareQuery(ctx); err != nil { return 0, err } @@ -199,7 +196,7 @@ func (ruq *RoleUserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ruq *RoleUserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeRoleUser, "Exist") + ctx = setContextOp(ctx, ruq.ctx, "Exist") switch _, err := ruq.First(ctx); { case IsNotFound(err): return false, nil @@ -227,17 +224,15 @@ func (ruq *RoleUserQuery) Clone() *RoleUserQuery { } return &RoleUserQuery{ config: ruq.config, - limit: ruq.limit, - offset: ruq.offset, + ctx: ruq.ctx.Clone(), order: append([]OrderFunc{}, ruq.order...), inters: append([]Interceptor{}, ruq.inters...), predicates: append([]predicate.RoleUser{}, ruq.predicates...), withRole: ruq.withRole.Clone(), withUser: ruq.withUser.Clone(), // clone intermediate query. - sql: ruq.sql.Clone(), - path: ruq.path, - unique: ruq.unique, + sql: ruq.sql.Clone(), + path: ruq.path, } } @@ -278,9 +273,9 @@ func (ruq *RoleUserQuery) WithUser(opts ...func(*UserQuery)) *RoleUserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ruq *RoleUserQuery) GroupBy(field string, fields ...string) *RoleUserGroupBy { - ruq.fields = append([]string{field}, fields...) + ruq.ctx.Fields = append([]string{field}, fields...) grbuild := &RoleUserGroupBy{build: ruq} - grbuild.flds = &ruq.fields + grbuild.flds = &ruq.ctx.Fields grbuild.label = roleuser.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (ruq *RoleUserQuery) GroupBy(field string, fields ...string) *RoleUserGroup // Select(roleuser.FieldCreatedAt). // Scan(ctx, &v) func (ruq *RoleUserQuery) Select(fields ...string) *RoleUserSelect { - ruq.fields = append(ruq.fields, fields...) + ruq.ctx.Fields = append(ruq.ctx.Fields, fields...) sbuild := &RoleUserSelect{RoleUserQuery: ruq} sbuild.label = roleuser.Label - sbuild.flds, sbuild.scan = &ruq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ruq.ctx.Fields, sbuild.Scan return sbuild } @@ -322,7 +317,7 @@ func (ruq *RoleUserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range ruq.fields { + for _, f := range ruq.ctx.Fields { if !roleuser.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -454,10 +449,10 @@ func (ruq *RoleUserQuery) querySpec() *sqlgraph.QuerySpec { From: ruq.sql, Unique: true, } - if unique := ruq.unique; unique != nil { + if unique := ruq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := ruq.fields; len(fields) > 0 { + if fields := ruq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) for i := range fields { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) @@ -470,10 +465,10 @@ func (ruq *RoleUserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ruq.limit; limit != nil { + if limit := ruq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ruq.offset; offset != nil { + if offset := ruq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ruq.order; len(ps) > 0 { @@ -489,7 +484,7 @@ func (ruq *RoleUserQuery) querySpec() *sqlgraph.QuerySpec { func (ruq *RoleUserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ruq.driver.Dialect()) t1 := builder.Table(roleuser.Table) - columns := ruq.fields + columns := ruq.ctx.Fields if len(columns) == 0 { columns = roleuser.Columns } @@ -498,7 +493,7 @@ func (ruq *RoleUserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ruq.sql selector.Select(selector.Columns(columns...)...) } - if ruq.unique != nil && *ruq.unique { + if ruq.ctx.Unique != nil && *ruq.ctx.Unique { selector.Distinct() } for _, p := range ruq.predicates { @@ -507,12 +502,12 @@ func (ruq *RoleUserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ruq.order { p(selector) } - if offset := ruq.offset; offset != nil { + if offset := ruq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ruq.limit; limit != nil { + if limit := ruq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -532,7 +527,7 @@ func (rugb *RoleUserGroupBy) Aggregate(fns ...AggregateFunc) *RoleUserGroupBy { // Scan applies the selector query and scans the result into the given value. func (rugb *RoleUserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRoleUser, "GroupBy") + ctx = setContextOp(ctx, rugb.build.ctx, "GroupBy") if err := rugb.build.prepareQuery(ctx); err != nil { return err } @@ -580,7 +575,7 @@ func (rus *RoleUserSelect) Aggregate(fns ...AggregateFunc) *RoleUserSelect { // Scan applies the selector query and scans the result into the given value. func (rus *RoleUserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeRoleUser, "Select") + ctx = setContextOp(ctx, rus.ctx, "Select") if err := rus.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/tag_query.go b/entc/integration/edgeschema/ent/tag_query.go index 554e1a7d0..1acd0dcaf 100644 --- a/entc/integration/edgeschema/ent/tag_query.go +++ b/entc/integration/edgeschema/ent/tag_query.go @@ -26,11 +26,8 @@ import ( // TagQuery is the builder for querying Tag entities. type TagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Tag withTweets *TweetQuery @@ -50,20 +47,20 @@ func (tq *TagQuery) Where(ps ...predicate.Tag) *TagQuery { // Limit the number of records to be returned by this query. func (tq *TagQuery) Limit(limit int) *TagQuery { - tq.limit = &limit + tq.ctx.Limit = &limit return tq } // Offset to start from. func (tq *TagQuery) Offset(offset int) *TagQuery { - tq.offset = &offset + tq.ctx.Offset = &offset return tq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tq *TagQuery) Unique(unique bool) *TagQuery { - tq.unique = &unique + tq.ctx.Unique = &unique return tq } @@ -164,7 +161,7 @@ func (tq *TagQuery) QueryGroupTags() *GroupTagQuery { // First returns the first Tag entity from the query. // Returns a *NotFoundError when no Tag was found. func (tq *TagQuery) First(ctx context.Context) (*Tag, error) { - nodes, err := tq.Limit(1).All(newQueryContext(ctx, TypeTag, "First")) + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) if err != nil { return nil, err } @@ -187,7 +184,7 @@ func (tq *TagQuery) FirstX(ctx context.Context) *Tag { // Returns a *NotFoundError when no Tag ID was found. func (tq *TagQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(1).IDs(newQueryContext(ctx, TypeTag, "FirstID")); err != nil { + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -210,7 +207,7 @@ func (tq *TagQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Tag entity is found. // Returns a *NotFoundError when no Tag entities are found. func (tq *TagQuery) Only(ctx context.Context) (*Tag, error) { - nodes, err := tq.Limit(2).All(newQueryContext(ctx, TypeTag, "Only")) + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) if err != nil { return nil, err } @@ -238,7 +235,7 @@ func (tq *TagQuery) OnlyX(ctx context.Context) *Tag { // Returns a *NotFoundError when no entities are found. func (tq *TagQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(2).IDs(newQueryContext(ctx, TypeTag, "OnlyID")); err != nil { + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -263,7 +260,7 @@ func (tq *TagQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Tags. func (tq *TagQuery) All(ctx context.Context) ([]*Tag, error) { - ctx = newQueryContext(ctx, TypeTag, "All") + ctx = setContextOp(ctx, tq.ctx, "All") if err := tq.prepareQuery(ctx); err != nil { return nil, err } @@ -283,7 +280,7 @@ func (tq *TagQuery) AllX(ctx context.Context) []*Tag { // IDs executes the query and returns a list of Tag IDs. func (tq *TagQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeTag, "IDs") + ctx = setContextOp(ctx, tq.ctx, "IDs") if err := tq.Select(tag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -301,7 +298,7 @@ func (tq *TagQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (tq *TagQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeTag, "Count") + ctx = setContextOp(ctx, tq.ctx, "Count") if err := tq.prepareQuery(ctx); err != nil { return 0, err } @@ -319,7 +316,7 @@ func (tq *TagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tq *TagQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeTag, "Exist") + ctx = setContextOp(ctx, tq.ctx, "Exist") switch _, err := tq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -347,8 +344,7 @@ func (tq *TagQuery) Clone() *TagQuery { } return &TagQuery{ config: tq.config, - limit: tq.limit, - offset: tq.offset, + ctx: tq.ctx.Clone(), order: append([]OrderFunc{}, tq.order...), inters: append([]Interceptor{}, tq.inters...), predicates: append([]predicate.Tag{}, tq.predicates...), @@ -357,9 +353,8 @@ func (tq *TagQuery) Clone() *TagQuery { withTweetTags: tq.withTweetTags.Clone(), withGroupTags: tq.withGroupTags.Clone(), // clone intermediate query. - sql: tq.sql.Clone(), - path: tq.path, - unique: tq.unique, + sql: tq.sql.Clone(), + path: tq.path, } } @@ -422,9 +417,9 @@ func (tq *TagQuery) WithGroupTags(opts ...func(*GroupTagQuery)) *TagQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (tq *TagQuery) GroupBy(field string, fields ...string) *TagGroupBy { - tq.fields = append([]string{field}, fields...) + tq.ctx.Fields = append([]string{field}, fields...) grbuild := &TagGroupBy{build: tq} - grbuild.flds = &tq.fields + grbuild.flds = &tq.ctx.Fields grbuild.label = tag.Label grbuild.scan = grbuild.Scan return grbuild @@ -443,10 +438,10 @@ func (tq *TagQuery) GroupBy(field string, fields ...string) *TagGroupBy { // Select(tag.FieldValue). // Scan(ctx, &v) func (tq *TagQuery) Select(fields ...string) *TagSelect { - tq.fields = append(tq.fields, fields...) + tq.ctx.Fields = append(tq.ctx.Fields, fields...) sbuild := &TagSelect{TagQuery: tq} sbuild.label = tag.Label - sbuild.flds, sbuild.scan = &tq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan return sbuild } @@ -466,7 +461,7 @@ func (tq *TagQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range tq.fields { + for _, f := range tq.ctx.Fields { if !tag.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -714,9 +709,9 @@ func (tq *TagQuery) loadGroupTags(ctx context.Context, query *GroupTagQuery, nod func (tq *TagQuery) sqlCount(ctx context.Context) (int, error) { _spec := tq.querySpec() - _spec.Node.Columns = tq.fields - if len(tq.fields) > 0 { - _spec.Unique = tq.unique != nil && *tq.unique + _spec.Node.Columns = tq.ctx.Fields + if len(tq.ctx.Fields) > 0 { + _spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique } return sqlgraph.CountNodes(ctx, tq.driver, _spec) } @@ -734,10 +729,10 @@ func (tq *TagQuery) querySpec() *sqlgraph.QuerySpec { From: tq.sql, Unique: true, } - if unique := tq.unique; unique != nil { + if unique := tq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := tq.fields; len(fields) > 0 { + if fields := tq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, tag.FieldID) for i := range fields { @@ -753,10 +748,10 @@ func (tq *TagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := tq.order; len(ps) > 0 { @@ -772,7 +767,7 @@ func (tq *TagQuery) querySpec() *sqlgraph.QuerySpec { func (tq *TagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(tq.driver.Dialect()) t1 := builder.Table(tag.Table) - columns := tq.fields + columns := tq.ctx.Fields if len(columns) == 0 { columns = tag.Columns } @@ -781,7 +776,7 @@ func (tq *TagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = tq.sql selector.Select(selector.Columns(columns...)...) } - if tq.unique != nil && *tq.unique { + if tq.ctx.Unique != nil && *tq.ctx.Unique { selector.Distinct() } for _, p := range tq.predicates { @@ -790,12 +785,12 @@ func (tq *TagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range tq.order { p(selector) } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -815,7 +810,7 @@ func (tgb *TagGroupBy) Aggregate(fns ...AggregateFunc) *TagGroupBy { // Scan applies the selector query and scans the result into the given value. func (tgb *TagGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTag, "GroupBy") + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") if err := tgb.build.prepareQuery(ctx); err != nil { return err } @@ -863,7 +858,7 @@ func (ts *TagSelect) Aggregate(fns ...AggregateFunc) *TagSelect { // Scan applies the selector query and scans the result into the given value. func (ts *TagSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTag, "Select") + ctx = setContextOp(ctx, ts.ctx, "Select") if err := ts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/tweet_query.go b/entc/integration/edgeschema/ent/tweet_query.go index 8ff37f3a8..4b9e82fba 100644 --- a/entc/integration/edgeschema/ent/tweet_query.go +++ b/entc/integration/edgeschema/ent/tweet_query.go @@ -27,11 +27,8 @@ import ( // TweetQuery is the builder for querying Tweet entities. type TweetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Tweet withLikedUsers *UserQuery @@ -53,20 +50,20 @@ func (tq *TweetQuery) Where(ps ...predicate.Tweet) *TweetQuery { // Limit the number of records to be returned by this query. func (tq *TweetQuery) Limit(limit int) *TweetQuery { - tq.limit = &limit + tq.ctx.Limit = &limit return tq } // Offset to start from. func (tq *TweetQuery) Offset(offset int) *TweetQuery { - tq.offset = &offset + tq.ctx.Offset = &offset return tq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tq *TweetQuery) Unique(unique bool) *TweetQuery { - tq.unique = &unique + tq.ctx.Unique = &unique return tq } @@ -211,7 +208,7 @@ func (tq *TweetQuery) QueryTweetTags() *TweetTagQuery { // First returns the first Tweet entity from the query. // Returns a *NotFoundError when no Tweet was found. func (tq *TweetQuery) First(ctx context.Context) (*Tweet, error) { - nodes, err := tq.Limit(1).All(newQueryContext(ctx, TypeTweet, "First")) + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) if err != nil { return nil, err } @@ -234,7 +231,7 @@ func (tq *TweetQuery) FirstX(ctx context.Context) *Tweet { // Returns a *NotFoundError when no Tweet ID was found. func (tq *TweetQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(1).IDs(newQueryContext(ctx, TypeTweet, "FirstID")); err != nil { + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -257,7 +254,7 @@ func (tq *TweetQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Tweet entity is found. // Returns a *NotFoundError when no Tweet entities are found. func (tq *TweetQuery) Only(ctx context.Context) (*Tweet, error) { - nodes, err := tq.Limit(2).All(newQueryContext(ctx, TypeTweet, "Only")) + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) if err != nil { return nil, err } @@ -285,7 +282,7 @@ func (tq *TweetQuery) OnlyX(ctx context.Context) *Tweet { // Returns a *NotFoundError when no entities are found. func (tq *TweetQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(2).IDs(newQueryContext(ctx, TypeTweet, "OnlyID")); err != nil { + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -310,7 +307,7 @@ func (tq *TweetQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Tweets. func (tq *TweetQuery) All(ctx context.Context) ([]*Tweet, error) { - ctx = newQueryContext(ctx, TypeTweet, "All") + ctx = setContextOp(ctx, tq.ctx, "All") if err := tq.prepareQuery(ctx); err != nil { return nil, err } @@ -330,7 +327,7 @@ func (tq *TweetQuery) AllX(ctx context.Context) []*Tweet { // IDs executes the query and returns a list of Tweet IDs. func (tq *TweetQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeTweet, "IDs") + ctx = setContextOp(ctx, tq.ctx, "IDs") if err := tq.Select(tweet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -348,7 +345,7 @@ func (tq *TweetQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (tq *TweetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeTweet, "Count") + ctx = setContextOp(ctx, tq.ctx, "Count") if err := tq.prepareQuery(ctx); err != nil { return 0, err } @@ -366,7 +363,7 @@ func (tq *TweetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tq *TweetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeTweet, "Exist") + ctx = setContextOp(ctx, tq.ctx, "Exist") switch _, err := tq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -394,8 +391,7 @@ func (tq *TweetQuery) Clone() *TweetQuery { } return &TweetQuery{ config: tq.config, - limit: tq.limit, - offset: tq.offset, + ctx: tq.ctx.Clone(), order: append([]OrderFunc{}, tq.order...), inters: append([]Interceptor{}, tq.inters...), predicates: append([]predicate.Tweet{}, tq.predicates...), @@ -406,9 +402,8 @@ func (tq *TweetQuery) Clone() *TweetQuery { withTweetUser: tq.withTweetUser.Clone(), withTweetTags: tq.withTweetTags.Clone(), // clone intermediate query. - sql: tq.sql.Clone(), - path: tq.path, - unique: tq.unique, + sql: tq.sql.Clone(), + path: tq.path, } } @@ -493,9 +488,9 @@ func (tq *TweetQuery) WithTweetTags(opts ...func(*TweetTagQuery)) *TweetQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (tq *TweetQuery) GroupBy(field string, fields ...string) *TweetGroupBy { - tq.fields = append([]string{field}, fields...) + tq.ctx.Fields = append([]string{field}, fields...) grbuild := &TweetGroupBy{build: tq} - grbuild.flds = &tq.fields + grbuild.flds = &tq.ctx.Fields grbuild.label = tweet.Label grbuild.scan = grbuild.Scan return grbuild @@ -514,10 +509,10 @@ func (tq *TweetQuery) GroupBy(field string, fields ...string) *TweetGroupBy { // Select(tweet.FieldText). // Scan(ctx, &v) func (tq *TweetQuery) Select(fields ...string) *TweetSelect { - tq.fields = append(tq.fields, fields...) + tq.ctx.Fields = append(tq.ctx.Fields, fields...) sbuild := &TweetSelect{TweetQuery: tq} sbuild.label = tweet.Label - sbuild.flds, sbuild.scan = &tq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan return sbuild } @@ -537,7 +532,7 @@ func (tq *TweetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range tq.fields { + for _, f := range tq.ctx.Fields { if !tweet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -886,9 +881,9 @@ func (tq *TweetQuery) loadTweetTags(ctx context.Context, query *TweetTagQuery, n func (tq *TweetQuery) sqlCount(ctx context.Context) (int, error) { _spec := tq.querySpec() - _spec.Node.Columns = tq.fields - if len(tq.fields) > 0 { - _spec.Unique = tq.unique != nil && *tq.unique + _spec.Node.Columns = tq.ctx.Fields + if len(tq.ctx.Fields) > 0 { + _spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique } return sqlgraph.CountNodes(ctx, tq.driver, _spec) } @@ -906,10 +901,10 @@ func (tq *TweetQuery) querySpec() *sqlgraph.QuerySpec { From: tq.sql, Unique: true, } - if unique := tq.unique; unique != nil { + if unique := tq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := tq.fields; len(fields) > 0 { + if fields := tq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, tweet.FieldID) for i := range fields { @@ -925,10 +920,10 @@ func (tq *TweetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := tq.order; len(ps) > 0 { @@ -944,7 +939,7 @@ func (tq *TweetQuery) querySpec() *sqlgraph.QuerySpec { func (tq *TweetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(tq.driver.Dialect()) t1 := builder.Table(tweet.Table) - columns := tq.fields + columns := tq.ctx.Fields if len(columns) == 0 { columns = tweet.Columns } @@ -953,7 +948,7 @@ func (tq *TweetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = tq.sql selector.Select(selector.Columns(columns...)...) } - if tq.unique != nil && *tq.unique { + if tq.ctx.Unique != nil && *tq.ctx.Unique { selector.Distinct() } for _, p := range tq.predicates { @@ -962,12 +957,12 @@ func (tq *TweetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range tq.order { p(selector) } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -987,7 +982,7 @@ func (tgb *TweetGroupBy) Aggregate(fns ...AggregateFunc) *TweetGroupBy { // Scan applies the selector query and scans the result into the given value. func (tgb *TweetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTweet, "GroupBy") + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") if err := tgb.build.prepareQuery(ctx); err != nil { return err } @@ -1035,7 +1030,7 @@ func (ts *TweetSelect) Aggregate(fns ...AggregateFunc) *TweetSelect { // Scan applies the selector query and scans the result into the given value. func (ts *TweetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTweet, "Select") + ctx = setContextOp(ctx, ts.ctx, "Select") if err := ts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/tweetlike_query.go b/entc/integration/edgeschema/ent/tweetlike_query.go index 0f22d6df9..db326d2f3 100644 --- a/entc/integration/edgeschema/ent/tweetlike_query.go +++ b/entc/integration/edgeschema/ent/tweetlike_query.go @@ -23,11 +23,8 @@ import ( // TweetLikeQuery is the builder for querying TweetLike entities. type TweetLikeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.TweetLike withTweet *TweetQuery @@ -45,20 +42,20 @@ func (tlq *TweetLikeQuery) Where(ps ...predicate.TweetLike) *TweetLikeQuery { // Limit the number of records to be returned by this query. func (tlq *TweetLikeQuery) Limit(limit int) *TweetLikeQuery { - tlq.limit = &limit + tlq.ctx.Limit = &limit return tlq } // Offset to start from. func (tlq *TweetLikeQuery) Offset(offset int) *TweetLikeQuery { - tlq.offset = &offset + tlq.ctx.Offset = &offset return tlq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tlq *TweetLikeQuery) Unique(unique bool) *TweetLikeQuery { - tlq.unique = &unique + tlq.ctx.Unique = &unique return tlq } @@ -115,7 +112,7 @@ func (tlq *TweetLikeQuery) QueryUser() *UserQuery { // First returns the first TweetLike entity from the query. // Returns a *NotFoundError when no TweetLike was found. func (tlq *TweetLikeQuery) First(ctx context.Context) (*TweetLike, error) { - nodes, err := tlq.Limit(1).All(newQueryContext(ctx, TypeTweetLike, "First")) + nodes, err := tlq.Limit(1).All(setContextOp(ctx, tlq.ctx, "First")) if err != nil { return nil, err } @@ -138,7 +135,7 @@ func (tlq *TweetLikeQuery) FirstX(ctx context.Context) *TweetLike { // Returns a *NotSingularError when more than one TweetLike entity is found. // Returns a *NotFoundError when no TweetLike entities are found. func (tlq *TweetLikeQuery) Only(ctx context.Context) (*TweetLike, error) { - nodes, err := tlq.Limit(2).All(newQueryContext(ctx, TypeTweetLike, "Only")) + nodes, err := tlq.Limit(2).All(setContextOp(ctx, tlq.ctx, "Only")) if err != nil { return nil, err } @@ -163,7 +160,7 @@ func (tlq *TweetLikeQuery) OnlyX(ctx context.Context) *TweetLike { // All executes the query and returns a list of TweetLikes. func (tlq *TweetLikeQuery) All(ctx context.Context) ([]*TweetLike, error) { - ctx = newQueryContext(ctx, TypeTweetLike, "All") + ctx = setContextOp(ctx, tlq.ctx, "All") if err := tlq.prepareQuery(ctx); err != nil { return nil, err } @@ -182,7 +179,7 @@ func (tlq *TweetLikeQuery) AllX(ctx context.Context) []*TweetLike { // Count returns the count of the given query. func (tlq *TweetLikeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeTweetLike, "Count") + ctx = setContextOp(ctx, tlq.ctx, "Count") if err := tlq.prepareQuery(ctx); err != nil { return 0, err } @@ -200,7 +197,7 @@ func (tlq *TweetLikeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tlq *TweetLikeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeTweetLike, "Exist") + ctx = setContextOp(ctx, tlq.ctx, "Exist") switch _, err := tlq.First(ctx); { case IsNotFound(err): return false, nil @@ -228,17 +225,15 @@ func (tlq *TweetLikeQuery) Clone() *TweetLikeQuery { } return &TweetLikeQuery{ config: tlq.config, - limit: tlq.limit, - offset: tlq.offset, + ctx: tlq.ctx.Clone(), order: append([]OrderFunc{}, tlq.order...), inters: append([]Interceptor{}, tlq.inters...), predicates: append([]predicate.TweetLike{}, tlq.predicates...), withTweet: tlq.withTweet.Clone(), withUser: tlq.withUser.Clone(), // clone intermediate query. - sql: tlq.sql.Clone(), - path: tlq.path, - unique: tlq.unique, + sql: tlq.sql.Clone(), + path: tlq.path, } } @@ -279,9 +274,9 @@ func (tlq *TweetLikeQuery) WithUser(opts ...func(*UserQuery)) *TweetLikeQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (tlq *TweetLikeQuery) GroupBy(field string, fields ...string) *TweetLikeGroupBy { - tlq.fields = append([]string{field}, fields...) + tlq.ctx.Fields = append([]string{field}, fields...) grbuild := &TweetLikeGroupBy{build: tlq} - grbuild.flds = &tlq.fields + grbuild.flds = &tlq.ctx.Fields grbuild.label = tweetlike.Label grbuild.scan = grbuild.Scan return grbuild @@ -300,10 +295,10 @@ func (tlq *TweetLikeQuery) GroupBy(field string, fields ...string) *TweetLikeGro // Select(tweetlike.FieldLikedAt). // Scan(ctx, &v) func (tlq *TweetLikeQuery) Select(fields ...string) *TweetLikeSelect { - tlq.fields = append(tlq.fields, fields...) + tlq.ctx.Fields = append(tlq.ctx.Fields, fields...) sbuild := &TweetLikeSelect{TweetLikeQuery: tlq} sbuild.label = tweetlike.Label - sbuild.flds, sbuild.scan = &tlq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &tlq.ctx.Fields, sbuild.Scan return sbuild } @@ -323,7 +318,7 @@ func (tlq *TweetLikeQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range tlq.fields { + for _, f := range tlq.ctx.Fields { if !tweetlike.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -461,10 +456,10 @@ func (tlq *TweetLikeQuery) querySpec() *sqlgraph.QuerySpec { From: tlq.sql, Unique: true, } - if unique := tlq.unique; unique != nil { + if unique := tlq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := tlq.fields; len(fields) > 0 { + if fields := tlq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) for i := range fields { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) @@ -477,10 +472,10 @@ func (tlq *TweetLikeQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := tlq.limit; limit != nil { + if limit := tlq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := tlq.offset; offset != nil { + if offset := tlq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := tlq.order; len(ps) > 0 { @@ -496,7 +491,7 @@ func (tlq *TweetLikeQuery) querySpec() *sqlgraph.QuerySpec { func (tlq *TweetLikeQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(tlq.driver.Dialect()) t1 := builder.Table(tweetlike.Table) - columns := tlq.fields + columns := tlq.ctx.Fields if len(columns) == 0 { columns = tweetlike.Columns } @@ -505,7 +500,7 @@ func (tlq *TweetLikeQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = tlq.sql selector.Select(selector.Columns(columns...)...) } - if tlq.unique != nil && *tlq.unique { + if tlq.ctx.Unique != nil && *tlq.ctx.Unique { selector.Distinct() } for _, p := range tlq.predicates { @@ -514,12 +509,12 @@ func (tlq *TweetLikeQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range tlq.order { p(selector) } - if offset := tlq.offset; offset != nil { + if offset := tlq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := tlq.limit; limit != nil { + if limit := tlq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -539,7 +534,7 @@ func (tlgb *TweetLikeGroupBy) Aggregate(fns ...AggregateFunc) *TweetLikeGroupBy // Scan applies the selector query and scans the result into the given value. func (tlgb *TweetLikeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTweetLike, "GroupBy") + ctx = setContextOp(ctx, tlgb.build.ctx, "GroupBy") if err := tlgb.build.prepareQuery(ctx); err != nil { return err } @@ -587,7 +582,7 @@ func (tls *TweetLikeSelect) Aggregate(fns ...AggregateFunc) *TweetLikeSelect { // Scan applies the selector query and scans the result into the given value. func (tls *TweetLikeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTweetLike, "Select") + ctx = setContextOp(ctx, tls.ctx, "Select") if err := tls.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/tweettag_query.go b/entc/integration/edgeschema/ent/tweettag_query.go index 7d1e85b78..de6c94cdc 100644 --- a/entc/integration/edgeschema/ent/tweettag_query.go +++ b/entc/integration/edgeschema/ent/tweettag_query.go @@ -24,11 +24,8 @@ import ( // TweetTagQuery is the builder for querying TweetTag entities. type TweetTagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.TweetTag withTag *TagQuery @@ -46,20 +43,20 @@ func (ttq *TweetTagQuery) Where(ps ...predicate.TweetTag) *TweetTagQuery { // Limit the number of records to be returned by this query. func (ttq *TweetTagQuery) Limit(limit int) *TweetTagQuery { - ttq.limit = &limit + ttq.ctx.Limit = &limit return ttq } // Offset to start from. func (ttq *TweetTagQuery) Offset(offset int) *TweetTagQuery { - ttq.offset = &offset + ttq.ctx.Offset = &offset return ttq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ttq *TweetTagQuery) Unique(unique bool) *TweetTagQuery { - ttq.unique = &unique + ttq.ctx.Unique = &unique return ttq } @@ -116,7 +113,7 @@ func (ttq *TweetTagQuery) QueryTweet() *TweetQuery { // First returns the first TweetTag entity from the query. // Returns a *NotFoundError when no TweetTag was found. func (ttq *TweetTagQuery) First(ctx context.Context) (*TweetTag, error) { - nodes, err := ttq.Limit(1).All(newQueryContext(ctx, TypeTweetTag, "First")) + nodes, err := ttq.Limit(1).All(setContextOp(ctx, ttq.ctx, "First")) if err != nil { return nil, err } @@ -139,7 +136,7 @@ func (ttq *TweetTagQuery) FirstX(ctx context.Context) *TweetTag { // Returns a *NotFoundError when no TweetTag ID was found. func (ttq *TweetTagQuery) FirstID(ctx context.Context) (id uuid.UUID, err error) { var ids []uuid.UUID - if ids, err = ttq.Limit(1).IDs(newQueryContext(ctx, TypeTweetTag, "FirstID")); err != nil { + if ids, err = ttq.Limit(1).IDs(setContextOp(ctx, ttq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -162,7 +159,7 @@ func (ttq *TweetTagQuery) FirstIDX(ctx context.Context) uuid.UUID { // Returns a *NotSingularError when more than one TweetTag entity is found. // Returns a *NotFoundError when no TweetTag entities are found. func (ttq *TweetTagQuery) Only(ctx context.Context) (*TweetTag, error) { - nodes, err := ttq.Limit(2).All(newQueryContext(ctx, TypeTweetTag, "Only")) + nodes, err := ttq.Limit(2).All(setContextOp(ctx, ttq.ctx, "Only")) if err != nil { return nil, err } @@ -190,7 +187,7 @@ func (ttq *TweetTagQuery) OnlyX(ctx context.Context) *TweetTag { // Returns a *NotFoundError when no entities are found. func (ttq *TweetTagQuery) OnlyID(ctx context.Context) (id uuid.UUID, err error) { var ids []uuid.UUID - if ids, err = ttq.Limit(2).IDs(newQueryContext(ctx, TypeTweetTag, "OnlyID")); err != nil { + if ids, err = ttq.Limit(2).IDs(setContextOp(ctx, ttq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -215,7 +212,7 @@ func (ttq *TweetTagQuery) OnlyIDX(ctx context.Context) uuid.UUID { // All executes the query and returns a list of TweetTags. func (ttq *TweetTagQuery) All(ctx context.Context) ([]*TweetTag, error) { - ctx = newQueryContext(ctx, TypeTweetTag, "All") + ctx = setContextOp(ctx, ttq.ctx, "All") if err := ttq.prepareQuery(ctx); err != nil { return nil, err } @@ -235,7 +232,7 @@ func (ttq *TweetTagQuery) AllX(ctx context.Context) []*TweetTag { // IDs executes the query and returns a list of TweetTag IDs. func (ttq *TweetTagQuery) IDs(ctx context.Context) ([]uuid.UUID, error) { var ids []uuid.UUID - ctx = newQueryContext(ctx, TypeTweetTag, "IDs") + ctx = setContextOp(ctx, ttq.ctx, "IDs") if err := ttq.Select(tweettag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -253,7 +250,7 @@ func (ttq *TweetTagQuery) IDsX(ctx context.Context) []uuid.UUID { // Count returns the count of the given query. func (ttq *TweetTagQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeTweetTag, "Count") + ctx = setContextOp(ctx, ttq.ctx, "Count") if err := ttq.prepareQuery(ctx); err != nil { return 0, err } @@ -271,7 +268,7 @@ func (ttq *TweetTagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ttq *TweetTagQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeTweetTag, "Exist") + ctx = setContextOp(ctx, ttq.ctx, "Exist") switch _, err := ttq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -299,17 +296,15 @@ func (ttq *TweetTagQuery) Clone() *TweetTagQuery { } return &TweetTagQuery{ config: ttq.config, - limit: ttq.limit, - offset: ttq.offset, + ctx: ttq.ctx.Clone(), order: append([]OrderFunc{}, ttq.order...), inters: append([]Interceptor{}, ttq.inters...), predicates: append([]predicate.TweetTag{}, ttq.predicates...), withTag: ttq.withTag.Clone(), withTweet: ttq.withTweet.Clone(), // clone intermediate query. - sql: ttq.sql.Clone(), - path: ttq.path, - unique: ttq.unique, + sql: ttq.sql.Clone(), + path: ttq.path, } } @@ -350,9 +345,9 @@ func (ttq *TweetTagQuery) WithTweet(opts ...func(*TweetQuery)) *TweetTagQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ttq *TweetTagQuery) GroupBy(field string, fields ...string) *TweetTagGroupBy { - ttq.fields = append([]string{field}, fields...) + ttq.ctx.Fields = append([]string{field}, fields...) grbuild := &TweetTagGroupBy{build: ttq} - grbuild.flds = &ttq.fields + grbuild.flds = &ttq.ctx.Fields grbuild.label = tweettag.Label grbuild.scan = grbuild.Scan return grbuild @@ -371,10 +366,10 @@ func (ttq *TweetTagQuery) GroupBy(field string, fields ...string) *TweetTagGroup // Select(tweettag.FieldAddedAt). // Scan(ctx, &v) func (ttq *TweetTagQuery) Select(fields ...string) *TweetTagSelect { - ttq.fields = append(ttq.fields, fields...) + ttq.ctx.Fields = append(ttq.ctx.Fields, fields...) sbuild := &TweetTagSelect{TweetTagQuery: ttq} sbuild.label = tweettag.Label - sbuild.flds, sbuild.scan = &ttq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ttq.ctx.Fields, sbuild.Scan return sbuild } @@ -394,7 +389,7 @@ func (ttq *TweetTagQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range ttq.fields { + for _, f := range ttq.ctx.Fields { if !tweettag.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -512,9 +507,9 @@ func (ttq *TweetTagQuery) loadTweet(ctx context.Context, query *TweetQuery, node func (ttq *TweetTagQuery) sqlCount(ctx context.Context) (int, error) { _spec := ttq.querySpec() - _spec.Node.Columns = ttq.fields - if len(ttq.fields) > 0 { - _spec.Unique = ttq.unique != nil && *ttq.unique + _spec.Node.Columns = ttq.ctx.Fields + if len(ttq.ctx.Fields) > 0 { + _spec.Unique = ttq.ctx.Unique != nil && *ttq.ctx.Unique } return sqlgraph.CountNodes(ctx, ttq.driver, _spec) } @@ -532,10 +527,10 @@ func (ttq *TweetTagQuery) querySpec() *sqlgraph.QuerySpec { From: ttq.sql, Unique: true, } - if unique := ttq.unique; unique != nil { + if unique := ttq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := ttq.fields; len(fields) > 0 { + if fields := ttq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, tweettag.FieldID) for i := range fields { @@ -551,10 +546,10 @@ func (ttq *TweetTagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ttq.limit; limit != nil { + if limit := ttq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ttq.offset; offset != nil { + if offset := ttq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ttq.order; len(ps) > 0 { @@ -570,7 +565,7 @@ func (ttq *TweetTagQuery) querySpec() *sqlgraph.QuerySpec { func (ttq *TweetTagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ttq.driver.Dialect()) t1 := builder.Table(tweettag.Table) - columns := ttq.fields + columns := ttq.ctx.Fields if len(columns) == 0 { columns = tweettag.Columns } @@ -579,7 +574,7 @@ func (ttq *TweetTagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ttq.sql selector.Select(selector.Columns(columns...)...) } - if ttq.unique != nil && *ttq.unique { + if ttq.ctx.Unique != nil && *ttq.ctx.Unique { selector.Distinct() } for _, p := range ttq.predicates { @@ -588,12 +583,12 @@ func (ttq *TweetTagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ttq.order { p(selector) } - if offset := ttq.offset; offset != nil { + if offset := ttq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ttq.limit; limit != nil { + if limit := ttq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -613,7 +608,7 @@ func (ttgb *TweetTagGroupBy) Aggregate(fns ...AggregateFunc) *TweetTagGroupBy { // Scan applies the selector query and scans the result into the given value. func (ttgb *TweetTagGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTweetTag, "GroupBy") + ctx = setContextOp(ctx, ttgb.build.ctx, "GroupBy") if err := ttgb.build.prepareQuery(ctx); err != nil { return err } @@ -661,7 +656,7 @@ func (tts *TweetTagSelect) Aggregate(fns ...AggregateFunc) *TweetTagSelect { // Scan applies the selector query and scans the result into the given value. func (tts *TweetTagSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTweetTag, "Select") + ctx = setContextOp(ctx, tts.ctx, "Select") if err := tts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/user_query.go b/entc/integration/edgeschema/ent/user_query.go index c2bf89d55..fd42588c4 100644 --- a/entc/integration/edgeschema/ent/user_query.go +++ b/entc/integration/edgeschema/ent/user_query.go @@ -32,11 +32,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withGroups *GroupQuery @@ -64,20 +61,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -354,7 +351,7 @@ func (uq *UserQuery) QueryRolesUsers() *RoleUserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -377,7 +374,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -400,7 +397,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -428,7 +425,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -453,7 +450,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -473,7 +470,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -491,7 +488,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -509,7 +506,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -537,8 +534,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -555,9 +551,8 @@ func (uq *UserQuery) Clone() *UserQuery { withUserTweets: uq.withUserTweets.Clone(), withRolesUsers: uq.withRolesUsers.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -708,9 +703,9 @@ func (uq *UserQuery) WithRolesUsers(opts ...func(*RoleUserQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -729,10 +724,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldName). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -752,7 +747,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -1410,9 +1405,9 @@ func (uq *UserQuery) loadRolesUsers(ctx context.Context, query *RoleUserQuery, n func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -1430,10 +1425,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -1449,10 +1444,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -1468,7 +1463,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -1477,7 +1472,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -1486,12 +1481,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -1511,7 +1506,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -1559,7 +1554,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/usergroup_query.go b/entc/integration/edgeschema/ent/usergroup_query.go index 7eb307f8b..779a56e43 100644 --- a/entc/integration/edgeschema/ent/usergroup_query.go +++ b/entc/integration/edgeschema/ent/usergroup_query.go @@ -23,11 +23,8 @@ import ( // UserGroupQuery is the builder for querying UserGroup entities. type UserGroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.UserGroup withUser *UserQuery @@ -45,20 +42,20 @@ func (ugq *UserGroupQuery) Where(ps ...predicate.UserGroup) *UserGroupQuery { // Limit the number of records to be returned by this query. func (ugq *UserGroupQuery) Limit(limit int) *UserGroupQuery { - ugq.limit = &limit + ugq.ctx.Limit = &limit return ugq } // Offset to start from. func (ugq *UserGroupQuery) Offset(offset int) *UserGroupQuery { - ugq.offset = &offset + ugq.ctx.Offset = &offset return ugq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ugq *UserGroupQuery) Unique(unique bool) *UserGroupQuery { - ugq.unique = &unique + ugq.ctx.Unique = &unique return ugq } @@ -115,7 +112,7 @@ func (ugq *UserGroupQuery) QueryGroup() *GroupQuery { // First returns the first UserGroup entity from the query. // Returns a *NotFoundError when no UserGroup was found. func (ugq *UserGroupQuery) First(ctx context.Context) (*UserGroup, error) { - nodes, err := ugq.Limit(1).All(newQueryContext(ctx, TypeUserGroup, "First")) + nodes, err := ugq.Limit(1).All(setContextOp(ctx, ugq.ctx, "First")) if err != nil { return nil, err } @@ -138,7 +135,7 @@ func (ugq *UserGroupQuery) FirstX(ctx context.Context) *UserGroup { // Returns a *NotFoundError when no UserGroup ID was found. func (ugq *UserGroupQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ugq.Limit(1).IDs(newQueryContext(ctx, TypeUserGroup, "FirstID")); err != nil { + if ids, err = ugq.Limit(1).IDs(setContextOp(ctx, ugq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -161,7 +158,7 @@ func (ugq *UserGroupQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one UserGroup entity is found. // Returns a *NotFoundError when no UserGroup entities are found. func (ugq *UserGroupQuery) Only(ctx context.Context) (*UserGroup, error) { - nodes, err := ugq.Limit(2).All(newQueryContext(ctx, TypeUserGroup, "Only")) + nodes, err := ugq.Limit(2).All(setContextOp(ctx, ugq.ctx, "Only")) if err != nil { return nil, err } @@ -189,7 +186,7 @@ func (ugq *UserGroupQuery) OnlyX(ctx context.Context) *UserGroup { // Returns a *NotFoundError when no entities are found. func (ugq *UserGroupQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ugq.Limit(2).IDs(newQueryContext(ctx, TypeUserGroup, "OnlyID")); err != nil { + if ids, err = ugq.Limit(2).IDs(setContextOp(ctx, ugq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -214,7 +211,7 @@ func (ugq *UserGroupQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of UserGroups. func (ugq *UserGroupQuery) All(ctx context.Context) ([]*UserGroup, error) { - ctx = newQueryContext(ctx, TypeUserGroup, "All") + ctx = setContextOp(ctx, ugq.ctx, "All") if err := ugq.prepareQuery(ctx); err != nil { return nil, err } @@ -234,7 +231,7 @@ func (ugq *UserGroupQuery) AllX(ctx context.Context) []*UserGroup { // IDs executes the query and returns a list of UserGroup IDs. func (ugq *UserGroupQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUserGroup, "IDs") + ctx = setContextOp(ctx, ugq.ctx, "IDs") if err := ugq.Select(usergroup.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -252,7 +249,7 @@ func (ugq *UserGroupQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (ugq *UserGroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUserGroup, "Count") + ctx = setContextOp(ctx, ugq.ctx, "Count") if err := ugq.prepareQuery(ctx); err != nil { return 0, err } @@ -270,7 +267,7 @@ func (ugq *UserGroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ugq *UserGroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUserGroup, "Exist") + ctx = setContextOp(ctx, ugq.ctx, "Exist") switch _, err := ugq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -298,17 +295,15 @@ func (ugq *UserGroupQuery) Clone() *UserGroupQuery { } return &UserGroupQuery{ config: ugq.config, - limit: ugq.limit, - offset: ugq.offset, + ctx: ugq.ctx.Clone(), order: append([]OrderFunc{}, ugq.order...), inters: append([]Interceptor{}, ugq.inters...), predicates: append([]predicate.UserGroup{}, ugq.predicates...), withUser: ugq.withUser.Clone(), withGroup: ugq.withGroup.Clone(), // clone intermediate query. - sql: ugq.sql.Clone(), - path: ugq.path, - unique: ugq.unique, + sql: ugq.sql.Clone(), + path: ugq.path, } } @@ -349,9 +344,9 @@ func (ugq *UserGroupQuery) WithGroup(opts ...func(*GroupQuery)) *UserGroupQuery // Aggregate(ent.Count()). // Scan(ctx, &v) func (ugq *UserGroupQuery) GroupBy(field string, fields ...string) *UserGroupGroupBy { - ugq.fields = append([]string{field}, fields...) + ugq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupGroupBy{build: ugq} - grbuild.flds = &ugq.fields + grbuild.flds = &ugq.ctx.Fields grbuild.label = usergroup.Label grbuild.scan = grbuild.Scan return grbuild @@ -370,10 +365,10 @@ func (ugq *UserGroupQuery) GroupBy(field string, fields ...string) *UserGroupGro // Select(usergroup.FieldJoinedAt). // Scan(ctx, &v) func (ugq *UserGroupQuery) Select(fields ...string) *UserGroupSelect { - ugq.fields = append(ugq.fields, fields...) + ugq.ctx.Fields = append(ugq.ctx.Fields, fields...) sbuild := &UserGroupSelect{UserGroupQuery: ugq} sbuild.label = usergroup.Label - sbuild.flds, sbuild.scan = &ugq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ugq.ctx.Fields, sbuild.Scan return sbuild } @@ -393,7 +388,7 @@ func (ugq *UserGroupQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range ugq.fields { + for _, f := range ugq.ctx.Fields { if !usergroup.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -511,9 +506,9 @@ func (ugq *UserGroupQuery) loadGroup(ctx context.Context, query *GroupQuery, nod func (ugq *UserGroupQuery) sqlCount(ctx context.Context) (int, error) { _spec := ugq.querySpec() - _spec.Node.Columns = ugq.fields - if len(ugq.fields) > 0 { - _spec.Unique = ugq.unique != nil && *ugq.unique + _spec.Node.Columns = ugq.ctx.Fields + if len(ugq.ctx.Fields) > 0 { + _spec.Unique = ugq.ctx.Unique != nil && *ugq.ctx.Unique } return sqlgraph.CountNodes(ctx, ugq.driver, _spec) } @@ -531,10 +526,10 @@ func (ugq *UserGroupQuery) querySpec() *sqlgraph.QuerySpec { From: ugq.sql, Unique: true, } - if unique := ugq.unique; unique != nil { + if unique := ugq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := ugq.fields; len(fields) > 0 { + if fields := ugq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, usergroup.FieldID) for i := range fields { @@ -550,10 +545,10 @@ func (ugq *UserGroupQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ugq.limit; limit != nil { + if limit := ugq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ugq.offset; offset != nil { + if offset := ugq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ugq.order; len(ps) > 0 { @@ -569,7 +564,7 @@ func (ugq *UserGroupQuery) querySpec() *sqlgraph.QuerySpec { func (ugq *UserGroupQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ugq.driver.Dialect()) t1 := builder.Table(usergroup.Table) - columns := ugq.fields + columns := ugq.ctx.Fields if len(columns) == 0 { columns = usergroup.Columns } @@ -578,7 +573,7 @@ func (ugq *UserGroupQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ugq.sql selector.Select(selector.Columns(columns...)...) } - if ugq.unique != nil && *ugq.unique { + if ugq.ctx.Unique != nil && *ugq.ctx.Unique { selector.Distinct() } for _, p := range ugq.predicates { @@ -587,12 +582,12 @@ func (ugq *UserGroupQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ugq.order { p(selector) } - if offset := ugq.offset; offset != nil { + if offset := ugq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ugq.limit; limit != nil { + if limit := ugq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -612,7 +607,7 @@ func (uggb *UserGroupGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupGroupBy // Scan applies the selector query and scans the result into the given value. func (uggb *UserGroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUserGroup, "GroupBy") + ctx = setContextOp(ctx, uggb.build.ctx, "GroupBy") if err := uggb.build.prepareQuery(ctx); err != nil { return err } @@ -660,7 +655,7 @@ func (ugs *UserGroupSelect) Aggregate(fns ...AggregateFunc) *UserGroupSelect { // Scan applies the selector query and scans the result into the given value. func (ugs *UserGroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUserGroup, "Select") + ctx = setContextOp(ctx, ugs.ctx, "Select") if err := ugs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/edgeschema/ent/usertweet_query.go b/entc/integration/edgeschema/ent/usertweet_query.go index 5fae44fe2..128400935 100644 --- a/entc/integration/edgeschema/ent/usertweet_query.go +++ b/entc/integration/edgeschema/ent/usertweet_query.go @@ -23,11 +23,8 @@ import ( // UserTweetQuery is the builder for querying UserTweet entities. type UserTweetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.UserTweet withUser *UserQuery @@ -45,20 +42,20 @@ func (utq *UserTweetQuery) Where(ps ...predicate.UserTweet) *UserTweetQuery { // Limit the number of records to be returned by this query. func (utq *UserTweetQuery) Limit(limit int) *UserTweetQuery { - utq.limit = &limit + utq.ctx.Limit = &limit return utq } // Offset to start from. func (utq *UserTweetQuery) Offset(offset int) *UserTweetQuery { - utq.offset = &offset + utq.ctx.Offset = &offset return utq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (utq *UserTweetQuery) Unique(unique bool) *UserTweetQuery { - utq.unique = &unique + utq.ctx.Unique = &unique return utq } @@ -115,7 +112,7 @@ func (utq *UserTweetQuery) QueryTweet() *TweetQuery { // First returns the first UserTweet entity from the query. // Returns a *NotFoundError when no UserTweet was found. func (utq *UserTweetQuery) First(ctx context.Context) (*UserTweet, error) { - nodes, err := utq.Limit(1).All(newQueryContext(ctx, TypeUserTweet, "First")) + nodes, err := utq.Limit(1).All(setContextOp(ctx, utq.ctx, "First")) if err != nil { return nil, err } @@ -138,7 +135,7 @@ func (utq *UserTweetQuery) FirstX(ctx context.Context) *UserTweet { // Returns a *NotFoundError when no UserTweet ID was found. func (utq *UserTweetQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = utq.Limit(1).IDs(newQueryContext(ctx, TypeUserTweet, "FirstID")); err != nil { + if ids, err = utq.Limit(1).IDs(setContextOp(ctx, utq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -161,7 +158,7 @@ func (utq *UserTweetQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one UserTweet entity is found. // Returns a *NotFoundError when no UserTweet entities are found. func (utq *UserTweetQuery) Only(ctx context.Context) (*UserTweet, error) { - nodes, err := utq.Limit(2).All(newQueryContext(ctx, TypeUserTweet, "Only")) + nodes, err := utq.Limit(2).All(setContextOp(ctx, utq.ctx, "Only")) if err != nil { return nil, err } @@ -189,7 +186,7 @@ func (utq *UserTweetQuery) OnlyX(ctx context.Context) *UserTweet { // Returns a *NotFoundError when no entities are found. func (utq *UserTweetQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = utq.Limit(2).IDs(newQueryContext(ctx, TypeUserTweet, "OnlyID")); err != nil { + if ids, err = utq.Limit(2).IDs(setContextOp(ctx, utq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -214,7 +211,7 @@ func (utq *UserTweetQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of UserTweets. func (utq *UserTweetQuery) All(ctx context.Context) ([]*UserTweet, error) { - ctx = newQueryContext(ctx, TypeUserTweet, "All") + ctx = setContextOp(ctx, utq.ctx, "All") if err := utq.prepareQuery(ctx); err != nil { return nil, err } @@ -234,7 +231,7 @@ func (utq *UserTweetQuery) AllX(ctx context.Context) []*UserTweet { // IDs executes the query and returns a list of UserTweet IDs. func (utq *UserTweetQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUserTweet, "IDs") + ctx = setContextOp(ctx, utq.ctx, "IDs") if err := utq.Select(usertweet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -252,7 +249,7 @@ func (utq *UserTweetQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (utq *UserTweetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUserTweet, "Count") + ctx = setContextOp(ctx, utq.ctx, "Count") if err := utq.prepareQuery(ctx); err != nil { return 0, err } @@ -270,7 +267,7 @@ func (utq *UserTweetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (utq *UserTweetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUserTweet, "Exist") + ctx = setContextOp(ctx, utq.ctx, "Exist") switch _, err := utq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -298,17 +295,15 @@ func (utq *UserTweetQuery) Clone() *UserTweetQuery { } return &UserTweetQuery{ config: utq.config, - limit: utq.limit, - offset: utq.offset, + ctx: utq.ctx.Clone(), order: append([]OrderFunc{}, utq.order...), inters: append([]Interceptor{}, utq.inters...), predicates: append([]predicate.UserTweet{}, utq.predicates...), withUser: utq.withUser.Clone(), withTweet: utq.withTweet.Clone(), // clone intermediate query. - sql: utq.sql.Clone(), - path: utq.path, - unique: utq.unique, + sql: utq.sql.Clone(), + path: utq.path, } } @@ -349,9 +344,9 @@ func (utq *UserTweetQuery) WithTweet(opts ...func(*TweetQuery)) *UserTweetQuery // Aggregate(ent.Count()). // Scan(ctx, &v) func (utq *UserTweetQuery) GroupBy(field string, fields ...string) *UserTweetGroupBy { - utq.fields = append([]string{field}, fields...) + utq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserTweetGroupBy{build: utq} - grbuild.flds = &utq.fields + grbuild.flds = &utq.ctx.Fields grbuild.label = usertweet.Label grbuild.scan = grbuild.Scan return grbuild @@ -370,10 +365,10 @@ func (utq *UserTweetQuery) GroupBy(field string, fields ...string) *UserTweetGro // Select(usertweet.FieldCreatedAt). // Scan(ctx, &v) func (utq *UserTweetQuery) Select(fields ...string) *UserTweetSelect { - utq.fields = append(utq.fields, fields...) + utq.ctx.Fields = append(utq.ctx.Fields, fields...) sbuild := &UserTweetSelect{UserTweetQuery: utq} sbuild.label = usertweet.Label - sbuild.flds, sbuild.scan = &utq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &utq.ctx.Fields, sbuild.Scan return sbuild } @@ -393,7 +388,7 @@ func (utq *UserTweetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range utq.fields { + for _, f := range utq.ctx.Fields { if !usertweet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -511,9 +506,9 @@ func (utq *UserTweetQuery) loadTweet(ctx context.Context, query *TweetQuery, nod func (utq *UserTweetQuery) sqlCount(ctx context.Context) (int, error) { _spec := utq.querySpec() - _spec.Node.Columns = utq.fields - if len(utq.fields) > 0 { - _spec.Unique = utq.unique != nil && *utq.unique + _spec.Node.Columns = utq.ctx.Fields + if len(utq.ctx.Fields) > 0 { + _spec.Unique = utq.ctx.Unique != nil && *utq.ctx.Unique } return sqlgraph.CountNodes(ctx, utq.driver, _spec) } @@ -531,10 +526,10 @@ func (utq *UserTweetQuery) querySpec() *sqlgraph.QuerySpec { From: utq.sql, Unique: true, } - if unique := utq.unique; unique != nil { + if unique := utq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := utq.fields; len(fields) > 0 { + if fields := utq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, usertweet.FieldID) for i := range fields { @@ -550,10 +545,10 @@ func (utq *UserTweetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := utq.limit; limit != nil { + if limit := utq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := utq.offset; offset != nil { + if offset := utq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := utq.order; len(ps) > 0 { @@ -569,7 +564,7 @@ func (utq *UserTweetQuery) querySpec() *sqlgraph.QuerySpec { func (utq *UserTweetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(utq.driver.Dialect()) t1 := builder.Table(usertweet.Table) - columns := utq.fields + columns := utq.ctx.Fields if len(columns) == 0 { columns = usertweet.Columns } @@ -578,7 +573,7 @@ func (utq *UserTweetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = utq.sql selector.Select(selector.Columns(columns...)...) } - if utq.unique != nil && *utq.unique { + if utq.ctx.Unique != nil && *utq.ctx.Unique { selector.Distinct() } for _, p := range utq.predicates { @@ -587,12 +582,12 @@ func (utq *UserTweetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range utq.order { p(selector) } - if offset := utq.offset; offset != nil { + if offset := utq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := utq.limit; limit != nil { + if limit := utq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -612,7 +607,7 @@ func (utgb *UserTweetGroupBy) Aggregate(fns ...AggregateFunc) *UserTweetGroupBy // Scan applies the selector query and scans the result into the given value. func (utgb *UserTweetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUserTweet, "GroupBy") + ctx = setContextOp(ctx, utgb.build.ctx, "GroupBy") if err := utgb.build.prepareQuery(ctx); err != nil { return err } @@ -660,7 +655,7 @@ func (uts *UserTweetSelect) Aggregate(fns ...AggregateFunc) *UserTweetSelect { // Scan applies the selector query and scans the result into the given value. func (uts *UserTweetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUserTweet, "Select") + ctx = setContextOp(ctx, uts.ctx, "Select") if err := uts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/api_query.go b/entc/integration/ent/api_query.go index e90236426..a2b936d93 100644 --- a/entc/integration/ent/api_query.go +++ b/entc/integration/ent/api_query.go @@ -22,11 +22,8 @@ import ( // APIQuery is the builder for querying Api entities. type APIQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Api modifiers []func(*sql.Selector) @@ -43,20 +40,20 @@ func (aq *APIQuery) Where(ps ...predicate.Api) *APIQuery { // Limit the number of records to be returned by this query. func (aq *APIQuery) Limit(limit int) *APIQuery { - aq.limit = &limit + aq.ctx.Limit = &limit return aq } // Offset to start from. func (aq *APIQuery) Offset(offset int) *APIQuery { - aq.offset = &offset + aq.ctx.Offset = &offset return aq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (aq *APIQuery) Unique(unique bool) *APIQuery { - aq.unique = &unique + aq.ctx.Unique = &unique return aq } @@ -69,7 +66,7 @@ func (aq *APIQuery) Order(o ...OrderFunc) *APIQuery { // First returns the first Api entity from the query. // Returns a *NotFoundError when no Api was found. func (aq *APIQuery) First(ctx context.Context) (*Api, error) { - nodes, err := aq.Limit(1).All(newQueryContext(ctx, TypeAPI, "First")) + nodes, err := aq.Limit(1).All(setContextOp(ctx, aq.ctx, "First")) if err != nil { return nil, err } @@ -92,7 +89,7 @@ func (aq *APIQuery) FirstX(ctx context.Context) *Api { // Returns a *NotFoundError when no Api ID was found. func (aq *APIQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(1).IDs(newQueryContext(ctx, TypeAPI, "FirstID")); err != nil { + if ids, err = aq.Limit(1).IDs(setContextOp(ctx, aq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -115,7 +112,7 @@ func (aq *APIQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Api entity is found. // Returns a *NotFoundError when no Api entities are found. func (aq *APIQuery) Only(ctx context.Context) (*Api, error) { - nodes, err := aq.Limit(2).All(newQueryContext(ctx, TypeAPI, "Only")) + nodes, err := aq.Limit(2).All(setContextOp(ctx, aq.ctx, "Only")) if err != nil { return nil, err } @@ -143,7 +140,7 @@ func (aq *APIQuery) OnlyX(ctx context.Context) *Api { // Returns a *NotFoundError when no entities are found. func (aq *APIQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(2).IDs(newQueryContext(ctx, TypeAPI, "OnlyID")); err != nil { + if ids, err = aq.Limit(2).IDs(setContextOp(ctx, aq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -168,7 +165,7 @@ func (aq *APIQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Apis. func (aq *APIQuery) All(ctx context.Context) ([]*Api, error) { - ctx = newQueryContext(ctx, TypeAPI, "All") + ctx = setContextOp(ctx, aq.ctx, "All") if err := aq.prepareQuery(ctx); err != nil { return nil, err } @@ -188,7 +185,7 @@ func (aq *APIQuery) AllX(ctx context.Context) []*Api { // IDs executes the query and returns a list of Api IDs. func (aq *APIQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeAPI, "IDs") + ctx = setContextOp(ctx, aq.ctx, "IDs") if err := aq.Select(api.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -206,7 +203,7 @@ func (aq *APIQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (aq *APIQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeAPI, "Count") + ctx = setContextOp(ctx, aq.ctx, "Count") if err := aq.prepareQuery(ctx); err != nil { return 0, err } @@ -224,7 +221,7 @@ func (aq *APIQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (aq *APIQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeAPI, "Exist") + ctx = setContextOp(ctx, aq.ctx, "Exist") switch _, err := aq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -252,24 +249,22 @@ func (aq *APIQuery) Clone() *APIQuery { } return &APIQuery{ config: aq.config, - limit: aq.limit, - offset: aq.offset, + ctx: aq.ctx.Clone(), order: append([]OrderFunc{}, aq.order...), inters: append([]Interceptor{}, aq.inters...), predicates: append([]predicate.Api{}, aq.predicates...), // clone intermediate query. - sql: aq.sql.Clone(), - path: aq.path, - unique: aq.unique, + sql: aq.sql.Clone(), + path: aq.path, } } // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (aq *APIQuery) GroupBy(field string, fields ...string) *APIGroupBy { - aq.fields = append([]string{field}, fields...) + aq.ctx.Fields = append([]string{field}, fields...) grbuild := &APIGroupBy{build: aq} - grbuild.flds = &aq.fields + grbuild.flds = &aq.ctx.Fields grbuild.label = api.Label grbuild.scan = grbuild.Scan return grbuild @@ -278,10 +273,10 @@ func (aq *APIQuery) GroupBy(field string, fields ...string) *APIGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (aq *APIQuery) Select(fields ...string) *APISelect { - aq.fields = append(aq.fields, fields...) + aq.ctx.Fields = append(aq.ctx.Fields, fields...) sbuild := &APISelect{APIQuery: aq} sbuild.label = api.Label - sbuild.flds, sbuild.scan = &aq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &aq.ctx.Fields, sbuild.Scan return sbuild } @@ -301,7 +296,7 @@ func (aq *APIQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range aq.fields { + for _, f := range aq.ctx.Fields { if !api.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -349,9 +344,9 @@ func (aq *APIQuery) sqlCount(ctx context.Context) (int, error) { if len(aq.modifiers) > 0 { _spec.Modifiers = aq.modifiers } - _spec.Node.Columns = aq.fields - if len(aq.fields) > 0 { - _spec.Unique = aq.unique != nil && *aq.unique + _spec.Node.Columns = aq.ctx.Fields + if len(aq.ctx.Fields) > 0 { + _spec.Unique = aq.ctx.Unique != nil && *aq.ctx.Unique } return sqlgraph.CountNodes(ctx, aq.driver, _spec) } @@ -369,10 +364,10 @@ func (aq *APIQuery) querySpec() *sqlgraph.QuerySpec { From: aq.sql, Unique: true, } - if unique := aq.unique; unique != nil { + if unique := aq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := aq.fields; len(fields) > 0 { + if fields := aq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, api.FieldID) for i := range fields { @@ -388,10 +383,10 @@ func (aq *APIQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := aq.order; len(ps) > 0 { @@ -407,7 +402,7 @@ func (aq *APIQuery) querySpec() *sqlgraph.QuerySpec { func (aq *APIQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(aq.driver.Dialect()) t1 := builder.Table(api.Table) - columns := aq.fields + columns := aq.ctx.Fields if len(columns) == 0 { columns = api.Columns } @@ -416,7 +411,7 @@ func (aq *APIQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = aq.sql selector.Select(selector.Columns(columns...)...) } - if aq.unique != nil && *aq.unique { + if aq.ctx.Unique != nil && *aq.ctx.Unique { selector.Distinct() } for _, m := range aq.modifiers { @@ -428,12 +423,12 @@ func (aq *APIQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range aq.order { p(selector) } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -485,7 +480,7 @@ func (agb *APIGroupBy) Aggregate(fns ...AggregateFunc) *APIGroupBy { // Scan applies the selector query and scans the result into the given value. func (agb *APIGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeAPI, "GroupBy") + ctx = setContextOp(ctx, agb.build.ctx, "GroupBy") if err := agb.build.prepareQuery(ctx); err != nil { return err } @@ -533,7 +528,7 @@ func (as *APISelect) Aggregate(fns ...AggregateFunc) *APISelect { // Scan applies the selector query and scans the result into the given value. func (as *APISelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeAPI, "Select") + ctx = setContextOp(ctx, as.ctx, "Select") if err := as.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/card_query.go b/entc/integration/ent/card_query.go index 8735e04de..79136b0c4 100644 --- a/entc/integration/ent/card_query.go +++ b/entc/integration/ent/card_query.go @@ -25,11 +25,8 @@ import ( // CardQuery is the builder for querying Card entities. type CardQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Card withOwner *UserQuery @@ -50,20 +47,20 @@ func (cq *CardQuery) Where(ps ...predicate.Card) *CardQuery { // Limit the number of records to be returned by this query. func (cq *CardQuery) Limit(limit int) *CardQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CardQuery) Offset(offset int) *CardQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CardQuery) Unique(unique bool) *CardQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -120,7 +117,7 @@ func (cq *CardQuery) QuerySpec() *SpecQuery { // First returns the first Card entity from the query. // Returns a *NotFoundError when no Card was found. func (cq *CardQuery) First(ctx context.Context) (*Card, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeCard, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -143,7 +140,7 @@ func (cq *CardQuery) FirstX(ctx context.Context) *Card { // Returns a *NotFoundError when no Card ID was found. func (cq *CardQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeCard, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -166,7 +163,7 @@ func (cq *CardQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Card entity is found. // Returns a *NotFoundError when no Card entities are found. func (cq *CardQuery) Only(ctx context.Context) (*Card, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeCard, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -194,7 +191,7 @@ func (cq *CardQuery) OnlyX(ctx context.Context) *Card { // Returns a *NotFoundError when no entities are found. func (cq *CardQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeCard, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -219,7 +216,7 @@ func (cq *CardQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Cards. func (cq *CardQuery) All(ctx context.Context) ([]*Card, error) { - ctx = newQueryContext(ctx, TypeCard, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -239,7 +236,7 @@ func (cq *CardQuery) AllX(ctx context.Context) []*Card { // IDs executes the query and returns a list of Card IDs. func (cq *CardQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeCard, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(card.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -257,7 +254,7 @@ func (cq *CardQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *CardQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCard, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -275,7 +272,7 @@ func (cq *CardQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CardQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCard, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -303,17 +300,15 @@ func (cq *CardQuery) Clone() *CardQuery { } return &CardQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Card{}, cq.predicates...), withOwner: cq.withOwner.Clone(), withSpec: cq.withSpec.Clone(), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -354,9 +349,9 @@ func (cq *CardQuery) WithSpec(opts ...func(*SpecQuery)) *CardQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CardQuery) GroupBy(field string, fields ...string) *CardGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CardGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = card.Label grbuild.scan = grbuild.Scan return grbuild @@ -375,10 +370,10 @@ func (cq *CardQuery) GroupBy(field string, fields ...string) *CardGroupBy { // Select(card.FieldCreateTime). // Scan(ctx, &v) func (cq *CardQuery) Select(fields ...string) *CardSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CardSelect{CardQuery: cq} sbuild.label = card.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -398,7 +393,7 @@ func (cq *CardQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !card.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -569,9 +564,9 @@ func (cq *CardQuery) sqlCount(ctx context.Context) (int, error) { if len(cq.modifiers) > 0 { _spec.Modifiers = cq.modifiers } - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -589,10 +584,10 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, card.FieldID) for i := range fields { @@ -608,10 +603,10 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -627,7 +622,7 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(card.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = card.Columns } @@ -636,7 +631,7 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, m := range cq.modifiers { @@ -648,12 +643,12 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -719,7 +714,7 @@ func (cgb *CardGroupBy) Aggregate(fns ...AggregateFunc) *CardGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CardGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCard, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -767,7 +762,7 @@ func (cs *CardSelect) Aggregate(fns ...AggregateFunc) *CardSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CardSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCard, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/client.go b/entc/integration/ent/client.go index 34ae20fad..311316a1c 100644 --- a/entc/integration/ent/client.go +++ b/entc/integration/ent/client.go @@ -378,6 +378,7 @@ func (c *APIClient) DeleteOneID(id int) *APIDeleteOne { func (c *APIClient) Query() *APIQuery { return &APIQuery{ config: c.config, + ctx: &QueryContext{Type: TypeAPI}, inters: c.Interceptors(), } } @@ -495,6 +496,7 @@ func (c *CardClient) DeleteOneID(id int) *CardDeleteOne { func (c *CardClient) Query() *CardQuery { return &CardQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCard}, inters: c.Interceptors(), } } @@ -644,6 +646,7 @@ func (c *CommentClient) DeleteOneID(id int) *CommentDeleteOne { func (c *CommentClient) Query() *CommentQuery { return &CommentQuery{ config: c.config, + ctx: &QueryContext{Type: TypeComment}, inters: c.Interceptors(), } } @@ -761,6 +764,7 @@ func (c *FieldTypeClient) DeleteOneID(id int) *FieldTypeDeleteOne { func (c *FieldTypeClient) Query() *FieldTypeQuery { return &FieldTypeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeFieldType}, inters: c.Interceptors(), } } @@ -878,6 +882,7 @@ func (c *FileClient) DeleteOneID(id int) *FileDeleteOne { func (c *FileClient) Query() *FileQuery { return &FileQuery{ config: c.config, + ctx: &QueryContext{Type: TypeFile}, inters: c.Interceptors(), } } @@ -1043,6 +1048,7 @@ func (c *FileTypeClient) DeleteOneID(id int) *FileTypeDeleteOne { func (c *FileTypeClient) Query() *FileTypeQuery { return &FileTypeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeFileType}, inters: c.Interceptors(), } } @@ -1176,6 +1182,7 @@ func (c *GoodsClient) DeleteOneID(id int) *GoodsDeleteOne { func (c *GoodsClient) Query() *GoodsQuery { return &GoodsQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGoods}, inters: c.Interceptors(), } } @@ -1293,6 +1300,7 @@ func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { func (c *GroupClient) Query() *GroupQuery { return &GroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroup}, inters: c.Interceptors(), } } @@ -1474,6 +1482,7 @@ func (c *GroupInfoClient) DeleteOneID(id int) *GroupInfoDeleteOne { func (c *GroupInfoClient) Query() *GroupInfoQuery { return &GroupInfoQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroupInfo}, inters: c.Interceptors(), } } @@ -1607,6 +1616,7 @@ func (c *ItemClient) DeleteOneID(id string) *ItemDeleteOne { func (c *ItemClient) Query() *ItemQuery { return &ItemQuery{ config: c.config, + ctx: &QueryContext{Type: TypeItem}, inters: c.Interceptors(), } } @@ -1724,6 +1734,7 @@ func (c *LicenseClient) DeleteOneID(id int) *LicenseDeleteOne { func (c *LicenseClient) Query() *LicenseQuery { return &LicenseQuery{ config: c.config, + ctx: &QueryContext{Type: TypeLicense}, inters: c.Interceptors(), } } @@ -1841,6 +1852,7 @@ func (c *NodeClient) DeleteOneID(id int) *NodeDeleteOne { func (c *NodeClient) Query() *NodeQuery { return &NodeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeNode}, inters: c.Interceptors(), } } @@ -1990,6 +2002,7 @@ func (c *PetClient) DeleteOneID(id int) *PetDeleteOne { func (c *PetClient) Query() *PetQuery { return &PetQuery{ config: c.config, + ctx: &QueryContext{Type: TypePet}, inters: c.Interceptors(), } } @@ -2139,6 +2152,7 @@ func (c *SpecClient) DeleteOneID(id int) *SpecDeleteOne { func (c *SpecClient) Query() *SpecQuery { return &SpecQuery{ config: c.config, + ctx: &QueryContext{Type: TypeSpec}, inters: c.Interceptors(), } } @@ -2272,6 +2286,7 @@ func (c *TaskClient) DeleteOneID(id int) *TaskDeleteOne { func (c *TaskClient) Query() *TaskQuery { return &TaskQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTask}, inters: c.Interceptors(), } } @@ -2389,6 +2404,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/ent/comment_query.go b/entc/integration/ent/comment_query.go index 971e095ea..6a667ae91 100644 --- a/entc/integration/ent/comment_query.go +++ b/entc/integration/ent/comment_query.go @@ -22,11 +22,8 @@ import ( // CommentQuery is the builder for querying Comment entities. type CommentQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Comment modifiers []func(*sql.Selector) @@ -43,20 +40,20 @@ func (cq *CommentQuery) Where(ps ...predicate.Comment) *CommentQuery { // Limit the number of records to be returned by this query. func (cq *CommentQuery) Limit(limit int) *CommentQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CommentQuery) Offset(offset int) *CommentQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CommentQuery) Unique(unique bool) *CommentQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -69,7 +66,7 @@ func (cq *CommentQuery) Order(o ...OrderFunc) *CommentQuery { // First returns the first Comment entity from the query. // Returns a *NotFoundError when no Comment was found. func (cq *CommentQuery) First(ctx context.Context) (*Comment, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeComment, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -92,7 +89,7 @@ func (cq *CommentQuery) FirstX(ctx context.Context) *Comment { // Returns a *NotFoundError when no Comment ID was found. func (cq *CommentQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeComment, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -115,7 +112,7 @@ func (cq *CommentQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Comment entity is found. // Returns a *NotFoundError when no Comment entities are found. func (cq *CommentQuery) Only(ctx context.Context) (*Comment, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeComment, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -143,7 +140,7 @@ func (cq *CommentQuery) OnlyX(ctx context.Context) *Comment { // Returns a *NotFoundError when no entities are found. func (cq *CommentQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeComment, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -168,7 +165,7 @@ func (cq *CommentQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Comments. func (cq *CommentQuery) All(ctx context.Context) ([]*Comment, error) { - ctx = newQueryContext(ctx, TypeComment, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -188,7 +185,7 @@ func (cq *CommentQuery) AllX(ctx context.Context) []*Comment { // IDs executes the query and returns a list of Comment IDs. func (cq *CommentQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeComment, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(comment.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -206,7 +203,7 @@ func (cq *CommentQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *CommentQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeComment, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -224,7 +221,7 @@ func (cq *CommentQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CommentQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeComment, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -252,15 +249,13 @@ func (cq *CommentQuery) Clone() *CommentQuery { } return &CommentQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Comment{}, cq.predicates...), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -279,9 +274,9 @@ func (cq *CommentQuery) Clone() *CommentQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CommentQuery) GroupBy(field string, fields ...string) *CommentGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CommentGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = comment.Label grbuild.scan = grbuild.Scan return grbuild @@ -300,10 +295,10 @@ func (cq *CommentQuery) GroupBy(field string, fields ...string) *CommentGroupBy // Select(comment.FieldUniqueInt). // Scan(ctx, &v) func (cq *CommentQuery) Select(fields ...string) *CommentSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CommentSelect{CommentQuery: cq} sbuild.label = comment.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -323,7 +318,7 @@ func (cq *CommentQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !comment.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -371,9 +366,9 @@ func (cq *CommentQuery) sqlCount(ctx context.Context) (int, error) { if len(cq.modifiers) > 0 { _spec.Modifiers = cq.modifiers } - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -391,10 +386,10 @@ func (cq *CommentQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, comment.FieldID) for i := range fields { @@ -410,10 +405,10 @@ func (cq *CommentQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -429,7 +424,7 @@ func (cq *CommentQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CommentQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(comment.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = comment.Columns } @@ -438,7 +433,7 @@ func (cq *CommentQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, m := range cq.modifiers { @@ -450,12 +445,12 @@ func (cq *CommentQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -507,7 +502,7 @@ func (cgb *CommentGroupBy) Aggregate(fns ...AggregateFunc) *CommentGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CommentGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeComment, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -555,7 +550,7 @@ func (cs *CommentSelect) Aggregate(fns ...AggregateFunc) *CommentSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CommentSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeComment, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/ent.go b/entc/integration/ent/ent.go index f34b81de2..6c03f2dca 100644 --- a/entc/integration/ent/ent.go +++ b/entc/integration/ent/ent.go @@ -40,6 +40,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -538,10 +539,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/ent/fieldtype_query.go b/entc/integration/ent/fieldtype_query.go index cfd1768ae..bc74a4bc6 100644 --- a/entc/integration/ent/fieldtype_query.go +++ b/entc/integration/ent/fieldtype_query.go @@ -22,11 +22,8 @@ import ( // FieldTypeQuery is the builder for querying FieldType entities. type FieldTypeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.FieldType withFKs bool @@ -44,20 +41,20 @@ func (ftq *FieldTypeQuery) Where(ps ...predicate.FieldType) *FieldTypeQuery { // Limit the number of records to be returned by this query. func (ftq *FieldTypeQuery) Limit(limit int) *FieldTypeQuery { - ftq.limit = &limit + ftq.ctx.Limit = &limit return ftq } // Offset to start from. func (ftq *FieldTypeQuery) Offset(offset int) *FieldTypeQuery { - ftq.offset = &offset + ftq.ctx.Offset = &offset return ftq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ftq *FieldTypeQuery) Unique(unique bool) *FieldTypeQuery { - ftq.unique = &unique + ftq.ctx.Unique = &unique return ftq } @@ -70,7 +67,7 @@ func (ftq *FieldTypeQuery) Order(o ...OrderFunc) *FieldTypeQuery { // First returns the first FieldType entity from the query. // Returns a *NotFoundError when no FieldType was found. func (ftq *FieldTypeQuery) First(ctx context.Context) (*FieldType, error) { - nodes, err := ftq.Limit(1).All(newQueryContext(ctx, TypeFieldType, "First")) + nodes, err := ftq.Limit(1).All(setContextOp(ctx, ftq.ctx, "First")) if err != nil { return nil, err } @@ -93,7 +90,7 @@ func (ftq *FieldTypeQuery) FirstX(ctx context.Context) *FieldType { // Returns a *NotFoundError when no FieldType ID was found. func (ftq *FieldTypeQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ftq.Limit(1).IDs(newQueryContext(ctx, TypeFieldType, "FirstID")); err != nil { + if ids, err = ftq.Limit(1).IDs(setContextOp(ctx, ftq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -116,7 +113,7 @@ func (ftq *FieldTypeQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one FieldType entity is found. // Returns a *NotFoundError when no FieldType entities are found. func (ftq *FieldTypeQuery) Only(ctx context.Context) (*FieldType, error) { - nodes, err := ftq.Limit(2).All(newQueryContext(ctx, TypeFieldType, "Only")) + nodes, err := ftq.Limit(2).All(setContextOp(ctx, ftq.ctx, "Only")) if err != nil { return nil, err } @@ -144,7 +141,7 @@ func (ftq *FieldTypeQuery) OnlyX(ctx context.Context) *FieldType { // Returns a *NotFoundError when no entities are found. func (ftq *FieldTypeQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ftq.Limit(2).IDs(newQueryContext(ctx, TypeFieldType, "OnlyID")); err != nil { + if ids, err = ftq.Limit(2).IDs(setContextOp(ctx, ftq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -169,7 +166,7 @@ func (ftq *FieldTypeQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of FieldTypes. func (ftq *FieldTypeQuery) All(ctx context.Context) ([]*FieldType, error) { - ctx = newQueryContext(ctx, TypeFieldType, "All") + ctx = setContextOp(ctx, ftq.ctx, "All") if err := ftq.prepareQuery(ctx); err != nil { return nil, err } @@ -189,7 +186,7 @@ func (ftq *FieldTypeQuery) AllX(ctx context.Context) []*FieldType { // IDs executes the query and returns a list of FieldType IDs. func (ftq *FieldTypeQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeFieldType, "IDs") + ctx = setContextOp(ctx, ftq.ctx, "IDs") if err := ftq.Select(fieldtype.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -207,7 +204,7 @@ func (ftq *FieldTypeQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (ftq *FieldTypeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeFieldType, "Count") + ctx = setContextOp(ctx, ftq.ctx, "Count") if err := ftq.prepareQuery(ctx); err != nil { return 0, err } @@ -225,7 +222,7 @@ func (ftq *FieldTypeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ftq *FieldTypeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeFieldType, "Exist") + ctx = setContextOp(ctx, ftq.ctx, "Exist") switch _, err := ftq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -253,15 +250,13 @@ func (ftq *FieldTypeQuery) Clone() *FieldTypeQuery { } return &FieldTypeQuery{ config: ftq.config, - limit: ftq.limit, - offset: ftq.offset, + ctx: ftq.ctx.Clone(), order: append([]OrderFunc{}, ftq.order...), inters: append([]Interceptor{}, ftq.inters...), predicates: append([]predicate.FieldType{}, ftq.predicates...), // clone intermediate query. - sql: ftq.sql.Clone(), - path: ftq.path, - unique: ftq.unique, + sql: ftq.sql.Clone(), + path: ftq.path, } } @@ -280,9 +275,9 @@ func (ftq *FieldTypeQuery) Clone() *FieldTypeQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ftq *FieldTypeQuery) GroupBy(field string, fields ...string) *FieldTypeGroupBy { - ftq.fields = append([]string{field}, fields...) + ftq.ctx.Fields = append([]string{field}, fields...) grbuild := &FieldTypeGroupBy{build: ftq} - grbuild.flds = &ftq.fields + grbuild.flds = &ftq.ctx.Fields grbuild.label = fieldtype.Label grbuild.scan = grbuild.Scan return grbuild @@ -301,10 +296,10 @@ func (ftq *FieldTypeQuery) GroupBy(field string, fields ...string) *FieldTypeGro // Select(fieldtype.FieldInt). // Scan(ctx, &v) func (ftq *FieldTypeQuery) Select(fields ...string) *FieldTypeSelect { - ftq.fields = append(ftq.fields, fields...) + ftq.ctx.Fields = append(ftq.ctx.Fields, fields...) sbuild := &FieldTypeSelect{FieldTypeQuery: ftq} sbuild.label = fieldtype.Label - sbuild.flds, sbuild.scan = &ftq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ftq.ctx.Fields, sbuild.Scan return sbuild } @@ -324,7 +319,7 @@ func (ftq *FieldTypeQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range ftq.fields { + for _, f := range ftq.ctx.Fields { if !fieldtype.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -376,9 +371,9 @@ func (ftq *FieldTypeQuery) sqlCount(ctx context.Context) (int, error) { if len(ftq.modifiers) > 0 { _spec.Modifiers = ftq.modifiers } - _spec.Node.Columns = ftq.fields - if len(ftq.fields) > 0 { - _spec.Unique = ftq.unique != nil && *ftq.unique + _spec.Node.Columns = ftq.ctx.Fields + if len(ftq.ctx.Fields) > 0 { + _spec.Unique = ftq.ctx.Unique != nil && *ftq.ctx.Unique } return sqlgraph.CountNodes(ctx, ftq.driver, _spec) } @@ -396,10 +391,10 @@ func (ftq *FieldTypeQuery) querySpec() *sqlgraph.QuerySpec { From: ftq.sql, Unique: true, } - if unique := ftq.unique; unique != nil { + if unique := ftq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := ftq.fields; len(fields) > 0 { + if fields := ftq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, fieldtype.FieldID) for i := range fields { @@ -415,10 +410,10 @@ func (ftq *FieldTypeQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ftq.limit; limit != nil { + if limit := ftq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ftq.offset; offset != nil { + if offset := ftq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ftq.order; len(ps) > 0 { @@ -434,7 +429,7 @@ func (ftq *FieldTypeQuery) querySpec() *sqlgraph.QuerySpec { func (ftq *FieldTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ftq.driver.Dialect()) t1 := builder.Table(fieldtype.Table) - columns := ftq.fields + columns := ftq.ctx.Fields if len(columns) == 0 { columns = fieldtype.Columns } @@ -443,7 +438,7 @@ func (ftq *FieldTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ftq.sql selector.Select(selector.Columns(columns...)...) } - if ftq.unique != nil && *ftq.unique { + if ftq.ctx.Unique != nil && *ftq.ctx.Unique { selector.Distinct() } for _, m := range ftq.modifiers { @@ -455,12 +450,12 @@ func (ftq *FieldTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ftq.order { p(selector) } - if offset := ftq.offset; offset != nil { + if offset := ftq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ftq.limit; limit != nil { + if limit := ftq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -512,7 +507,7 @@ func (ftgb *FieldTypeGroupBy) Aggregate(fns ...AggregateFunc) *FieldTypeGroupBy // Scan applies the selector query and scans the result into the given value. func (ftgb *FieldTypeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFieldType, "GroupBy") + ctx = setContextOp(ctx, ftgb.build.ctx, "GroupBy") if err := ftgb.build.prepareQuery(ctx); err != nil { return err } @@ -560,7 +555,7 @@ func (fts *FieldTypeSelect) Aggregate(fns ...AggregateFunc) *FieldTypeSelect { // Scan applies the selector query and scans the result into the given value. func (fts *FieldTypeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFieldType, "Select") + ctx = setContextOp(ctx, fts.ctx, "Select") if err := fts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/file_query.go b/entc/integration/ent/file_query.go index 386e3dee1..33841824b 100644 --- a/entc/integration/ent/file_query.go +++ b/entc/integration/ent/file_query.go @@ -26,11 +26,8 @@ import ( // FileQuery is the builder for querying File entities. type FileQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.File withOwner *UserQuery @@ -52,20 +49,20 @@ func (fq *FileQuery) Where(ps ...predicate.File) *FileQuery { // Limit the number of records to be returned by this query. func (fq *FileQuery) Limit(limit int) *FileQuery { - fq.limit = &limit + fq.ctx.Limit = &limit return fq } // Offset to start from. func (fq *FileQuery) Offset(offset int) *FileQuery { - fq.offset = &offset + fq.ctx.Offset = &offset return fq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (fq *FileQuery) Unique(unique bool) *FileQuery { - fq.unique = &unique + fq.ctx.Unique = &unique return fq } @@ -144,7 +141,7 @@ func (fq *FileQuery) QueryField() *FieldTypeQuery { // First returns the first File entity from the query. // Returns a *NotFoundError when no File was found. func (fq *FileQuery) First(ctx context.Context) (*File, error) { - nodes, err := fq.Limit(1).All(newQueryContext(ctx, TypeFile, "First")) + nodes, err := fq.Limit(1).All(setContextOp(ctx, fq.ctx, "First")) if err != nil { return nil, err } @@ -167,7 +164,7 @@ func (fq *FileQuery) FirstX(ctx context.Context) *File { // Returns a *NotFoundError when no File ID was found. func (fq *FileQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = fq.Limit(1).IDs(newQueryContext(ctx, TypeFile, "FirstID")); err != nil { + if ids, err = fq.Limit(1).IDs(setContextOp(ctx, fq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -190,7 +187,7 @@ func (fq *FileQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one File entity is found. // Returns a *NotFoundError when no File entities are found. func (fq *FileQuery) Only(ctx context.Context) (*File, error) { - nodes, err := fq.Limit(2).All(newQueryContext(ctx, TypeFile, "Only")) + nodes, err := fq.Limit(2).All(setContextOp(ctx, fq.ctx, "Only")) if err != nil { return nil, err } @@ -218,7 +215,7 @@ func (fq *FileQuery) OnlyX(ctx context.Context) *File { // Returns a *NotFoundError when no entities are found. func (fq *FileQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = fq.Limit(2).IDs(newQueryContext(ctx, TypeFile, "OnlyID")); err != nil { + if ids, err = fq.Limit(2).IDs(setContextOp(ctx, fq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -243,7 +240,7 @@ func (fq *FileQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Files. func (fq *FileQuery) All(ctx context.Context) ([]*File, error) { - ctx = newQueryContext(ctx, TypeFile, "All") + ctx = setContextOp(ctx, fq.ctx, "All") if err := fq.prepareQuery(ctx); err != nil { return nil, err } @@ -263,7 +260,7 @@ func (fq *FileQuery) AllX(ctx context.Context) []*File { // IDs executes the query and returns a list of File IDs. func (fq *FileQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeFile, "IDs") + ctx = setContextOp(ctx, fq.ctx, "IDs") if err := fq.Select(file.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -281,7 +278,7 @@ func (fq *FileQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (fq *FileQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeFile, "Count") + ctx = setContextOp(ctx, fq.ctx, "Count") if err := fq.prepareQuery(ctx); err != nil { return 0, err } @@ -299,7 +296,7 @@ func (fq *FileQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (fq *FileQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeFile, "Exist") + ctx = setContextOp(ctx, fq.ctx, "Exist") switch _, err := fq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -327,8 +324,7 @@ func (fq *FileQuery) Clone() *FileQuery { } return &FileQuery{ config: fq.config, - limit: fq.limit, - offset: fq.offset, + ctx: fq.ctx.Clone(), order: append([]OrderFunc{}, fq.order...), inters: append([]Interceptor{}, fq.inters...), predicates: append([]predicate.File{}, fq.predicates...), @@ -336,9 +332,8 @@ func (fq *FileQuery) Clone() *FileQuery { withType: fq.withType.Clone(), withField: fq.withField.Clone(), // clone intermediate query. - sql: fq.sql.Clone(), - path: fq.path, - unique: fq.unique, + sql: fq.sql.Clone(), + path: fq.path, } } @@ -390,9 +385,9 @@ func (fq *FileQuery) WithField(opts ...func(*FieldTypeQuery)) *FileQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (fq *FileQuery) GroupBy(field string, fields ...string) *FileGroupBy { - fq.fields = append([]string{field}, fields...) + fq.ctx.Fields = append([]string{field}, fields...) grbuild := &FileGroupBy{build: fq} - grbuild.flds = &fq.fields + grbuild.flds = &fq.ctx.Fields grbuild.label = file.Label grbuild.scan = grbuild.Scan return grbuild @@ -411,10 +406,10 @@ func (fq *FileQuery) GroupBy(field string, fields ...string) *FileGroupBy { // Select(file.FieldSize). // Scan(ctx, &v) func (fq *FileQuery) Select(fields ...string) *FileSelect { - fq.fields = append(fq.fields, fields...) + fq.ctx.Fields = append(fq.ctx.Fields, fields...) sbuild := &FileSelect{FileQuery: fq} sbuild.label = file.Label - sbuild.flds, sbuild.scan = &fq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &fq.ctx.Fields, sbuild.Scan return sbuild } @@ -434,7 +429,7 @@ func (fq *FileQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range fq.fields { + for _, f := range fq.ctx.Fields { if !file.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -617,9 +612,9 @@ func (fq *FileQuery) sqlCount(ctx context.Context) (int, error) { if len(fq.modifiers) > 0 { _spec.Modifiers = fq.modifiers } - _spec.Node.Columns = fq.fields - if len(fq.fields) > 0 { - _spec.Unique = fq.unique != nil && *fq.unique + _spec.Node.Columns = fq.ctx.Fields + if len(fq.ctx.Fields) > 0 { + _spec.Unique = fq.ctx.Unique != nil && *fq.ctx.Unique } return sqlgraph.CountNodes(ctx, fq.driver, _spec) } @@ -637,10 +632,10 @@ func (fq *FileQuery) querySpec() *sqlgraph.QuerySpec { From: fq.sql, Unique: true, } - if unique := fq.unique; unique != nil { + if unique := fq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := fq.fields; len(fields) > 0 { + if fields := fq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, file.FieldID) for i := range fields { @@ -656,10 +651,10 @@ func (fq *FileQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := fq.limit; limit != nil { + if limit := fq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := fq.offset; offset != nil { + if offset := fq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := fq.order; len(ps) > 0 { @@ -675,7 +670,7 @@ func (fq *FileQuery) querySpec() *sqlgraph.QuerySpec { func (fq *FileQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(fq.driver.Dialect()) t1 := builder.Table(file.Table) - columns := fq.fields + columns := fq.ctx.Fields if len(columns) == 0 { columns = file.Columns } @@ -684,7 +679,7 @@ func (fq *FileQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = fq.sql selector.Select(selector.Columns(columns...)...) } - if fq.unique != nil && *fq.unique { + if fq.ctx.Unique != nil && *fq.ctx.Unique { selector.Distinct() } for _, m := range fq.modifiers { @@ -696,12 +691,12 @@ func (fq *FileQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range fq.order { p(selector) } - if offset := fq.offset; offset != nil { + if offset := fq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := fq.limit; limit != nil { + if limit := fq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -767,7 +762,7 @@ func (fgb *FileGroupBy) Aggregate(fns ...AggregateFunc) *FileGroupBy { // Scan applies the selector query and scans the result into the given value. func (fgb *FileGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFile, "GroupBy") + ctx = setContextOp(ctx, fgb.build.ctx, "GroupBy") if err := fgb.build.prepareQuery(ctx); err != nil { return err } @@ -815,7 +810,7 @@ func (fs *FileSelect) Aggregate(fns ...AggregateFunc) *FileSelect { // Scan applies the selector query and scans the result into the given value. func (fs *FileSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFile, "Select") + ctx = setContextOp(ctx, fs.ctx, "Select") if err := fs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/filetype_query.go b/entc/integration/ent/filetype_query.go index 4ec15bdbb..95cbdc0da 100644 --- a/entc/integration/ent/filetype_query.go +++ b/entc/integration/ent/filetype_query.go @@ -24,11 +24,8 @@ import ( // FileTypeQuery is the builder for querying FileType entities. type FileTypeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.FileType withFiles *FileQuery @@ -47,20 +44,20 @@ func (ftq *FileTypeQuery) Where(ps ...predicate.FileType) *FileTypeQuery { // Limit the number of records to be returned by this query. func (ftq *FileTypeQuery) Limit(limit int) *FileTypeQuery { - ftq.limit = &limit + ftq.ctx.Limit = &limit return ftq } // Offset to start from. func (ftq *FileTypeQuery) Offset(offset int) *FileTypeQuery { - ftq.offset = &offset + ftq.ctx.Offset = &offset return ftq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ftq *FileTypeQuery) Unique(unique bool) *FileTypeQuery { - ftq.unique = &unique + ftq.ctx.Unique = &unique return ftq } @@ -95,7 +92,7 @@ func (ftq *FileTypeQuery) QueryFiles() *FileQuery { // First returns the first FileType entity from the query. // Returns a *NotFoundError when no FileType was found. func (ftq *FileTypeQuery) First(ctx context.Context) (*FileType, error) { - nodes, err := ftq.Limit(1).All(newQueryContext(ctx, TypeFileType, "First")) + nodes, err := ftq.Limit(1).All(setContextOp(ctx, ftq.ctx, "First")) if err != nil { return nil, err } @@ -118,7 +115,7 @@ func (ftq *FileTypeQuery) FirstX(ctx context.Context) *FileType { // Returns a *NotFoundError when no FileType ID was found. func (ftq *FileTypeQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ftq.Limit(1).IDs(newQueryContext(ctx, TypeFileType, "FirstID")); err != nil { + if ids, err = ftq.Limit(1).IDs(setContextOp(ctx, ftq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -141,7 +138,7 @@ func (ftq *FileTypeQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one FileType entity is found. // Returns a *NotFoundError when no FileType entities are found. func (ftq *FileTypeQuery) Only(ctx context.Context) (*FileType, error) { - nodes, err := ftq.Limit(2).All(newQueryContext(ctx, TypeFileType, "Only")) + nodes, err := ftq.Limit(2).All(setContextOp(ctx, ftq.ctx, "Only")) if err != nil { return nil, err } @@ -169,7 +166,7 @@ func (ftq *FileTypeQuery) OnlyX(ctx context.Context) *FileType { // Returns a *NotFoundError when no entities are found. func (ftq *FileTypeQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ftq.Limit(2).IDs(newQueryContext(ctx, TypeFileType, "OnlyID")); err != nil { + if ids, err = ftq.Limit(2).IDs(setContextOp(ctx, ftq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -194,7 +191,7 @@ func (ftq *FileTypeQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of FileTypes. func (ftq *FileTypeQuery) All(ctx context.Context) ([]*FileType, error) { - ctx = newQueryContext(ctx, TypeFileType, "All") + ctx = setContextOp(ctx, ftq.ctx, "All") if err := ftq.prepareQuery(ctx); err != nil { return nil, err } @@ -214,7 +211,7 @@ func (ftq *FileTypeQuery) AllX(ctx context.Context) []*FileType { // IDs executes the query and returns a list of FileType IDs. func (ftq *FileTypeQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeFileType, "IDs") + ctx = setContextOp(ctx, ftq.ctx, "IDs") if err := ftq.Select(filetype.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -232,7 +229,7 @@ func (ftq *FileTypeQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (ftq *FileTypeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeFileType, "Count") + ctx = setContextOp(ctx, ftq.ctx, "Count") if err := ftq.prepareQuery(ctx); err != nil { return 0, err } @@ -250,7 +247,7 @@ func (ftq *FileTypeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ftq *FileTypeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeFileType, "Exist") + ctx = setContextOp(ctx, ftq.ctx, "Exist") switch _, err := ftq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -278,16 +275,14 @@ func (ftq *FileTypeQuery) Clone() *FileTypeQuery { } return &FileTypeQuery{ config: ftq.config, - limit: ftq.limit, - offset: ftq.offset, + ctx: ftq.ctx.Clone(), order: append([]OrderFunc{}, ftq.order...), inters: append([]Interceptor{}, ftq.inters...), predicates: append([]predicate.FileType{}, ftq.predicates...), withFiles: ftq.withFiles.Clone(), // clone intermediate query. - sql: ftq.sql.Clone(), - path: ftq.path, - unique: ftq.unique, + sql: ftq.sql.Clone(), + path: ftq.path, } } @@ -317,9 +312,9 @@ func (ftq *FileTypeQuery) WithFiles(opts ...func(*FileQuery)) *FileTypeQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ftq *FileTypeQuery) GroupBy(field string, fields ...string) *FileTypeGroupBy { - ftq.fields = append([]string{field}, fields...) + ftq.ctx.Fields = append([]string{field}, fields...) grbuild := &FileTypeGroupBy{build: ftq} - grbuild.flds = &ftq.fields + grbuild.flds = &ftq.ctx.Fields grbuild.label = filetype.Label grbuild.scan = grbuild.Scan return grbuild @@ -338,10 +333,10 @@ func (ftq *FileTypeQuery) GroupBy(field string, fields ...string) *FileTypeGroup // Select(filetype.FieldName). // Scan(ctx, &v) func (ftq *FileTypeQuery) Select(fields ...string) *FileTypeSelect { - ftq.fields = append(ftq.fields, fields...) + ftq.ctx.Fields = append(ftq.ctx.Fields, fields...) sbuild := &FileTypeSelect{FileTypeQuery: ftq} sbuild.label = filetype.Label - sbuild.flds, sbuild.scan = &ftq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ftq.ctx.Fields, sbuild.Scan return sbuild } @@ -361,7 +356,7 @@ func (ftq *FileTypeQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range ftq.fields { + for _, f := range ftq.ctx.Fields { if !filetype.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -459,9 +454,9 @@ func (ftq *FileTypeQuery) sqlCount(ctx context.Context) (int, error) { if len(ftq.modifiers) > 0 { _spec.Modifiers = ftq.modifiers } - _spec.Node.Columns = ftq.fields - if len(ftq.fields) > 0 { - _spec.Unique = ftq.unique != nil && *ftq.unique + _spec.Node.Columns = ftq.ctx.Fields + if len(ftq.ctx.Fields) > 0 { + _spec.Unique = ftq.ctx.Unique != nil && *ftq.ctx.Unique } return sqlgraph.CountNodes(ctx, ftq.driver, _spec) } @@ -479,10 +474,10 @@ func (ftq *FileTypeQuery) querySpec() *sqlgraph.QuerySpec { From: ftq.sql, Unique: true, } - if unique := ftq.unique; unique != nil { + if unique := ftq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := ftq.fields; len(fields) > 0 { + if fields := ftq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, filetype.FieldID) for i := range fields { @@ -498,10 +493,10 @@ func (ftq *FileTypeQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ftq.limit; limit != nil { + if limit := ftq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ftq.offset; offset != nil { + if offset := ftq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ftq.order; len(ps) > 0 { @@ -517,7 +512,7 @@ func (ftq *FileTypeQuery) querySpec() *sqlgraph.QuerySpec { func (ftq *FileTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ftq.driver.Dialect()) t1 := builder.Table(filetype.Table) - columns := ftq.fields + columns := ftq.ctx.Fields if len(columns) == 0 { columns = filetype.Columns } @@ -526,7 +521,7 @@ func (ftq *FileTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ftq.sql selector.Select(selector.Columns(columns...)...) } - if ftq.unique != nil && *ftq.unique { + if ftq.ctx.Unique != nil && *ftq.ctx.Unique { selector.Distinct() } for _, m := range ftq.modifiers { @@ -538,12 +533,12 @@ func (ftq *FileTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ftq.order { p(selector) } - if offset := ftq.offset; offset != nil { + if offset := ftq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ftq.limit; limit != nil { + if limit := ftq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -609,7 +604,7 @@ func (ftgb *FileTypeGroupBy) Aggregate(fns ...AggregateFunc) *FileTypeGroupBy { // Scan applies the selector query and scans the result into the given value. func (ftgb *FileTypeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFileType, "GroupBy") + ctx = setContextOp(ctx, ftgb.build.ctx, "GroupBy") if err := ftgb.build.prepareQuery(ctx); err != nil { return err } @@ -657,7 +652,7 @@ func (fts *FileTypeSelect) Aggregate(fns ...AggregateFunc) *FileTypeSelect { // Scan applies the selector query and scans the result into the given value. func (fts *FileTypeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFileType, "Select") + ctx = setContextOp(ctx, fts.ctx, "Select") if err := fts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/goods_query.go b/entc/integration/ent/goods_query.go index 0795b1775..27c9c86fa 100644 --- a/entc/integration/ent/goods_query.go +++ b/entc/integration/ent/goods_query.go @@ -22,11 +22,8 @@ import ( // GoodsQuery is the builder for querying Goods entities. type GoodsQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Goods modifiers []func(*sql.Selector) @@ -43,20 +40,20 @@ func (gq *GoodsQuery) Where(ps ...predicate.Goods) *GoodsQuery { // Limit the number of records to be returned by this query. func (gq *GoodsQuery) Limit(limit int) *GoodsQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GoodsQuery) Offset(offset int) *GoodsQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GoodsQuery) Unique(unique bool) *GoodsQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -69,7 +66,7 @@ func (gq *GoodsQuery) Order(o ...OrderFunc) *GoodsQuery { // First returns the first Goods entity from the query. // Returns a *NotFoundError when no Goods was found. func (gq *GoodsQuery) First(ctx context.Context) (*Goods, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGoods, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -92,7 +89,7 @@ func (gq *GoodsQuery) FirstX(ctx context.Context) *Goods { // Returns a *NotFoundError when no Goods ID was found. func (gq *GoodsQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGoods, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -115,7 +112,7 @@ func (gq *GoodsQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Goods entity is found. // Returns a *NotFoundError when no Goods entities are found. func (gq *GoodsQuery) Only(ctx context.Context) (*Goods, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGoods, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -143,7 +140,7 @@ func (gq *GoodsQuery) OnlyX(ctx context.Context) *Goods { // Returns a *NotFoundError when no entities are found. func (gq *GoodsQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGoods, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -168,7 +165,7 @@ func (gq *GoodsQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of GoodsSlice. func (gq *GoodsQuery) All(ctx context.Context) ([]*Goods, error) { - ctx = newQueryContext(ctx, TypeGoods, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -188,7 +185,7 @@ func (gq *GoodsQuery) AllX(ctx context.Context) []*Goods { // IDs executes the query and returns a list of Goods IDs. func (gq *GoodsQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGoods, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(goods.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -206,7 +203,7 @@ func (gq *GoodsQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gq *GoodsQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGoods, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -224,7 +221,7 @@ func (gq *GoodsQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GoodsQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGoods, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -252,24 +249,22 @@ func (gq *GoodsQuery) Clone() *GoodsQuery { } return &GoodsQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Goods{}, gq.predicates...), // clone intermediate query. - sql: gq.sql.Clone(), - path: gq.path, - unique: gq.unique, + sql: gq.sql.Clone(), + path: gq.path, } } // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (gq *GoodsQuery) GroupBy(field string, fields ...string) *GoodsGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GoodsGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = goods.Label grbuild.scan = grbuild.Scan return grbuild @@ -278,10 +273,10 @@ func (gq *GoodsQuery) GroupBy(field string, fields ...string) *GoodsGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (gq *GoodsQuery) Select(fields ...string) *GoodsSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GoodsSelect{GoodsQuery: gq} sbuild.label = goods.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -301,7 +296,7 @@ func (gq *GoodsQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gq.fields { + for _, f := range gq.ctx.Fields { if !goods.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -349,9 +344,9 @@ func (gq *GoodsQuery) sqlCount(ctx context.Context) (int, error) { if len(gq.modifiers) > 0 { _spec.Modifiers = gq.modifiers } - _spec.Node.Columns = gq.fields - if len(gq.fields) > 0 { - _spec.Unique = gq.unique != nil && *gq.unique + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique } return sqlgraph.CountNodes(ctx, gq.driver, _spec) } @@ -369,10 +364,10 @@ func (gq *GoodsQuery) querySpec() *sqlgraph.QuerySpec { From: gq.sql, Unique: true, } - if unique := gq.unique; unique != nil { + if unique := gq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gq.fields; len(fields) > 0 { + if fields := gq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, goods.FieldID) for i := range fields { @@ -388,10 +383,10 @@ func (gq *GoodsQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gq.order; len(ps) > 0 { @@ -407,7 +402,7 @@ func (gq *GoodsQuery) querySpec() *sqlgraph.QuerySpec { func (gq *GoodsQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(goods.Table) - columns := gq.fields + columns := gq.ctx.Fields if len(columns) == 0 { columns = goods.Columns } @@ -416,7 +411,7 @@ func (gq *GoodsQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gq.sql selector.Select(selector.Columns(columns...)...) } - if gq.unique != nil && *gq.unique { + if gq.ctx.Unique != nil && *gq.ctx.Unique { selector.Distinct() } for _, m := range gq.modifiers { @@ -428,12 +423,12 @@ func (gq *GoodsQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gq.order { p(selector) } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -485,7 +480,7 @@ func (ggb *GoodsGroupBy) Aggregate(fns ...AggregateFunc) *GoodsGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GoodsGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGoods, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -533,7 +528,7 @@ func (gs *GoodsSelect) Aggregate(fns ...AggregateFunc) *GoodsSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GoodsSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGoods, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/group_query.go b/entc/integration/ent/group_query.go index b2ab85b3a..022519999 100644 --- a/entc/integration/ent/group_query.go +++ b/entc/integration/ent/group_query.go @@ -26,11 +26,8 @@ import ( // GroupQuery is the builder for querying Group entities. type GroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Group withFiles *FileQuery @@ -55,20 +52,20 @@ func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { // Limit the number of records to be returned by this query. func (gq *GroupQuery) Limit(limit int) *GroupQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GroupQuery) Offset(offset int) *GroupQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GroupQuery) Unique(unique bool) *GroupQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -169,7 +166,7 @@ func (gq *GroupQuery) QueryInfo() *GroupInfoQuery { // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGroup, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -192,7 +189,7 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group { // Returns a *NotFoundError when no Group ID was found. func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGroup, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -215,7 +212,7 @@ func (gq *GroupQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Group entity is found. // Returns a *NotFoundError when no Group entities are found. func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGroup, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -243,7 +240,7 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { // Returns a *NotFoundError when no entities are found. func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGroup, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -268,7 +265,7 @@ func (gq *GroupQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Groups. func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { - ctx = newQueryContext(ctx, TypeGroup, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -288,7 +285,7 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group { // IDs executes the query and returns a list of Group IDs. func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroup, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -306,7 +303,7 @@ func (gq *GroupQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gq *GroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroup, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -324,7 +321,7 @@ func (gq *GroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroup, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -352,8 +349,7 @@ func (gq *GroupQuery) Clone() *GroupQuery { } return &GroupQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Group{}, gq.predicates...), @@ -362,9 +358,8 @@ func (gq *GroupQuery) Clone() *GroupQuery { withUsers: gq.withUsers.Clone(), withInfo: gq.withInfo.Clone(), // clone intermediate query. - sql: gq.sql.Clone(), - path: gq.path, - unique: gq.unique, + sql: gq.sql.Clone(), + path: gq.path, } } @@ -427,9 +422,9 @@ func (gq *GroupQuery) WithInfo(opts ...func(*GroupInfoQuery)) *GroupQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = group.Label grbuild.scan = grbuild.Scan return grbuild @@ -448,10 +443,10 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Select(group.FieldActive). // Scan(ctx, &v) func (gq *GroupQuery) Select(fields ...string) *GroupSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GroupSelect{GroupQuery: gq} sbuild.label = group.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -471,7 +466,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gq.fields { + for _, f := range gq.ctx.Fields { if !group.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -734,9 +729,9 @@ func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { if len(gq.modifiers) > 0 { _spec.Modifiers = gq.modifiers } - _spec.Node.Columns = gq.fields - if len(gq.fields) > 0 { - _spec.Unique = gq.unique != nil && *gq.unique + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique } return sqlgraph.CountNodes(ctx, gq.driver, _spec) } @@ -754,10 +749,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { From: gq.sql, Unique: true, } - if unique := gq.unique; unique != nil { + if unique := gq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gq.fields; len(fields) > 0 { + if fields := gq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) for i := range fields { @@ -773,10 +768,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gq.order; len(ps) > 0 { @@ -792,7 +787,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) - columns := gq.fields + columns := gq.ctx.Fields if len(columns) == 0 { columns = group.Columns } @@ -801,7 +796,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gq.sql selector.Select(selector.Columns(columns...)...) } - if gq.unique != nil && *gq.unique { + if gq.ctx.Unique != nil && *gq.ctx.Unique { selector.Distinct() } for _, m := range gq.modifiers { @@ -813,12 +808,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gq.order { p(selector) } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -912,7 +907,7 @@ func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -960,7 +955,7 @@ func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/groupinfo_query.go b/entc/integration/ent/groupinfo_query.go index 3587c22d1..f917e6a0b 100644 --- a/entc/integration/ent/groupinfo_query.go +++ b/entc/integration/ent/groupinfo_query.go @@ -24,11 +24,8 @@ import ( // GroupInfoQuery is the builder for querying GroupInfo entities. type GroupInfoQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.GroupInfo withGroups *GroupQuery @@ -47,20 +44,20 @@ func (giq *GroupInfoQuery) Where(ps ...predicate.GroupInfo) *GroupInfoQuery { // Limit the number of records to be returned by this query. func (giq *GroupInfoQuery) Limit(limit int) *GroupInfoQuery { - giq.limit = &limit + giq.ctx.Limit = &limit return giq } // Offset to start from. func (giq *GroupInfoQuery) Offset(offset int) *GroupInfoQuery { - giq.offset = &offset + giq.ctx.Offset = &offset return giq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (giq *GroupInfoQuery) Unique(unique bool) *GroupInfoQuery { - giq.unique = &unique + giq.ctx.Unique = &unique return giq } @@ -95,7 +92,7 @@ func (giq *GroupInfoQuery) QueryGroups() *GroupQuery { // First returns the first GroupInfo entity from the query. // Returns a *NotFoundError when no GroupInfo was found. func (giq *GroupInfoQuery) First(ctx context.Context) (*GroupInfo, error) { - nodes, err := giq.Limit(1).All(newQueryContext(ctx, TypeGroupInfo, "First")) + nodes, err := giq.Limit(1).All(setContextOp(ctx, giq.ctx, "First")) if err != nil { return nil, err } @@ -118,7 +115,7 @@ func (giq *GroupInfoQuery) FirstX(ctx context.Context) *GroupInfo { // Returns a *NotFoundError when no GroupInfo ID was found. func (giq *GroupInfoQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = giq.Limit(1).IDs(newQueryContext(ctx, TypeGroupInfo, "FirstID")); err != nil { + if ids, err = giq.Limit(1).IDs(setContextOp(ctx, giq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -141,7 +138,7 @@ func (giq *GroupInfoQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one GroupInfo entity is found. // Returns a *NotFoundError when no GroupInfo entities are found. func (giq *GroupInfoQuery) Only(ctx context.Context) (*GroupInfo, error) { - nodes, err := giq.Limit(2).All(newQueryContext(ctx, TypeGroupInfo, "Only")) + nodes, err := giq.Limit(2).All(setContextOp(ctx, giq.ctx, "Only")) if err != nil { return nil, err } @@ -169,7 +166,7 @@ func (giq *GroupInfoQuery) OnlyX(ctx context.Context) *GroupInfo { // Returns a *NotFoundError when no entities are found. func (giq *GroupInfoQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = giq.Limit(2).IDs(newQueryContext(ctx, TypeGroupInfo, "OnlyID")); err != nil { + if ids, err = giq.Limit(2).IDs(setContextOp(ctx, giq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -194,7 +191,7 @@ func (giq *GroupInfoQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of GroupInfos. func (giq *GroupInfoQuery) All(ctx context.Context) ([]*GroupInfo, error) { - ctx = newQueryContext(ctx, TypeGroupInfo, "All") + ctx = setContextOp(ctx, giq.ctx, "All") if err := giq.prepareQuery(ctx); err != nil { return nil, err } @@ -214,7 +211,7 @@ func (giq *GroupInfoQuery) AllX(ctx context.Context) []*GroupInfo { // IDs executes the query and returns a list of GroupInfo IDs. func (giq *GroupInfoQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroupInfo, "IDs") + ctx = setContextOp(ctx, giq.ctx, "IDs") if err := giq.Select(groupinfo.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -232,7 +229,7 @@ func (giq *GroupInfoQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (giq *GroupInfoQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroupInfo, "Count") + ctx = setContextOp(ctx, giq.ctx, "Count") if err := giq.prepareQuery(ctx); err != nil { return 0, err } @@ -250,7 +247,7 @@ func (giq *GroupInfoQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (giq *GroupInfoQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroupInfo, "Exist") + ctx = setContextOp(ctx, giq.ctx, "Exist") switch _, err := giq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -278,16 +275,14 @@ func (giq *GroupInfoQuery) Clone() *GroupInfoQuery { } return &GroupInfoQuery{ config: giq.config, - limit: giq.limit, - offset: giq.offset, + ctx: giq.ctx.Clone(), order: append([]OrderFunc{}, giq.order...), inters: append([]Interceptor{}, giq.inters...), predicates: append([]predicate.GroupInfo{}, giq.predicates...), withGroups: giq.withGroups.Clone(), // clone intermediate query. - sql: giq.sql.Clone(), - path: giq.path, - unique: giq.unique, + sql: giq.sql.Clone(), + path: giq.path, } } @@ -317,9 +312,9 @@ func (giq *GroupInfoQuery) WithGroups(opts ...func(*GroupQuery)) *GroupInfoQuery // Aggregate(ent.Count()). // Scan(ctx, &v) func (giq *GroupInfoQuery) GroupBy(field string, fields ...string) *GroupInfoGroupBy { - giq.fields = append([]string{field}, fields...) + giq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupInfoGroupBy{build: giq} - grbuild.flds = &giq.fields + grbuild.flds = &giq.ctx.Fields grbuild.label = groupinfo.Label grbuild.scan = grbuild.Scan return grbuild @@ -338,10 +333,10 @@ func (giq *GroupInfoQuery) GroupBy(field string, fields ...string) *GroupInfoGro // Select(groupinfo.FieldDesc). // Scan(ctx, &v) func (giq *GroupInfoQuery) Select(fields ...string) *GroupInfoSelect { - giq.fields = append(giq.fields, fields...) + giq.ctx.Fields = append(giq.ctx.Fields, fields...) sbuild := &GroupInfoSelect{GroupInfoQuery: giq} sbuild.label = groupinfo.Label - sbuild.flds, sbuild.scan = &giq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &giq.ctx.Fields, sbuild.Scan return sbuild } @@ -361,7 +356,7 @@ func (giq *GroupInfoQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range giq.fields { + for _, f := range giq.ctx.Fields { if !groupinfo.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -459,9 +454,9 @@ func (giq *GroupInfoQuery) sqlCount(ctx context.Context) (int, error) { if len(giq.modifiers) > 0 { _spec.Modifiers = giq.modifiers } - _spec.Node.Columns = giq.fields - if len(giq.fields) > 0 { - _spec.Unique = giq.unique != nil && *giq.unique + _spec.Node.Columns = giq.ctx.Fields + if len(giq.ctx.Fields) > 0 { + _spec.Unique = giq.ctx.Unique != nil && *giq.ctx.Unique } return sqlgraph.CountNodes(ctx, giq.driver, _spec) } @@ -479,10 +474,10 @@ func (giq *GroupInfoQuery) querySpec() *sqlgraph.QuerySpec { From: giq.sql, Unique: true, } - if unique := giq.unique; unique != nil { + if unique := giq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := giq.fields; len(fields) > 0 { + if fields := giq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, groupinfo.FieldID) for i := range fields { @@ -498,10 +493,10 @@ func (giq *GroupInfoQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := giq.limit; limit != nil { + if limit := giq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := giq.offset; offset != nil { + if offset := giq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := giq.order; len(ps) > 0 { @@ -517,7 +512,7 @@ func (giq *GroupInfoQuery) querySpec() *sqlgraph.QuerySpec { func (giq *GroupInfoQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(giq.driver.Dialect()) t1 := builder.Table(groupinfo.Table) - columns := giq.fields + columns := giq.ctx.Fields if len(columns) == 0 { columns = groupinfo.Columns } @@ -526,7 +521,7 @@ func (giq *GroupInfoQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = giq.sql selector.Select(selector.Columns(columns...)...) } - if giq.unique != nil && *giq.unique { + if giq.ctx.Unique != nil && *giq.ctx.Unique { selector.Distinct() } for _, m := range giq.modifiers { @@ -538,12 +533,12 @@ func (giq *GroupInfoQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range giq.order { p(selector) } - if offset := giq.offset; offset != nil { + if offset := giq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := giq.limit; limit != nil { + if limit := giq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -609,7 +604,7 @@ func (gigb *GroupInfoGroupBy) Aggregate(fns ...AggregateFunc) *GroupInfoGroupBy // Scan applies the selector query and scans the result into the given value. func (gigb *GroupInfoGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroupInfo, "GroupBy") + ctx = setContextOp(ctx, gigb.build.ctx, "GroupBy") if err := gigb.build.prepareQuery(ctx); err != nil { return err } @@ -657,7 +652,7 @@ func (gis *GroupInfoSelect) Aggregate(fns ...AggregateFunc) *GroupInfoSelect { // Scan applies the selector query and scans the result into the given value. func (gis *GroupInfoSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroupInfo, "Select") + ctx = setContextOp(ctx, gis.ctx, "Select") if err := gis.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/item_query.go b/entc/integration/ent/item_query.go index 0424f5f00..7335e6ce6 100644 --- a/entc/integration/ent/item_query.go +++ b/entc/integration/ent/item_query.go @@ -22,11 +22,8 @@ import ( // ItemQuery is the builder for querying Item entities. type ItemQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Item modifiers []func(*sql.Selector) @@ -43,20 +40,20 @@ func (iq *ItemQuery) Where(ps ...predicate.Item) *ItemQuery { // Limit the number of records to be returned by this query. func (iq *ItemQuery) Limit(limit int) *ItemQuery { - iq.limit = &limit + iq.ctx.Limit = &limit return iq } // Offset to start from. func (iq *ItemQuery) Offset(offset int) *ItemQuery { - iq.offset = &offset + iq.ctx.Offset = &offset return iq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (iq *ItemQuery) Unique(unique bool) *ItemQuery { - iq.unique = &unique + iq.ctx.Unique = &unique return iq } @@ -69,7 +66,7 @@ func (iq *ItemQuery) Order(o ...OrderFunc) *ItemQuery { // First returns the first Item entity from the query. // Returns a *NotFoundError when no Item was found. func (iq *ItemQuery) First(ctx context.Context) (*Item, error) { - nodes, err := iq.Limit(1).All(newQueryContext(ctx, TypeItem, "First")) + nodes, err := iq.Limit(1).All(setContextOp(ctx, iq.ctx, "First")) if err != nil { return nil, err } @@ -92,7 +89,7 @@ func (iq *ItemQuery) FirstX(ctx context.Context) *Item { // Returns a *NotFoundError when no Item ID was found. func (iq *ItemQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = iq.Limit(1).IDs(newQueryContext(ctx, TypeItem, "FirstID")); err != nil { + if ids, err = iq.Limit(1).IDs(setContextOp(ctx, iq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -115,7 +112,7 @@ func (iq *ItemQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Item entity is found. // Returns a *NotFoundError when no Item entities are found. func (iq *ItemQuery) Only(ctx context.Context) (*Item, error) { - nodes, err := iq.Limit(2).All(newQueryContext(ctx, TypeItem, "Only")) + nodes, err := iq.Limit(2).All(setContextOp(ctx, iq.ctx, "Only")) if err != nil { return nil, err } @@ -143,7 +140,7 @@ func (iq *ItemQuery) OnlyX(ctx context.Context) *Item { // Returns a *NotFoundError when no entities are found. func (iq *ItemQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = iq.Limit(2).IDs(newQueryContext(ctx, TypeItem, "OnlyID")); err != nil { + if ids, err = iq.Limit(2).IDs(setContextOp(ctx, iq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -168,7 +165,7 @@ func (iq *ItemQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Items. func (iq *ItemQuery) All(ctx context.Context) ([]*Item, error) { - ctx = newQueryContext(ctx, TypeItem, "All") + ctx = setContextOp(ctx, iq.ctx, "All") if err := iq.prepareQuery(ctx); err != nil { return nil, err } @@ -188,7 +185,7 @@ func (iq *ItemQuery) AllX(ctx context.Context) []*Item { // IDs executes the query and returns a list of Item IDs. func (iq *ItemQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeItem, "IDs") + ctx = setContextOp(ctx, iq.ctx, "IDs") if err := iq.Select(item.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -206,7 +203,7 @@ func (iq *ItemQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (iq *ItemQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeItem, "Count") + ctx = setContextOp(ctx, iq.ctx, "Count") if err := iq.prepareQuery(ctx); err != nil { return 0, err } @@ -224,7 +221,7 @@ func (iq *ItemQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (iq *ItemQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeItem, "Exist") + ctx = setContextOp(ctx, iq.ctx, "Exist") switch _, err := iq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -252,15 +249,13 @@ func (iq *ItemQuery) Clone() *ItemQuery { } return &ItemQuery{ config: iq.config, - limit: iq.limit, - offset: iq.offset, + ctx: iq.ctx.Clone(), order: append([]OrderFunc{}, iq.order...), inters: append([]Interceptor{}, iq.inters...), predicates: append([]predicate.Item{}, iq.predicates...), // clone intermediate query. - sql: iq.sql.Clone(), - path: iq.path, - unique: iq.unique, + sql: iq.sql.Clone(), + path: iq.path, } } @@ -279,9 +274,9 @@ func (iq *ItemQuery) Clone() *ItemQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (iq *ItemQuery) GroupBy(field string, fields ...string) *ItemGroupBy { - iq.fields = append([]string{field}, fields...) + iq.ctx.Fields = append([]string{field}, fields...) grbuild := &ItemGroupBy{build: iq} - grbuild.flds = &iq.fields + grbuild.flds = &iq.ctx.Fields grbuild.label = item.Label grbuild.scan = grbuild.Scan return grbuild @@ -300,10 +295,10 @@ func (iq *ItemQuery) GroupBy(field string, fields ...string) *ItemGroupBy { // Select(item.FieldText). // Scan(ctx, &v) func (iq *ItemQuery) Select(fields ...string) *ItemSelect { - iq.fields = append(iq.fields, fields...) + iq.ctx.Fields = append(iq.ctx.Fields, fields...) sbuild := &ItemSelect{ItemQuery: iq} sbuild.label = item.Label - sbuild.flds, sbuild.scan = &iq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &iq.ctx.Fields, sbuild.Scan return sbuild } @@ -323,7 +318,7 @@ func (iq *ItemQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range iq.fields { + for _, f := range iq.ctx.Fields { if !item.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -371,9 +366,9 @@ func (iq *ItemQuery) sqlCount(ctx context.Context) (int, error) { if len(iq.modifiers) > 0 { _spec.Modifiers = iq.modifiers } - _spec.Node.Columns = iq.fields - if len(iq.fields) > 0 { - _spec.Unique = iq.unique != nil && *iq.unique + _spec.Node.Columns = iq.ctx.Fields + if len(iq.ctx.Fields) > 0 { + _spec.Unique = iq.ctx.Unique != nil && *iq.ctx.Unique } return sqlgraph.CountNodes(ctx, iq.driver, _spec) } @@ -391,10 +386,10 @@ func (iq *ItemQuery) querySpec() *sqlgraph.QuerySpec { From: iq.sql, Unique: true, } - if unique := iq.unique; unique != nil { + if unique := iq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := iq.fields; len(fields) > 0 { + if fields := iq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, item.FieldID) for i := range fields { @@ -410,10 +405,10 @@ func (iq *ItemQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := iq.limit; limit != nil { + if limit := iq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := iq.offset; offset != nil { + if offset := iq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := iq.order; len(ps) > 0 { @@ -429,7 +424,7 @@ func (iq *ItemQuery) querySpec() *sqlgraph.QuerySpec { func (iq *ItemQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(iq.driver.Dialect()) t1 := builder.Table(item.Table) - columns := iq.fields + columns := iq.ctx.Fields if len(columns) == 0 { columns = item.Columns } @@ -438,7 +433,7 @@ func (iq *ItemQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = iq.sql selector.Select(selector.Columns(columns...)...) } - if iq.unique != nil && *iq.unique { + if iq.ctx.Unique != nil && *iq.ctx.Unique { selector.Distinct() } for _, m := range iq.modifiers { @@ -450,12 +445,12 @@ func (iq *ItemQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range iq.order { p(selector) } - if offset := iq.offset; offset != nil { + if offset := iq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := iq.limit; limit != nil { + if limit := iq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -507,7 +502,7 @@ func (igb *ItemGroupBy) Aggregate(fns ...AggregateFunc) *ItemGroupBy { // Scan applies the selector query and scans the result into the given value. func (igb *ItemGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeItem, "GroupBy") + ctx = setContextOp(ctx, igb.build.ctx, "GroupBy") if err := igb.build.prepareQuery(ctx); err != nil { return err } @@ -555,7 +550,7 @@ func (is *ItemSelect) Aggregate(fns ...AggregateFunc) *ItemSelect { // Scan applies the selector query and scans the result into the given value. func (is *ItemSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeItem, "Select") + ctx = setContextOp(ctx, is.ctx, "Select") if err := is.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/license_query.go b/entc/integration/ent/license_query.go index 13e1f2a65..949a77345 100644 --- a/entc/integration/ent/license_query.go +++ b/entc/integration/ent/license_query.go @@ -22,11 +22,8 @@ import ( // LicenseQuery is the builder for querying License entities. type LicenseQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.License modifiers []func(*sql.Selector) @@ -43,20 +40,20 @@ func (lq *LicenseQuery) Where(ps ...predicate.License) *LicenseQuery { // Limit the number of records to be returned by this query. func (lq *LicenseQuery) Limit(limit int) *LicenseQuery { - lq.limit = &limit + lq.ctx.Limit = &limit return lq } // Offset to start from. func (lq *LicenseQuery) Offset(offset int) *LicenseQuery { - lq.offset = &offset + lq.ctx.Offset = &offset return lq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (lq *LicenseQuery) Unique(unique bool) *LicenseQuery { - lq.unique = &unique + lq.ctx.Unique = &unique return lq } @@ -69,7 +66,7 @@ func (lq *LicenseQuery) Order(o ...OrderFunc) *LicenseQuery { // First returns the first License entity from the query. // Returns a *NotFoundError when no License was found. func (lq *LicenseQuery) First(ctx context.Context) (*License, error) { - nodes, err := lq.Limit(1).All(newQueryContext(ctx, TypeLicense, "First")) + nodes, err := lq.Limit(1).All(setContextOp(ctx, lq.ctx, "First")) if err != nil { return nil, err } @@ -92,7 +89,7 @@ func (lq *LicenseQuery) FirstX(ctx context.Context) *License { // Returns a *NotFoundError when no License ID was found. func (lq *LicenseQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = lq.Limit(1).IDs(newQueryContext(ctx, TypeLicense, "FirstID")); err != nil { + if ids, err = lq.Limit(1).IDs(setContextOp(ctx, lq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -115,7 +112,7 @@ func (lq *LicenseQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one License entity is found. // Returns a *NotFoundError when no License entities are found. func (lq *LicenseQuery) Only(ctx context.Context) (*License, error) { - nodes, err := lq.Limit(2).All(newQueryContext(ctx, TypeLicense, "Only")) + nodes, err := lq.Limit(2).All(setContextOp(ctx, lq.ctx, "Only")) if err != nil { return nil, err } @@ -143,7 +140,7 @@ func (lq *LicenseQuery) OnlyX(ctx context.Context) *License { // Returns a *NotFoundError when no entities are found. func (lq *LicenseQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = lq.Limit(2).IDs(newQueryContext(ctx, TypeLicense, "OnlyID")); err != nil { + if ids, err = lq.Limit(2).IDs(setContextOp(ctx, lq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -168,7 +165,7 @@ func (lq *LicenseQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Licenses. func (lq *LicenseQuery) All(ctx context.Context) ([]*License, error) { - ctx = newQueryContext(ctx, TypeLicense, "All") + ctx = setContextOp(ctx, lq.ctx, "All") if err := lq.prepareQuery(ctx); err != nil { return nil, err } @@ -188,7 +185,7 @@ func (lq *LicenseQuery) AllX(ctx context.Context) []*License { // IDs executes the query and returns a list of License IDs. func (lq *LicenseQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeLicense, "IDs") + ctx = setContextOp(ctx, lq.ctx, "IDs") if err := lq.Select(license.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -206,7 +203,7 @@ func (lq *LicenseQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (lq *LicenseQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeLicense, "Count") + ctx = setContextOp(ctx, lq.ctx, "Count") if err := lq.prepareQuery(ctx); err != nil { return 0, err } @@ -224,7 +221,7 @@ func (lq *LicenseQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (lq *LicenseQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeLicense, "Exist") + ctx = setContextOp(ctx, lq.ctx, "Exist") switch _, err := lq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -252,15 +249,13 @@ func (lq *LicenseQuery) Clone() *LicenseQuery { } return &LicenseQuery{ config: lq.config, - limit: lq.limit, - offset: lq.offset, + ctx: lq.ctx.Clone(), order: append([]OrderFunc{}, lq.order...), inters: append([]Interceptor{}, lq.inters...), predicates: append([]predicate.License{}, lq.predicates...), // clone intermediate query. - sql: lq.sql.Clone(), - path: lq.path, - unique: lq.unique, + sql: lq.sql.Clone(), + path: lq.path, } } @@ -279,9 +274,9 @@ func (lq *LicenseQuery) Clone() *LicenseQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (lq *LicenseQuery) GroupBy(field string, fields ...string) *LicenseGroupBy { - lq.fields = append([]string{field}, fields...) + lq.ctx.Fields = append([]string{field}, fields...) grbuild := &LicenseGroupBy{build: lq} - grbuild.flds = &lq.fields + grbuild.flds = &lq.ctx.Fields grbuild.label = license.Label grbuild.scan = grbuild.Scan return grbuild @@ -300,10 +295,10 @@ func (lq *LicenseQuery) GroupBy(field string, fields ...string) *LicenseGroupBy // Select(license.FieldCreateTime). // Scan(ctx, &v) func (lq *LicenseQuery) Select(fields ...string) *LicenseSelect { - lq.fields = append(lq.fields, fields...) + lq.ctx.Fields = append(lq.ctx.Fields, fields...) sbuild := &LicenseSelect{LicenseQuery: lq} sbuild.label = license.Label - sbuild.flds, sbuild.scan = &lq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &lq.ctx.Fields, sbuild.Scan return sbuild } @@ -323,7 +318,7 @@ func (lq *LicenseQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range lq.fields { + for _, f := range lq.ctx.Fields { if !license.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -371,9 +366,9 @@ func (lq *LicenseQuery) sqlCount(ctx context.Context) (int, error) { if len(lq.modifiers) > 0 { _spec.Modifiers = lq.modifiers } - _spec.Node.Columns = lq.fields - if len(lq.fields) > 0 { - _spec.Unique = lq.unique != nil && *lq.unique + _spec.Node.Columns = lq.ctx.Fields + if len(lq.ctx.Fields) > 0 { + _spec.Unique = lq.ctx.Unique != nil && *lq.ctx.Unique } return sqlgraph.CountNodes(ctx, lq.driver, _spec) } @@ -391,10 +386,10 @@ func (lq *LicenseQuery) querySpec() *sqlgraph.QuerySpec { From: lq.sql, Unique: true, } - if unique := lq.unique; unique != nil { + if unique := lq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := lq.fields; len(fields) > 0 { + if fields := lq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, license.FieldID) for i := range fields { @@ -410,10 +405,10 @@ func (lq *LicenseQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := lq.limit; limit != nil { + if limit := lq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := lq.offset; offset != nil { + if offset := lq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := lq.order; len(ps) > 0 { @@ -429,7 +424,7 @@ func (lq *LicenseQuery) querySpec() *sqlgraph.QuerySpec { func (lq *LicenseQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(lq.driver.Dialect()) t1 := builder.Table(license.Table) - columns := lq.fields + columns := lq.ctx.Fields if len(columns) == 0 { columns = license.Columns } @@ -438,7 +433,7 @@ func (lq *LicenseQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = lq.sql selector.Select(selector.Columns(columns...)...) } - if lq.unique != nil && *lq.unique { + if lq.ctx.Unique != nil && *lq.ctx.Unique { selector.Distinct() } for _, m := range lq.modifiers { @@ -450,12 +445,12 @@ func (lq *LicenseQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range lq.order { p(selector) } - if offset := lq.offset; offset != nil { + if offset := lq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := lq.limit; limit != nil { + if limit := lq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -507,7 +502,7 @@ func (lgb *LicenseGroupBy) Aggregate(fns ...AggregateFunc) *LicenseGroupBy { // Scan applies the selector query and scans the result into the given value. func (lgb *LicenseGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeLicense, "GroupBy") + ctx = setContextOp(ctx, lgb.build.ctx, "GroupBy") if err := lgb.build.prepareQuery(ctx); err != nil { return err } @@ -555,7 +550,7 @@ func (ls *LicenseSelect) Aggregate(fns ...AggregateFunc) *LicenseSelect { // Scan applies the selector query and scans the result into the given value. func (ls *LicenseSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeLicense, "Select") + ctx = setContextOp(ctx, ls.ctx, "Select") if err := ls.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/node_query.go b/entc/integration/ent/node_query.go index 8d637a4ad..67097ebaa 100644 --- a/entc/integration/ent/node_query.go +++ b/entc/integration/ent/node_query.go @@ -23,11 +23,8 @@ import ( // NodeQuery is the builder for querying Node entities. type NodeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Node withPrev *NodeQuery @@ -47,20 +44,20 @@ func (nq *NodeQuery) Where(ps ...predicate.Node) *NodeQuery { // Limit the number of records to be returned by this query. func (nq *NodeQuery) Limit(limit int) *NodeQuery { - nq.limit = &limit + nq.ctx.Limit = &limit return nq } // Offset to start from. func (nq *NodeQuery) Offset(offset int) *NodeQuery { - nq.offset = &offset + nq.ctx.Offset = &offset return nq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (nq *NodeQuery) Unique(unique bool) *NodeQuery { - nq.unique = &unique + nq.ctx.Unique = &unique return nq } @@ -117,7 +114,7 @@ func (nq *NodeQuery) QueryNext() *NodeQuery { // First returns the first Node entity from the query. // Returns a *NotFoundError when no Node was found. func (nq *NodeQuery) First(ctx context.Context) (*Node, error) { - nodes, err := nq.Limit(1).All(newQueryContext(ctx, TypeNode, "First")) + nodes, err := nq.Limit(1).All(setContextOp(ctx, nq.ctx, "First")) if err != nil { return nil, err } @@ -140,7 +137,7 @@ func (nq *NodeQuery) FirstX(ctx context.Context) *Node { // Returns a *NotFoundError when no Node ID was found. func (nq *NodeQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = nq.Limit(1).IDs(newQueryContext(ctx, TypeNode, "FirstID")); err != nil { + if ids, err = nq.Limit(1).IDs(setContextOp(ctx, nq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -163,7 +160,7 @@ func (nq *NodeQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Node entity is found. // Returns a *NotFoundError when no Node entities are found. func (nq *NodeQuery) Only(ctx context.Context) (*Node, error) { - nodes, err := nq.Limit(2).All(newQueryContext(ctx, TypeNode, "Only")) + nodes, err := nq.Limit(2).All(setContextOp(ctx, nq.ctx, "Only")) if err != nil { return nil, err } @@ -191,7 +188,7 @@ func (nq *NodeQuery) OnlyX(ctx context.Context) *Node { // Returns a *NotFoundError when no entities are found. func (nq *NodeQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = nq.Limit(2).IDs(newQueryContext(ctx, TypeNode, "OnlyID")); err != nil { + if ids, err = nq.Limit(2).IDs(setContextOp(ctx, nq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -216,7 +213,7 @@ func (nq *NodeQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Nodes. func (nq *NodeQuery) All(ctx context.Context) ([]*Node, error) { - ctx = newQueryContext(ctx, TypeNode, "All") + ctx = setContextOp(ctx, nq.ctx, "All") if err := nq.prepareQuery(ctx); err != nil { return nil, err } @@ -236,7 +233,7 @@ func (nq *NodeQuery) AllX(ctx context.Context) []*Node { // IDs executes the query and returns a list of Node IDs. func (nq *NodeQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeNode, "IDs") + ctx = setContextOp(ctx, nq.ctx, "IDs") if err := nq.Select(node.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -254,7 +251,7 @@ func (nq *NodeQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (nq *NodeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeNode, "Count") + ctx = setContextOp(ctx, nq.ctx, "Count") if err := nq.prepareQuery(ctx); err != nil { return 0, err } @@ -272,7 +269,7 @@ func (nq *NodeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (nq *NodeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeNode, "Exist") + ctx = setContextOp(ctx, nq.ctx, "Exist") switch _, err := nq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -300,17 +297,15 @@ func (nq *NodeQuery) Clone() *NodeQuery { } return &NodeQuery{ config: nq.config, - limit: nq.limit, - offset: nq.offset, + ctx: nq.ctx.Clone(), order: append([]OrderFunc{}, nq.order...), inters: append([]Interceptor{}, nq.inters...), predicates: append([]predicate.Node{}, nq.predicates...), withPrev: nq.withPrev.Clone(), withNext: nq.withNext.Clone(), // clone intermediate query. - sql: nq.sql.Clone(), - path: nq.path, - unique: nq.unique, + sql: nq.sql.Clone(), + path: nq.path, } } @@ -351,9 +346,9 @@ func (nq *NodeQuery) WithNext(opts ...func(*NodeQuery)) *NodeQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (nq *NodeQuery) GroupBy(field string, fields ...string) *NodeGroupBy { - nq.fields = append([]string{field}, fields...) + nq.ctx.Fields = append([]string{field}, fields...) grbuild := &NodeGroupBy{build: nq} - grbuild.flds = &nq.fields + grbuild.flds = &nq.ctx.Fields grbuild.label = node.Label grbuild.scan = grbuild.Scan return grbuild @@ -372,10 +367,10 @@ func (nq *NodeQuery) GroupBy(field string, fields ...string) *NodeGroupBy { // Select(node.FieldValue). // Scan(ctx, &v) func (nq *NodeQuery) Select(fields ...string) *NodeSelect { - nq.fields = append(nq.fields, fields...) + nq.ctx.Fields = append(nq.ctx.Fields, fields...) sbuild := &NodeSelect{NodeQuery: nq} sbuild.label = node.Label - sbuild.flds, sbuild.scan = &nq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &nq.ctx.Fields, sbuild.Scan return sbuild } @@ -395,7 +390,7 @@ func (nq *NodeQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range nq.fields { + for _, f := range nq.ctx.Fields { if !node.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -528,9 +523,9 @@ func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { if len(nq.modifiers) > 0 { _spec.Modifiers = nq.modifiers } - _spec.Node.Columns = nq.fields - if len(nq.fields) > 0 { - _spec.Unique = nq.unique != nil && *nq.unique + _spec.Node.Columns = nq.ctx.Fields + if len(nq.ctx.Fields) > 0 { + _spec.Unique = nq.ctx.Unique != nil && *nq.ctx.Unique } return sqlgraph.CountNodes(ctx, nq.driver, _spec) } @@ -548,10 +543,10 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { From: nq.sql, Unique: true, } - if unique := nq.unique; unique != nil { + if unique := nq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := nq.fields; len(fields) > 0 { + if fields := nq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, node.FieldID) for i := range fields { @@ -567,10 +562,10 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := nq.limit; limit != nil { + if limit := nq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := nq.offset; offset != nil { + if offset := nq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := nq.order; len(ps) > 0 { @@ -586,7 +581,7 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(nq.driver.Dialect()) t1 := builder.Table(node.Table) - columns := nq.fields + columns := nq.ctx.Fields if len(columns) == 0 { columns = node.Columns } @@ -595,7 +590,7 @@ func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = nq.sql selector.Select(selector.Columns(columns...)...) } - if nq.unique != nil && *nq.unique { + if nq.ctx.Unique != nil && *nq.ctx.Unique { selector.Distinct() } for _, m := range nq.modifiers { @@ -607,12 +602,12 @@ func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range nq.order { p(selector) } - if offset := nq.offset; offset != nil { + if offset := nq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := nq.limit; limit != nil { + if limit := nq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -664,7 +659,7 @@ func (ngb *NodeGroupBy) Aggregate(fns ...AggregateFunc) *NodeGroupBy { // Scan applies the selector query and scans the result into the given value. func (ngb *NodeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeNode, "GroupBy") + ctx = setContextOp(ctx, ngb.build.ctx, "GroupBy") if err := ngb.build.prepareQuery(ctx); err != nil { return err } @@ -712,7 +707,7 @@ func (ns *NodeSelect) Aggregate(fns ...AggregateFunc) *NodeSelect { // Scan applies the selector query and scans the result into the given value. func (ns *NodeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeNode, "Select") + ctx = setContextOp(ctx, ns.ctx, "Select") if err := ns.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/pet_query.go b/entc/integration/ent/pet_query.go index 4d1132939..85f5cebb7 100644 --- a/entc/integration/ent/pet_query.go +++ b/entc/integration/ent/pet_query.go @@ -23,11 +23,8 @@ import ( // PetQuery is the builder for querying Pet entities. type PetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Pet withTeam *UserQuery @@ -47,20 +44,20 @@ func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { // Limit the number of records to be returned by this query. func (pq *PetQuery) Limit(limit int) *PetQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PetQuery) Offset(offset int) *PetQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PetQuery) Unique(unique bool) *PetQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -117,7 +114,7 @@ func (pq *PetQuery) QueryOwner() *UserQuery { // First returns the first Pet entity from the query. // Returns a *NotFoundError when no Pet was found. func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePet, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -140,7 +137,7 @@ func (pq *PetQuery) FirstX(ctx context.Context) *Pet { // Returns a *NotFoundError when no Pet ID was found. func (pq *PetQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePet, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -163,7 +160,7 @@ func (pq *PetQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Pet entity is found. // Returns a *NotFoundError when no Pet entities are found. func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePet, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -191,7 +188,7 @@ func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { // Returns a *NotFoundError when no entities are found. func (pq *PetQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePet, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -216,7 +213,7 @@ func (pq *PetQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Pets. func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { - ctx = newQueryContext(ctx, TypePet, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -236,7 +233,7 @@ func (pq *PetQuery) AllX(ctx context.Context) []*Pet { // IDs executes the query and returns a list of Pet IDs. func (pq *PetQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypePet, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(pet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -254,7 +251,7 @@ func (pq *PetQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pq *PetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePet, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -272,7 +269,7 @@ func (pq *PetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePet, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -300,17 +297,15 @@ func (pq *PetQuery) Clone() *PetQuery { } return &PetQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Pet{}, pq.predicates...), withTeam: pq.withTeam.Clone(), withOwner: pq.withOwner.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -351,9 +346,9 @@ func (pq *PetQuery) WithOwner(opts ...func(*UserQuery)) *PetQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PetGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = pet.Label grbuild.scan = grbuild.Scan return grbuild @@ -372,10 +367,10 @@ func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { // Select(pet.FieldAge). // Scan(ctx, &v) func (pq *PetQuery) Select(fields ...string) *PetSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PetSelect{PetQuery: pq} sbuild.label = pet.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -395,7 +390,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !pet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -532,9 +527,9 @@ func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { if len(pq.modifiers) > 0 { _spec.Modifiers = pq.modifiers } - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -552,10 +547,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, pet.FieldID) for i := range fields { @@ -571,10 +566,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -590,7 +585,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = pet.Columns } @@ -599,7 +594,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } for _, m := range pq.modifiers { @@ -611,12 +606,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -668,7 +663,7 @@ func (pgb *PetGroupBy) Aggregate(fns ...AggregateFunc) *PetGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -716,7 +711,7 @@ func (ps *PetSelect) Aggregate(fns ...AggregateFunc) *PetSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/spec_query.go b/entc/integration/ent/spec_query.go index 21ce5e3af..b852d01a1 100644 --- a/entc/integration/ent/spec_query.go +++ b/entc/integration/ent/spec_query.go @@ -24,11 +24,8 @@ import ( // SpecQuery is the builder for querying Spec entities. type SpecQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Spec withCard *CardQuery @@ -47,20 +44,20 @@ func (sq *SpecQuery) Where(ps ...predicate.Spec) *SpecQuery { // Limit the number of records to be returned by this query. func (sq *SpecQuery) Limit(limit int) *SpecQuery { - sq.limit = &limit + sq.ctx.Limit = &limit return sq } // Offset to start from. func (sq *SpecQuery) Offset(offset int) *SpecQuery { - sq.offset = &offset + sq.ctx.Offset = &offset return sq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (sq *SpecQuery) Unique(unique bool) *SpecQuery { - sq.unique = &unique + sq.ctx.Unique = &unique return sq } @@ -95,7 +92,7 @@ func (sq *SpecQuery) QueryCard() *CardQuery { // First returns the first Spec entity from the query. // Returns a *NotFoundError when no Spec was found. func (sq *SpecQuery) First(ctx context.Context) (*Spec, error) { - nodes, err := sq.Limit(1).All(newQueryContext(ctx, TypeSpec, "First")) + nodes, err := sq.Limit(1).All(setContextOp(ctx, sq.ctx, "First")) if err != nil { return nil, err } @@ -118,7 +115,7 @@ func (sq *SpecQuery) FirstX(ctx context.Context) *Spec { // Returns a *NotFoundError when no Spec ID was found. func (sq *SpecQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sq.Limit(1).IDs(newQueryContext(ctx, TypeSpec, "FirstID")); err != nil { + if ids, err = sq.Limit(1).IDs(setContextOp(ctx, sq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -141,7 +138,7 @@ func (sq *SpecQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Spec entity is found. // Returns a *NotFoundError when no Spec entities are found. func (sq *SpecQuery) Only(ctx context.Context) (*Spec, error) { - nodes, err := sq.Limit(2).All(newQueryContext(ctx, TypeSpec, "Only")) + nodes, err := sq.Limit(2).All(setContextOp(ctx, sq.ctx, "Only")) if err != nil { return nil, err } @@ -169,7 +166,7 @@ func (sq *SpecQuery) OnlyX(ctx context.Context) *Spec { // Returns a *NotFoundError when no entities are found. func (sq *SpecQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sq.Limit(2).IDs(newQueryContext(ctx, TypeSpec, "OnlyID")); err != nil { + if ids, err = sq.Limit(2).IDs(setContextOp(ctx, sq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -194,7 +191,7 @@ func (sq *SpecQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Specs. func (sq *SpecQuery) All(ctx context.Context) ([]*Spec, error) { - ctx = newQueryContext(ctx, TypeSpec, "All") + ctx = setContextOp(ctx, sq.ctx, "All") if err := sq.prepareQuery(ctx); err != nil { return nil, err } @@ -214,7 +211,7 @@ func (sq *SpecQuery) AllX(ctx context.Context) []*Spec { // IDs executes the query and returns a list of Spec IDs. func (sq *SpecQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeSpec, "IDs") + ctx = setContextOp(ctx, sq.ctx, "IDs") if err := sq.Select(spec.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -232,7 +229,7 @@ func (sq *SpecQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (sq *SpecQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeSpec, "Count") + ctx = setContextOp(ctx, sq.ctx, "Count") if err := sq.prepareQuery(ctx); err != nil { return 0, err } @@ -250,7 +247,7 @@ func (sq *SpecQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (sq *SpecQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeSpec, "Exist") + ctx = setContextOp(ctx, sq.ctx, "Exist") switch _, err := sq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -278,16 +275,14 @@ func (sq *SpecQuery) Clone() *SpecQuery { } return &SpecQuery{ config: sq.config, - limit: sq.limit, - offset: sq.offset, + ctx: sq.ctx.Clone(), order: append([]OrderFunc{}, sq.order...), inters: append([]Interceptor{}, sq.inters...), predicates: append([]predicate.Spec{}, sq.predicates...), withCard: sq.withCard.Clone(), // clone intermediate query. - sql: sq.sql.Clone(), - path: sq.path, - unique: sq.unique, + sql: sq.sql.Clone(), + path: sq.path, } } @@ -305,9 +300,9 @@ func (sq *SpecQuery) WithCard(opts ...func(*CardQuery)) *SpecQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (sq *SpecQuery) GroupBy(field string, fields ...string) *SpecGroupBy { - sq.fields = append([]string{field}, fields...) + sq.ctx.Fields = append([]string{field}, fields...) grbuild := &SpecGroupBy{build: sq} - grbuild.flds = &sq.fields + grbuild.flds = &sq.ctx.Fields grbuild.label = spec.Label grbuild.scan = grbuild.Scan return grbuild @@ -316,10 +311,10 @@ func (sq *SpecQuery) GroupBy(field string, fields ...string) *SpecGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (sq *SpecQuery) Select(fields ...string) *SpecSelect { - sq.fields = append(sq.fields, fields...) + sq.ctx.Fields = append(sq.ctx.Fields, fields...) sbuild := &SpecSelect{SpecQuery: sq} sbuild.label = spec.Label - sbuild.flds, sbuild.scan = &sq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &sq.ctx.Fields, sbuild.Scan return sbuild } @@ -339,7 +334,7 @@ func (sq *SpecQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range sq.fields { + for _, f := range sq.ctx.Fields { if !spec.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -464,9 +459,9 @@ func (sq *SpecQuery) sqlCount(ctx context.Context) (int, error) { if len(sq.modifiers) > 0 { _spec.Modifiers = sq.modifiers } - _spec.Node.Columns = sq.fields - if len(sq.fields) > 0 { - _spec.Unique = sq.unique != nil && *sq.unique + _spec.Node.Columns = sq.ctx.Fields + if len(sq.ctx.Fields) > 0 { + _spec.Unique = sq.ctx.Unique != nil && *sq.ctx.Unique } return sqlgraph.CountNodes(ctx, sq.driver, _spec) } @@ -484,10 +479,10 @@ func (sq *SpecQuery) querySpec() *sqlgraph.QuerySpec { From: sq.sql, Unique: true, } - if unique := sq.unique; unique != nil { + if unique := sq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := sq.fields; len(fields) > 0 { + if fields := sq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, spec.FieldID) for i := range fields { @@ -503,10 +498,10 @@ func (sq *SpecQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := sq.order; len(ps) > 0 { @@ -522,7 +517,7 @@ func (sq *SpecQuery) querySpec() *sqlgraph.QuerySpec { func (sq *SpecQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(sq.driver.Dialect()) t1 := builder.Table(spec.Table) - columns := sq.fields + columns := sq.ctx.Fields if len(columns) == 0 { columns = spec.Columns } @@ -531,7 +526,7 @@ func (sq *SpecQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = sq.sql selector.Select(selector.Columns(columns...)...) } - if sq.unique != nil && *sq.unique { + if sq.ctx.Unique != nil && *sq.ctx.Unique { selector.Distinct() } for _, m := range sq.modifiers { @@ -543,12 +538,12 @@ func (sq *SpecQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range sq.order { p(selector) } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -614,7 +609,7 @@ func (sgb *SpecGroupBy) Aggregate(fns ...AggregateFunc) *SpecGroupBy { // Scan applies the selector query and scans the result into the given value. func (sgb *SpecGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeSpec, "GroupBy") + ctx = setContextOp(ctx, sgb.build.ctx, "GroupBy") if err := sgb.build.prepareQuery(ctx); err != nil { return err } @@ -662,7 +657,7 @@ func (ss *SpecSelect) Aggregate(fns ...AggregateFunc) *SpecSelect { // Scan applies the selector query and scans the result into the given value. func (ss *SpecSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeSpec, "Select") + ctx = setContextOp(ctx, ss.ctx, "Select") if err := ss.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/task_query.go b/entc/integration/ent/task_query.go index 2a2696fb0..ae4c595b3 100644 --- a/entc/integration/ent/task_query.go +++ b/entc/integration/ent/task_query.go @@ -23,11 +23,8 @@ import ( // TaskQuery is the builder for querying Task entities. type TaskQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Task modifiers []func(*sql.Selector) @@ -44,20 +41,20 @@ func (tq *TaskQuery) Where(ps ...predicate.Task) *TaskQuery { // Limit the number of records to be returned by this query. func (tq *TaskQuery) Limit(limit int) *TaskQuery { - tq.limit = &limit + tq.ctx.Limit = &limit return tq } // Offset to start from. func (tq *TaskQuery) Offset(offset int) *TaskQuery { - tq.offset = &offset + tq.ctx.Offset = &offset return tq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tq *TaskQuery) Unique(unique bool) *TaskQuery { - tq.unique = &unique + tq.ctx.Unique = &unique return tq } @@ -70,7 +67,7 @@ func (tq *TaskQuery) Order(o ...OrderFunc) *TaskQuery { // First returns the first Task entity from the query. // Returns a *NotFoundError when no Task was found. func (tq *TaskQuery) First(ctx context.Context) (*Task, error) { - nodes, err := tq.Limit(1).All(newQueryContext(ctx, TypeTask, "First")) + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) if err != nil { return nil, err } @@ -93,7 +90,7 @@ func (tq *TaskQuery) FirstX(ctx context.Context) *Task { // Returns a *NotFoundError when no Task ID was found. func (tq *TaskQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(1).IDs(newQueryContext(ctx, TypeTask, "FirstID")); err != nil { + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -116,7 +113,7 @@ func (tq *TaskQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Task entity is found. // Returns a *NotFoundError when no Task entities are found. func (tq *TaskQuery) Only(ctx context.Context) (*Task, error) { - nodes, err := tq.Limit(2).All(newQueryContext(ctx, TypeTask, "Only")) + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) if err != nil { return nil, err } @@ -144,7 +141,7 @@ func (tq *TaskQuery) OnlyX(ctx context.Context) *Task { // Returns a *NotFoundError when no entities are found. func (tq *TaskQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(2).IDs(newQueryContext(ctx, TypeTask, "OnlyID")); err != nil { + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -169,7 +166,7 @@ func (tq *TaskQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Tasks. func (tq *TaskQuery) All(ctx context.Context) ([]*Task, error) { - ctx = newQueryContext(ctx, TypeTask, "All") + ctx = setContextOp(ctx, tq.ctx, "All") if err := tq.prepareQuery(ctx); err != nil { return nil, err } @@ -189,7 +186,7 @@ func (tq *TaskQuery) AllX(ctx context.Context) []*Task { // IDs executes the query and returns a list of Task IDs. func (tq *TaskQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeTask, "IDs") + ctx = setContextOp(ctx, tq.ctx, "IDs") if err := tq.Select(enttask.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -207,7 +204,7 @@ func (tq *TaskQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (tq *TaskQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeTask, "Count") + ctx = setContextOp(ctx, tq.ctx, "Count") if err := tq.prepareQuery(ctx); err != nil { return 0, err } @@ -225,7 +222,7 @@ func (tq *TaskQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tq *TaskQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeTask, "Exist") + ctx = setContextOp(ctx, tq.ctx, "Exist") switch _, err := tq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -253,15 +250,13 @@ func (tq *TaskQuery) Clone() *TaskQuery { } return &TaskQuery{ config: tq.config, - limit: tq.limit, - offset: tq.offset, + ctx: tq.ctx.Clone(), order: append([]OrderFunc{}, tq.order...), inters: append([]Interceptor{}, tq.inters...), predicates: append([]predicate.Task{}, tq.predicates...), // clone intermediate query. - sql: tq.sql.Clone(), - path: tq.path, - unique: tq.unique, + sql: tq.sql.Clone(), + path: tq.path, } } @@ -280,9 +275,9 @@ func (tq *TaskQuery) Clone() *TaskQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (tq *TaskQuery) GroupBy(field string, fields ...string) *TaskGroupBy { - tq.fields = append([]string{field}, fields...) + tq.ctx.Fields = append([]string{field}, fields...) grbuild := &TaskGroupBy{build: tq} - grbuild.flds = &tq.fields + grbuild.flds = &tq.ctx.Fields grbuild.label = enttask.Label grbuild.scan = grbuild.Scan return grbuild @@ -301,10 +296,10 @@ func (tq *TaskQuery) GroupBy(field string, fields ...string) *TaskGroupBy { // Select(enttask.FieldPriority). // Scan(ctx, &v) func (tq *TaskQuery) Select(fields ...string) *TaskSelect { - tq.fields = append(tq.fields, fields...) + tq.ctx.Fields = append(tq.ctx.Fields, fields...) sbuild := &TaskSelect{TaskQuery: tq} sbuild.label = enttask.Label - sbuild.flds, sbuild.scan = &tq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan return sbuild } @@ -324,7 +319,7 @@ func (tq *TaskQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range tq.fields { + for _, f := range tq.ctx.Fields { if !enttask.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -372,9 +367,9 @@ func (tq *TaskQuery) sqlCount(ctx context.Context) (int, error) { if len(tq.modifiers) > 0 { _spec.Modifiers = tq.modifiers } - _spec.Node.Columns = tq.fields - if len(tq.fields) > 0 { - _spec.Unique = tq.unique != nil && *tq.unique + _spec.Node.Columns = tq.ctx.Fields + if len(tq.ctx.Fields) > 0 { + _spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique } return sqlgraph.CountNodes(ctx, tq.driver, _spec) } @@ -392,10 +387,10 @@ func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { From: tq.sql, Unique: true, } - if unique := tq.unique; unique != nil { + if unique := tq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := tq.fields; len(fields) > 0 { + if fields := tq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, enttask.FieldID) for i := range fields { @@ -411,10 +406,10 @@ func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := tq.order; len(ps) > 0 { @@ -430,7 +425,7 @@ func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(tq.driver.Dialect()) t1 := builder.Table(enttask.Table) - columns := tq.fields + columns := tq.ctx.Fields if len(columns) == 0 { columns = enttask.Columns } @@ -439,7 +434,7 @@ func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = tq.sql selector.Select(selector.Columns(columns...)...) } - if tq.unique != nil && *tq.unique { + if tq.ctx.Unique != nil && *tq.ctx.Unique { selector.Distinct() } for _, m := range tq.modifiers { @@ -451,12 +446,12 @@ func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range tq.order { p(selector) } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -508,7 +503,7 @@ func (tgb *TaskGroupBy) Aggregate(fns ...AggregateFunc) *TaskGroupBy { // Scan applies the selector query and scans the result into the given value. func (tgb *TaskGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTask, "GroupBy") + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") if err := tgb.build.prepareQuery(ctx); err != nil { return err } @@ -556,7 +551,7 @@ func (ts *TaskSelect) Aggregate(fns ...AggregateFunc) *TaskSelect { // Scan applies the selector query and scans the result into the given value. func (ts *TaskSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTask, "Select") + ctx = setContextOp(ctx, ts.ctx, "Select") if err := ts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/ent/user_query.go b/entc/integration/ent/user_query.go index 41f858a58..e06be007e 100644 --- a/entc/integration/ent/user_query.go +++ b/entc/integration/ent/user_query.go @@ -27,11 +27,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withCard *CardQuery @@ -67,20 +64,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -335,7 +332,7 @@ func (uq *UserQuery) QueryParent() *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -358,7 +355,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -381,7 +378,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -409,7 +406,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -434,7 +431,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -454,7 +451,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -472,7 +469,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -490,7 +487,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -518,8 +515,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -535,9 +531,8 @@ func (uq *UserQuery) Clone() *UserQuery { withChildren: uq.withChildren.Clone(), withParent: uq.withParent.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -677,9 +672,9 @@ func (uq *UserQuery) WithParent(opts ...func(*UserQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -698,10 +693,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldOptionalInt). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -721,7 +716,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -1358,9 +1353,9 @@ func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { if len(uq.modifiers) > 0 { _spec.Modifiers = uq.modifiers } - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -1378,10 +1373,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -1397,10 +1392,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -1416,7 +1411,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -1425,7 +1420,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, m := range uq.modifiers { @@ -1437,12 +1432,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -1592,7 +1587,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -1640,7 +1635,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/gremlin/ent/api_query.go b/entc/integration/gremlin/ent/api_query.go index d2b3d5473..3bac6810a 100644 --- a/entc/integration/gremlin/ent/api_query.go +++ b/entc/integration/gremlin/ent/api_query.go @@ -22,11 +22,8 @@ import ( // APIQuery is the builder for querying Api entities. type APIQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Api // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (aq *APIQuery) Where(ps ...predicate.Api) *APIQuery { // Limit the number of records to be returned by this query. func (aq *APIQuery) Limit(limit int) *APIQuery { - aq.limit = &limit + aq.ctx.Limit = &limit return aq } // Offset to start from. func (aq *APIQuery) Offset(offset int) *APIQuery { - aq.offset = &offset + aq.ctx.Offset = &offset return aq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (aq *APIQuery) Unique(unique bool) *APIQuery { - aq.unique = &unique + aq.ctx.Unique = &unique return aq } @@ -68,7 +65,7 @@ func (aq *APIQuery) Order(o ...OrderFunc) *APIQuery { // First returns the first Api entity from the query. // Returns a *NotFoundError when no Api was found. func (aq *APIQuery) First(ctx context.Context) (*Api, error) { - nodes, err := aq.Limit(1).All(newQueryContext(ctx, TypeAPI, "First")) + nodes, err := aq.Limit(1).All(setContextOp(ctx, aq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (aq *APIQuery) FirstX(ctx context.Context) *Api { // Returns a *NotFoundError when no Api ID was found. func (aq *APIQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = aq.Limit(1).IDs(newQueryContext(ctx, TypeAPI, "FirstID")); err != nil { + if ids, err = aq.Limit(1).IDs(setContextOp(ctx, aq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (aq *APIQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Api entity is found. // Returns a *NotFoundError when no Api entities are found. func (aq *APIQuery) Only(ctx context.Context) (*Api, error) { - nodes, err := aq.Limit(2).All(newQueryContext(ctx, TypeAPI, "Only")) + nodes, err := aq.Limit(2).All(setContextOp(ctx, aq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (aq *APIQuery) OnlyX(ctx context.Context) *Api { // Returns a *NotFoundError when no entities are found. func (aq *APIQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = aq.Limit(2).IDs(newQueryContext(ctx, TypeAPI, "OnlyID")); err != nil { + if ids, err = aq.Limit(2).IDs(setContextOp(ctx, aq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (aq *APIQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Apis. func (aq *APIQuery) All(ctx context.Context) ([]*Api, error) { - ctx = newQueryContext(ctx, TypeAPI, "All") + ctx = setContextOp(ctx, aq.ctx, "All") if err := aq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (aq *APIQuery) AllX(ctx context.Context) []*Api { // IDs executes the query and returns a list of Api IDs. func (aq *APIQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeAPI, "IDs") + ctx = setContextOp(ctx, aq.ctx, "IDs") if err := aq.Select(api.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (aq *APIQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (aq *APIQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeAPI, "Count") + ctx = setContextOp(ctx, aq.ctx, "Count") if err := aq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (aq *APIQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (aq *APIQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeAPI, "Exist") + ctx = setContextOp(ctx, aq.ctx, "Exist") switch _, err := aq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,24 +248,22 @@ func (aq *APIQuery) Clone() *APIQuery { } return &APIQuery{ config: aq.config, - limit: aq.limit, - offset: aq.offset, + ctx: aq.ctx.Clone(), order: append([]OrderFunc{}, aq.order...), inters: append([]Interceptor{}, aq.inters...), predicates: append([]predicate.Api{}, aq.predicates...), // clone intermediate query. gremlin: aq.gremlin.Clone(), path: aq.path, - unique: aq.unique, } } // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (aq *APIQuery) GroupBy(field string, fields ...string) *APIGroupBy { - aq.fields = append([]string{field}, fields...) + aq.ctx.Fields = append([]string{field}, fields...) grbuild := &APIGroupBy{build: aq} - grbuild.flds = &aq.fields + grbuild.flds = &aq.ctx.Fields grbuild.label = api.Label grbuild.scan = grbuild.Scan return grbuild @@ -277,10 +272,10 @@ func (aq *APIQuery) GroupBy(field string, fields ...string) *APIGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (aq *APIQuery) Select(fields ...string) *APISelect { - aq.fields = append(aq.fields, fields...) + aq.ctx.Fields = append(aq.ctx.Fields, fields...) sbuild := &APISelect{APIQuery: aq} sbuild.label = api.Label - sbuild.flds, sbuild.scan = &aq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &aq.ctx.Fields, sbuild.Scan return sbuild } @@ -313,9 +308,9 @@ func (aq *APIQuery) prepareQuery(ctx context.Context) error { func (aq *APIQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Api, error) { res := &gremlin.Response{} traversal := aq.gremlinQuery(ctx) - if len(aq.fields) > 0 { - fields := make([]any, len(aq.fields)) - for i, f := range aq.fields { + if len(aq.ctx.Fields) > 0 { + fields := make([]any, len(aq.ctx.Fields)) + for i, f := range aq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -357,7 +352,7 @@ func (aq *APIQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := aq.limit, aq.offset; { + switch limit, offset := aq.ctx.Limit, aq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -365,7 +360,7 @@ func (aq *APIQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := aq.unique; unique == nil || *unique { + if unique := aq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -385,7 +380,7 @@ func (agb *APIGroupBy) Aggregate(fns ...AggregateFunc) *APIGroupBy { // Scan applies the selector query and scans the result into the given value. func (agb *APIGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeAPI, "GroupBy") + ctx = setContextOp(ctx, agb.build.ctx, "GroupBy") if err := agb.build.prepareQuery(ctx); err != nil { return err } @@ -440,7 +435,7 @@ func (as *APISelect) Aggregate(fns ...AggregateFunc) *APISelect { // Scan applies the selector query and scans the result into the given value. func (as *APISelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeAPI, "Select") + ctx = setContextOp(ctx, as.ctx, "Select") if err := as.prepareQuery(ctx); err != nil { return err } @@ -452,15 +447,15 @@ func (as *APISelect) gremlinScan(ctx context.Context, root *APIQuery, v any) err res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(as.fields) == 1 { - if as.fields[0] != api.FieldID { - traversal = traversal.Values(as.fields...) + if fields := as.ctx.Fields; len(fields) == 1 { + if fields[0] != api.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(as.fields)) - for i, f := range as.fields { + fields := make([]any, len(as.ctx.Fields)) + for i, f := range as.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -469,7 +464,7 @@ func (as *APISelect) gremlinScan(ctx context.Context, root *APIQuery, v any) err if err := as.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/card_query.go b/entc/integration/gremlin/ent/card_query.go index 25ee9c341..1be3c9762 100644 --- a/entc/integration/gremlin/ent/card_query.go +++ b/entc/integration/gremlin/ent/card_query.go @@ -24,11 +24,8 @@ import ( // CardQuery is the builder for querying Card entities. type CardQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Card withOwner *UserQuery @@ -46,20 +43,20 @@ func (cq *CardQuery) Where(ps ...predicate.Card) *CardQuery { // Limit the number of records to be returned by this query. func (cq *CardQuery) Limit(limit int) *CardQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CardQuery) Offset(offset int) *CardQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CardQuery) Unique(unique bool) *CardQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -100,7 +97,7 @@ func (cq *CardQuery) QuerySpec() *SpecQuery { // First returns the first Card entity from the query. // Returns a *NotFoundError when no Card was found. func (cq *CardQuery) First(ctx context.Context) (*Card, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeCard, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -123,7 +120,7 @@ func (cq *CardQuery) FirstX(ctx context.Context) *Card { // Returns a *NotFoundError when no Card ID was found. func (cq *CardQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeCard, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -146,7 +143,7 @@ func (cq *CardQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Card entity is found. // Returns a *NotFoundError when no Card entities are found. func (cq *CardQuery) Only(ctx context.Context) (*Card, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeCard, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -174,7 +171,7 @@ func (cq *CardQuery) OnlyX(ctx context.Context) *Card { // Returns a *NotFoundError when no entities are found. func (cq *CardQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeCard, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -199,7 +196,7 @@ func (cq *CardQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Cards. func (cq *CardQuery) All(ctx context.Context) ([]*Card, error) { - ctx = newQueryContext(ctx, TypeCard, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -219,7 +216,7 @@ func (cq *CardQuery) AllX(ctx context.Context) []*Card { // IDs executes the query and returns a list of Card IDs. func (cq *CardQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeCard, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(card.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -237,7 +234,7 @@ func (cq *CardQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (cq *CardQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCard, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -255,7 +252,7 @@ func (cq *CardQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CardQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCard, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -283,8 +280,7 @@ func (cq *CardQuery) Clone() *CardQuery { } return &CardQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Card{}, cq.predicates...), @@ -293,7 +289,6 @@ func (cq *CardQuery) Clone() *CardQuery { // clone intermediate query. gremlin: cq.gremlin.Clone(), path: cq.path, - unique: cq.unique, } } @@ -334,9 +329,9 @@ func (cq *CardQuery) WithSpec(opts ...func(*SpecQuery)) *CardQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CardQuery) GroupBy(field string, fields ...string) *CardGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CardGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = card.Label grbuild.scan = grbuild.Scan return grbuild @@ -355,10 +350,10 @@ func (cq *CardQuery) GroupBy(field string, fields ...string) *CardGroupBy { // Select(card.FieldCreateTime). // Scan(ctx, &v) func (cq *CardQuery) Select(fields ...string) *CardSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CardSelect{CardQuery: cq} sbuild.label = card.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -391,9 +386,9 @@ func (cq *CardQuery) prepareQuery(ctx context.Context) error { func (cq *CardQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Card, error) { res := &gremlin.Response{} traversal := cq.gremlinQuery(ctx) - if len(cq.fields) > 0 { - fields := make([]any, len(cq.fields)) - for i, f := range cq.fields { + if len(cq.ctx.Fields) > 0 { + fields := make([]any, len(cq.ctx.Fields)) + for i, f := range cq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -435,7 +430,7 @@ func (cq *CardQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := cq.limit, cq.offset; { + switch limit, offset := cq.ctx.Limit, cq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -443,7 +438,7 @@ func (cq *CardQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := cq.unique; unique == nil || *unique { + if unique := cq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -463,7 +458,7 @@ func (cgb *CardGroupBy) Aggregate(fns ...AggregateFunc) *CardGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CardGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCard, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -518,7 +513,7 @@ func (cs *CardSelect) Aggregate(fns ...AggregateFunc) *CardSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CardSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCard, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } @@ -530,15 +525,15 @@ func (cs *CardSelect) gremlinScan(ctx context.Context, root *CardQuery, v any) e res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(cs.fields) == 1 { - if cs.fields[0] != card.FieldID { - traversal = traversal.Values(cs.fields...) + if fields := cs.ctx.Fields; len(fields) == 1 { + if fields[0] != card.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(cs.fields)) - for i, f := range cs.fields { + fields := make([]any, len(cs.ctx.Fields)) + for i, f := range cs.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -547,7 +542,7 @@ func (cs *CardSelect) gremlinScan(ctx context.Context, root *CardQuery, v any) e if err := cs.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/client.go b/entc/integration/gremlin/ent/client.go index 2fdda56f6..e28f78628 100644 --- a/entc/integration/gremlin/ent/client.go +++ b/entc/integration/gremlin/ent/client.go @@ -349,6 +349,7 @@ func (c *APIClient) DeleteOneID(id string) *APIDeleteOne { func (c *APIClient) Query() *APIQuery { return &APIQuery{ config: c.config, + ctx: &QueryContext{Type: TypeAPI}, inters: c.Interceptors(), } } @@ -466,6 +467,7 @@ func (c *CardClient) DeleteOneID(id string) *CardDeleteOne { func (c *CardClient) Query() *CardQuery { return &CardQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCard}, inters: c.Interceptors(), } } @@ -605,6 +607,7 @@ func (c *CommentClient) DeleteOneID(id string) *CommentDeleteOne { func (c *CommentClient) Query() *CommentQuery { return &CommentQuery{ config: c.config, + ctx: &QueryContext{Type: TypeComment}, inters: c.Interceptors(), } } @@ -722,6 +725,7 @@ func (c *FieldTypeClient) DeleteOneID(id string) *FieldTypeDeleteOne { func (c *FieldTypeClient) Query() *FieldTypeQuery { return &FieldTypeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeFieldType}, inters: c.Interceptors(), } } @@ -839,6 +843,7 @@ func (c *FileClient) DeleteOneID(id string) *FileDeleteOne { func (c *FileClient) Query() *FileQuery { return &FileQuery{ config: c.config, + ctx: &QueryContext{Type: TypeFile}, inters: c.Interceptors(), } } @@ -989,6 +994,7 @@ func (c *FileTypeClient) DeleteOneID(id string) *FileTypeDeleteOne { func (c *FileTypeClient) Query() *FileTypeQuery { return &FileTypeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeFileType}, inters: c.Interceptors(), } } @@ -1117,6 +1123,7 @@ func (c *GoodsClient) DeleteOneID(id string) *GoodsDeleteOne { func (c *GoodsClient) Query() *GoodsQuery { return &GoodsQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGoods}, inters: c.Interceptors(), } } @@ -1234,6 +1241,7 @@ func (c *GroupClient) DeleteOneID(id string) *GroupDeleteOne { func (c *GroupClient) Query() *GroupQuery { return &GroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroup}, inters: c.Interceptors(), } } @@ -1395,6 +1403,7 @@ func (c *GroupInfoClient) DeleteOneID(id string) *GroupInfoDeleteOne { func (c *GroupInfoClient) Query() *GroupInfoQuery { return &GroupInfoQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroupInfo}, inters: c.Interceptors(), } } @@ -1523,6 +1532,7 @@ func (c *ItemClient) DeleteOneID(id string) *ItemDeleteOne { func (c *ItemClient) Query() *ItemQuery { return &ItemQuery{ config: c.config, + ctx: &QueryContext{Type: TypeItem}, inters: c.Interceptors(), } } @@ -1640,6 +1650,7 @@ func (c *LicenseClient) DeleteOneID(id int) *LicenseDeleteOne { func (c *LicenseClient) Query() *LicenseQuery { return &LicenseQuery{ config: c.config, + ctx: &QueryContext{Type: TypeLicense}, inters: c.Interceptors(), } } @@ -1757,6 +1768,7 @@ func (c *NodeClient) DeleteOneID(id string) *NodeDeleteOne { func (c *NodeClient) Query() *NodeQuery { return &NodeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeNode}, inters: c.Interceptors(), } } @@ -1896,6 +1908,7 @@ func (c *PetClient) DeleteOneID(id string) *PetDeleteOne { func (c *PetClient) Query() *PetQuery { return &PetQuery{ config: c.config, + ctx: &QueryContext{Type: TypePet}, inters: c.Interceptors(), } } @@ -2035,6 +2048,7 @@ func (c *SpecClient) DeleteOneID(id string) *SpecDeleteOne { func (c *SpecClient) Query() *SpecQuery { return &SpecQuery{ config: c.config, + ctx: &QueryContext{Type: TypeSpec}, inters: c.Interceptors(), } } @@ -2163,6 +2177,7 @@ func (c *TaskClient) DeleteOneID(id string) *TaskDeleteOne { func (c *TaskClient) Query() *TaskQuery { return &TaskQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTask}, inters: c.Interceptors(), } } @@ -2280,6 +2295,7 @@ func (c *UserClient) DeleteOneID(id string) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/gremlin/ent/comment_query.go b/entc/integration/gremlin/ent/comment_query.go index bf1b8f105..36d5dbf1a 100644 --- a/entc/integration/gremlin/ent/comment_query.go +++ b/entc/integration/gremlin/ent/comment_query.go @@ -22,11 +22,8 @@ import ( // CommentQuery is the builder for querying Comment entities. type CommentQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Comment // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (cq *CommentQuery) Where(ps ...predicate.Comment) *CommentQuery { // Limit the number of records to be returned by this query. func (cq *CommentQuery) Limit(limit int) *CommentQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CommentQuery) Offset(offset int) *CommentQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CommentQuery) Unique(unique bool) *CommentQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -68,7 +65,7 @@ func (cq *CommentQuery) Order(o ...OrderFunc) *CommentQuery { // First returns the first Comment entity from the query. // Returns a *NotFoundError when no Comment was found. func (cq *CommentQuery) First(ctx context.Context) (*Comment, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeComment, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (cq *CommentQuery) FirstX(ctx context.Context) *Comment { // Returns a *NotFoundError when no Comment ID was found. func (cq *CommentQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeComment, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (cq *CommentQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Comment entity is found. // Returns a *NotFoundError when no Comment entities are found. func (cq *CommentQuery) Only(ctx context.Context) (*Comment, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeComment, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (cq *CommentQuery) OnlyX(ctx context.Context) *Comment { // Returns a *NotFoundError when no entities are found. func (cq *CommentQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeComment, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (cq *CommentQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Comments. func (cq *CommentQuery) All(ctx context.Context) ([]*Comment, error) { - ctx = newQueryContext(ctx, TypeComment, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (cq *CommentQuery) AllX(ctx context.Context) []*Comment { // IDs executes the query and returns a list of Comment IDs. func (cq *CommentQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeComment, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(comment.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (cq *CommentQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (cq *CommentQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeComment, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (cq *CommentQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CommentQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeComment, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,15 +248,13 @@ func (cq *CommentQuery) Clone() *CommentQuery { } return &CommentQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Comment{}, cq.predicates...), // clone intermediate query. gremlin: cq.gremlin.Clone(), path: cq.path, - unique: cq.unique, } } @@ -278,9 +273,9 @@ func (cq *CommentQuery) Clone() *CommentQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CommentQuery) GroupBy(field string, fields ...string) *CommentGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CommentGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = comment.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (cq *CommentQuery) GroupBy(field string, fields ...string) *CommentGroupBy // Select(comment.FieldUniqueInt). // Scan(ctx, &v) func (cq *CommentQuery) Select(fields ...string) *CommentSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CommentSelect{CommentQuery: cq} sbuild.label = comment.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -335,9 +330,9 @@ func (cq *CommentQuery) prepareQuery(ctx context.Context) error { func (cq *CommentQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Comment, error) { res := &gremlin.Response{} traversal := cq.gremlinQuery(ctx) - if len(cq.fields) > 0 { - fields := make([]any, len(cq.fields)) - for i, f := range cq.fields { + if len(cq.ctx.Fields) > 0 { + fields := make([]any, len(cq.ctx.Fields)) + for i, f := range cq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -379,7 +374,7 @@ func (cq *CommentQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := cq.limit, cq.offset; { + switch limit, offset := cq.ctx.Limit, cq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -387,7 +382,7 @@ func (cq *CommentQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := cq.unique; unique == nil || *unique { + if unique := cq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -407,7 +402,7 @@ func (cgb *CommentGroupBy) Aggregate(fns ...AggregateFunc) *CommentGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CommentGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeComment, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -462,7 +457,7 @@ func (cs *CommentSelect) Aggregate(fns ...AggregateFunc) *CommentSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CommentSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeComment, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } @@ -474,15 +469,15 @@ func (cs *CommentSelect) gremlinScan(ctx context.Context, root *CommentQuery, v res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(cs.fields) == 1 { - if cs.fields[0] != comment.FieldID { - traversal = traversal.Values(cs.fields...) + if fields := cs.ctx.Fields; len(fields) == 1 { + if fields[0] != comment.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(cs.fields)) - for i, f := range cs.fields { + fields := make([]any, len(cs.ctx.Fields)) + for i, f := range cs.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -491,7 +486,7 @@ func (cs *CommentSelect) gremlinScan(ctx context.Context, root *CommentQuery, v if err := cs.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/ent.go b/entc/integration/gremlin/ent/ent.go index 609dcdba6..c77e9b013 100644 --- a/entc/integration/gremlin/ent/ent.go +++ b/entc/integration/gremlin/ent/ent.go @@ -27,6 +27,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -510,10 +511,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/gremlin/ent/fieldtype_query.go b/entc/integration/gremlin/ent/fieldtype_query.go index d62491f30..613b71ccf 100644 --- a/entc/integration/gremlin/ent/fieldtype_query.go +++ b/entc/integration/gremlin/ent/fieldtype_query.go @@ -22,11 +22,8 @@ import ( // FieldTypeQuery is the builder for querying FieldType entities. type FieldTypeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.FieldType // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (ftq *FieldTypeQuery) Where(ps ...predicate.FieldType) *FieldTypeQuery { // Limit the number of records to be returned by this query. func (ftq *FieldTypeQuery) Limit(limit int) *FieldTypeQuery { - ftq.limit = &limit + ftq.ctx.Limit = &limit return ftq } // Offset to start from. func (ftq *FieldTypeQuery) Offset(offset int) *FieldTypeQuery { - ftq.offset = &offset + ftq.ctx.Offset = &offset return ftq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ftq *FieldTypeQuery) Unique(unique bool) *FieldTypeQuery { - ftq.unique = &unique + ftq.ctx.Unique = &unique return ftq } @@ -68,7 +65,7 @@ func (ftq *FieldTypeQuery) Order(o ...OrderFunc) *FieldTypeQuery { // First returns the first FieldType entity from the query. // Returns a *NotFoundError when no FieldType was found. func (ftq *FieldTypeQuery) First(ctx context.Context) (*FieldType, error) { - nodes, err := ftq.Limit(1).All(newQueryContext(ctx, TypeFieldType, "First")) + nodes, err := ftq.Limit(1).All(setContextOp(ctx, ftq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (ftq *FieldTypeQuery) FirstX(ctx context.Context) *FieldType { // Returns a *NotFoundError when no FieldType ID was found. func (ftq *FieldTypeQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = ftq.Limit(1).IDs(newQueryContext(ctx, TypeFieldType, "FirstID")); err != nil { + if ids, err = ftq.Limit(1).IDs(setContextOp(ctx, ftq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (ftq *FieldTypeQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one FieldType entity is found. // Returns a *NotFoundError when no FieldType entities are found. func (ftq *FieldTypeQuery) Only(ctx context.Context) (*FieldType, error) { - nodes, err := ftq.Limit(2).All(newQueryContext(ctx, TypeFieldType, "Only")) + nodes, err := ftq.Limit(2).All(setContextOp(ctx, ftq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (ftq *FieldTypeQuery) OnlyX(ctx context.Context) *FieldType { // Returns a *NotFoundError when no entities are found. func (ftq *FieldTypeQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = ftq.Limit(2).IDs(newQueryContext(ctx, TypeFieldType, "OnlyID")); err != nil { + if ids, err = ftq.Limit(2).IDs(setContextOp(ctx, ftq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (ftq *FieldTypeQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of FieldTypes. func (ftq *FieldTypeQuery) All(ctx context.Context) ([]*FieldType, error) { - ctx = newQueryContext(ctx, TypeFieldType, "All") + ctx = setContextOp(ctx, ftq.ctx, "All") if err := ftq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (ftq *FieldTypeQuery) AllX(ctx context.Context) []*FieldType { // IDs executes the query and returns a list of FieldType IDs. func (ftq *FieldTypeQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeFieldType, "IDs") + ctx = setContextOp(ctx, ftq.ctx, "IDs") if err := ftq.Select(fieldtype.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (ftq *FieldTypeQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (ftq *FieldTypeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeFieldType, "Count") + ctx = setContextOp(ctx, ftq.ctx, "Count") if err := ftq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (ftq *FieldTypeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ftq *FieldTypeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeFieldType, "Exist") + ctx = setContextOp(ctx, ftq.ctx, "Exist") switch _, err := ftq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,15 +248,13 @@ func (ftq *FieldTypeQuery) Clone() *FieldTypeQuery { } return &FieldTypeQuery{ config: ftq.config, - limit: ftq.limit, - offset: ftq.offset, + ctx: ftq.ctx.Clone(), order: append([]OrderFunc{}, ftq.order...), inters: append([]Interceptor{}, ftq.inters...), predicates: append([]predicate.FieldType{}, ftq.predicates...), // clone intermediate query. gremlin: ftq.gremlin.Clone(), path: ftq.path, - unique: ftq.unique, } } @@ -278,9 +273,9 @@ func (ftq *FieldTypeQuery) Clone() *FieldTypeQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ftq *FieldTypeQuery) GroupBy(field string, fields ...string) *FieldTypeGroupBy { - ftq.fields = append([]string{field}, fields...) + ftq.ctx.Fields = append([]string{field}, fields...) grbuild := &FieldTypeGroupBy{build: ftq} - grbuild.flds = &ftq.fields + grbuild.flds = &ftq.ctx.Fields grbuild.label = fieldtype.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (ftq *FieldTypeQuery) GroupBy(field string, fields ...string) *FieldTypeGro // Select(fieldtype.FieldInt). // Scan(ctx, &v) func (ftq *FieldTypeQuery) Select(fields ...string) *FieldTypeSelect { - ftq.fields = append(ftq.fields, fields...) + ftq.ctx.Fields = append(ftq.ctx.Fields, fields...) sbuild := &FieldTypeSelect{FieldTypeQuery: ftq} sbuild.label = fieldtype.Label - sbuild.flds, sbuild.scan = &ftq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ftq.ctx.Fields, sbuild.Scan return sbuild } @@ -335,9 +330,9 @@ func (ftq *FieldTypeQuery) prepareQuery(ctx context.Context) error { func (ftq *FieldTypeQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*FieldType, error) { res := &gremlin.Response{} traversal := ftq.gremlinQuery(ctx) - if len(ftq.fields) > 0 { - fields := make([]any, len(ftq.fields)) - for i, f := range ftq.fields { + if len(ftq.ctx.Fields) > 0 { + fields := make([]any, len(ftq.ctx.Fields)) + for i, f := range ftq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -379,7 +374,7 @@ func (ftq *FieldTypeQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := ftq.limit, ftq.offset; { + switch limit, offset := ftq.ctx.Limit, ftq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -387,7 +382,7 @@ func (ftq *FieldTypeQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := ftq.unique; unique == nil || *unique { + if unique := ftq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -407,7 +402,7 @@ func (ftgb *FieldTypeGroupBy) Aggregate(fns ...AggregateFunc) *FieldTypeGroupBy // Scan applies the selector query and scans the result into the given value. func (ftgb *FieldTypeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFieldType, "GroupBy") + ctx = setContextOp(ctx, ftgb.build.ctx, "GroupBy") if err := ftgb.build.prepareQuery(ctx); err != nil { return err } @@ -462,7 +457,7 @@ func (fts *FieldTypeSelect) Aggregate(fns ...AggregateFunc) *FieldTypeSelect { // Scan applies the selector query and scans the result into the given value. func (fts *FieldTypeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFieldType, "Select") + ctx = setContextOp(ctx, fts.ctx, "Select") if err := fts.prepareQuery(ctx); err != nil { return err } @@ -474,15 +469,15 @@ func (fts *FieldTypeSelect) gremlinScan(ctx context.Context, root *FieldTypeQuer res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(fts.fields) == 1 { - if fts.fields[0] != fieldtype.FieldID { - traversal = traversal.Values(fts.fields...) + if fields := fts.ctx.Fields; len(fields) == 1 { + if fields[0] != fieldtype.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(fts.fields)) - for i, f := range fts.fields { + fields := make([]any, len(fts.ctx.Fields)) + for i, f := range fts.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -491,7 +486,7 @@ func (fts *FieldTypeSelect) gremlinScan(ctx context.Context, root *FieldTypeQuer if err := fts.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/file_query.go b/entc/integration/gremlin/ent/file_query.go index b9464f1b8..8b3cba302 100644 --- a/entc/integration/gremlin/ent/file_query.go +++ b/entc/integration/gremlin/ent/file_query.go @@ -24,11 +24,8 @@ import ( // FileQuery is the builder for querying File entities. type FileQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.File withOwner *UserQuery @@ -47,20 +44,20 @@ func (fq *FileQuery) Where(ps ...predicate.File) *FileQuery { // Limit the number of records to be returned by this query. func (fq *FileQuery) Limit(limit int) *FileQuery { - fq.limit = &limit + fq.ctx.Limit = &limit return fq } // Offset to start from. func (fq *FileQuery) Offset(offset int) *FileQuery { - fq.offset = &offset + fq.ctx.Offset = &offset return fq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (fq *FileQuery) Unique(unique bool) *FileQuery { - fq.unique = &unique + fq.ctx.Unique = &unique return fq } @@ -115,7 +112,7 @@ func (fq *FileQuery) QueryField() *FieldTypeQuery { // First returns the first File entity from the query. // Returns a *NotFoundError when no File was found. func (fq *FileQuery) First(ctx context.Context) (*File, error) { - nodes, err := fq.Limit(1).All(newQueryContext(ctx, TypeFile, "First")) + nodes, err := fq.Limit(1).All(setContextOp(ctx, fq.ctx, "First")) if err != nil { return nil, err } @@ -138,7 +135,7 @@ func (fq *FileQuery) FirstX(ctx context.Context) *File { // Returns a *NotFoundError when no File ID was found. func (fq *FileQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = fq.Limit(1).IDs(newQueryContext(ctx, TypeFile, "FirstID")); err != nil { + if ids, err = fq.Limit(1).IDs(setContextOp(ctx, fq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -161,7 +158,7 @@ func (fq *FileQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one File entity is found. // Returns a *NotFoundError when no File entities are found. func (fq *FileQuery) Only(ctx context.Context) (*File, error) { - nodes, err := fq.Limit(2).All(newQueryContext(ctx, TypeFile, "Only")) + nodes, err := fq.Limit(2).All(setContextOp(ctx, fq.ctx, "Only")) if err != nil { return nil, err } @@ -189,7 +186,7 @@ func (fq *FileQuery) OnlyX(ctx context.Context) *File { // Returns a *NotFoundError when no entities are found. func (fq *FileQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = fq.Limit(2).IDs(newQueryContext(ctx, TypeFile, "OnlyID")); err != nil { + if ids, err = fq.Limit(2).IDs(setContextOp(ctx, fq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -214,7 +211,7 @@ func (fq *FileQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Files. func (fq *FileQuery) All(ctx context.Context) ([]*File, error) { - ctx = newQueryContext(ctx, TypeFile, "All") + ctx = setContextOp(ctx, fq.ctx, "All") if err := fq.prepareQuery(ctx); err != nil { return nil, err } @@ -234,7 +231,7 @@ func (fq *FileQuery) AllX(ctx context.Context) []*File { // IDs executes the query and returns a list of File IDs. func (fq *FileQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeFile, "IDs") + ctx = setContextOp(ctx, fq.ctx, "IDs") if err := fq.Select(file.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -252,7 +249,7 @@ func (fq *FileQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (fq *FileQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeFile, "Count") + ctx = setContextOp(ctx, fq.ctx, "Count") if err := fq.prepareQuery(ctx); err != nil { return 0, err } @@ -270,7 +267,7 @@ func (fq *FileQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (fq *FileQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeFile, "Exist") + ctx = setContextOp(ctx, fq.ctx, "Exist") switch _, err := fq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -298,8 +295,7 @@ func (fq *FileQuery) Clone() *FileQuery { } return &FileQuery{ config: fq.config, - limit: fq.limit, - offset: fq.offset, + ctx: fq.ctx.Clone(), order: append([]OrderFunc{}, fq.order...), inters: append([]Interceptor{}, fq.inters...), predicates: append([]predicate.File{}, fq.predicates...), @@ -309,7 +305,6 @@ func (fq *FileQuery) Clone() *FileQuery { // clone intermediate query. gremlin: fq.gremlin.Clone(), path: fq.path, - unique: fq.unique, } } @@ -361,9 +356,9 @@ func (fq *FileQuery) WithField(opts ...func(*FieldTypeQuery)) *FileQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (fq *FileQuery) GroupBy(field string, fields ...string) *FileGroupBy { - fq.fields = append([]string{field}, fields...) + fq.ctx.Fields = append([]string{field}, fields...) grbuild := &FileGroupBy{build: fq} - grbuild.flds = &fq.fields + grbuild.flds = &fq.ctx.Fields grbuild.label = file.Label grbuild.scan = grbuild.Scan return grbuild @@ -382,10 +377,10 @@ func (fq *FileQuery) GroupBy(field string, fields ...string) *FileGroupBy { // Select(file.FieldSize). // Scan(ctx, &v) func (fq *FileQuery) Select(fields ...string) *FileSelect { - fq.fields = append(fq.fields, fields...) + fq.ctx.Fields = append(fq.ctx.Fields, fields...) sbuild := &FileSelect{FileQuery: fq} sbuild.label = file.Label - sbuild.flds, sbuild.scan = &fq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &fq.ctx.Fields, sbuild.Scan return sbuild } @@ -418,9 +413,9 @@ func (fq *FileQuery) prepareQuery(ctx context.Context) error { func (fq *FileQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*File, error) { res := &gremlin.Response{} traversal := fq.gremlinQuery(ctx) - if len(fq.fields) > 0 { - fields := make([]any, len(fq.fields)) - for i, f := range fq.fields { + if len(fq.ctx.Fields) > 0 { + fields := make([]any, len(fq.ctx.Fields)) + for i, f := range fq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -462,7 +457,7 @@ func (fq *FileQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := fq.limit, fq.offset; { + switch limit, offset := fq.ctx.Limit, fq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -470,7 +465,7 @@ func (fq *FileQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := fq.unique; unique == nil || *unique { + if unique := fq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -490,7 +485,7 @@ func (fgb *FileGroupBy) Aggregate(fns ...AggregateFunc) *FileGroupBy { // Scan applies the selector query and scans the result into the given value. func (fgb *FileGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFile, "GroupBy") + ctx = setContextOp(ctx, fgb.build.ctx, "GroupBy") if err := fgb.build.prepareQuery(ctx); err != nil { return err } @@ -545,7 +540,7 @@ func (fs *FileSelect) Aggregate(fns ...AggregateFunc) *FileSelect { // Scan applies the selector query and scans the result into the given value. func (fs *FileSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFile, "Select") + ctx = setContextOp(ctx, fs.ctx, "Select") if err := fs.prepareQuery(ctx); err != nil { return err } @@ -557,15 +552,15 @@ func (fs *FileSelect) gremlinScan(ctx context.Context, root *FileQuery, v any) e res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(fs.fields) == 1 { - if fs.fields[0] != file.FieldID { - traversal = traversal.Values(fs.fields...) + if fields := fs.ctx.Fields; len(fields) == 1 { + if fields[0] != file.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(fs.fields)) - for i, f := range fs.fields { + fields := make([]any, len(fs.ctx.Fields)) + for i, f := range fs.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -574,7 +569,7 @@ func (fs *FileSelect) gremlinScan(ctx context.Context, root *FileQuery, v any) e if err := fs.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/filetype_query.go b/entc/integration/gremlin/ent/filetype_query.go index 535d7e4ec..3ec2c7efe 100644 --- a/entc/integration/gremlin/ent/filetype_query.go +++ b/entc/integration/gremlin/ent/filetype_query.go @@ -22,11 +22,8 @@ import ( // FileTypeQuery is the builder for querying FileType entities. type FileTypeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.FileType withFiles *FileQuery @@ -43,20 +40,20 @@ func (ftq *FileTypeQuery) Where(ps ...predicate.FileType) *FileTypeQuery { // Limit the number of records to be returned by this query. func (ftq *FileTypeQuery) Limit(limit int) *FileTypeQuery { - ftq.limit = &limit + ftq.ctx.Limit = &limit return ftq } // Offset to start from. func (ftq *FileTypeQuery) Offset(offset int) *FileTypeQuery { - ftq.offset = &offset + ftq.ctx.Offset = &offset return ftq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ftq *FileTypeQuery) Unique(unique bool) *FileTypeQuery { - ftq.unique = &unique + ftq.ctx.Unique = &unique return ftq } @@ -83,7 +80,7 @@ func (ftq *FileTypeQuery) QueryFiles() *FileQuery { // First returns the first FileType entity from the query. // Returns a *NotFoundError when no FileType was found. func (ftq *FileTypeQuery) First(ctx context.Context) (*FileType, error) { - nodes, err := ftq.Limit(1).All(newQueryContext(ctx, TypeFileType, "First")) + nodes, err := ftq.Limit(1).All(setContextOp(ctx, ftq.ctx, "First")) if err != nil { return nil, err } @@ -106,7 +103,7 @@ func (ftq *FileTypeQuery) FirstX(ctx context.Context) *FileType { // Returns a *NotFoundError when no FileType ID was found. func (ftq *FileTypeQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = ftq.Limit(1).IDs(newQueryContext(ctx, TypeFileType, "FirstID")); err != nil { + if ids, err = ftq.Limit(1).IDs(setContextOp(ctx, ftq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -129,7 +126,7 @@ func (ftq *FileTypeQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one FileType entity is found. // Returns a *NotFoundError when no FileType entities are found. func (ftq *FileTypeQuery) Only(ctx context.Context) (*FileType, error) { - nodes, err := ftq.Limit(2).All(newQueryContext(ctx, TypeFileType, "Only")) + nodes, err := ftq.Limit(2).All(setContextOp(ctx, ftq.ctx, "Only")) if err != nil { return nil, err } @@ -157,7 +154,7 @@ func (ftq *FileTypeQuery) OnlyX(ctx context.Context) *FileType { // Returns a *NotFoundError when no entities are found. func (ftq *FileTypeQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = ftq.Limit(2).IDs(newQueryContext(ctx, TypeFileType, "OnlyID")); err != nil { + if ids, err = ftq.Limit(2).IDs(setContextOp(ctx, ftq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -182,7 +179,7 @@ func (ftq *FileTypeQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of FileTypes. func (ftq *FileTypeQuery) All(ctx context.Context) ([]*FileType, error) { - ctx = newQueryContext(ctx, TypeFileType, "All") + ctx = setContextOp(ctx, ftq.ctx, "All") if err := ftq.prepareQuery(ctx); err != nil { return nil, err } @@ -202,7 +199,7 @@ func (ftq *FileTypeQuery) AllX(ctx context.Context) []*FileType { // IDs executes the query and returns a list of FileType IDs. func (ftq *FileTypeQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeFileType, "IDs") + ctx = setContextOp(ctx, ftq.ctx, "IDs") if err := ftq.Select(filetype.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -220,7 +217,7 @@ func (ftq *FileTypeQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (ftq *FileTypeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeFileType, "Count") + ctx = setContextOp(ctx, ftq.ctx, "Count") if err := ftq.prepareQuery(ctx); err != nil { return 0, err } @@ -238,7 +235,7 @@ func (ftq *FileTypeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ftq *FileTypeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeFileType, "Exist") + ctx = setContextOp(ctx, ftq.ctx, "Exist") switch _, err := ftq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -266,8 +263,7 @@ func (ftq *FileTypeQuery) Clone() *FileTypeQuery { } return &FileTypeQuery{ config: ftq.config, - limit: ftq.limit, - offset: ftq.offset, + ctx: ftq.ctx.Clone(), order: append([]OrderFunc{}, ftq.order...), inters: append([]Interceptor{}, ftq.inters...), predicates: append([]predicate.FileType{}, ftq.predicates...), @@ -275,7 +271,6 @@ func (ftq *FileTypeQuery) Clone() *FileTypeQuery { // clone intermediate query. gremlin: ftq.gremlin.Clone(), path: ftq.path, - unique: ftq.unique, } } @@ -305,9 +300,9 @@ func (ftq *FileTypeQuery) WithFiles(opts ...func(*FileQuery)) *FileTypeQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ftq *FileTypeQuery) GroupBy(field string, fields ...string) *FileTypeGroupBy { - ftq.fields = append([]string{field}, fields...) + ftq.ctx.Fields = append([]string{field}, fields...) grbuild := &FileTypeGroupBy{build: ftq} - grbuild.flds = &ftq.fields + grbuild.flds = &ftq.ctx.Fields grbuild.label = filetype.Label grbuild.scan = grbuild.Scan return grbuild @@ -326,10 +321,10 @@ func (ftq *FileTypeQuery) GroupBy(field string, fields ...string) *FileTypeGroup // Select(filetype.FieldName). // Scan(ctx, &v) func (ftq *FileTypeQuery) Select(fields ...string) *FileTypeSelect { - ftq.fields = append(ftq.fields, fields...) + ftq.ctx.Fields = append(ftq.ctx.Fields, fields...) sbuild := &FileTypeSelect{FileTypeQuery: ftq} sbuild.label = filetype.Label - sbuild.flds, sbuild.scan = &ftq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ftq.ctx.Fields, sbuild.Scan return sbuild } @@ -362,9 +357,9 @@ func (ftq *FileTypeQuery) prepareQuery(ctx context.Context) error { func (ftq *FileTypeQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*FileType, error) { res := &gremlin.Response{} traversal := ftq.gremlinQuery(ctx) - if len(ftq.fields) > 0 { - fields := make([]any, len(ftq.fields)) - for i, f := range ftq.fields { + if len(ftq.ctx.Fields) > 0 { + fields := make([]any, len(ftq.ctx.Fields)) + for i, f := range ftq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -406,7 +401,7 @@ func (ftq *FileTypeQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := ftq.limit, ftq.offset; { + switch limit, offset := ftq.ctx.Limit, ftq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -414,7 +409,7 @@ func (ftq *FileTypeQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := ftq.unique; unique == nil || *unique { + if unique := ftq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -434,7 +429,7 @@ func (ftgb *FileTypeGroupBy) Aggregate(fns ...AggregateFunc) *FileTypeGroupBy { // Scan applies the selector query and scans the result into the given value. func (ftgb *FileTypeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFileType, "GroupBy") + ctx = setContextOp(ctx, ftgb.build.ctx, "GroupBy") if err := ftgb.build.prepareQuery(ctx); err != nil { return err } @@ -489,7 +484,7 @@ func (fts *FileTypeSelect) Aggregate(fns ...AggregateFunc) *FileTypeSelect { // Scan applies the selector query and scans the result into the given value. func (fts *FileTypeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFileType, "Select") + ctx = setContextOp(ctx, fts.ctx, "Select") if err := fts.prepareQuery(ctx); err != nil { return err } @@ -501,15 +496,15 @@ func (fts *FileTypeSelect) gremlinScan(ctx context.Context, root *FileTypeQuery, res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(fts.fields) == 1 { - if fts.fields[0] != filetype.FieldID { - traversal = traversal.Values(fts.fields...) + if fields := fts.ctx.Fields; len(fields) == 1 { + if fields[0] != filetype.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(fts.fields)) - for i, f := range fts.fields { + fields := make([]any, len(fts.ctx.Fields)) + for i, f := range fts.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -518,7 +513,7 @@ func (fts *FileTypeSelect) gremlinScan(ctx context.Context, root *FileTypeQuery, if err := fts.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/goods_query.go b/entc/integration/gremlin/ent/goods_query.go index a553d7328..b4257db63 100644 --- a/entc/integration/gremlin/ent/goods_query.go +++ b/entc/integration/gremlin/ent/goods_query.go @@ -22,11 +22,8 @@ import ( // GoodsQuery is the builder for querying Goods entities. type GoodsQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Goods // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (gq *GoodsQuery) Where(ps ...predicate.Goods) *GoodsQuery { // Limit the number of records to be returned by this query. func (gq *GoodsQuery) Limit(limit int) *GoodsQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GoodsQuery) Offset(offset int) *GoodsQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GoodsQuery) Unique(unique bool) *GoodsQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -68,7 +65,7 @@ func (gq *GoodsQuery) Order(o ...OrderFunc) *GoodsQuery { // First returns the first Goods entity from the query. // Returns a *NotFoundError when no Goods was found. func (gq *GoodsQuery) First(ctx context.Context) (*Goods, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGoods, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (gq *GoodsQuery) FirstX(ctx context.Context) *Goods { // Returns a *NotFoundError when no Goods ID was found. func (gq *GoodsQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGoods, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (gq *GoodsQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Goods entity is found. // Returns a *NotFoundError when no Goods entities are found. func (gq *GoodsQuery) Only(ctx context.Context) (*Goods, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGoods, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (gq *GoodsQuery) OnlyX(ctx context.Context) *Goods { // Returns a *NotFoundError when no entities are found. func (gq *GoodsQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGoods, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (gq *GoodsQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of GoodsSlice. func (gq *GoodsQuery) All(ctx context.Context) ([]*Goods, error) { - ctx = newQueryContext(ctx, TypeGoods, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (gq *GoodsQuery) AllX(ctx context.Context) []*Goods { // IDs executes the query and returns a list of Goods IDs. func (gq *GoodsQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeGoods, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(goods.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (gq *GoodsQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (gq *GoodsQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGoods, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (gq *GoodsQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GoodsQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGoods, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,24 +248,22 @@ func (gq *GoodsQuery) Clone() *GoodsQuery { } return &GoodsQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Goods{}, gq.predicates...), // clone intermediate query. gremlin: gq.gremlin.Clone(), path: gq.path, - unique: gq.unique, } } // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (gq *GoodsQuery) GroupBy(field string, fields ...string) *GoodsGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GoodsGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = goods.Label grbuild.scan = grbuild.Scan return grbuild @@ -277,10 +272,10 @@ func (gq *GoodsQuery) GroupBy(field string, fields ...string) *GoodsGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (gq *GoodsQuery) Select(fields ...string) *GoodsSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GoodsSelect{GoodsQuery: gq} sbuild.label = goods.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -313,9 +308,9 @@ func (gq *GoodsQuery) prepareQuery(ctx context.Context) error { func (gq *GoodsQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Goods, error) { res := &gremlin.Response{} traversal := gq.gremlinQuery(ctx) - if len(gq.fields) > 0 { - fields := make([]any, len(gq.fields)) - for i, f := range gq.fields { + if len(gq.ctx.Fields) > 0 { + fields := make([]any, len(gq.ctx.Fields)) + for i, f := range gq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -357,7 +352,7 @@ func (gq *GoodsQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := gq.limit, gq.offset; { + switch limit, offset := gq.ctx.Limit, gq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -365,7 +360,7 @@ func (gq *GoodsQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := gq.unique; unique == nil || *unique { + if unique := gq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -385,7 +380,7 @@ func (ggb *GoodsGroupBy) Aggregate(fns ...AggregateFunc) *GoodsGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GoodsGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGoods, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -440,7 +435,7 @@ func (gs *GoodsSelect) Aggregate(fns ...AggregateFunc) *GoodsSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GoodsSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGoods, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } @@ -452,15 +447,15 @@ func (gs *GoodsSelect) gremlinScan(ctx context.Context, root *GoodsQuery, v any) res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(gs.fields) == 1 { - if gs.fields[0] != goods.FieldID { - traversal = traversal.Values(gs.fields...) + if fields := gs.ctx.Fields; len(fields) == 1 { + if fields[0] != goods.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(gs.fields)) - for i, f := range gs.fields { + fields := make([]any, len(gs.ctx.Fields)) + for i, f := range gs.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -469,7 +464,7 @@ func (gs *GoodsSelect) gremlinScan(ctx context.Context, root *GoodsQuery, v any) if err := gs.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/group_query.go b/entc/integration/gremlin/ent/group_query.go index a0fb81cf3..72628ecbd 100644 --- a/entc/integration/gremlin/ent/group_query.go +++ b/entc/integration/gremlin/ent/group_query.go @@ -23,11 +23,8 @@ import ( // GroupQuery is the builder for querying Group entities. type GroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Group withFiles *FileQuery @@ -47,20 +44,20 @@ func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { // Limit the number of records to be returned by this query. func (gq *GroupQuery) Limit(limit int) *GroupQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GroupQuery) Offset(offset int) *GroupQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GroupQuery) Unique(unique bool) *GroupQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -129,7 +126,7 @@ func (gq *GroupQuery) QueryInfo() *GroupInfoQuery { // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGroup, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -152,7 +149,7 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group { // Returns a *NotFoundError when no Group ID was found. func (gq *GroupQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGroup, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -175,7 +172,7 @@ func (gq *GroupQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Group entity is found. // Returns a *NotFoundError when no Group entities are found. func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGroup, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -203,7 +200,7 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { // Returns a *NotFoundError when no entities are found. func (gq *GroupQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGroup, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -228,7 +225,7 @@ func (gq *GroupQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Groups. func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { - ctx = newQueryContext(ctx, TypeGroup, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -248,7 +245,7 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group { // IDs executes the query and returns a list of Group IDs. func (gq *GroupQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeGroup, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -266,7 +263,7 @@ func (gq *GroupQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (gq *GroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroup, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -284,7 +281,7 @@ func (gq *GroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroup, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -312,8 +309,7 @@ func (gq *GroupQuery) Clone() *GroupQuery { } return &GroupQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Group{}, gq.predicates...), @@ -324,7 +320,6 @@ func (gq *GroupQuery) Clone() *GroupQuery { // clone intermediate query. gremlin: gq.gremlin.Clone(), path: gq.path, - unique: gq.unique, } } @@ -387,9 +382,9 @@ func (gq *GroupQuery) WithInfo(opts ...func(*GroupInfoQuery)) *GroupQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = group.Label grbuild.scan = grbuild.Scan return grbuild @@ -408,10 +403,10 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Select(group.FieldActive). // Scan(ctx, &v) func (gq *GroupQuery) Select(fields ...string) *GroupSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GroupSelect{GroupQuery: gq} sbuild.label = group.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -444,9 +439,9 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { func (gq *GroupQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { res := &gremlin.Response{} traversal := gq.gremlinQuery(ctx) - if len(gq.fields) > 0 { - fields := make([]any, len(gq.fields)) - for i, f := range gq.fields { + if len(gq.ctx.Fields) > 0 { + fields := make([]any, len(gq.ctx.Fields)) + for i, f := range gq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -488,7 +483,7 @@ func (gq *GroupQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := gq.limit, gq.offset; { + switch limit, offset := gq.ctx.Limit, gq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -496,7 +491,7 @@ func (gq *GroupQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := gq.unique; unique == nil || *unique { + if unique := gq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -516,7 +511,7 @@ func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -571,7 +566,7 @@ func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } @@ -583,15 +578,15 @@ func (gs *GroupSelect) gremlinScan(ctx context.Context, root *GroupQuery, v any) res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(gs.fields) == 1 { - if gs.fields[0] != group.FieldID { - traversal = traversal.Values(gs.fields...) + if fields := gs.ctx.Fields; len(fields) == 1 { + if fields[0] != group.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(gs.fields)) - for i, f := range gs.fields { + fields := make([]any, len(gs.ctx.Fields)) + for i, f := range gs.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -600,7 +595,7 @@ func (gs *GroupSelect) gremlinScan(ctx context.Context, root *GroupQuery, v any) if err := gs.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/groupinfo_query.go b/entc/integration/gremlin/ent/groupinfo_query.go index 41e0dd7f0..1fbd6bcac 100644 --- a/entc/integration/gremlin/ent/groupinfo_query.go +++ b/entc/integration/gremlin/ent/groupinfo_query.go @@ -23,11 +23,8 @@ import ( // GroupInfoQuery is the builder for querying GroupInfo entities. type GroupInfoQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.GroupInfo withGroups *GroupQuery @@ -44,20 +41,20 @@ func (giq *GroupInfoQuery) Where(ps ...predicate.GroupInfo) *GroupInfoQuery { // Limit the number of records to be returned by this query. func (giq *GroupInfoQuery) Limit(limit int) *GroupInfoQuery { - giq.limit = &limit + giq.ctx.Limit = &limit return giq } // Offset to start from. func (giq *GroupInfoQuery) Offset(offset int) *GroupInfoQuery { - giq.offset = &offset + giq.ctx.Offset = &offset return giq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (giq *GroupInfoQuery) Unique(unique bool) *GroupInfoQuery { - giq.unique = &unique + giq.ctx.Unique = &unique return giq } @@ -84,7 +81,7 @@ func (giq *GroupInfoQuery) QueryGroups() *GroupQuery { // First returns the first GroupInfo entity from the query. // Returns a *NotFoundError when no GroupInfo was found. func (giq *GroupInfoQuery) First(ctx context.Context) (*GroupInfo, error) { - nodes, err := giq.Limit(1).All(newQueryContext(ctx, TypeGroupInfo, "First")) + nodes, err := giq.Limit(1).All(setContextOp(ctx, giq.ctx, "First")) if err != nil { return nil, err } @@ -107,7 +104,7 @@ func (giq *GroupInfoQuery) FirstX(ctx context.Context) *GroupInfo { // Returns a *NotFoundError when no GroupInfo ID was found. func (giq *GroupInfoQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = giq.Limit(1).IDs(newQueryContext(ctx, TypeGroupInfo, "FirstID")); err != nil { + if ids, err = giq.Limit(1).IDs(setContextOp(ctx, giq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -130,7 +127,7 @@ func (giq *GroupInfoQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one GroupInfo entity is found. // Returns a *NotFoundError when no GroupInfo entities are found. func (giq *GroupInfoQuery) Only(ctx context.Context) (*GroupInfo, error) { - nodes, err := giq.Limit(2).All(newQueryContext(ctx, TypeGroupInfo, "Only")) + nodes, err := giq.Limit(2).All(setContextOp(ctx, giq.ctx, "Only")) if err != nil { return nil, err } @@ -158,7 +155,7 @@ func (giq *GroupInfoQuery) OnlyX(ctx context.Context) *GroupInfo { // Returns a *NotFoundError when no entities are found. func (giq *GroupInfoQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = giq.Limit(2).IDs(newQueryContext(ctx, TypeGroupInfo, "OnlyID")); err != nil { + if ids, err = giq.Limit(2).IDs(setContextOp(ctx, giq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -183,7 +180,7 @@ func (giq *GroupInfoQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of GroupInfos. func (giq *GroupInfoQuery) All(ctx context.Context) ([]*GroupInfo, error) { - ctx = newQueryContext(ctx, TypeGroupInfo, "All") + ctx = setContextOp(ctx, giq.ctx, "All") if err := giq.prepareQuery(ctx); err != nil { return nil, err } @@ -203,7 +200,7 @@ func (giq *GroupInfoQuery) AllX(ctx context.Context) []*GroupInfo { // IDs executes the query and returns a list of GroupInfo IDs. func (giq *GroupInfoQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeGroupInfo, "IDs") + ctx = setContextOp(ctx, giq.ctx, "IDs") if err := giq.Select(groupinfo.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -221,7 +218,7 @@ func (giq *GroupInfoQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (giq *GroupInfoQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroupInfo, "Count") + ctx = setContextOp(ctx, giq.ctx, "Count") if err := giq.prepareQuery(ctx); err != nil { return 0, err } @@ -239,7 +236,7 @@ func (giq *GroupInfoQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (giq *GroupInfoQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroupInfo, "Exist") + ctx = setContextOp(ctx, giq.ctx, "Exist") switch _, err := giq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -267,8 +264,7 @@ func (giq *GroupInfoQuery) Clone() *GroupInfoQuery { } return &GroupInfoQuery{ config: giq.config, - limit: giq.limit, - offset: giq.offset, + ctx: giq.ctx.Clone(), order: append([]OrderFunc{}, giq.order...), inters: append([]Interceptor{}, giq.inters...), predicates: append([]predicate.GroupInfo{}, giq.predicates...), @@ -276,7 +272,6 @@ func (giq *GroupInfoQuery) Clone() *GroupInfoQuery { // clone intermediate query. gremlin: giq.gremlin.Clone(), path: giq.path, - unique: giq.unique, } } @@ -306,9 +301,9 @@ func (giq *GroupInfoQuery) WithGroups(opts ...func(*GroupQuery)) *GroupInfoQuery // Aggregate(ent.Count()). // Scan(ctx, &v) func (giq *GroupInfoQuery) GroupBy(field string, fields ...string) *GroupInfoGroupBy { - giq.fields = append([]string{field}, fields...) + giq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupInfoGroupBy{build: giq} - grbuild.flds = &giq.fields + grbuild.flds = &giq.ctx.Fields grbuild.label = groupinfo.Label grbuild.scan = grbuild.Scan return grbuild @@ -327,10 +322,10 @@ func (giq *GroupInfoQuery) GroupBy(field string, fields ...string) *GroupInfoGro // Select(groupinfo.FieldDesc). // Scan(ctx, &v) func (giq *GroupInfoQuery) Select(fields ...string) *GroupInfoSelect { - giq.fields = append(giq.fields, fields...) + giq.ctx.Fields = append(giq.ctx.Fields, fields...) sbuild := &GroupInfoSelect{GroupInfoQuery: giq} sbuild.label = groupinfo.Label - sbuild.flds, sbuild.scan = &giq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &giq.ctx.Fields, sbuild.Scan return sbuild } @@ -363,9 +358,9 @@ func (giq *GroupInfoQuery) prepareQuery(ctx context.Context) error { func (giq *GroupInfoQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*GroupInfo, error) { res := &gremlin.Response{} traversal := giq.gremlinQuery(ctx) - if len(giq.fields) > 0 { - fields := make([]any, len(giq.fields)) - for i, f := range giq.fields { + if len(giq.ctx.Fields) > 0 { + fields := make([]any, len(giq.ctx.Fields)) + for i, f := range giq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -407,7 +402,7 @@ func (giq *GroupInfoQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := giq.limit, giq.offset; { + switch limit, offset := giq.ctx.Limit, giq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -415,7 +410,7 @@ func (giq *GroupInfoQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := giq.unique; unique == nil || *unique { + if unique := giq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -435,7 +430,7 @@ func (gigb *GroupInfoGroupBy) Aggregate(fns ...AggregateFunc) *GroupInfoGroupBy // Scan applies the selector query and scans the result into the given value. func (gigb *GroupInfoGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroupInfo, "GroupBy") + ctx = setContextOp(ctx, gigb.build.ctx, "GroupBy") if err := gigb.build.prepareQuery(ctx); err != nil { return err } @@ -490,7 +485,7 @@ func (gis *GroupInfoSelect) Aggregate(fns ...AggregateFunc) *GroupInfoSelect { // Scan applies the selector query and scans the result into the given value. func (gis *GroupInfoSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroupInfo, "Select") + ctx = setContextOp(ctx, gis.ctx, "Select") if err := gis.prepareQuery(ctx); err != nil { return err } @@ -502,15 +497,15 @@ func (gis *GroupInfoSelect) gremlinScan(ctx context.Context, root *GroupInfoQuer res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(gis.fields) == 1 { - if gis.fields[0] != groupinfo.FieldID { - traversal = traversal.Values(gis.fields...) + if fields := gis.ctx.Fields; len(fields) == 1 { + if fields[0] != groupinfo.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(gis.fields)) - for i, f := range gis.fields { + fields := make([]any, len(gis.ctx.Fields)) + for i, f := range gis.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -519,7 +514,7 @@ func (gis *GroupInfoSelect) gremlinScan(ctx context.Context, root *GroupInfoQuer if err := gis.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/item_query.go b/entc/integration/gremlin/ent/item_query.go index 792cfbe71..c05659eef 100644 --- a/entc/integration/gremlin/ent/item_query.go +++ b/entc/integration/gremlin/ent/item_query.go @@ -22,11 +22,8 @@ import ( // ItemQuery is the builder for querying Item entities. type ItemQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Item // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (iq *ItemQuery) Where(ps ...predicate.Item) *ItemQuery { // Limit the number of records to be returned by this query. func (iq *ItemQuery) Limit(limit int) *ItemQuery { - iq.limit = &limit + iq.ctx.Limit = &limit return iq } // Offset to start from. func (iq *ItemQuery) Offset(offset int) *ItemQuery { - iq.offset = &offset + iq.ctx.Offset = &offset return iq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (iq *ItemQuery) Unique(unique bool) *ItemQuery { - iq.unique = &unique + iq.ctx.Unique = &unique return iq } @@ -68,7 +65,7 @@ func (iq *ItemQuery) Order(o ...OrderFunc) *ItemQuery { // First returns the first Item entity from the query. // Returns a *NotFoundError when no Item was found. func (iq *ItemQuery) First(ctx context.Context) (*Item, error) { - nodes, err := iq.Limit(1).All(newQueryContext(ctx, TypeItem, "First")) + nodes, err := iq.Limit(1).All(setContextOp(ctx, iq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (iq *ItemQuery) FirstX(ctx context.Context) *Item { // Returns a *NotFoundError when no Item ID was found. func (iq *ItemQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = iq.Limit(1).IDs(newQueryContext(ctx, TypeItem, "FirstID")); err != nil { + if ids, err = iq.Limit(1).IDs(setContextOp(ctx, iq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (iq *ItemQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Item entity is found. // Returns a *NotFoundError when no Item entities are found. func (iq *ItemQuery) Only(ctx context.Context) (*Item, error) { - nodes, err := iq.Limit(2).All(newQueryContext(ctx, TypeItem, "Only")) + nodes, err := iq.Limit(2).All(setContextOp(ctx, iq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (iq *ItemQuery) OnlyX(ctx context.Context) *Item { // Returns a *NotFoundError when no entities are found. func (iq *ItemQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = iq.Limit(2).IDs(newQueryContext(ctx, TypeItem, "OnlyID")); err != nil { + if ids, err = iq.Limit(2).IDs(setContextOp(ctx, iq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (iq *ItemQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Items. func (iq *ItemQuery) All(ctx context.Context) ([]*Item, error) { - ctx = newQueryContext(ctx, TypeItem, "All") + ctx = setContextOp(ctx, iq.ctx, "All") if err := iq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (iq *ItemQuery) AllX(ctx context.Context) []*Item { // IDs executes the query and returns a list of Item IDs. func (iq *ItemQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeItem, "IDs") + ctx = setContextOp(ctx, iq.ctx, "IDs") if err := iq.Select(item.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (iq *ItemQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (iq *ItemQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeItem, "Count") + ctx = setContextOp(ctx, iq.ctx, "Count") if err := iq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (iq *ItemQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (iq *ItemQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeItem, "Exist") + ctx = setContextOp(ctx, iq.ctx, "Exist") switch _, err := iq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,15 +248,13 @@ func (iq *ItemQuery) Clone() *ItemQuery { } return &ItemQuery{ config: iq.config, - limit: iq.limit, - offset: iq.offset, + ctx: iq.ctx.Clone(), order: append([]OrderFunc{}, iq.order...), inters: append([]Interceptor{}, iq.inters...), predicates: append([]predicate.Item{}, iq.predicates...), // clone intermediate query. gremlin: iq.gremlin.Clone(), path: iq.path, - unique: iq.unique, } } @@ -278,9 +273,9 @@ func (iq *ItemQuery) Clone() *ItemQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (iq *ItemQuery) GroupBy(field string, fields ...string) *ItemGroupBy { - iq.fields = append([]string{field}, fields...) + iq.ctx.Fields = append([]string{field}, fields...) grbuild := &ItemGroupBy{build: iq} - grbuild.flds = &iq.fields + grbuild.flds = &iq.ctx.Fields grbuild.label = item.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (iq *ItemQuery) GroupBy(field string, fields ...string) *ItemGroupBy { // Select(item.FieldText). // Scan(ctx, &v) func (iq *ItemQuery) Select(fields ...string) *ItemSelect { - iq.fields = append(iq.fields, fields...) + iq.ctx.Fields = append(iq.ctx.Fields, fields...) sbuild := &ItemSelect{ItemQuery: iq} sbuild.label = item.Label - sbuild.flds, sbuild.scan = &iq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &iq.ctx.Fields, sbuild.Scan return sbuild } @@ -335,9 +330,9 @@ func (iq *ItemQuery) prepareQuery(ctx context.Context) error { func (iq *ItemQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Item, error) { res := &gremlin.Response{} traversal := iq.gremlinQuery(ctx) - if len(iq.fields) > 0 { - fields := make([]any, len(iq.fields)) - for i, f := range iq.fields { + if len(iq.ctx.Fields) > 0 { + fields := make([]any, len(iq.ctx.Fields)) + for i, f := range iq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -379,7 +374,7 @@ func (iq *ItemQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := iq.limit, iq.offset; { + switch limit, offset := iq.ctx.Limit, iq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -387,7 +382,7 @@ func (iq *ItemQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := iq.unique; unique == nil || *unique { + if unique := iq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -407,7 +402,7 @@ func (igb *ItemGroupBy) Aggregate(fns ...AggregateFunc) *ItemGroupBy { // Scan applies the selector query and scans the result into the given value. func (igb *ItemGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeItem, "GroupBy") + ctx = setContextOp(ctx, igb.build.ctx, "GroupBy") if err := igb.build.prepareQuery(ctx); err != nil { return err } @@ -462,7 +457,7 @@ func (is *ItemSelect) Aggregate(fns ...AggregateFunc) *ItemSelect { // Scan applies the selector query and scans the result into the given value. func (is *ItemSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeItem, "Select") + ctx = setContextOp(ctx, is.ctx, "Select") if err := is.prepareQuery(ctx); err != nil { return err } @@ -474,15 +469,15 @@ func (is *ItemSelect) gremlinScan(ctx context.Context, root *ItemQuery, v any) e res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(is.fields) == 1 { - if is.fields[0] != item.FieldID { - traversal = traversal.Values(is.fields...) + if fields := is.ctx.Fields; len(fields) == 1 { + if fields[0] != item.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(is.fields)) - for i, f := range is.fields { + fields := make([]any, len(is.ctx.Fields)) + for i, f := range is.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -491,7 +486,7 @@ func (is *ItemSelect) gremlinScan(ctx context.Context, root *ItemQuery, v any) e if err := is.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/license_query.go b/entc/integration/gremlin/ent/license_query.go index 2b137ae23..6d06295ca 100644 --- a/entc/integration/gremlin/ent/license_query.go +++ b/entc/integration/gremlin/ent/license_query.go @@ -22,11 +22,8 @@ import ( // LicenseQuery is the builder for querying License entities. type LicenseQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.License // intermediate query (i.e. traversal path). @@ -42,20 +39,20 @@ func (lq *LicenseQuery) Where(ps ...predicate.License) *LicenseQuery { // Limit the number of records to be returned by this query. func (lq *LicenseQuery) Limit(limit int) *LicenseQuery { - lq.limit = &limit + lq.ctx.Limit = &limit return lq } // Offset to start from. func (lq *LicenseQuery) Offset(offset int) *LicenseQuery { - lq.offset = &offset + lq.ctx.Offset = &offset return lq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (lq *LicenseQuery) Unique(unique bool) *LicenseQuery { - lq.unique = &unique + lq.ctx.Unique = &unique return lq } @@ -68,7 +65,7 @@ func (lq *LicenseQuery) Order(o ...OrderFunc) *LicenseQuery { // First returns the first License entity from the query. // Returns a *NotFoundError when no License was found. func (lq *LicenseQuery) First(ctx context.Context) (*License, error) { - nodes, err := lq.Limit(1).All(newQueryContext(ctx, TypeLicense, "First")) + nodes, err := lq.Limit(1).All(setContextOp(ctx, lq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (lq *LicenseQuery) FirstX(ctx context.Context) *License { // Returns a *NotFoundError when no License ID was found. func (lq *LicenseQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = lq.Limit(1).IDs(newQueryContext(ctx, TypeLicense, "FirstID")); err != nil { + if ids, err = lq.Limit(1).IDs(setContextOp(ctx, lq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (lq *LicenseQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one License entity is found. // Returns a *NotFoundError when no License entities are found. func (lq *LicenseQuery) Only(ctx context.Context) (*License, error) { - nodes, err := lq.Limit(2).All(newQueryContext(ctx, TypeLicense, "Only")) + nodes, err := lq.Limit(2).All(setContextOp(ctx, lq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (lq *LicenseQuery) OnlyX(ctx context.Context) *License { // Returns a *NotFoundError when no entities are found. func (lq *LicenseQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = lq.Limit(2).IDs(newQueryContext(ctx, TypeLicense, "OnlyID")); err != nil { + if ids, err = lq.Limit(2).IDs(setContextOp(ctx, lq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (lq *LicenseQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Licenses. func (lq *LicenseQuery) All(ctx context.Context) ([]*License, error) { - ctx = newQueryContext(ctx, TypeLicense, "All") + ctx = setContextOp(ctx, lq.ctx, "All") if err := lq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (lq *LicenseQuery) AllX(ctx context.Context) []*License { // IDs executes the query and returns a list of License IDs. func (lq *LicenseQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeLicense, "IDs") + ctx = setContextOp(ctx, lq.ctx, "IDs") if err := lq.Select(license.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (lq *LicenseQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (lq *LicenseQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeLicense, "Count") + ctx = setContextOp(ctx, lq.ctx, "Count") if err := lq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (lq *LicenseQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (lq *LicenseQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeLicense, "Exist") + ctx = setContextOp(ctx, lq.ctx, "Exist") switch _, err := lq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,15 +248,13 @@ func (lq *LicenseQuery) Clone() *LicenseQuery { } return &LicenseQuery{ config: lq.config, - limit: lq.limit, - offset: lq.offset, + ctx: lq.ctx.Clone(), order: append([]OrderFunc{}, lq.order...), inters: append([]Interceptor{}, lq.inters...), predicates: append([]predicate.License{}, lq.predicates...), // clone intermediate query. gremlin: lq.gremlin.Clone(), path: lq.path, - unique: lq.unique, } } @@ -278,9 +273,9 @@ func (lq *LicenseQuery) Clone() *LicenseQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (lq *LicenseQuery) GroupBy(field string, fields ...string) *LicenseGroupBy { - lq.fields = append([]string{field}, fields...) + lq.ctx.Fields = append([]string{field}, fields...) grbuild := &LicenseGroupBy{build: lq} - grbuild.flds = &lq.fields + grbuild.flds = &lq.ctx.Fields grbuild.label = license.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (lq *LicenseQuery) GroupBy(field string, fields ...string) *LicenseGroupBy // Select(license.FieldCreateTime). // Scan(ctx, &v) func (lq *LicenseQuery) Select(fields ...string) *LicenseSelect { - lq.fields = append(lq.fields, fields...) + lq.ctx.Fields = append(lq.ctx.Fields, fields...) sbuild := &LicenseSelect{LicenseQuery: lq} sbuild.label = license.Label - sbuild.flds, sbuild.scan = &lq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &lq.ctx.Fields, sbuild.Scan return sbuild } @@ -335,9 +330,9 @@ func (lq *LicenseQuery) prepareQuery(ctx context.Context) error { func (lq *LicenseQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*License, error) { res := &gremlin.Response{} traversal := lq.gremlinQuery(ctx) - if len(lq.fields) > 0 { - fields := make([]any, len(lq.fields)) - for i, f := range lq.fields { + if len(lq.ctx.Fields) > 0 { + fields := make([]any, len(lq.ctx.Fields)) + for i, f := range lq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -379,7 +374,7 @@ func (lq *LicenseQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := lq.limit, lq.offset; { + switch limit, offset := lq.ctx.Limit, lq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -387,7 +382,7 @@ func (lq *LicenseQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := lq.unique; unique == nil || *unique { + if unique := lq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -407,7 +402,7 @@ func (lgb *LicenseGroupBy) Aggregate(fns ...AggregateFunc) *LicenseGroupBy { // Scan applies the selector query and scans the result into the given value. func (lgb *LicenseGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeLicense, "GroupBy") + ctx = setContextOp(ctx, lgb.build.ctx, "GroupBy") if err := lgb.build.prepareQuery(ctx); err != nil { return err } @@ -462,7 +457,7 @@ func (ls *LicenseSelect) Aggregate(fns ...AggregateFunc) *LicenseSelect { // Scan applies the selector query and scans the result into the given value. func (ls *LicenseSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeLicense, "Select") + ctx = setContextOp(ctx, ls.ctx, "Select") if err := ls.prepareQuery(ctx); err != nil { return err } @@ -474,15 +469,15 @@ func (ls *LicenseSelect) gremlinScan(ctx context.Context, root *LicenseQuery, v res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(ls.fields) == 1 { - if ls.fields[0] != license.FieldID { - traversal = traversal.Values(ls.fields...) + if fields := ls.ctx.Fields; len(fields) == 1 { + if fields[0] != license.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(ls.fields)) - for i, f := range ls.fields { + fields := make([]any, len(ls.ctx.Fields)) + for i, f := range ls.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -491,7 +486,7 @@ func (ls *LicenseSelect) gremlinScan(ctx context.Context, root *LicenseQuery, v if err := ls.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/node_query.go b/entc/integration/gremlin/ent/node_query.go index 3a22a996b..f455f2243 100644 --- a/entc/integration/gremlin/ent/node_query.go +++ b/entc/integration/gremlin/ent/node_query.go @@ -22,11 +22,8 @@ import ( // NodeQuery is the builder for querying Node entities. type NodeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Node withPrev *NodeQuery @@ -44,20 +41,20 @@ func (nq *NodeQuery) Where(ps ...predicate.Node) *NodeQuery { // Limit the number of records to be returned by this query. func (nq *NodeQuery) Limit(limit int) *NodeQuery { - nq.limit = &limit + nq.ctx.Limit = &limit return nq } // Offset to start from. func (nq *NodeQuery) Offset(offset int) *NodeQuery { - nq.offset = &offset + nq.ctx.Offset = &offset return nq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (nq *NodeQuery) Unique(unique bool) *NodeQuery { - nq.unique = &unique + nq.ctx.Unique = &unique return nq } @@ -98,7 +95,7 @@ func (nq *NodeQuery) QueryNext() *NodeQuery { // First returns the first Node entity from the query. // Returns a *NotFoundError when no Node was found. func (nq *NodeQuery) First(ctx context.Context) (*Node, error) { - nodes, err := nq.Limit(1).All(newQueryContext(ctx, TypeNode, "First")) + nodes, err := nq.Limit(1).All(setContextOp(ctx, nq.ctx, "First")) if err != nil { return nil, err } @@ -121,7 +118,7 @@ func (nq *NodeQuery) FirstX(ctx context.Context) *Node { // Returns a *NotFoundError when no Node ID was found. func (nq *NodeQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = nq.Limit(1).IDs(newQueryContext(ctx, TypeNode, "FirstID")); err != nil { + if ids, err = nq.Limit(1).IDs(setContextOp(ctx, nq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -144,7 +141,7 @@ func (nq *NodeQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Node entity is found. // Returns a *NotFoundError when no Node entities are found. func (nq *NodeQuery) Only(ctx context.Context) (*Node, error) { - nodes, err := nq.Limit(2).All(newQueryContext(ctx, TypeNode, "Only")) + nodes, err := nq.Limit(2).All(setContextOp(ctx, nq.ctx, "Only")) if err != nil { return nil, err } @@ -172,7 +169,7 @@ func (nq *NodeQuery) OnlyX(ctx context.Context) *Node { // Returns a *NotFoundError when no entities are found. func (nq *NodeQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = nq.Limit(2).IDs(newQueryContext(ctx, TypeNode, "OnlyID")); err != nil { + if ids, err = nq.Limit(2).IDs(setContextOp(ctx, nq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -197,7 +194,7 @@ func (nq *NodeQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Nodes. func (nq *NodeQuery) All(ctx context.Context) ([]*Node, error) { - ctx = newQueryContext(ctx, TypeNode, "All") + ctx = setContextOp(ctx, nq.ctx, "All") if err := nq.prepareQuery(ctx); err != nil { return nil, err } @@ -217,7 +214,7 @@ func (nq *NodeQuery) AllX(ctx context.Context) []*Node { // IDs executes the query and returns a list of Node IDs. func (nq *NodeQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeNode, "IDs") + ctx = setContextOp(ctx, nq.ctx, "IDs") if err := nq.Select(node.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -235,7 +232,7 @@ func (nq *NodeQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (nq *NodeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeNode, "Count") + ctx = setContextOp(ctx, nq.ctx, "Count") if err := nq.prepareQuery(ctx); err != nil { return 0, err } @@ -253,7 +250,7 @@ func (nq *NodeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (nq *NodeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeNode, "Exist") + ctx = setContextOp(ctx, nq.ctx, "Exist") switch _, err := nq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -281,8 +278,7 @@ func (nq *NodeQuery) Clone() *NodeQuery { } return &NodeQuery{ config: nq.config, - limit: nq.limit, - offset: nq.offset, + ctx: nq.ctx.Clone(), order: append([]OrderFunc{}, nq.order...), inters: append([]Interceptor{}, nq.inters...), predicates: append([]predicate.Node{}, nq.predicates...), @@ -291,7 +287,6 @@ func (nq *NodeQuery) Clone() *NodeQuery { // clone intermediate query. gremlin: nq.gremlin.Clone(), path: nq.path, - unique: nq.unique, } } @@ -332,9 +327,9 @@ func (nq *NodeQuery) WithNext(opts ...func(*NodeQuery)) *NodeQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (nq *NodeQuery) GroupBy(field string, fields ...string) *NodeGroupBy { - nq.fields = append([]string{field}, fields...) + nq.ctx.Fields = append([]string{field}, fields...) grbuild := &NodeGroupBy{build: nq} - grbuild.flds = &nq.fields + grbuild.flds = &nq.ctx.Fields grbuild.label = node.Label grbuild.scan = grbuild.Scan return grbuild @@ -353,10 +348,10 @@ func (nq *NodeQuery) GroupBy(field string, fields ...string) *NodeGroupBy { // Select(node.FieldValue). // Scan(ctx, &v) func (nq *NodeQuery) Select(fields ...string) *NodeSelect { - nq.fields = append(nq.fields, fields...) + nq.ctx.Fields = append(nq.ctx.Fields, fields...) sbuild := &NodeSelect{NodeQuery: nq} sbuild.label = node.Label - sbuild.flds, sbuild.scan = &nq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &nq.ctx.Fields, sbuild.Scan return sbuild } @@ -389,9 +384,9 @@ func (nq *NodeQuery) prepareQuery(ctx context.Context) error { func (nq *NodeQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Node, error) { res := &gremlin.Response{} traversal := nq.gremlinQuery(ctx) - if len(nq.fields) > 0 { - fields := make([]any, len(nq.fields)) - for i, f := range nq.fields { + if len(nq.ctx.Fields) > 0 { + fields := make([]any, len(nq.ctx.Fields)) + for i, f := range nq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -433,7 +428,7 @@ func (nq *NodeQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := nq.limit, nq.offset; { + switch limit, offset := nq.ctx.Limit, nq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -441,7 +436,7 @@ func (nq *NodeQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := nq.unique; unique == nil || *unique { + if unique := nq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -461,7 +456,7 @@ func (ngb *NodeGroupBy) Aggregate(fns ...AggregateFunc) *NodeGroupBy { // Scan applies the selector query and scans the result into the given value. func (ngb *NodeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeNode, "GroupBy") + ctx = setContextOp(ctx, ngb.build.ctx, "GroupBy") if err := ngb.build.prepareQuery(ctx); err != nil { return err } @@ -516,7 +511,7 @@ func (ns *NodeSelect) Aggregate(fns ...AggregateFunc) *NodeSelect { // Scan applies the selector query and scans the result into the given value. func (ns *NodeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeNode, "Select") + ctx = setContextOp(ctx, ns.ctx, "Select") if err := ns.prepareQuery(ctx); err != nil { return err } @@ -528,15 +523,15 @@ func (ns *NodeSelect) gremlinScan(ctx context.Context, root *NodeQuery, v any) e res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(ns.fields) == 1 { - if ns.fields[0] != node.FieldID { - traversal = traversal.Values(ns.fields...) + if fields := ns.ctx.Fields; len(fields) == 1 { + if fields[0] != node.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(ns.fields)) - for i, f := range ns.fields { + fields := make([]any, len(ns.ctx.Fields)) + for i, f := range ns.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -545,7 +540,7 @@ func (ns *NodeSelect) gremlinScan(ctx context.Context, root *NodeQuery, v any) e if err := ns.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/pet_query.go b/entc/integration/gremlin/ent/pet_query.go index e01c20909..bf3135d7f 100644 --- a/entc/integration/gremlin/ent/pet_query.go +++ b/entc/integration/gremlin/ent/pet_query.go @@ -23,11 +23,8 @@ import ( // PetQuery is the builder for querying Pet entities. type PetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Pet withTeam *UserQuery @@ -45,20 +42,20 @@ func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { // Limit the number of records to be returned by this query. func (pq *PetQuery) Limit(limit int) *PetQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PetQuery) Offset(offset int) *PetQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PetQuery) Unique(unique bool) *PetQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -99,7 +96,7 @@ func (pq *PetQuery) QueryOwner() *UserQuery { // First returns the first Pet entity from the query. // Returns a *NotFoundError when no Pet was found. func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePet, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -122,7 +119,7 @@ func (pq *PetQuery) FirstX(ctx context.Context) *Pet { // Returns a *NotFoundError when no Pet ID was found. func (pq *PetQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePet, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -145,7 +142,7 @@ func (pq *PetQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Pet entity is found. // Returns a *NotFoundError when no Pet entities are found. func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePet, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -173,7 +170,7 @@ func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { // Returns a *NotFoundError when no entities are found. func (pq *PetQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePet, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -198,7 +195,7 @@ func (pq *PetQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Pets. func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { - ctx = newQueryContext(ctx, TypePet, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -218,7 +215,7 @@ func (pq *PetQuery) AllX(ctx context.Context) []*Pet { // IDs executes the query and returns a list of Pet IDs. func (pq *PetQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypePet, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(pet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -236,7 +233,7 @@ func (pq *PetQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (pq *PetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePet, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -254,7 +251,7 @@ func (pq *PetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePet, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -282,8 +279,7 @@ func (pq *PetQuery) Clone() *PetQuery { } return &PetQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Pet{}, pq.predicates...), @@ -292,7 +288,6 @@ func (pq *PetQuery) Clone() *PetQuery { // clone intermediate query. gremlin: pq.gremlin.Clone(), path: pq.path, - unique: pq.unique, } } @@ -333,9 +328,9 @@ func (pq *PetQuery) WithOwner(opts ...func(*UserQuery)) *PetQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PetGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = pet.Label grbuild.scan = grbuild.Scan return grbuild @@ -354,10 +349,10 @@ func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { // Select(pet.FieldAge). // Scan(ctx, &v) func (pq *PetQuery) Select(fields ...string) *PetSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PetSelect{PetQuery: pq} sbuild.label = pet.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -390,9 +385,9 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { func (pq *PetQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { res := &gremlin.Response{} traversal := pq.gremlinQuery(ctx) - if len(pq.fields) > 0 { - fields := make([]any, len(pq.fields)) - for i, f := range pq.fields { + if len(pq.ctx.Fields) > 0 { + fields := make([]any, len(pq.ctx.Fields)) + for i, f := range pq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -434,7 +429,7 @@ func (pq *PetQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := pq.limit, pq.offset; { + switch limit, offset := pq.ctx.Limit, pq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -442,7 +437,7 @@ func (pq *PetQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := pq.unique; unique == nil || *unique { + if unique := pq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -462,7 +457,7 @@ func (pgb *PetGroupBy) Aggregate(fns ...AggregateFunc) *PetGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -517,7 +512,7 @@ func (ps *PetSelect) Aggregate(fns ...AggregateFunc) *PetSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } @@ -529,15 +524,15 @@ func (ps *PetSelect) gremlinScan(ctx context.Context, root *PetQuery, v any) err res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(ps.fields) == 1 { - if ps.fields[0] != pet.FieldID { - traversal = traversal.Values(ps.fields...) + if fields := ps.ctx.Fields; len(fields) == 1 { + if fields[0] != pet.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(ps.fields)) - for i, f := range ps.fields { + fields := make([]any, len(ps.ctx.Fields)) + for i, f := range ps.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -546,7 +541,7 @@ func (ps *PetSelect) gremlinScan(ctx context.Context, root *PetQuery, v any) err if err := ps.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/spec_query.go b/entc/integration/gremlin/ent/spec_query.go index d5eb6f997..eff174263 100644 --- a/entc/integration/gremlin/ent/spec_query.go +++ b/entc/integration/gremlin/ent/spec_query.go @@ -22,11 +22,8 @@ import ( // SpecQuery is the builder for querying Spec entities. type SpecQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Spec withCard *CardQuery @@ -43,20 +40,20 @@ func (sq *SpecQuery) Where(ps ...predicate.Spec) *SpecQuery { // Limit the number of records to be returned by this query. func (sq *SpecQuery) Limit(limit int) *SpecQuery { - sq.limit = &limit + sq.ctx.Limit = &limit return sq } // Offset to start from. func (sq *SpecQuery) Offset(offset int) *SpecQuery { - sq.offset = &offset + sq.ctx.Offset = &offset return sq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (sq *SpecQuery) Unique(unique bool) *SpecQuery { - sq.unique = &unique + sq.ctx.Unique = &unique return sq } @@ -83,7 +80,7 @@ func (sq *SpecQuery) QueryCard() *CardQuery { // First returns the first Spec entity from the query. // Returns a *NotFoundError when no Spec was found. func (sq *SpecQuery) First(ctx context.Context) (*Spec, error) { - nodes, err := sq.Limit(1).All(newQueryContext(ctx, TypeSpec, "First")) + nodes, err := sq.Limit(1).All(setContextOp(ctx, sq.ctx, "First")) if err != nil { return nil, err } @@ -106,7 +103,7 @@ func (sq *SpecQuery) FirstX(ctx context.Context) *Spec { // Returns a *NotFoundError when no Spec ID was found. func (sq *SpecQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = sq.Limit(1).IDs(newQueryContext(ctx, TypeSpec, "FirstID")); err != nil { + if ids, err = sq.Limit(1).IDs(setContextOp(ctx, sq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -129,7 +126,7 @@ func (sq *SpecQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Spec entity is found. // Returns a *NotFoundError when no Spec entities are found. func (sq *SpecQuery) Only(ctx context.Context) (*Spec, error) { - nodes, err := sq.Limit(2).All(newQueryContext(ctx, TypeSpec, "Only")) + nodes, err := sq.Limit(2).All(setContextOp(ctx, sq.ctx, "Only")) if err != nil { return nil, err } @@ -157,7 +154,7 @@ func (sq *SpecQuery) OnlyX(ctx context.Context) *Spec { // Returns a *NotFoundError when no entities are found. func (sq *SpecQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = sq.Limit(2).IDs(newQueryContext(ctx, TypeSpec, "OnlyID")); err != nil { + if ids, err = sq.Limit(2).IDs(setContextOp(ctx, sq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -182,7 +179,7 @@ func (sq *SpecQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Specs. func (sq *SpecQuery) All(ctx context.Context) ([]*Spec, error) { - ctx = newQueryContext(ctx, TypeSpec, "All") + ctx = setContextOp(ctx, sq.ctx, "All") if err := sq.prepareQuery(ctx); err != nil { return nil, err } @@ -202,7 +199,7 @@ func (sq *SpecQuery) AllX(ctx context.Context) []*Spec { // IDs executes the query and returns a list of Spec IDs. func (sq *SpecQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeSpec, "IDs") + ctx = setContextOp(ctx, sq.ctx, "IDs") if err := sq.Select(spec.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -220,7 +217,7 @@ func (sq *SpecQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (sq *SpecQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeSpec, "Count") + ctx = setContextOp(ctx, sq.ctx, "Count") if err := sq.prepareQuery(ctx); err != nil { return 0, err } @@ -238,7 +235,7 @@ func (sq *SpecQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (sq *SpecQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeSpec, "Exist") + ctx = setContextOp(ctx, sq.ctx, "Exist") switch _, err := sq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -266,8 +263,7 @@ func (sq *SpecQuery) Clone() *SpecQuery { } return &SpecQuery{ config: sq.config, - limit: sq.limit, - offset: sq.offset, + ctx: sq.ctx.Clone(), order: append([]OrderFunc{}, sq.order...), inters: append([]Interceptor{}, sq.inters...), predicates: append([]predicate.Spec{}, sq.predicates...), @@ -275,7 +271,6 @@ func (sq *SpecQuery) Clone() *SpecQuery { // clone intermediate query. gremlin: sq.gremlin.Clone(), path: sq.path, - unique: sq.unique, } } @@ -293,9 +288,9 @@ func (sq *SpecQuery) WithCard(opts ...func(*CardQuery)) *SpecQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (sq *SpecQuery) GroupBy(field string, fields ...string) *SpecGroupBy { - sq.fields = append([]string{field}, fields...) + sq.ctx.Fields = append([]string{field}, fields...) grbuild := &SpecGroupBy{build: sq} - grbuild.flds = &sq.fields + grbuild.flds = &sq.ctx.Fields grbuild.label = spec.Label grbuild.scan = grbuild.Scan return grbuild @@ -304,10 +299,10 @@ func (sq *SpecQuery) GroupBy(field string, fields ...string) *SpecGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (sq *SpecQuery) Select(fields ...string) *SpecSelect { - sq.fields = append(sq.fields, fields...) + sq.ctx.Fields = append(sq.ctx.Fields, fields...) sbuild := &SpecSelect{SpecQuery: sq} sbuild.label = spec.Label - sbuild.flds, sbuild.scan = &sq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &sq.ctx.Fields, sbuild.Scan return sbuild } @@ -340,9 +335,9 @@ func (sq *SpecQuery) prepareQuery(ctx context.Context) error { func (sq *SpecQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Spec, error) { res := &gremlin.Response{} traversal := sq.gremlinQuery(ctx) - if len(sq.fields) > 0 { - fields := make([]any, len(sq.fields)) - for i, f := range sq.fields { + if len(sq.ctx.Fields) > 0 { + fields := make([]any, len(sq.ctx.Fields)) + for i, f := range sq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -384,7 +379,7 @@ func (sq *SpecQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := sq.limit, sq.offset; { + switch limit, offset := sq.ctx.Limit, sq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -392,7 +387,7 @@ func (sq *SpecQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := sq.unique; unique == nil || *unique { + if unique := sq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -412,7 +407,7 @@ func (sgb *SpecGroupBy) Aggregate(fns ...AggregateFunc) *SpecGroupBy { // Scan applies the selector query and scans the result into the given value. func (sgb *SpecGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeSpec, "GroupBy") + ctx = setContextOp(ctx, sgb.build.ctx, "GroupBy") if err := sgb.build.prepareQuery(ctx); err != nil { return err } @@ -467,7 +462,7 @@ func (ss *SpecSelect) Aggregate(fns ...AggregateFunc) *SpecSelect { // Scan applies the selector query and scans the result into the given value. func (ss *SpecSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeSpec, "Select") + ctx = setContextOp(ctx, ss.ctx, "Select") if err := ss.prepareQuery(ctx); err != nil { return err } @@ -479,15 +474,15 @@ func (ss *SpecSelect) gremlinScan(ctx context.Context, root *SpecQuery, v any) e res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(ss.fields) == 1 { - if ss.fields[0] != spec.FieldID { - traversal = traversal.Values(ss.fields...) + if fields := ss.ctx.Fields; len(fields) == 1 { + if fields[0] != spec.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(ss.fields)) - for i, f := range ss.fields { + fields := make([]any, len(ss.ctx.Fields)) + for i, f := range ss.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -496,7 +491,7 @@ func (ss *SpecSelect) gremlinScan(ctx context.Context, root *SpecQuery, v any) e if err := ss.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/task_query.go b/entc/integration/gremlin/ent/task_query.go index 440dafa74..eb04963d8 100644 --- a/entc/integration/gremlin/ent/task_query.go +++ b/entc/integration/gremlin/ent/task_query.go @@ -23,11 +23,8 @@ import ( // TaskQuery is the builder for querying Task entities. type TaskQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Task // intermediate query (i.e. traversal path). @@ -43,20 +40,20 @@ func (tq *TaskQuery) Where(ps ...predicate.Task) *TaskQuery { // Limit the number of records to be returned by this query. func (tq *TaskQuery) Limit(limit int) *TaskQuery { - tq.limit = &limit + tq.ctx.Limit = &limit return tq } // Offset to start from. func (tq *TaskQuery) Offset(offset int) *TaskQuery { - tq.offset = &offset + tq.ctx.Offset = &offset return tq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tq *TaskQuery) Unique(unique bool) *TaskQuery { - tq.unique = &unique + tq.ctx.Unique = &unique return tq } @@ -69,7 +66,7 @@ func (tq *TaskQuery) Order(o ...OrderFunc) *TaskQuery { // First returns the first Task entity from the query. // Returns a *NotFoundError when no Task was found. func (tq *TaskQuery) First(ctx context.Context) (*Task, error) { - nodes, err := tq.Limit(1).All(newQueryContext(ctx, TypeTask, "First")) + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) if err != nil { return nil, err } @@ -92,7 +89,7 @@ func (tq *TaskQuery) FirstX(ctx context.Context) *Task { // Returns a *NotFoundError when no Task ID was found. func (tq *TaskQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = tq.Limit(1).IDs(newQueryContext(ctx, TypeTask, "FirstID")); err != nil { + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -115,7 +112,7 @@ func (tq *TaskQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one Task entity is found. // Returns a *NotFoundError when no Task entities are found. func (tq *TaskQuery) Only(ctx context.Context) (*Task, error) { - nodes, err := tq.Limit(2).All(newQueryContext(ctx, TypeTask, "Only")) + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) if err != nil { return nil, err } @@ -143,7 +140,7 @@ func (tq *TaskQuery) OnlyX(ctx context.Context) *Task { // Returns a *NotFoundError when no entities are found. func (tq *TaskQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = tq.Limit(2).IDs(newQueryContext(ctx, TypeTask, "OnlyID")); err != nil { + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -168,7 +165,7 @@ func (tq *TaskQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Tasks. func (tq *TaskQuery) All(ctx context.Context) ([]*Task, error) { - ctx = newQueryContext(ctx, TypeTask, "All") + ctx = setContextOp(ctx, tq.ctx, "All") if err := tq.prepareQuery(ctx); err != nil { return nil, err } @@ -188,7 +185,7 @@ func (tq *TaskQuery) AllX(ctx context.Context) []*Task { // IDs executes the query and returns a list of Task IDs. func (tq *TaskQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeTask, "IDs") + ctx = setContextOp(ctx, tq.ctx, "IDs") if err := tq.Select(enttask.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -206,7 +203,7 @@ func (tq *TaskQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (tq *TaskQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeTask, "Count") + ctx = setContextOp(ctx, tq.ctx, "Count") if err := tq.prepareQuery(ctx); err != nil { return 0, err } @@ -224,7 +221,7 @@ func (tq *TaskQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tq *TaskQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeTask, "Exist") + ctx = setContextOp(ctx, tq.ctx, "Exist") switch _, err := tq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -252,15 +249,13 @@ func (tq *TaskQuery) Clone() *TaskQuery { } return &TaskQuery{ config: tq.config, - limit: tq.limit, - offset: tq.offset, + ctx: tq.ctx.Clone(), order: append([]OrderFunc{}, tq.order...), inters: append([]Interceptor{}, tq.inters...), predicates: append([]predicate.Task{}, tq.predicates...), // clone intermediate query. gremlin: tq.gremlin.Clone(), path: tq.path, - unique: tq.unique, } } @@ -279,9 +274,9 @@ func (tq *TaskQuery) Clone() *TaskQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (tq *TaskQuery) GroupBy(field string, fields ...string) *TaskGroupBy { - tq.fields = append([]string{field}, fields...) + tq.ctx.Fields = append([]string{field}, fields...) grbuild := &TaskGroupBy{build: tq} - grbuild.flds = &tq.fields + grbuild.flds = &tq.ctx.Fields grbuild.label = enttask.Label grbuild.scan = grbuild.Scan return grbuild @@ -300,10 +295,10 @@ func (tq *TaskQuery) GroupBy(field string, fields ...string) *TaskGroupBy { // Select(enttask.FieldPriority). // Scan(ctx, &v) func (tq *TaskQuery) Select(fields ...string) *TaskSelect { - tq.fields = append(tq.fields, fields...) + tq.ctx.Fields = append(tq.ctx.Fields, fields...) sbuild := &TaskSelect{TaskQuery: tq} sbuild.label = enttask.Label - sbuild.flds, sbuild.scan = &tq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan return sbuild } @@ -336,9 +331,9 @@ func (tq *TaskQuery) prepareQuery(ctx context.Context) error { func (tq *TaskQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*Task, error) { res := &gremlin.Response{} traversal := tq.gremlinQuery(ctx) - if len(tq.fields) > 0 { - fields := make([]any, len(tq.fields)) - for i, f := range tq.fields { + if len(tq.ctx.Fields) > 0 { + fields := make([]any, len(tq.ctx.Fields)) + for i, f := range tq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -380,7 +375,7 @@ func (tq *TaskQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := tq.limit, tq.offset; { + switch limit, offset := tq.ctx.Limit, tq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -388,7 +383,7 @@ func (tq *TaskQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := tq.unique; unique == nil || *unique { + if unique := tq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -408,7 +403,7 @@ func (tgb *TaskGroupBy) Aggregate(fns ...AggregateFunc) *TaskGroupBy { // Scan applies the selector query and scans the result into the given value. func (tgb *TaskGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTask, "GroupBy") + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") if err := tgb.build.prepareQuery(ctx); err != nil { return err } @@ -463,7 +458,7 @@ func (ts *TaskSelect) Aggregate(fns ...AggregateFunc) *TaskSelect { // Scan applies the selector query and scans the result into the given value. func (ts *TaskSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTask, "Select") + ctx = setContextOp(ctx, ts.ctx, "Select") if err := ts.prepareQuery(ctx); err != nil { return err } @@ -475,15 +470,15 @@ func (ts *TaskSelect) gremlinScan(ctx context.Context, root *TaskQuery, v any) e res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(ts.fields) == 1 { - if ts.fields[0] != enttask.FieldID { - traversal = traversal.Values(ts.fields...) + if fields := ts.ctx.Fields; len(fields) == 1 { + if fields[0] != enttask.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(ts.fields)) - for i, f := range ts.fields { + fields := make([]any, len(ts.ctx.Fields)) + for i, f := range ts.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -492,7 +487,7 @@ func (ts *TaskSelect) gremlinScan(ctx context.Context, root *TaskQuery, v any) e if err := ts.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/gremlin/ent/user_query.go b/entc/integration/gremlin/ent/user_query.go index e5183ddca..6a355184f 100644 --- a/entc/integration/gremlin/ent/user_query.go +++ b/entc/integration/gremlin/ent/user_query.go @@ -22,11 +22,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withCard *CardQuery @@ -53,20 +50,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -233,7 +230,7 @@ func (uq *UserQuery) QueryParent() *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -256,7 +253,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -279,7 +276,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) string { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -307,7 +304,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id string, err error) { var ids []string - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -332,7 +329,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) string { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -352,7 +349,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]string, error) { var ids []string - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -370,7 +367,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []string { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -388,7 +385,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -416,8 +413,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -435,7 +431,6 @@ func (uq *UserQuery) Clone() *UserQuery { // clone intermediate query. gremlin: uq.gremlin.Clone(), path: uq.path, - unique: uq.unique, } } @@ -575,9 +570,9 @@ func (uq *UserQuery) WithParent(opts ...func(*UserQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -596,10 +591,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldOptionalInt). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -632,9 +627,9 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { func (uq *UserQuery) gremlinAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { res := &gremlin.Response{} traversal := uq.gremlinQuery(ctx) - if len(uq.fields) > 0 { - fields := make([]any, len(uq.fields)) - for i, f := range uq.fields { + if len(uq.ctx.Fields) > 0 { + fields := make([]any, len(uq.ctx.Fields)) + for i, f := range uq.ctx.Fields { fields[i] = f } traversal.ValueMap(fields...) @@ -676,7 +671,7 @@ func (uq *UserQuery) gremlinQuery(context.Context) *dsl.Traversal { p(v) } } - switch limit, offset := uq.limit, uq.offset; { + switch limit, offset := uq.ctx.Limit, uq.ctx.Offset; { case limit != nil && offset != nil: v.Range(*offset, *offset+*limit) case offset != nil: @@ -684,7 +679,7 @@ func (uq *UserQuery) gremlinQuery(context.Context) *dsl.Traversal { case limit != nil: v.Limit(*limit) } - if unique := uq.unique; unique == nil || *unique { + if unique := uq.ctx.Unique; unique == nil || *unique { v.Dedup() } return v @@ -704,7 +699,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -759,7 +754,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } @@ -771,15 +766,15 @@ func (us *UserSelect) gremlinScan(ctx context.Context, root *UserQuery, v any) e res = &gremlin.Response{} traversal = root.gremlinQuery(ctx) ) - if len(us.fields) == 1 { - if us.fields[0] != user.FieldID { - traversal = traversal.Values(us.fields...) + if fields := us.ctx.Fields; len(fields) == 1 { + if fields[0] != user.FieldID { + traversal = traversal.Values(fields...) } else { traversal = traversal.ID() } } else { - fields := make([]any, len(us.fields)) - for i, f := range us.fields { + fields := make([]any, len(us.ctx.Fields)) + for i, f := range us.ctx.Fields { fields[i] = f } traversal = traversal.ValueMap(fields...) @@ -788,7 +783,7 @@ func (us *UserSelect) gremlinScan(ctx context.Context, root *UserQuery, v any) e if err := us.driver.Exec(ctx, query, bindings, res); err != nil { return err } - if len(root.fields) == 1 { + if len(root.ctx.Fields) == 1 { return res.ReadVal(v) } vm, err := res.ReadValueMap() diff --git a/entc/integration/hooks/ent/card_query.go b/entc/integration/hooks/ent/card_query.go index b31b1a1b4..ae454a27e 100644 --- a/entc/integration/hooks/ent/card_query.go +++ b/entc/integration/hooks/ent/card_query.go @@ -22,11 +22,8 @@ import ( // CardQuery is the builder for querying Card entities. type CardQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Card withOwner *UserQuery @@ -44,20 +41,20 @@ func (cq *CardQuery) Where(ps ...predicate.Card) *CardQuery { // Limit the number of records to be returned by this query. func (cq *CardQuery) Limit(limit int) *CardQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CardQuery) Offset(offset int) *CardQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CardQuery) Unique(unique bool) *CardQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -92,7 +89,7 @@ func (cq *CardQuery) QueryOwner() *UserQuery { // First returns the first Card entity from the query. // Returns a *NotFoundError when no Card was found. func (cq *CardQuery) First(ctx context.Context) (*Card, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeCard, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (cq *CardQuery) FirstX(ctx context.Context) *Card { // Returns a *NotFoundError when no Card ID was found. func (cq *CardQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeCard, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (cq *CardQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Card entity is found. // Returns a *NotFoundError when no Card entities are found. func (cq *CardQuery) Only(ctx context.Context) (*Card, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeCard, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (cq *CardQuery) OnlyX(ctx context.Context) *Card { // Returns a *NotFoundError when no entities are found. func (cq *CardQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeCard, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (cq *CardQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Cards. func (cq *CardQuery) All(ctx context.Context) ([]*Card, error) { - ctx = newQueryContext(ctx, TypeCard, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (cq *CardQuery) AllX(ctx context.Context) []*Card { // IDs executes the query and returns a list of Card IDs. func (cq *CardQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeCard, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(card.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (cq *CardQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *CardQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCard, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (cq *CardQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CardQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCard, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (cq *CardQuery) Clone() *CardQuery { } return &CardQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Card{}, cq.predicates...), withOwner: cq.withOwner.Clone(), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -314,9 +309,9 @@ func (cq *CardQuery) WithOwner(opts ...func(*UserQuery)) *CardQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (cq *CardQuery) GroupBy(field string, fields ...string) *CardGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CardGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = card.Label grbuild.scan = grbuild.Scan return grbuild @@ -335,10 +330,10 @@ func (cq *CardQuery) GroupBy(field string, fields ...string) *CardGroupBy { // Select(card.FieldNumber). // Scan(ctx, &v) func (cq *CardQuery) Select(fields ...string) *CardSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CardSelect{CardQuery: cq} sbuild.label = card.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -358,7 +353,7 @@ func (cq *CardQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !card.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -450,9 +445,9 @@ func (cq *CardQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*C func (cq *CardQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -470,10 +465,10 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, card.FieldID) for i := range fields { @@ -489,10 +484,10 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -508,7 +503,7 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(card.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = card.Columns } @@ -517,7 +512,7 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -526,12 +521,12 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -551,7 +546,7 @@ func (cgb *CardGroupBy) Aggregate(fns ...AggregateFunc) *CardGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CardGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCard, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -599,7 +594,7 @@ func (cs *CardSelect) Aggregate(fns ...AggregateFunc) *CardSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CardSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCard, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/hooks/ent/client.go b/entc/integration/hooks/ent/client.go index 7fa66d2cf..1095202aa 100644 --- a/entc/integration/hooks/ent/client.go +++ b/entc/integration/hooks/ent/client.go @@ -237,6 +237,7 @@ func (c *CardClient) DeleteOneID(id int) *CardDeleteOne { func (c *CardClient) Query() *CardQuery { return &CardQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCard}, inters: c.Interceptors(), } } @@ -372,6 +373,7 @@ func (c *PetClient) DeleteOneID(id int) *PetDeleteOne { func (c *PetClient) Query() *PetQuery { return &PetQuery{ config: c.config, + ctx: &QueryContext{Type: TypePet}, inters: c.Interceptors(), } } @@ -507,6 +509,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/hooks/ent/ent.go b/entc/integration/hooks/ent/ent.go index 6b23f9fdf..fb1c56a9f 100644 --- a/entc/integration/hooks/ent/ent.go +++ b/entc/integration/hooks/ent/ent.go @@ -26,6 +26,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -511,10 +512,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/hooks/ent/pet_query.go b/entc/integration/hooks/ent/pet_query.go index cf3ad2852..28c1ce313 100644 --- a/entc/integration/hooks/ent/pet_query.go +++ b/entc/integration/hooks/ent/pet_query.go @@ -22,11 +22,8 @@ import ( // PetQuery is the builder for querying Pet entities. type PetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Pet withOwner *UserQuery @@ -44,20 +41,20 @@ func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { // Limit the number of records to be returned by this query. func (pq *PetQuery) Limit(limit int) *PetQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PetQuery) Offset(offset int) *PetQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PetQuery) Unique(unique bool) *PetQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -92,7 +89,7 @@ func (pq *PetQuery) QueryOwner() *UserQuery { // First returns the first Pet entity from the query. // Returns a *NotFoundError when no Pet was found. func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePet, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (pq *PetQuery) FirstX(ctx context.Context) *Pet { // Returns a *NotFoundError when no Pet ID was found. func (pq *PetQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePet, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (pq *PetQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Pet entity is found. // Returns a *NotFoundError when no Pet entities are found. func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePet, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { // Returns a *NotFoundError when no entities are found. func (pq *PetQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePet, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (pq *PetQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Pets. func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { - ctx = newQueryContext(ctx, TypePet, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (pq *PetQuery) AllX(ctx context.Context) []*Pet { // IDs executes the query and returns a list of Pet IDs. func (pq *PetQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypePet, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(pet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (pq *PetQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pq *PetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePet, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (pq *PetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePet, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (pq *PetQuery) Clone() *PetQuery { } return &PetQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Pet{}, pq.predicates...), withOwner: pq.withOwner.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -314,9 +309,9 @@ func (pq *PetQuery) WithOwner(opts ...func(*UserQuery)) *PetQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PetGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = pet.Label grbuild.scan = grbuild.Scan return grbuild @@ -335,10 +330,10 @@ func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { // Select(pet.FieldDeleteTime). // Scan(ctx, &v) func (pq *PetQuery) Select(fields ...string) *PetSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PetSelect{PetQuery: pq} sbuild.label = pet.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -358,7 +353,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !pet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -450,9 +445,9 @@ func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pe func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -470,10 +465,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, pet.FieldID) for i := range fields { @@ -489,10 +484,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -508,7 +503,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = pet.Columns } @@ -517,7 +512,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } for _, p := range pq.predicates { @@ -526,12 +521,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -551,7 +546,7 @@ func (pgb *PetGroupBy) Aggregate(fns ...AggregateFunc) *PetGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -599,7 +594,7 @@ func (ps *PetSelect) Aggregate(fns ...AggregateFunc) *PetSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/hooks/ent/user_query.go b/entc/integration/hooks/ent/user_query.go index d5fe22d39..40d4e7c8f 100644 --- a/entc/integration/hooks/ent/user_query.go +++ b/entc/integration/hooks/ent/user_query.go @@ -24,11 +24,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withCards *CardQuery @@ -49,20 +46,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -163,7 +160,7 @@ func (uq *UserQuery) QueryBestFriend() *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -186,7 +183,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -209,7 +206,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -237,7 +234,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -262,7 +259,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -282,7 +279,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -300,7 +297,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -318,7 +315,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -346,8 +343,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -356,9 +352,8 @@ func (uq *UserQuery) Clone() *UserQuery { withFriends: uq.withFriends.Clone(), withBestFriend: uq.withBestFriend.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -421,9 +416,9 @@ func (uq *UserQuery) WithBestFriend(opts ...func(*UserQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -442,10 +437,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldVersion). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -465,7 +460,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -701,9 +696,9 @@ func (uq *UserQuery) loadBestFriend(ctx context.Context, query *UserQuery, nodes func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -721,10 +716,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -740,10 +735,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -759,7 +754,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -768,7 +763,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -777,12 +772,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -802,7 +797,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -850,7 +845,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/hooks/hooks_test.go b/entc/integration/hooks/hooks_test.go index 3650bcc23..7c848b96f 100644 --- a/entc/integration/hooks/hooks_test.go +++ b/entc/integration/hooks/hooks_test.go @@ -706,13 +706,17 @@ func TestTraverseUnique(t *testing.T) { // Disable unique traversal using interceptors. client.User.Intercept( - intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error { - q.Unique(false) + intercept.Func(func(ctx context.Context, q intercept.Query) error { + // Skip setting the Unique if the modifier was set explicitly. + if entgo.QueryFromContext(ctx).Unique == nil { + q.Unique(false) + } return nil }), ) // The JOIN with pets will return the same owner twice, one for each pet. require.Equal(t, 2, client.Pet.Query().QueryOwner().CountX(ctx)) + require.Equal(t, 1, client.Pet.Query().QueryOwner().Unique(true).CountX(ctx)) } // The following example demonstrates how to write interceptors that diff --git a/entc/integration/idtype/ent/client.go b/entc/integration/idtype/ent/client.go index 22255e618..c1955c4ad 100644 --- a/entc/integration/idtype/ent/client.go +++ b/entc/integration/idtype/ent/client.go @@ -217,6 +217,7 @@ func (c *UserClient) DeleteOneID(id uint64) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/idtype/ent/ent.go b/entc/integration/idtype/ent/ent.go index d15f22927..de4b0c041 100644 --- a/entc/integration/idtype/ent/ent.go +++ b/entc/integration/idtype/ent/ent.go @@ -24,6 +24,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -507,10 +508,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/idtype/ent/user_query.go b/entc/integration/idtype/ent/user_query.go index 1a1ab5edb..dc958893f 100644 --- a/entc/integration/idtype/ent/user_query.go +++ b/entc/integration/idtype/ent/user_query.go @@ -22,11 +22,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withSpouse *UserQuery @@ -46,20 +43,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -138,7 +135,7 @@ func (uq *UserQuery) QueryFollowing() *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -161,7 +158,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id uint64, err error) { var ids []uint64 - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -184,7 +181,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) uint64 { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -212,7 +209,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id uint64, err error) { var ids []uint64 - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -237,7 +234,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) uint64 { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -257,7 +254,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]uint64, error) { var ids []uint64 - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -275,7 +272,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []uint64 { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -293,7 +290,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -321,8 +318,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -330,9 +326,8 @@ func (uq *UserQuery) Clone() *UserQuery { withFollowers: uq.withFollowers.Clone(), withFollowing: uq.withFollowing.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -384,9 +379,9 @@ func (uq *UserQuery) WithFollowing(opts ...func(*UserQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -405,10 +400,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldName). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -428,7 +423,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -652,9 +647,9 @@ func (uq *UserQuery) loadFollowing(ctx context.Context, query *UserQuery, nodes func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -672,10 +667,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -691,10 +686,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -710,7 +705,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -719,7 +714,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -728,12 +723,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -753,7 +748,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -801,7 +796,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/json/ent/client.go b/entc/integration/json/ent/client.go index 4625dc88a..a37376aee 100644 --- a/entc/integration/json/ent/client.go +++ b/entc/integration/json/ent/client.go @@ -216,6 +216,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/json/ent/ent.go b/entc/integration/json/ent/ent.go index 8b933b46a..d2eddb164 100644 --- a/entc/integration/json/ent/ent.go +++ b/entc/integration/json/ent/ent.go @@ -24,6 +24,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -507,10 +508,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/json/ent/user_query.go b/entc/integration/json/ent/user_query.go index f019a24c9..8138bb42f 100644 --- a/entc/integration/json/ent/user_query.go +++ b/entc/integration/json/ent/user_query.go @@ -21,11 +21,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User modifiers []func(*sql.Selector) @@ -42,20 +39,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -68,7 +65,7 @@ func (uq *UserQuery) Order(o ...OrderFunc) *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -91,7 +88,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -114,7 +111,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -142,7 +139,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -167,7 +164,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -187,7 +184,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -205,7 +202,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -223,7 +220,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -251,15 +248,13 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -278,9 +273,9 @@ func (uq *UserQuery) Clone() *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -299,10 +294,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldT). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -322,7 +317,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -370,9 +365,9 @@ func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { if len(uq.modifiers) > 0 { _spec.Modifiers = uq.modifiers } - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -390,10 +385,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -409,10 +404,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -428,7 +423,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -437,7 +432,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, m := range uq.modifiers { @@ -449,12 +444,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -480,7 +475,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -528,7 +523,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv1/car_query.go b/entc/integration/migrate/entv1/car_query.go index 3842c9a24..dc1ef0953 100644 --- a/entc/integration/migrate/entv1/car_query.go +++ b/entc/integration/migrate/entv1/car_query.go @@ -22,11 +22,8 @@ import ( // CarQuery is the builder for querying Car entities. type CarQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Car withOwner *UserQuery @@ -44,20 +41,20 @@ func (cq *CarQuery) Where(ps ...predicate.Car) *CarQuery { // Limit the number of records to be returned by this query. func (cq *CarQuery) Limit(limit int) *CarQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CarQuery) Offset(offset int) *CarQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CarQuery) Unique(unique bool) *CarQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -92,7 +89,7 @@ func (cq *CarQuery) QueryOwner() *UserQuery { // First returns the first Car entity from the query. // Returns a *NotFoundError when no Car was found. func (cq *CarQuery) First(ctx context.Context) (*Car, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeCar, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (cq *CarQuery) FirstX(ctx context.Context) *Car { // Returns a *NotFoundError when no Car ID was found. func (cq *CarQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeCar, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (cq *CarQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Car entity is found. // Returns a *NotFoundError when no Car entities are found. func (cq *CarQuery) Only(ctx context.Context) (*Car, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeCar, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (cq *CarQuery) OnlyX(ctx context.Context) *Car { // Returns a *NotFoundError when no entities are found. func (cq *CarQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeCar, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (cq *CarQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Cars. func (cq *CarQuery) All(ctx context.Context) ([]*Car, error) { - ctx = newQueryContext(ctx, TypeCar, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (cq *CarQuery) AllX(ctx context.Context) []*Car { // IDs executes the query and returns a list of Car IDs. func (cq *CarQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeCar, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(car.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (cq *CarQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *CarQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCar, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (cq *CarQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CarQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCar, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (cq *CarQuery) Clone() *CarQuery { } return &CarQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Car{}, cq.predicates...), withOwner: cq.withOwner.Clone(), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -302,9 +297,9 @@ func (cq *CarQuery) WithOwner(opts ...func(*UserQuery)) *CarQuery { // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (cq *CarQuery) GroupBy(field string, fields ...string) *CarGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CarGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = car.Label grbuild.scan = grbuild.Scan return grbuild @@ -313,10 +308,10 @@ func (cq *CarQuery) GroupBy(field string, fields ...string) *CarGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (cq *CarQuery) Select(fields ...string) *CarSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CarSelect{CarQuery: cq} sbuild.label = car.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -336,7 +331,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !car.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv1: invalid field %q for query", f)} } @@ -428,9 +423,9 @@ func (cq *CarQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Ca func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -448,10 +443,10 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, car.FieldID) for i := range fields { @@ -467,10 +462,10 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -486,7 +481,7 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(car.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = car.Columns } @@ -495,7 +490,7 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -504,12 +499,12 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -529,7 +524,7 @@ func (cgb *CarGroupBy) Aggregate(fns ...AggregateFunc) *CarGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CarGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCar, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -577,7 +572,7 @@ func (cs *CarSelect) Aggregate(fns ...AggregateFunc) *CarSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CarSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCar, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv1/client.go b/entc/integration/migrate/entv1/client.go index 4f69d7d6a..22ee2a9a4 100644 --- a/entc/integration/migrate/entv1/client.go +++ b/entc/integration/migrate/entv1/client.go @@ -247,6 +247,7 @@ func (c *CarClient) DeleteOneID(id int) *CarDeleteOne { func (c *CarClient) Query() *CarQuery { return &CarQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCar}, inters: c.Interceptors(), } } @@ -380,6 +381,7 @@ func (c *ConversionClient) DeleteOneID(id int) *ConversionDeleteOne { func (c *ConversionClient) Query() *ConversionQuery { return &ConversionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeConversion}, inters: c.Interceptors(), } } @@ -497,6 +499,7 @@ func (c *CustomTypeClient) DeleteOneID(id int) *CustomTypeDeleteOne { func (c *CustomTypeClient) Query() *CustomTypeQuery { return &CustomTypeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCustomType}, inters: c.Interceptors(), } } @@ -614,6 +617,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/migrate/entv1/conversion_query.go b/entc/integration/migrate/entv1/conversion_query.go index c5c0d651f..743f3b67b 100644 --- a/entc/integration/migrate/entv1/conversion_query.go +++ b/entc/integration/migrate/entv1/conversion_query.go @@ -21,11 +21,8 @@ import ( // ConversionQuery is the builder for querying Conversion entities. type ConversionQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Conversion // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (cq *ConversionQuery) Where(ps ...predicate.Conversion) *ConversionQuery { // Limit the number of records to be returned by this query. func (cq *ConversionQuery) Limit(limit int) *ConversionQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *ConversionQuery) Offset(offset int) *ConversionQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *ConversionQuery) Unique(unique bool) *ConversionQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -67,7 +64,7 @@ func (cq *ConversionQuery) Order(o ...OrderFunc) *ConversionQuery { // First returns the first Conversion entity from the query. // Returns a *NotFoundError when no Conversion was found. func (cq *ConversionQuery) First(ctx context.Context) (*Conversion, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeConversion, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (cq *ConversionQuery) FirstX(ctx context.Context) *Conversion { // Returns a *NotFoundError when no Conversion ID was found. func (cq *ConversionQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeConversion, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (cq *ConversionQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Conversion entity is found. // Returns a *NotFoundError when no Conversion entities are found. func (cq *ConversionQuery) Only(ctx context.Context) (*Conversion, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeConversion, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (cq *ConversionQuery) OnlyX(ctx context.Context) *Conversion { // Returns a *NotFoundError when no entities are found. func (cq *ConversionQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeConversion, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (cq *ConversionQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Conversions. func (cq *ConversionQuery) All(ctx context.Context) ([]*Conversion, error) { - ctx = newQueryContext(ctx, TypeConversion, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (cq *ConversionQuery) AllX(ctx context.Context) []*Conversion { // IDs executes the query and returns a list of Conversion IDs. func (cq *ConversionQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeConversion, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(conversion.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (cq *ConversionQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *ConversionQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeConversion, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (cq *ConversionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *ConversionQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeConversion, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (cq *ConversionQuery) Clone() *ConversionQuery { } return &ConversionQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Conversion{}, cq.predicates...), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -277,9 +272,9 @@ func (cq *ConversionQuery) Clone() *ConversionQuery { // Aggregate(entv1.Count()). // Scan(ctx, &v) func (cq *ConversionQuery) GroupBy(field string, fields ...string) *ConversionGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &ConversionGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = conversion.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (cq *ConversionQuery) GroupBy(field string, fields ...string) *ConversionGr // Select(conversion.FieldName). // Scan(ctx, &v) func (cq *ConversionQuery) Select(fields ...string) *ConversionSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &ConversionSelect{ConversionQuery: cq} sbuild.label = conversion.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (cq *ConversionQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !conversion.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv1: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (cq *ConversionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*C func (cq *ConversionQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -383,10 +378,10 @@ func (cq *ConversionQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, conversion.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (cq *ConversionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (cq *ConversionQuery) querySpec() *sqlgraph.QuerySpec { func (cq *ConversionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(conversion.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = conversion.Columns } @@ -430,7 +425,7 @@ func (cq *ConversionQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -439,12 +434,12 @@ func (cq *ConversionQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (cgb *ConversionGroupBy) Aggregate(fns ...AggregateFunc) *ConversionGroupBy // Scan applies the selector query and scans the result into the given value. func (cgb *ConversionGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeConversion, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (cs *ConversionSelect) Aggregate(fns ...AggregateFunc) *ConversionSelect { // Scan applies the selector query and scans the result into the given value. func (cs *ConversionSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeConversion, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv1/customtype_query.go b/entc/integration/migrate/entv1/customtype_query.go index 89975b8d0..7d394e298 100644 --- a/entc/integration/migrate/entv1/customtype_query.go +++ b/entc/integration/migrate/entv1/customtype_query.go @@ -21,11 +21,8 @@ import ( // CustomTypeQuery is the builder for querying CustomType entities. type CustomTypeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.CustomType // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (ctq *CustomTypeQuery) Where(ps ...predicate.CustomType) *CustomTypeQuery { // Limit the number of records to be returned by this query. func (ctq *CustomTypeQuery) Limit(limit int) *CustomTypeQuery { - ctq.limit = &limit + ctq.ctx.Limit = &limit return ctq } // Offset to start from. func (ctq *CustomTypeQuery) Offset(offset int) *CustomTypeQuery { - ctq.offset = &offset + ctq.ctx.Offset = &offset return ctq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ctq *CustomTypeQuery) Unique(unique bool) *CustomTypeQuery { - ctq.unique = &unique + ctq.ctx.Unique = &unique return ctq } @@ -67,7 +64,7 @@ func (ctq *CustomTypeQuery) Order(o ...OrderFunc) *CustomTypeQuery { // First returns the first CustomType entity from the query. // Returns a *NotFoundError when no CustomType was found. func (ctq *CustomTypeQuery) First(ctx context.Context) (*CustomType, error) { - nodes, err := ctq.Limit(1).All(newQueryContext(ctx, TypeCustomType, "First")) + nodes, err := ctq.Limit(1).All(setContextOp(ctx, ctq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (ctq *CustomTypeQuery) FirstX(ctx context.Context) *CustomType { // Returns a *NotFoundError when no CustomType ID was found. func (ctq *CustomTypeQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ctq.Limit(1).IDs(newQueryContext(ctx, TypeCustomType, "FirstID")); err != nil { + if ids, err = ctq.Limit(1).IDs(setContextOp(ctx, ctq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (ctq *CustomTypeQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one CustomType entity is found. // Returns a *NotFoundError when no CustomType entities are found. func (ctq *CustomTypeQuery) Only(ctx context.Context) (*CustomType, error) { - nodes, err := ctq.Limit(2).All(newQueryContext(ctx, TypeCustomType, "Only")) + nodes, err := ctq.Limit(2).All(setContextOp(ctx, ctq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (ctq *CustomTypeQuery) OnlyX(ctx context.Context) *CustomType { // Returns a *NotFoundError when no entities are found. func (ctq *CustomTypeQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ctq.Limit(2).IDs(newQueryContext(ctx, TypeCustomType, "OnlyID")); err != nil { + if ids, err = ctq.Limit(2).IDs(setContextOp(ctx, ctq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (ctq *CustomTypeQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of CustomTypes. func (ctq *CustomTypeQuery) All(ctx context.Context) ([]*CustomType, error) { - ctx = newQueryContext(ctx, TypeCustomType, "All") + ctx = setContextOp(ctx, ctq.ctx, "All") if err := ctq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (ctq *CustomTypeQuery) AllX(ctx context.Context) []*CustomType { // IDs executes the query and returns a list of CustomType IDs. func (ctq *CustomTypeQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeCustomType, "IDs") + ctx = setContextOp(ctx, ctq.ctx, "IDs") if err := ctq.Select(customtype.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (ctq *CustomTypeQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (ctq *CustomTypeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCustomType, "Count") + ctx = setContextOp(ctx, ctq.ctx, "Count") if err := ctq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (ctq *CustomTypeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ctq *CustomTypeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCustomType, "Exist") + ctx = setContextOp(ctx, ctq.ctx, "Exist") switch _, err := ctq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (ctq *CustomTypeQuery) Clone() *CustomTypeQuery { } return &CustomTypeQuery{ config: ctq.config, - limit: ctq.limit, - offset: ctq.offset, + ctx: ctq.ctx.Clone(), order: append([]OrderFunc{}, ctq.order...), inters: append([]Interceptor{}, ctq.inters...), predicates: append([]predicate.CustomType{}, ctq.predicates...), // clone intermediate query. - sql: ctq.sql.Clone(), - path: ctq.path, - unique: ctq.unique, + sql: ctq.sql.Clone(), + path: ctq.path, } } @@ -277,9 +272,9 @@ func (ctq *CustomTypeQuery) Clone() *CustomTypeQuery { // Aggregate(entv1.Count()). // Scan(ctx, &v) func (ctq *CustomTypeQuery) GroupBy(field string, fields ...string) *CustomTypeGroupBy { - ctq.fields = append([]string{field}, fields...) + ctq.ctx.Fields = append([]string{field}, fields...) grbuild := &CustomTypeGroupBy{build: ctq} - grbuild.flds = &ctq.fields + grbuild.flds = &ctq.ctx.Fields grbuild.label = customtype.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (ctq *CustomTypeQuery) GroupBy(field string, fields ...string) *CustomTypeG // Select(customtype.FieldCustom). // Scan(ctx, &v) func (ctq *CustomTypeQuery) Select(fields ...string) *CustomTypeSelect { - ctq.fields = append(ctq.fields, fields...) + ctq.ctx.Fields = append(ctq.ctx.Fields, fields...) sbuild := &CustomTypeSelect{CustomTypeQuery: ctq} sbuild.label = customtype.Label - sbuild.flds, sbuild.scan = &ctq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ctq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (ctq *CustomTypeQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range ctq.fields { + for _, f := range ctq.ctx.Fields { if !customtype.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv1: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (ctq *CustomTypeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]* func (ctq *CustomTypeQuery) sqlCount(ctx context.Context) (int, error) { _spec := ctq.querySpec() - _spec.Node.Columns = ctq.fields - if len(ctq.fields) > 0 { - _spec.Unique = ctq.unique != nil && *ctq.unique + _spec.Node.Columns = ctq.ctx.Fields + if len(ctq.ctx.Fields) > 0 { + _spec.Unique = ctq.ctx.Unique != nil && *ctq.ctx.Unique } return sqlgraph.CountNodes(ctx, ctq.driver, _spec) } @@ -383,10 +378,10 @@ func (ctq *CustomTypeQuery) querySpec() *sqlgraph.QuerySpec { From: ctq.sql, Unique: true, } - if unique := ctq.unique; unique != nil { + if unique := ctq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := ctq.fields; len(fields) > 0 { + if fields := ctq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, customtype.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (ctq *CustomTypeQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ctq.limit; limit != nil { + if limit := ctq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ctq.offset; offset != nil { + if offset := ctq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ctq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (ctq *CustomTypeQuery) querySpec() *sqlgraph.QuerySpec { func (ctq *CustomTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ctq.driver.Dialect()) t1 := builder.Table(customtype.Table) - columns := ctq.fields + columns := ctq.ctx.Fields if len(columns) == 0 { columns = customtype.Columns } @@ -430,7 +425,7 @@ func (ctq *CustomTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ctq.sql selector.Select(selector.Columns(columns...)...) } - if ctq.unique != nil && *ctq.unique { + if ctq.ctx.Unique != nil && *ctq.ctx.Unique { selector.Distinct() } for _, p := range ctq.predicates { @@ -439,12 +434,12 @@ func (ctq *CustomTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ctq.order { p(selector) } - if offset := ctq.offset; offset != nil { + if offset := ctq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ctq.limit; limit != nil { + if limit := ctq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (ctgb *CustomTypeGroupBy) Aggregate(fns ...AggregateFunc) *CustomTypeGroupB // Scan applies the selector query and scans the result into the given value. func (ctgb *CustomTypeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCustomType, "GroupBy") + ctx = setContextOp(ctx, ctgb.build.ctx, "GroupBy") if err := ctgb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (cts *CustomTypeSelect) Aggregate(fns ...AggregateFunc) *CustomTypeSelect { // Scan applies the selector query and scans the result into the given value. func (cts *CustomTypeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCustomType, "Select") + ctx = setContextOp(ctx, cts.ctx, "Select") if err := cts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv1/ent.go b/entc/integration/migrate/entv1/ent.go index 5c7909f26..0fbdbfd15 100644 --- a/entc/integration/migrate/entv1/ent.go +++ b/entc/integration/migrate/entv1/ent.go @@ -27,6 +27,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -513,10 +514,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/migrate/entv1/user_query.go b/entc/integration/migrate/entv1/user_query.go index bf778c9d9..58482f90f 100644 --- a/entc/integration/migrate/entv1/user_query.go +++ b/entc/integration/migrate/entv1/user_query.go @@ -23,11 +23,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withParent *UserQuery @@ -48,20 +45,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -162,7 +159,7 @@ func (uq *UserQuery) QueryCar() *CarQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -185,7 +182,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -208,7 +205,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -236,7 +233,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -261,7 +258,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -281,7 +278,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -299,7 +296,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -317,7 +314,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -345,8 +342,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -355,9 +351,8 @@ func (uq *UserQuery) Clone() *UserQuery { withSpouse: uq.withSpouse.Clone(), withCar: uq.withCar.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -420,9 +415,9 @@ func (uq *UserQuery) WithCar(opts ...func(*CarQuery)) *UserQuery { // Aggregate(entv1.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -441,10 +436,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldAge). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -464,7 +459,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv1: invalid field %q for query", f)} } @@ -669,9 +664,9 @@ func (uq *UserQuery) loadCar(ctx context.Context, query *CarQuery, nodes []*User func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -689,10 +684,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -708,10 +703,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -727,7 +722,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -736,7 +731,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -745,12 +740,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -770,7 +765,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -818,7 +813,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/blog_query.go b/entc/integration/migrate/entv2/blog_query.go index 92a5e4c99..09eae6a2c 100644 --- a/entc/integration/migrate/entv2/blog_query.go +++ b/entc/integration/migrate/entv2/blog_query.go @@ -23,11 +23,8 @@ import ( // BlogQuery is the builder for querying Blog entities. type BlogQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Blog withAdmins *UserQuery @@ -44,20 +41,20 @@ func (bq *BlogQuery) Where(ps ...predicate.Blog) *BlogQuery { // Limit the number of records to be returned by this query. func (bq *BlogQuery) Limit(limit int) *BlogQuery { - bq.limit = &limit + bq.ctx.Limit = &limit return bq } // Offset to start from. func (bq *BlogQuery) Offset(offset int) *BlogQuery { - bq.offset = &offset + bq.ctx.Offset = &offset return bq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (bq *BlogQuery) Unique(unique bool) *BlogQuery { - bq.unique = &unique + bq.ctx.Unique = &unique return bq } @@ -92,7 +89,7 @@ func (bq *BlogQuery) QueryAdmins() *UserQuery { // First returns the first Blog entity from the query. // Returns a *NotFoundError when no Blog was found. func (bq *BlogQuery) First(ctx context.Context) (*Blog, error) { - nodes, err := bq.Limit(1).All(newQueryContext(ctx, TypeBlog, "First")) + nodes, err := bq.Limit(1).All(setContextOp(ctx, bq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (bq *BlogQuery) FirstX(ctx context.Context) *Blog { // Returns a *NotFoundError when no Blog ID was found. func (bq *BlogQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = bq.Limit(1).IDs(newQueryContext(ctx, TypeBlog, "FirstID")); err != nil { + if ids, err = bq.Limit(1).IDs(setContextOp(ctx, bq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (bq *BlogQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Blog entity is found. // Returns a *NotFoundError when no Blog entities are found. func (bq *BlogQuery) Only(ctx context.Context) (*Blog, error) { - nodes, err := bq.Limit(2).All(newQueryContext(ctx, TypeBlog, "Only")) + nodes, err := bq.Limit(2).All(setContextOp(ctx, bq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (bq *BlogQuery) OnlyX(ctx context.Context) *Blog { // Returns a *NotFoundError when no entities are found. func (bq *BlogQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = bq.Limit(2).IDs(newQueryContext(ctx, TypeBlog, "OnlyID")); err != nil { + if ids, err = bq.Limit(2).IDs(setContextOp(ctx, bq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (bq *BlogQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Blogs. func (bq *BlogQuery) All(ctx context.Context) ([]*Blog, error) { - ctx = newQueryContext(ctx, TypeBlog, "All") + ctx = setContextOp(ctx, bq.ctx, "All") if err := bq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (bq *BlogQuery) AllX(ctx context.Context) []*Blog { // IDs executes the query and returns a list of Blog IDs. func (bq *BlogQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeBlog, "IDs") + ctx = setContextOp(ctx, bq.ctx, "IDs") if err := bq.Select(blog.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (bq *BlogQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (bq *BlogQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeBlog, "Count") + ctx = setContextOp(ctx, bq.ctx, "Count") if err := bq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (bq *BlogQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (bq *BlogQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeBlog, "Exist") + ctx = setContextOp(ctx, bq.ctx, "Exist") switch _, err := bq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (bq *BlogQuery) Clone() *BlogQuery { } return &BlogQuery{ config: bq.config, - limit: bq.limit, - offset: bq.offset, + ctx: bq.ctx.Clone(), order: append([]OrderFunc{}, bq.order...), inters: append([]Interceptor{}, bq.inters...), predicates: append([]predicate.Blog{}, bq.predicates...), withAdmins: bq.withAdmins.Clone(), // clone intermediate query. - sql: bq.sql.Clone(), - path: bq.path, - unique: bq.unique, + sql: bq.sql.Clone(), + path: bq.path, } } @@ -314,9 +309,9 @@ func (bq *BlogQuery) WithAdmins(opts ...func(*UserQuery)) *BlogQuery { // Aggregate(entv2.Count()). // Scan(ctx, &v) func (bq *BlogQuery) GroupBy(field string, fields ...string) *BlogGroupBy { - bq.fields = append([]string{field}, fields...) + bq.ctx.Fields = append([]string{field}, fields...) grbuild := &BlogGroupBy{build: bq} - grbuild.flds = &bq.fields + grbuild.flds = &bq.ctx.Fields grbuild.label = blog.Label grbuild.scan = grbuild.Scan return grbuild @@ -335,10 +330,10 @@ func (bq *BlogQuery) GroupBy(field string, fields ...string) *BlogGroupBy { // Select(blog.FieldOid). // Scan(ctx, &v) func (bq *BlogQuery) Select(fields ...string) *BlogSelect { - bq.fields = append(bq.fields, fields...) + bq.ctx.Fields = append(bq.ctx.Fields, fields...) sbuild := &BlogSelect{BlogQuery: bq} sbuild.label = blog.Label - sbuild.flds, sbuild.scan = &bq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &bq.ctx.Fields, sbuild.Scan return sbuild } @@ -358,7 +353,7 @@ func (bq *BlogQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range bq.fields { + for _, f := range bq.ctx.Fields { if !blog.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -443,9 +438,9 @@ func (bq *BlogQuery) loadAdmins(ctx context.Context, query *UserQuery, nodes []* func (bq *BlogQuery) sqlCount(ctx context.Context) (int, error) { _spec := bq.querySpec() - _spec.Node.Columns = bq.fields - if len(bq.fields) > 0 { - _spec.Unique = bq.unique != nil && *bq.unique + _spec.Node.Columns = bq.ctx.Fields + if len(bq.ctx.Fields) > 0 { + _spec.Unique = bq.ctx.Unique != nil && *bq.ctx.Unique } return sqlgraph.CountNodes(ctx, bq.driver, _spec) } @@ -463,10 +458,10 @@ func (bq *BlogQuery) querySpec() *sqlgraph.QuerySpec { From: bq.sql, Unique: true, } - if unique := bq.unique; unique != nil { + if unique := bq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := bq.fields; len(fields) > 0 { + if fields := bq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, blog.FieldID) for i := range fields { @@ -482,10 +477,10 @@ func (bq *BlogQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := bq.order; len(ps) > 0 { @@ -501,7 +496,7 @@ func (bq *BlogQuery) querySpec() *sqlgraph.QuerySpec { func (bq *BlogQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(bq.driver.Dialect()) t1 := builder.Table(blog.Table) - columns := bq.fields + columns := bq.ctx.Fields if len(columns) == 0 { columns = blog.Columns } @@ -510,7 +505,7 @@ func (bq *BlogQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = bq.sql selector.Select(selector.Columns(columns...)...) } - if bq.unique != nil && *bq.unique { + if bq.ctx.Unique != nil && *bq.ctx.Unique { selector.Distinct() } for _, p := range bq.predicates { @@ -519,12 +514,12 @@ func (bq *BlogQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range bq.order { p(selector) } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -544,7 +539,7 @@ func (bgb *BlogGroupBy) Aggregate(fns ...AggregateFunc) *BlogGroupBy { // Scan applies the selector query and scans the result into the given value. func (bgb *BlogGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeBlog, "GroupBy") + ctx = setContextOp(ctx, bgb.build.ctx, "GroupBy") if err := bgb.build.prepareQuery(ctx); err != nil { return err } @@ -592,7 +587,7 @@ func (bs *BlogSelect) Aggregate(fns ...AggregateFunc) *BlogSelect { // Scan applies the selector query and scans the result into the given value. func (bs *BlogSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeBlog, "Select") + ctx = setContextOp(ctx, bs.ctx, "Select") if err := bs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/car_query.go b/entc/integration/migrate/entv2/car_query.go index 2326e1239..a695ec453 100644 --- a/entc/integration/migrate/entv2/car_query.go +++ b/entc/integration/migrate/entv2/car_query.go @@ -22,11 +22,8 @@ import ( // CarQuery is the builder for querying Car entities. type CarQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Car withOwner *UserQuery @@ -44,20 +41,20 @@ func (cq *CarQuery) Where(ps ...predicate.Car) *CarQuery { // Limit the number of records to be returned by this query. func (cq *CarQuery) Limit(limit int) *CarQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *CarQuery) Offset(offset int) *CarQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *CarQuery) Unique(unique bool) *CarQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -92,7 +89,7 @@ func (cq *CarQuery) QueryOwner() *UserQuery { // First returns the first Car entity from the query. // Returns a *NotFoundError when no Car was found. func (cq *CarQuery) First(ctx context.Context) (*Car, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeCar, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (cq *CarQuery) FirstX(ctx context.Context) *Car { // Returns a *NotFoundError when no Car ID was found. func (cq *CarQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeCar, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (cq *CarQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Car entity is found. // Returns a *NotFoundError when no Car entities are found. func (cq *CarQuery) Only(ctx context.Context) (*Car, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeCar, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (cq *CarQuery) OnlyX(ctx context.Context) *Car { // Returns a *NotFoundError when no entities are found. func (cq *CarQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeCar, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (cq *CarQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Cars. func (cq *CarQuery) All(ctx context.Context) ([]*Car, error) { - ctx = newQueryContext(ctx, TypeCar, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (cq *CarQuery) AllX(ctx context.Context) []*Car { // IDs executes the query and returns a list of Car IDs. func (cq *CarQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeCar, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(car.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (cq *CarQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *CarQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCar, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (cq *CarQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *CarQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCar, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (cq *CarQuery) Clone() *CarQuery { } return &CarQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Car{}, cq.predicates...), withOwner: cq.withOwner.Clone(), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -314,9 +309,9 @@ func (cq *CarQuery) WithOwner(opts ...func(*UserQuery)) *CarQuery { // Aggregate(entv2.Count()). // Scan(ctx, &v) func (cq *CarQuery) GroupBy(field string, fields ...string) *CarGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &CarGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = car.Label grbuild.scan = grbuild.Scan return grbuild @@ -335,10 +330,10 @@ func (cq *CarQuery) GroupBy(field string, fields ...string) *CarGroupBy { // Select(car.FieldName). // Scan(ctx, &v) func (cq *CarQuery) Select(fields ...string) *CarSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &CarSelect{CarQuery: cq} sbuild.label = car.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -358,7 +353,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !car.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -450,9 +445,9 @@ func (cq *CarQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Ca func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -470,10 +465,10 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, car.FieldID) for i := range fields { @@ -489,10 +484,10 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -508,7 +503,7 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(car.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = car.Columns } @@ -517,7 +512,7 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -526,12 +521,12 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -551,7 +546,7 @@ func (cgb *CarGroupBy) Aggregate(fns ...AggregateFunc) *CarGroupBy { // Scan applies the selector query and scans the result into the given value. func (cgb *CarGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCar, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -599,7 +594,7 @@ func (cs *CarSelect) Aggregate(fns ...AggregateFunc) *CarSelect { // Scan applies the selector query and scans the result into the given value. func (cs *CarSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCar, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/client.go b/entc/integration/migrate/entv2/client.go index 4cbfa284f..09508d315 100644 --- a/entc/integration/migrate/entv2/client.go +++ b/entc/integration/migrate/entv2/client.go @@ -297,6 +297,7 @@ func (c *BlogClient) DeleteOneID(id int) *BlogDeleteOne { func (c *BlogClient) Query() *BlogQuery { return &BlogQuery{ config: c.config, + ctx: &QueryContext{Type: TypeBlog}, inters: c.Interceptors(), } } @@ -430,6 +431,7 @@ func (c *CarClient) DeleteOneID(id int) *CarDeleteOne { func (c *CarClient) Query() *CarQuery { return &CarQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCar}, inters: c.Interceptors(), } } @@ -563,6 +565,7 @@ func (c *ConversionClient) DeleteOneID(id int) *ConversionDeleteOne { func (c *ConversionClient) Query() *ConversionQuery { return &ConversionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeConversion}, inters: c.Interceptors(), } } @@ -680,6 +683,7 @@ func (c *CustomTypeClient) DeleteOneID(id int) *CustomTypeDeleteOne { func (c *CustomTypeClient) Query() *CustomTypeQuery { return &CustomTypeQuery{ config: c.config, + ctx: &QueryContext{Type: TypeCustomType}, inters: c.Interceptors(), } } @@ -797,6 +801,7 @@ func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { func (c *GroupClient) Query() *GroupQuery { return &GroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroup}, inters: c.Interceptors(), } } @@ -914,6 +919,7 @@ func (c *MediaClient) DeleteOneID(id int) *MediaDeleteOne { func (c *MediaClient) Query() *MediaQuery { return &MediaQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMedia}, inters: c.Interceptors(), } } @@ -1031,6 +1037,7 @@ func (c *PetClient) DeleteOneID(id int) *PetDeleteOne { func (c *PetClient) Query() *PetQuery { return &PetQuery{ config: c.config, + ctx: &QueryContext{Type: TypePet}, inters: c.Interceptors(), } } @@ -1164,6 +1171,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } @@ -1329,6 +1337,7 @@ func (c *ZooClient) DeleteOneID(id int) *ZooDeleteOne { func (c *ZooClient) Query() *ZooQuery { return &ZooQuery{ config: c.config, + ctx: &QueryContext{Type: TypeZoo}, inters: c.Interceptors(), } } diff --git a/entc/integration/migrate/entv2/conversion_query.go b/entc/integration/migrate/entv2/conversion_query.go index a7068eedc..42d699989 100644 --- a/entc/integration/migrate/entv2/conversion_query.go +++ b/entc/integration/migrate/entv2/conversion_query.go @@ -21,11 +21,8 @@ import ( // ConversionQuery is the builder for querying Conversion entities. type ConversionQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Conversion // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (cq *ConversionQuery) Where(ps ...predicate.Conversion) *ConversionQuery { // Limit the number of records to be returned by this query. func (cq *ConversionQuery) Limit(limit int) *ConversionQuery { - cq.limit = &limit + cq.ctx.Limit = &limit return cq } // Offset to start from. func (cq *ConversionQuery) Offset(offset int) *ConversionQuery { - cq.offset = &offset + cq.ctx.Offset = &offset return cq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (cq *ConversionQuery) Unique(unique bool) *ConversionQuery { - cq.unique = &unique + cq.ctx.Unique = &unique return cq } @@ -67,7 +64,7 @@ func (cq *ConversionQuery) Order(o ...OrderFunc) *ConversionQuery { // First returns the first Conversion entity from the query. // Returns a *NotFoundError when no Conversion was found. func (cq *ConversionQuery) First(ctx context.Context) (*Conversion, error) { - nodes, err := cq.Limit(1).All(newQueryContext(ctx, TypeConversion, "First")) + nodes, err := cq.Limit(1).All(setContextOp(ctx, cq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (cq *ConversionQuery) FirstX(ctx context.Context) *Conversion { // Returns a *NotFoundError when no Conversion ID was found. func (cq *ConversionQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(1).IDs(newQueryContext(ctx, TypeConversion, "FirstID")); err != nil { + if ids, err = cq.Limit(1).IDs(setContextOp(ctx, cq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (cq *ConversionQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Conversion entity is found. // Returns a *NotFoundError when no Conversion entities are found. func (cq *ConversionQuery) Only(ctx context.Context) (*Conversion, error) { - nodes, err := cq.Limit(2).All(newQueryContext(ctx, TypeConversion, "Only")) + nodes, err := cq.Limit(2).All(setContextOp(ctx, cq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (cq *ConversionQuery) OnlyX(ctx context.Context) *Conversion { // Returns a *NotFoundError when no entities are found. func (cq *ConversionQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = cq.Limit(2).IDs(newQueryContext(ctx, TypeConversion, "OnlyID")); err != nil { + if ids, err = cq.Limit(2).IDs(setContextOp(ctx, cq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (cq *ConversionQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Conversions. func (cq *ConversionQuery) All(ctx context.Context) ([]*Conversion, error) { - ctx = newQueryContext(ctx, TypeConversion, "All") + ctx = setContextOp(ctx, cq.ctx, "All") if err := cq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (cq *ConversionQuery) AllX(ctx context.Context) []*Conversion { // IDs executes the query and returns a list of Conversion IDs. func (cq *ConversionQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeConversion, "IDs") + ctx = setContextOp(ctx, cq.ctx, "IDs") if err := cq.Select(conversion.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (cq *ConversionQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (cq *ConversionQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeConversion, "Count") + ctx = setContextOp(ctx, cq.ctx, "Count") if err := cq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (cq *ConversionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (cq *ConversionQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeConversion, "Exist") + ctx = setContextOp(ctx, cq.ctx, "Exist") switch _, err := cq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (cq *ConversionQuery) Clone() *ConversionQuery { } return &ConversionQuery{ config: cq.config, - limit: cq.limit, - offset: cq.offset, + ctx: cq.ctx.Clone(), order: append([]OrderFunc{}, cq.order...), inters: append([]Interceptor{}, cq.inters...), predicates: append([]predicate.Conversion{}, cq.predicates...), // clone intermediate query. - sql: cq.sql.Clone(), - path: cq.path, - unique: cq.unique, + sql: cq.sql.Clone(), + path: cq.path, } } @@ -277,9 +272,9 @@ func (cq *ConversionQuery) Clone() *ConversionQuery { // Aggregate(entv2.Count()). // Scan(ctx, &v) func (cq *ConversionQuery) GroupBy(field string, fields ...string) *ConversionGroupBy { - cq.fields = append([]string{field}, fields...) + cq.ctx.Fields = append([]string{field}, fields...) grbuild := &ConversionGroupBy{build: cq} - grbuild.flds = &cq.fields + grbuild.flds = &cq.ctx.Fields grbuild.label = conversion.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (cq *ConversionQuery) GroupBy(field string, fields ...string) *ConversionGr // Select(conversion.FieldName). // Scan(ctx, &v) func (cq *ConversionQuery) Select(fields ...string) *ConversionSelect { - cq.fields = append(cq.fields, fields...) + cq.ctx.Fields = append(cq.ctx.Fields, fields...) sbuild := &ConversionSelect{ConversionQuery: cq} sbuild.label = conversion.Label - sbuild.flds, sbuild.scan = &cq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &cq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (cq *ConversionQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range cq.fields { + for _, f := range cq.ctx.Fields { if !conversion.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (cq *ConversionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*C func (cq *ConversionQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() - _spec.Node.Columns = cq.fields - if len(cq.fields) > 0 { - _spec.Unique = cq.unique != nil && *cq.unique + _spec.Node.Columns = cq.ctx.Fields + if len(cq.ctx.Fields) > 0 { + _spec.Unique = cq.ctx.Unique != nil && *cq.ctx.Unique } return sqlgraph.CountNodes(ctx, cq.driver, _spec) } @@ -383,10 +378,10 @@ func (cq *ConversionQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } - if unique := cq.unique; unique != nil { + if unique := cq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := cq.fields; len(fields) > 0 { + if fields := cq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, conversion.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (cq *ConversionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := cq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (cq *ConversionQuery) querySpec() *sqlgraph.QuerySpec { func (cq *ConversionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(conversion.Table) - columns := cq.fields + columns := cq.ctx.Fields if len(columns) == 0 { columns = conversion.Columns } @@ -430,7 +425,7 @@ func (cq *ConversionQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = cq.sql selector.Select(selector.Columns(columns...)...) } - if cq.unique != nil && *cq.unique { + if cq.ctx.Unique != nil && *cq.ctx.Unique { selector.Distinct() } for _, p := range cq.predicates { @@ -439,12 +434,12 @@ func (cq *ConversionQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range cq.order { p(selector) } - if offset := cq.offset; offset != nil { + if offset := cq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := cq.limit; limit != nil { + if limit := cq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (cgb *ConversionGroupBy) Aggregate(fns ...AggregateFunc) *ConversionGroupBy // Scan applies the selector query and scans the result into the given value. func (cgb *ConversionGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeConversion, "GroupBy") + ctx = setContextOp(ctx, cgb.build.ctx, "GroupBy") if err := cgb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (cs *ConversionSelect) Aggregate(fns ...AggregateFunc) *ConversionSelect { // Scan applies the selector query and scans the result into the given value. func (cs *ConversionSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeConversion, "Select") + ctx = setContextOp(ctx, cs.ctx, "Select") if err := cs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/customtype_query.go b/entc/integration/migrate/entv2/customtype_query.go index 9dbd92bf1..1a59c8b7a 100644 --- a/entc/integration/migrate/entv2/customtype_query.go +++ b/entc/integration/migrate/entv2/customtype_query.go @@ -21,11 +21,8 @@ import ( // CustomTypeQuery is the builder for querying CustomType entities. type CustomTypeQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.CustomType // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (ctq *CustomTypeQuery) Where(ps ...predicate.CustomType) *CustomTypeQuery { // Limit the number of records to be returned by this query. func (ctq *CustomTypeQuery) Limit(limit int) *CustomTypeQuery { - ctq.limit = &limit + ctq.ctx.Limit = &limit return ctq } // Offset to start from. func (ctq *CustomTypeQuery) Offset(offset int) *CustomTypeQuery { - ctq.offset = &offset + ctq.ctx.Offset = &offset return ctq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ctq *CustomTypeQuery) Unique(unique bool) *CustomTypeQuery { - ctq.unique = &unique + ctq.ctx.Unique = &unique return ctq } @@ -67,7 +64,7 @@ func (ctq *CustomTypeQuery) Order(o ...OrderFunc) *CustomTypeQuery { // First returns the first CustomType entity from the query. // Returns a *NotFoundError when no CustomType was found. func (ctq *CustomTypeQuery) First(ctx context.Context) (*CustomType, error) { - nodes, err := ctq.Limit(1).All(newQueryContext(ctx, TypeCustomType, "First")) + nodes, err := ctq.Limit(1).All(setContextOp(ctx, ctq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (ctq *CustomTypeQuery) FirstX(ctx context.Context) *CustomType { // Returns a *NotFoundError when no CustomType ID was found. func (ctq *CustomTypeQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ctq.Limit(1).IDs(newQueryContext(ctx, TypeCustomType, "FirstID")); err != nil { + if ids, err = ctq.Limit(1).IDs(setContextOp(ctx, ctq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (ctq *CustomTypeQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one CustomType entity is found. // Returns a *NotFoundError when no CustomType entities are found. func (ctq *CustomTypeQuery) Only(ctx context.Context) (*CustomType, error) { - nodes, err := ctq.Limit(2).All(newQueryContext(ctx, TypeCustomType, "Only")) + nodes, err := ctq.Limit(2).All(setContextOp(ctx, ctq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (ctq *CustomTypeQuery) OnlyX(ctx context.Context) *CustomType { // Returns a *NotFoundError when no entities are found. func (ctq *CustomTypeQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ctq.Limit(2).IDs(newQueryContext(ctx, TypeCustomType, "OnlyID")); err != nil { + if ids, err = ctq.Limit(2).IDs(setContextOp(ctx, ctq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (ctq *CustomTypeQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of CustomTypes. func (ctq *CustomTypeQuery) All(ctx context.Context) ([]*CustomType, error) { - ctx = newQueryContext(ctx, TypeCustomType, "All") + ctx = setContextOp(ctx, ctq.ctx, "All") if err := ctq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (ctq *CustomTypeQuery) AllX(ctx context.Context) []*CustomType { // IDs executes the query and returns a list of CustomType IDs. func (ctq *CustomTypeQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeCustomType, "IDs") + ctx = setContextOp(ctx, ctq.ctx, "IDs") if err := ctq.Select(customtype.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (ctq *CustomTypeQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (ctq *CustomTypeQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeCustomType, "Count") + ctx = setContextOp(ctx, ctq.ctx, "Count") if err := ctq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (ctq *CustomTypeQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ctq *CustomTypeQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeCustomType, "Exist") + ctx = setContextOp(ctx, ctq.ctx, "Exist") switch _, err := ctq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (ctq *CustomTypeQuery) Clone() *CustomTypeQuery { } return &CustomTypeQuery{ config: ctq.config, - limit: ctq.limit, - offset: ctq.offset, + ctx: ctq.ctx.Clone(), order: append([]OrderFunc{}, ctq.order...), inters: append([]Interceptor{}, ctq.inters...), predicates: append([]predicate.CustomType{}, ctq.predicates...), // clone intermediate query. - sql: ctq.sql.Clone(), - path: ctq.path, - unique: ctq.unique, + sql: ctq.sql.Clone(), + path: ctq.path, } } @@ -277,9 +272,9 @@ func (ctq *CustomTypeQuery) Clone() *CustomTypeQuery { // Aggregate(entv2.Count()). // Scan(ctx, &v) func (ctq *CustomTypeQuery) GroupBy(field string, fields ...string) *CustomTypeGroupBy { - ctq.fields = append([]string{field}, fields...) + ctq.ctx.Fields = append([]string{field}, fields...) grbuild := &CustomTypeGroupBy{build: ctq} - grbuild.flds = &ctq.fields + grbuild.flds = &ctq.ctx.Fields grbuild.label = customtype.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (ctq *CustomTypeQuery) GroupBy(field string, fields ...string) *CustomTypeG // Select(customtype.FieldCustom). // Scan(ctx, &v) func (ctq *CustomTypeQuery) Select(fields ...string) *CustomTypeSelect { - ctq.fields = append(ctq.fields, fields...) + ctq.ctx.Fields = append(ctq.ctx.Fields, fields...) sbuild := &CustomTypeSelect{CustomTypeQuery: ctq} sbuild.label = customtype.Label - sbuild.flds, sbuild.scan = &ctq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &ctq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (ctq *CustomTypeQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range ctq.fields { + for _, f := range ctq.ctx.Fields { if !customtype.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (ctq *CustomTypeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]* func (ctq *CustomTypeQuery) sqlCount(ctx context.Context) (int, error) { _spec := ctq.querySpec() - _spec.Node.Columns = ctq.fields - if len(ctq.fields) > 0 { - _spec.Unique = ctq.unique != nil && *ctq.unique + _spec.Node.Columns = ctq.ctx.Fields + if len(ctq.ctx.Fields) > 0 { + _spec.Unique = ctq.ctx.Unique != nil && *ctq.ctx.Unique } return sqlgraph.CountNodes(ctx, ctq.driver, _spec) } @@ -383,10 +378,10 @@ func (ctq *CustomTypeQuery) querySpec() *sqlgraph.QuerySpec { From: ctq.sql, Unique: true, } - if unique := ctq.unique; unique != nil { + if unique := ctq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := ctq.fields; len(fields) > 0 { + if fields := ctq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, customtype.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (ctq *CustomTypeQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ctq.limit; limit != nil { + if limit := ctq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ctq.offset; offset != nil { + if offset := ctq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ctq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (ctq *CustomTypeQuery) querySpec() *sqlgraph.QuerySpec { func (ctq *CustomTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ctq.driver.Dialect()) t1 := builder.Table(customtype.Table) - columns := ctq.fields + columns := ctq.ctx.Fields if len(columns) == 0 { columns = customtype.Columns } @@ -430,7 +425,7 @@ func (ctq *CustomTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ctq.sql selector.Select(selector.Columns(columns...)...) } - if ctq.unique != nil && *ctq.unique { + if ctq.ctx.Unique != nil && *ctq.ctx.Unique { selector.Distinct() } for _, p := range ctq.predicates { @@ -439,12 +434,12 @@ func (ctq *CustomTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ctq.order { p(selector) } - if offset := ctq.offset; offset != nil { + if offset := ctq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ctq.limit; limit != nil { + if limit := ctq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (ctgb *CustomTypeGroupBy) Aggregate(fns ...AggregateFunc) *CustomTypeGroupB // Scan applies the selector query and scans the result into the given value. func (ctgb *CustomTypeGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCustomType, "GroupBy") + ctx = setContextOp(ctx, ctgb.build.ctx, "GroupBy") if err := ctgb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (cts *CustomTypeSelect) Aggregate(fns ...AggregateFunc) *CustomTypeSelect { // Scan applies the selector query and scans the result into the given value. func (cts *CustomTypeSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeCustomType, "Select") + ctx = setContextOp(ctx, cts.ctx, "Select") if err := cts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/ent.go b/entc/integration/migrate/entv2/ent.go index 0a36f3efc..21d977227 100644 --- a/entc/integration/migrate/entv2/ent.go +++ b/entc/integration/migrate/entv2/ent.go @@ -32,6 +32,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -523,10 +524,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/migrate/entv2/group_query.go b/entc/integration/migrate/entv2/group_query.go index a866fa128..9a9e8363d 100644 --- a/entc/integration/migrate/entv2/group_query.go +++ b/entc/integration/migrate/entv2/group_query.go @@ -21,11 +21,8 @@ import ( // GroupQuery is the builder for querying Group entities. type GroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Group // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { // Limit the number of records to be returned by this query. func (gq *GroupQuery) Limit(limit int) *GroupQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GroupQuery) Offset(offset int) *GroupQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GroupQuery) Unique(unique bool) *GroupQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -67,7 +64,7 @@ func (gq *GroupQuery) Order(o ...OrderFunc) *GroupQuery { // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGroup, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group { // Returns a *NotFoundError when no Group ID was found. func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGroup, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (gq *GroupQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Group entity is found. // Returns a *NotFoundError when no Group entities are found. func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGroup, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { // Returns a *NotFoundError when no entities are found. func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGroup, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (gq *GroupQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Groups. func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { - ctx = newQueryContext(ctx, TypeGroup, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group { // IDs executes the query and returns a list of Group IDs. func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroup, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (gq *GroupQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gq *GroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroup, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (gq *GroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroup, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,24 +247,22 @@ func (gq *GroupQuery) Clone() *GroupQuery { } return &GroupQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Group{}, gq.predicates...), // clone intermediate query. - sql: gq.sql.Clone(), - path: gq.path, - unique: gq.unique, + sql: gq.sql.Clone(), + path: gq.path, } } // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = group.Label grbuild.scan = grbuild.Scan return grbuild @@ -276,10 +271,10 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (gq *GroupQuery) Select(fields ...string) *GroupSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GroupSelect{GroupQuery: gq} sbuild.label = group.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -299,7 +294,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gq.fields { + for _, f := range gq.ctx.Fields { if !group.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -341,9 +336,9 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { _spec := gq.querySpec() - _spec.Node.Columns = gq.fields - if len(gq.fields) > 0 { - _spec.Unique = gq.unique != nil && *gq.unique + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique } return sqlgraph.CountNodes(ctx, gq.driver, _spec) } @@ -361,10 +356,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { From: gq.sql, Unique: true, } - if unique := gq.unique; unique != nil { + if unique := gq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gq.fields; len(fields) > 0 { + if fields := gq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) for i := range fields { @@ -380,10 +375,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gq.order; len(ps) > 0 { @@ -399,7 +394,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) - columns := gq.fields + columns := gq.ctx.Fields if len(columns) == 0 { columns = group.Columns } @@ -408,7 +403,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gq.sql selector.Select(selector.Columns(columns...)...) } - if gq.unique != nil && *gq.unique { + if gq.ctx.Unique != nil && *gq.ctx.Unique { selector.Distinct() } for _, p := range gq.predicates { @@ -417,12 +412,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gq.order { p(selector) } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -442,7 +437,7 @@ func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -490,7 +485,7 @@ func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/media_query.go b/entc/integration/migrate/entv2/media_query.go index c0611992d..9db5f1c6b 100644 --- a/entc/integration/migrate/entv2/media_query.go +++ b/entc/integration/migrate/entv2/media_query.go @@ -21,11 +21,8 @@ import ( // MediaQuery is the builder for querying Media entities. type MediaQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Media // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (mq *MediaQuery) Where(ps ...predicate.Media) *MediaQuery { // Limit the number of records to be returned by this query. func (mq *MediaQuery) Limit(limit int) *MediaQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } // Offset to start from. func (mq *MediaQuery) Offset(offset int) *MediaQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MediaQuery) Unique(unique bool) *MediaQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } @@ -67,7 +64,7 @@ func (mq *MediaQuery) Order(o ...OrderFunc) *MediaQuery { // First returns the first Media entity from the query. // Returns a *NotFoundError when no Media was found. func (mq *MediaQuery) First(ctx context.Context) (*Media, error) { - nodes, err := mq.Limit(1).All(newQueryContext(ctx, TypeMedia, "First")) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (mq *MediaQuery) FirstX(ctx context.Context) *Media { // Returns a *NotFoundError when no Media ID was found. func (mq *MediaQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(1).IDs(newQueryContext(ctx, TypeMedia, "FirstID")); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (mq *MediaQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Media entity is found. // Returns a *NotFoundError when no Media entities are found. func (mq *MediaQuery) Only(ctx context.Context) (*Media, error) { - nodes, err := mq.Limit(2).All(newQueryContext(ctx, TypeMedia, "Only")) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (mq *MediaQuery) OnlyX(ctx context.Context) *Media { // Returns a *NotFoundError when no entities are found. func (mq *MediaQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(2).IDs(newQueryContext(ctx, TypeMedia, "OnlyID")); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (mq *MediaQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MediaSlice. func (mq *MediaQuery) All(ctx context.Context) ([]*Media, error) { - ctx = newQueryContext(ctx, TypeMedia, "All") + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (mq *MediaQuery) AllX(ctx context.Context) []*Media { // IDs executes the query and returns a list of Media IDs. func (mq *MediaQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeMedia, "IDs") + ctx = setContextOp(ctx, mq.ctx, "IDs") if err := mq.Select(media.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (mq *MediaQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mq *MediaQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeMedia, "Count") + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (mq *MediaQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MediaQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeMedia, "Exist") + ctx = setContextOp(ctx, mq.ctx, "Exist") switch _, err := mq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (mq *MediaQuery) Clone() *MediaQuery { } return &MediaQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, + ctx: mq.ctx.Clone(), order: append([]OrderFunc{}, mq.order...), inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Media{}, mq.predicates...), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } @@ -277,9 +272,9 @@ func (mq *MediaQuery) Clone() *MediaQuery { // Aggregate(entv2.Count()). // Scan(ctx, &v) func (mq *MediaQuery) GroupBy(field string, fields ...string) *MediaGroupBy { - mq.fields = append([]string{field}, fields...) + mq.ctx.Fields = append([]string{field}, fields...) grbuild := &MediaGroupBy{build: mq} - grbuild.flds = &mq.fields + grbuild.flds = &mq.ctx.Fields grbuild.label = media.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (mq *MediaQuery) GroupBy(field string, fields ...string) *MediaGroupBy { // Select(media.FieldSource). // Scan(ctx, &v) func (mq *MediaQuery) Select(fields ...string) *MediaSelect { - mq.fields = append(mq.fields, fields...) + mq.ctx.Fields = append(mq.ctx.Fields, fields...) sbuild := &MediaSelect{MediaQuery: mq} sbuild.label = media.Label - sbuild.flds, sbuild.scan = &mq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (mq *MediaQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range mq.fields { + for _, f := range mq.ctx.Fields { if !media.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (mq *MediaQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Media, func (mq *MediaQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } @@ -383,10 +378,10 @@ func (mq *MediaQuery) querySpec() *sqlgraph.QuerySpec { From: mq.sql, Unique: true, } - if unique := mq.unique; unique != nil { + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, media.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (mq *MediaQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (mq *MediaQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MediaQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(media.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = media.Columns } @@ -430,7 +425,7 @@ func (mq *MediaQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -439,12 +434,12 @@ func (mq *MediaQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (mgb *MediaGroupBy) Aggregate(fns ...AggregateFunc) *MediaGroupBy { // Scan applies the selector query and scans the result into the given value. func (mgb *MediaGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeMedia, "GroupBy") + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") if err := mgb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (ms *MediaSelect) Aggregate(fns ...AggregateFunc) *MediaSelect { // Scan applies the selector query and scans the result into the given value. func (ms *MediaSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeMedia, "Select") + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/pet_query.go b/entc/integration/migrate/entv2/pet_query.go index 37dacc389..9837d8234 100644 --- a/entc/integration/migrate/entv2/pet_query.go +++ b/entc/integration/migrate/entv2/pet_query.go @@ -22,11 +22,8 @@ import ( // PetQuery is the builder for querying Pet entities. type PetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Pet withOwner *UserQuery @@ -44,20 +41,20 @@ func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { // Limit the number of records to be returned by this query. func (pq *PetQuery) Limit(limit int) *PetQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PetQuery) Offset(offset int) *PetQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PetQuery) Unique(unique bool) *PetQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -92,7 +89,7 @@ func (pq *PetQuery) QueryOwner() *UserQuery { // First returns the first Pet entity from the query. // Returns a *NotFoundError when no Pet was found. func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePet, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -115,7 +112,7 @@ func (pq *PetQuery) FirstX(ctx context.Context) *Pet { // Returns a *NotFoundError when no Pet ID was found. func (pq *PetQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePet, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -138,7 +135,7 @@ func (pq *PetQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Pet entity is found. // Returns a *NotFoundError when no Pet entities are found. func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePet, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -166,7 +163,7 @@ func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { // Returns a *NotFoundError when no entities are found. func (pq *PetQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePet, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -191,7 +188,7 @@ func (pq *PetQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Pets. func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { - ctx = newQueryContext(ctx, TypePet, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -211,7 +208,7 @@ func (pq *PetQuery) AllX(ctx context.Context) []*Pet { // IDs executes the query and returns a list of Pet IDs. func (pq *PetQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypePet, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(pet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -229,7 +226,7 @@ func (pq *PetQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pq *PetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePet, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -247,7 +244,7 @@ func (pq *PetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePet, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -275,16 +272,14 @@ func (pq *PetQuery) Clone() *PetQuery { } return &PetQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Pet{}, pq.predicates...), withOwner: pq.withOwner.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -314,9 +309,9 @@ func (pq *PetQuery) WithOwner(opts ...func(*UserQuery)) *PetQuery { // Aggregate(entv2.Count()). // Scan(ctx, &v) func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PetGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = pet.Label grbuild.scan = grbuild.Scan return grbuild @@ -335,10 +330,10 @@ func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { // Select(pet.FieldName). // Scan(ctx, &v) func (pq *PetQuery) Select(fields ...string) *PetSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PetSelect{PetQuery: pq} sbuild.label = pet.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -358,7 +353,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !pet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -450,9 +445,9 @@ func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pe func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -470,10 +465,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, pet.FieldID) for i := range fields { @@ -489,10 +484,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -508,7 +503,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = pet.Columns } @@ -517,7 +512,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } for _, p := range pq.predicates { @@ -526,12 +521,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -551,7 +546,7 @@ func (pgb *PetGroupBy) Aggregate(fns ...AggregateFunc) *PetGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -599,7 +594,7 @@ func (ps *PetSelect) Aggregate(fns ...AggregateFunc) *PetSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/user_query.go b/entc/integration/migrate/entv2/user_query.go index e4603eda3..ec7af789b 100644 --- a/entc/integration/migrate/entv2/user_query.go +++ b/entc/integration/migrate/entv2/user_query.go @@ -24,11 +24,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withCar *CarQuery @@ -48,20 +45,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -140,7 +137,7 @@ func (uq *UserQuery) QueryFriends() *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -163,7 +160,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -186,7 +183,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -214,7 +211,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -239,7 +236,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -259,7 +256,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -277,7 +274,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -295,7 +292,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -323,8 +320,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -332,9 +328,8 @@ func (uq *UserQuery) Clone() *UserQuery { withPets: uq.withPets.Clone(), withFriends: uq.withFriends.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -386,9 +381,9 @@ func (uq *UserQuery) WithFriends(opts ...func(*UserQuery)) *UserQuery { // Aggregate(entv2.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -407,10 +402,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldMixedString). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -430,7 +425,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -620,9 +615,9 @@ func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes [] func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -640,10 +635,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -659,10 +654,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -678,7 +673,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -687,7 +682,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -696,12 +691,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -721,7 +716,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -769,7 +764,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/entv2/zoo_query.go b/entc/integration/migrate/entv2/zoo_query.go index 82392cb9f..346587eab 100644 --- a/entc/integration/migrate/entv2/zoo_query.go +++ b/entc/integration/migrate/entv2/zoo_query.go @@ -21,11 +21,8 @@ import ( // ZooQuery is the builder for querying Zoo entities. type ZooQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Zoo // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (zq *ZooQuery) Where(ps ...predicate.Zoo) *ZooQuery { // Limit the number of records to be returned by this query. func (zq *ZooQuery) Limit(limit int) *ZooQuery { - zq.limit = &limit + zq.ctx.Limit = &limit return zq } // Offset to start from. func (zq *ZooQuery) Offset(offset int) *ZooQuery { - zq.offset = &offset + zq.ctx.Offset = &offset return zq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (zq *ZooQuery) Unique(unique bool) *ZooQuery { - zq.unique = &unique + zq.ctx.Unique = &unique return zq } @@ -67,7 +64,7 @@ func (zq *ZooQuery) Order(o ...OrderFunc) *ZooQuery { // First returns the first Zoo entity from the query. // Returns a *NotFoundError when no Zoo was found. func (zq *ZooQuery) First(ctx context.Context) (*Zoo, error) { - nodes, err := zq.Limit(1).All(newQueryContext(ctx, TypeZoo, "First")) + nodes, err := zq.Limit(1).All(setContextOp(ctx, zq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (zq *ZooQuery) FirstX(ctx context.Context) *Zoo { // Returns a *NotFoundError when no Zoo ID was found. func (zq *ZooQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = zq.Limit(1).IDs(newQueryContext(ctx, TypeZoo, "FirstID")); err != nil { + if ids, err = zq.Limit(1).IDs(setContextOp(ctx, zq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (zq *ZooQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Zoo entity is found. // Returns a *NotFoundError when no Zoo entities are found. func (zq *ZooQuery) Only(ctx context.Context) (*Zoo, error) { - nodes, err := zq.Limit(2).All(newQueryContext(ctx, TypeZoo, "Only")) + nodes, err := zq.Limit(2).All(setContextOp(ctx, zq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (zq *ZooQuery) OnlyX(ctx context.Context) *Zoo { // Returns a *NotFoundError when no entities are found. func (zq *ZooQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = zq.Limit(2).IDs(newQueryContext(ctx, TypeZoo, "OnlyID")); err != nil { + if ids, err = zq.Limit(2).IDs(setContextOp(ctx, zq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (zq *ZooQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Zoos. func (zq *ZooQuery) All(ctx context.Context) ([]*Zoo, error) { - ctx = newQueryContext(ctx, TypeZoo, "All") + ctx = setContextOp(ctx, zq.ctx, "All") if err := zq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (zq *ZooQuery) AllX(ctx context.Context) []*Zoo { // IDs executes the query and returns a list of Zoo IDs. func (zq *ZooQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeZoo, "IDs") + ctx = setContextOp(ctx, zq.ctx, "IDs") if err := zq.Select(zoo.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (zq *ZooQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (zq *ZooQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeZoo, "Count") + ctx = setContextOp(ctx, zq.ctx, "Count") if err := zq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (zq *ZooQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (zq *ZooQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeZoo, "Exist") + ctx = setContextOp(ctx, zq.ctx, "Exist") switch _, err := zq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,24 +247,22 @@ func (zq *ZooQuery) Clone() *ZooQuery { } return &ZooQuery{ config: zq.config, - limit: zq.limit, - offset: zq.offset, + ctx: zq.ctx.Clone(), order: append([]OrderFunc{}, zq.order...), inters: append([]Interceptor{}, zq.inters...), predicates: append([]predicate.Zoo{}, zq.predicates...), // clone intermediate query. - sql: zq.sql.Clone(), - path: zq.path, - unique: zq.unique, + sql: zq.sql.Clone(), + path: zq.path, } } // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (zq *ZooQuery) GroupBy(field string, fields ...string) *ZooGroupBy { - zq.fields = append([]string{field}, fields...) + zq.ctx.Fields = append([]string{field}, fields...) grbuild := &ZooGroupBy{build: zq} - grbuild.flds = &zq.fields + grbuild.flds = &zq.ctx.Fields grbuild.label = zoo.Label grbuild.scan = grbuild.Scan return grbuild @@ -276,10 +271,10 @@ func (zq *ZooQuery) GroupBy(field string, fields ...string) *ZooGroupBy { // Select allows the selection one or more fields/columns for the given query, // instead of selecting all fields in the entity. func (zq *ZooQuery) Select(fields ...string) *ZooSelect { - zq.fields = append(zq.fields, fields...) + zq.ctx.Fields = append(zq.ctx.Fields, fields...) sbuild := &ZooSelect{ZooQuery: zq} sbuild.label = zoo.Label - sbuild.flds, sbuild.scan = &zq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &zq.ctx.Fields, sbuild.Scan return sbuild } @@ -299,7 +294,7 @@ func (zq *ZooQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range zq.fields { + for _, f := range zq.ctx.Fields { if !zoo.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("entv2: invalid field %q for query", f)} } @@ -341,9 +336,9 @@ func (zq *ZooQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Zoo, err func (zq *ZooQuery) sqlCount(ctx context.Context) (int, error) { _spec := zq.querySpec() - _spec.Node.Columns = zq.fields - if len(zq.fields) > 0 { - _spec.Unique = zq.unique != nil && *zq.unique + _spec.Node.Columns = zq.ctx.Fields + if len(zq.ctx.Fields) > 0 { + _spec.Unique = zq.ctx.Unique != nil && *zq.ctx.Unique } return sqlgraph.CountNodes(ctx, zq.driver, _spec) } @@ -361,10 +356,10 @@ func (zq *ZooQuery) querySpec() *sqlgraph.QuerySpec { From: zq.sql, Unique: true, } - if unique := zq.unique; unique != nil { + if unique := zq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := zq.fields; len(fields) > 0 { + if fields := zq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, zoo.FieldID) for i := range fields { @@ -380,10 +375,10 @@ func (zq *ZooQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := zq.limit; limit != nil { + if limit := zq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := zq.offset; offset != nil { + if offset := zq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := zq.order; len(ps) > 0 { @@ -399,7 +394,7 @@ func (zq *ZooQuery) querySpec() *sqlgraph.QuerySpec { func (zq *ZooQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(zq.driver.Dialect()) t1 := builder.Table(zoo.Table) - columns := zq.fields + columns := zq.ctx.Fields if len(columns) == 0 { columns = zoo.Columns } @@ -408,7 +403,7 @@ func (zq *ZooQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = zq.sql selector.Select(selector.Columns(columns...)...) } - if zq.unique != nil && *zq.unique { + if zq.ctx.Unique != nil && *zq.ctx.Unique { selector.Distinct() } for _, p := range zq.predicates { @@ -417,12 +412,12 @@ func (zq *ZooQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range zq.order { p(selector) } - if offset := zq.offset; offset != nil { + if offset := zq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := zq.limit; limit != nil { + if limit := zq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -442,7 +437,7 @@ func (zgb *ZooGroupBy) Aggregate(fns ...AggregateFunc) *ZooGroupBy { // Scan applies the selector query and scans the result into the given value. func (zgb *ZooGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeZoo, "GroupBy") + ctx = setContextOp(ctx, zgb.build.ctx, "GroupBy") if err := zgb.build.prepareQuery(ctx); err != nil { return err } @@ -490,7 +485,7 @@ func (zs *ZooSelect) Aggregate(fns ...AggregateFunc) *ZooSelect { // Scan applies the selector query and scans the result into the given value. func (zs *ZooSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeZoo, "Select") + ctx = setContextOp(ctx, zs.ctx, "Select") if err := zs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/versioned/client.go b/entc/integration/migrate/versioned/client.go index ce5a3b0ff..10913af7f 100644 --- a/entc/integration/migrate/versioned/client.go +++ b/entc/integration/migrate/versioned/client.go @@ -226,6 +226,7 @@ func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { func (c *GroupClient) Query() *GroupQuery { return &GroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroup}, inters: c.Interceptors(), } } @@ -343,6 +344,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/migrate/versioned/ent.go b/entc/integration/migrate/versioned/ent.go index 9a89b1e2e..715ab41fd 100644 --- a/entc/integration/migrate/versioned/ent.go +++ b/entc/integration/migrate/versioned/ent.go @@ -25,6 +25,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -509,10 +510,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/migrate/versioned/group_query.go b/entc/integration/migrate/versioned/group_query.go index fa6b63e07..922c45d35 100644 --- a/entc/integration/migrate/versioned/group_query.go +++ b/entc/integration/migrate/versioned/group_query.go @@ -21,11 +21,8 @@ import ( // GroupQuery is the builder for querying Group entities. type GroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Group // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { // Limit the number of records to be returned by this query. func (gq *GroupQuery) Limit(limit int) *GroupQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GroupQuery) Offset(offset int) *GroupQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GroupQuery) Unique(unique bool) *GroupQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -67,7 +64,7 @@ func (gq *GroupQuery) Order(o ...OrderFunc) *GroupQuery { // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGroup, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group { // Returns a *NotFoundError when no Group ID was found. func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGroup, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (gq *GroupQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Group entity is found. // Returns a *NotFoundError when no Group entities are found. func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGroup, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { // Returns a *NotFoundError when no entities are found. func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGroup, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (gq *GroupQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Groups. func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { - ctx = newQueryContext(ctx, TypeGroup, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group { // IDs executes the query and returns a list of Group IDs. func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroup, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (gq *GroupQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gq *GroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroup, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (gq *GroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroup, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (gq *GroupQuery) Clone() *GroupQuery { } return &GroupQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Group{}, gq.predicates...), // clone intermediate query. - sql: gq.sql.Clone(), - path: gq.path, - unique: gq.unique, + sql: gq.sql.Clone(), + path: gq.path, } } @@ -277,9 +272,9 @@ func (gq *GroupQuery) Clone() *GroupQuery { // Aggregate(versioned.Count()). // Scan(ctx, &v) func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = group.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Select(group.FieldName). // Scan(ctx, &v) func (gq *GroupQuery) Select(fields ...string) *GroupSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GroupSelect{GroupQuery: gq} sbuild.label = group.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gq.fields { + for _, f := range gq.ctx.Fields { if !group.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("versioned: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { _spec := gq.querySpec() - _spec.Node.Columns = gq.fields - if len(gq.fields) > 0 { - _spec.Unique = gq.unique != nil && *gq.unique + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique } return sqlgraph.CountNodes(ctx, gq.driver, _spec) } @@ -383,10 +378,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { From: gq.sql, Unique: true, } - if unique := gq.unique; unique != nil { + if unique := gq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gq.fields; len(fields) > 0 { + if fields := gq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) - columns := gq.fields + columns := gq.ctx.Fields if len(columns) == 0 { columns = group.Columns } @@ -430,7 +425,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gq.sql selector.Select(selector.Columns(columns...)...) } - if gq.unique != nil && *gq.unique { + if gq.ctx.Unique != nil && *gq.ctx.Unique { selector.Distinct() } for _, p := range gq.predicates { @@ -439,12 +434,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gq.order { p(selector) } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/migrate/versioned/user_query.go b/entc/integration/migrate/versioned/user_query.go index 370167033..711503ab6 100644 --- a/entc/integration/migrate/versioned/user_query.go +++ b/entc/integration/migrate/versioned/user_query.go @@ -21,11 +21,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User // intermediate query (i.e. traversal path). @@ -41,20 +38,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -67,7 +64,7 @@ func (uq *UserQuery) Order(o ...OrderFunc) *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -90,7 +87,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -113,7 +110,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -166,7 +163,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -204,7 +201,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -222,7 +219,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -250,15 +247,13 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -277,9 +272,9 @@ func (uq *UserQuery) Clone() *UserQuery { // Aggregate(versioned.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -298,10 +293,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldAge). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -321,7 +316,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("versioned: invalid field %q for query", f)} } @@ -363,9 +358,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -383,10 +378,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -402,10 +397,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -421,7 +416,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -430,7 +425,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -439,12 +434,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -464,7 +459,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -512,7 +507,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/multischema/ent/client.go b/entc/integration/multischema/ent/client.go index 4f7c8d696..f53d56fcc 100644 --- a/entc/integration/multischema/ent/client.go +++ b/entc/integration/multischema/ent/client.go @@ -247,6 +247,7 @@ func (c *FriendshipClient) DeleteOneID(id int) *FriendshipDeleteOne { func (c *FriendshipClient) Query() *FriendshipQuery { return &FriendshipQuery{ config: c.config, + ctx: &QueryContext{Type: TypeFriendship}, inters: c.Interceptors(), } } @@ -402,6 +403,7 @@ func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { func (c *GroupClient) Query() *GroupQuery { return &GroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroup}, inters: c.Interceptors(), } } @@ -538,6 +540,7 @@ func (c *PetClient) DeleteOneID(id int) *PetDeleteOne { func (c *PetClient) Query() *PetQuery { return &PetQuery{ config: c.config, + ctx: &QueryContext{Type: TypePet}, inters: c.Interceptors(), } } @@ -674,6 +677,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/multischema/ent/ent.go b/entc/integration/multischema/ent/ent.go index 141619838..6cc6cffe3 100644 --- a/entc/integration/multischema/ent/ent.go +++ b/entc/integration/multischema/ent/ent.go @@ -27,6 +27,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -513,10 +514,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/multischema/ent/friendship_query.go b/entc/integration/multischema/ent/friendship_query.go index 5c5f849a5..951008595 100644 --- a/entc/integration/multischema/ent/friendship_query.go +++ b/entc/integration/multischema/ent/friendship_query.go @@ -23,11 +23,8 @@ import ( // FriendshipQuery is the builder for querying Friendship entities. type FriendshipQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Friendship withUser *UserQuery @@ -46,20 +43,20 @@ func (fq *FriendshipQuery) Where(ps ...predicate.Friendship) *FriendshipQuery { // Limit the number of records to be returned by this query. func (fq *FriendshipQuery) Limit(limit int) *FriendshipQuery { - fq.limit = &limit + fq.ctx.Limit = &limit return fq } // Offset to start from. func (fq *FriendshipQuery) Offset(offset int) *FriendshipQuery { - fq.offset = &offset + fq.ctx.Offset = &offset return fq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (fq *FriendshipQuery) Unique(unique bool) *FriendshipQuery { - fq.unique = &unique + fq.ctx.Unique = &unique return fq } @@ -122,7 +119,7 @@ func (fq *FriendshipQuery) QueryFriend() *UserQuery { // First returns the first Friendship entity from the query. // Returns a *NotFoundError when no Friendship was found. func (fq *FriendshipQuery) First(ctx context.Context) (*Friendship, error) { - nodes, err := fq.Limit(1).All(newQueryContext(ctx, TypeFriendship, "First")) + nodes, err := fq.Limit(1).All(setContextOp(ctx, fq.ctx, "First")) if err != nil { return nil, err } @@ -145,7 +142,7 @@ func (fq *FriendshipQuery) FirstX(ctx context.Context) *Friendship { // Returns a *NotFoundError when no Friendship ID was found. func (fq *FriendshipQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = fq.Limit(1).IDs(newQueryContext(ctx, TypeFriendship, "FirstID")); err != nil { + if ids, err = fq.Limit(1).IDs(setContextOp(ctx, fq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -168,7 +165,7 @@ func (fq *FriendshipQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Friendship entity is found. // Returns a *NotFoundError when no Friendship entities are found. func (fq *FriendshipQuery) Only(ctx context.Context) (*Friendship, error) { - nodes, err := fq.Limit(2).All(newQueryContext(ctx, TypeFriendship, "Only")) + nodes, err := fq.Limit(2).All(setContextOp(ctx, fq.ctx, "Only")) if err != nil { return nil, err } @@ -196,7 +193,7 @@ func (fq *FriendshipQuery) OnlyX(ctx context.Context) *Friendship { // Returns a *NotFoundError when no entities are found. func (fq *FriendshipQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = fq.Limit(2).IDs(newQueryContext(ctx, TypeFriendship, "OnlyID")); err != nil { + if ids, err = fq.Limit(2).IDs(setContextOp(ctx, fq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -221,7 +218,7 @@ func (fq *FriendshipQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Friendships. func (fq *FriendshipQuery) All(ctx context.Context) ([]*Friendship, error) { - ctx = newQueryContext(ctx, TypeFriendship, "All") + ctx = setContextOp(ctx, fq.ctx, "All") if err := fq.prepareQuery(ctx); err != nil { return nil, err } @@ -241,7 +238,7 @@ func (fq *FriendshipQuery) AllX(ctx context.Context) []*Friendship { // IDs executes the query and returns a list of Friendship IDs. func (fq *FriendshipQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeFriendship, "IDs") + ctx = setContextOp(ctx, fq.ctx, "IDs") if err := fq.Select(friendship.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -259,7 +256,7 @@ func (fq *FriendshipQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (fq *FriendshipQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeFriendship, "Count") + ctx = setContextOp(ctx, fq.ctx, "Count") if err := fq.prepareQuery(ctx); err != nil { return 0, err } @@ -277,7 +274,7 @@ func (fq *FriendshipQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (fq *FriendshipQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeFriendship, "Exist") + ctx = setContextOp(ctx, fq.ctx, "Exist") switch _, err := fq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -305,17 +302,15 @@ func (fq *FriendshipQuery) Clone() *FriendshipQuery { } return &FriendshipQuery{ config: fq.config, - limit: fq.limit, - offset: fq.offset, + ctx: fq.ctx.Clone(), order: append([]OrderFunc{}, fq.order...), inters: append([]Interceptor{}, fq.inters...), predicates: append([]predicate.Friendship{}, fq.predicates...), withUser: fq.withUser.Clone(), withFriend: fq.withFriend.Clone(), // clone intermediate query. - sql: fq.sql.Clone(), - path: fq.path, - unique: fq.unique, + sql: fq.sql.Clone(), + path: fq.path, } } @@ -356,9 +351,9 @@ func (fq *FriendshipQuery) WithFriend(opts ...func(*UserQuery)) *FriendshipQuery // Aggregate(ent.Count()). // Scan(ctx, &v) func (fq *FriendshipQuery) GroupBy(field string, fields ...string) *FriendshipGroupBy { - fq.fields = append([]string{field}, fields...) + fq.ctx.Fields = append([]string{field}, fields...) grbuild := &FriendshipGroupBy{build: fq} - grbuild.flds = &fq.fields + grbuild.flds = &fq.ctx.Fields grbuild.label = friendship.Label grbuild.scan = grbuild.Scan return grbuild @@ -377,10 +372,10 @@ func (fq *FriendshipQuery) GroupBy(field string, fields ...string) *FriendshipGr // Select(friendship.FieldWeight). // Scan(ctx, &v) func (fq *FriendshipQuery) Select(fields ...string) *FriendshipSelect { - fq.fields = append(fq.fields, fields...) + fq.ctx.Fields = append(fq.ctx.Fields, fields...) sbuild := &FriendshipSelect{FriendshipQuery: fq} sbuild.label = friendship.Label - sbuild.flds, sbuild.scan = &fq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &fq.ctx.Fields, sbuild.Scan return sbuild } @@ -400,7 +395,7 @@ func (fq *FriendshipQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range fq.fields { + for _, f := range fq.ctx.Fields { if !friendship.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -528,9 +523,9 @@ func (fq *FriendshipQuery) sqlCount(ctx context.Context) (int, error) { if len(fq.modifiers) > 0 { _spec.Modifiers = fq.modifiers } - _spec.Node.Columns = fq.fields - if len(fq.fields) > 0 { - _spec.Unique = fq.unique != nil && *fq.unique + _spec.Node.Columns = fq.ctx.Fields + if len(fq.ctx.Fields) > 0 { + _spec.Unique = fq.ctx.Unique != nil && *fq.ctx.Unique } return sqlgraph.CountNodes(ctx, fq.driver, _spec) } @@ -548,10 +543,10 @@ func (fq *FriendshipQuery) querySpec() *sqlgraph.QuerySpec { From: fq.sql, Unique: true, } - if unique := fq.unique; unique != nil { + if unique := fq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := fq.fields; len(fields) > 0 { + if fields := fq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, friendship.FieldID) for i := range fields { @@ -567,10 +562,10 @@ func (fq *FriendshipQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := fq.limit; limit != nil { + if limit := fq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := fq.offset; offset != nil { + if offset := fq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := fq.order; len(ps) > 0 { @@ -586,7 +581,7 @@ func (fq *FriendshipQuery) querySpec() *sqlgraph.QuerySpec { func (fq *FriendshipQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(fq.driver.Dialect()) t1 := builder.Table(friendship.Table) - columns := fq.fields + columns := fq.ctx.Fields if len(columns) == 0 { columns = friendship.Columns } @@ -595,7 +590,7 @@ func (fq *FriendshipQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = fq.sql selector.Select(selector.Columns(columns...)...) } - if fq.unique != nil && *fq.unique { + if fq.ctx.Unique != nil && *fq.ctx.Unique { selector.Distinct() } t1.Schema(fq.schemaConfig.Friendship) @@ -610,12 +605,12 @@ func (fq *FriendshipQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range fq.order { p(selector) } - if offset := fq.offset; offset != nil { + if offset := fq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := fq.limit; limit != nil { + if limit := fq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -641,7 +636,7 @@ func (fgb *FriendshipGroupBy) Aggregate(fns ...AggregateFunc) *FriendshipGroupBy // Scan applies the selector query and scans the result into the given value. func (fgb *FriendshipGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFriendship, "GroupBy") + ctx = setContextOp(ctx, fgb.build.ctx, "GroupBy") if err := fgb.build.prepareQuery(ctx); err != nil { return err } @@ -689,7 +684,7 @@ func (fs *FriendshipSelect) Aggregate(fns ...AggregateFunc) *FriendshipSelect { // Scan applies the selector query and scans the result into the given value. func (fs *FriendshipSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeFriendship, "Select") + ctx = setContextOp(ctx, fs.ctx, "Select") if err := fs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/multischema/ent/group_query.go b/entc/integration/multischema/ent/group_query.go index 6e9fde295..36204d2aa 100644 --- a/entc/integration/multischema/ent/group_query.go +++ b/entc/integration/multischema/ent/group_query.go @@ -24,11 +24,8 @@ import ( // GroupQuery is the builder for querying Group entities. type GroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Group withUsers *UserQuery @@ -46,20 +43,20 @@ func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { // Limit the number of records to be returned by this query. func (gq *GroupQuery) Limit(limit int) *GroupQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GroupQuery) Offset(offset int) *GroupQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GroupQuery) Unique(unique bool) *GroupQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -97,7 +94,7 @@ func (gq *GroupQuery) QueryUsers() *UserQuery { // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGroup, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -120,7 +117,7 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group { // Returns a *NotFoundError when no Group ID was found. func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGroup, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -143,7 +140,7 @@ func (gq *GroupQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Group entity is found. // Returns a *NotFoundError when no Group entities are found. func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGroup, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -171,7 +168,7 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { // Returns a *NotFoundError when no entities are found. func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGroup, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -196,7 +193,7 @@ func (gq *GroupQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Groups. func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { - ctx = newQueryContext(ctx, TypeGroup, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -216,7 +213,7 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group { // IDs executes the query and returns a list of Group IDs. func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroup, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -234,7 +231,7 @@ func (gq *GroupQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gq *GroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroup, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -252,7 +249,7 @@ func (gq *GroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroup, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -280,16 +277,14 @@ func (gq *GroupQuery) Clone() *GroupQuery { } return &GroupQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Group{}, gq.predicates...), withUsers: gq.withUsers.Clone(), // clone intermediate query. - sql: gq.sql.Clone(), - path: gq.path, - unique: gq.unique, + sql: gq.sql.Clone(), + path: gq.path, } } @@ -319,9 +314,9 @@ func (gq *GroupQuery) WithUsers(opts ...func(*UserQuery)) *GroupQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = group.Label grbuild.scan = grbuild.Scan return grbuild @@ -340,10 +335,10 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Select(group.FieldName). // Scan(ctx, &v) func (gq *GroupQuery) Select(fields ...string) *GroupSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GroupSelect{GroupQuery: gq} sbuild.label = group.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -363,7 +358,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gq.fields { + for _, f := range gq.ctx.Fields { if !group.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -486,9 +481,9 @@ func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { if len(gq.modifiers) > 0 { _spec.Modifiers = gq.modifiers } - _spec.Node.Columns = gq.fields - if len(gq.fields) > 0 { - _spec.Unique = gq.unique != nil && *gq.unique + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique } return sqlgraph.CountNodes(ctx, gq.driver, _spec) } @@ -506,10 +501,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { From: gq.sql, Unique: true, } - if unique := gq.unique; unique != nil { + if unique := gq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gq.fields; len(fields) > 0 { + if fields := gq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) for i := range fields { @@ -525,10 +520,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gq.order; len(ps) > 0 { @@ -544,7 +539,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) - columns := gq.fields + columns := gq.ctx.Fields if len(columns) == 0 { columns = group.Columns } @@ -553,7 +548,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gq.sql selector.Select(selector.Columns(columns...)...) } - if gq.unique != nil && *gq.unique { + if gq.ctx.Unique != nil && *gq.ctx.Unique { selector.Distinct() } t1.Schema(gq.schemaConfig.Group) @@ -568,12 +563,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gq.order { p(selector) } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -599,7 +594,7 @@ func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -647,7 +642,7 @@ func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/multischema/ent/pet_query.go b/entc/integration/multischema/ent/pet_query.go index 0f6ae89ed..494b09dab 100644 --- a/entc/integration/multischema/ent/pet_query.go +++ b/entc/integration/multischema/ent/pet_query.go @@ -23,11 +23,8 @@ import ( // PetQuery is the builder for querying Pet entities. type PetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Pet withOwner *UserQuery @@ -45,20 +42,20 @@ func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { // Limit the number of records to be returned by this query. func (pq *PetQuery) Limit(limit int) *PetQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PetQuery) Offset(offset int) *PetQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PetQuery) Unique(unique bool) *PetQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -96,7 +93,7 @@ func (pq *PetQuery) QueryOwner() *UserQuery { // First returns the first Pet entity from the query. // Returns a *NotFoundError when no Pet was found. func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePet, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -119,7 +116,7 @@ func (pq *PetQuery) FirstX(ctx context.Context) *Pet { // Returns a *NotFoundError when no Pet ID was found. func (pq *PetQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePet, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -142,7 +139,7 @@ func (pq *PetQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Pet entity is found. // Returns a *NotFoundError when no Pet entities are found. func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePet, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -170,7 +167,7 @@ func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { // Returns a *NotFoundError when no entities are found. func (pq *PetQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePet, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -195,7 +192,7 @@ func (pq *PetQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Pets. func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { - ctx = newQueryContext(ctx, TypePet, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -215,7 +212,7 @@ func (pq *PetQuery) AllX(ctx context.Context) []*Pet { // IDs executes the query and returns a list of Pet IDs. func (pq *PetQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypePet, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(pet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -233,7 +230,7 @@ func (pq *PetQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pq *PetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePet, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -251,7 +248,7 @@ func (pq *PetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePet, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -279,16 +276,14 @@ func (pq *PetQuery) Clone() *PetQuery { } return &PetQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Pet{}, pq.predicates...), withOwner: pq.withOwner.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -318,9 +313,9 @@ func (pq *PetQuery) WithOwner(opts ...func(*UserQuery)) *PetQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PetGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = pet.Label grbuild.scan = grbuild.Scan return grbuild @@ -339,10 +334,10 @@ func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { // Select(pet.FieldName). // Scan(ctx, &v) func (pq *PetQuery) Select(fields ...string) *PetSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PetSelect{PetQuery: pq} sbuild.label = pet.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -362,7 +357,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !pet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -454,9 +449,9 @@ func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { if len(pq.modifiers) > 0 { _spec.Modifiers = pq.modifiers } - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -474,10 +469,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, pet.FieldID) for i := range fields { @@ -493,10 +488,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -512,7 +507,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = pet.Columns } @@ -521,7 +516,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } t1.Schema(pq.schemaConfig.Pet) @@ -536,12 +531,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -567,7 +562,7 @@ func (pgb *PetGroupBy) Aggregate(fns ...AggregateFunc) *PetGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -615,7 +610,7 @@ func (ps *PetSelect) Aggregate(fns ...AggregateFunc) *PetSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/multischema/ent/user_query.go b/entc/integration/multischema/ent/user_query.go index 1f87d6e44..d12717e2f 100644 --- a/entc/integration/multischema/ent/user_query.go +++ b/entc/integration/multischema/ent/user_query.go @@ -26,11 +26,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withPets *PetQuery @@ -51,20 +48,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -177,7 +174,7 @@ func (uq *UserQuery) QueryFriendships() *FriendshipQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -200,7 +197,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -223,7 +220,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -251,7 +248,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -276,7 +273,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -296,7 +293,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -314,7 +311,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -332,7 +329,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -360,8 +357,7 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), @@ -370,9 +366,8 @@ func (uq *UserQuery) Clone() *UserQuery { withFriends: uq.withFriends.Clone(), withFriendships: uq.withFriendships.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -435,9 +430,9 @@ func (uq *UserQuery) WithFriendships(opts ...func(*FriendshipQuery)) *UserQuery // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -456,10 +451,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldName). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -479,7 +474,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -739,9 +734,9 @@ func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { if len(uq.modifiers) > 0 { _spec.Modifiers = uq.modifiers } - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -759,10 +754,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -778,10 +773,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -797,7 +792,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -806,7 +801,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } t1.Schema(uq.schemaConfig.User) @@ -821,12 +816,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -852,7 +847,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -900,7 +895,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/privacy/ent/client.go b/entc/integration/privacy/ent/client.go index 860a5894f..a07be1ffb 100644 --- a/entc/integration/privacy/ent/client.go +++ b/entc/integration/privacy/ent/client.go @@ -237,6 +237,7 @@ func (c *TaskClient) DeleteOneID(id int) *TaskDeleteOne { func (c *TaskClient) Query() *TaskQuery { return &TaskQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTask}, inters: c.Interceptors(), } } @@ -387,6 +388,7 @@ func (c *TeamClient) DeleteOneID(id int) *TeamDeleteOne { func (c *TeamClient) Query() *TeamQuery { return &TeamQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTeam}, inters: c.Interceptors(), } } @@ -537,6 +539,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), } } diff --git a/entc/integration/privacy/ent/ent.go b/entc/integration/privacy/ent/ent.go index 198dff1b6..417a84e29 100644 --- a/entc/integration/privacy/ent/ent.go +++ b/entc/integration/privacy/ent/ent.go @@ -26,6 +26,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -511,10 +512,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/privacy/ent/task_query.go b/entc/integration/privacy/ent/task_query.go index 64195917e..c4b4534fd 100644 --- a/entc/integration/privacy/ent/task_query.go +++ b/entc/integration/privacy/ent/task_query.go @@ -25,11 +25,8 @@ import ( // TaskQuery is the builder for querying Task entities. type TaskQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Task withTeams *TeamQuery @@ -48,20 +45,20 @@ func (tq *TaskQuery) Where(ps ...predicate.Task) *TaskQuery { // Limit the number of records to be returned by this query. func (tq *TaskQuery) Limit(limit int) *TaskQuery { - tq.limit = &limit + tq.ctx.Limit = &limit return tq } // Offset to start from. func (tq *TaskQuery) Offset(offset int) *TaskQuery { - tq.offset = &offset + tq.ctx.Offset = &offset return tq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tq *TaskQuery) Unique(unique bool) *TaskQuery { - tq.unique = &unique + tq.ctx.Unique = &unique return tq } @@ -118,7 +115,7 @@ func (tq *TaskQuery) QueryOwner() *UserQuery { // First returns the first Task entity from the query. // Returns a *NotFoundError when no Task was found. func (tq *TaskQuery) First(ctx context.Context) (*Task, error) { - nodes, err := tq.Limit(1).All(newQueryContext(ctx, TypeTask, "First")) + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (tq *TaskQuery) FirstX(ctx context.Context) *Task { // Returns a *NotFoundError when no Task ID was found. func (tq *TaskQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(1).IDs(newQueryContext(ctx, TypeTask, "FirstID")); err != nil { + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -164,7 +161,7 @@ func (tq *TaskQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Task entity is found. // Returns a *NotFoundError when no Task entities are found. func (tq *TaskQuery) Only(ctx context.Context) (*Task, error) { - nodes, err := tq.Limit(2).All(newQueryContext(ctx, TypeTask, "Only")) + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) if err != nil { return nil, err } @@ -192,7 +189,7 @@ func (tq *TaskQuery) OnlyX(ctx context.Context) *Task { // Returns a *NotFoundError when no entities are found. func (tq *TaskQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(2).IDs(newQueryContext(ctx, TypeTask, "OnlyID")); err != nil { + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -217,7 +214,7 @@ func (tq *TaskQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Tasks. func (tq *TaskQuery) All(ctx context.Context) ([]*Task, error) { - ctx = newQueryContext(ctx, TypeTask, "All") + ctx = setContextOp(ctx, tq.ctx, "All") if err := tq.prepareQuery(ctx); err != nil { return nil, err } @@ -237,7 +234,7 @@ func (tq *TaskQuery) AllX(ctx context.Context) []*Task { // IDs executes the query and returns a list of Task IDs. func (tq *TaskQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeTask, "IDs") + ctx = setContextOp(ctx, tq.ctx, "IDs") if err := tq.Select(task.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -255,7 +252,7 @@ func (tq *TaskQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (tq *TaskQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeTask, "Count") + ctx = setContextOp(ctx, tq.ctx, "Count") if err := tq.prepareQuery(ctx); err != nil { return 0, err } @@ -273,7 +270,7 @@ func (tq *TaskQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tq *TaskQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeTask, "Exist") + ctx = setContextOp(ctx, tq.ctx, "Exist") switch _, err := tq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -301,17 +298,15 @@ func (tq *TaskQuery) Clone() *TaskQuery { } return &TaskQuery{ config: tq.config, - limit: tq.limit, - offset: tq.offset, + ctx: tq.ctx.Clone(), order: append([]OrderFunc{}, tq.order...), inters: append([]Interceptor{}, tq.inters...), predicates: append([]predicate.Task{}, tq.predicates...), withTeams: tq.withTeams.Clone(), withOwner: tq.withOwner.Clone(), // clone intermediate query. - sql: tq.sql.Clone(), - path: tq.path, - unique: tq.unique, + sql: tq.sql.Clone(), + path: tq.path, } } @@ -352,9 +347,9 @@ func (tq *TaskQuery) WithOwner(opts ...func(*UserQuery)) *TaskQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (tq *TaskQuery) GroupBy(field string, fields ...string) *TaskGroupBy { - tq.fields = append([]string{field}, fields...) + tq.ctx.Fields = append([]string{field}, fields...) grbuild := &TaskGroupBy{build: tq} - grbuild.flds = &tq.fields + grbuild.flds = &tq.ctx.Fields grbuild.label = task.Label grbuild.scan = grbuild.Scan return grbuild @@ -373,10 +368,10 @@ func (tq *TaskQuery) GroupBy(field string, fields ...string) *TaskGroupBy { // Select(task.FieldTitle). // Scan(ctx, &v) func (tq *TaskQuery) Select(fields ...string) *TaskSelect { - tq.fields = append(tq.fields, fields...) + tq.ctx.Fields = append(tq.ctx.Fields, fields...) sbuild := &TaskSelect{TaskQuery: tq} sbuild.label = task.Label - sbuild.flds, sbuild.scan = &tq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan return sbuild } @@ -396,7 +391,7 @@ func (tq *TaskQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range tq.fields { + for _, f := range tq.ctx.Fields { if !task.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -560,9 +555,9 @@ func (tq *TaskQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*T func (tq *TaskQuery) sqlCount(ctx context.Context) (int, error) { _spec := tq.querySpec() - _spec.Node.Columns = tq.fields - if len(tq.fields) > 0 { - _spec.Unique = tq.unique != nil && *tq.unique + _spec.Node.Columns = tq.ctx.Fields + if len(tq.ctx.Fields) > 0 { + _spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique } return sqlgraph.CountNodes(ctx, tq.driver, _spec) } @@ -580,10 +575,10 @@ func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { From: tq.sql, Unique: true, } - if unique := tq.unique; unique != nil { + if unique := tq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := tq.fields; len(fields) > 0 { + if fields := tq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, task.FieldID) for i := range fields { @@ -599,10 +594,10 @@ func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := tq.order; len(ps) > 0 { @@ -618,7 +613,7 @@ func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(tq.driver.Dialect()) t1 := builder.Table(task.Table) - columns := tq.fields + columns := tq.ctx.Fields if len(columns) == 0 { columns = task.Columns } @@ -627,7 +622,7 @@ func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = tq.sql selector.Select(selector.Columns(columns...)...) } - if tq.unique != nil && *tq.unique { + if tq.ctx.Unique != nil && *tq.ctx.Unique { selector.Distinct() } for _, p := range tq.predicates { @@ -636,12 +631,12 @@ func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range tq.order { p(selector) } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -661,7 +656,7 @@ func (tgb *TaskGroupBy) Aggregate(fns ...AggregateFunc) *TaskGroupBy { // Scan applies the selector query and scans the result into the given value. func (tgb *TaskGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTask, "GroupBy") + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") if err := tgb.build.prepareQuery(ctx); err != nil { return err } @@ -709,7 +704,7 @@ func (ts *TaskSelect) Aggregate(fns ...AggregateFunc) *TaskSelect { // Scan applies the selector query and scans the result into the given value. func (ts *TaskSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTask, "Select") + ctx = setContextOp(ctx, ts.ctx, "Select") if err := ts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/privacy/ent/team_query.go b/entc/integration/privacy/ent/team_query.go index 2983713b6..716451af4 100644 --- a/entc/integration/privacy/ent/team_query.go +++ b/entc/integration/privacy/ent/team_query.go @@ -25,11 +25,8 @@ import ( // TeamQuery is the builder for querying Team entities. type TeamQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Team withTasks *TaskQuery @@ -47,20 +44,20 @@ func (tq *TeamQuery) Where(ps ...predicate.Team) *TeamQuery { // Limit the number of records to be returned by this query. func (tq *TeamQuery) Limit(limit int) *TeamQuery { - tq.limit = &limit + tq.ctx.Limit = &limit return tq } // Offset to start from. func (tq *TeamQuery) Offset(offset int) *TeamQuery { - tq.offset = &offset + tq.ctx.Offset = &offset return tq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tq *TeamQuery) Unique(unique bool) *TeamQuery { - tq.unique = &unique + tq.ctx.Unique = &unique return tq } @@ -117,7 +114,7 @@ func (tq *TeamQuery) QueryUsers() *UserQuery { // First returns the first Team entity from the query. // Returns a *NotFoundError when no Team was found. func (tq *TeamQuery) First(ctx context.Context) (*Team, error) { - nodes, err := tq.Limit(1).All(newQueryContext(ctx, TypeTeam, "First")) + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) if err != nil { return nil, err } @@ -140,7 +137,7 @@ func (tq *TeamQuery) FirstX(ctx context.Context) *Team { // Returns a *NotFoundError when no Team ID was found. func (tq *TeamQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(1).IDs(newQueryContext(ctx, TypeTeam, "FirstID")); err != nil { + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -163,7 +160,7 @@ func (tq *TeamQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Team entity is found. // Returns a *NotFoundError when no Team entities are found. func (tq *TeamQuery) Only(ctx context.Context) (*Team, error) { - nodes, err := tq.Limit(2).All(newQueryContext(ctx, TypeTeam, "Only")) + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) if err != nil { return nil, err } @@ -191,7 +188,7 @@ func (tq *TeamQuery) OnlyX(ctx context.Context) *Team { // Returns a *NotFoundError when no entities are found. func (tq *TeamQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(2).IDs(newQueryContext(ctx, TypeTeam, "OnlyID")); err != nil { + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -216,7 +213,7 @@ func (tq *TeamQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Teams. func (tq *TeamQuery) All(ctx context.Context) ([]*Team, error) { - ctx = newQueryContext(ctx, TypeTeam, "All") + ctx = setContextOp(ctx, tq.ctx, "All") if err := tq.prepareQuery(ctx); err != nil { return nil, err } @@ -236,7 +233,7 @@ func (tq *TeamQuery) AllX(ctx context.Context) []*Team { // IDs executes the query and returns a list of Team IDs. func (tq *TeamQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeTeam, "IDs") + ctx = setContextOp(ctx, tq.ctx, "IDs") if err := tq.Select(team.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -254,7 +251,7 @@ func (tq *TeamQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (tq *TeamQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeTeam, "Count") + ctx = setContextOp(ctx, tq.ctx, "Count") if err := tq.prepareQuery(ctx); err != nil { return 0, err } @@ -272,7 +269,7 @@ func (tq *TeamQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tq *TeamQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeTeam, "Exist") + ctx = setContextOp(ctx, tq.ctx, "Exist") switch _, err := tq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -300,17 +297,15 @@ func (tq *TeamQuery) Clone() *TeamQuery { } return &TeamQuery{ config: tq.config, - limit: tq.limit, - offset: tq.offset, + ctx: tq.ctx.Clone(), order: append([]OrderFunc{}, tq.order...), inters: append([]Interceptor{}, tq.inters...), predicates: append([]predicate.Team{}, tq.predicates...), withTasks: tq.withTasks.Clone(), withUsers: tq.withUsers.Clone(), // clone intermediate query. - sql: tq.sql.Clone(), - path: tq.path, - unique: tq.unique, + sql: tq.sql.Clone(), + path: tq.path, } } @@ -351,9 +346,9 @@ func (tq *TeamQuery) WithUsers(opts ...func(*UserQuery)) *TeamQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (tq *TeamQuery) GroupBy(field string, fields ...string) *TeamGroupBy { - tq.fields = append([]string{field}, fields...) + tq.ctx.Fields = append([]string{field}, fields...) grbuild := &TeamGroupBy{build: tq} - grbuild.flds = &tq.fields + grbuild.flds = &tq.ctx.Fields grbuild.label = team.Label grbuild.scan = grbuild.Scan return grbuild @@ -372,10 +367,10 @@ func (tq *TeamQuery) GroupBy(field string, fields ...string) *TeamGroupBy { // Select(team.FieldName). // Scan(ctx, &v) func (tq *TeamQuery) Select(fields ...string) *TeamSelect { - tq.fields = append(tq.fields, fields...) + tq.ctx.Fields = append(tq.ctx.Fields, fields...) sbuild := &TeamSelect{TeamQuery: tq} sbuild.label = team.Label - sbuild.flds, sbuild.scan = &tq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan return sbuild } @@ -395,7 +390,7 @@ func (tq *TeamQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range tq.fields { + for _, f := range tq.ctx.Fields { if !team.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -579,9 +574,9 @@ func (tq *TeamQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*T func (tq *TeamQuery) sqlCount(ctx context.Context) (int, error) { _spec := tq.querySpec() - _spec.Node.Columns = tq.fields - if len(tq.fields) > 0 { - _spec.Unique = tq.unique != nil && *tq.unique + _spec.Node.Columns = tq.ctx.Fields + if len(tq.ctx.Fields) > 0 { + _spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique } return sqlgraph.CountNodes(ctx, tq.driver, _spec) } @@ -599,10 +594,10 @@ func (tq *TeamQuery) querySpec() *sqlgraph.QuerySpec { From: tq.sql, Unique: true, } - if unique := tq.unique; unique != nil { + if unique := tq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := tq.fields; len(fields) > 0 { + if fields := tq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, team.FieldID) for i := range fields { @@ -618,10 +613,10 @@ func (tq *TeamQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := tq.order; len(ps) > 0 { @@ -637,7 +632,7 @@ func (tq *TeamQuery) querySpec() *sqlgraph.QuerySpec { func (tq *TeamQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(tq.driver.Dialect()) t1 := builder.Table(team.Table) - columns := tq.fields + columns := tq.ctx.Fields if len(columns) == 0 { columns = team.Columns } @@ -646,7 +641,7 @@ func (tq *TeamQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = tq.sql selector.Select(selector.Columns(columns...)...) } - if tq.unique != nil && *tq.unique { + if tq.ctx.Unique != nil && *tq.ctx.Unique { selector.Distinct() } for _, p := range tq.predicates { @@ -655,12 +650,12 @@ func (tq *TeamQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range tq.order { p(selector) } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -680,7 +675,7 @@ func (tgb *TeamGroupBy) Aggregate(fns ...AggregateFunc) *TeamGroupBy { // Scan applies the selector query and scans the result into the given value. func (tgb *TeamGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTeam, "GroupBy") + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") if err := tgb.build.prepareQuery(ctx); err != nil { return err } @@ -728,7 +723,7 @@ func (ts *TeamSelect) Aggregate(fns ...AggregateFunc) *TeamSelect { // Scan applies the selector query and scans the result into the given value. func (ts *TeamSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeTeam, "Select") + ctx = setContextOp(ctx, ts.ctx, "Select") if err := ts.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/privacy/ent/user_query.go b/entc/integration/privacy/ent/user_query.go index 6a54cff86..0c27825b7 100644 --- a/entc/integration/privacy/ent/user_query.go +++ b/entc/integration/privacy/ent/user_query.go @@ -25,11 +25,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withTeams *TeamQuery @@ -47,20 +44,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -117,7 +114,7 @@ func (uq *UserQuery) QueryTasks() *TaskQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -140,7 +137,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -163,7 +160,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -191,7 +188,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -216,7 +213,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -236,7 +233,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -254,7 +251,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -272,7 +269,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -300,17 +297,15 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), withTeams: uq.withTeams.Clone(), withTasks: uq.withTasks.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -351,9 +346,9 @@ func (uq *UserQuery) WithTasks(opts ...func(*TaskQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -372,10 +367,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldName). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -395,7 +390,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -552,9 +547,9 @@ func (uq *UserQuery) loadTasks(ctx context.Context, query *TaskQuery, nodes []*U func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -572,10 +567,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -591,10 +586,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -610,7 +605,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -619,7 +614,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -628,12 +623,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -653,7 +648,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -701,7 +696,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/template/ent/client.go b/entc/integration/template/ent/client.go index 8c12a93da..94b62790f 100644 --- a/entc/integration/template/ent/client.go +++ b/entc/integration/template/ent/client.go @@ -246,6 +246,7 @@ func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { func (c *GroupClient) Query() *GroupQuery { return &GroupQuery{ config: c.config, + ctx: &QueryContext{Type: TypeGroup}, inters: c.Interceptors(), extra: "Group", } @@ -364,6 +365,7 @@ func (c *PetClient) DeleteOneID(id int) *PetDeleteOne { func (c *PetClient) Query() *PetQuery { return &PetQuery{ config: c.config, + ctx: &QueryContext{Type: TypePet}, inters: c.Interceptors(), extra: "Pet", } @@ -498,6 +500,7 @@ func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { func (c *UserClient) Query() *UserQuery { return &UserQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUser}, inters: c.Interceptors(), extra: "User", } diff --git a/entc/integration/template/ent/ent.go b/entc/integration/template/ent/ent.go index c8ec65854..5912f0287 100644 --- a/entc/integration/template/ent/ent.go +++ b/entc/integration/template/ent/ent.go @@ -26,6 +26,7 @@ type ( Hook = ent.Hook Value = ent.Value Query = ent.Query + QueryContext = ent.QueryContext Querier = ent.Querier QuerierFunc = ent.QuerierFunc Interceptor = ent.Interceptor @@ -511,10 +512,11 @@ func withHooks[V Value, M any, PM interface { return nv, nil } -// newQueryContext returns a new context with the given QueryContext attached in case it does not exist. -func newQueryContext(ctx context.Context, typ, op string) context.Context { +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { if ent.QueryFromContext(ctx) == nil { - ctx = ent.NewQueryContext(ctx, &ent.QueryContext{Type: typ, Op: op}) + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) } return ctx } diff --git a/entc/integration/template/ent/group_query.go b/entc/integration/template/ent/group_query.go index 71054db19..93f5cb998 100644 --- a/entc/integration/template/ent/group_query.go +++ b/entc/integration/template/ent/group_query.go @@ -21,11 +21,8 @@ import ( // GroupQuery is the builder for querying Group entities. type GroupQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Group // additional query fields. @@ -44,20 +41,20 @@ func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { // Limit the number of records to be returned by this query. func (gq *GroupQuery) Limit(limit int) *GroupQuery { - gq.limit = &limit + gq.ctx.Limit = &limit return gq } // Offset to start from. func (gq *GroupQuery) Offset(offset int) *GroupQuery { - gq.offset = &offset + gq.ctx.Offset = &offset return gq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (gq *GroupQuery) Unique(unique bool) *GroupQuery { - gq.unique = &unique + gq.ctx.Unique = &unique return gq } @@ -70,7 +67,7 @@ func (gq *GroupQuery) Order(o ...OrderFunc) *GroupQuery { // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(1).All(newQueryContext(ctx, TypeGroup, "First")) + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) if err != nil { return nil, err } @@ -93,7 +90,7 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group { // Returns a *NotFoundError when no Group ID was found. func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(1).IDs(newQueryContext(ctx, TypeGroup, "FirstID")); err != nil { + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -116,7 +113,7 @@ func (gq *GroupQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Group entity is found. // Returns a *NotFoundError when no Group entities are found. func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { - nodes, err := gq.Limit(2).All(newQueryContext(ctx, TypeGroup, "Only")) + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) if err != nil { return nil, err } @@ -144,7 +141,7 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { // Returns a *NotFoundError when no entities are found. func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = gq.Limit(2).IDs(newQueryContext(ctx, TypeGroup, "OnlyID")); err != nil { + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -169,7 +166,7 @@ func (gq *GroupQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Groups. func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { - ctx = newQueryContext(ctx, TypeGroup, "All") + ctx = setContextOp(ctx, gq.ctx, "All") if err := gq.prepareQuery(ctx); err != nil { return nil, err } @@ -189,7 +186,7 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group { // IDs executes the query and returns a list of Group IDs. func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeGroup, "IDs") + ctx = setContextOp(ctx, gq.ctx, "IDs") if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -207,7 +204,7 @@ func (gq *GroupQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (gq *GroupQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeGroup, "Count") + ctx = setContextOp(ctx, gq.ctx, "Count") if err := gq.prepareQuery(ctx); err != nil { return 0, err } @@ -225,7 +222,7 @@ func (gq *GroupQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeGroup, "Exist") + ctx = setContextOp(ctx, gq.ctx, "Exist") switch _, err := gq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -253,15 +250,13 @@ func (gq *GroupQuery) Clone() *GroupQuery { } return &GroupQuery{ config: gq.config, - limit: gq.limit, - offset: gq.offset, + ctx: gq.ctx.Clone(), order: append([]OrderFunc{}, gq.order...), inters: append([]Interceptor{}, gq.inters...), predicates: append([]predicate.Group{}, gq.predicates...), // clone intermediate query. - sql: gq.sql.Clone(), - path: gq.path, - unique: gq.unique, + sql: gq.sql.Clone(), + path: gq.path, } } @@ -280,9 +275,9 @@ func (gq *GroupQuery) Clone() *GroupQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { - gq.fields = append([]string{field}, fields...) + gq.ctx.Fields = append([]string{field}, fields...) grbuild := &GroupGroupBy{build: gq} - grbuild.flds = &gq.fields + grbuild.flds = &gq.ctx.Fields grbuild.label = group.Label grbuild.scan = grbuild.Scan return grbuild @@ -301,10 +296,10 @@ func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { // Select(group.FieldMaxUsers). // Scan(ctx, &v) func (gq *GroupQuery) Select(fields ...string) *GroupSelect { - gq.fields = append(gq.fields, fields...) + gq.ctx.Fields = append(gq.ctx.Fields, fields...) sbuild := &GroupSelect{GroupQuery: gq} sbuild.label = group.Label - sbuild.flds, sbuild.scan = &gq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan return sbuild } @@ -324,7 +319,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range gq.fields { + for _, f := range gq.ctx.Fields { if !group.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -372,9 +367,9 @@ func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { if len(gq.modifiers) > 0 { _spec.Modifiers = gq.modifiers } - _spec.Node.Columns = gq.fields - if len(gq.fields) > 0 { - _spec.Unique = gq.unique != nil && *gq.unique + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique } return sqlgraph.CountNodes(ctx, gq.driver, _spec) } @@ -392,10 +387,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { From: gq.sql, Unique: true, } - if unique := gq.unique; unique != nil { + if unique := gq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := gq.fields; len(fields) > 0 { + if fields := gq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) for i := range fields { @@ -411,10 +406,10 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := gq.order; len(ps) > 0 { @@ -430,7 +425,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) - columns := gq.fields + columns := gq.ctx.Fields if len(columns) == 0 { columns = group.Columns } @@ -439,7 +434,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = gq.sql selector.Select(selector.Columns(columns...)...) } - if gq.unique != nil && *gq.unique { + if gq.ctx.Unique != nil && *gq.ctx.Unique { selector.Distinct() } for _, m := range gq.modifiers { @@ -451,12 +446,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range gq.order { p(selector) } - if offset := gq.offset; offset != nil { + if offset := gq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := gq.limit; limit != nil { + if limit := gq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -481,7 +476,7 @@ func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { // Scan applies the selector query and scans the result into the given value. func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "GroupBy") + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") if err := ggb.build.prepareQuery(ctx); err != nil { return err } @@ -529,7 +524,7 @@ func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { // Scan applies the selector query and scans the result into the given value. func (gs *GroupSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeGroup, "Select") + ctx = setContextOp(ctx, gs.ctx, "Select") if err := gs.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/template/ent/pet_query.go b/entc/integration/template/ent/pet_query.go index 769c81836..1549c0453 100644 --- a/entc/integration/template/ent/pet_query.go +++ b/entc/integration/template/ent/pet_query.go @@ -22,11 +22,8 @@ import ( // PetQuery is the builder for querying Pet entities. type PetQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.Pet withOwner *UserQuery @@ -47,20 +44,20 @@ func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { // Limit the number of records to be returned by this query. func (pq *PetQuery) Limit(limit int) *PetQuery { - pq.limit = &limit + pq.ctx.Limit = &limit return pq } // Offset to start from. func (pq *PetQuery) Offset(offset int) *PetQuery { - pq.offset = &offset + pq.ctx.Offset = &offset return pq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pq *PetQuery) Unique(unique bool) *PetQuery { - pq.unique = &unique + pq.ctx.Unique = &unique return pq } @@ -95,7 +92,7 @@ func (pq *PetQuery) QueryOwner() *UserQuery { // First returns the first Pet entity from the query. // Returns a *NotFoundError when no Pet was found. func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(1).All(newQueryContext(ctx, TypePet, "First")) + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) if err != nil { return nil, err } @@ -118,7 +115,7 @@ func (pq *PetQuery) FirstX(ctx context.Context) *Pet { // Returns a *NotFoundError when no Pet ID was found. func (pq *PetQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(1).IDs(newQueryContext(ctx, TypePet, "FirstID")); err != nil { + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -141,7 +138,7 @@ func (pq *PetQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Pet entity is found. // Returns a *NotFoundError when no Pet entities are found. func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { - nodes, err := pq.Limit(2).All(newQueryContext(ctx, TypePet, "Only")) + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) if err != nil { return nil, err } @@ -169,7 +166,7 @@ func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { // Returns a *NotFoundError when no entities are found. func (pq *PetQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pq.Limit(2).IDs(newQueryContext(ctx, TypePet, "OnlyID")); err != nil { + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -194,7 +191,7 @@ func (pq *PetQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Pets. func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { - ctx = newQueryContext(ctx, TypePet, "All") + ctx = setContextOp(ctx, pq.ctx, "All") if err := pq.prepareQuery(ctx); err != nil { return nil, err } @@ -214,7 +211,7 @@ func (pq *PetQuery) AllX(ctx context.Context) []*Pet { // IDs executes the query and returns a list of Pet IDs. func (pq *PetQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypePet, "IDs") + ctx = setContextOp(ctx, pq.ctx, "IDs") if err := pq.Select(pet.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -232,7 +229,7 @@ func (pq *PetQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pq *PetQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypePet, "Count") + ctx = setContextOp(ctx, pq.ctx, "Count") if err := pq.prepareQuery(ctx); err != nil { return 0, err } @@ -250,7 +247,7 @@ func (pq *PetQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypePet, "Exist") + ctx = setContextOp(ctx, pq.ctx, "Exist") switch _, err := pq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -278,16 +275,14 @@ func (pq *PetQuery) Clone() *PetQuery { } return &PetQuery{ config: pq.config, - limit: pq.limit, - offset: pq.offset, + ctx: pq.ctx.Clone(), order: append([]OrderFunc{}, pq.order...), inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Pet{}, pq.predicates...), withOwner: pq.withOwner.Clone(), // clone intermediate query. - sql: pq.sql.Clone(), - path: pq.path, - unique: pq.unique, + sql: pq.sql.Clone(), + path: pq.path, } } @@ -317,9 +312,9 @@ func (pq *PetQuery) WithOwner(opts ...func(*UserQuery)) *PetQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { - pq.fields = append([]string{field}, fields...) + pq.ctx.Fields = append([]string{field}, fields...) grbuild := &PetGroupBy{build: pq} - grbuild.flds = &pq.fields + grbuild.flds = &pq.ctx.Fields grbuild.label = pet.Label grbuild.scan = grbuild.Scan return grbuild @@ -338,10 +333,10 @@ func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { // Select(pet.FieldAge). // Scan(ctx, &v) func (pq *PetQuery) Select(fields ...string) *PetSelect { - pq.fields = append(pq.fields, fields...) + pq.ctx.Fields = append(pq.ctx.Fields, fields...) sbuild := &PetSelect{PetQuery: pq} sbuild.label = pet.Label - sbuild.flds, sbuild.scan = &pq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan return sbuild } @@ -361,7 +356,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range pq.fields { + for _, f := range pq.ctx.Fields { if !pet.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -459,9 +454,9 @@ func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { if len(pq.modifiers) > 0 { _spec.Modifiers = pq.modifiers } - _spec.Node.Columns = pq.fields - if len(pq.fields) > 0 { - _spec.Unique = pq.unique != nil && *pq.unique + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique } return sqlgraph.CountNodes(ctx, pq.driver, _spec) } @@ -479,10 +474,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } - if unique := pq.unique; unique != nil { + if unique := pq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := pq.fields; len(fields) > 0 { + if fields := pq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, pet.FieldID) for i := range fields { @@ -498,10 +493,10 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pq.order; len(ps) > 0 { @@ -517,7 +512,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) - columns := pq.fields + columns := pq.ctx.Fields if len(columns) == 0 { columns = pet.Columns } @@ -526,7 +521,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pq.sql selector.Select(selector.Columns(columns...)...) } - if pq.unique != nil && *pq.unique { + if pq.ctx.Unique != nil && *pq.ctx.Unique { selector.Distinct() } for _, m := range pq.modifiers { @@ -538,12 +533,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pq.order { p(selector) } - if offset := pq.offset; offset != nil { + if offset := pq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pq.limit; limit != nil { + if limit := pq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -568,7 +563,7 @@ func (pgb *PetGroupBy) Aggregate(fns ...AggregateFunc) *PetGroupBy { // Scan applies the selector query and scans the result into the given value. func (pgb *PetGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "GroupBy") + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") if err := pgb.build.prepareQuery(ctx); err != nil { return err } @@ -616,7 +611,7 @@ func (ps *PetSelect) Aggregate(fns ...AggregateFunc) *PetSelect { // Scan applies the selector query and scans the result into the given value. func (ps *PetSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypePet, "Select") + ctx = setContextOp(ctx, ps.ctx, "Select") if err := ps.prepareQuery(ctx); err != nil { return err } diff --git a/entc/integration/template/ent/user_query.go b/entc/integration/template/ent/user_query.go index 610339532..54fdb51ba 100644 --- a/entc/integration/template/ent/user_query.go +++ b/entc/integration/template/ent/user_query.go @@ -23,11 +23,8 @@ import ( // UserQuery is the builder for querying User entities. type UserQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string inters []Interceptor predicates []predicate.User withPets *PetQuery @@ -48,20 +45,20 @@ func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { // Limit the number of records to be returned by this query. func (uq *UserQuery) Limit(limit int) *UserQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } // Offset to start from. func (uq *UserQuery) Offset(offset int) *UserQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UserQuery) Unique(unique bool) *UserQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } @@ -118,7 +115,7 @@ func (uq *UserQuery) QueryFriends() *UserQuery { // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(1).All(newQueryContext(ctx, TypeUser, "First")) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -141,7 +138,7 @@ func (uq *UserQuery) FirstX(ctx context.Context) *User { // Returns a *NotFoundError when no User ID was found. func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(newQueryContext(ctx, TypeUser, "FirstID")); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -164,7 +161,7 @@ func (uq *UserQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one User entity is found. // Returns a *NotFoundError when no User entities are found. func (uq *UserQuery) Only(ctx context.Context) (*User, error) { - nodes, err := uq.Limit(2).All(newQueryContext(ctx, TypeUser, "Only")) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -192,7 +189,7 @@ func (uq *UserQuery) OnlyX(ctx context.Context) *User { // Returns a *NotFoundError when no entities are found. func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(newQueryContext(ctx, TypeUser, "OnlyID")); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -217,7 +214,7 @@ func (uq *UserQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Users. func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { - ctx = newQueryContext(ctx, TypeUser, "All") + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } @@ -237,7 +234,7 @@ func (uq *UserQuery) AllX(ctx context.Context) []*User { // IDs executes the query and returns a list of User IDs. func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { var ids []int - ctx = newQueryContext(ctx, TypeUser, "IDs") + ctx = setContextOp(ctx, uq.ctx, "IDs") if err := uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { return nil, err } @@ -255,7 +252,7 @@ func (uq *UserQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UserQuery) Count(ctx context.Context) (int, error) { - ctx = newQueryContext(ctx, TypeUser, "Count") + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } @@ -273,7 +270,7 @@ func (uq *UserQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { - ctx = newQueryContext(ctx, TypeUser, "Exist") + ctx = setContextOp(ctx, uq.ctx, "Exist") switch _, err := uq.FirstID(ctx); { case IsNotFound(err): return false, nil @@ -301,17 +298,15 @@ func (uq *UserQuery) Clone() *UserQuery { } return &UserQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), withPets: uq.withPets.Clone(), withFriends: uq.withFriends.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } @@ -352,9 +347,9 @@ func (uq *UserQuery) WithFriends(opts ...func(*UserQuery)) *UserQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { - uq.fields = append([]string{field}, fields...) + uq.ctx.Fields = append([]string{field}, fields...) grbuild := &UserGroupBy{build: uq} - grbuild.flds = &uq.fields + grbuild.flds = &uq.ctx.Fields grbuild.label = user.Label grbuild.scan = grbuild.Scan return grbuild @@ -373,10 +368,10 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Select(user.FieldName). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { - uq.fields = append(uq.fields, fields...) + uq.ctx.Fields = append(uq.ctx.Fields, fields...) sbuild := &UserSelect{UserQuery: uq} sbuild.label = user.Label - sbuild.flds, sbuild.scan = &uq.fields, sbuild.Scan + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan return sbuild } @@ -396,7 +391,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { } } } - for _, f := range uq.fields { + for _, f := range uq.ctx.Fields { if !user.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -553,9 +548,9 @@ func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { if len(uq.modifiers) > 0 { _spec.Modifiers = uq.modifiers } - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } @@ -573,10 +568,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { From: uq.sql, Unique: true, } - if unique := uq.unique; unique != nil { + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) for i := range fields { @@ -592,10 +587,10 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -611,7 +606,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = user.Columns } @@ -620,7 +615,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, m := range uq.modifiers { @@ -632,12 +627,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -662,7 +657,7 @@ func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { // Scan applies the selector query and scans the result into the given value. func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "GroupBy") + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") if err := ugb.build.prepareQuery(ctx); err != nil { return err } @@ -710,7 +705,7 @@ func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { // Scan applies the selector query and scans the result into the given value. func (us *UserSelect) Scan(ctx context.Context, v any) error { - ctx = newQueryContext(ctx, TypeUser, "Select") + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err }