From 48362e79cdd9255d477cdca093ade2af55df7eb8 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Tue, 6 Oct 2020 16:29:09 +0300 Subject: [PATCH] entc/gen: derive the id-type from the schema (#823) if it was not provided --- entc/entc.go | 4 ---- entc/gen/graph.go | 26 ++++++++++++++++++++++++++ entc/gen/type.go | 6 +++++- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/entc/entc.go b/entc/entc.go index 6d37aec5a..89bb33081 100644 --- a/entc/entc.go +++ b/entc/entc.go @@ -17,7 +17,6 @@ import ( "github.com/facebook/ent/entc/gen" "github.com/facebook/ent/entc/load" - "github.com/facebook/ent/schema/field" ) // LoadGraph loads the schema package from the given schema path, @@ -58,9 +57,6 @@ func Generate(schemaPath string, cfg *gen.Config, options ...Option) (err error) // the schema. cfg.Target = filepath.Dir(abs) } - if cfg.IDType == nil { - cfg.IDType = &field.TypeInfo{Type: field.TypeInt} - } for _, opt := range options { if err := opt(cfg); err != nil { return err diff --git a/entc/gen/graph.go b/entc/gen/graph.go index f664e7e24..abd0bcf73 100644 --- a/entc/gen/graph.go +++ b/entc/gen/graph.go @@ -98,9 +98,35 @@ func NewGraph(c *Config, schemas ...*load.Schema) (g *Graph, err error) { for i := range schemas { g.addIndexes(schemas[i]) } + g.defaults() return } +// defaultIDType holds the default value for IDType. +var defaultIDType = &field.TypeInfo{Type: field.TypeInt} + +// defaults sets the default value of the IDType. The IDType field is used +// by multiple templates. If the IDType wasn't provided, it will fallback to +// int, or the one used in the schema (if all schemas share the same IDType). +func (g *Graph) defaults() { + if g.IDType != nil { + return + } + if len(g.Nodes) == 0 { + g.IDType = defaultIDType + return + } + // Check that all nodes have the same type for the ID field. + for i := 0; i < len(g.Nodes)-1; i++ { + cid, nid := g.Nodes[i].ID.Type, g.Nodes[i+1].ID.Type + if cid.Type != nid.Type { + g.IDType = defaultIDType + return + } + } + g.IDType = g.Nodes[0].ID.Type +} + // Gen generates the artifacts for the graph. func (g *Graph) Gen() error { var ( diff --git a/entc/gen/type.go b/entc/gen/type.go index dba1de5a7..a1ec32ec8 100644 --- a/entc/gen/type.go +++ b/entc/gen/type.go @@ -160,11 +160,15 @@ type ( // NewType creates a new type and its fields from the given schema. func NewType(c *Config, schema *load.Schema) (*Type, error) { + idType := c.IDType + if idType == nil { + idType = defaultIDType + } typ := &Type{ Config: c, ID: &Field{ Name: "id", - Type: c.IDType, + Type: idType, StructTag: structTag("id", ""), }, schema: schema,