Files
ent/entc/gen/func.go
2019-12-22 13:32:51 +02:00

309 lines
7.3 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package gen
import (
"bytes"
"fmt"
"path/filepath"
"reflect"
"sort"
"strconv"
"strings"
"text/template"
"unicode"
"github.com/facebookincubator/ent/schema/field"
"github.com/go-openapi/inflect"
)
var (
// Funcs are the predefined template
// functions used by the codegen.
Funcs = template.FuncMap{
"ops": ops,
"add": add,
"append": reflect.Append,
"appends": reflect.AppendSlice,
"order": order,
"snake": snake,
"pascal": pascal,
"extend": extend,
"xrange": xrange,
"receiver": receiver,
"plural": plural,
"aggregate": aggregate,
"primitives": primitives,
"singular": rules.Singularize,
"quote": strconv.Quote,
"base": filepath.Base,
"keys": keys,
"join": join,
"lower": strings.ToLower,
"upper": strings.ToUpper,
"hasField": hasField,
"indirect": indirect,
"hasPrefix": strings.HasPrefix,
"hasSuffix": strings.HasSuffix,
"trimPackage": trimPackage,
"xtemplate": xtemplate,
"hasTemplate": hasTemplate,
}
rules = ruleset()
acronym = make(map[string]bool)
)
// ops returns all operations for given field.
func ops(f *Field) (op []Op) {
switch t := f.Type.Type; {
case t == field.TypeJSON:
case t == field.TypeBool:
op = boolOps
case t == field.TypeString && strings.ToLower(f.Name) != "id":
op = stringOps
case t == field.TypeEnum:
op = enumOps
default:
op = numericOps
}
if f.Optional {
op = append(op, nillableOps...)
}
return op
}
// xrange generates a slice of len n.
func xrange(n int) (a []int) {
for i := 0; i < n; i++ {
a = append(a, i)
}
return
}
// plural a name.
func plural(name string) string {
p := rules.Pluralize(name)
if p == name {
p += "Slice"
}
return p
}
// pascal converts the given column name into a PascalCase.
//
// user_info => UserInfo
// full_name => FullName
// user_id => UserID
//
func pascal(s string) string {
words := strings.Split(s, "_")
for i, w := range words {
upper := strings.ToUpper(w)
if acronym[upper] {
words[i] = upper
} else {
words[i] = rules.Capitalize(w)
}
}
return strings.Join(words, "")
}
// snake converts the given struct or field name into a snake_case.
//
// Username => username
// FullName => full_name
// HTTPCode => http_code
//
func snake(s string) string {
var b strings.Builder
for i := 0; i < len(s); i++ {
r := rune(s[i])
// put '_' if it is not a start or end of a word, current letter is an uppercase letter,
// and previous letter is a lowercase letter (cases like: "UserInfo"), or next letter is
// also a lowercase letter and previous letter is not "_".
if i > 0 && i < len(s)-1 && unicode.IsUpper(r) &&
(unicode.IsLower(rune(s[i-1])) ||
unicode.IsLower(rune(s[i+1])) && unicode.IsLetter(rune(s[i-1]))) {
b.WriteString("_")
}
b.WriteRune(unicode.ToLower(r))
}
return b.String()
}
// receiver returns the receiver name of the given type.
//
// []T => t
// [1]T => t
// User => u
// UserQuery => uq
//
func receiver(s string) (r string) {
// trim invalid tokens for identifier prefix.
s = strings.Trim(s, "[]*&0123456789")
parts := strings.Split(snake(s), "_")
min := len(parts[0])
for _, w := range parts[1:] {
if len(w) < min {
min = len(w)
}
}
for i := 1; i < min; i++ {
r := parts[0][:i]
for _, w := range parts[1:] {
r += w[:i]
}
if _, ok := importPkg[r]; !ok {
return r
}
}
return strings.ToLower(s)
}
// typeScope wraps the Type object with extended scope.
type typeScope struct {
*Type
Scope map[interface{}]interface{}
}
// graphScope wraps the Graph object with extended scope.
type graphScope struct {
*Graph
Scope map[interface{}]interface{}
}
// extend extends the parent block with a KV pairs.
//
// {{ with $scope := extend $ "key" "value" }}
// {{ template "setters" $scope }}
// {{ end}}
//
func extend(v interface{}, kv ...interface{}) (interface{}, error) {
scope := make(map[interface{}]interface{})
if len(kv)%2 != 0 {
return nil, fmt.Errorf("invalid number of parameters: %d", len(kv))
}
for i := 0; i < len(kv); i += 2 {
scope[kv[i]] = kv[i+1]
}
switch v := v.(type) {
case *Type:
return &typeScope{Type: v, Scope: scope}, nil
case *Graph:
return &graphScope{Graph: v, Scope: scope}, nil
case *typeScope:
for k := range v.Scope {
scope[k] = v.Scope[k]
}
return &typeScope{Type: v.Type, Scope: scope}, nil
default:
return nil, fmt.Errorf("invalid type for extend: %T", v)
}
}
// add calculates summarize list of variables.
func add(xs ...int) (n int) {
for _, x := range xs {
n += x
}
return
}
func ruleset() *inflect.Ruleset {
rules := inflect.NewDefaultRuleset()
// add common initialisms. copied from golint.
for _, w := range []string{
"API", "ASCII", "CPU", "CSS", "DNS", "GUID", "UID", "UI",
"RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "HTML",
"HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM",
"UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS",
} {
acronym[w] = true
rules.AddAcronym(w)
}
return rules
}
// order returns a map of sort orders.
// The key is the function name, and the value its database keyword.
func order() map[string]string {
return map[string]string{
"asc": "incr",
"desc": "decr",
}
}
// aggregate returns a map between all agg-functions and if they accept a field name as a parameter or not.
func aggregate() map[string]bool {
return map[string]bool{
"min": true,
"max": true,
"sum": true,
"mean": true,
"count": false,
}
}
// keys returns the given map keys.
func keys(v reflect.Value) ([]string, error) {
if k := v.Type().Kind(); k != reflect.Map {
return nil, fmt.Errorf("expect map for keys, got: %v", k)
}
keys := make([]string, v.Len())
for i, v := range v.MapKeys() {
keys[i] = v.String()
}
return keys, nil
}
// primitives returns all primitives types.
func primitives() []string {
return []string{field.TypeString.String(), field.TypeInt.String(), field.TypeFloat64.String(), field.TypeBool.String()}
}
// join is a wrapper around strings.Join to provide consistent output.
func join(a []string, sep string) string {
sort.Strings(a)
return strings.Join(a, sep)
}
// xtemplate dynamically executes templates by their names.
func xtemplate(name string, v interface{}) (string, error) {
buf := bytes.NewBuffer(nil)
if err := templates.ExecuteTemplate(buf, name, v); err != nil {
return "", err
}
return buf.String(), nil
}
// hasTemplate checks whether a template exists in the loaded templates.
func hasTemplate(name string) bool {
for _, t := range templates.Templates() {
if t.Name() == name {
return true
}
}
return false
}
// hasField determines if a struct has a field with the given name.
func hasField(v interface{}, name string) bool {
vr := reflect.Indirect(reflect.ValueOf(v))
return vr.FieldByName(name).IsValid()
}
// trimPackage trims the package name from the given identifier.
func trimPackage(ident, pkg string) string {
return strings.TrimPrefix(ident, pkg+".")
}
// indirect returns the item at the end of indirection.
func indirect(v reflect.Value) reflect.Value {
for ; v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface; v = v.Elem() {
}
return v
}