diff --git a/entc/gen/template/dialect/sql/feature/schemaconfig.tmpl b/entc/gen/template/dialect/sql/feature/schemaconfig.tmpl index 95ae96e30..df690dc91 100644 --- a/entc/gen/template/dialect/sql/feature/schemaconfig.tmpl +++ b/entc/gen/template/dialect/sql/feature/schemaconfig.tmpl @@ -69,22 +69,22 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context {{- end }} {{- end }} -{{- define "dialect/sql/delete/spec/schemaconfig" }} - {{- template "dialect/sql/spec/schemaconfig" $ }} +{{- define "dialect/sql/delete/spec/ctxschemaconfig" }} + {{- template "dialect/sql/spec/ctxschemaconfig" $ }} {{- end }} -{{- define "dialect/sql/update/spec/schemaconfig" }} - {{- template "dialect/sql/spec/schemaconfig" $ }} +{{- define "dialect/sql/update/spec/ctxschemaconfig" }} + {{- template "dialect/sql/spec/ctxschemaconfig" $ }} {{- end }} -{{- define "dialect/sql/create/spec/schemaconfig" }} +{{- define "dialect/sql/create/spec/ctxschemaconfig" }} {{- with extend $ "Ident" "_spec" "SkipContext" true }} - {{- template "dialect/sql/spec/schemaconfig" . }} + {{- template "dialect/sql/spec/ctxschemaconfig" . }} {{- end }} {{- end }} -{{- define "dialect/sql/query/spec/schemaconfig" }} - {{- template "dialect/sql/spec/schemaconfig" . }} +{{- define "dialect/sql/query/spec/ctxschemaconfig" }} + {{- template "dialect/sql/spec/ctxschemaconfig" . }} {{- end }} {{- define "dialect/sql/query/eagerloading/spec/schemaconfig" }} @@ -113,7 +113,8 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context {{- end }} {{- end }} -{{- define "dialect/sql/spec/schemaconfig" -}} +{{/* A template for injecting the SchemaConfig to the context. Should be executed before other templates. */}} +{{- define "dialect/sql/spec/ctxschemaconfig" -}} {{- $builder := pascal $.Scope.Builder }} {{- $receiver := receiver $builder }} {{- $ident := "_spec.Node" }}{{ with $.Scope.Ident }}{{ $ident = . }}{{ end }} @@ -125,7 +126,7 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context {{- end }} {{- end -}} -{{- define "dialect/sql/query/selector/schemaconfig" -}} +{{- define "dialect/sql/query/selector/ctxschemaconfig" -}} {{- $builder := pascal $.Scope.Builder }} {{- $receiver := receiver $builder }} {{- if $.FeatureEnabled "sql/schemaconfig" }} @@ -135,36 +136,36 @@ func NewSchemaConfigContext(parent context.Context, config SchemaConfig) context {{- end }} {{- end -}} -{{- define "dialect/sql/query/path/schemaconfig" }} +{{- define "dialect/sql/query/path/ctxschemaconfig" }} {{- if $.FeatureEnabled "sql/schemaconfig" }} schemaConfig := {{ $.Scope.Receiver }}.schemaConfig - {{- template "dialect/sql/query/step/schemaconfig" . }} + {{- template "dialect/sql/query/step/ctxschemaconfig" . }} {{- end -}} {{- end -}} -{{- define "dialect/sql/query/from/schemaconfig" }} +{{- define "dialect/sql/query/from/ctxschemaconfig" }} {{- if $.FeatureEnabled "sql/schemaconfig" }} schemaConfig := {{ $.Scope.Receiver }}.schemaConfig - {{- template "dialect/sql/query/step/schemaconfig" . }} + {{- template "dialect/sql/query/step/ctxschemaconfig" . }} {{- end -}} {{- end -}} -{{- define "dialect/sql/predicate/edge/has/schemaconfig" -}} - {{- template "dialect/sql/predicate/edge/schemaconfig" . }} +{{- define "dialect/sql/predicate/edge/has/ctxschemaconfig" -}} + {{- template "dialect/sql/predicate/edge/ctxschemaconfig" . }} {{- end -}} -{{- define "dialect/sql/predicate/edge/haswith/schemaconfig" -}} - {{- template "dialect/sql/predicate/edge/schemaconfig" . }} +{{- define "dialect/sql/predicate/edge/haswith/ctxschemaconfig" -}} + {{- template "dialect/sql/predicate/edge/ctxschemaconfig" . }} {{- end -}} -{{- define "dialect/sql/predicate/edge/schemaconfig" -}} +{{- define "dialect/sql/predicate/edge/ctxschemaconfig" -}} {{- if $.FeatureEnabled "sql/schemaconfig" }} schemaConfig := internal.SchemaConfigFromContext(s.Context()) - {{- template "dialect/sql/query/step/schemaconfig" . }} + {{- template "dialect/sql/query/step/ctxschemaconfig" . }} {{- end -}} {{- end -}} -{{- define "dialect/sql/query/step/schemaconfig" -}} +{{- define "dialect/sql/query/step/ctxschemaconfig" -}} {{- $e := $.Scope.Edge }} step.To.Schema = schemaConfig.{{ $e.Type.Name }} {{- $schema := $e.Type.Name }} diff --git a/entc/integration/multischema/ent/config.go b/entc/integration/multischema/ent/config.go index 00e70df8f..063cc0f55 100644 --- a/entc/integration/multischema/ent/config.go +++ b/entc/integration/multischema/ent/config.go @@ -7,6 +7,8 @@ package ent import ( + "context" + "entgo.io/ent" "entgo.io/ent/dialect" "entgo.io/ent/entc/integration/multischema/ent/internal" @@ -67,6 +69,12 @@ func Driver(driver dialect.Driver) Option { } } +// SchemaConfigFromContext exports the internal.SchemaConfigFromContext +// for external usage (inside custom predicates or modifiers). +func SchemaConfigFromContext(ctx context.Context) SchemaConfig { + return internal.SchemaConfigFromContext(ctx) +} + // SchemaConfig represents alternative schema names for all tables // that can be passed at runtime. type SchemaConfig = internal.SchemaConfig diff --git a/entc/integration/multischema/ent/generate.go b/entc/integration/multischema/ent/generate.go index c6a376934..f36c45723 100644 --- a/entc/integration/multischema/ent/generate.go +++ b/entc/integration/multischema/ent/generate.go @@ -4,4 +4,4 @@ package ent -//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/schemaconfig --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema +//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/modifier,sql/schemaconfig --template ./template --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema diff --git a/entc/integration/multischema/ent/group_query.go b/entc/integration/multischema/ent/group_query.go index ccf440499..2a081f0a5 100644 --- a/entc/integration/multischema/ent/group_query.go +++ b/entc/integration/multischema/ent/group_query.go @@ -33,6 +33,7 @@ type GroupQuery struct { predicates []predicate.Group // eager-loading edges. withUsers *UserQuery + modifiers []func(s *sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -377,6 +378,9 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } _spec.Node.Schema = gq.schemaConfig.Group ctx = internal.NewSchemaConfigContext(ctx, gq.schemaConfig) + if len(gq.modifiers) > 0 { + _spec.Modifiers = gq.modifiers + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } @@ -457,6 +461,9 @@ func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { _spec := gq.querySpec() _spec.Node.Schema = gq.schemaConfig.Group ctx = internal.NewSchemaConfigContext(ctx, gq.schemaConfig) + if len(gq.modifiers) > 0 { + _spec.Modifiers = gq.modifiers + } _spec.Node.Columns = gq.fields if len(gq.fields) > 0 { _spec.Unique = gq.unique != nil && *gq.unique @@ -538,6 +545,9 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { t1.Schema(gq.schemaConfig.Group) ctx = internal.NewSchemaConfigContext(ctx, gq.schemaConfig) selector.WithContext(ctx) + for _, m := range gq.modifiers { + m(selector) + } for _, p := range gq.predicates { p(selector) } @@ -555,6 +565,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { return selector } +// Modify adds a query modifier for attaching custom logic to queries. +func (gq *GroupQuery) Modify(modifiers ...func(s *sql.Selector)) *GroupSelect { + gq.modifiers = append(gq.modifiers, modifiers...) + return gq.Select() +} + // GroupGroupBy is the group-by builder for Group entities. type GroupGroupBy struct { config @@ -1044,3 +1060,9 @@ func (gs *GroupSelect) sqlScan(ctx context.Context, v interface{}) error { defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (gs *GroupSelect) Modify(modifiers ...func(s *sql.Selector)) *GroupSelect { + gs.modifiers = append(gs.modifiers, modifiers...) + return gs +} diff --git a/entc/integration/multischema/ent/migrate/schema.go b/entc/integration/multischema/ent/migrate/schema.go index 8aceab15c..2cad2cda2 100644 --- a/entc/integration/multischema/ent/migrate/schema.go +++ b/entc/integration/multischema/ent/migrate/schema.go @@ -27,7 +27,7 @@ var ( PetsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Default: "unknown"}, - {Name: "user_pets", Type: field.TypeInt, Nullable: true}, + {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } // PetsTable holds the schema information for the "pets" table. PetsTable = &schema.Table{ diff --git a/entc/integration/multischema/ent/mutation.go b/entc/integration/multischema/ent/mutation.go index 787499af1..8d55285d5 100644 --- a/entc/integration/multischema/ent/mutation.go +++ b/entc/integration/multischema/ent/mutation.go @@ -548,9 +548,53 @@ func (m *PetMutation) ResetName() { m.name = nil } -// SetOwnerID sets the "owner" edge to the User entity by id. -func (m *PetMutation) SetOwnerID(id int) { - m.owner = &id +// SetOwnerID sets the "owner_id" field. +func (m *PetMutation) SetOwnerID(i int) { + m.owner = &i +} + +// OwnerID returns the value of the "owner_id" field in the mutation. +func (m *PetMutation) OwnerID() (r int, exists bool) { + v := m.owner + if v == nil { + return + } + return *v, true +} + +// OldOwnerID returns the old "owner_id" field's value of the Pet entity. +// If the Pet object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PetMutation) OldOwnerID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, fmt.Errorf("OldOwnerID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, fmt.Errorf("OldOwnerID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOwnerID: %w", err) + } + return oldValue.OwnerID, nil +} + +// ClearOwnerID clears the value of the "owner_id" field. +func (m *PetMutation) ClearOwnerID() { + m.owner = nil + m.clearedFields[pet.FieldOwnerID] = struct{}{} +} + +// OwnerIDCleared returns if the "owner_id" field was cleared in this mutation. +func (m *PetMutation) OwnerIDCleared() bool { + _, ok := m.clearedFields[pet.FieldOwnerID] + return ok +} + +// ResetOwnerID resets all changes to the "owner_id" field. +func (m *PetMutation) ResetOwnerID() { + m.owner = nil + delete(m.clearedFields, pet.FieldOwnerID) } // ClearOwner clears the "owner" edge to the User entity. @@ -560,15 +604,7 @@ func (m *PetMutation) ClearOwner() { // OwnerCleared reports if the "owner" edge to the User entity was cleared. func (m *PetMutation) OwnerCleared() bool { - return m.clearedowner -} - -// OwnerID returns the "owner" edge ID in the mutation. -func (m *PetMutation) OwnerID() (id int, exists bool) { - if m.owner != nil { - return *m.owner, true - } - return + return m.OwnerIDCleared() || m.clearedowner } // OwnerIDs returns the "owner" edge IDs in the mutation. @@ -606,10 +642,13 @@ func (m *PetMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *PetMutation) Fields() []string { - fields := make([]string, 0, 1) + fields := make([]string, 0, 2) if m.name != nil { fields = append(fields, pet.FieldName) } + if m.owner != nil { + fields = append(fields, pet.FieldOwnerID) + } return fields } @@ -620,6 +659,8 @@ func (m *PetMutation) Field(name string) (ent.Value, bool) { switch name { case pet.FieldName: return m.Name() + case pet.FieldOwnerID: + return m.OwnerID() } return nil, false } @@ -631,6 +672,8 @@ func (m *PetMutation) OldField(ctx context.Context, name string) (ent.Value, err switch name { case pet.FieldName: return m.OldName(ctx) + case pet.FieldOwnerID: + return m.OldOwnerID(ctx) } return nil, fmt.Errorf("unknown Pet field %s", name) } @@ -647,6 +690,13 @@ func (m *PetMutation) SetField(name string, value ent.Value) error { } m.SetName(v) return nil + case pet.FieldOwnerID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOwnerID(v) + return nil } return fmt.Errorf("unknown Pet field %s", name) } @@ -654,13 +704,16 @@ func (m *PetMutation) SetField(name string, value ent.Value) error { // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. func (m *PetMutation) AddedFields() []string { - return nil + var fields []string + return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. func (m *PetMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } return nil, false } @@ -676,7 +729,11 @@ func (m *PetMutation) AddField(name string, value ent.Value) error { // ClearedFields returns all nullable fields that were cleared during this // mutation. func (m *PetMutation) ClearedFields() []string { - return nil + var fields []string + if m.FieldCleared(pet.FieldOwnerID) { + fields = append(fields, pet.FieldOwnerID) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was @@ -689,6 +746,11 @@ func (m *PetMutation) FieldCleared(name string) bool { // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. func (m *PetMutation) ClearField(name string) error { + switch name { + case pet.FieldOwnerID: + m.ClearOwnerID() + return nil + } return fmt.Errorf("unknown Pet nullable field %s", name) } @@ -699,6 +761,9 @@ func (m *PetMutation) ResetField(name string) error { case pet.FieldName: m.ResetName() return nil + case pet.FieldOwnerID: + m.ResetOwnerID() + return nil } return fmt.Errorf("unknown Pet field %s", name) } diff --git a/entc/integration/multischema/ent/pet.go b/entc/integration/multischema/ent/pet.go index baff75180..052c52523 100644 --- a/entc/integration/multischema/ent/pet.go +++ b/entc/integration/multischema/ent/pet.go @@ -22,10 +22,11 @@ type Pet struct { ID int `json:"id,omitempty"` // Name holds the value of the "name" field. Name string `json:"name,omitempty"` + // OwnerID holds the value of the "owner_id" field. + OwnerID int `json:"owner_id,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the PetQuery when eager-loading is set. - Edges PetEdges `json:"edges"` - user_pets *int + Edges PetEdges `json:"edges"` } // PetEdges holds the relations/edges for other nodes in the graph. @@ -56,12 +57,10 @@ func (*Pet) scanValues(columns []string) ([]interface{}, error) { values := make([]interface{}, len(columns)) for i := range columns { switch columns[i] { - case pet.FieldID: + case pet.FieldID, pet.FieldOwnerID: values[i] = new(sql.NullInt64) case pet.FieldName: values[i] = new(sql.NullString) - case pet.ForeignKeys[0]: // user_pets - values[i] = new(sql.NullInt64) default: return nil, fmt.Errorf("unexpected column %q for type Pet", columns[i]) } @@ -89,12 +88,11 @@ func (pe *Pet) assignValues(columns []string, values []interface{}) error { } else if value.Valid { pe.Name = value.String } - case pet.ForeignKeys[0]: + case pet.FieldOwnerID: if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for edge-field user_pets", value) + return fmt.Errorf("unexpected type %T for field owner_id", values[i]) } else if value.Valid { - pe.user_pets = new(int) - *pe.user_pets = int(value.Int64) + pe.OwnerID = int(value.Int64) } } } @@ -131,6 +129,8 @@ func (pe *Pet) String() string { builder.WriteString(fmt.Sprintf("id=%v", pe.ID)) builder.WriteString(", name=") builder.WriteString(pe.Name) + builder.WriteString(", owner_id=") + builder.WriteString(fmt.Sprintf("%v", pe.OwnerID)) builder.WriteByte(')') return builder.String() } diff --git a/entc/integration/multischema/ent/pet/pet.go b/entc/integration/multischema/ent/pet/pet.go index 35ce11dbe..0991dc50e 100644 --- a/entc/integration/multischema/ent/pet/pet.go +++ b/entc/integration/multischema/ent/pet/pet.go @@ -13,6 +13,8 @@ const ( FieldID = "id" // FieldName holds the string denoting the name field in the database. FieldName = "name" + // FieldOwnerID holds the string denoting the owner_id field in the database. + FieldOwnerID = "owner_id" // EdgeOwner holds the string denoting the owner edge name in mutations. EdgeOwner = "owner" // Table holds the table name of the pet in the database. @@ -23,19 +25,14 @@ const ( // It exists in this package in order to avoid circular dependency with the "user" package. OwnerInverseTable = "users" // OwnerColumn is the table column denoting the owner relation/edge. - OwnerColumn = "user_pets" + OwnerColumn = "owner_id" ) // Columns holds all SQL columns for pet fields. var Columns = []string{ FieldID, FieldName, -} - -// ForeignKeys holds the SQL foreign-keys that are owned by the "pets" -// table and are not defined as standalone fields in the schema. -var ForeignKeys = []string{ - "user_pets", + FieldOwnerID, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -45,11 +42,6 @@ func ValidColumn(column string) bool { return true } } - for i := range ForeignKeys { - if column == ForeignKeys[i] { - return true - } - } return false } diff --git a/entc/integration/multischema/ent/pet/where.go b/entc/integration/multischema/ent/pet/where.go index a27cd8394..1411a8cb0 100644 --- a/entc/integration/multischema/ent/pet/where.go +++ b/entc/integration/multischema/ent/pet/where.go @@ -103,6 +103,13 @@ func Name(v string) predicate.Pet { }) } +// OwnerID applies equality check predicate on the "owner_id" field. It's identical to OwnerIDEQ. +func OwnerID(v int) predicate.Pet { + return predicate.Pet(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldOwnerID), v)) + }) +} + // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.Pet { return predicate.Pet(func(s *sql.Selector) { @@ -214,6 +221,68 @@ func NameContainsFold(v string) predicate.Pet { }) } +// OwnerIDEQ applies the EQ predicate on the "owner_id" field. +func OwnerIDEQ(v int) predicate.Pet { + return predicate.Pet(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldOwnerID), v)) + }) +} + +// OwnerIDNEQ applies the NEQ predicate on the "owner_id" field. +func OwnerIDNEQ(v int) predicate.Pet { + return predicate.Pet(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldOwnerID), v)) + }) +} + +// OwnerIDIn applies the In predicate on the "owner_id" field. +func OwnerIDIn(vs ...int) predicate.Pet { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Pet(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldOwnerID), v...)) + }) +} + +// OwnerIDNotIn applies the NotIn predicate on the "owner_id" field. +func OwnerIDNotIn(vs ...int) predicate.Pet { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Pet(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldOwnerID), v...)) + }) +} + +// OwnerIDIsNil applies the IsNil predicate on the "owner_id" field. +func OwnerIDIsNil() predicate.Pet { + return predicate.Pet(func(s *sql.Selector) { + s.Where(sql.IsNull(s.C(FieldOwnerID))) + }) +} + +// OwnerIDNotNil applies the NotNil predicate on the "owner_id" field. +func OwnerIDNotNil() predicate.Pet { + return predicate.Pet(func(s *sql.Selector) { + s.Where(sql.NotNull(s.C(FieldOwnerID))) + }) +} + // HasOwner applies the HasEdge predicate on the "owner" edge. func HasOwner() predicate.Pet { return predicate.Pet(func(s *sql.Selector) { diff --git a/entc/integration/multischema/ent/pet_create.go b/entc/integration/multischema/ent/pet_create.go index f94afb889..d55af1f22 100644 --- a/entc/integration/multischema/ent/pet_create.go +++ b/entc/integration/multischema/ent/pet_create.go @@ -38,16 +38,16 @@ func (pc *PetCreate) SetNillableName(s *string) *PetCreate { return pc } -// SetOwnerID sets the "owner" edge to the User entity by ID. -func (pc *PetCreate) SetOwnerID(id int) *PetCreate { - pc.mutation.SetOwnerID(id) +// SetOwnerID sets the "owner_id" field. +func (pc *PetCreate) SetOwnerID(i int) *PetCreate { + pc.mutation.SetOwnerID(i) return pc } -// SetNillableOwnerID sets the "owner" edge to the User entity by ID if the given value is not nil. -func (pc *PetCreate) SetNillableOwnerID(id *int) *PetCreate { - if id != nil { - pc = pc.SetOwnerID(*id) +// SetNillableOwnerID sets the "owner_id" field if the given value is not nil. +func (pc *PetCreate) SetNillableOwnerID(i *int) *PetCreate { + if i != nil { + pc.SetOwnerID(*i) } return pc } @@ -193,7 +193,7 @@ func (pc *PetCreate) createSpec() (*Pet, *sqlgraph.CreateSpec) { for _, k := range nodes { edge.Target.Nodes = append(edge.Target.Nodes, k) } - _node.user_pets = &nodes[0] + _node.OwnerID = nodes[0] _spec.Edges = append(_spec.Edges, edge) } return _node, _spec diff --git a/entc/integration/multischema/ent/pet_query.go b/entc/integration/multischema/ent/pet_query.go index 6aace0e19..070891663 100644 --- a/entc/integration/multischema/ent/pet_query.go +++ b/entc/integration/multischema/ent/pet_query.go @@ -32,7 +32,7 @@ type PetQuery struct { predicates []predicate.Pet // eager-loading edges. withOwner *UserQuery - withFKs bool + modifiers []func(s *sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -357,18 +357,11 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { var ( nodes = []*Pet{} - withFKs = pq.withFKs _spec = pq.querySpec() loadedTypes = [1]bool{ pq.withOwner != nil, } ) - if pq.withOwner != nil { - withFKs = true - } - if withFKs { - _spec.Node.Columns = append(_spec.Node.Columns, pet.ForeignKeys...) - } _spec.ScanValues = func(columns []string) ([]interface{}, error) { node := &Pet{config: pq.config} nodes = append(nodes, node) @@ -384,6 +377,9 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { } _spec.Node.Schema = pq.schemaConfig.Pet ctx = internal.NewSchemaConfigContext(ctx, pq.schemaConfig) + if len(pq.modifiers) > 0 { + _spec.Modifiers = pq.modifiers + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } @@ -395,10 +391,7 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { ids := make([]int, 0, len(nodes)) nodeids := make(map[int][]*Pet) for i := range nodes { - if nodes[i].user_pets == nil { - continue - } - fk := *nodes[i].user_pets + fk := nodes[i].OwnerID if _, ok := nodeids[fk]; !ok { ids = append(ids, fk) } @@ -412,7 +405,7 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { for _, n := range neighbors { nodes, ok := nodeids[n.ID] if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) + return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) } for i := range nodes { nodes[i].Edges.Owner = n @@ -427,6 +420,9 @@ func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() _spec.Node.Schema = pq.schemaConfig.Pet ctx = internal.NewSchemaConfigContext(ctx, pq.schemaConfig) + if len(pq.modifiers) > 0 { + _spec.Modifiers = pq.modifiers + } _spec.Node.Columns = pq.fields if len(pq.fields) > 0 { _spec.Unique = pq.unique != nil && *pq.unique @@ -508,6 +504,9 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { t1.Schema(pq.schemaConfig.Pet) ctx = internal.NewSchemaConfigContext(ctx, pq.schemaConfig) selector.WithContext(ctx) + for _, m := range pq.modifiers { + m(selector) + } for _, p := range pq.predicates { p(selector) } @@ -525,6 +524,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { return selector } +// Modify adds a query modifier for attaching custom logic to queries. +func (pq *PetQuery) Modify(modifiers ...func(s *sql.Selector)) *PetSelect { + pq.modifiers = append(pq.modifiers, modifiers...) + return pq.Select() +} + // PetGroupBy is the group-by builder for Pet entities. type PetGroupBy struct { config @@ -1014,3 +1019,9 @@ func (ps *PetSelect) sqlScan(ctx context.Context, v interface{}) error { defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (ps *PetSelect) Modify(modifiers ...func(s *sql.Selector)) *PetSelect { + ps.modifiers = append(ps.modifiers, modifiers...) + return ps +} diff --git a/entc/integration/multischema/ent/pet_update.go b/entc/integration/multischema/ent/pet_update.go index d22b37d7c..1590d50ce 100644 --- a/entc/integration/multischema/ent/pet_update.go +++ b/entc/integration/multischema/ent/pet_update.go @@ -47,20 +47,26 @@ func (pu *PetUpdate) SetNillableName(s *string) *PetUpdate { return pu } -// SetOwnerID sets the "owner" edge to the User entity by ID. -func (pu *PetUpdate) SetOwnerID(id int) *PetUpdate { - pu.mutation.SetOwnerID(id) +// SetOwnerID sets the "owner_id" field. +func (pu *PetUpdate) SetOwnerID(i int) *PetUpdate { + pu.mutation.SetOwnerID(i) return pu } -// SetNillableOwnerID sets the "owner" edge to the User entity by ID if the given value is not nil. -func (pu *PetUpdate) SetNillableOwnerID(id *int) *PetUpdate { - if id != nil { - pu = pu.SetOwnerID(*id) +// SetNillableOwnerID sets the "owner_id" field if the given value is not nil. +func (pu *PetUpdate) SetNillableOwnerID(i *int) *PetUpdate { + if i != nil { + pu.SetOwnerID(*i) } return pu } +// ClearOwnerID clears the value of the "owner_id" field. +func (pu *PetUpdate) ClearOwnerID() *PetUpdate { + pu.mutation.ClearOwnerID() + return pu +} + // SetOwner sets the "owner" edge to the User entity. func (pu *PetUpdate) SetOwner(u *User) *PetUpdate { return pu.SetOwnerID(u.ID) @@ -228,20 +234,26 @@ func (puo *PetUpdateOne) SetNillableName(s *string) *PetUpdateOne { return puo } -// SetOwnerID sets the "owner" edge to the User entity by ID. -func (puo *PetUpdateOne) SetOwnerID(id int) *PetUpdateOne { - puo.mutation.SetOwnerID(id) +// SetOwnerID sets the "owner_id" field. +func (puo *PetUpdateOne) SetOwnerID(i int) *PetUpdateOne { + puo.mutation.SetOwnerID(i) return puo } -// SetNillableOwnerID sets the "owner" edge to the User entity by ID if the given value is not nil. -func (puo *PetUpdateOne) SetNillableOwnerID(id *int) *PetUpdateOne { - if id != nil { - puo = puo.SetOwnerID(*id) +// SetNillableOwnerID sets the "owner_id" field if the given value is not nil. +func (puo *PetUpdateOne) SetNillableOwnerID(i *int) *PetUpdateOne { + if i != nil { + puo.SetOwnerID(*i) } return puo } +// ClearOwnerID clears the value of the "owner_id" field. +func (puo *PetUpdateOne) ClearOwnerID() *PetUpdateOne { + puo.mutation.ClearOwnerID() + return puo +} + // SetOwner sets the "owner" edge to the User entity. func (puo *PetUpdateOne) SetOwner(u *User) *PetUpdateOne { return puo.SetOwnerID(u.ID) diff --git a/entc/integration/multischema/ent/schema/pet.go b/entc/integration/multischema/ent/schema/pet.go index eb93499df..3e5514278 100644 --- a/entc/integration/multischema/ent/schema/pet.go +++ b/entc/integration/multischema/ent/schema/pet.go @@ -20,6 +20,8 @@ func (Pet) Fields() []ent.Field { return []ent.Field{ field.String("name"). Default("unknown"), + field.Int("owner_id"). + Optional(), } } @@ -28,6 +30,7 @@ func (Pet) Edges() []ent.Edge { return []ent.Edge{ edge.From("owner", User.Type). Ref("pets"). + Field("owner_id"). Unique(), } } diff --git a/entc/integration/multischema/ent/template/config.tmpl b/entc/integration/multischema/ent/template/config.tmpl new file mode 100644 index 000000000..fe0ec8dbc --- /dev/null +++ b/entc/integration/multischema/ent/template/config.tmpl @@ -0,0 +1,16 @@ +{{/* +Copyright 2019-present Facebook Inc. All rights reserved. +This source code is licensed under the Apache 2.0 license found +in the LICENSE file in the root directory of this source tree. +*/}} + +{{/* The line below tells Intellij/GoLand to enable the autocompletion based *gen.Graph type. */}} +{{/* gotype: entgo.io/ent/entc/gen.Graph */}} + +{{ define "config/options/schemaconfig" }} +// SchemaConfigFromContext exports the internal.SchemaConfigFromContext +// for external usage (inside custom predicates or modifiers). +func SchemaConfigFromContext(ctx context.Context) SchemaConfig { + return internal.SchemaConfigFromContext(ctx) +} +{{ end }} diff --git a/entc/integration/multischema/ent/user/user.go b/entc/integration/multischema/ent/user/user.go index 1ee2ee3f4..ad39b2eb3 100644 --- a/entc/integration/multischema/ent/user/user.go +++ b/entc/integration/multischema/ent/user/user.go @@ -25,7 +25,7 @@ const ( // It exists in this package in order to avoid circular dependency with the "pet" package. PetsInverseTable = "pets" // PetsColumn is the table column denoting the pets relation/edge. - PetsColumn = "user_pets" + PetsColumn = "owner_id" // GroupsTable is the table that holds the groups relation/edge. The primary key declared below. GroupsTable = "group_users" // GroupsInverseTable is the table name for the Group entity. diff --git a/entc/integration/multischema/ent/user_query.go b/entc/integration/multischema/ent/user_query.go index 1b0c0a2f0..b9512520b 100644 --- a/entc/integration/multischema/ent/user_query.go +++ b/entc/integration/multischema/ent/user_query.go @@ -35,6 +35,7 @@ type UserQuery struct { // eager-loading edges. withPets *PetQuery withGroups *GroupQuery + modifiers []func(s *sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -417,6 +418,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } _spec.Node.Schema = uq.schemaConfig.User ctx = internal.NewSchemaConfigContext(ctx, uq.schemaConfig) + if len(uq.modifiers) > 0 { + _spec.Modifiers = uq.modifiers + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -432,7 +436,6 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { nodeids[nodes[i].ID] = nodes[i] nodes[i].Edges.Pets = []*Pet{} } - query.withFKs = true query.Where(predicate.Pet(func(s *sql.Selector) { s.Where(sql.InValues(user.PetsColumn, fks...)) })) @@ -441,13 +444,10 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { return nil, err } for _, n := range neighbors { - fk := n.user_pets - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] + fk := n.OwnerID + node, ok := nodeids[fk] if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) + return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, fk, n.ID) } node.Edges.Pets = append(node.Edges.Pets, n) } @@ -526,6 +526,9 @@ func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() _spec.Node.Schema = uq.schemaConfig.User ctx = internal.NewSchemaConfigContext(ctx, uq.schemaConfig) + if len(uq.modifiers) > 0 { + _spec.Modifiers = uq.modifiers + } _spec.Node.Columns = uq.fields if len(uq.fields) > 0 { _spec.Unique = uq.unique != nil && *uq.unique @@ -607,6 +610,9 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { t1.Schema(uq.schemaConfig.User) ctx = internal.NewSchemaConfigContext(ctx, uq.schemaConfig) selector.WithContext(ctx) + for _, m := range uq.modifiers { + m(selector) + } for _, p := range uq.predicates { p(selector) } @@ -624,6 +630,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { return selector } +// Modify adds a query modifier for attaching custom logic to queries. +func (uq *UserQuery) Modify(modifiers ...func(s *sql.Selector)) *UserSelect { + uq.modifiers = append(uq.modifiers, modifiers...) + return uq.Select() +} + // UserGroupBy is the group-by builder for User entities. type UserGroupBy struct { config @@ -1113,3 +1125,9 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (us *UserSelect) Modify(modifiers ...func(s *sql.Selector)) *UserSelect { + us.modifiers = append(us.modifiers, modifiers...) + return us +} diff --git a/entc/integration/multischema/multischema_test.go b/entc/integration/multischema/multischema_test.go index d4944bf0b..5ac33f7cb 100644 --- a/entc/integration/multischema/multischema_test.go +++ b/entc/integration/multischema/multischema_test.go @@ -45,7 +45,29 @@ func TestMySQL(t *testing.T) { client.Group.Create().SetName("GitHub"), client.Group.Create().SetName("GitLab"), ).SaveX(ctx) - usr := client.User.Create().AddPets(pedro).AddGroups(groups...).SaveX(ctx) + usr := client.User.Create().SetName("a8m").AddPets(pedro).AddGroups(groups...).SaveX(ctx) + + // Custom modifier with schema config. + var names []struct { + User string `sql:"user_name"` + Pet string `sql:"pet_name"` + } + client.Pet.Query(). + Modify(func(s *sql.Selector) { + // The below function is exported using a custom + // template defined in ent/template/config.tmpl. + cfg := ent.SchemaConfigFromContext(s.Context()) + t := sql.Table(user.Table).Schema(cfg.User) + s.Join(t).On(s.C(pet.FieldOwnerID), t.C(user.FieldID)) + s.Select( + sql.As(t.C(user.FieldName), "user_name"), + sql.As(s.C(pet.FieldName), "pet_name"), + ) + }). + ScanX(ctx, &names) + require.Len(t, names, 1) + require.Equal(t, "a8m", names[0].User) + require.Equal(t, "Pedro", names[0].Pet) id := client.Group.Query(). Where(group.HasUsersWith(user.ID(usr.ID))).