diff --git a/cmd/entc/entc.go b/cmd/entc/entc.go index d6e556cd5..6e8feaf1e 100644 --- a/cmd/entc/entc.go +++ b/cmd/entc/entc.go @@ -95,6 +95,13 @@ func main() { for _, tmpl := range template { opts = append(opts, entc.TemplateDir(tmpl)) } + // If the target directory is not inferred from + // the schema path, resolve its package path. + if cfg.Target != "" { + pkgPath, err := PkgPath(DefaultConfig, cfg.Target) + failOnErr(err) + cfg.Package = pkgPath + } cfg.IDType = &field.TypeInfo{Type: field.Type(idtype)} err := entc.Generate(path[0], &cfg, opts...) failOnErr(err) diff --git a/cmd/entc/packages.go b/cmd/entc/packages.go index 6413525d4..d4be5487d 100644 --- a/cmd/entc/packages.go +++ b/cmd/entc/packages.go @@ -20,27 +20,32 @@ func PkgPath(config *packages.Config, target string) (string, error) { if config == nil { config = DefaultConfig } - abs, err := filepath.Abs(target) + pathCheck, err := filepath.Abs(target) if err != nil { return "", err } - pathCheck := abs + var parts []string if _, err := os.Stat(pathCheck); os.IsNotExist(err) { - pathCheck = filepath.Dir(abs) + parts = append(parts, filepath.Base(pathCheck)) + pathCheck = filepath.Dir(pathCheck) } - pkgs, err := packages.Load(config, pathCheck) - if err != nil { - return "", fmt.Errorf("load package info: %v", err) + // Try maximum 2 directories above the given target + // to find the root packages or module. + for i := 0; i < 2; i++ { + pkgs, err := packages.Load(config, pathCheck) + if err != nil { + return "", fmt.Errorf("load package info: %v", err) + } + if len(pkgs) == 0 || len(pkgs[0].Errors) != 0 { + parts = append(parts, filepath.Base(pathCheck)) + pathCheck = filepath.Dir(pathCheck) + continue + } + pkgPath := pkgs[0].PkgPath + for j := len(parts) - 1; j >= 0; j-- { + pkgPath = path.Join(pkgPath, parts[j]) + } + return pkgPath, nil } - if len(pkgs) == 0 { - return "", fmt.Errorf("no package was found for: %s", pathCheck) - } - if errs := pkgs[0].Errors; len(errs) != 0 { - return "", errs[0] - } - pkgPath := pkgs[0].PkgPath - if abs != pathCheck { - pkgPath = path.Join(pkgPath, filepath.Base(abs)) - } - return pkgPath, nil + return "", fmt.Errorf("root package or module was not found for: %s", target) } diff --git a/cmd/entc/packages_test.go b/cmd/entc/packages_test.go index a4c61086b..98228b971 100644 --- a/cmd/entc/packages_test.go +++ b/cmd/entc/packages_test.go @@ -35,6 +35,11 @@ func testPkgPath(t *testing.T, x packagestest.Exporter) { target = filepath.Join(e.Config.Dir, "z/ent") pkgPath, err = PkgPath(e.Config, target) + require.NoError(t, err) + require.Equal(t, "golang.org/x/y/z/ent", pkgPath) + + target = filepath.Join(e.Config.Dir, "z/e/n/t") + pkgPath, err = PkgPath(e.Config, target) require.Error(t, err) require.Empty(t, pkgPath) }