mirror of
https://github.com/ent/ent.git
synced 2026-04-28 13:40:56 +03:00
entc/gen: allow selecting specific fields (#1075)
This commit is contained in:
@@ -390,8 +390,8 @@ type (
|
||||
Fields FieldMut
|
||||
Predicate func(*sql.Selector)
|
||||
|
||||
ScanValues []interface{}
|
||||
Assign func(...interface{}) error
|
||||
ScanValues func(columns []string) ([]interface{}, error)
|
||||
Assign func(columns []string, values []interface{}) error
|
||||
}
|
||||
)
|
||||
|
||||
@@ -480,8 +480,8 @@ type QuerySpec struct {
|
||||
Order func(*sql.Selector)
|
||||
Predicate func(*sql.Selector)
|
||||
|
||||
ScanValues func() []interface{}
|
||||
Assign func(...interface{}) error
|
||||
ScanValues func(columns []string) ([]interface{}, error)
|
||||
Assign func(columns []string, values []interface{}) error
|
||||
}
|
||||
|
||||
// QueryNodes queries the nodes in the graph query and scans them to the given values.
|
||||
@@ -556,12 +556,19 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for rows.Next() {
|
||||
values := q.ScanValues()
|
||||
values, err := q.ScanValues(columns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.Assign(values...); err != nil {
|
||||
if err := q.Assign(columns, values); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -786,16 +793,27 @@ func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdge
|
||||
|
||||
func (u *updater) scan(rows *sql.Rows) error {
|
||||
defer rows.Close()
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value}
|
||||
}
|
||||
if err := rows.Scan(u.ScanValues...); err != nil {
|
||||
values, err := u.ScanValues(columns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
return fmt.Errorf("failed scanning rows: %v", err)
|
||||
}
|
||||
return u.Assign(u.ScanValues...)
|
||||
if err := u.Assign(columns, values); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type creator struct {
|
||||
|
||||
@@ -7,6 +7,7 @@ package sqlgraph
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -839,18 +840,40 @@ type user struct {
|
||||
}
|
||||
}
|
||||
|
||||
func (*user) values() []interface{} {
|
||||
return []interface{}{&sql.NullInt64{}, &sql.NullInt64{}, &sql.NullString{}}
|
||||
func (*user) values(columns []string) ([]interface{}, error) {
|
||||
values := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
switch c := columns[i]; c {
|
||||
case "id", "age", "fk1", "fk2":
|
||||
values[i] = &sql.NullInt64{}
|
||||
case "name":
|
||||
values[i] = &sql.NullString{}
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected column %q", c)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func (u *user) assign(values ...interface{}) error {
|
||||
u.id = int(values[0].(*sql.NullInt64).Int64)
|
||||
u.age = int(values[1].(*sql.NullInt64).Int64)
|
||||
u.name = values[2].(*sql.NullString).String
|
||||
// loaded with foreign-keys.
|
||||
if len(values) > 3 {
|
||||
u.edges.fk1 = int(values[3].(*sql.NullInt64).Int64)
|
||||
u.edges.fk2 = int(values[4].(*sql.NullInt64).Int64)
|
||||
func (u *user) assign(columns []string, values []interface{}) error {
|
||||
if len(columns) != len(values) {
|
||||
return fmt.Errorf("mismatch number of values")
|
||||
}
|
||||
for i, c := range columns {
|
||||
switch c {
|
||||
case "id":
|
||||
u.id = int(values[i].(*sql.NullInt64).Int64)
|
||||
case "age":
|
||||
u.age = int(values[i].(*sql.NullInt64).Int64)
|
||||
case "name":
|
||||
u.name = values[i].(*sql.NullString).String
|
||||
case "fk1":
|
||||
u.edges.fk1 = int(values[i].(*sql.NullInt64).Int64)
|
||||
case "fk2":
|
||||
u.edges.fk2 = int(values[i].(*sql.NullInt64).Int64)
|
||||
default:
|
||||
return fmt.Errorf("unknown column %q", c)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1070,7 +1093,7 @@ func TestUpdateNode(t *testing.T) {
|
||||
tt.prepare(mock)
|
||||
usr := &user{}
|
||||
tt.spec.Assign = usr.assign
|
||||
tt.spec.ScanValues = usr.values()
|
||||
tt.spec.ScanValues = usr.values
|
||||
err = UpdateNode(context.Background(), sql.OpenDB("", db), tt.spec)
|
||||
require.Equal(t, tt.wantErr, err != nil, err)
|
||||
require.Equal(t, tt.wantUser, usr)
|
||||
@@ -1325,13 +1348,13 @@ func TestQueryNodes(t *testing.T) {
|
||||
Predicate: func(s *sql.Selector) {
|
||||
s.Where(sql.LT("age", 40))
|
||||
},
|
||||
ScanValues: func() []interface{} {
|
||||
ScanValues: func(columns []string) ([]interface{}, error) {
|
||||
u := &user{}
|
||||
users = append(users, u)
|
||||
return append(u.values(), &sql.NullInt64{}, &sql.NullInt64{}) // extra values for fks.
|
||||
return u.values(columns)
|
||||
},
|
||||
Assign: func(values ...interface{}) error {
|
||||
return users[len(users)-1].assign(values...)
|
||||
Assign: func(columns []string, values []interface{}) error {
|
||||
return users[len(users)-1].assign(columns, values)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user