entc/integration/json: add example for using interfaces in JSON fields (#2497)

This commit is contained in:
Ariel Mashraki
2022-04-25 13:34:05 +03:00
committed by GitHub
parent 04e0dc936b
commit 879bb8a905
11 changed files with 266 additions and 6 deletions

View File

@@ -22,6 +22,7 @@ var (
{Name: "ints", Type: field.TypeJSON, Nullable: true},
{Name: "floats", Type: field.TypeJSON, Nullable: true},
{Name: "strings", Type: field.TypeJSON, Nullable: true},
{Name: "addr", Type: field.TypeJSON, Nullable: true},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{

View File

@@ -47,6 +47,7 @@ type UserMutation struct {
ints *[]int
floats *[]float64
strings *[]string
addr *schema.Addr
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*User, error)
@@ -481,6 +482,55 @@ func (m *UserMutation) ResetStrings() {
delete(m.clearedFields, user.FieldStrings)
}
// SetAddr sets the "addr" field.
func (m *UserMutation) SetAddr(s schema.Addr) {
m.addr = &s
}
// Addr returns the value of the "addr" field in the mutation.
func (m *UserMutation) Addr() (r schema.Addr, exists bool) {
v := m.addr
if v == nil {
return
}
return *v, true
}
// OldAddr returns the old "addr" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldAddr(ctx context.Context) (v schema.Addr, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAddr is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAddr requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAddr: %w", err)
}
return oldValue.Addr, nil
}
// ClearAddr clears the value of the "addr" field.
func (m *UserMutation) ClearAddr() {
m.addr = nil
m.clearedFields[user.FieldAddr] = struct{}{}
}
// AddrCleared returns if the "addr" field was cleared in this mutation.
func (m *UserMutation) AddrCleared() bool {
_, ok := m.clearedFields[user.FieldAddr]
return ok
}
// ResetAddr resets all changes to the "addr" field.
func (m *UserMutation) ResetAddr() {
m.addr = nil
delete(m.clearedFields, user.FieldAddr)
}
// Where appends a list predicates to the UserMutation builder.
func (m *UserMutation) Where(ps ...predicate.User) {
m.predicates = append(m.predicates, ps...)
@@ -500,7 +550,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 7)
fields := make([]string, 0, 8)
if m.t != nil {
fields = append(fields, user.FieldT)
}
@@ -522,6 +572,9 @@ func (m *UserMutation) Fields() []string {
if m.strings != nil {
fields = append(fields, user.FieldStrings)
}
if m.addr != nil {
fields = append(fields, user.FieldAddr)
}
return fields
}
@@ -544,6 +597,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.Floats()
case user.FieldStrings:
return m.Strings()
case user.FieldAddr:
return m.Addr()
}
return nil, false
}
@@ -567,6 +622,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldFloats(ctx)
case user.FieldStrings:
return m.OldStrings(ctx)
case user.FieldAddr:
return m.OldAddr(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -625,6 +682,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetStrings(v)
return nil
case user.FieldAddr:
v, ok := value.(schema.Addr)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAddr(v)
return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -673,6 +737,9 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldStrings) {
fields = append(fields, user.FieldStrings)
}
if m.FieldCleared(user.FieldAddr) {
fields = append(fields, user.FieldAddr)
}
return fields
}
@@ -705,6 +772,9 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldStrings:
m.ClearStrings()
return nil
case user.FieldAddr:
m.ClearAddr()
return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -734,6 +804,9 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldStrings:
m.ResetStrings()
return nil
case user.FieldAddr:
m.ResetAddr()
return nil
}
return fmt.Errorf("unknown User field %s", name)
}

View File

