ent/entc: correctly cache type info in node.tmpl

Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/156

Reviewed By: a8m

Differential Revision: D18429543

fbshipit-source-id: 11bbd9c426878f819ebb2b89978e10948f0730bd
This commit is contained in:
Alex Snast
2019-11-12 00:12:17 -08:00
committed by Facebook Github Bot
parent 73e294a21e
commit aa57d732c1
4 changed files with 123 additions and 68 deletions

View File

@@ -10,7 +10,6 @@ import (
"context"
"fmt"
"log"
"sync"
"github.com/facebookincubator/ent/entc/integration/template/ent/migrate"
@@ -34,9 +33,8 @@ type Client struct {
// User is the client for interacting with the User builders.
User *UserClient
// additional fields.
sync.Mutex
tables []string
// additional fields for node api
tables tables
}
// NewClient creates a new client configured with the given options.

View File

@@ -11,12 +11,15 @@ import (
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/dialect/sql/schema"
"github.com/facebookincubator/ent/entc/integration/template/ent/group"
"github.com/facebookincubator/ent/entc/integration/template/ent/pet"
"github.com/facebookincubator/ent/entc/integration/template/ent/user"
"golang.org/x/sync/semaphore"
)
// Noder wraps the basic Node method.
@@ -146,58 +149,80 @@ func (u *User) Node(ctx context.Context) (node *Node, err error) {
return node, nil
}
var (
once sync.Once
types []string
noders = make(map[string]func(context.Context, int) (Noder, error))
)
func (c *Client) Node(ctx context.Context, id int) (*Node, error) {
noder, err := c.Noder(ctx, id)
n, err := c.Noder(ctx, id)
if err != nil {
return nil, err
}
return noder.Node(ctx)
return n.Node(ctx)
}
func (c *Client) Noder(ctx context.Context, id int) (Noder, error) {
var err error
once.Do(func() {
err = c.loadTypes(ctx)
})
tables, err := c.tables.Load(ctx, c.driver)
if err != nil {
return nil, err
}
idx := id / (1<<32 - 1)
if idx >= 0 && idx < len(types) {
if fn, ok := noders[types[idx]]; ok {
return fn(ctx, id)
}
if idx < 0 && idx >= len(tables) {
return nil, fmt.Errorf("cannot resolve table from id %v", id)
}
return nil, fmt.Errorf("cannot resolve node type for id %v", id)
return c.noder(ctx, tables[idx], id)
}
func (c *Client) loadTypes(ctx context.Context) error {
func (c *Client) noder(ctx context.Context, tbl string, id int) (Noder, error) {
switch tbl {
case group.Table:
return c.Group.Get(ctx, id)
case pet.Table:
return c.Pet.Get(ctx, id)
case user.Table:
return c.User.Get(ctx, id)
default:
return nil, fmt.Errorf("cannot resolve noder from table %q", tbl)
}
}
type (
tables struct {
once sync.Once
sem *semaphore.Weighted
value atomic.Value
}
querier interface {
Query(ctx context.Context, query string, args, v interface{}) error
}
)
func (t *tables) Load(ctx context.Context, querier querier) ([]string, error) {
if tables := t.value.Load(); tables != nil {
return tables.([]string), nil
}
t.once.Do(func() { t.sem = semaphore.NewWeighted(1) })
if err := t.sem.Acquire(ctx, 1); err != nil {
return nil, err
}
defer t.sem.Release(1)
if tables := t.value.Load(); tables != nil {
return tables.([]string), nil
}
tables, err := t.load(ctx, querier)
if err == nil {
t.value.Store(tables)
}
return tables, err
}
func (tables) load(ctx context.Context, querier querier) ([]string, error) {
rows := &sql.Rows{}
query, args := sql.Select("type").
From(sql.Table(schema.TypeTable)).
OrderBy(sql.Asc("id")).
Query()
if err := c.driver.Query(ctx, query, args, rows); err != nil {
return err
if err := querier.Query(ctx, query, args, rows); err != nil {
return nil, err
}
defer rows.Close()
if err := sql.ScanSlice(rows, &types); err != nil {
return err
}
noders[group.Table] = func(ctx context.Context, id int) (Noder, error) {
return c.Group.Get(ctx, id)
}
noders[pet.Table] = func(ctx context.Context, id int) (Noder, error) {
return c.Pet.Get(ctx, id)
}
noders[user.Table] = func(ctx context.Context, id int) (Noder, error) {
return c.User.Get(ctx, id)
}
return nil
var tables []string
return tables, sql.ScanSlice(rows, &tables)
}

View File

@@ -11,6 +11,8 @@ in the LICENSE file in the root directory of this source tree.
import (
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/dialect/sql/schema"
"golang.org/x/sync/semaphore"
)
// Noder wraps the basic Node method.
@@ -85,25 +87,16 @@ type Edge struct {
{{/* add the node api to the client */}}
var (
once sync.Once
types []string
noders = make(map[string]func(context.Context, {{ $.IDType }}) (Noder, error))
)
func (c *Client) Node(ctx context.Context, id {{ $.IDType }}) (*Node, error) {
noder, err := c.Noder(ctx, id)
n, err := c.Noder(ctx, id)
if err != nil {
return nil, err
}
return noder.Node(ctx)
return n.Node(ctx)
}
func (c *Client) Noder(ctx context.Context, id {{ $.IDType }}) (Noder, error) {
var err error
once.Do(func() {
err = c.loadTypes(ctx)
})
tables, err := c.tables.Load(ctx, c.driver)
if err != nil {
return nil, err
}
@@ -116,32 +109,70 @@ func (c *Client) Noder(ctx context.Context, id {{ $.IDType }}) (Noder, error) {
{{- else }}
idx := id/(1<<32 - 1)
{{- end }}
if idx >= 0 && idx < len(types) {
if fn, ok := noders[types[idx]]; ok {
return fn(ctx, id)
}
if idx < 0 && idx >= len(tables) {
return nil, fmt.Errorf("cannot resolve table from id %v", id)
}
return nil, fmt.Errorf("cannot resolve node type for id %v", id)
return c.noder(ctx, tables[idx], id)
}
func (c *Client) loadTypes(ctx context.Context) error {
func (c *Client) noder(ctx context.Context, tbl string, id {{ $.IDType }}) (Noder, error) {
switch tbl {
{{- range $_, $n := $.Nodes }}
case {{ $n.Package }}.Table:
return c.{{ $n.Name }}.Get(ctx, id)
{{- end }}
default:
return nil, fmt.Errorf("cannot resolve noder from table %q", tbl)
}
}
type (
tables struct {
once sync.Once
sem *semaphore.Weighted
value atomic.Value
}
querier interface {
Query(ctx context.Context, query string, args, v interface{}) error
}
)
func (t *tables) Load(ctx context.Context, querier querier) ([]string, error) {
if tables := t.value.Load(); tables != nil {
return tables.([]string), nil
}
t.once.Do(func() { t.sem = semaphore.NewWeighted(1) })
if err := t.sem.Acquire(ctx, 1); err != nil {
return nil, err
}
defer t.sem.Release(1)
if tables := t.value.Load(); tables != nil {
return tables.([]string), nil
}
tables, err := t.load(ctx, querier)
if err == nil {
t.value.Store(tables)
}
return tables, err
}
func (tables) load(ctx context.Context, querier querier) ([]string, error) {
rows := &sql.Rows{}
query, args := sql.Select("type").
From(sql.Table(schema.TypeTable)).
OrderBy(sql.Asc("id")).
Query()
if err := c.driver.Query(ctx, query, args, rows); err != nil {
return err
if err := querier.Query(ctx, query, args, rows); err != nil {
return nil, err
}
defer rows.Close()
if err := sql.ScanSlice(rows, &types); err != nil {
return err
}
{{- range $_, $n := $.Nodes }}
noders[{{ $n.Package }}.Table] = func(ctx context.Context, id {{ $.IDType }}) (Noder, error) {
return c.{{ $n.Name }}.Get(ctx, id)
}
{{- end }}
return nil
var tables []string
return tables, sql.ScanSlice(rows, &tables)
}
{{ end }}
{{ define "client/fields/additional" }}
// additional fields for node api
tables tables
{{ end }}

View File

@@ -6,6 +6,7 @@ package template
import (
"context"
"reflect"
"testing"
"github.com/facebookincubator/ent/entc/integration/template/ent"
@@ -43,6 +44,6 @@ func TestCustomTemplate(t *testing.T) {
require.Equal(t, g.ID, node.ID)
require.Equal(t, &ent.Field{Type: "int", Name: "MaxUsers", Value: "10"}, node.Fields[0])
// compile time check for client fields.
_ = &client.Mutex
// check for client additional fields.
require.True(t, reflect.ValueOf(client).Elem().FieldByName("tables").IsValid())
}