From aa57d732c1f68992542f4d819a6e8ee3f402a5f7 Mon Sep 17 00:00:00 2001 From: Alex Snast Date: Tue, 12 Nov 2019 00:12:17 -0800 Subject: [PATCH] ent/entc: correctly cache type info in node.tmpl Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/156 Reviewed By: a8m Differential Revision: D18429543 fbshipit-source-id: 11bbd9c426878f819ebb2b89978e10948f0730bd --- entc/integration/template/ent/client.go | 6 +- entc/integration/template/ent/node.go | 91 ++++++++++++------- .../template/ent/template/node.tmpl | 89 ++++++++++++------ entc/integration/template/template_test.go | 5 +- 4 files changed, 123 insertions(+), 68 deletions(-) diff --git a/entc/integration/template/ent/client.go b/entc/integration/template/ent/client.go index ec828b5d2..c124aaabe 100644 --- a/entc/integration/template/ent/client.go +++ b/entc/integration/template/ent/client.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "log" - "sync" "github.com/facebookincubator/ent/entc/integration/template/ent/migrate" @@ -34,9 +33,8 @@ type Client struct { // User is the client for interacting with the User builders. User *UserClient - // additional fields. - sync.Mutex - tables []string + // additional fields for node api + tables tables } // NewClient creates a new client configured with the given options. diff --git a/entc/integration/template/ent/node.go b/entc/integration/template/ent/node.go index c76d936d6..bbdb8d7a9 100644 --- a/entc/integration/template/ent/node.go +++ b/entc/integration/template/ent/node.go @@ -11,12 +11,15 @@ import ( "encoding/json" "fmt" "sync" + "sync/atomic" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/schema" "github.com/facebookincubator/ent/entc/integration/template/ent/group" "github.com/facebookincubator/ent/entc/integration/template/ent/pet" "github.com/facebookincubator/ent/entc/integration/template/ent/user" + + "golang.org/x/sync/semaphore" ) // Noder wraps the basic Node method. @@ -146,58 +149,80 @@ func (u *User) Node(ctx context.Context) (node *Node, err error) { return node, nil } -var ( - once sync.Once - types []string - noders = make(map[string]func(context.Context, int) (Noder, error)) -) - func (c *Client) Node(ctx context.Context, id int) (*Node, error) { - noder, err := c.Noder(ctx, id) + n, err := c.Noder(ctx, id) if err != nil { return nil, err } - return noder.Node(ctx) + return n.Node(ctx) } func (c *Client) Noder(ctx context.Context, id int) (Noder, error) { - var err error - once.Do(func() { - err = c.loadTypes(ctx) - }) + tables, err := c.tables.Load(ctx, c.driver) if err != nil { return nil, err } idx := id / (1<<32 - 1) - if idx >= 0 && idx < len(types) { - if fn, ok := noders[types[idx]]; ok { - return fn(ctx, id) - } + if idx < 0 && idx >= len(tables) { + return nil, fmt.Errorf("cannot resolve table from id %v", id) } - return nil, fmt.Errorf("cannot resolve node type for id %v", id) + return c.noder(ctx, tables[idx], id) } -func (c *Client) loadTypes(ctx context.Context) error { +func (c *Client) noder(ctx context.Context, tbl string, id int) (Noder, error) { + switch tbl { + case group.Table: + return c.Group.Get(ctx, id) + case pet.Table: + return c.Pet.Get(ctx, id) + case user.Table: + return c.User.Get(ctx, id) + default: + return nil, fmt.Errorf("cannot resolve noder from table %q", tbl) + } +} + +type ( + tables struct { + once sync.Once + sem *semaphore.Weighted + value atomic.Value + } + + querier interface { + Query(ctx context.Context, query string, args, v interface{}) error + } +) + +func (t *tables) Load(ctx context.Context, querier querier) ([]string, error) { + if tables := t.value.Load(); tables != nil { + return tables.([]string), nil + } + t.once.Do(func() { t.sem = semaphore.NewWeighted(1) }) + if err := t.sem.Acquire(ctx, 1); err != nil { + return nil, err + } + defer t.sem.Release(1) + if tables := t.value.Load(); tables != nil { + return tables.([]string), nil + } + tables, err := t.load(ctx, querier) + if err == nil { + t.value.Store(tables) + } + return tables, err +} + +func (tables) load(ctx context.Context, querier querier) ([]string, error) { rows := &sql.Rows{} query, args := sql.Select("type"). From(sql.Table(schema.TypeTable)). OrderBy(sql.Asc("id")). Query() - if err := c.driver.Query(ctx, query, args, rows); err != nil { - return err + if err := querier.Query(ctx, query, args, rows); err != nil { + return nil, err } defer rows.Close() - if err := sql.ScanSlice(rows, &types); err != nil { - return err - } - noders[group.Table] = func(ctx context.Context, id int) (Noder, error) { - return c.Group.Get(ctx, id) - } - noders[pet.Table] = func(ctx context.Context, id int) (Noder, error) { - return c.Pet.Get(ctx, id) - } - noders[user.Table] = func(ctx context.Context, id int) (Noder, error) { - return c.User.Get(ctx, id) - } - return nil + var tables []string + return tables, sql.ScanSlice(rows, &tables) } diff --git a/entc/integration/template/ent/template/node.tmpl b/entc/integration/template/ent/template/node.tmpl index 20b6c6e00..5691f9eda 100644 --- a/entc/integration/template/ent/template/node.tmpl +++ b/entc/integration/template/ent/template/node.tmpl @@ -11,6 +11,8 @@ in the LICENSE file in the root directory of this source tree. import ( "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/schema" + + "golang.org/x/sync/semaphore" ) // Noder wraps the basic Node method. @@ -85,25 +87,16 @@ type Edge struct { {{/* add the node api to the client */}} -var ( - once sync.Once - types []string - noders = make(map[string]func(context.Context, {{ $.IDType }}) (Noder, error)) -) - func (c *Client) Node(ctx context.Context, id {{ $.IDType }}) (*Node, error) { - noder, err := c.Noder(ctx, id) + n, err := c.Noder(ctx, id) if err != nil { return nil, err } - return noder.Node(ctx) + return n.Node(ctx) } func (c *Client) Noder(ctx context.Context, id {{ $.IDType }}) (Noder, error) { - var err error - once.Do(func() { - err = c.loadTypes(ctx) - }) + tables, err := c.tables.Load(ctx, c.driver) if err != nil { return nil, err } @@ -116,32 +109,70 @@ func (c *Client) Noder(ctx context.Context, id {{ $.IDType }}) (Noder, error) { {{- else }} idx := id/(1<<32 - 1) {{- end }} - if idx >= 0 && idx < len(types) { - if fn, ok := noders[types[idx]]; ok { - return fn(ctx, id) - } + if idx < 0 && idx >= len(tables) { + return nil, fmt.Errorf("cannot resolve table from id %v", id) } - return nil, fmt.Errorf("cannot resolve node type for id %v", id) + return c.noder(ctx, tables[idx], id) } -func (c *Client) loadTypes(ctx context.Context) error { +func (c *Client) noder(ctx context.Context, tbl string, id {{ $.IDType }}) (Noder, error) { + switch tbl { + {{- range $_, $n := $.Nodes }} + case {{ $n.Package }}.Table: + return c.{{ $n.Name }}.Get(ctx, id) + {{- end }} + default: + return nil, fmt.Errorf("cannot resolve noder from table %q", tbl) + } +} + +type ( + tables struct { + once sync.Once + sem *semaphore.Weighted + value atomic.Value + } + + querier interface { + Query(ctx context.Context, query string, args, v interface{}) error + } +) + +func (t *tables) Load(ctx context.Context, querier querier) ([]string, error) { + if tables := t.value.Load(); tables != nil { + return tables.([]string), nil + } + t.once.Do(func() { t.sem = semaphore.NewWeighted(1) }) + if err := t.sem.Acquire(ctx, 1); err != nil { + return nil, err + } + defer t.sem.Release(1) + if tables := t.value.Load(); tables != nil { + return tables.([]string), nil + } + tables, err := t.load(ctx, querier) + if err == nil { + t.value.Store(tables) + } + return tables, err +} + +func (tables) load(ctx context.Context, querier querier) ([]string, error) { rows := &sql.Rows{} query, args := sql.Select("type"). From(sql.Table(schema.TypeTable)). OrderBy(sql.Asc("id")). Query() - if err := c.driver.Query(ctx, query, args, rows); err != nil { - return err + if err := querier.Query(ctx, query, args, rows); err != nil { + return nil, err } defer rows.Close() - if err := sql.ScanSlice(rows, &types); err != nil { - return err - } - {{- range $_, $n := $.Nodes }} - noders[{{ $n.Package }}.Table] = func(ctx context.Context, id {{ $.IDType }}) (Noder, error) { - return c.{{ $n.Name }}.Get(ctx, id) - } - {{- end }} - return nil + var tables []string + return tables, sql.ScanSlice(rows, &tables) } {{ end }} + +{{ define "client/fields/additional" }} + // additional fields for node api + tables tables +{{ end }} diff --git a/entc/integration/template/template_test.go b/entc/integration/template/template_test.go index 67e3a039a..4af934b00 100644 --- a/entc/integration/template/template_test.go +++ b/entc/integration/template/template_test.go @@ -6,6 +6,7 @@ package template import ( "context" + "reflect" "testing" "github.com/facebookincubator/ent/entc/integration/template/ent" @@ -43,6 +44,6 @@ func TestCustomTemplate(t *testing.T) { require.Equal(t, g.ID, node.ID) require.Equal(t, &ent.Field{Type: "int", Name: "MaxUsers", Value: "10"}, node.Fields[0]) - // compile time check for client fields. - _ = &client.Mutex + // check for client additional fields. + require.True(t, reflect.ValueOf(client).Elem().FieldByName("tables").IsValid()) }