diff --git a/examples/edgeindex/ent/city.go b/examples/edgeindex/ent/city.go index 0f61a57c7..08802fdf7 100644 --- a/examples/edgeindex/ent/city.go +++ b/examples/edgeindex/ent/city.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/edgeindex/ent/city" ) // City is the model entity for the City schema. @@ -22,21 +23,31 @@ type City struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into City. -func (c *City) FromRows(rows *sql.Rows) error { - var scanc struct { - ID int - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*City) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `city.Columns`. - if err := rows.Scan( - &scanc.ID, - &scanc.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the City fields. +func (c *City) assignValues(values ...interface{}) error { + if m, n := len(values), len(city.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + c.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[0]) + } else if value.Valid { + c.Name = value.String } - c.ID = scanc.ID - c.Name = scanc.Name.String return nil } @@ -77,18 +88,6 @@ func (c *City) String() string { // Cities is a parsable slice of City. type Cities []*City -// FromRows scans the sql response data into Cities. -func (c *Cities) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanc := &City{} - if err := scanc.FromRows(rows); err != nil { - return err - } - *c = append(*c, scanc) - } - return nil -} - func (c Cities) config(cfg config) { for _i := range c { c[_i].config = cfg diff --git a/examples/edgeindex/ent/city_query.go b/examples/edgeindex/ent/city_query.go index bdce36491..d37cd7973 100644 --- a/examples/edgeindex/ent/city_query.go +++ b/examples/edgeindex/ent/city_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/edgeindex/ent/city" "github.com/facebookincubator/ent/examples/edgeindex/ent/predicate" "github.com/facebookincubator/ent/examples/edgeindex/ent/street" + "github.com/facebookincubator/ent/schema/field" ) // CityQuery is the builder for querying City entities. @@ -278,45 +279,31 @@ func (cq *CityQuery) Select(field string, fields ...string) *CitySelect { } func (cq *CityQuery) sqlAll(ctx context.Context) ([]*City, error) { - rows := &sql.Rows{} - selector := cq.sqlQuery() - if unique := cq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*City + spec = cq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &City{config: cq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := cq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, cq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var cs Cities - if err := cs.FromRows(rows); err != nil { - return nil, err - } - cs.config(cq.config) - return cs, nil + return nodes, nil } func (cq *CityQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := cq.sqlQuery() - unique := []string{city.FieldID} - if len(cq.unique) > 0 { - unique = cq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := cq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := cq.querySpec() + return sqlgraph.CountNodes(ctx, cq.driver, spec) } func (cq *CityQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (cq *CityQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (cq *CityQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: city.Table, + Columns: city.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: city.FieldID, + }, + }, + From: cq.sql, + Unique: true, + } + if ps := cq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := cq.limit; limit != nil { + spec.Limit = *limit + } + if offset := cq.offset; offset != nil { + spec.Offset = *offset + } + if ps := cq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (cq *CityQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(city.Table) @@ -598,7 +621,7 @@ func (cs *CitySelect) sqlScan(ctx context.Context, v interface{}) error { } func (cs *CitySelect) sqlQuery() sql.Querier { - view := "city_view" - return sql.Dialect(cs.driver.Dialect()). - Select(cs.fields...).From(cs.sql.As(view)) + selector := cs.sql + selector.Select(selector.Columns(cs.fields...)...) + return selector } diff --git a/examples/edgeindex/ent/city_update.go b/examples/edgeindex/ent/city_update.go index 6b4c7e629..f5380ee1a 100644 --- a/examples/edgeindex/ent/city_update.go +++ b/examples/edgeindex/ent/city_update.go @@ -8,7 +8,6 @@ package ent import ( "context" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -318,27 +317,8 @@ func (cuo *CityUpdateOne) sqlSave(ctx context.Context) (c *City, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } c = &City{config: cuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - c.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[0]) - } else if value.Valid { - c.Name = value.String - } - return nil - } + spec.Assign = c.assignValues + spec.ScanValues = c.scanValues() if err = sqlgraph.UpdateNode(ctx, cuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/edgeindex/ent/street.go b/examples/edgeindex/ent/street.go index 906d3f134..2221807be 100644 --- a/examples/edgeindex/ent/street.go +++ b/examples/edgeindex/ent/street.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/edgeindex/ent/street" ) // Street is the model entity for the Street schema. @@ -22,21 +23,31 @@ type Street struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into Street. -func (s *Street) FromRows(rows *sql.Rows) error { - var scans struct { - ID int - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*Street) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `street.Columns`. - if err := rows.Scan( - &scans.ID, - &scans.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Street fields. +func (s *Street) assignValues(values ...interface{}) error { + if m, n := len(values), len(street.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + s.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[0]) + } else if value.Valid { + s.Name = value.String } - s.ID = scans.ID - s.Name = scans.Name.String return nil } @@ -77,18 +88,6 @@ func (s *Street) String() string { // Streets is a parsable slice of Street. type Streets []*Street -// FromRows scans the sql response data into Streets. -func (s *Streets) FromRows(rows *sql.Rows) error { - for rows.Next() { - scans := &Street{} - if err := scans.FromRows(rows); err != nil { - return err - } - *s = append(*s, scans) - } - return nil -} - func (s Streets) config(cfg config) { for _i := range s { s[_i].config = cfg diff --git a/examples/edgeindex/ent/street_query.go b/examples/edgeindex/ent/street_query.go index a5842438e..b9ce562f7 100644 --- a/examples/edgeindex/ent/street_query.go +++ b/examples/edgeindex/ent/street_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/edgeindex/ent/city" "github.com/facebookincubator/ent/examples/edgeindex/ent/predicate" "github.com/facebookincubator/ent/examples/edgeindex/ent/street" + "github.com/facebookincubator/ent/schema/field" ) // StreetQuery is the builder for querying Street entities. @@ -278,45 +279,31 @@ func (sq *StreetQuery) Select(field string, fields ...string) *StreetSelect { } func (sq *StreetQuery) sqlAll(ctx context.Context) ([]*Street, error) { - rows := &sql.Rows{} - selector := sq.sqlQuery() - if unique := sq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Street + spec = sq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Street{config: sq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := sq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, sq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var sSlice Streets - if err := sSlice.FromRows(rows); err != nil { - return nil, err - } - sSlice.config(sq.config) - return sSlice, nil + return nodes, nil } func (sq *StreetQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := sq.sqlQuery() - unique := []string{street.FieldID} - if len(sq.unique) > 0 { - unique = sq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := sq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := sq.querySpec() + return sqlgraph.CountNodes(ctx, sq.driver, spec) } func (sq *StreetQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (sq *StreetQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (sq *StreetQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: street.Table, + Columns: street.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: street.FieldID, + }, + }, + From: sq.sql, + Unique: true, + } + if ps := sq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := sq.limit; limit != nil { + spec.Limit = *limit + } + if offset := sq.offset; offset != nil { + spec.Offset = *offset + } + if ps := sq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (sq *StreetQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(sq.driver.Dialect()) t1 := builder.Table(street.Table) @@ -598,7 +621,7 @@ func (ss *StreetSelect) sqlScan(ctx context.Context, v interface{}) error { } func (ss *StreetSelect) sqlQuery() sql.Querier { - view := "street_view" - return sql.Dialect(ss.driver.Dialect()). - Select(ss.fields...).From(ss.sql.As(view)) + selector := ss.sql + selector.Select(selector.Columns(ss.fields...)...) + return selector } diff --git a/examples/edgeindex/ent/street_update.go b/examples/edgeindex/ent/street_update.go index 4d255bd9a..6e8da07ed 100644 --- a/examples/edgeindex/ent/street_update.go +++ b/examples/edgeindex/ent/street_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -295,27 +294,8 @@ func (suo *StreetUpdateOne) sqlSave(ctx context.Context) (s *Street, err error) spec.Edges.Add = append(spec.Edges.Add, edge) } s = &Street{config: suo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - s.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[0]) - } else if value.Valid { - s.Name = value.String - } - return nil - } + spec.Assign = s.assignValues + spec.ScanValues = s.scanValues() if err = sqlgraph.UpdateNode(ctx, suo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/entcpkg/ent/user.go b/examples/entcpkg/ent/user.go index d6e04e0f2..e5d87ba83 100644 --- a/examples/entcpkg/ent/user.go +++ b/examples/entcpkg/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/entcpkg/ent/user" ) // User is the model entity for the User schema. @@ -20,18 +21,25 @@ type User struct { ID int `json:"id,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } - u.ID = scanu.ID + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] return nil } @@ -65,18 +73,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/entcpkg/ent/user_query.go b/examples/entcpkg/ent/user_query.go index f31dd6d2c..2eee933ef 100644 --- a/examples/entcpkg/ent/user_query.go +++ b/examples/entcpkg/ent/user_query.go @@ -13,8 +13,10 @@ import ( "math" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/dialect/sql/sqlgraph" "github.com/facebookincubator/ent/examples/entcpkg/ent/predicate" "github.com/facebookincubator/ent/examples/entcpkg/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -240,45 +242,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -289,6 +277,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -560,7 +584,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/entcpkg/ent/user_update.go b/examples/entcpkg/ent/user_update.go index 1b704c07f..56dc963e6 100644 --- a/examples/entcpkg/ent/user_update.go +++ b/examples/entcpkg/ent/user_update.go @@ -8,7 +8,6 @@ package ent import ( "context" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -129,21 +128,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { }, } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/m2m2types/ent/group.go b/examples/m2m2types/ent/group.go index 27245a60e..00a01931d 100644 --- a/examples/m2m2types/ent/group.go +++ b/examples/m2m2types/ent/group.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/m2m2types/ent/group" ) // Group is the model entity for the Group schema. @@ -22,21 +23,31 @@ type Group struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into Group. -func (gr *Group) FromRows(rows *sql.Rows) error { - var scangr struct { - ID int - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*Group) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `group.Columns`. - if err := rows.Scan( - &scangr.ID, - &scangr.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Group fields. +func (gr *Group) assignValues(values ...interface{}) error { + if m, n := len(values), len(group.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + gr.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[0]) + } else if value.Valid { + gr.Name = value.String } - gr.ID = scangr.ID - gr.Name = scangr.Name.String return nil } @@ -77,18 +88,6 @@ func (gr *Group) String() string { // Groups is a parsable slice of Group. type Groups []*Group -// FromRows scans the sql response data into Groups. -func (gr *Groups) FromRows(rows *sql.Rows) error { - for rows.Next() { - scangr := &Group{} - if err := scangr.FromRows(rows); err != nil { - return err - } - *gr = append(*gr, scangr) - } - return nil -} - func (gr Groups) config(cfg config) { for _i := range gr { gr[_i].config = cfg diff --git a/examples/m2m2types/ent/group_query.go b/examples/m2m2types/ent/group_query.go index 26b95aee4..9f1550e55 100644 --- a/examples/m2m2types/ent/group_query.go +++ b/examples/m2m2types/ent/group_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/m2m2types/ent/group" "github.com/facebookincubator/ent/examples/m2m2types/ent/predicate" "github.com/facebookincubator/ent/examples/m2m2types/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // GroupQuery is the builder for querying Group entities. @@ -278,45 +279,31 @@ func (gq *GroupQuery) Select(field string, fields ...string) *GroupSelect { } func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { - rows := &sql.Rows{} - selector := gq.sqlQuery() - if unique := gq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Group + spec = gq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Group{config: gq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := gq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, gq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var grs Groups - if err := grs.FromRows(rows); err != nil { - return nil, err - } - grs.config(gq.config) - return grs, nil + return nodes, nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := gq.sqlQuery() - unique := []string{group.FieldID} - if len(gq.unique) > 0 { - unique = gq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := gq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := gq.querySpec() + return sqlgraph.CountNodes(ctx, gq.driver, spec) } func (gq *GroupQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (gq *GroupQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: group.Table, + Columns: group.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: group.FieldID, + }, + }, + From: gq.sql, + Unique: true, + } + if ps := gq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := gq.limit; limit != nil { + spec.Limit = *limit + } + if offset := gq.offset; offset != nil { + spec.Offset = *offset + } + if ps := gq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (gq *GroupQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) @@ -598,7 +621,7 @@ func (gs *GroupSelect) sqlScan(ctx context.Context, v interface{}) error { } func (gs *GroupSelect) sqlQuery() sql.Querier { - view := "group_view" - return sql.Dialect(gs.driver.Dialect()). - Select(gs.fields...).From(gs.sql.As(view)) + selector := gs.sql + selector.Select(selector.Columns(gs.fields...)...) + return selector } diff --git a/examples/m2m2types/ent/group_update.go b/examples/m2m2types/ent/group_update.go index a07dc7c9c..0d9539a20 100644 --- a/examples/m2m2types/ent/group_update.go +++ b/examples/m2m2types/ent/group_update.go @@ -8,7 +8,6 @@ package ent import ( "context" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -318,27 +317,8 @@ func (guo *GroupUpdateOne) sqlSave(ctx context.Context) (gr *Group, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } gr = &Group{config: guo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - gr.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[0]) - } else if value.Valid { - gr.Name = value.String - } - return nil - } + spec.Assign = gr.assignValues + spec.ScanValues = gr.scanValues() if err = sqlgraph.UpdateNode(ctx, guo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/m2m2types/ent/user.go b/examples/m2m2types/ent/user.go index 41cee9d25..df3b95b53 100644 --- a/examples/m2m2types/ent/user.go +++ b/examples/m2m2types/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/m2m2types/ent/user" ) // User is the model entity for the User schema. @@ -24,24 +25,37 @@ type User struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int - Age sql.NullInt64 - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - &scanu.Age, - &scanu.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field age", values[0]) + } else if value.Valid { + u.Age = int(value.Int64) + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[1]) + } else if value.Valid { + u.Name = value.String } - u.ID = scanu.ID - u.Age = int(scanu.Age.Int64) - u.Name = scanu.Name.String return nil } @@ -84,18 +98,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/m2m2types/ent/user_query.go b/examples/m2m2types/ent/user_query.go index 8dd4e5513..20d4aba49 100644 --- a/examples/m2m2types/ent/user_query.go +++ b/examples/m2m2types/ent/user_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/m2m2types/ent/group" "github.com/facebookincubator/ent/examples/m2m2types/ent/predicate" "github.com/facebookincubator/ent/examples/m2m2types/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -278,45 +279,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -598,7 +621,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/m2m2types/ent/user_update.go b/examples/m2m2types/ent/user_update.go index 30990ea62..4b05092e2 100644 --- a/examples/m2m2types/ent/user_update.go +++ b/examples/m2m2types/ent/user_update.go @@ -8,7 +8,6 @@ package ent import ( "context" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -384,33 +383,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field age", values[0]) - } else if value.Valid { - u.Age = int(value.Int64) - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[1]) - } else if value.Valid { - u.Name = value.String - } - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/m2mbidi/ent/user.go b/examples/m2mbidi/ent/user.go index 3889299b4..d7bd7f051 100644 --- a/examples/m2mbidi/ent/user.go +++ b/examples/m2mbidi/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/m2mbidi/ent/user" ) // User is the model entity for the User schema. @@ -24,24 +25,37 @@ type User struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int - Age sql.NullInt64 - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - &scanu.Age, - &scanu.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field age", values[0]) + } else if value.Valid { + u.Age = int(value.Int64) + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[1]) + } else if value.Valid { + u.Name = value.String } - u.ID = scanu.ID - u.Age = int(scanu.Age.Int64) - u.Name = scanu.Name.String return nil } @@ -84,18 +98,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/m2mbidi/ent/user_query.go b/examples/m2mbidi/ent/user_query.go index d564d2b66..4351a0df6 100644 --- a/examples/m2mbidi/ent/user_query.go +++ b/examples/m2mbidi/ent/user_query.go @@ -16,6 +16,7 @@ import ( "github.com/facebookincubator/ent/dialect/sql/sqlgraph" "github.com/facebookincubator/ent/examples/m2mbidi/ent/predicate" "github.com/facebookincubator/ent/examples/m2mbidi/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -277,45 +278,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -326,6 +313,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -597,7 +620,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/m2mbidi/ent/user_update.go b/examples/m2mbidi/ent/user_update.go index e7e87a47b..c7b1d0833 100644 --- a/examples/m2mbidi/ent/user_update.go +++ b/examples/m2mbidi/ent/user_update.go @@ -8,7 +8,6 @@ package ent import ( "context" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -383,33 +382,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field age", values[0]) - } else if value.Valid { - u.Age = int(value.Int64) - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[1]) - } else if value.Valid { - u.Name = value.String - } - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/m2mrecur/ent/user.go b/examples/m2mrecur/ent/user.go index 859424101..0af01a579 100644 --- a/examples/m2mrecur/ent/user.go +++ b/examples/m2mrecur/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/m2mrecur/ent/user" ) // User is the model entity for the User schema. @@ -24,24 +25,37 @@ type User struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int - Age sql.NullInt64 - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - &scanu.Age, - &scanu.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field age", values[0]) + } else if value.Valid { + u.Age = int(value.Int64) + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[1]) + } else if value.Valid { + u.Name = value.String } - u.ID = scanu.ID - u.Age = int(scanu.Age.Int64) - u.Name = scanu.Name.String return nil } @@ -89,18 +103,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/m2mrecur/ent/user_query.go b/examples/m2mrecur/ent/user_query.go index 4cff10b6d..cafe6c5cf 100644 --- a/examples/m2mrecur/ent/user_query.go +++ b/examples/m2mrecur/ent/user_query.go @@ -16,6 +16,7 @@ import ( "github.com/facebookincubator/ent/dialect/sql/sqlgraph" "github.com/facebookincubator/ent/examples/m2mrecur/ent/predicate" "github.com/facebookincubator/ent/examples/m2mrecur/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -289,45 +290,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -338,6 +325,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -609,7 +632,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/m2mrecur/ent/user_update.go b/examples/m2mrecur/ent/user_update.go index 80845fd14..f5f909f2a 100644 --- a/examples/m2mrecur/ent/user_update.go +++ b/examples/m2mrecur/ent/user_update.go @@ -8,7 +8,6 @@ package ent import ( "context" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -543,33 +542,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field age", values[0]) - } else if value.Valid { - u.Age = int(value.Int64) - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[1]) - } else if value.Valid { - u.Name = value.String - } - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/o2m2types/ent/pet.go b/examples/o2m2types/ent/pet.go index 06ada8bde..b07ed2f31 100644 --- a/examples/o2m2types/ent/pet.go +++ b/examples/o2m2types/ent/pet.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/o2m2types/ent/pet" ) // Pet is the model entity for the Pet schema. @@ -22,21 +23,31 @@ type Pet struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into Pet. -func (pe *Pet) FromRows(rows *sql.Rows) error { - var scanpe struct { - ID int - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*Pet) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `pet.Columns`. - if err := rows.Scan( - &scanpe.ID, - &scanpe.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Pet fields. +func (pe *Pet) assignValues(values ...interface{}) error { + if m, n := len(values), len(pet.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + pe.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[0]) + } else if value.Valid { + pe.Name = value.String } - pe.ID = scanpe.ID - pe.Name = scanpe.Name.String return nil } @@ -77,18 +88,6 @@ func (pe *Pet) String() string { // Pets is a parsable slice of Pet. type Pets []*Pet -// FromRows scans the sql response data into Pets. -func (pe *Pets) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanpe := &Pet{} - if err := scanpe.FromRows(rows); err != nil { - return err - } - *pe = append(*pe, scanpe) - } - return nil -} - func (pe Pets) config(cfg config) { for _i := range pe { pe[_i].config = cfg diff --git a/examples/o2m2types/ent/pet_query.go b/examples/o2m2types/ent/pet_query.go index 304ace3f1..72a0648ec 100644 --- a/examples/o2m2types/ent/pet_query.go +++ b/examples/o2m2types/ent/pet_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/o2m2types/ent/pet" "github.com/facebookincubator/ent/examples/o2m2types/ent/predicate" "github.com/facebookincubator/ent/examples/o2m2types/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // PetQuery is the builder for querying Pet entities. @@ -278,45 +279,31 @@ func (pq *PetQuery) Select(field string, fields ...string) *PetSelect { } func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { - rows := &sql.Rows{} - selector := pq.sqlQuery() - if unique := pq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Pet + spec = pq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Pet{config: pq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := pq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, pq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var pes Pets - if err := pes.FromRows(rows); err != nil { - return nil, err - } - pes.config(pq.config) - return pes, nil + return nodes, nil } func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := pq.sqlQuery() - unique := []string{pet.FieldID} - if len(pq.unique) > 0 { - unique = pq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := pq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := pq.querySpec() + return sqlgraph.CountNodes(ctx, pq.driver, spec) } func (pq *PetQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (pq *PetQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: pet.Table, + Columns: pet.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: pet.FieldID, + }, + }, + From: pq.sql, + Unique: true, + } + if ps := pq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := pq.limit; limit != nil { + spec.Limit = *limit + } + if offset := pq.offset; offset != nil { + spec.Offset = *offset + } + if ps := pq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (pq *PetQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) @@ -598,7 +621,7 @@ func (ps *PetSelect) sqlScan(ctx context.Context, v interface{}) error { } func (ps *PetSelect) sqlQuery() sql.Querier { - view := "pet_view" - return sql.Dialect(ps.driver.Dialect()). - Select(ps.fields...).From(ps.sql.As(view)) + selector := ps.sql + selector.Select(selector.Columns(ps.fields...)...) + return selector } diff --git a/examples/o2m2types/ent/pet_update.go b/examples/o2m2types/ent/pet_update.go index a8683c024..2769925e4 100644 --- a/examples/o2m2types/ent/pet_update.go +++ b/examples/o2m2types/ent/pet_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -295,27 +294,8 @@ func (puo *PetUpdateOne) sqlSave(ctx context.Context) (pe *Pet, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } pe = &Pet{config: puo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - pe.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[0]) - } else if value.Valid { - pe.Name = value.String - } - return nil - } + spec.Assign = pe.assignValues + spec.ScanValues = pe.scanValues() if err = sqlgraph.UpdateNode(ctx, puo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/o2m2types/ent/user.go b/examples/o2m2types/ent/user.go index 8cae7c7e1..701b37812 100644 --- a/examples/o2m2types/ent/user.go +++ b/examples/o2m2types/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/o2m2types/ent/user" ) // User is the model entity for the User schema. @@ -24,24 +25,37 @@ type User struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int - Age sql.NullInt64 - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - &scanu.Age, - &scanu.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field age", values[0]) + } else if value.Valid { + u.Age = int(value.Int64) + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[1]) + } else if value.Valid { + u.Name = value.String } - u.ID = scanu.ID - u.Age = int(scanu.Age.Int64) - u.Name = scanu.Name.String return nil } @@ -84,18 +98,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/o2m2types/ent/user_query.go b/examples/o2m2types/ent/user_query.go index 2e8365638..6834013be 100644 --- a/examples/o2m2types/ent/user_query.go +++ b/examples/o2m2types/ent/user_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/o2m2types/ent/pet" "github.com/facebookincubator/ent/examples/o2m2types/ent/predicate" "github.com/facebookincubator/ent/examples/o2m2types/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -278,45 +279,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -598,7 +621,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/o2m2types/ent/user_update.go b/examples/o2m2types/ent/user_update.go index 566c21930..1893bdfbb 100644 --- a/examples/o2m2types/ent/user_update.go +++ b/examples/o2m2types/ent/user_update.go @@ -8,7 +8,6 @@ package ent import ( "context" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -384,33 +383,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field age", values[0]) - } else if value.Valid { - u.Age = int(value.Int64) - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[1]) - } else if value.Valid { - u.Name = value.String - } - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/o2mrecur/ent/node.go b/examples/o2mrecur/ent/node.go index 4da635695..57258adb0 100644 --- a/examples/o2mrecur/ent/node.go +++ b/examples/o2mrecur/ent/node.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/o2mrecur/ent/node" ) // Node is the model entity for the Node schema. @@ -22,21 +23,31 @@ type Node struct { Value int `json:"value,omitempty"` } -// FromRows scans the sql response data into Node. -func (n *Node) FromRows(rows *sql.Rows) error { - var scann struct { - ID int - Value sql.NullInt64 +// scanValues returns the types for scanning values from sql.Rows. +func (*Node) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, } - // the order here should be the same as in the `node.Columns`. - if err := rows.Scan( - &scann.ID, - &scann.Value, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Node fields. +func (n *Node) assignValues(values ...interface{}) error { + if m, n := len(values), len(node.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + n.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field value", values[0]) + } else if value.Valid { + n.Value = int(value.Int64) } - n.ID = scann.ID - n.Value = int(scann.Value.Int64) return nil } @@ -82,18 +93,6 @@ func (n *Node) String() string { // Nodes is a parsable slice of Node. type Nodes []*Node -// FromRows scans the sql response data into Nodes. -func (n *Nodes) FromRows(rows *sql.Rows) error { - for rows.Next() { - scann := &Node{} - if err := scann.FromRows(rows); err != nil { - return err - } - *n = append(*n, scann) - } - return nil -} - func (n Nodes) config(cfg config) { for _i := range n { n[_i].config = cfg diff --git a/examples/o2mrecur/ent/node_query.go b/examples/o2mrecur/ent/node_query.go index 77431ba35..a8db7c359 100644 --- a/examples/o2mrecur/ent/node_query.go +++ b/examples/o2mrecur/ent/node_query.go @@ -16,6 +16,7 @@ import ( "github.com/facebookincubator/ent/dialect/sql/sqlgraph" "github.com/facebookincubator/ent/examples/o2mrecur/ent/node" "github.com/facebookincubator/ent/examples/o2mrecur/ent/predicate" + "github.com/facebookincubator/ent/schema/field" ) // NodeQuery is the builder for querying Node entities. @@ -289,45 +290,31 @@ func (nq *NodeQuery) Select(field string, fields ...string) *NodeSelect { } func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { - rows := &sql.Rows{} - selector := nq.sqlQuery() - if unique := nq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Node + spec = nq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Node{config: nq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := nq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, nq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var ns Nodes - if err := ns.FromRows(rows); err != nil { - return nil, err - } - ns.config(nq.config) - return ns, nil + return nodes, nil } func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := nq.sqlQuery() - unique := []string{node.FieldID} - if len(nq.unique) > 0 { - unique = nq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := nq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := nq.querySpec() + return sqlgraph.CountNodes(ctx, nq.driver, spec) } func (nq *NodeQuery) sqlExist(ctx context.Context) (bool, error) { @@ -338,6 +325,42 @@ func (nq *NodeQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: node.Table, + Columns: node.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: node.FieldID, + }, + }, + From: nq.sql, + Unique: true, + } + if ps := nq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := nq.limit; limit != nil { + spec.Limit = *limit + } + if offset := nq.offset; offset != nil { + spec.Offset = *offset + } + if ps := nq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (nq *NodeQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(nq.driver.Dialect()) t1 := builder.Table(node.Table) @@ -609,7 +632,7 @@ func (ns *NodeSelect) sqlScan(ctx context.Context, v interface{}) error { } func (ns *NodeSelect) sqlQuery() sql.Querier { - view := "node_view" - return sql.Dialect(ns.driver.Dialect()). - Select(ns.fields...).From(ns.sql.As(view)) + selector := ns.sql + selector.Select(selector.Columns(ns.fields...)...) + return selector } diff --git a/examples/o2mrecur/ent/node_update.go b/examples/o2mrecur/ent/node_update.go index 67cdac708..3a2091800 100644 --- a/examples/o2mrecur/ent/node_update.go +++ b/examples/o2mrecur/ent/node_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -492,27 +491,8 @@ func (nuo *NodeUpdateOne) sqlSave(ctx context.Context) (n *Node, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } n = &Node{config: nuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - n.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field value", values[0]) - } else if value.Valid { - n.Value = int(value.Int64) - } - return nil - } + spec.Assign = n.assignValues + spec.ScanValues = n.scanValues() if err = sqlgraph.UpdateNode(ctx, nuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/o2o2types/ent/card.go b/examples/o2o2types/ent/card.go index c454170a3..12097d610 100644 --- a/examples/o2o2types/ent/card.go +++ b/examples/o2o2types/ent/card.go @@ -12,6 +12,7 @@ import ( "time" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/o2o2types/ent/card" ) // Card is the model entity for the Card schema. @@ -25,24 +26,37 @@ type Card struct { Number string `json:"number,omitempty"` } -// FromRows scans the sql response data into Card. -func (c *Card) FromRows(rows *sql.Rows) error { - var scanc struct { - ID int - Expired sql.NullTime - Number sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*Card) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullTime{}, + &sql.NullString{}, } - // the order here should be the same as in the `card.Columns`. - if err := rows.Scan( - &scanc.ID, - &scanc.Expired, - &scanc.Number, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Card fields. +func (c *Card) assignValues(values ...interface{}) error { + if m, n := len(values), len(card.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + c.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expired", values[0]) + } else if value.Valid { + c.Expired = value.Time + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field number", values[1]) + } else if value.Valid { + c.Number = value.String } - c.ID = scanc.ID - c.Expired = scanc.Expired.Time - c.Number = scanc.Number.String return nil } @@ -85,18 +99,6 @@ func (c *Card) String() string { // Cards is a parsable slice of Card. type Cards []*Card -// FromRows scans the sql response data into Cards. -func (c *Cards) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanc := &Card{} - if err := scanc.FromRows(rows); err != nil { - return err - } - *c = append(*c, scanc) - } - return nil -} - func (c Cards) config(cfg config) { for _i := range c { c[_i].config = cfg diff --git a/examples/o2o2types/ent/card_query.go b/examples/o2o2types/ent/card_query.go index 77d78c3e2..8cbcaaafc 100644 --- a/examples/o2o2types/ent/card_query.go +++ b/examples/o2o2types/ent/card_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/o2o2types/ent/card" "github.com/facebookincubator/ent/examples/o2o2types/ent/predicate" "github.com/facebookincubator/ent/examples/o2o2types/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // CardQuery is the builder for querying Card entities. @@ -278,45 +279,31 @@ func (cq *CardQuery) Select(field string, fields ...string) *CardSelect { } func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { - rows := &sql.Rows{} - selector := cq.sqlQuery() - if unique := cq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Card + spec = cq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Card{config: cq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := cq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, cq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var cs Cards - if err := cs.FromRows(rows); err != nil { - return nil, err - } - cs.config(cq.config) - return cs, nil + return nodes, nil } func (cq *CardQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := cq.sqlQuery() - unique := []string{card.FieldID} - if len(cq.unique) > 0 { - unique = cq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := cq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := cq.querySpec() + return sqlgraph.CountNodes(ctx, cq.driver, spec) } func (cq *CardQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (cq *CardQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: card.Table, + Columns: card.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: card.FieldID, + }, + }, + From: cq.sql, + Unique: true, + } + if ps := cq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := cq.limit; limit != nil { + spec.Limit = *limit + } + if offset := cq.offset; offset != nil { + spec.Offset = *offset + } + if ps := cq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (cq *CardQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(card.Table) @@ -598,7 +621,7 @@ func (cs *CardSelect) sqlScan(ctx context.Context, v interface{}) error { } func (cs *CardSelect) sqlQuery() sql.Querier { - view := "card_view" - return sql.Dialect(cs.driver.Dialect()). - Select(cs.fields...).From(cs.sql.As(view)) + selector := cs.sql + selector.Select(selector.Columns(cs.fields...)...) + return selector } diff --git a/examples/o2o2types/ent/card_update.go b/examples/o2o2types/ent/card_update.go index 9acff830b..727fecefd 100644 --- a/examples/o2o2types/ent/card_update.go +++ b/examples/o2o2types/ent/card_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "time" "github.com/facebookincubator/ent/dialect/sql" @@ -314,33 +313,8 @@ func (cuo *CardUpdateOne) sqlSave(ctx context.Context) (c *Card, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } c = &Card{config: cuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullTime{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - c.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullTime); !ok { - return fmt.Errorf("unexpected type %T for field expired", values[0]) - } else if value.Valid { - c.Expired = value.Time - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field number", values[1]) - } else if value.Valid { - c.Number = value.String - } - return nil - } + spec.Assign = c.assignValues + spec.ScanValues = c.scanValues() if err = sqlgraph.UpdateNode(ctx, cuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/o2o2types/ent/user.go b/examples/o2o2types/ent/user.go index 227debdc5..33ea8ddd4 100644 --- a/examples/o2o2types/ent/user.go +++ b/examples/o2o2types/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/o2o2types/ent/user" ) // User is the model entity for the User schema. @@ -24,24 +25,37 @@ type User struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int - Age sql.NullInt64 - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - &scanu.Age, - &scanu.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field age", values[0]) + } else if value.Valid { + u.Age = int(value.Int64) + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[1]) + } else if value.Valid { + u.Name = value.String } - u.ID = scanu.ID - u.Age = int(scanu.Age.Int64) - u.Name = scanu.Name.String return nil } @@ -84,18 +98,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/o2o2types/ent/user_query.go b/examples/o2o2types/ent/user_query.go index 3bb4c44f5..4fe507174 100644 --- a/examples/o2o2types/ent/user_query.go +++ b/examples/o2o2types/ent/user_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/o2o2types/ent/card" "github.com/facebookincubator/ent/examples/o2o2types/ent/predicate" "github.com/facebookincubator/ent/examples/o2o2types/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -278,45 +279,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -598,7 +621,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/o2o2types/ent/user_update.go b/examples/o2o2types/ent/user_update.go index f0ebb4d4c..d3404e5e6 100644 --- a/examples/o2o2types/ent/user_update.go +++ b/examples/o2o2types/ent/user_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -361,33 +360,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field age", values[0]) - } else if value.Valid { - u.Age = int(value.Int64) - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[1]) - } else if value.Valid { - u.Name = value.String - } - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/o2obidi/ent/user.go b/examples/o2obidi/ent/user.go index 81d28d0b9..45747e74f 100644 --- a/examples/o2obidi/ent/user.go +++ b/examples/o2obidi/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/o2obidi/ent/user" ) // User is the model entity for the User schema. @@ -24,24 +25,37 @@ type User struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int - Age sql.NullInt64 - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - &scanu.Age, - &scanu.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field age", values[0]) + } else if value.Valid { + u.Age = int(value.Int64) + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[1]) + } else if value.Valid { + u.Name = value.String } - u.ID = scanu.ID - u.Age = int(scanu.Age.Int64) - u.Name = scanu.Name.String return nil } @@ -84,18 +98,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/o2obidi/ent/user_query.go b/examples/o2obidi/ent/user_query.go index 98d210dcc..d4c01ee48 100644 --- a/examples/o2obidi/ent/user_query.go +++ b/examples/o2obidi/ent/user_query.go @@ -16,6 +16,7 @@ import ( "github.com/facebookincubator/ent/dialect/sql/sqlgraph" "github.com/facebookincubator/ent/examples/o2obidi/ent/predicate" "github.com/facebookincubator/ent/examples/o2obidi/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -277,45 +278,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -326,6 +313,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -597,7 +620,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/o2obidi/ent/user_update.go b/examples/o2obidi/ent/user_update.go index 41f5aeb24..566a21027 100644 --- a/examples/o2obidi/ent/user_update.go +++ b/examples/o2obidi/ent/user_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -360,33 +359,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field age", values[0]) - } else if value.Valid { - u.Age = int(value.Int64) - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[1]) - } else if value.Valid { - u.Name = value.String - } - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/o2orecur/ent/node.go b/examples/o2orecur/ent/node.go index 1a426c9e2..fdd708fba 100644 --- a/examples/o2orecur/ent/node.go +++ b/examples/o2orecur/ent/node.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/o2orecur/ent/node" ) // Node is the model entity for the Node schema. @@ -22,21 +23,31 @@ type Node struct { Value int `json:"value,omitempty"` } -// FromRows scans the sql response data into Node. -func (n *Node) FromRows(rows *sql.Rows) error { - var scann struct { - ID int - Value sql.NullInt64 +// scanValues returns the types for scanning values from sql.Rows. +func (*Node) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, } - // the order here should be the same as in the `node.Columns`. - if err := rows.Scan( - &scann.ID, - &scann.Value, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Node fields. +func (n *Node) assignValues(values ...interface{}) error { + if m, n := len(values), len(node.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + n.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field value", values[0]) + } else if value.Valid { + n.Value = int(value.Int64) } - n.ID = scann.ID - n.Value = int(scann.Value.Int64) return nil } @@ -82,18 +93,6 @@ func (n *Node) String() string { // Nodes is a parsable slice of Node. type Nodes []*Node -// FromRows scans the sql response data into Nodes. -func (n *Nodes) FromRows(rows *sql.Rows) error { - for rows.Next() { - scann := &Node{} - if err := scann.FromRows(rows); err != nil { - return err - } - *n = append(*n, scann) - } - return nil -} - func (n Nodes) config(cfg config) { for _i := range n { n[_i].config = cfg diff --git a/examples/o2orecur/ent/node_query.go b/examples/o2orecur/ent/node_query.go index b82b545be..79053f401 100644 --- a/examples/o2orecur/ent/node_query.go +++ b/examples/o2orecur/ent/node_query.go @@ -16,6 +16,7 @@ import ( "github.com/facebookincubator/ent/dialect/sql/sqlgraph" "github.com/facebookincubator/ent/examples/o2orecur/ent/node" "github.com/facebookincubator/ent/examples/o2orecur/ent/predicate" + "github.com/facebookincubator/ent/schema/field" ) // NodeQuery is the builder for querying Node entities. @@ -289,45 +290,31 @@ func (nq *NodeQuery) Select(field string, fields ...string) *NodeSelect { } func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { - rows := &sql.Rows{} - selector := nq.sqlQuery() - if unique := nq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Node + spec = nq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Node{config: nq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := nq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, nq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var ns Nodes - if err := ns.FromRows(rows); err != nil { - return nil, err - } - ns.config(nq.config) - return ns, nil + return nodes, nil } func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := nq.sqlQuery() - unique := []string{node.FieldID} - if len(nq.unique) > 0 { - unique = nq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := nq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := nq.querySpec() + return sqlgraph.CountNodes(ctx, nq.driver, spec) } func (nq *NodeQuery) sqlExist(ctx context.Context) (bool, error) { @@ -338,6 +325,42 @@ func (nq *NodeQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: node.Table, + Columns: node.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: node.FieldID, + }, + }, + From: nq.sql, + Unique: true, + } + if ps := nq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := nq.limit; limit != nil { + spec.Limit = *limit + } + if offset := nq.offset; offset != nil { + spec.Offset = *offset + } + if ps := nq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (nq *NodeQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(nq.driver.Dialect()) t1 := builder.Table(node.Table) @@ -609,7 +632,7 @@ func (ns *NodeSelect) sqlScan(ctx context.Context, v interface{}) error { } func (ns *NodeSelect) sqlQuery() sql.Querier { - view := "node_view" - return sql.Dialect(ns.driver.Dialect()). - Select(ns.fields...).From(ns.sql.As(view)) + selector := ns.sql + selector.Select(selector.Columns(ns.fields...)...) + return selector } diff --git a/examples/o2orecur/ent/node_update.go b/examples/o2orecur/ent/node_update.go index f92ebf40e..4f37a46ad 100644 --- a/examples/o2orecur/ent/node_update.go +++ b/examples/o2orecur/ent/node_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -468,27 +467,8 @@ func (nuo *NodeUpdateOne) sqlSave(ctx context.Context) (n *Node, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } n = &Node{config: nuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - n.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field value", values[0]) - } else if value.Valid { - n.Value = int(value.Int64) - } - return nil - } + spec.Assign = n.assignValues + spec.ScanValues = n.scanValues() if err = sqlgraph.UpdateNode(ctx, nuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/start/ent/car.go b/examples/start/ent/car.go index 7485d5c30..c9281d2cf 100644 --- a/examples/start/ent/car.go +++ b/examples/start/ent/car.go @@ -12,6 +12,7 @@ import ( "time" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/start/ent/car" ) // Car is the model entity for the Car schema. @@ -25,24 +26,37 @@ type Car struct { RegisteredAt time.Time `json:"registered_at,omitempty"` } -// FromRows scans the sql response data into Car. -func (c *Car) FromRows(rows *sql.Rows) error { - var scanc struct { - ID int - Model sql.NullString - RegisteredAt sql.NullTime +// scanValues returns the types for scanning values from sql.Rows. +func (*Car) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullString{}, + &sql.NullTime{}, } - // the order here should be the same as in the `car.Columns`. - if err := rows.Scan( - &scanc.ID, - &scanc.Model, - &scanc.RegisteredAt, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Car fields. +func (c *Car) assignValues(values ...interface{}) error { + if m, n := len(values), len(car.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + c.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field model", values[0]) + } else if value.Valid { + c.Model = value.String + } + if value, ok := values[1].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field registered_at", values[1]) + } else if value.Valid { + c.RegisteredAt = value.Time } - c.ID = scanc.ID - c.Model = scanc.Model.String - c.RegisteredAt = scanc.RegisteredAt.Time return nil } @@ -85,18 +99,6 @@ func (c *Car) String() string { // Cars is a parsable slice of Car. type Cars []*Car -// FromRows scans the sql response data into Cars. -func (c *Cars) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanc := &Car{} - if err := scanc.FromRows(rows); err != nil { - return err - } - *c = append(*c, scanc) - } - return nil -} - func (c Cars) config(cfg config) { for _i := range c { c[_i].config = cfg diff --git a/examples/start/ent/car_query.go b/examples/start/ent/car_query.go index 325b71445..86a507232 100644 --- a/examples/start/ent/car_query.go +++ b/examples/start/ent/car_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/start/ent/car" "github.com/facebookincubator/ent/examples/start/ent/predicate" "github.com/facebookincubator/ent/examples/start/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // CarQuery is the builder for querying Car entities. @@ -278,45 +279,31 @@ func (cq *CarQuery) Select(field string, fields ...string) *CarSelect { } func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { - rows := &sql.Rows{} - selector := cq.sqlQuery() - if unique := cq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Car + spec = cq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Car{config: cq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := cq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, cq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var cs Cars - if err := cs.FromRows(rows); err != nil { - return nil, err - } - cs.config(cq.config) - return cs, nil + return nodes, nil } func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := cq.sqlQuery() - unique := []string{car.FieldID} - if len(cq.unique) > 0 { - unique = cq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := cq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := cq.querySpec() + return sqlgraph.CountNodes(ctx, cq.driver, spec) } func (cq *CarQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (cq *CarQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: car.Table, + Columns: car.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: car.FieldID, + }, + }, + From: cq.sql, + Unique: true, + } + if ps := cq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := cq.limit; limit != nil { + spec.Limit = *limit + } + if offset := cq.offset; offset != nil { + spec.Offset = *offset + } + if ps := cq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (cq *CarQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(car.Table) @@ -598,7 +621,7 @@ func (cs *CarSelect) sqlScan(ctx context.Context, v interface{}) error { } func (cs *CarSelect) sqlQuery() sql.Querier { - view := "car_view" - return sql.Dialect(cs.driver.Dialect()). - Select(cs.fields...).From(cs.sql.As(view)) + selector := cs.sql + selector.Select(selector.Columns(cs.fields...)...) + return selector } diff --git a/examples/start/ent/car_update.go b/examples/start/ent/car_update.go index 69bf1362b..282a00b70 100644 --- a/examples/start/ent/car_update.go +++ b/examples/start/ent/car_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "time" "github.com/facebookincubator/ent/dialect/sql" @@ -324,33 +323,8 @@ func (cuo *CarUpdateOne) sqlSave(ctx context.Context) (c *Car, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } c = &Car{config: cuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullString{}, - &sql.NullTime{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - c.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field model", values[0]) - } else if value.Valid { - c.Model = value.String - } - if value, ok := values[1].(*sql.NullTime); !ok { - return fmt.Errorf("unexpected type %T for field registered_at", values[1]) - } else if value.Valid { - c.RegisteredAt = value.Time - } - return nil - } + spec.Assign = c.assignValues + spec.ScanValues = c.scanValues() if err = sqlgraph.UpdateNode(ctx, cuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/start/ent/group.go b/examples/start/ent/group.go index 27245a60e..15b92a849 100644 --- a/examples/start/ent/group.go +++ b/examples/start/ent/group.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/start/ent/group" ) // Group is the model entity for the Group schema. @@ -22,21 +23,31 @@ type Group struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into Group. -func (gr *Group) FromRows(rows *sql.Rows) error { - var scangr struct { - ID int - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*Group) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `group.Columns`. - if err := rows.Scan( - &scangr.ID, - &scangr.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Group fields. +func (gr *Group) assignValues(values ...interface{}) error { + if m, n := len(values), len(group.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + gr.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[0]) + } else if value.Valid { + gr.Name = value.String } - gr.ID = scangr.ID - gr.Name = scangr.Name.String return nil } @@ -77,18 +88,6 @@ func (gr *Group) String() string { // Groups is a parsable slice of Group. type Groups []*Group -// FromRows scans the sql response data into Groups. -func (gr *Groups) FromRows(rows *sql.Rows) error { - for rows.Next() { - scangr := &Group{} - if err := scangr.FromRows(rows); err != nil { - return err - } - *gr = append(*gr, scangr) - } - return nil -} - func (gr Groups) config(cfg config) { for _i := range gr { gr[_i].config = cfg diff --git a/examples/start/ent/group_query.go b/examples/start/ent/group_query.go index 065a8cdd8..fc166621a 100644 --- a/examples/start/ent/group_query.go +++ b/examples/start/ent/group_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/start/ent/group" "github.com/facebookincubator/ent/examples/start/ent/predicate" "github.com/facebookincubator/ent/examples/start/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // GroupQuery is the builder for querying Group entities. @@ -278,45 +279,31 @@ func (gq *GroupQuery) Select(field string, fields ...string) *GroupSelect { } func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { - rows := &sql.Rows{} - selector := gq.sqlQuery() - if unique := gq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Group + spec = gq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Group{config: gq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := gq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, gq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var grs Groups - if err := grs.FromRows(rows); err != nil { - return nil, err - } - grs.config(gq.config) - return grs, nil + return nodes, nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := gq.sqlQuery() - unique := []string{group.FieldID} - if len(gq.unique) > 0 { - unique = gq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := gq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := gq.querySpec() + return sqlgraph.CountNodes(ctx, gq.driver, spec) } func (gq *GroupQuery) sqlExist(ctx context.Context) (bool, error) { @@ -327,6 +314,42 @@ func (gq *GroupQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: group.Table, + Columns: group.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: group.FieldID, + }, + }, + From: gq.sql, + Unique: true, + } + if ps := gq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := gq.limit; limit != nil { + spec.Limit = *limit + } + if offset := gq.offset; offset != nil { + spec.Offset = *offset + } + if ps := gq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (gq *GroupQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) @@ -598,7 +621,7 @@ func (gs *GroupSelect) sqlScan(ctx context.Context, v interface{}) error { } func (gs *GroupSelect) sqlQuery() sql.Querier { - view := "group_view" - return sql.Dialect(gs.driver.Dialect()). - Select(gs.fields...).From(gs.sql.As(view)) + selector := gs.sql + selector.Select(selector.Columns(gs.fields...)...) + return selector } diff --git a/examples/start/ent/group_update.go b/examples/start/ent/group_update.go index 2d66a2550..d817e484c 100644 --- a/examples/start/ent/group_update.go +++ b/examples/start/ent/group_update.go @@ -328,27 +328,8 @@ func (guo *GroupUpdateOne) sqlSave(ctx context.Context) (gr *Group, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } gr = &Group{config: guo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - gr.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[0]) - } else if value.Valid { - gr.Name = value.String - } - return nil - } + spec.Assign = gr.assignValues + spec.ScanValues = gr.scanValues() if err = sqlgraph.UpdateNode(ctx, guo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/start/ent/user.go b/examples/start/ent/user.go index a34a88e85..1fbd500e8 100644 --- a/examples/start/ent/user.go +++ b/examples/start/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/start/ent/user" ) // User is the model entity for the User schema. @@ -24,24 +25,37 @@ type User struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int - Age sql.NullInt64 - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - &scanu.Age, - &scanu.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field age", values[0]) + } else if value.Valid { + u.Age = int(value.Int64) + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[1]) + } else if value.Valid { + u.Name = value.String } - u.ID = scanu.ID - u.Age = int(scanu.Age.Int64) - u.Name = scanu.Name.String return nil } @@ -89,18 +103,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/start/ent/user_query.go b/examples/start/ent/user_query.go index d4ed79e48..c706064c5 100644 --- a/examples/start/ent/user_query.go +++ b/examples/start/ent/user_query.go @@ -18,6 +18,7 @@ import ( "github.com/facebookincubator/ent/examples/start/ent/group" "github.com/facebookincubator/ent/examples/start/ent/predicate" "github.com/facebookincubator/ent/examples/start/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -291,45 +292,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -340,6 +327,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -611,7 +634,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/start/ent/user_update.go b/examples/start/ent/user_update.go index 2208ef15e..856f0c1d9 100644 --- a/examples/start/ent/user_update.go +++ b/examples/start/ent/user_update.go @@ -571,33 +571,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field age", values[0]) - } else if value.Valid { - u.Age = int(value.Int64) - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[1]) - } else if value.Valid { - u.Name = value.String - } - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/traversal/ent/group.go b/examples/traversal/ent/group.go index 6edee4566..9f65c54c4 100644 --- a/examples/traversal/ent/group.go +++ b/examples/traversal/ent/group.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/traversal/ent/group" ) // Group is the model entity for the Group schema. @@ -22,21 +23,31 @@ type Group struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into Group. -func (gr *Group) FromRows(rows *sql.Rows) error { - var scangr struct { - ID int - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*Group) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `group.Columns`. - if err := rows.Scan( - &scangr.ID, - &scangr.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Group fields. +func (gr *Group) assignValues(values ...interface{}) error { + if m, n := len(values), len(group.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + gr.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[0]) + } else if value.Valid { + gr.Name = value.String } - gr.ID = scangr.ID - gr.Name = scangr.Name.String return nil } @@ -82,18 +93,6 @@ func (gr *Group) String() string { // Groups is a parsable slice of Group. type Groups []*Group -// FromRows scans the sql response data into Groups. -func (gr *Groups) FromRows(rows *sql.Rows) error { - for rows.Next() { - scangr := &Group{} - if err := scangr.FromRows(rows); err != nil { - return err - } - *gr = append(*gr, scangr) - } - return nil -} - func (gr Groups) config(cfg config) { for _i := range gr { gr[_i].config = cfg diff --git a/examples/traversal/ent/group_query.go b/examples/traversal/ent/group_query.go index 93474fa40..00bb879ae 100644 --- a/examples/traversal/ent/group_query.go +++ b/examples/traversal/ent/group_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/traversal/ent/group" "github.com/facebookincubator/ent/examples/traversal/ent/predicate" "github.com/facebookincubator/ent/examples/traversal/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // GroupQuery is the builder for querying Group entities. @@ -290,45 +291,31 @@ func (gq *GroupQuery) Select(field string, fields ...string) *GroupSelect { } func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { - rows := &sql.Rows{} - selector := gq.sqlQuery() - if unique := gq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Group + spec = gq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Group{config: gq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := gq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, gq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var grs Groups - if err := grs.FromRows(rows); err != nil { - return nil, err - } - grs.config(gq.config) - return grs, nil + return nodes, nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := gq.sqlQuery() - unique := []string{group.FieldID} - if len(gq.unique) > 0 { - unique = gq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := gq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := gq.querySpec() + return sqlgraph.CountNodes(ctx, gq.driver, spec) } func (gq *GroupQuery) sqlExist(ctx context.Context) (bool, error) { @@ -339,6 +326,42 @@ func (gq *GroupQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: group.Table, + Columns: group.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: group.FieldID, + }, + }, + From: gq.sql, + Unique: true, + } + if ps := gq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := gq.limit; limit != nil { + spec.Limit = *limit + } + if offset := gq.offset; offset != nil { + spec.Offset = *offset + } + if ps := gq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (gq *GroupQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(gq.driver.Dialect()) t1 := builder.Table(group.Table) @@ -610,7 +633,7 @@ func (gs *GroupSelect) sqlScan(ctx context.Context, v interface{}) error { } func (gs *GroupSelect) sqlQuery() sql.Querier { - view := "group_view" - return sql.Dialect(gs.driver.Dialect()). - Select(gs.fields...).From(gs.sql.As(view)) + selector := gs.sql + selector.Select(selector.Columns(gs.fields...)...) + return selector } diff --git a/examples/traversal/ent/group_update.go b/examples/traversal/ent/group_update.go index 1d8445124..976aed090 100644 --- a/examples/traversal/ent/group_update.go +++ b/examples/traversal/ent/group_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -455,27 +454,8 @@ func (guo *GroupUpdateOne) sqlSave(ctx context.Context) (gr *Group, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } gr = &Group{config: guo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - gr.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[0]) - } else if value.Valid { - gr.Name = value.String - } - return nil - } + spec.Assign = gr.assignValues + spec.ScanValues = gr.scanValues() if err = sqlgraph.UpdateNode(ctx, guo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/traversal/ent/pet.go b/examples/traversal/ent/pet.go index 8916f1297..dd624cbe2 100644 --- a/examples/traversal/ent/pet.go +++ b/examples/traversal/ent/pet.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/traversal/ent/pet" ) // Pet is the model entity for the Pet schema. @@ -22,21 +23,31 @@ type Pet struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into Pet. -func (pe *Pet) FromRows(rows *sql.Rows) error { - var scanpe struct { - ID int - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*Pet) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `pet.Columns`. - if err := rows.Scan( - &scanpe.ID, - &scanpe.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Pet fields. +func (pe *Pet) assignValues(values ...interface{}) error { + if m, n := len(values), len(pet.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + pe.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[0]) + } else if value.Valid { + pe.Name = value.String } - pe.ID = scanpe.ID - pe.Name = scanpe.Name.String return nil } @@ -82,18 +93,6 @@ func (pe *Pet) String() string { // Pets is a parsable slice of Pet. type Pets []*Pet -// FromRows scans the sql response data into Pets. -func (pe *Pets) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanpe := &Pet{} - if err := scanpe.FromRows(rows); err != nil { - return err - } - *pe = append(*pe, scanpe) - } - return nil -} - func (pe Pets) config(cfg config) { for _i := range pe { pe[_i].config = cfg diff --git a/examples/traversal/ent/pet_query.go b/examples/traversal/ent/pet_query.go index 7c627af10..a7e19ff97 100644 --- a/examples/traversal/ent/pet_query.go +++ b/examples/traversal/ent/pet_query.go @@ -17,6 +17,7 @@ import ( "github.com/facebookincubator/ent/examples/traversal/ent/pet" "github.com/facebookincubator/ent/examples/traversal/ent/predicate" "github.com/facebookincubator/ent/examples/traversal/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // PetQuery is the builder for querying Pet entities. @@ -290,45 +291,31 @@ func (pq *PetQuery) Select(field string, fields ...string) *PetSelect { } func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { - rows := &sql.Rows{} - selector := pq.sqlQuery() - if unique := pq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*Pet + spec = pq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &Pet{config: pq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := pq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, pq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var pes Pets - if err := pes.FromRows(rows); err != nil { - return nil, err - } - pes.config(pq.config) - return pes, nil + return nodes, nil } func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := pq.sqlQuery() - unique := []string{pet.FieldID} - if len(pq.unique) > 0 { - unique = pq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := pq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := pq.querySpec() + return sqlgraph.CountNodes(ctx, pq.driver, spec) } func (pq *PetQuery) sqlExist(ctx context.Context) (bool, error) { @@ -339,6 +326,42 @@ func (pq *PetQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: pet.Table, + Columns: pet.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: pet.FieldID, + }, + }, + From: pq.sql, + Unique: true, + } + if ps := pq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := pq.limit; limit != nil { + spec.Limit = *limit + } + if offset := pq.offset; offset != nil { + spec.Offset = *offset + } + if ps := pq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (pq *PetQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(pet.Table) @@ -610,7 +633,7 @@ func (ps *PetSelect) sqlScan(ctx context.Context, v interface{}) error { } func (ps *PetSelect) sqlQuery() sql.Querier { - view := "pet_view" - return sql.Dialect(ps.driver.Dialect()). - Select(ps.fields...).From(ps.sql.As(view)) + selector := ps.sql + selector.Select(selector.Columns(ps.fields...)...) + return selector } diff --git a/examples/traversal/ent/pet_update.go b/examples/traversal/ent/pet_update.go index 7e22207cd..e6de6bdc3 100644 --- a/examples/traversal/ent/pet_update.go +++ b/examples/traversal/ent/pet_update.go @@ -9,7 +9,6 @@ package ent import ( "context" "errors" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -455,27 +454,8 @@ func (puo *PetUpdateOne) sqlSave(ctx context.Context) (pe *Pet, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } pe = &Pet{config: puo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - pe.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[0]) - } else if value.Valid { - pe.Name = value.String - } - return nil - } + spec.Assign = pe.assignValues + spec.ScanValues = pe.scanValues() if err = sqlgraph.UpdateNode(ctx, puo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr diff --git a/examples/traversal/ent/user.go b/examples/traversal/ent/user.go index 6ba66355d..a44248c2f 100644 --- a/examples/traversal/ent/user.go +++ b/examples/traversal/ent/user.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/examples/traversal/ent/user" ) // User is the model entity for the User schema. @@ -24,24 +25,37 @@ type User struct { Name string `json:"name,omitempty"` } -// FromRows scans the sql response data into User. -func (u *User) FromRows(rows *sql.Rows) error { - var scanu struct { - ID int - Age sql.NullInt64 - Name sql.NullString +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues() []interface{} { + return []interface{}{ + &sql.NullInt64{}, + &sql.NullInt64{}, + &sql.NullString{}, } - // the order here should be the same as in the `user.Columns`. - if err := rows.Scan( - &scanu.ID, - &scanu.Age, - &scanu.Name, - ); err != nil { - return err +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(values ...interface{}) error { + if m, n := len(values), len(user.Columns); m != n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + value, ok := values[0].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + values = values[1:] + if value, ok := values[0].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field age", values[0]) + } else if value.Valid { + u.Age = int(value.Int64) + } + if value, ok := values[1].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[1]) + } else if value.Valid { + u.Name = value.String } - u.ID = scanu.ID - u.Age = int(scanu.Age.Int64) - u.Name = scanu.Name.String return nil } @@ -99,18 +113,6 @@ func (u *User) String() string { // Users is a parsable slice of User. type Users []*User -// FromRows scans the sql response data into Users. -func (u *Users) FromRows(rows *sql.Rows) error { - for rows.Next() { - scanu := &User{} - if err := scanu.FromRows(rows); err != nil { - return err - } - *u = append(*u, scanu) - } - return nil -} - func (u Users) config(cfg config) { for _i := range u { u[_i].config = cfg diff --git a/examples/traversal/ent/user_query.go b/examples/traversal/ent/user_query.go index 4b05a0331..0448edcb0 100644 --- a/examples/traversal/ent/user_query.go +++ b/examples/traversal/ent/user_query.go @@ -18,6 +18,7 @@ import ( "github.com/facebookincubator/ent/examples/traversal/ent/pet" "github.com/facebookincubator/ent/examples/traversal/ent/predicate" "github.com/facebookincubator/ent/examples/traversal/ent/user" + "github.com/facebookincubator/ent/schema/field" ) // UserQuery is the builder for querying User entities. @@ -315,45 +316,31 @@ func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { } func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - if unique := uq.unique; len(unique) == 0 { - selector.Distinct() + var ( + nodes []*User + spec = uq.querySpec() + ) + spec.ScanValues = func() []interface{} { + node := &User{config: uq.config} + nodes = append(nodes, node) + return node.scanValues() } - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { + spec.Assign = func(values ...interface{}) error { + if len(nodes) == 0 { + return fmt.Errorf("ent: Assign called without calling ScanValues") + } + node := nodes[len(nodes)-1] + return node.assignValues(values...) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, spec); err != nil { return nil, err } - defer rows.Close() - var us Users - if err := us.FromRows(rows); err != nil { - return nil, err - } - us.config(uq.config) - return us, nil + return nodes, nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { - rows := &sql.Rows{} - selector := uq.sqlQuery() - unique := []string{user.FieldID} - if len(uq.unique) > 0 { - unique = uq.unique - } - selector.Count(sql.Distinct(selector.Columns(unique...)...)) - query, args := selector.Query() - if err := uq.driver.Query(ctx, query, args, rows); err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errors.New("ent: no rows found") - } - var n int - if err := rows.Scan(&n); err != nil { - return 0, fmt.Errorf("ent: failed reading count: %v", err) - } - return n, nil + spec := uq.querySpec() + return sqlgraph.CountNodes(ctx, uq.driver, spec) } func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { @@ -364,6 +351,42 @@ func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { return n > 0, nil } +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + spec := &sqlgraph.QuerySpec{ + Node: &sqlgraph.NodeSpec{ + Table: user.Table, + Columns: user.Columns, + ID: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + From: uq.sql, + Unique: true, + } + if ps := uq.predicates; len(ps) > 0 { + spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.limit; limit != nil { + spec.Limit = *limit + } + if offset := uq.offset; offset != nil { + spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return spec +} + func (uq *UserQuery) sqlQuery() *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(user.Table) @@ -635,7 +658,7 @@ func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { } func (us *UserSelect) sqlQuery() sql.Querier { - view := "user_view" - return sql.Dialect(us.driver.Dialect()). - Select(us.fields...).From(us.sql.As(view)) + selector := us.sql + selector.Select(selector.Columns(us.fields...)...) + return selector } diff --git a/examples/traversal/ent/user_update.go b/examples/traversal/ent/user_update.go index b2bffb0e5..0a3b78ee7 100644 --- a/examples/traversal/ent/user_update.go +++ b/examples/traversal/ent/user_update.go @@ -8,7 +8,6 @@ package ent import ( "context" - "fmt" "github.com/facebookincubator/ent/dialect/sql" "github.com/facebookincubator/ent/dialect/sql/sqlgraph" @@ -865,33 +864,8 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { spec.Edges.Add = append(spec.Edges.Add, edge) } u = &User{config: uuo.config} - spec.ScanTypes = []interface{}{ - &sql.NullInt64{}, - &sql.NullInt64{}, - &sql.NullString{}, - } - spec.Assign = func(values ...interface{}) error { - if m, n := len(values), len(spec.ScanTypes); m != n { - return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) - } - value, ok := values[0].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) - } - u.ID = int(value.Int64) - values = values[1:] - if value, ok := values[0].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field age", values[0]) - } else if value.Valid { - u.Age = int(value.Int64) - } - if value, ok := values[1].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field name", values[1]) - } else if value.Valid { - u.Name = value.String - } - return nil - } + spec.Assign = u.assignValues + spec.ScanValues = u.scanValues() if err = sqlgraph.UpdateNode(ctx, uuo.driver, spec); err != nil { if cerr, ok := isSQLConstraintError(err); ok { err = cerr