diff --git a/entc/integration/template/ent/node.go b/entc/integration/template/ent/node.go index ed9fe720d..c76d936d6 100644 --- a/entc/integration/template/ent/node.go +++ b/entc/integration/template/ent/node.go @@ -9,6 +9,7 @@ package ent import ( "context" "encoding/json" + "fmt" "sync" "github.com/facebookincubator/ent/dialect/sql" @@ -146,12 +147,20 @@ func (u *User) Node(ctx context.Context) (node *Node, err error) { } var ( - once sync.Once - types []string - typeNodes = make(map[string]func(context.Context, int) (*Node, error)) + once sync.Once + types []string + noders = make(map[string]func(context.Context, int) (Noder, error)) ) func (c *Client) Node(ctx context.Context, id int) (*Node, error) { + noder, err := c.Noder(ctx, id) + if err != nil { + return nil, err + } + return noder.Node(ctx) +} + +func (c *Client) Noder(ctx context.Context, id int) (Noder, error) { var err error once.Do(func() { err = c.loadTypes(ctx) @@ -160,7 +169,12 @@ func (c *Client) Node(ctx context.Context, id int) (*Node, error) { return nil, err } idx := id / (1<<32 - 1) - return typeNodes[types[idx]](ctx, id) + if idx >= 0 && idx < len(types) { + if fn, ok := noders[types[idx]]; ok { + return fn(ctx, id) + } + } + return nil, fmt.Errorf("cannot resolve node type for id %v", id) } func (c *Client) loadTypes(ctx context.Context) error { @@ -176,26 +190,14 @@ func (c *Client) loadTypes(ctx context.Context) error { if err := sql.ScanSlice(rows, &types); err != nil { return err } - typeNodes[group.Table] = func(ctx context.Context, id int) (*Node, error) { - nv, err := c.Group.Get(ctx, id) - if err != nil { - return nil, err - } - return nv.Node(ctx) + noders[group.Table] = func(ctx context.Context, id int) (Noder, error) { + return c.Group.Get(ctx, id) } - typeNodes[pet.Table] = func(ctx context.Context, id int) (*Node, error) { - nv, err := c.Pet.Get(ctx, id) - if err != nil { - return nil, err - } - return nv.Node(ctx) + noders[pet.Table] = func(ctx context.Context, id int) (Noder, error) { + return c.Pet.Get(ctx, id) } - typeNodes[user.Table] = func(ctx context.Context, id int) (*Node, error) { - nv, err := c.User.Get(ctx, id) - if err != nil { - return nil, err - } - return nv.Node(ctx) + noders[user.Table] = func(ctx context.Context, id int) (Noder, error) { + return c.User.Get(ctx, id) } return nil } diff --git a/entc/integration/template/ent/template/node.tmpl b/entc/integration/template/ent/template/node.tmpl index c53c4308b..20b6c6e00 100644 --- a/entc/integration/template/ent/template/node.tmpl +++ b/entc/integration/template/ent/template/node.tmpl @@ -8,6 +8,11 @@ in the LICENSE file in the root directory of this source tree. {{ $pkg := base $.Config.Package }} {{ template "header" $ }} +import ( + "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/dialect/sql/schema" +) + // Noder wraps the basic Node method. type Noder interface { Node(context.Context) (*Node, error) @@ -83,10 +88,18 @@ type Edge struct { var ( once sync.Once types []string - typeNodes = make(map[string]func(context.Context, {{ $.IDType }})(*Node, error)) + noders = make(map[string]func(context.Context, {{ $.IDType }}) (Noder, error)) ) func (c *Client) Node(ctx context.Context, id {{ $.IDType }}) (*Node, error) { + noder, err := c.Noder(ctx, id) + if err != nil { + return nil, err + } + return noder.Node(ctx) +} + +func (c *Client) Noder(ctx context.Context, id {{ $.IDType }}) (Noder, error) { var err error once.Do(func() { err = c.loadTypes(ctx) @@ -103,7 +116,12 @@ func (c *Client) Node(ctx context.Context, id {{ $.IDType }}) (*Node, error) { {{- else }} idx := id/(1<<32 - 1) {{- end }} - return typeNodes[types[idx]](ctx, id) + if idx >= 0 && idx < len(types) { + if fn, ok := noders[types[idx]]; ok { + return fn(ctx, id) + } + } + return nil, fmt.Errorf("cannot resolve node type for id %v", id) } func (c *Client) loadTypes(ctx context.Context) error { @@ -120,12 +138,8 @@ func (c *Client) loadTypes(ctx context.Context) error { return err } {{- range $_, $n := $.Nodes }} - typeNodes[{{ $n.Package }}.Table] = func(ctx context.Context, id {{ $.IDType }})(*Node, error) { - nv, err := c.{{ $n.Name }}.Get(ctx, id) - if err != nil { - return nil, err - } - return nv.Node(ctx) + noders[{{ $n.Package }}.Table] = func(ctx context.Context, id {{ $.IDType }}) (Noder, error) { + return c.{{ $n.Name }}.Get(ctx, id) } {{- end }} return nil