cmd/entc: support custom target on codegen (#189)

Summary:
Pull Request resolved: https://github.com/facebookincubator/ent/pull/189

Fixes #61

Reviewed By: alexsn

Differential Revision: D18676988

fbshipit-source-id: 00d415e14d1278a45edea49c69abe4916303f55d
This commit is contained in:
Ariel Mashraki
2019-11-24 07:57:34 -08:00
committed by Facebook Github Bot
parent 67c3fd2db9
commit 0344904a4e
3 changed files with 34 additions and 17 deletions

View File

@@ -95,6 +95,13 @@ func main() {
for _, tmpl := range template {
opts = append(opts, entc.TemplateDir(tmpl))
}
// If the target directory is not inferred from
// the schema path, resolve its package path.
if cfg.Target != "" {
pkgPath, err := PkgPath(DefaultConfig, cfg.Target)
failOnErr(err)
cfg.Package = pkgPath
}
cfg.IDType = &field.TypeInfo{Type: field.Type(idtype)}
err := entc.Generate(path[0], &cfg, opts...)
failOnErr(err)

View File

@@ -20,27 +20,32 @@ func PkgPath(config *packages.Config, target string) (string, error) {
if config == nil {
config = DefaultConfig
}
abs, err := filepath.Abs(target)
pathCheck, err := filepath.Abs(target)
if err != nil {
return "", err
}
pathCheck := abs
var parts []string
if _, err := os.Stat(pathCheck); os.IsNotExist(err) {
pathCheck = filepath.Dir(abs)
parts = append(parts, filepath.Base(pathCheck))
pathCheck = filepath.Dir(pathCheck)
}
pkgs, err := packages.Load(config, pathCheck)
if err != nil {
return "", fmt.Errorf("load package info: %v", err)
// Try maximum 2 directories above the given target
// to find the root packages or module.
for i := 0; i < 2; i++ {
pkgs, err := packages.Load(config, pathCheck)
if err != nil {
return "", fmt.Errorf("load package info: %v", err)
}
if len(pkgs) == 0 || len(pkgs[0].Errors) != 0 {
parts = append(parts, filepath.Base(pathCheck))
pathCheck = filepath.Dir(pathCheck)
continue
}
pkgPath := pkgs[0].PkgPath
for j := len(parts) - 1; j >= 0; j-- {
pkgPath = path.Join(pkgPath, parts[j])
}
return pkgPath, nil
}
if len(pkgs) == 0 {
return "", fmt.Errorf("no package was found for: %s", pathCheck)
}
if errs := pkgs[0].Errors; len(errs) != 0 {
return "", errs[0]
}
pkgPath := pkgs[0].PkgPath
if abs != pathCheck {
pkgPath = path.Join(pkgPath, filepath.Base(abs))
}
return pkgPath, nil
return "", fmt.Errorf("root package or module was not found for: %s", target)
}

View File

@@ -35,6 +35,11 @@ func testPkgPath(t *testing.T, x packagestest.Exporter) {
target = filepath.Join(e.Config.Dir, "z/ent")
pkgPath, err = PkgPath(e.Config, target)
require.NoError(t, err)
require.Equal(t, "golang.org/x/y/z/ent", pkgPath)
target = filepath.Join(e.Config.Dir, "z/e/n/t")
pkgPath, err = PkgPath(e.Config, target)
require.Error(t, err)
require.Empty(t, pkgPath)
}