mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
249
cmd/internal/base/base.go
Normal file
249
cmd/internal/base/base.go
Normal file
@@ -0,0 +1,249 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
// Package base defines shared basic pieces of the ent command.
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
"unicode"
|
||||
|
||||
"github.com/facebook/ent/cmd/internal/printer"
|
||||
"github.com/facebook/ent/entc"
|
||||
"github.com/facebook/ent/entc/gen"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// custom implementation for pflag.
|
||||
type IDType field.Type
|
||||
|
||||
// Set implements the Set method of the flag.Value interface.
|
||||
func (t *IDType) Set(s string) error {
|
||||
switch s {
|
||||
case field.TypeInt.String():
|
||||
*t = IDType(field.TypeInt)
|
||||
case field.TypeInt64.String():
|
||||
*t = IDType(field.TypeInt64)
|
||||
case field.TypeUint.String():
|
||||
*t = IDType(field.TypeUint)
|
||||
case field.TypeUint64.String():
|
||||
*t = IDType(field.TypeUint64)
|
||||
case field.TypeString.String():
|
||||
*t = IDType(field.TypeString)
|
||||
default:
|
||||
return fmt.Errorf("invalid type %q", s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Type returns the type representation of the id option for help command.
|
||||
func (IDType) Type() string {
|
||||
return fmt.Sprintf("%v", []field.Type{
|
||||
field.TypeInt,
|
||||
field.TypeInt64,
|
||||
field.TypeUint,
|
||||
field.TypeUint64,
|
||||
field.TypeString,
|
||||
})
|
||||
}
|
||||
|
||||
// String returns the default value for the help command.
|
||||
func (IDType) String() string {
|
||||
return field.TypeInt.String()
|
||||
}
|
||||
|
||||
// InitCmd returns the init command for ent/c packages.
|
||||
func InitCmd() *cobra.Command {
|
||||
var target string
|
||||
cmd := &cobra.Command{
|
||||
Use: "init [flags] [schemas]",
|
||||
Short: "initialize an environment with zero or more schemas",
|
||||
Example: examples(
|
||||
"ent init Example",
|
||||
"ent init --target entv1/schema User Group",
|
||||
),
|
||||
Args: func(_ *cobra.Command, names []string) error {
|
||||
for _, name := range names {
|
||||
if !unicode.IsUpper(rune(name[0])) {
|
||||
return errors.New("schema names must begin with uppercase")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, names []string) {
|
||||
if err := initEnv(target, names); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&target, "target", defaultSchema, "target directory for schemas")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// DescribeCmd returns the describe command for ent/c packages.
|
||||
func DescribeCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "describe [flags] path",
|
||||
Short: "printer a description of the graph schema",
|
||||
Example: examples(
|
||||
"ent describe ./ent/schema",
|
||||
"ent describe github.com/a8m/x",
|
||||
),
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, path []string) {
|
||||
graph, err := entc.LoadGraph(path[0], &gen.Config{})
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
printer.Fprint(os.Stdout, graph)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateCmd returns the generate command for ent/c packages.
|
||||
func GenerateCmd(postRun ...func(*gen.Config)) *cobra.Command {
|
||||
var (
|
||||
cfg gen.Config
|
||||
storage string
|
||||
features []string
|
||||
templates []string
|
||||
idtype = IDType(field.TypeInt)
|
||||
cmd = &cobra.Command{
|
||||
Use: "generate [flags] path",
|
||||
Short: "generate go code for the schema directory",
|
||||
Example: examples(
|
||||
"ent generate ./ent/schema",
|
||||
"ent generate github.com/a8m/x",
|
||||
),
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, path []string) {
|
||||
opts := []entc.Option{
|
||||
entc.Storage(storage),
|
||||
entc.FeatureNames(features...),
|
||||
}
|
||||
for _, tmpl := range templates {
|
||||
typ := "dir"
|
||||
if parts := strings.SplitN(tmpl, "=", 2); len(parts) > 1 {
|
||||
typ, tmpl = parts[0], parts[1]
|
||||
}
|
||||
switch typ {
|
||||
case "dir":
|
||||
opts = append(opts, entc.TemplateDir(tmpl))
|
||||
case "file":
|
||||
opts = append(opts, entc.TemplateFiles(tmpl))
|
||||
case "glob":
|
||||
opts = append(opts, entc.TemplateGlob(tmpl))
|
||||
default:
|
||||
log.Fatalln("unsupported template type", typ)
|
||||
}
|
||||
}
|
||||
// If the target directory is not inferred from
|
||||
// the schema path, resolve its package path.
|
||||
if cfg.Target != "" {
|
||||
pkgPath, err := PkgPath(DefaultConfig, cfg.Target)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
cfg.Package = pkgPath
|
||||
}
|
||||
cfg.IDType = &field.TypeInfo{Type: field.Type(idtype)}
|
||||
if err := entc.Generate(path[0], &cfg, opts...); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
for _, fn := range postRun {
|
||||
fn(&cfg)
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
cmd.Flags().Var(&idtype, "idtype", "type of the id field")
|
||||
cmd.Flags().StringVar(&storage, "storage", "sql", "storage driver to support in codegen")
|
||||
cmd.Flags().StringVar(&cfg.Header, "header", "", "override codegen header")
|
||||
cmd.Flags().StringVar(&cfg.Target, "target", "", "target directory for codegen")
|
||||
cmd.Flags().StringSliceVarP(&features, "feature", "", nil, "extend codegen with additional features")
|
||||
cmd.Flags().StringSliceVarP(&templates, "template", "", nil, "external templates to execute")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// initEnv initialize an environment for ent codegen.
|
||||
func initEnv(target string, names []string) error {
|
||||
if err := createDir(target); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range names {
|
||||
b := bytes.NewBuffer(nil)
|
||||
if err := tmpl.Execute(b, name); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
target := filepath.Join(target, strings.ToLower(name+".go"))
|
||||
if err := ioutil.WriteFile(target, b.Bytes(), 0644); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createDir(target string) error {
|
||||
_, err := os.Stat(target)
|
||||
if err == nil || !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(target, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("creating schema directory: %w", err)
|
||||
}
|
||||
if target != defaultSchema {
|
||||
return nil
|
||||
}
|
||||
if err := ioutil.WriteFile("ent/generate.go", []byte(genFile), 0644); err != nil {
|
||||
return fmt.Errorf("creating generate.go file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// schema template for the "init" command.
|
||||
var tmpl = template.Must(template.New("schema").
|
||||
Parse(`package schema
|
||||
|
||||
import "github.com/facebook/ent"
|
||||
|
||||
// {{ . }} holds the schema definition for the {{ . }} entity.
|
||||
type {{ . }} struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the {{ . }}.
|
||||
func ({{ . }}) Fields() []ent.Field {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Edges of the {{ . }}.
|
||||
func ({{ . }}) Edges() []ent.Edge {
|
||||
return nil
|
||||
}
|
||||
`))
|
||||
|
||||
const (
|
||||
// default schema package path.
|
||||
defaultSchema = "ent/schema"
|
||||
// ent/generate.go file used for "go generate" command.
|
||||
genFile = "package ent\n\n//go:generate go run github.com/facebook/ent/cmd/ent generate ./schema\n"
|
||||
)
|
||||
|
||||
// examples formats the given examples to the cli.
|
||||
func examples(ex ...string) string {
|
||||
for i := range ex {
|
||||
ex[i] = " " + ex[i] // indent each row with 2 spaces.
|
||||
}
|
||||
return strings.Join(ex, "\n")
|
||||
}
|
||||
55
cmd/internal/base/packages.go
Normal file
55
cmd/internal/base/packages.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
// DefaultConfig for loading Go base.
|
||||
var DefaultConfig = &packages.Config{Mode: packages.NeedName}
|
||||
|
||||
// PkgPath returns the Go package name for given target path.
|
||||
// Even if the existing path is not exist yet in the filesystem.
|
||||
//
|
||||
// If base.Config is nil, DefaultConfig will be used to load base.
|
||||
func PkgPath(config *packages.Config, target string) (string, error) {
|
||||
if config == nil {
|
||||
config = DefaultConfig
|
||||
}
|
||||
pathCheck, err := filepath.Abs(target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parts []string
|
||||
if _, err := os.Stat(pathCheck); os.IsNotExist(err) {
|
||||
parts = append(parts, filepath.Base(pathCheck))
|
||||
pathCheck = filepath.Dir(pathCheck)
|
||||
}
|
||||
// Try maximum 2 directories above the given
|
||||
// target to find the root package or module.
|
||||
for i := 0; i < 2; i++ {
|
||||
pkgs, err := packages.Load(config, pathCheck)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load package info: %v", err)
|
||||
}
|
||||
if len(pkgs) == 0 || len(pkgs[0].Errors) != 0 {
|
||||
parts = append(parts, filepath.Base(pathCheck))
|
||||
pathCheck = filepath.Dir(pathCheck)
|
||||
continue
|
||||
}
|
||||
pkgPath := pkgs[0].PkgPath
|
||||
for j := len(parts) - 1; j >= 0; j-- {
|
||||
pkgPath = path.Join(pkgPath, parts[j])
|
||||
}
|
||||
return pkgPath, nil
|
||||
}
|
||||
return "", fmt.Errorf("root package or module was not found for: %s", target)
|
||||
}
|
||||
49
cmd/internal/base/packages_test.go
Normal file
49
cmd/internal/base/packages_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/tools/go/packages/packagestest"
|
||||
)
|
||||
|
||||
func TestPkgPath(t *testing.T) { packagestest.TestAll(t, testPkgPath) }
|
||||
func testPkgPath(t *testing.T, x packagestest.Exporter) {
|
||||
e := packagestest.Export(t, x, []packagestest.Module{
|
||||
{
|
||||
Name: "golang.org/x",
|
||||
Files: map[string]interface{}{
|
||||
"x.go": "package x",
|
||||
"y/y.go": "package y",
|
||||
},
|
||||
},
|
||||
})
|
||||
defer e.Cleanup()
|
||||
|
||||
e.Config.Dir = filepath.Dir(e.File("golang.org/x", "x.go"))
|
||||
target := filepath.Join(e.Config.Dir, "ent")
|
||||
pkgPath, err := PkgPath(e.Config, target)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "golang.org/x/ent", pkgPath)
|
||||
|
||||
e.Config.Dir = filepath.Dir(e.File("golang.org/x", "y/y.go"))
|
||||
target = filepath.Join(e.Config.Dir, "ent")
|
||||
pkgPath, err = PkgPath(e.Config, target)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "golang.org/x/y/ent", pkgPath)
|
||||
|
||||
target = filepath.Join(e.Config.Dir, "z/ent")
|
||||
pkgPath, err = PkgPath(e.Config, target)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "golang.org/x/y/z/ent", pkgPath)
|
||||
|
||||
target = filepath.Join(e.Config.Dir, "z/e/n/t")
|
||||
pkgPath, err = PkgPath(e.Config, target)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, pkgPath)
|
||||
}
|
||||
Reference in New Issue
Block a user