@@ -6,6 +6,9 @@ package schema
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
@@ -38,6 +41,8 @@ func (User) Fields() []ent.Field {
Optional(),
field.Strings("strings").
Optional(),
field.JSON("addr", Addr{}).
Optional(),
}
}
@@ -50,3 +55,47 @@ type T struct {
Li []int `json:"li,omitempty"`
Ls []string `json:"ls,omitempty"`
}
type Addr struct{ net.Addr }
func (a *Addr) UnmarshalJSON(data []byte) error {
var types struct {
TCP *net.TCPAddr `json:"tcp,omitempty"`
UDP *net.UDPAddr `json:"udp,omitempty"`
}
if err := json.Unmarshal(data, &types); err != nil {
return err
}
switch {
case types.TCP != nil && types.UDP != nil:
return errors.New("TCP and UDP addresses are mutually exclusive")
case types.TCP != nil:
a.Addr = types.TCP
case types.UDP != nil:
a.Addr = types.UDP
}
return nil
}
func (a Addr) MarshalJSON() ([]byte, error) {
var types struct {
TCP *net.TCPAddr `json:"tcp,omitempty"`
UDP *net.UDPAddr `json:"udp,omitempty"`
}
switch a := a.Addr.(type) {
case *net.TCPAddr:
types.TCP = a
case *net.UDPAddr:
types.UDP = a
default:
return nil, fmt.Errorf("unsupported address type: %T", a)
}
return json.Marshal(types)
}
func (a Addr) String() string {
if a.Addr == nil {
return ""
}
return a.Addr.String()
}

View File

@@ -37,6 +37,8 @@ type User struct {
Floats []float64 `json:"floats,omitempty"`
// Strings holds the value of the "strings" field.
Strings []string `json:"strings,omitempty"`
// Addr holds the value of the "addr" field.
Addr schema.Addr `json:"addr,omitempty"`
}
// scanValues returns the types for scanning values from sql.Rows.
@@ -44,7 +46,7 @@ func (*User) scanValues(columns []string) ([]interface{}, error) {
values := make([]interface{}, len(columns))
for i := range columns {
switch columns[i] {
case user.FieldT, user.FieldURL, user.FieldRaw, user.FieldDirs, user.FieldInts, user.FieldFloats, user.FieldStrings:
case user.FieldT, user.FieldURL, user.FieldRaw, user.FieldDirs, user.FieldInts, user.FieldFloats, user.FieldStrings, user.FieldAddr:
values[i] = new([]byte)
case user.FieldID:
values[i] = new(sql.NullInt64)
@@ -125,6 +127,14 @@ func (u *User) assignValues(columns []string, values []interface{}) error {
return fmt.Errorf("unmarshal field strings: %w", err)
}
}
case user.FieldAddr:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field addr", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &u.Addr); err != nil {
return fmt.Errorf("unmarshal field addr: %w", err)
}
}
}
}
return nil
@@ -167,6 +177,8 @@ func (u *User) String() string {
builder.WriteString(fmt.Sprintf("%v", u.Floats))
builder.WriteString(", strings=")
builder.WriteString(fmt.Sprintf("%v", u.Strings))
builder.WriteString(", addr=")
builder.WriteString(fmt.Sprintf("%v", u.Addr))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -29,6 +29,8 @@ const (
FieldFloats = "floats"
// FieldStrings holds the string denoting the strings field in the database.
FieldStrings = "strings"
// FieldAddr holds the string denoting the addr field in the database.
FieldAddr = "addr"
// Table holds the table name of the user in the database.
Table = "users"
)
@@ -43,6 +45,7 @@ var Columns = []string{
FieldInts,
FieldFloats,
FieldStrings,
FieldAddr,
}
// ValidColumn reports if the column name is valid (part of the table columns).

View File

@@ -178,6 +178,20 @@ func StringsNotNil() predicate.User {
})
}
// AddrIsNil applies the IsNil predicate on the "addr" field.
func AddrIsNil() predicate.User {
return predicate.User(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C(FieldAddr)))
})
}
// AddrNotNil applies the NotNil predicate on the "addr" field.
func AddrNotNil() predicate.User {
return predicate.User(func(s *sql.Selector) {
s.Where(sql.NotNull(s.C(FieldAddr)))
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.User) predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -69,6 +69,20 @@ func (uc *UserCreate) SetStrings(s []string) *UserCreate {
return uc
}
// SetAddr sets the "addr" field.
func (uc *UserCreate) SetAddr(s schema.Addr) *UserCreate {
uc.mutation.SetAddr(s)
return uc
}
// SetNillableAddr sets the "addr" field if the given value is not nil.
func (uc *UserCreate) SetNillableAddr(s *schema.Addr) *UserCreate {
if s != nil {
uc.SetAddr(*s)
}
return uc
}
// Mutation returns the UserMutation object of the builder.
func (uc *UserCreate) Mutation() *UserMutation {
return uc.mutation
@@ -238,6 +252,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
})
_node.Strings = value
}
if value, ok := uc.mutation.Addr(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeJSON,
Value: value,
Column: user.FieldAddr,
})
_node.Addr = value
}
return _node, _spec
}

