mirror of
https://github.com/ent/ent.git
synced 2026-03-05 19:35:23 +03:00
schema/field: add support for ValueScanner on field.Bytes() (#4067)
This commit is contained in:
@@ -691,6 +691,15 @@ func (b *bytesBuilder) GoType(typ any) *bytesBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// ValueScanner provides an external value scanner for the given GoType.
|
||||
// Using this option allow users to use field types that do not implement
|
||||
// the sql.Scanner and driver.Valuer interfaces, such as slices and maps
|
||||
// or types exist in external packages (e.g., url.URL).
|
||||
func (b *bytesBuilder) ValueScanner(vs any) *bytesBuilder {
|
||||
b.desc.ValueScanner = vs
|
||||
return b
|
||||
}
|
||||
|
||||
// Annotations adds a list of annotations to the field object to be used by
|
||||
// codegen extensions.
|
||||
func (b *bytesBuilder) Annotations(annotations ...schema.Annotation) *bytesBuilder {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -289,6 +290,58 @@ func TestBytes_DefaultFunc(t *testing.T) {
|
||||
assert.EqualError(t, fd.Err, `field.Bytes("ip").DefaultFunc expects func but got slice`)
|
||||
}
|
||||
|
||||
type nullBytes []byte
|
||||
|
||||
func (b *nullBytes) Scan(v any) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case []byte:
|
||||
*b = v
|
||||
return nil
|
||||
case string:
|
||||
*b = []byte(v)
|
||||
return nil
|
||||
default:
|
||||
return errors.New("unexpected type")
|
||||
}
|
||||
}
|
||||
|
||||
func (b nullBytes) Value() (driver.Value, error) { return b, nil }
|
||||
|
||||
func TestBytes_ValueScanner(t *testing.T) {
|
||||
fd := field.Bytes("dir").
|
||||
ValueScanner(field.ValueScannerFunc[[]byte, *nullBytes]{
|
||||
V: func(s []byte) (driver.Value, error) {
|
||||
return []byte(hex.EncodeToString(s)), nil
|
||||
},
|
||||
S: func(ns *nullBytes) ([]byte, error) {
|
||||
if ns == nil {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := hex.DecodeString(string(*ns))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
},
|
||||
}).Descriptor()
|
||||
require.NoError(t, fd.Err)
|
||||
require.NotNil(t, fd.ValueScanner)
|
||||
_, ok := fd.ValueScanner.(field.ValueScannerFunc[[]byte, *nullBytes])
|
||||
require.True(t, ok)
|
||||
|
||||
fd = field.Bytes("url").
|
||||
GoType(&url.URL{}).
|
||||
ValueScanner(field.BinaryValueScanner[*url.URL]{}).
|
||||
Descriptor()
|
||||
require.NoError(t, fd.Err)
|
||||
require.NotNil(t, fd.ValueScanner)
|
||||
_, ok = fd.ValueScanner.(field.TypeValueScanner[*url.URL])
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestString_DefaultFunc(t *testing.T) {
|
||||
f1 := func() http.Dir { return "/tmp" }
|
||||
fd := field.String("dir").GoType(http.Dir("/tmp")).DefaultFunc(f1).Descriptor()
|
||||
|
||||
Reference in New Issue
Block a user