entc/gen: allow selecting specific fields (#1075)

This commit is contained in:
Ariel Mashraki
2020-12-23 17:35:39 +02:00
committed by GitHub
parent 0902673b40
commit da34571560
204 changed files with 3453 additions and 2891 deletions

View File

@@ -390,8 +390,8 @@ type (
Fields FieldMut
Predicate func(*sql.Selector)
ScanValues []interface{}
Assign func(...interface{}) error
ScanValues func(columns []string) ([]interface{}, error)
Assign func(columns []string, values []interface{}) error
}
)
@@ -480,8 +480,8 @@ type QuerySpec struct {
Order func(*sql.Selector)
Predicate func(*sql.Selector)
ScanValues func() []interface{}
Assign func(...interface{}) error
ScanValues func(columns []string) ([]interface{}, error)
Assign func(columns []string, values []interface{}) error
}
// QueryNodes queries the nodes in the graph query and scans them to the given values.
@@ -556,12 +556,19 @@ func (q *query) nodes(ctx context.Context, drv dialect.Driver) error {
return err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return err
}
for rows.Next() {
values := q.ScanValues()
values, err := q.ScanValues(columns)
if err != nil {
return err
}
if err := rows.Scan(values...); err != nil {
return err
}
if err := q.Assign(values...); err != nil {
if err := q.Assign(columns, values); err != nil {
return err
}
}
@@ -786,16 +793,27 @@ func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdge
func (u *updater) scan(rows *sql.Rows) error {
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return err
}
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value}
}
if err := rows.Scan(u.ScanValues...); err != nil {
values, err := u.ScanValues(columns)
if err != nil {
return err
}
if err := rows.Scan(values...); err != nil {
return fmt.Errorf("failed scanning rows: %v", err)
}
return u.Assign(u.ScanValues...)
if err := u.Assign(columns, values); err != nil {
return err
}
return nil
}
type creator struct {

View File

@@ -7,6 +7,7 @@ package sqlgraph
import (
"context"
"database/sql/driver"
"fmt"
"regexp"
"strings"
"testing"
@@ -839,18 +840,40 @@ type user struct {
}
}
func (*user) values() []interface{} {
return []interface{}{&sql.NullInt64{}, &sql.NullInt64{}, &sql.NullString{}}
func (*user) values(columns []string) ([]interface{}, error) {
values := make([]interface{}, len(columns))
for i := range columns {
switch c := columns[i]; c {
case "id", "age", "fk1", "fk2":
values[i] = &sql.NullInt64{}
case "name":
values[i] = &sql.NullString{}
default:
return nil, fmt.Errorf("unexpected column %q", c)
}
}
return values, nil
}
func (u *user) assign(values ...interface{}) error {
u.id = int(values[0].(*sql.NullInt64).Int64)
u.age = int(values[1].(*sql.NullInt64).Int64)
u.name = values[2].(*sql.NullString).String
// loaded with foreign-keys.
if len(values) > 3 {
u.edges.fk1 = int(values[3].(*sql.NullInt64).Int64)
u.edges.fk2 = int(values[4].(*sql.NullInt64).Int64)
func (u *user) assign(columns []string, values []interface{}) error {
if len(columns) != len(values) {
return fmt.Errorf("mismatch number of values")
}
for i, c := range columns {
switch c {
case "id":
u.id = int(values[i].(*sql.NullInt64).Int64)
case "age":
u.age = int(values[i].(*sql.NullInt64).Int64)
case "name":
u.name = values[i].(*sql.NullString).String
case "fk1":
u.edges.fk1 = int(values[i].(*sql.NullInt64).Int64)
case "fk2":
u.edges.fk2 = int(values[i].(*sql.NullInt64).Int64)
default:
return fmt.Errorf("unknown column %q", c)
}
}
return nil
}
@@ -1070,7 +1093,7 @@ func TestUpdateNode(t *testing.T) {
tt.prepare(mock)
usr := &user{}
tt.spec.Assign = usr.assign
tt.spec.ScanValues = usr.values()
tt.spec.ScanValues = usr.values
err = UpdateNode(context.Background(), sql.OpenDB("", db), tt.spec)
require.Equal(t, tt.wantErr, err != nil, err)
require.Equal(t, tt.wantUser, usr)
@@ -1325,13 +1348,13 @@ func TestQueryNodes(t *testing.T) {
Predicate: func(s *sql.Selector) {
s.Where(sql.LT("age", 40))
},
ScanValues: func() []interface{} {
ScanValues: func(columns []string) ([]interface{}, error) {
u := &user{}
users = append(users, u)
return append(u.values(), &sql.NullInt64{}, &sql.NullInt64{}) // extra values for fks.
return u.values(columns)
},
Assign: func(values ...interface{}) error {
return users[len(users)-1].assign(values...)
Assign: func(columns []string, values []interface{}) error {
return users[len(users)-1].assign(columns, values)
},
}
)