mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
361 lines
10 KiB
Go
361 lines
10 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
// Package base defines shared basic pieces of the ent command.
|
|
package base
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"text/template"
|
|
"unicode"
|
|
|
|
"entgo.io/ent/cmd/internal/printer"
|
|
"entgo.io/ent/dialect/sql/schema"
|
|
"entgo.io/ent/entc"
|
|
"entgo.io/ent/entc/gen"
|
|
"entgo.io/ent/schema/field"
|
|
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
// IDType is a custom ID implementation for pflag.
|
|
type IDType field.Type
|
|
|
|
// Set implements the Set method of the flag.Value interface.
|
|
func (t *IDType) Set(s string) error {
|
|
switch s {
|
|
case field.TypeInt.String():
|
|
*t = IDType(field.TypeInt)
|
|
case field.TypeInt64.String():
|
|
*t = IDType(field.TypeInt64)
|
|
case field.TypeUint.String():
|
|
*t = IDType(field.TypeUint)
|
|
case field.TypeUint64.String():
|
|
*t = IDType(field.TypeUint64)
|
|
case field.TypeString.String():
|
|
*t = IDType(field.TypeString)
|
|
default:
|
|
return fmt.Errorf("invalid type %q", s)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Type returns the type representation of the id option for help command.
|
|
func (IDType) Type() string {
|
|
return fmt.Sprintf("%v", []field.Type{
|
|
field.TypeInt,
|
|
field.TypeInt64,
|
|
field.TypeUint,
|
|
field.TypeUint64,
|
|
field.TypeString,
|
|
})
|
|
}
|
|
|
|
// String returns the default value for the help command.
|
|
func (IDType) String() string {
|
|
return field.TypeInt.String()
|
|
}
|
|
|
|
// InitCmd returns the init command for ent/c packages.
|
|
func InitCmd() *cobra.Command {
|
|
c := NewCmd()
|
|
c.Use = "init [flags] [schemas]"
|
|
c.Short = "initialize an environment with zero or more schemas"
|
|
c.Example = examples(
|
|
"ent init Example",
|
|
"ent init --target entv1/schema User Group",
|
|
"ent init --template ./path/to/file.tmpl User",
|
|
)
|
|
c.Deprecated = `use "ent new" instead`
|
|
return c
|
|
}
|
|
|
|
// NewCmd returns the new command for ent/c packages.
|
|
func NewCmd() *cobra.Command {
|
|
var target, tmplPath string
|
|
cmd := &cobra.Command{
|
|
Use: "new [flags] [schemas]",
|
|
Short: "initialize a new environment with zero or more schemas",
|
|
Example: examples(
|
|
"ent new Example",
|
|
"ent new --target entv1/schema User Group",
|
|
"ent new --template ./path/to/file.tmpl User",
|
|
),
|
|
Args: func(_ *cobra.Command, names []string) error {
|
|
for _, name := range names {
|
|
if !unicode.IsUpper(rune(name[0])) {
|
|
return errors.New("schema names must begin with uppercase")
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
Run: func(cmd *cobra.Command, names []string) {
|
|
var (
|
|
err error
|
|
tmpl *template.Template
|
|
)
|
|
if tmplPath != "" {
|
|
tmpl = template.New(filepath.Base(tmplPath)).Funcs(gen.Funcs)
|
|
tmpl, err = tmpl.ParseFiles(tmplPath)
|
|
} else {
|
|
tmpl = template.New("schema").Funcs(gen.Funcs)
|
|
tmpl, err = tmpl.Parse(defaultTemplate)
|
|
}
|
|
if err != nil {
|
|
log.Fatalln(fmt.Errorf("ent/new: could not parse template %w", err))
|
|
}
|
|
if err := newEnv(target, names, tmpl); err != nil {
|
|
log.Fatalln(fmt.Errorf("ent/new: %w", err))
|
|
}
|
|
},
|
|
}
|
|
cmd.Flags().StringVar(&target, "target", defaultSchema, "target directory for schemas")
|
|
cmd.Flags().StringVar(&tmplPath, "template", "", "template to use for new schemas")
|
|
return cmd
|
|
}
|
|
|
|
// DescribeCmd returns the describe command for ent/c packages.
|
|
func DescribeCmd() *cobra.Command {
|
|
return &cobra.Command{
|
|
Use: "describe [flags] path",
|
|
Short: "print a description of the graph schema",
|
|
Example: examples(
|
|
"ent describe ./ent/schema",
|
|
"ent describe github.com/a8m/x",
|
|
),
|
|
Args: cobra.ExactArgs(1),
|
|
Run: func(cmd *cobra.Command, path []string) {
|
|
graph, err := entc.LoadGraph(path[0], &gen.Config{})
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
printer.Fprint(os.Stdout, graph)
|
|
},
|
|
}
|
|
}
|
|
|
|
// GenerateCmd returns the generate command for ent/c packages.
|
|
func GenerateCmd(postRun ...func(*gen.Config)) *cobra.Command {
|
|
var (
|
|
cfg gen.Config
|
|
storage string
|
|
features []string
|
|
templates []string
|
|
idtype = IDType(field.TypeInt)
|
|
cmd = &cobra.Command{
|
|
Use: "generate [flags] path",
|
|
Short: "generate go code for the schema directory",
|
|
Example: examples(
|
|
"ent generate ./ent/schema",
|
|
"ent generate github.com/a8m/x",
|
|
),
|
|
Args: cobra.ExactArgs(1),
|
|
Run: func(cmd *cobra.Command, path []string) {
|
|
opts := []entc.Option{
|
|
entc.Storage(storage),
|
|
entc.FeatureNames(features...),
|
|
}
|
|
for _, tmpl := range templates {
|
|
typ := "dir"
|
|
if parts := strings.SplitN(tmpl, "=", 2); len(parts) > 1 {
|
|
typ, tmpl = parts[0], parts[1]
|
|
}
|
|
switch typ {
|
|
case "dir":
|
|
opts = append(opts, entc.TemplateDir(tmpl))
|
|
case "file":
|
|
opts = append(opts, entc.TemplateFiles(tmpl))
|
|
case "glob":
|
|
opts = append(opts, entc.TemplateGlob(tmpl))
|
|
default:
|
|
log.Fatalln("unsupported template type", typ)
|
|
}
|
|
}
|
|
// If the target directory is not inferred from
|
|
// the schema path, resolve its package path.
|
|
if cfg.Target != "" {
|
|
pkgPath, err := PkgPath(DefaultConfig, cfg.Target)
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
cfg.Package = pkgPath
|
|
}
|
|
cfg.IDType = &field.TypeInfo{Type: field.Type(idtype)}
|
|
if err := entc.Generate(path[0], &cfg, opts...); err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
for _, fn := range postRun {
|
|
fn(&cfg)
|
|
}
|
|
},
|
|
}
|
|
)
|
|
cmd.Flags().Var(&idtype, "idtype", "type of the id field")
|
|
cmd.Flags().StringVar(&storage, "storage", "sql", "storage driver to support in codegen")
|
|
cmd.Flags().StringVar(&cfg.Header, "header", "", "override codegen header")
|
|
cmd.Flags().StringVar(&cfg.Target, "target", "", "target directory for codegen")
|
|
cmd.Flags().StringSliceVarP(&features, "feature", "", nil, "extend codegen with additional features")
|
|
cmd.Flags().StringSliceVarP(&templates, "template", "", nil, "external templates to execute")
|
|
// The --idtype flag predates the field.<Type>("id") option.
|
|
// See, https://entgo.io/docs/schema-fields#id-field.
|
|
cobra.CheckErr(cmd.Flags().MarkHidden("idtype"))
|
|
return cmd
|
|
}
|
|
|
|
// SchemaCmd returns DDL to use Ent as an Atlas schema loader.
|
|
func SchemaCmd() *cobra.Command {
|
|
var (
|
|
cfg gen.Config
|
|
dlct, version string
|
|
features, buildTags []string
|
|
hashSymbols bool
|
|
cmd = &cobra.Command{
|
|
Use: "schema [flags] path",
|
|
Short: "dump the DDL for the schema directory",
|
|
Example: examples(
|
|
"ent schema ./ent/schema --dialect mysql --version 5.6",
|
|
"ent schema ./ent/schema --dialect sqlite3",
|
|
"ent schema github.com/a8m/x --dialect postgres --version 15",
|
|
),
|
|
Args: cobra.ExactArgs(1),
|
|
Run: func(cmd *cobra.Command, path []string) {
|
|
for _, o := range []entc.Option{
|
|
entc.FeatureNames(features...),
|
|
entc.BuildTags(buildTags...),
|
|
} {
|
|
if err := o(&cfg); err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
}
|
|
// If the target directory is not inferred from
|
|
// the schema path, resolve its package path.
|
|
if cfg.Target != "" {
|
|
pkgPath, err := PkgPath(DefaultConfig, cfg.Target)
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
cfg.Package = pkgPath
|
|
}
|
|
g, err := entc.LoadGraph(path[0], &cfg)
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
t, err := g.Tables()
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
v, err := g.Views()
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
ddl, err := schema.DDL(cmd.Context(), schema.DDLArgs{
|
|
Dialect: dlct,
|
|
Version: version,
|
|
HashSymbols: hashSymbols,
|
|
Tables: append(t, v...),
|
|
})
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
fmt.Println(ddl)
|
|
},
|
|
}
|
|
)
|
|
cmd.Flags().StringVar(&dlct, "dialect", "", "database dialect to use")
|
|
cmd.Flags().StringVar(&version, "version", "", "database version to assume")
|
|
cmd.Flags().StringSliceVarP(&features, "feature", "", nil, "extend codegen with additional features")
|
|
cmd.Flags().StringSliceVarP(&buildTags, "build-tags", "", nil, "go build tags to use when loading the schema graph")
|
|
cmd.Flags().BoolVar(&hashSymbols, "hash-symbols", false, "whether to hash long symbols")
|
|
cobra.CheckErr(cmd.MarkFlagRequired("dialect"))
|
|
return cmd
|
|
}
|
|
|
|
// newEnv create a new environment for ent codegen.
|
|
func newEnv(target string, names []string, tmpl *template.Template) error {
|
|
if err := createDir(target); err != nil {
|
|
return fmt.Errorf("create dir %s: %w", target, err)
|
|
}
|
|
for _, name := range names {
|
|
if err := gen.ValidSchemaName(name); err != nil {
|
|
return fmt.Errorf("new schema %s: %w", name, err)
|
|
}
|
|
if fileExists(target, name) {
|
|
return fmt.Errorf("new schema %s: already exists", name)
|
|
}
|
|
b := bytes.NewBuffer(nil)
|
|
if err := tmpl.Execute(b, name); err != nil {
|
|
return fmt.Errorf("executing template %s: %w", name, err)
|
|
}
|
|
newFileTarget := filepath.Join(target, strings.ToLower(name+".go"))
|
|
if err := os.WriteFile(newFileTarget, b.Bytes(), 0644); err != nil {
|
|
return fmt.Errorf("writing file %s: %w", newFileTarget, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func createDir(target string) error {
|
|
_, err := os.Stat(target)
|
|
if err == nil || !os.IsNotExist(err) {
|
|
return err
|
|
}
|
|
if err := os.MkdirAll(target, os.ModePerm); err != nil {
|
|
return fmt.Errorf("creating schema directory: %w", err)
|
|
}
|
|
if target != defaultSchema {
|
|
return nil
|
|
}
|
|
if err := os.WriteFile("ent/generate.go", []byte(genFile), 0644); err != nil {
|
|
return fmt.Errorf("creating generate.go file: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func fileExists(target, name string) bool {
|
|
var _, err = os.Stat(filepath.Join(target, strings.ToLower(name+".go")))
|
|
|
|
return err == nil
|
|
}
|
|
|
|
const (
|
|
// default schema package path.
|
|
defaultSchema = "ent/schema"
|
|
// ent/generate.go file used for "go generate" command.
|
|
genFile = "package ent\n\n//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema\n"
|
|
// schema template for the "init" command.
|
|
defaultTemplate = `package schema
|
|
|
|
import "entgo.io/ent"
|
|
|
|
// {{ . }} holds the schema definition for the {{ . }} entity.
|
|
type {{ . }} struct {
|
|
ent.Schema
|
|
}
|
|
|
|
// Fields of the {{ . }}.
|
|
func ({{ . }}) Fields() []ent.Field {
|
|
return nil
|
|
}
|
|
|
|
// Edges of the {{ . }}.
|
|
func ({{ . }}) Edges() []ent.Edge {
|
|
return nil
|
|
}
|
|
`
|
|
)
|
|
|
|
// examples formats the given examples to the cli.
|
|
func examples(ex ...string) string {
|
|
for i := range ex {
|
|
ex[i] = " " + ex[i] // indent each row with 2 spaces.
|
|
}
|
|
return strings.Join(ex, "\n")
|
|
}
|