diff --git a/dialect/entsql/annotation.go b/dialect/entsql/annotation.go index 6eb3c8a1b..ef4901c1f 100644 --- a/dialect/entsql/annotation.go +++ b/dialect/entsql/annotation.go @@ -159,10 +159,10 @@ func (a Annotation) Merge(other schema.Annotation) schema.Annotation { return a } -var ( - _ schema.Annotation = (*Annotation)(nil) - _ schema.Merger = (*Annotation)(nil) -) +var _ interface { + schema.Annotation + schema.Merger +} = (*Annotation)(nil) // ReferenceOption for constraint actions. type ReferenceOption string diff --git a/entc/entc.go b/entc/entc.go index 9e2d085f0..b9ac652e0 100644 --- a/entc/entc.go +++ b/entc/entc.go @@ -142,10 +142,13 @@ func Annotations(annotations ...Annotation) Option { } for _, ant := range annotations { name := ant.Name() - if _, ok := cfg.Annotations[name]; ok { + if curr, ok := cfg.Annotations[name]; !ok { + cfg.Annotations[name] = ant + } else if m, ok := curr.(schema.Merger); ok { + cfg.Annotations[name] = m.Merge(ant) + } else { return fmt.Errorf("duplicate annotations with name %q", name) } - cfg.Annotations[name] = ant } return nil } @@ -264,12 +267,12 @@ func (DefaultExtension) Options() []Option { return nil } var _ Extension = (*DefaultExtension)(nil) // DependencyOption allows configuring optional dependencies using functional options. -type DependencyOption func(*gen.DependencyAnnotation) error +type DependencyOption func(*gen.Dependency) error // DependencyType sets the type of the struct field in // the generated builders for the configured dependency. func DependencyType(v interface{}) DependencyOption { - return func(d *gen.DependencyAnnotation) error { + return func(d *gen.Dependency) error { if v == nil { return errors.New("nil dependency type") } @@ -292,7 +295,7 @@ func DependencyType(v interface{}) DependencyOption { // DependencyTypeInfo is similar to DependencyType, but // allows setting the field.TypeInfo explicitly. func DependencyTypeInfo(t *field.TypeInfo) DependencyOption { - return func(d *gen.DependencyAnnotation) error { + return func(d *gen.Dependency) error { if t == nil { return errors.New("nil dependency type info") } @@ -304,7 +307,7 @@ func DependencyTypeInfo(t *field.TypeInfo) DependencyOption { // DependencyField sets the struct field and the option name // of the dependency in the generated builders. func DependencyName(name string) DependencyOption { - return func(d *gen.DependencyAnnotation) error { + return func(d *gen.Dependency) error { d.Field = name d.Option = name return nil @@ -329,7 +332,7 @@ func DependencyName(name string) DependencyOption { // func Dependency(opts ...DependencyOption) Option { return func(cfg *gen.Config) error { - d := &gen.DependencyAnnotation{} + d := &gen.Dependency{} for _, opt := range opts { if err := opt(d); err != nil { return err @@ -338,17 +341,7 @@ func Dependency(opts ...DependencyOption) Option { if err := d.Build(); err != nil { return err } - if cfg.Annotations == nil { - cfg.Annotations = gen.Annotations{} - } - v, ok := cfg.Annotations[d.Name()] - if !ok { - v = []*gen.DependencyAnnotation{d} - } else { - v = append(v.([]*gen.DependencyAnnotation), d) - } - cfg.Annotations[d.Name()] = v - return nil + return Annotations(gen.Dependencies{d})(cfg) } } diff --git a/entc/gen/graph.go b/entc/gen/graph.go index 38864efcf..9abe662ba 100644 --- a/entc/gen/graph.go +++ b/entc/gen/graph.go @@ -8,13 +8,11 @@ package gen import ( "bytes" "encoding/json" - "errors" "fmt" "go/parser" "go/token" "os" "path/filepath" - "reflect" "runtime/debug" "strings" "text/template/parse" @@ -756,76 +754,6 @@ func (a assets) format() error { return nil } -// DependencyAnnotation allows configuring optional dependencies as struct fields on the -// generated builders. For example: -// -// DependencyAnnotation{ -// Field: "HTTPClient", -// Type: "*http.Client", -// Option: "WithClient", -// } -// -// Although the DependencyAnnotation is exported, used should use the entc.OptionalDependency -// option in order to build this annotation. -type DependencyAnnotation struct { - // Field defines the struct field name on the builders. - // It defaults to the full type name. For example: - // - // http.Client => HTTPClient - // net.Conn => NetConn - // url.URL => URL - // - Field string - // Type defines the type identifier. For example, `*http.Client`. - Type *field.TypeInfo - // Option defines the name of the config option. - // It defaults to the field name. - Option string -} - -// Name describes the annotation name. -func (DependencyAnnotation) Name() string { - return "Dependencies" -} - -// Build builds the annotation and fails if it is invalid. -func (d *DependencyAnnotation) Build() error { - if d.Type == nil { - return errors.New("entc/gen: missing dependency type") - } - if d.Field == "" { - name, err := d.defaultName() - if err != nil { - return err - } - d.Field = name - } - if d.Option == "" { - d.Option = d.Field - } - return nil -} - -func (d *DependencyAnnotation) defaultName() (string, error) { - var pkg, name string - switch parts := strings.Split(strings.TrimLeft(d.Type.Ident, "[]*"), "."); len(parts) { - case 1: - name = parts[0] - case 2: - name = parts[1] - // Avoid stuttering. - if !strings.EqualFold(parts[0], name) { - pkg = parts[0] - } - default: - return "", fmt.Errorf("entc/gen: unexpected number of parts: %q", parts) - } - if r := d.Type.RType; r != nil && (r.Kind == reflect.Array || r.Kind == reflect.Slice) { - name = plural(name) - } - return pascal(pkg) + pascal(name), nil -} - // expect panics if the condition is false. func expect(cond bool, msg string, args ...interface{}) { if !cond { diff --git a/entc/gen/graph_test.go b/entc/gen/graph_test.go index 151af9af2..2d82f24c8 100644 --- a/entc/gen/graph_test.go +++ b/entc/gen/graph_test.go @@ -434,7 +434,7 @@ func TestDependencyAnnotation_Build(t *testing.T) { }, } for _, tt := range tests { - d := &DependencyAnnotation{Type: tt.typ} + d := &Dependency{Type: tt.typ} require.NoError(t, d.Build()) require.Equal(t, tt.field, d.Field) } diff --git a/entc/gen/template.go b/entc/gen/template.go index 6f6745343..660b42ea0 100644 --- a/entc/gen/template.go +++ b/entc/gen/template.go @@ -7,16 +7,21 @@ package gen import ( "bytes" "embed" + "errors" "fmt" "go/parser" "go/token" "io/fs" "os" "path/filepath" + "reflect" "strconv" "strings" "text/template" "text/template/parse" + + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" ) type ( @@ -310,6 +315,95 @@ func MustParse(t *Template, err error) *Template { return t } +type ( + // Dependencies wraps a list of dependencies as codegen + // annotation. + Dependencies []*Dependency + + // Dependency allows configuring optional dependencies as struct fields on the + // generated builders. For example: + // + // DependencyAnnotation{ + // Field: "HTTPClient", + // Type: "*http.Client", + // Option: "WithClient", + // } + // + // Although the Dependency and the DependencyAnnotation are exported, used should + // use the entc.Dependency option in order to build this annotation. + Dependency struct { + // Field defines the struct field name on the builders. + // It defaults to the full type name. For example: + // + // http.Client => HTTPClient + // net.Conn => NetConn + // url.URL => URL + // + Field string + // Type defines the type identifier. For example, `*http.Client`. + Type *field.TypeInfo + // Option defines the name of the config option. + // It defaults to the field name. + Option string + } +) + +// Name describes the annotation name. +func (Dependencies) Name() string { + return "Dependencies" +} + +// Merge implements the schema.Merger interface. +func (d Dependencies) Merge(other schema.Annotation) schema.Annotation { + if deps, ok := other.(Dependencies); ok { + return append(d, deps...) + } + return d +} + +var _ interface { + schema.Annotation + schema.Merger +} = (*Dependencies)(nil) + +// Build builds the annotation and fails if it is invalid. +func (d *Dependency) Build() error { + if d.Type == nil { + return errors.New("entc/gen: missing dependency type") + } + if d.Field == "" { + name, err := d.defaultName() + if err != nil { + return err + } + d.Field = name + } + if d.Option == "" { + d.Option = d.Field + } + return nil +} + +func (d *Dependency) defaultName() (string, error) { + var pkg, name string + switch parts := strings.Split(strings.TrimLeft(d.Type.Ident, "[]*"), "."); len(parts) { + case 1: + name = parts[0] + case 2: + name = parts[1] + // Avoid stuttering. + if !strings.EqualFold(parts[0], name) { + pkg = parts[0] + } + default: + return "", fmt.Errorf("entc/gen: unexpected number of parts: %q", parts) + } + if r := d.Type.RType; r != nil && (r.Kind == reflect.Array || r.Kind == reflect.Slice) { + name = plural(name) + } + return pascal(pkg) + pascal(name), nil +} + func pkgf(s string) func(t *Type) string { return func(t *Type) string { return fmt.Sprintf(s, t.Package()) } }