entc/gen/integ: add example for using query modifiers in multischema mode

This commit is contained in:
Ariel Mashraki
2021-10-18 18:43:31 +03:00
committed by Ariel Mashraki
parent 2c9a175f06
commit aa8d2ecb58
17 changed files with 342 additions and 103 deletions

View File

@@ -7,6 +7,8 @@
package ent
import (
"context"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/entc/integration/multischema/ent/internal"
@@ -67,6 +69,12 @@ func Driver(driver dialect.Driver) Option {
}
}
// SchemaConfigFromContext exports the internal.SchemaConfigFromContext
// for external usage (inside custom predicates or modifiers).
func SchemaConfigFromContext(ctx context.Context) SchemaConfig {
return internal.SchemaConfigFromContext(ctx)
}
// SchemaConfig represents alternative schema names for all tables
// that can be passed at runtime.
type SchemaConfig = internal.SchemaConfig

View File

@@ -4,4 +4,4 @@
package ent
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/schemaconfig --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/modifier,sql/schemaconfig --template ./template --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema

View File

@@ -33,6 +33,7 @@ type GroupQuery struct {
predicates []predicate.Group
// eager-loading edges.
withUsers *UserQuery
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -377,6 +378,9 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) {
}
_spec.Node.Schema = gq.schemaConfig.Group
ctx = internal.NewSchemaConfigContext(ctx, gq.schemaConfig)
if len(gq.modifiers) > 0 {
_spec.Modifiers = gq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil {
return nil, err
}
@@ -457,6 +461,9 @@ func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) {
_spec := gq.querySpec()
_spec.Node.Schema = gq.schemaConfig.Group
ctx = internal.NewSchemaConfigContext(ctx, gq.schemaConfig)
if len(gq.modifiers) > 0 {
_spec.Modifiers = gq.modifiers
}
_spec.Node.Columns = gq.fields
if len(gq.fields) > 0 {
_spec.Unique = gq.unique != nil && *gq.unique
@@ -538,6 +545,9 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
t1.Schema(gq.schemaConfig.Group)
ctx = internal.NewSchemaConfigContext(ctx, gq.schemaConfig)
selector.WithContext(ctx)
for _, m := range gq.modifiers {
m(selector)
}
for _, p := range gq.predicates {
p(selector)
}
@@ -555,6 +565,12 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (gq *GroupQuery) Modify(modifiers ...func(s *sql.Selector)) *GroupSelect {
gq.modifiers = append(gq.modifiers, modifiers...)
return gq.Select()
}
// GroupGroupBy is the group-by builder for Group entities.
type GroupGroupBy struct {
config
@@ -1044,3 +1060,9 @@ func (gs *GroupSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (gs *GroupSelect) Modify(modifiers ...func(s *sql.Selector)) *GroupSelect {
gs.modifiers = append(gs.modifiers, modifiers...)
return gs
}

View File

@@ -27,7 +27,7 @@ var (
PetsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString, Default: "unknown"},
{Name: "user_pets", Type: field.TypeInt, Nullable: true},
{Name: "owner_id", Type: field.TypeInt, Nullable: true},
}
// PetsTable holds the schema information for the "pets" table.
PetsTable = &schema.Table{

View File

@@ -548,9 +548,53 @@ func (m *PetMutation) ResetName() {
m.name = nil
}
// SetOwnerID sets the "owner" edge to the User entity by id.
func (m *PetMutation) SetOwnerID(id int) {
m.owner = &id
// SetOwnerID sets the "owner_id" field.
func (m *PetMutation) SetOwnerID(i int) {
m.owner = &i
}
// OwnerID returns the value of the "owner_id" field in the mutation.
func (m *PetMutation) OwnerID() (r int, exists bool) {
v := m.owner
if v == nil {
return
}
return *v, true
}
// OldOwnerID returns the old "owner_id" field's value of the Pet entity.
// If the Pet object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *PetMutation) OldOwnerID(ctx context.Context) (v int, err error) {
if !m.op.Is(OpUpdateOne) {
return v, fmt.Errorf("OldOwnerID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, fmt.Errorf("OldOwnerID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldOwnerID: %w", err)
}
return oldValue.OwnerID, nil
}
// ClearOwnerID clears the value of the "owner_id" field.
func (m *PetMutation) ClearOwnerID() {
m.owner = nil
m.clearedFields[pet.FieldOwnerID] = struct{}{}
}
// OwnerIDCleared returns if the "owner_id" field was cleared in this mutation.
func (m *PetMutation) OwnerIDCleared() bool {
_, ok := m.clearedFields[pet.FieldOwnerID]
return ok
}
// ResetOwnerID resets all changes to the "owner_id" field.
func (m *PetMutation) ResetOwnerID() {
m.owner = nil
delete(m.clearedFields, pet.FieldOwnerID)
}
// ClearOwner clears the "owner" edge to the User entity.
@@ -560,15 +604,7 @@ func (m *PetMutation) ClearOwner() {
// OwnerCleared reports if the "owner" edge to the User entity was cleared.
func (m *PetMutation) OwnerCleared() bool {
return m.clearedowner
}
// OwnerID returns the "owner" edge ID in the mutation.
func (m *PetMutation) OwnerID() (id int, exists bool) {
if m.owner != nil {
return *m.owner, true
}
return
return m.OwnerIDCleared() || m.clearedowner
}
// OwnerIDs returns the "owner" edge IDs in the mutation.
@@ -606,10 +642,13 @@ func (m *PetMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *PetMutation) Fields() []string {
fields := make([]string, 0, 1)
fields := make([]string, 0, 2)
if m.name != nil {
fields = append(fields, pet.FieldName)
}
if m.owner != nil {
fields = append(fields, pet.FieldOwnerID)
}
return fields
}
@@ -620,6 +659,8 @@ func (m *PetMutation) Field(name string) (ent.Value, bool) {
switch name {
case pet.FieldName:
return m.Name()
case pet.FieldOwnerID:
return m.OwnerID()
}
return nil, false
}
@@ -631,6 +672,8 @@ func (m *PetMutation) OldField(ctx context.Context, name string) (ent.Value, err
switch name {
case pet.FieldName:
return m.OldName(ctx)
case pet.FieldOwnerID:
return m.OldOwnerID(ctx)
}
return nil, fmt.Errorf("unknown Pet field %s", name)
}
@@ -647,6 +690,13 @@ func (m *PetMutation) SetField(name string, value ent.Value) error {
}
m.SetName(v)
return nil
case pet.FieldOwnerID:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetOwnerID(v)
return nil
}
return fmt.Errorf("unknown Pet field %s", name)
}
@@ -654,13 +704,16 @@ func (m *PetMutation) SetField(name string, value ent.Value) error {
// AddedFields returns all numeric fields that were incremented/decremented during
// this mutation.
func (m *PetMutation) AddedFields() []string {
return nil
var fields []string
return fields
}
// AddedField returns the numeric value that was incremented/decremented on a field
// with the given name. The second boolean return value indicates that this field
// was not set, or was not defined in the schema.
func (m *PetMutation) AddedField(name string) (ent.Value, bool) {
switch name {
}
return nil, false
}
@@ -676,7 +729,11 @@ func (m *PetMutation) AddField(name string, value ent.Value) error {
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
func (m *PetMutation) ClearedFields() []string {
return nil
var fields []string
if m.FieldCleared(pet.FieldOwnerID) {
fields = append(fields, pet.FieldOwnerID)
}
return fields
}
// FieldCleared returns a boolean indicating if a field with the given name was
@@ -689,6 +746,11 @@ func (m *PetMutation) FieldCleared(name string) bool {
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
func (m *PetMutation) ClearField(name string) error {
switch name {
case pet.FieldOwnerID:
m.ClearOwnerID()
return nil
}
return fmt.Errorf("unknown Pet nullable field %s", name)
}
@@ -699,6 +761,9 @@ func (m *PetMutation) ResetField(name string) error {
case pet.FieldName:
m.ResetName()
return nil
case pet.FieldOwnerID:
m.ResetOwnerID()
return nil
}
return fmt.Errorf("unknown Pet field %s", name)
}

View File

@@ -22,10 +22,11 @@ type Pet struct {
ID int `json:"id,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// OwnerID holds the value of the "owner_id" field.
OwnerID int `json:"owner_id,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the PetQuery when eager-loading is set.
Edges PetEdges `json:"edges"`
user_pets *int
Edges PetEdges `json:"edges"`
}
// PetEdges holds the relations/edges for other nodes in the graph.
@@ -56,12 +57,10 @@ func (*Pet) scanValues(columns []string) ([]interface{}, error) {
values := make([]interface{}, len(columns))
for i := range columns {
switch columns[i] {
case pet.FieldID:
case pet.FieldID, pet.FieldOwnerID:
values[i] = new(sql.NullInt64)
case pet.FieldName:
values[i] = new(sql.NullString)
case pet.ForeignKeys[0]: // user_pets
values[i] = new(sql.NullInt64)
default:
return nil, fmt.Errorf("unexpected column %q for type Pet", columns[i])
}
@@ -89,12 +88,11 @@ func (pe *Pet) assignValues(columns []string, values []interface{}) error {
} else if value.Valid {
pe.Name = value.String
}
case pet.ForeignKeys[0]:
case pet.FieldOwnerID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for edge-field user_pets", value)
return fmt.Errorf("unexpected type %T for field owner_id", values[i])
} else if value.Valid {
pe.user_pets = new(int)
*pe.user_pets = int(value.Int64)
pe.OwnerID = int(value.Int64)
}
}
}
@@ -131,6 +129,8 @@ func (pe *Pet) String() string {
builder.WriteString(fmt.Sprintf("id=%v", pe.ID))
builder.WriteString(", name=")
builder.WriteString(pe.Name)
builder.WriteString(", owner_id=")
builder.WriteString(fmt.Sprintf("%v", pe.OwnerID))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -13,6 +13,8 @@ const (
FieldID = "id"
// FieldName holds the string denoting the name field in the database.
FieldName = "name"
// FieldOwnerID holds the string denoting the owner_id field in the database.
FieldOwnerID = "owner_id"
// EdgeOwner holds the string denoting the owner edge name in mutations.
EdgeOwner = "owner"
// Table holds the table name of the pet in the database.
@@ -23,19 +25,14 @@ const (
// It exists in this package in order to avoid circular dependency with the "user" package.
OwnerInverseTable = "users"
// OwnerColumn is the table column denoting the owner relation/edge.
OwnerColumn = "user_pets"
OwnerColumn = "owner_id"
)
// Columns holds all SQL columns for pet fields.
var Columns = []string{
FieldID,
FieldName,
}
// ForeignKeys holds the SQL foreign-keys that are owned by the "pets"
// table and are not defined as standalone fields in the schema.
var ForeignKeys = []string{
"user_pets",
FieldOwnerID,
}
// ValidColumn reports if the column name is valid (part of the table columns).
@@ -45,11 +42,6 @@ func ValidColumn(column string) bool {
return true
}
}
for i := range ForeignKeys {
if column == ForeignKeys[i] {
return true
}
}
return false
}

View File

@@ -103,6 +103,13 @@ func Name(v string) predicate.Pet {
})
}
// OwnerID applies equality check predicate on the "owner_id" field. It's identical to OwnerIDEQ.
func OwnerID(v int) predicate.Pet {
return predicate.Pet(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldOwnerID), v))
})
}
// NameEQ applies the EQ predicate on the "name" field.
func NameEQ(v string) predicate.Pet {
return predicate.Pet(func(s *sql.Selector) {
@@ -214,6 +221,68 @@ func NameContainsFold(v string) predicate.Pet {
})
}
// OwnerIDEQ applies the EQ predicate on the "owner_id" field.
func OwnerIDEQ(v int) predicate.Pet {
return predicate.Pet(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldOwnerID), v))
})
}
// OwnerIDNEQ applies the NEQ predicate on the "owner_id" field.
func OwnerIDNEQ(v int) predicate.Pet {
return predicate.Pet(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldOwnerID), v))
})
}
// OwnerIDIn applies the In predicate on the "owner_id" field.
func OwnerIDIn(vs ...int) predicate.Pet {
v := make([]interface{}, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Pet(func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
// since we can't apply "IN ()". This will make this predicate falsy.
if len(v) == 0 {
s.Where(sql.False())
return
}
s.Where(sql.In(s.C(FieldOwnerID), v...))
})
}
// OwnerIDNotIn applies the NotIn predicate on the "owner_id" field.
func OwnerIDNotIn(vs ...int) predicate.Pet {
v := make([]interface{}, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Pet(func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
// since we can't apply "IN ()". This will make this predicate falsy.
if len(v) == 0 {
s.Where(sql.False())
return
}
s.Where(sql.NotIn(s.C(FieldOwnerID), v...))
})
}
// OwnerIDIsNil applies the IsNil predicate on the "owner_id" field.
func OwnerIDIsNil() predicate.Pet {
return predicate.Pet(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C(FieldOwnerID)))
})
}
// OwnerIDNotNil applies the NotNil predicate on the "owner_id" field.
func OwnerIDNotNil() predicate.Pet {
return predicate.Pet(func(s *sql.Selector) {
s.Where(sql.NotNull(s.C(FieldOwnerID)))
})
}
// HasOwner applies the HasEdge predicate on the "owner" edge.
func HasOwner() predicate.Pet {
return predicate.Pet(func(s *sql.Selector) {

View File

@@ -38,16 +38,16 @@ func (pc *PetCreate) SetNillableName(s *string) *PetCreate {
return pc
}
// SetOwnerID sets the "owner" edge to the User entity by ID.
func (pc *PetCreate) SetOwnerID(id int) *PetCreate {
pc.mutation.SetOwnerID(id)
// SetOwnerID sets the "owner_id" field.
func (pc *PetCreate) SetOwnerID(i int) *PetCreate {
pc.mutation.SetOwnerID(i)
return pc
}
// SetNillableOwnerID sets the "owner" edge to the User entity by ID if the given value is not nil.
func (pc *PetCreate) SetNillableOwnerID(id *int) *PetCreate {
if id != nil {
pc = pc.SetOwnerID(*id)
// SetNillableOwnerID sets the "owner_id" field if the given value is not nil.
func (pc *PetCreate) SetNillableOwnerID(i *int) *PetCreate {
if i != nil {
pc.SetOwnerID(*i)
}
return pc
}
@@ -193,7 +193,7 @@ func (pc *PetCreate) createSpec() (*Pet, *sqlgraph.CreateSpec) {
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.user_pets = &nodes[0]
_node.OwnerID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec

View File

@@ -32,7 +32,7 @@ type PetQuery struct {
predicates []predicate.Pet
// eager-loading edges.
withOwner *UserQuery
withFKs bool
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -357,18 +357,11 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error {
func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) {
var (
nodes = []*Pet{}
withFKs = pq.withFKs
_spec = pq.querySpec()
loadedTypes = [1]bool{
pq.withOwner != nil,
}
)
if pq.withOwner != nil {
withFKs = true
}
if withFKs {
_spec.Node.Columns = append(_spec.Node.Columns, pet.ForeignKeys...)
}
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
node := &Pet{config: pq.config}
nodes = append(nodes, node)
@@ -384,6 +377,9 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) {
}
_spec.Node.Schema = pq.schemaConfig.Pet
ctx = internal.NewSchemaConfigContext(ctx, pq.schemaConfig)
if len(pq.modifiers) > 0 {
_spec.Modifiers = pq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil {
return nil, err
}
@@ -395,10 +391,7 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) {
ids := make([]int, 0, len(nodes))
nodeids := make(map[int][]*Pet)
for i := range nodes {
if nodes[i].user_pets == nil {
continue
}
fk := *nodes[i].user_pets
fk := nodes[i].OwnerID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
@@ -412,7 +405,7 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) {
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID)
return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.Owner = n
@@ -427,6 +420,9 @@ func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) {
_spec := pq.querySpec()
_spec.Node.Schema = pq.schemaConfig.Pet
ctx = internal.NewSchemaConfigContext(ctx, pq.schemaConfig)
if len(pq.modifiers) > 0 {
_spec.Modifiers = pq.modifiers
}
_spec.Node.Columns = pq.fields
if len(pq.fields) > 0 {
_spec.Unique = pq.unique != nil && *pq.unique
@@ -508,6 +504,9 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector {
t1.Schema(pq.schemaConfig.Pet)
ctx = internal.NewSchemaConfigContext(ctx, pq.schemaConfig)
selector.WithContext(ctx)
for _, m := range pq.modifiers {
m(selector)
}
for _, p := range pq.predicates {
p(selector)
}
@@ -525,6 +524,12 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (pq *PetQuery) Modify(modifiers ...func(s *sql.Selector)) *PetSelect {
pq.modifiers = append(pq.modifiers, modifiers...)
return pq.Select()
}
// PetGroupBy is the group-by builder for Pet entities.
type PetGroupBy struct {
config
@@ -1014,3 +1019,9 @@ func (ps *PetSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (ps *PetSelect) Modify(modifiers ...func(s *sql.Selector)) *PetSelect {
ps.modifiers = append(ps.modifiers, modifiers...)
return ps
}

View File

@@ -47,20 +47,26 @@ func (pu *PetUpdate) SetNillableName(s *string) *PetUpdate {
return pu
}
// SetOwnerID sets the "owner" edge to the User entity by ID.
func (pu *PetUpdate) SetOwnerID(id int) *PetUpdate {
pu.mutation.SetOwnerID(id)
// SetOwnerID sets the "owner_id" field.
func (pu *PetUpdate) SetOwnerID(i int) *PetUpdate {
pu.mutation.SetOwnerID(i)
return pu
}
// SetNillableOwnerID sets the "owner" edge to the User entity by ID if the given value is not nil.
func (pu *PetUpdate) SetNillableOwnerID(id *int) *PetUpdate {
if id != nil {
pu = pu.SetOwnerID(*id)
// SetNillableOwnerID sets the "owner_id" field if the given value is not nil.
func (pu *PetUpdate) SetNillableOwnerID(i *int) *PetUpdate {
if i != nil {
pu.SetOwnerID(*i)
}
return pu
}
// ClearOwnerID clears the value of the "owner_id" field.
func (pu *PetUpdate) ClearOwnerID() *PetUpdate {
pu.mutation.ClearOwnerID()
return pu
}
// SetOwner sets the "owner" edge to the User entity.
func (pu *PetUpdate) SetOwner(u *User) *PetUpdate {
return pu.SetOwnerID(u.ID)
@@ -228,20 +234,26 @@ func (puo *PetUpdateOne) SetNillableName(s *string) *PetUpdateOne {
return puo
}
// SetOwnerID sets the "owner" edge to the User entity by ID.
func (puo *PetUpdateOne) SetOwnerID(id int) *PetUpdateOne {
puo.mutation.SetOwnerID(id)
// SetOwnerID sets the "owner_id" field.
func (puo *PetUpdateOne) SetOwnerID(i int) *PetUpdateOne {
puo.mutation.SetOwnerID(i)
return puo
}
// SetNillableOwnerID sets the "owner" edge to the User entity by ID if the given value is not nil.
func (puo *PetUpdateOne) SetNillableOwnerID(id *int) *PetUpdateOne {
if id != nil {
puo = puo.SetOwnerID(*id)
// SetNillableOwnerID sets the "owner_id" field if the given value is not nil.
func (puo *PetUpdateOne) SetNillableOwnerID(i *int) *PetUpdateOne {
if i != nil {
puo.SetOwnerID(*i)
}
return puo
}
// ClearOwnerID clears the value of the "owner_id" field.
func (puo *PetUpdateOne) ClearOwnerID() *PetUpdateOne {
puo.mutation.ClearOwnerID()
return puo
}
// SetOwner sets the "owner" edge to the User entity.
func (puo *PetUpdateOne) SetOwner(u *User) *PetUpdateOne {
return puo.SetOwnerID(u.ID)

View File

@@ -20,6 +20,8 @@ func (Pet) Fields() []ent.Field {
return []ent.Field{
field.String("name").
Default("unknown"),
field.Int("owner_id").
Optional(),
}
}
@@ -28,6 +30,7 @@ func (Pet) Edges() []ent.Edge {
return []ent.Edge{
edge.From("owner", User.Type).
Ref("pets").
Field("owner_id").
Unique(),
}
}

View File

@@ -0,0 +1,16 @@
{{/*
Copyright 2019-present Facebook Inc. All rights reserved.
This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}
{{/* The line below tells Intellij/GoLand to enable the autocompletion based *gen.Graph type. */}}
{{/* gotype: entgo.io/ent/entc/gen.Graph */}}
{{ define "config/options/schemaconfig" }}
// SchemaConfigFromContext exports the internal.SchemaConfigFromContext
// for external usage (inside custom predicates or modifiers).
func SchemaConfigFromContext(ctx context.Context) SchemaConfig {
return internal.SchemaConfigFromContext(ctx)
}
{{ end }}

View File

@@ -25,7 +25,7 @@ const (
// It exists in this package in order to avoid circular dependency with the "pet" package.
PetsInverseTable = "pets"
// PetsColumn is the table column denoting the pets relation/edge.
PetsColumn = "user_pets"
PetsColumn = "owner_id"
// GroupsTable is the table that holds the groups relation/edge. The primary key declared below.
GroupsTable = "group_users"
// GroupsInverseTable is the table name for the Group entity.

View File

@@ -35,6 +35,7 @@ type UserQuery struct {
// eager-loading edges.
withPets *PetQuery
withGroups *GroupQuery
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -417,6 +418,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
}
_spec.Node.Schema = uq.schemaConfig.User
ctx = internal.NewSchemaConfigContext(ctx, uq.schemaConfig)
if len(uq.modifiers) > 0 {
_spec.Modifiers = uq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil {
return nil, err
}
@@ -432,7 +436,6 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
nodeids[nodes[i].ID] = nodes[i]
nodes[i].Edges.Pets = []*Pet{}
}
query.withFKs = true
query.Where(predicate.Pet(func(s *sql.Selector) {
s.Where(sql.InValues(user.PetsColumn, fks...))
}))
@@ -441,13 +444,10 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) {
return nil, err
}
for _, n := range neighbors {
fk := n.user_pets
if fk == nil {
return nil, fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
fk := n.OwnerID
node, ok := nodeids[fk]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID)
return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, fk, n.ID)
}
node.Edges.Pets = append(node.Edges.Pets, n)
}
@@ -526,6 +526,9 @@ func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) {
_spec := uq.querySpec()
_spec.Node.Schema = uq.schemaConfig.User
ctx = internal.NewSchemaConfigContext(ctx, uq.schemaConfig)
if len(uq.modifiers) > 0 {
_spec.Modifiers = uq.modifiers
}
_spec.Node.Columns = uq.fields
if len(uq.fields) > 0 {
_spec.Unique = uq.unique != nil && *uq.unique
@@ -607,6 +610,9 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
t1.Schema(uq.schemaConfig.User)
ctx = internal.NewSchemaConfigContext(ctx, uq.schemaConfig)
selector.WithContext(ctx)
for _, m := range uq.modifiers {
m(selector)
}
for _, p := range uq.predicates {
p(selector)
}
@@ -624,6 +630,12 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (uq *UserQuery) Modify(modifiers ...func(s *sql.Selector)) *UserSelect {
uq.modifiers = append(uq.modifiers, modifiers...)
return uq.Select()
}
// UserGroupBy is the group-by builder for User entities.
type UserGroupBy struct {
config
@@ -1113,3 +1125,9 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (us *UserSelect) Modify(modifiers ...func(s *sql.Selector)) *UserSelect {
us.modifiers = append(us.modifiers, modifiers...)
return us
}

View File

@@ -45,7 +45,29 @@ func TestMySQL(t *testing.T) {
client.Group.Create().SetName("GitHub"),
client.Group.Create().SetName("GitLab"),
).SaveX(ctx)
usr := client.User.Create().AddPets(pedro).AddGroups(groups...).SaveX(ctx)
usr := client.User.Create().SetName("a8m").AddPets(pedro).AddGroups(groups...).SaveX(ctx)
// Custom modifier with schema config.
var names []struct {
User string `sql:"user_name"`
Pet string `sql:"pet_name"`
}
client.Pet.Query().
Modify(func(s *sql.Selector) {
// The below function is exported using a custom
// template defined in ent/template/config.tmpl.
cfg := ent.SchemaConfigFromContext(s.Context())
t := sql.Table(user.Table).Schema(cfg.User)
s.Join(t).On(s.C(pet.FieldOwnerID), t.C(user.FieldID))
s.Select(
sql.As(t.C(user.FieldName), "user_name"),
sql.As(s.C(pet.FieldName), "pet_name"),
)
}).
ScanX(ctx, &names)
require.Len(t, names, 1)
require.Equal(t, "a8m", names[0].User)
require.Equal(t, "Pedro", names[0].Pet)
id := client.Group.Query().
Where(group.HasUsersWith(user.ID(usr.ID))).