From 01f97334af13857dd9bd11d73ab3d1f78f4ab8b1 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Fri, 15 Jul 2022 11:01:13 +0300 Subject: [PATCH] entc/gen: use custom schema-type in join tables foreign-keys (#2760) Fixed https://github.com/ent/ent/issues/2753 --- entc/gen/graph.go | 4 +- entc/integration/customid/ent/client.go | 16 ++ entc/integration/customid/ent/doc.go | 18 +- entc/integration/customid/ent/doc/doc.go | 10 + entc/integration/customid/ent/doc/where.go | 28 +++ entc/integration/customid/ent/doc_create.go | 34 ++++ entc/integration/customid/ent/doc_query.go | 91 ++++++++- entc/integration/customid/ent/doc_update.go | 180 ++++++++++++++++++ entc/integration/customid/ent/entql.go | 26 +++ .../customid/ent/migrate/schema.go | 32 +++- entc/integration/customid/ent/mutation.go | 89 ++++++++- entc/integration/customid/ent/schema/doc.go | 7 + .../integration/edgeschema/ent/schema/role.go | 3 +- .../edgeschema/ent/schema/role_user.go | 3 +- 14 files changed, 529 insertions(+), 12 deletions(-) diff --git a/entc/gen/graph.go b/entc/gen/graph.go index 64bc5421a..8e15725a5 100644 --- a/entc/gen/graph.go +++ b/entc/gen/graph.go @@ -623,12 +623,12 @@ func (g *Graph) Tables() (all []*schema.Table, err error) { continue } t1, t2 := tables[n.Table()], tables[e.Type.Table()] - c1 := &schema.Column{Name: e.Rel.Columns[0], Type: field.TypeInt} + c1 := &schema.Column{Name: e.Rel.Columns[0], Type: field.TypeInt, SchemaType: n.ID.def.SchemaType} if ref := n.ID; ref.UserDefined { c1.Type = ref.Type.Type c1.Size = ref.size() } - c2 := &schema.Column{Name: e.Rel.Columns[1], Type: field.TypeInt} + c2 := &schema.Column{Name: e.Rel.Columns[1], Type: field.TypeInt, SchemaType: e.Type.ID.def.SchemaType} if ref := e.Type.ID; ref.UserDefined { c2.Type = ref.Type.Type c2.Size = ref.size() diff --git a/entc/integration/customid/ent/client.go b/entc/integration/customid/ent/client.go index e0ec3d42a..c18c830b7 100644 --- a/entc/integration/customid/ent/client.go +++ b/entc/integration/customid/ent/client.go @@ -897,6 +897,22 @@ func (c *DocClient) QueryChildren(d *Doc) *DocQuery { return query } +// QueryRelated queries the related edge of a Doc. +func (c *DocClient) QueryRelated(d *Doc) *DocQuery { + query := &DocQuery{config: c.config} + query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + id := d.ID + step := sqlgraph.NewStep( + sqlgraph.From(doc.Table, doc.FieldID, id), + sqlgraph.To(doc.Table, doc.FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, doc.RelatedTable, doc.RelatedPrimaryKey...), + ) + fromV = sqlgraph.Neighbors(d.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *DocClient) Hooks() []Hook { return c.hooks.Doc diff --git a/entc/integration/customid/ent/doc.go b/entc/integration/customid/ent/doc.go index d746b05a1..58b41a0c4 100644 --- a/entc/integration/customid/ent/doc.go +++ b/entc/integration/customid/ent/doc.go @@ -34,9 +34,11 @@ type DocEdges struct { Parent *Doc `json:"parent,omitempty"` // Children holds the value of the children edge. Children []*Doc `json:"children,omitempty"` + // Related holds the value of the related edge. + Related []*Doc `json:"related,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [3]bool } // ParentOrErr returns the Parent value or an error if the edge @@ -62,6 +64,15 @@ func (e DocEdges) ChildrenOrErr() ([]*Doc, error) { return nil, &NotLoadedError{edge: "children"} } +// RelatedOrErr returns the Related value or an error if the edge +// was not loaded in eager-loading. +func (e DocEdges) RelatedOrErr() ([]*Doc, error) { + if e.loadedTypes[2] { + return e.Related, nil + } + return nil, &NotLoadedError{edge: "related"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*Doc) scanValues(columns []string) ([]interface{}, error) { values := make([]interface{}, len(columns)) @@ -122,6 +133,11 @@ func (d *Doc) QueryChildren() *DocQuery { return (&DocClient{config: d.config}).QueryChildren(d) } +// QueryRelated queries the "related" edge of the Doc entity. +func (d *Doc) QueryRelated() *DocQuery { + return (&DocClient{config: d.config}).QueryRelated(d) +} + // Update returns a builder for updating this Doc. // Note that you need to call Doc.Unwrap() before calling this method if this Doc // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/entc/integration/customid/ent/doc/doc.go b/entc/integration/customid/ent/doc/doc.go index bc4e9ef35..c094a4455 100644 --- a/entc/integration/customid/ent/doc/doc.go +++ b/entc/integration/customid/ent/doc/doc.go @@ -21,6 +21,8 @@ const ( EdgeParent = "parent" // EdgeChildren holds the string denoting the children edge name in mutations. EdgeChildren = "children" + // EdgeRelated holds the string denoting the related edge name in mutations. + EdgeRelated = "related" // Table holds the table name of the doc in the database. Table = "docs" // ParentTable is the table that holds the parent relation/edge. @@ -31,6 +33,8 @@ const ( ChildrenTable = "docs" // ChildrenColumn is the table column denoting the children relation/edge. ChildrenColumn = "doc_children" + // RelatedTable is the table that holds the related relation/edge. The primary key declared below. + RelatedTable = "doc_related" ) // Columns holds all SQL columns for doc fields. @@ -45,6 +49,12 @@ var ForeignKeys = []string{ "doc_children", } +var ( + // RelatedPrimaryKey and RelatedColumn2 are the table columns denoting the + // primary key for the related relation (M2M). + RelatedPrimaryKey = []string{"doc_id", "related_id"} +) + // ValidColumn reports if the column name is valid (part of the table columns). func ValidColumn(column string) bool { for i := range Columns { diff --git a/entc/integration/customid/ent/doc/where.go b/entc/integration/customid/ent/doc/where.go index 79b91b7c4..bf2172f87 100644 --- a/entc/integration/customid/ent/doc/where.go +++ b/entc/integration/customid/ent/doc/where.go @@ -260,6 +260,34 @@ func HasChildrenWith(preds ...predicate.Doc) predicate.Doc { }) } +// HasRelated applies the HasEdge predicate on the "related" edge. +func HasRelated() predicate.Doc { + return predicate.Doc(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(RelatedTable, FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, RelatedTable, RelatedPrimaryKey...), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasRelatedWith applies the HasEdge predicate on the "related" edge with a given conditions (other predicates). +func HasRelatedWith(preds ...predicate.Doc) predicate.Doc { + return predicate.Doc(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, RelatedTable, RelatedPrimaryKey...), + ) + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Doc) predicate.Doc { return predicate.Doc(func(s *sql.Selector) { diff --git a/entc/integration/customid/ent/doc_create.go b/entc/integration/customid/ent/doc_create.go index 89dc570bf..eb6112bdf 100644 --- a/entc/integration/customid/ent/doc_create.go +++ b/entc/integration/customid/ent/doc_create.go @@ -89,6 +89,21 @@ func (dc *DocCreate) AddChildren(d ...*Doc) *DocCreate { return dc.AddChildIDs(ids...) } +// AddRelatedIDs adds the "related" edge to the Doc entity by IDs. +func (dc *DocCreate) AddRelatedIDs(ids ...schema.DocID) *DocCreate { + dc.mutation.AddRelatedIDs(ids...) + return dc +} + +// AddRelated adds the "related" edges to the Doc entity. +func (dc *DocCreate) AddRelated(d ...*Doc) *DocCreate { + ids := make([]schema.DocID, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return dc.AddRelatedIDs(ids...) +} + // Mutation returns the DocMutation object of the builder. func (dc *DocCreate) Mutation() *DocMutation { return dc.mutation @@ -263,6 +278,25 @@ func (dc *DocCreate) createSpec() (*Doc, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := dc.mutation.RelatedIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: doc.RelatedTable, + Columns: doc.RelatedPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: doc.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/entc/integration/customid/ent/doc_query.go b/entc/integration/customid/ent/doc_query.go index 991fa30cb..24b6f343b 100644 --- a/entc/integration/customid/ent/doc_query.go +++ b/entc/integration/customid/ent/doc_query.go @@ -32,6 +32,7 @@ type DocQuery struct { // eager-loading edges. withParent *DocQuery withChildren *DocQuery + withRelated *DocQuery withFKs bool // intermediate query (i.e. traversal path). sql *sql.Selector @@ -113,6 +114,28 @@ func (dq *DocQuery) QueryChildren() *DocQuery { return query } +// QueryRelated chains the current query on the "related" edge. +func (dq *DocQuery) QueryRelated() *DocQuery { + query := &DocQuery{config: dq.config} + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := dq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := dq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(doc.Table, doc.FieldID, selector), + sqlgraph.To(doc.Table, doc.FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, doc.RelatedTable, doc.RelatedPrimaryKey...), + ) + fromU = sqlgraph.SetNeighbors(dq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first Doc entity from the query. // Returns a *NotFoundError when no Doc was found. func (dq *DocQuery) First(ctx context.Context) (*Doc, error) { @@ -296,6 +319,7 @@ func (dq *DocQuery) Clone() *DocQuery { predicates: append([]predicate.Doc{}, dq.predicates...), withParent: dq.withParent.Clone(), withChildren: dq.withChildren.Clone(), + withRelated: dq.withRelated.Clone(), // clone intermediate query. sql: dq.sql.Clone(), path: dq.path, @@ -325,6 +349,17 @@ func (dq *DocQuery) WithChildren(opts ...func(*DocQuery)) *DocQuery { return dq } +// WithRelated tells the query-builder to eager-load the nodes that are connected to +// the "related" edge. The optional arguments are used to configure the query builder of the edge. +func (dq *DocQuery) WithRelated(opts ...func(*DocQuery)) *DocQuery { + query := &DocQuery{config: dq.config} + for _, opt := range opts { + opt(query) + } + dq.withRelated = query + return dq +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -396,9 +431,10 @@ func (dq *DocQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Doc, err nodes = []*Doc{} withFKs = dq.withFKs _spec = dq.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [3]bool{ dq.withParent != nil, dq.withChildren != nil, + dq.withRelated != nil, } ) if dq.withParent != nil { @@ -484,6 +520,59 @@ func (dq *DocQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Doc, err } } + if query := dq.withRelated; query != nil { + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[schema.DocID]*Doc) + nids := make(map[schema.DocID]map[*Doc]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node + node.Edges.Related = []*Doc{} + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(doc.RelatedTable) + s.Join(joinT).On(s.C(doc.FieldID), joinT.C(doc.RelatedPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(doc.RelatedPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(doc.RelatedPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]interface{}{new(schema.DocID)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := *values[0].(*schema.DocID) + inValue := *values[1].(*schema.DocID) + if nids[inValue] == nil { + nids[inValue] = map[*Doc]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byid[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return nil, err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return nil, fmt.Errorf(`unexpected "related" node returned %v`, n.ID) + } + for kn := range nodes { + kn.Edges.Related = append(kn.Edges.Related, n) + } + } + } + return nodes, nil } diff --git a/entc/integration/customid/ent/doc_update.go b/entc/integration/customid/ent/doc_update.go index 16589a318..7facb09ba 100644 --- a/entc/integration/customid/ent/doc_update.go +++ b/entc/integration/customid/ent/doc_update.go @@ -86,6 +86,21 @@ func (du *DocUpdate) AddChildren(d ...*Doc) *DocUpdate { return du.AddChildIDs(ids...) } +// AddRelatedIDs adds the "related" edge to the Doc entity by IDs. +func (du *DocUpdate) AddRelatedIDs(ids ...schema.DocID) *DocUpdate { + du.mutation.AddRelatedIDs(ids...) + return du +} + +// AddRelated adds the "related" edges to the Doc entity. +func (du *DocUpdate) AddRelated(d ...*Doc) *DocUpdate { + ids := make([]schema.DocID, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return du.AddRelatedIDs(ids...) +} + // Mutation returns the DocMutation object of the builder. func (du *DocUpdate) Mutation() *DocMutation { return du.mutation @@ -118,6 +133,27 @@ func (du *DocUpdate) RemoveChildren(d ...*Doc) *DocUpdate { return du.RemoveChildIDs(ids...) } +// ClearRelated clears all "related" edges to the Doc entity. +func (du *DocUpdate) ClearRelated() *DocUpdate { + du.mutation.ClearRelated() + return du +} + +// RemoveRelatedIDs removes the "related" edge to Doc entities by IDs. +func (du *DocUpdate) RemoveRelatedIDs(ids ...schema.DocID) *DocUpdate { + du.mutation.RemoveRelatedIDs(ids...) + return du +} + +// RemoveRelated removes "related" edges to Doc entities. +func (du *DocUpdate) RemoveRelated(d ...*Doc) *DocUpdate { + ids := make([]schema.DocID, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return du.RemoveRelatedIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (du *DocUpdate) Save(ctx context.Context) (int, error) { var ( @@ -292,6 +328,60 @@ func (du *DocUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if du.mutation.RelatedCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: doc.RelatedTable, + Columns: doc.RelatedPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: doc.FieldID, + }, + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := du.mutation.RemovedRelatedIDs(); len(nodes) > 0 && !du.mutation.RelatedCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: doc.RelatedTable, + Columns: doc.RelatedPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: doc.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := du.mutation.RelatedIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: doc.RelatedTable, + Columns: doc.RelatedPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: doc.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if n, err = sqlgraph.UpdateNodes(ctx, du.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{doc.Label} @@ -365,6 +455,21 @@ func (duo *DocUpdateOne) AddChildren(d ...*Doc) *DocUpdateOne { return duo.AddChildIDs(ids...) } +// AddRelatedIDs adds the "related" edge to the Doc entity by IDs. +func (duo *DocUpdateOne) AddRelatedIDs(ids ...schema.DocID) *DocUpdateOne { + duo.mutation.AddRelatedIDs(ids...) + return duo +} + +// AddRelated adds the "related" edges to the Doc entity. +func (duo *DocUpdateOne) AddRelated(d ...*Doc) *DocUpdateOne { + ids := make([]schema.DocID, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return duo.AddRelatedIDs(ids...) +} + // Mutation returns the DocMutation object of the builder. func (duo *DocUpdateOne) Mutation() *DocMutation { return duo.mutation @@ -397,6 +502,27 @@ func (duo *DocUpdateOne) RemoveChildren(d ...*Doc) *DocUpdateOne { return duo.RemoveChildIDs(ids...) } +// ClearRelated clears all "related" edges to the Doc entity. +func (duo *DocUpdateOne) ClearRelated() *DocUpdateOne { + duo.mutation.ClearRelated() + return duo +} + +// RemoveRelatedIDs removes the "related" edge to Doc entities by IDs. +func (duo *DocUpdateOne) RemoveRelatedIDs(ids ...schema.DocID) *DocUpdateOne { + duo.mutation.RemoveRelatedIDs(ids...) + return duo +} + +// RemoveRelated removes "related" edges to Doc entities. +func (duo *DocUpdateOne) RemoveRelated(d ...*Doc) *DocUpdateOne { + ids := make([]schema.DocID, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return duo.RemoveRelatedIDs(ids...) +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (duo *DocUpdateOne) Select(field string, fields ...string) *DocUpdateOne { @@ -601,6 +727,60 @@ func (duo *DocUpdateOne) sqlSave(ctx context.Context) (_node *Doc, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if duo.mutation.RelatedCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: doc.RelatedTable, + Columns: doc.RelatedPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: doc.FieldID, + }, + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := duo.mutation.RemovedRelatedIDs(); len(nodes) > 0 && !duo.mutation.RelatedCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: doc.RelatedTable, + Columns: doc.RelatedPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: doc.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := duo.mutation.RelatedIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: doc.RelatedTable, + Columns: doc.RelatedPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeString, + Column: doc.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &Doc{config: duo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/entc/integration/customid/ent/entql.go b/entc/integration/customid/ent/entql.go index 35093f734..8e7863874 100644 --- a/entc/integration/customid/ent/entql.go +++ b/entc/integration/customid/ent/entql.go @@ -376,6 +376,18 @@ var schemaGraph = func() *sqlgraph.Schema { "Doc", "Doc", ) + graph.MustAddE( + "related", + &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: doc.RelatedTable, + Columns: doc.RelatedPrimaryKey, + Bidi: true, + }, + "Doc", + "Doc", + ) graph.MustAddE( "users", &sqlgraph.EdgeSpec{ @@ -1004,6 +1016,20 @@ func (f *DocFilter) WhereHasChildrenWith(preds ...predicate.Doc) { }))) } +// WhereHasRelated applies a predicate to check if query has an edge related. +func (f *DocFilter) WhereHasRelated() { + f.Where(entql.HasEdge("related")) +} + +// WhereHasRelatedWith applies a predicate to check if query has an edge related with a given conditions (other predicates). +func (f *DocFilter) WhereHasRelatedWith(preds ...predicate.Doc) { + f.Where(entql.HasEdgeWith("related", sqlgraph.WrapFunc(func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }))) +} + // addPredicate implements the predicateAdder interface. func (gq *GroupQuery) addPredicate(pred func(s *sql.Selector)) { gq.predicates = append(gq.predicates, pred) diff --git a/entc/integration/customid/ent/migrate/schema.go b/entc/integration/customid/ent/migrate/schema.go index b08f3bbed..0d05dcc31 100644 --- a/entc/integration/customid/ent/migrate/schema.go +++ b/entc/integration/customid/ent/migrate/schema.go @@ -113,9 +113,9 @@ var ( } // DocsColumns holds the columns for the "docs" table. DocsColumns = []*schema.Column{ - {Name: "id", Type: field.TypeString, Unique: true, Size: 36}, + {Name: "id", Type: field.TypeString, Unique: true, Size: 36, SchemaType: map[string]string{"postgres": "uuid"}}, {Name: "text", Type: field.TypeString, Nullable: true}, - {Name: "doc_children", Type: field.TypeString, Nullable: true, Size: 36}, + {Name: "doc_children", Type: field.TypeString, Nullable: true, Size: 36, SchemaType: map[string]string{"postgres": "uuid"}}, } // DocsTable holds the schema information for the "docs" table. DocsTable = &schema.Table{ @@ -318,6 +318,31 @@ var ( }, }, } + // DocRelatedColumns holds the columns for the "doc_related" table. + DocRelatedColumns = []*schema.Column{ + {Name: "doc_id", Type: field.TypeString, Size: 36, SchemaType: map[string]string{"postgres": "uuid"}}, + {Name: "related_id", Type: field.TypeString, Size: 36, SchemaType: map[string]string{"postgres": "uuid"}}, + } + // DocRelatedTable holds the schema information for the "doc_related" table. + DocRelatedTable = &schema.Table{ + Name: "doc_related", + Columns: DocRelatedColumns, + PrimaryKey: []*schema.Column{DocRelatedColumns[0], DocRelatedColumns[1]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "doc_related_doc_id", + Columns: []*schema.Column{DocRelatedColumns[0]}, + RefColumns: []*schema.Column{DocsColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "doc_related_related_id", + Columns: []*schema.Column{DocRelatedColumns[1]}, + RefColumns: []*schema.Column{DocsColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + } // GroupUsersColumns holds the columns for the "group_users" table. GroupUsersColumns = []*schema.Column{ {Name: "group_id", Type: field.TypeInt}, @@ -386,6 +411,7 @@ var ( SessionsTable, TokensTable, UsersTable, + DocRelatedTable, GroupUsersTable, PetFriendsTable, } @@ -405,6 +431,8 @@ func init() { SessionsTable.ForeignKeys[0].RefTable = DevicesTable TokensTable.ForeignKeys[0].RefTable = AccountsTable UsersTable.ForeignKeys[0].RefTable = UsersTable + DocRelatedTable.ForeignKeys[0].RefTable = DocsTable + DocRelatedTable.ForeignKeys[1].RefTable = DocsTable GroupUsersTable.ForeignKeys[0].RefTable = GroupsTable GroupUsersTable.ForeignKeys[1].RefTable = UsersTable PetFriendsTable.ForeignKeys[0].RefTable = PetsTable diff --git a/entc/integration/customid/ent/mutation.go b/entc/integration/customid/ent/mutation.go index 07258a4a5..c37b763c6 100644 --- a/entc/integration/customid/ent/mutation.go +++ b/entc/integration/customid/ent/mutation.go @@ -2456,6 +2456,9 @@ type DocMutation struct { children map[schema.DocID]struct{} removedchildren map[schema.DocID]struct{} clearedchildren bool + related map[schema.DocID]struct{} + removedrelated map[schema.DocID]struct{} + clearedrelated bool done bool oldValue func(context.Context) (*Doc, error) predicates []predicate.Doc @@ -2707,6 +2710,60 @@ func (m *DocMutation) ResetChildren() { m.removedchildren = nil } +// AddRelatedIDs adds the "related" edge to the Doc entity by ids. +func (m *DocMutation) AddRelatedIDs(ids ...schema.DocID) { + if m.related == nil { + m.related = make(map[schema.DocID]struct{}) + } + for i := range ids { + m.related[ids[i]] = struct{}{} + } +} + +// ClearRelated clears the "related" edge to the Doc entity. +func (m *DocMutation) ClearRelated() { + m.clearedrelated = true +} + +// RelatedCleared reports if the "related" edge to the Doc entity was cleared. +func (m *DocMutation) RelatedCleared() bool { + return m.clearedrelated +} + +// RemoveRelatedIDs removes the "related" edge to the Doc entity by IDs. +func (m *DocMutation) RemoveRelatedIDs(ids ...schema.DocID) { + if m.removedrelated == nil { + m.removedrelated = make(map[schema.DocID]struct{}) + } + for i := range ids { + delete(m.related, ids[i]) + m.removedrelated[ids[i]] = struct{}{} + } +} + +// RemovedRelated returns the removed IDs of the "related" edge to the Doc entity. +func (m *DocMutation) RemovedRelatedIDs() (ids []schema.DocID) { + for id := range m.removedrelated { + ids = append(ids, id) + } + return +} + +// RelatedIDs returns the "related" edge IDs in the mutation. +func (m *DocMutation) RelatedIDs() (ids []schema.DocID) { + for id := range m.related { + ids = append(ids, id) + } + return +} + +// ResetRelated resets all changes to the "related" edge. +func (m *DocMutation) ResetRelated() { + m.related = nil + m.clearedrelated = false + m.removedrelated = nil +} + // Where appends a list predicates to the DocMutation builder. func (m *DocMutation) Where(ps ...predicate.Doc) { m.predicates = append(m.predicates, ps...) @@ -2834,13 +2891,16 @@ func (m *DocMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *DocMutation) AddedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.parent != nil { edges = append(edges, doc.EdgeParent) } if m.children != nil { edges = append(edges, doc.EdgeChildren) } + if m.related != nil { + edges = append(edges, doc.EdgeRelated) + } return edges } @@ -2858,16 +2918,25 @@ func (m *DocMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case doc.EdgeRelated: + ids := make([]ent.Value, 0, len(m.related)) + for id := range m.related { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *DocMutation) RemovedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.removedchildren != nil { edges = append(edges, doc.EdgeChildren) } + if m.removedrelated != nil { + edges = append(edges, doc.EdgeRelated) + } return edges } @@ -2881,19 +2950,28 @@ func (m *DocMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case doc.EdgeRelated: + ids := make([]ent.Value, 0, len(m.removedrelated)) + for id := range m.removedrelated { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *DocMutation) ClearedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.clearedparent { edges = append(edges, doc.EdgeParent) } if m.clearedchildren { edges = append(edges, doc.EdgeChildren) } + if m.clearedrelated { + edges = append(edges, doc.EdgeRelated) + } return edges } @@ -2905,6 +2983,8 @@ func (m *DocMutation) EdgeCleared(name string) bool { return m.clearedparent case doc.EdgeChildren: return m.clearedchildren + case doc.EdgeRelated: + return m.clearedrelated } return false } @@ -2930,6 +3010,9 @@ func (m *DocMutation) ResetEdge(name string) error { case doc.EdgeChildren: m.ResetChildren() return nil + case doc.EdgeRelated: + m.ResetRelated() + return nil } return fmt.Errorf("unknown Doc edge %s", name) } diff --git a/entc/integration/customid/ent/schema/doc.go b/entc/integration/customid/ent/schema/doc.go index 30fbd1bb6..fab3a54c0 100644 --- a/entc/integration/customid/ent/schema/doc.go +++ b/entc/integration/customid/ent/schema/doc.go @@ -9,9 +9,12 @@ import ( "fmt" "entgo.io/ent" + "entgo.io/ent/dialect" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "github.com/google/uuid" + + "ariga.io/atlas/sql/postgres" ) // Doc holds the schema definition for the Doc entity. @@ -30,6 +33,9 @@ func (Doc) Fields() []ent.Field { Immutable(). DefaultFunc(func() DocID { return DocID(uuid.NewString()) + }). + SchemaType(map[string]string{ + dialect.Postgres: postgres.TypeUUID, }), field.String("text"). Optional(), @@ -42,6 +48,7 @@ func (Doc) Edges() []ent.Edge { edge.To("children", Doc.Type). From("parent"). Unique(), + edge.To("related", Doc.Type), } } diff --git a/entc/integration/edgeschema/ent/schema/role.go b/entc/integration/edgeschema/ent/schema/role.go index 0473053c3..ead48af7f 100644 --- a/entc/integration/edgeschema/ent/schema/role.go +++ b/entc/integration/edgeschema/ent/schema/role.go @@ -20,7 +20,8 @@ type Role struct { // Fields of the Role. func (Role) Fields() []ent.Field { return []ent.Field{ - field.String("name").Unique(), + field.String("name"). + Unique(), field.Time("created_at"). Default(time.Now), } diff --git a/entc/integration/edgeschema/ent/schema/role_user.go b/entc/integration/edgeschema/ent/schema/role_user.go index 2efb4e260..2462f527e 100644 --- a/entc/integration/edgeschema/ent/schema/role_user.go +++ b/entc/integration/edgeschema/ent/schema/role_user.go @@ -29,8 +29,7 @@ func (RoleUser) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(time.Now), - - // Edge fields + // Edge fields. field.Int("role_id"), field.Int("user_id"), }