View File

@@ -113,6 +113,26 @@ func (uu *UserUpdate) ClearStrings() *UserUpdate {
return uu
}
// SetAddr sets the "addr" field.
func (uu *UserUpdate) SetAddr(s schema.Addr) *UserUpdate {
uu.mutation.SetAddr(s)
return uu
}
// SetNillableAddr sets the "addr" field if the given value is not nil.
func (uu *UserUpdate) SetNillableAddr(s *schema.Addr) *UserUpdate {
if s != nil {
uu.SetAddr(*s)
}
return uu
}
// ClearAddr clears the value of the "addr" field.
func (uu *UserUpdate) ClearAddr() *UserUpdate {
uu.mutation.ClearAddr()
return uu
}
// Mutation returns the UserMutation object of the builder.
func (uu *UserUpdate) Mutation() *UserMutation {
return uu.mutation
@@ -275,6 +295,19 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
Column: user.FieldStrings,
})
}
if value, ok := uu.mutation.Addr(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeJSON,
Value: value,
Column: user.FieldAddr,
})
}
if uu.mutation.AddrCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeJSON,
Column: user.FieldAddr,
})
}
if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -372,6 +405,26 @@ func (uuo *UserUpdateOne) ClearStrings() *UserUpdateOne {
return uuo
}
// SetAddr sets the "addr" field.
func (uuo *UserUpdateOne) SetAddr(s schema.Addr) *UserUpdateOne {
uuo.mutation.SetAddr(s)
return uuo
}
// SetNillableAddr sets the "addr" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableAddr(s *schema.Addr) *UserUpdateOne {
if s != nil {
uuo.SetAddr(*s)
}
return uuo
}
// ClearAddr clears the value of the "addr" field.
func (uuo *UserUpdateOne) ClearAddr() *UserUpdateOne {
uuo.mutation.ClearAddr()
return uuo
}
// Mutation returns the UserMutation object of the builder.
func (uuo *UserUpdateOne) Mutation() *UserMutation {
return uuo.mutation
@@ -558,6 +611,19 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error)
Column: user.FieldStrings,
})
}
if value, ok := uuo.mutation.Addr(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeJSON,
Value: value,
Column: user.FieldAddr,
})
}
if uuo.mutation.AddrCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeJSON,
Column: user.FieldAddr,
})
}
_node = &User{config: uuo.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -8,6 +8,7 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"testing"
@@ -46,6 +47,7 @@ func TestMySQL(t *testing.T) {
Ints(t, client)
Floats(t, client)
Strings(t, client)
NetAddr(t, client)
RawMessage(t, client)
// Skip predicates test for MySQL old versions.
if version != "56" {
@@ -79,6 +81,7 @@ func TestMaria(t *testing.T) {
Ints(t, client)
Floats(t, client)
Strings(t, client)
NetAddr(t, client)
RawMessage(t, client)
Predicates(t, client)
})
@@ -108,6 +111,7 @@ func TestPostgres(t *testing.T) {
Ints(t, client)
Floats(t, client)
Strings(t, client)
NetAddr(t, client)
RawMessage(t, client)
Predicates(t, client)
})
@@ -126,6 +130,7 @@ func TestSQLite(t *testing.T) {
Ints(t, client)
Floats(t, client)
Strings(t, client)
NetAddr(t, client)
RawMessage(t, client)
Predicates(t, client)
}
@@ -187,6 +192,15 @@ func RawMessage(t *testing.T, client *ent.Client) {
require.Equal(t, raw, client.User.GetX(ctx, usr.ID).Raw)
}
func NetAddr(t *testing.T, client *ent.Client) {
ctx := context.Background()
ip := net.ParseIP("127.0.0.1")
usr := client.User.Create().SetAddr(schema.Addr{Addr: &net.TCPAddr{IP: ip, Port: 80}}).SaveX(ctx)
require.Equal(t, "127.0.0.1:80", client.User.GetX(ctx, usr.ID).Addr.String())
usr.Update().SetAddr(schema.Addr{Addr: &net.UDPAddr{IP: ip, Port: 1812}}).ExecX(ctx)
require.Equal(t, "127.0.0.1:1812", client.User.GetX(ctx, usr.ID).Addr.String())
}
func Dirs(t *testing.T, client *ent.Client) {
ctx := context.Background()
dirs := []http.Dir{"dev", "usr"}