mirror of
https://github.com/ent/ent.git
synced 2026-04-29 06:00:55 +03:00
entc/gen: allow selecting specific fields (#1075)
This commit is contained in:
@@ -7,6 +7,7 @@ package sqlgraph
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -839,18 +840,40 @@ type user struct {
|
||||
}
|
||||
}
|
||||
|
||||
func (*user) values() []interface{} {
|
||||
return []interface{}{&sql.NullInt64{}, &sql.NullInt64{}, &sql.NullString{}}
|
||||
func (*user) values(columns []string) ([]interface{}, error) {
|
||||
values := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
switch c := columns[i]; c {
|
||||
case "id", "age", "fk1", "fk2":
|
||||
values[i] = &sql.NullInt64{}
|
||||
case "name":
|
||||
values[i] = &sql.NullString{}
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected column %q", c)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func (u *user) assign(values ...interface{}) error {
|
||||
u.id = int(values[0].(*sql.NullInt64).Int64)
|
||||
u.age = int(values[1].(*sql.NullInt64).Int64)
|
||||
u.name = values[2].(*sql.NullString).String
|
||||
// loaded with foreign-keys.
|
||||
if len(values) > 3 {
|
||||
u.edges.fk1 = int(values[3].(*sql.NullInt64).Int64)
|
||||
u.edges.fk2 = int(values[4].(*sql.NullInt64).Int64)
|
||||
func (u *user) assign(columns []string, values []interface{}) error {
|
||||
if len(columns) != len(values) {
|
||||
return fmt.Errorf("mismatch number of values")
|
||||
}
|
||||
for i, c := range columns {
|
||||
switch c {
|
||||
case "id":
|
||||
u.id = int(values[i].(*sql.NullInt64).Int64)
|
||||
case "age":
|
||||
u.age = int(values[i].(*sql.NullInt64).Int64)
|
||||
case "name":
|
||||
u.name = values[i].(*sql.NullString).String
|
||||
case "fk1":
|
||||
u.edges.fk1 = int(values[i].(*sql.NullInt64).Int64)
|
||||
case "fk2":
|
||||
u.edges.fk2 = int(values[i].(*sql.NullInt64).Int64)
|
||||
default:
|
||||
return fmt.Errorf("unknown column %q", c)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1070,7 +1093,7 @@ func TestUpdateNode(t *testing.T) {
|
||||
tt.prepare(mock)
|
||||
usr := &user{}
|
||||
tt.spec.Assign = usr.assign
|
||||
tt.spec.ScanValues = usr.values()
|
||||
tt.spec.ScanValues = usr.values
|
||||
err = UpdateNode(context.Background(), sql.OpenDB("", db), tt.spec)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
require.Equal(t, tt.wantUser, usr)
|
||||
@@ -1325,13 +1348,13 @@ func TestQueryNodes(t *testing.T) {
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.LT("age", 40))
|
||||
},
|
||||
ScanValues: func() []interface{} {
|
||||
ScanValues: func(columns []string) ([]interface{}, error) {
|
||||
u := &user{}
|
||||
users = append(users, u)
|
||||
return append(u.values(), &sql.NullInt64{}, &sql.NullInt64{}) // extra values for fks.
|
||||
return u.values(columns)
|
||||
},
|
||||
Assign: func(values ...interface{}) error {
|
||||
return users[len(users)-1].assign(values...)
|
||||
Assign: func(columns []string, values []interface{}) error {
|
||||
return users[len(users)-1].assign(columns, values)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user