From 17ee19e23a5c4ed52aac79a7d97289e02108167a Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Sun, 18 Oct 2020 14:17:20 +0300 Subject: [PATCH] entc/gen: add gen.Template for ent extensions (#859) --- entc/entc.go | 41 +++------------ entc/gen/bench_test.go | 5 +- entc/gen/graph.go | 33 +++--------- entc/gen/graph_test.go | 13 +++-- entc/gen/template.go | 98 +++++++++++++++++++++++++++++++++--- examples/entcpkg/ent/entc.go | 13 +++-- 6 files changed, 119 insertions(+), 84 deletions(-) diff --git a/entc/entc.go b/entc/entc.go index 89bb33081..25f5d40f7 100644 --- a/entc/entc.go +++ b/entc/entc.go @@ -9,11 +9,9 @@ package entc import ( "fmt" "go/token" - "os" "path" "path/filepath" "strings" - "text/template" "github.com/facebook/ent/entc/gen" "github.com/facebook/ent/entc/load" @@ -129,24 +127,10 @@ func FeatureNames(names ...string) Option { } } -// Funcs specifies external functions to add to the template execution. -func Funcs(funcMap template.FuncMap) Option { - return func(cfg *gen.Config) error { - if cfg.Funcs == nil { - cfg.Funcs = funcMap - return nil - } - for name, fn := range funcMap { - cfg.Funcs[name] = fn - } - return nil - } -} - // TemplateFiles parses the named files and associates the resulting templates // with codegen templates. func TemplateFiles(filenames ...string) Option { - return templateOption(func(t *template.Template) (*template.Template, error) { + return templateOption(func(t *gen.Template) (*gen.Template, error) { return t.ParseFiles(filenames...) }) } @@ -154,7 +138,7 @@ func TemplateFiles(filenames ...string) Option { // TemplateGlob parses the template definitions from the files identified // by the pattern and associates the resulting templates with codegen templates. func TemplateGlob(pattern string) Option { - return templateOption(func(t *template.Template) (*template.Template, error) { + return templateOption(func(t *gen.Template) (*gen.Template, error) { return t.ParseGlob(pattern) }) } @@ -162,29 +146,16 @@ func TemplateGlob(pattern string) Option { // TemplateDir parses the template definitions from the files in the directory // and associates the resulting templates with codegen templates. func TemplateDir(path string) Option { - return templateOption(func(t *template.Template) (*template.Template, error) { - err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { - if err != nil { - return fmt.Errorf("load template: %v", err) - } - if info.IsDir() || strings.HasSuffix(path, ".go") { - return nil - } - t, err = t.ParseFiles(path) - return err - }) - if err != nil { - return nil, err - } - return t, nil + return templateOption(func(t *gen.Template) (*gen.Template, error) { + return t.ParseDir(path) }) } // templateOption ensures the template instantiate // once for config and execute the given Option. -func templateOption(next func(t *template.Template) (*template.Template, error)) Option { +func templateOption(next func(t *gen.Template) (*gen.Template, error)) Option { return func(cfg *gen.Config) (err error) { - tmpl, err := next(template.New("external").Funcs(gen.Funcs).Funcs(cfg.Funcs)) + tmpl, err := next(gen.NewTemplate("external")) if err != nil { return err } diff --git a/entc/gen/bench_test.go b/entc/gen/bench_test.go index d81e8f23e..fbf021960 100644 --- a/entc/gen/bench_test.go +++ b/entc/gen/bench_test.go @@ -8,7 +8,6 @@ import ( "os" "path/filepath" "testing" - "text/template" "github.com/facebook/ent/entc" "github.com/facebook/ent/entc/gen" @@ -27,8 +26,8 @@ func BenchmarkGraph_Gen(b *testing.B) { IDType: &field.TypeInfo{Type: field.TypeInt}, Target: target, Package: "github.com/facebook/ent/entc/integration/ent", - Templates: []*template.Template{ - template.Must(template.New("template"). + Templates: []*gen.Template{ + gen.MustParse(gen.NewTemplate("template"). Funcs(gen.Funcs). ParseGlob("../integration/ent/template/*.tmpl")), }, diff --git a/entc/gen/graph.go b/entc/gen/graph.go index 682a1befc..8b6151039 100644 --- a/entc/gen/graph.go +++ b/entc/gen/graph.go @@ -14,7 +14,6 @@ import ( "os" "path/filepath" "runtime/debug" - "text/template" "text/template/parse" "github.com/facebook/ent/dialect/sql/schema" @@ -42,26 +41,12 @@ type ( // The supported types are string and int, which also the default. IDType *field.TypeInfo - // Template specifies an alternative template to execute or - // to override the default. If nil, the default template is used. - // - // Deprecated: the Template option predates the Templates option and it - // is planned be removed in v0.5.0. New code should use Templates instead. - Template *template.Template - // Templates specifies a list of alternative templates to execute or // to override the default. If nil, the default template is used. // // Note that, additional templates are executed on the Graph object and // the execution output is stored in a file derived by the template name. - Templates []*template.Template - - // Funcs specifies external functions to add to the template execution. - // - // Templates that use custom functions and override (or extend) the default - // templates will need to provide the same FuncMap that was used for parsing - // the template. - Funcs template.FuncMap + Templates []*Template // Features defines a list of additional features to add to the codegen phase. // For example, the PrivacyFeature. @@ -468,29 +453,25 @@ func (g *Graph) typ(name string) (*Type, bool) { // templates returns the template.Template for the code and external templates // to execute on the Graph object if provided. -func (g *Graph) templates() (*template.Template, []GraphTemplate) { - if g.Template != nil { - g.Templates = append(g.Templates, g.Template) - } - templates.Funcs(g.Funcs) +func (g *Graph) templates() (*Template, []GraphTemplate) { + initTemplates() external := make([]GraphTemplate, 0, len(g.Templates)) for _, rootT := range g.Templates { - rootT.Funcs(Funcs) - rootT.Funcs(g.Funcs) + templates.Funcs(rootT.FuncMap) for _, tmpl := range rootT.Templates() { if parse.IsEmptyTree(tmpl.Root) { continue } name := tmpl.Name() - // If this template doesn't override or extend one of the - // default templates, generate it in a new file. + // If the template does not override or extend one of + // the builtin templates, generate it in a new file. if templates.Lookup(name) == nil && !extendExisting(name) { external = append(external, GraphTemplate{ Name: name, Format: snake(name) + ".go", }) } - templates = template.Must(templates.AddParseTree(name, tmpl.Tree)) + templates = MustParse(templates.AddParseTree(name, tmpl.Tree)) } } return templates, external diff --git a/entc/gen/graph_test.go b/entc/gen/graph_test.go index 19f384844..a493b18ee 100644 --- a/entc/gen/graph_test.go +++ b/entc/gen/graph_test.go @@ -9,7 +9,6 @@ import ( "os" "path/filepath" "testing" - "text/template" "github.com/facebook/ent/entc/load" "github.com/facebook/ent/schema/field" @@ -270,13 +269,13 @@ func TestGraph_Gen(t *testing.T) { target := filepath.Join(os.TempDir(), "ent") require.NoError(os.MkdirAll(target, os.ModePerm), "creating tmpdir") defer os.RemoveAll(target) - external := template.Must(template.New("external").Parse("package external")) + external := MustParse(NewTemplate("external").Parse("package external")) graph, err := NewGraph(&Config{ - Package: "entc/gen", - Target: target, - Storage: drivers[0], - Template: external, - IDType: &field.TypeInfo{Type: field.TypeInt}, + Package: "entc/gen", + Target: target, + Storage: drivers[0], + Templates: []*Template{external}, + IDType: &field.TypeInfo{Type: field.TypeInt}, }, &load.Schema{ Name: "T1", Fields: []*load.Field{ diff --git a/entc/gen/template.go b/entc/gen/template.go index 06e4ed089..3ef32c6dd 100644 --- a/entc/gen/template.go +++ b/entc/gen/template.go @@ -9,9 +9,12 @@ import ( "fmt" "go/parser" "go/token" + "os" "path/filepath" "strconv" + "strings" "text/template" + "text/template/parse" "github.com/facebook/ent/entc/gen/internal" ) @@ -151,17 +154,15 @@ var ( }, } // templates holds the Go templates for the code generation. - // the init function below initializes the templates and its - // funcs to avoid initialization loop. - templates = template.New("templates") + templates *Template // importPkg are the import packages used for code generation. importPkg = make(map[string]string) ) -func init() { - templates.Funcs(Funcs) +func initTemplates() { + templates = NewTemplate("templates") for _, asset := range internal.AssetNames() { - templates = template.Must(templates.Parse(string(internal.MustAsset(asset)))) + templates = MustParse(templates.Parse(string(internal.MustAsset(asset)))) } b := bytes.NewBuffer([]byte("package main\n")) check(templates.ExecuteTemplate(b, "import", Type{Config: &Config{}}), "load imports") @@ -179,6 +180,91 @@ func init() { } } +// Template wraps the standard template.Template to +// provide additional functionality for ent extensions. +type Template struct { + *template.Template + FuncMap template.FuncMap +} + +// NewTemplate creates an empty template with the standard codegen functions. +func NewTemplate(name string) *Template { + t := &Template{Template: template.New(name)} + return t.Funcs(Funcs) +} + +// Funcs merges the given funcMap to the template functions. +func (t *Template) Funcs(funcMap template.FuncMap) *Template { + t.Template.Funcs(funcMap) + if t.FuncMap == nil { + t.FuncMap = template.FuncMap{} + } + for name, f := range funcMap { + if _, ok := t.FuncMap[name]; !ok { + t.FuncMap[name] = f + } + } + return t +} + +// Parse parses text as a template body for t. +func (t *Template) Parse(text string) (*Template, error) { + if _, err := t.Template.Parse(text); err != nil { + return nil, err + } + return t, nil +} + +// ParseFiles parses a list of files as templates and associate them with t. +// Each file can be a standalone template. +func (t *Template) ParseFiles(filenames ...string) (*Template, error) { + if _, err := t.Template.ParseFiles(filenames...); err != nil { + return nil, err + } + return t, nil +} + +// ParseGlob parses the files that match the given pattern as templates and +// associate them with t. +func (t *Template) ParseGlob(pattern string) (*Template, error) { + if _, err := t.Template.ParseGlob(pattern); err != nil { + return nil, err + } + return t, nil +} + +// ParseDir walks on the given dir path and parses the given matches with aren't Go files. +func (t *Template) ParseDir(path string) (*Template, error) { + err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("walk path %s: %v", path, err) + } + if info.IsDir() || strings.HasSuffix(path, ".go") { + return nil + } + _, err = t.ParseFiles(path) + return err + }) + return t, err +} + +// AddParseTree adds the given parse tree to the template. +func (t *Template) AddParseTree(name string, tree *parse.Tree) (*Template, error) { + if _, err := t.Template.AddParseTree(name, tree); err != nil { + return nil, err + } + return t, nil +} + +// MustParse is a helper that wraps a call to a function returning (*Template, error) +// and panics if the error is non-nil. +func MustParse(t *Template, err error) *Template { + if err != nil { + panic(err) + } + return t +} + func pkgf(s string) func(t *Type) string { return func(t *Type) string { return fmt.Sprintf(s, t.Package()) } } diff --git a/examples/entcpkg/ent/entc.go b/examples/entcpkg/ent/entc.go index 6e6eb874d..e15119ea4 100644 --- a/examples/entcpkg/ent/entc.go +++ b/examples/entcpkg/ent/entc.go @@ -18,10 +18,9 @@ import ( func main() { // A usage for custom templates with external functions. // One template is defined in the option below, and the - // second template is provided with the `Templates` field. + // second template is provided with the `Templates` option. opts := []entc.Option{ - entc.Funcs(template.FuncMap{"title": strings.ToTitle}), - entc.TemplateFiles("template/static.tmpl"), + entc.TemplateFiles("template/debug.tmpl"), } err := entc.Generate("./schema", &gen.Config{ Header: ` @@ -31,10 +30,10 @@ func main() { // Code generated by entc, DO NOT EDIT. `, - Templates: []*template.Template{ - template.Must(template.New("debug"). - Funcs(gen.Funcs). - ParseFiles("template/debug.tmpl")), + Templates: []*gen.Template{ + gen.MustParse(gen.NewTemplate("static"). + Funcs(template.FuncMap{"title": strings.ToTitle}). + ParseFiles("template/static.tmpl")), }, }, opts...) if err != nil {