mirror of
https://github.com/ent/ent.git
synced 2026-04-28 05:30:56 +03:00
schema/field: expose RType.Implements method (#2379)
Also, add both (T) and (*T) methods for RType
This commit is contained in:
@@ -171,7 +171,7 @@ func (m *Migrate) Diff(ctx context.Context, tables ...*Table) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Skip if the plan has no changes
|
||||
// Skip if the plan has no changes.
|
||||
if len(plan.Changes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1165,37 +1165,43 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
|
||||
Methods: make(map[string]struct{ In, Out []*RType }, t.NumMethod()),
|
||||
},
|
||||
}
|
||||
methods(t, info.RType)
|
||||
switch t.Kind() {
|
||||
case reflect.Slice, reflect.Ptr, reflect.Map:
|
||||
info.Nillable = true
|
||||
}
|
||||
switch pt := reflect.PtrTo(t); {
|
||||
case pt.Implements(valueScannerType):
|
||||
t = pt
|
||||
fallthrough
|
||||
case t.Implements(valueScannerType):
|
||||
n := t.NumMethod()
|
||||
for i := 0; i < n; i++ {
|
||||
m := t.Method(i)
|
||||
in := make([]*RType, m.Type.NumIn()-1)
|
||||
for j := range in {
|
||||
arg := m.Type.In(j + 1)
|
||||
in[j] = &RType{Name: arg.Name(), Ident: arg.String(), Kind: arg.Kind(), PkgPath: arg.PkgPath()}
|
||||
}
|
||||
out := make([]*RType, m.Type.NumOut())
|
||||
for j := range out {
|
||||
ret := m.Type.Out(j)
|
||||
out[j] = &RType{Name: ret.Name(), Ident: ret.String(), Kind: ret.Kind(), PkgPath: ret.PkgPath()}
|
||||
}
|
||||
info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
|
||||
}
|
||||
case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
|
||||
case pt.Implements(valueScannerType), t.Implements(valueScannerType),
|
||||
t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
|
||||
default:
|
||||
d.Err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectType)
|
||||
}
|
||||
d.Info = info
|
||||
}
|
||||
|
||||
func methods(t reflect.Type, rtype *RType) {
|
||||
// For type T, add methods with
|
||||
// pointer receiver as well (*T).
|
||||
if t.Kind() != reflect.Ptr {
|
||||
t = reflect.PtrTo(t)
|
||||
}
|
||||
n := t.NumMethod()
|
||||
for i := 0; i < n; i++ {
|
||||
m := t.Method(i)
|
||||
in := make([]*RType, m.Type.NumIn()-1)
|
||||
for j := range in {
|
||||
arg := m.Type.In(j + 1)
|
||||
in[j] = &RType{Name: arg.Name(), Ident: arg.String(), Kind: arg.Kind(), PkgPath: arg.PkgPath()}
|
||||
}
|
||||
out := make([]*RType, m.Type.NumOut())
|
||||
for j := range out {
|
||||
ret := m.Type.Out(j)
|
||||
out[j] = &RType{Name: ret.Name(), Ident: ret.String(), Kind: ret.Kind(), PkgPath: ret.PkgPath()}
|
||||
}
|
||||
rtype.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Descriptor) checkDefaultFunc(expectType reflect.Type) {
|
||||
for _, typ := range []reflect.Type{reflect.TypeOf(d.Default), reflect.TypeOf(d.UpdateDefault)} {
|
||||
if typ == nil || typ.Kind() != reflect.Func || d.Err != nil {
|
||||
|
||||
@@ -8,14 +8,18 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/schema/field"
|
||||
|
||||
@@ -680,6 +684,77 @@ func TestField_Other(t *testing.T) {
|
||||
assert.Error(t, fd.Err, "invalid default value")
|
||||
}
|
||||
|
||||
type UserRole string
|
||||
|
||||
const (
|
||||
Admin UserRole = "ADMIN"
|
||||
User UserRole = "USER"
|
||||
Unknown UserRole = "UNKNOWN"
|
||||
)
|
||||
|
||||
func (UserRole) Values() (roles []string) {
|
||||
for _, r := range []UserRole{Admin, User, Unknown} {
|
||||
roles = append(roles, string(r))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e UserRole) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// MarshalGQL implements graphql.Marshaler interface.
|
||||
func (e UserRole) MarshalGQL(w io.Writer) {
|
||||
_, _ = io.WriteString(w, strconv.Quote(e.String()))
|
||||
}
|
||||
|
||||
// UnmarshalGQL implements graphql.Unmarshaler interface.
|
||||
func (e *UserRole) UnmarshalGQL(val interface{}) error {
|
||||
str, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("enum %T must be a string", val)
|
||||
}
|
||||
*e = UserRole(str)
|
||||
switch *e {
|
||||
case Admin, User, Unknown:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%s is not a valid Role", str)
|
||||
}
|
||||
}
|
||||
|
||||
type Scalar struct{}
|
||||
|
||||
func (Scalar) MarshalGQL(io.Writer) {}
|
||||
func (*Scalar) UnmarshalGQL(interface{}) error { return nil }
|
||||
func (Scalar) Value() (driver.Value, error) { return nil, nil }
|
||||
|
||||
func TestRType_Implements(t *testing.T) {
|
||||
type (
|
||||
marshaler interface{ MarshalGQL(w io.Writer) }
|
||||
unmarshaler interface{ UnmarshalGQL(v interface{}) error }
|
||||
codec interface {
|
||||
marshaler
|
||||
unmarshaler
|
||||
}
|
||||
)
|
||||
var (
|
||||
codecType = reflect.TypeOf((*codec)(nil)).Elem()
|
||||
marshalType = reflect.TypeOf((*marshaler)(nil)).Elem()
|
||||
unmarshalType = reflect.TypeOf((*unmarshaler)(nil)).Elem()
|
||||
)
|
||||
for _, f := range []ent.Field{
|
||||
field.Enum("role").GoType(Admin),
|
||||
field.Other("scalar", &Scalar{}),
|
||||
field.Other("scalar", Scalar{}),
|
||||
} {
|
||||
fd := f.Descriptor()
|
||||
assert.True(t, fd.Info.RType.Implements(codecType))
|
||||
assert.True(t, fd.Info.RType.Implements(marshalType))
|
||||
assert.True(t, fd.Info.RType.Implements(unmarshalType))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeString(t *testing.T) {
|
||||
typ := field.TypeBool
|
||||
assert.Equal(t, "bool", typ.String())
|
||||
|
||||
@@ -120,12 +120,12 @@ func (t TypeInfo) ConstName() string {
|
||||
|
||||
// ValueScanner indicates if this type implements the ValueScanner interface.
|
||||
func (t TypeInfo) ValueScanner() bool {
|
||||
return t.RType.implements(valueScannerType)
|
||||
return t.RType.Implements(valueScannerType)
|
||||
}
|
||||
|
||||
// Valuer indicates if this type implements the driver.Valuer interface.
|
||||
func (t TypeInfo) Valuer() bool {
|
||||
return t.RType.implements(valuerType)
|
||||
return t.RType.Implements(valuerType)
|
||||
}
|
||||
|
||||
// Comparable reports whether values of this type are comparable.
|
||||
@@ -147,7 +147,7 @@ var stringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
|
||||
|
||||
// Stringer indicates if this type implements the Stringer interface.
|
||||
func (t TypeInfo) Stringer() bool {
|
||||
return t.RType.implements(stringerType)
|
||||
return t.RType.Implements(stringerType)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -215,7 +215,8 @@ func (r *RType) IsPtr() bool {
|
||||
return r != nil && r.Kind == reflect.Ptr
|
||||
}
|
||||
|
||||
func (r *RType) implements(typ reflect.Type) bool {
|
||||
// Implements reports whether the RType ~implements the given interface type.
|
||||
func (r *RType) Implements(typ reflect.Type) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user