diff --git a/entc/cmd/entc/entc.go b/entc/cmd/entc/entc.go index 4f58bdacb..46c664f47 100644 --- a/entc/cmd/entc/entc.go +++ b/entc/cmd/entc/entc.go @@ -77,10 +77,11 @@ func main() { }, func() *cobra.Command { var ( - cfg gen.Config - storage []string - idtype = idType(field.TypeInt) - cmd = &cobra.Command{ + cfg gen.Config + storage []string + template []string + idtype = idType(field.TypeInt) + cmd = &cobra.Command{ Use: "generate [flags] path", Short: "generate go code for the schema directory", Example: examples( @@ -99,6 +100,9 @@ func main() { failOnErr(err) cfg.Storage = append(cfg.Storage, sr) } + if len(template) > 0 { + cfg.Template = loadTemplate(template) + } cfg.IDType = field.Type(idtype) graph, err := loadGraph(path[0], cfg) failOnErr(err) @@ -109,6 +113,7 @@ func main() { cmd.Flags().Var(&idtype, "idtype", "type of the id field") cmd.Flags().StringVar(&cfg.Header, "header", "", "override codegen header") cmd.Flags().StringVar(&cfg.Target, "target", "", "target directory for codegen") + cmd.Flags().StringSliceVarP(&template, "template", "", nil, "external templates to execute") cmd.Flags().StringSliceVarP(&storage, "storage", "", []string{"sql"}, "list of storage drivers to support") return cmd }(), @@ -131,6 +136,34 @@ func loadGraph(path string, cfg gen.Config) (*gen.Graph, error) { return gen.NewGraph(cfg, spec.Schemas...) } +// loadTemplate loads templates from files or directory. +func loadTemplate(paths []string) *template.Template { + t := template.New("external"). + Funcs(gen.Funcs) + for _, path := range paths { + info, err := os.Stat(path) + failOnErr(err) + if !info.IsDir() { + buf, err := ioutil.ReadFile(path) + failOnErr(err) + t, err = t.Parse(string(buf)) + failOnErr(err) + continue + } + infos, err := ioutil.ReadDir(path) + failOnErr(err) + paths := make([]string, len(infos)) + for i := range infos { + paths[i] = filepath.Join(path, infos[0].Name()) + } + for _, tt := range loadTemplate(paths).Templates() { + t, err = t.AddParseTree(tt.Name(), tt.Tree) + failOnErr(err) + } + } + return t +} + // schema template for the "init" command. var tmpl = template.Must(template.New("schema"). Parse(`package schema diff --git a/entc/gen/func.go b/entc/gen/func.go index 9a5d43b2b..58d25aa9b 100644 --- a/entc/gen/func.go +++ b/entc/gen/func.go @@ -21,9 +21,9 @@ import ( ) var ( - rules = ruleset() - acronym = make(map[string]bool) - funcs = template.FuncMap{ + // Funcs are the predefined template + // functions used by the codegen. + Funcs = template.FuncMap{ "ops": ops, "add": add, "append": reflect.AppendSlice, @@ -49,6 +49,8 @@ var ( "xtemplate": xtemplate, "hasTemplate": hasTemplate, } + rules = ruleset() + acronym = make(map[string]bool) ) // ops returns all operations for given field. diff --git a/entc/gen/graph.go b/entc/gen/graph.go index 08ee0e9cc..de8b1f650 100644 --- a/entc/gen/graph.go +++ b/entc/gen/graph.go @@ -14,6 +14,7 @@ import ( "os/exec" "path/filepath" "text/template" + "text/template/parse" "github.com/facebookincubator/ent/dialect/sql/schema" "github.com/facebookincubator/ent/entc/load" @@ -77,21 +78,7 @@ func NewGraph(c Config, schemas ...*load.Schema) (g *Graph, err error) { // Gen generates the artifacts for the graph. func (g *Graph) Gen() (err error) { defer catch(&err) - var ( - external []GraphTemplate - templates = template.Must(templates.Clone()) - ) - if g.Template != nil { - for _, tmpl := range g.Template.Templates() { - if name := tmpl.Name(); templates.Lookup(name) == nil { - external = append(external, GraphTemplate{ - Name: name, - Format: snake(name) + ".go", - }) - } - templates = template.Must(templates.AddParseTree(tmpl.Name(), tmpl.Tree)) - } - } + templates, external := g.templates() for _, n := range g.Nodes { path := filepath.Join(g.Config.Target, n.Package()) check(os.MkdirAll(path, os.ModePerm), "create dir %q", path) @@ -394,6 +381,29 @@ func (g *Graph) typ(name string) (*Type, bool) { return nil, false } +// templates returns the template.Template for the code and external templates +// to execute on the Graph object if provided. +func (g *Graph) templates() (*template.Template, []GraphTemplate) { + templates = template.Must(templates.Clone()) + if g.Template == nil { + return templates, nil + } + external := make([]GraphTemplate, 0) + for _, tmpl := range g.Template.Templates() { + name := tmpl.Name() + // check that is not defined in the default templates + // it's not the root. + if templates.Lookup(name) == nil && !parse.IsEmptyTree(tmpl.Root) { + external = append(external, GraphTemplate{ + Name: name, + Format: snake(name) + ".go", + }) + } + templates = template.Must(templates.AddParseTree(name, tmpl.Tree)) + } + return templates, external +} + // expect panic if the condition is false. func expect(cond bool, msg string, args ...interface{}) { if !cond { diff --git a/entc/gen/template.go b/entc/gen/template.go index 2d3796c94..39b083edc 100644 --- a/entc/gen/template.go +++ b/entc/gen/template.go @@ -118,7 +118,7 @@ var ( ) func init() { - templates.Funcs(funcs) + templates.Funcs(Funcs) for _, asset := range internal.AssetNames() { templates = template.Must(templates.Parse(string(internal.MustAsset(asset)))) } diff --git a/entc/integration/generate.go b/entc/integration/generate.go index 4a8d574ee..bfcd7d03a 100644 --- a/entc/integration/generate.go +++ b/entc/integration/generate.go @@ -7,3 +7,4 @@ package integration //go:generate go run ../cmd/entc/entc.go generate --storage=sql,gremlin --idtype string --header "Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated (@generated) by entc, DO NOT EDIT." ./ent/schema //go:generate go run ../cmd/entc/entc.go generate --header "Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated (@generated) by entc, DO NOT EDIT." ./migrate/entv1/schema //go:generate go run ../cmd/entc/entc.go generate --header "Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated (@generated) by entc, DO NOT EDIT." ./migrate/entv2/schema +//go:generate go run ../cmd/entc/entc.go generate --header "Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated (@generated) by entc, DO NOT EDIT." ./template/ent/schema --template=template/ent/template diff --git a/entc/integration/template/ent/client.go b/entc/integration/template/ent/client.go new file mode 100644 index 000000000..5da4516fb --- /dev/null +++ b/entc/integration/template/ent/client.go @@ -0,0 +1,280 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "log" + + "github.com/facebookincubator/ent/entc/integration/template/ent/migrate" + + "github.com/facebookincubator/ent/entc/integration/template/ent/group" + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + "github.com/facebookincubator/ent/entc/integration/template/ent/user" + + "github.com/facebookincubator/ent/dialect" + "github.com/facebookincubator/ent/dialect/sql" +) + +// Client is the client that holds all ent builders. +type Client struct { + config + // Schema is the client for creating, migrating and dropping schema. + Schema *migrate.Schema + // Group is the client for interacting with the Group builders. + Group *GroupClient + // Pet is the client for interacting with the Pet builders. + Pet *PetClient + // User is the client for interacting with the User builders. + User *UserClient +} + +// NewClient creates a new client configured with the given options. +func NewClient(opts ...Option) *Client { + c := config{log: log.Println} + c.options(opts...) + return &Client{ + config: c, + Schema: migrate.NewSchema(c.driver), + Group: NewGroupClient(c), + Pet: NewPetClient(c), + User: NewUserClient(c), + } +} + +// Tx returns a new transactional client. +func (c *Client) Tx(ctx context.Context) (*Tx, error) { + if _, ok := c.driver.(*txDriver); ok { + return nil, fmt.Errorf("ent: cannot start a transaction within a transaction") + } + tx, err := newTx(ctx, c.driver) + if err != nil { + return nil, fmt.Errorf("ent: starting a transaction: %v", err) + } + cfg := config{driver: tx, log: c.log, debug: c.debug} + return &Tx{ + config: cfg, + Group: NewGroupClient(cfg), + Pet: NewPetClient(cfg), + User: NewUserClient(cfg), + }, nil +} + +// Debug returns a new debug-client. It's used to get verbose logging on specific operations. +// +// client.Debug(). +// Group. +// Query(). +// Count(ctx) +// +func (c *Client) Debug() *Client { + if c.debug { + return c + } + cfg := config{driver: dialect.Debug(c.driver, c.log), log: c.log, debug: true} + return &Client{ + config: cfg, + Schema: migrate.NewSchema(cfg.driver), + Group: NewGroupClient(cfg), + Pet: NewPetClient(cfg), + User: NewUserClient(cfg), + } +} + +// GroupClient is a client for the Group schema. +type GroupClient struct { + config +} + +// NewGroupClient returns a client for the Group from the given config. +func NewGroupClient(c config) *GroupClient { + return &GroupClient{config: c} +} + +// Create returns a create builder for Group. +func (c *GroupClient) Create() *GroupCreate { + return &GroupCreate{config: c.config} +} + +// Update returns an update builder for Group. +func (c *GroupClient) Update() *GroupUpdate { + return &GroupUpdate{config: c.config} +} + +// UpdateOne returns an update builder for the given entity. +func (c *GroupClient) UpdateOne(gr *Group) *GroupUpdateOne { + return c.UpdateOneID(gr.ID) +} + +// UpdateOneID returns an update builder for the given id. +func (c *GroupClient) UpdateOneID(id int) *GroupUpdateOne { + return &GroupUpdateOne{config: c.config, id: id} +} + +// Delete returns a delete builder for Group. +func (c *GroupClient) Delete() *GroupDelete { + return &GroupDelete{config: c.config} +} + +// DeleteOne returns a delete builder for the given entity. +func (c *GroupClient) DeleteOne(gr *Group) *GroupDeleteOne { + return c.DeleteOneID(gr.ID) +} + +// DeleteOneID returns a delete builder for the given id. +func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { + return &GroupDeleteOne{c.Delete().Where(group.ID(id))} +} + +// Create returns a query builder for Group. +func (c *GroupClient) Query() *GroupQuery { + return &GroupQuery{config: c.config} +} + +// PetClient is a client for the Pet schema. +type PetClient struct { + config +} + +// NewPetClient returns a client for the Pet from the given config. +func NewPetClient(c config) *PetClient { + return &PetClient{config: c} +} + +// Create returns a create builder for Pet. +func (c *PetClient) Create() *PetCreate { + return &PetCreate{config: c.config} +} + +// Update returns an update builder for Pet. +func (c *PetClient) Update() *PetUpdate { + return &PetUpdate{config: c.config} +} + +// UpdateOne returns an update builder for the given entity. +func (c *PetClient) UpdateOne(pe *Pet) *PetUpdateOne { + return c.UpdateOneID(pe.ID) +} + +// UpdateOneID returns an update builder for the given id. +func (c *PetClient) UpdateOneID(id int) *PetUpdateOne { + return &PetUpdateOne{config: c.config, id: id} +} + +// Delete returns a delete builder for Pet. +func (c *PetClient) Delete() *PetDelete { + return &PetDelete{config: c.config} +} + +// DeleteOne returns a delete builder for the given entity. +func (c *PetClient) DeleteOne(pe *Pet) *PetDeleteOne { + return c.DeleteOneID(pe.ID) +} + +// DeleteOneID returns a delete builder for the given id. +func (c *PetClient) DeleteOneID(id int) *PetDeleteOne { + return &PetDeleteOne{c.Delete().Where(pet.ID(id))} +} + +// Create returns a query builder for Pet. +func (c *PetClient) Query() *PetQuery { + return &PetQuery{config: c.config} +} + +// QueryOwner queries the owner edge of a Pet. +func (c *PetClient) QueryOwner(pe *Pet) *UserQuery { + query := &UserQuery{config: c.config} + id := pe.ID + t1 := sql.Table(user.Table) + t2 := sql.Select(pet.OwnerColumn). + From(sql.Table(pet.OwnerTable)). + Where(sql.EQ(pet.FieldID, id)) + query.sql = sql.Select().From(t1).Join(t2).On(t1.C(user.FieldID), t2.C(pet.OwnerColumn)) + + return query +} + +// UserClient is a client for the User schema. +type UserClient struct { + config +} + +// NewUserClient returns a client for the User from the given config. +func NewUserClient(c config) *UserClient { + return &UserClient{config: c} +} + +// Create returns a create builder for User. +func (c *UserClient) Create() *UserCreate { + return &UserCreate{config: c.config} +} + +// Update returns an update builder for User. +func (c *UserClient) Update() *UserUpdate { + return &UserUpdate{config: c.config} +} + +// UpdateOne returns an update builder for the given entity. +func (c *UserClient) UpdateOne(u *User) *UserUpdateOne { + return c.UpdateOneID(u.ID) +} + +// UpdateOneID returns an update builder for the given id. +func (c *UserClient) UpdateOneID(id int) *UserUpdateOne { + return &UserUpdateOne{config: c.config, id: id} +} + +// Delete returns a delete builder for User. +func (c *UserClient) Delete() *UserDelete { + return &UserDelete{config: c.config} +} + +// DeleteOne returns a delete builder for the given entity. +func (c *UserClient) DeleteOne(u *User) *UserDeleteOne { + return c.DeleteOneID(u.ID) +} + +// DeleteOneID returns a delete builder for the given id. +func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { + return &UserDeleteOne{c.Delete().Where(user.ID(id))} +} + +// Create returns a query builder for User. +func (c *UserClient) Query() *UserQuery { + return &UserQuery{config: c.config} +} + +// QueryPets queries the pets edge of a User. +func (c *UserClient) QueryPets(u *User) *PetQuery { + query := &PetQuery{config: c.config} + id := u.ID + query.sql = sql.Select().From(sql.Table(pet.Table)). + Where(sql.EQ(user.PetsColumn, id)) + + return query +} + +// QueryFriends queries the friends edge of a User. +func (c *UserClient) QueryFriends(u *User) *UserQuery { + query := &UserQuery{config: c.config} + id := u.ID + t1 := sql.Table(user.Table) + t2 := sql.Table(user.Table) + t3 := sql.Table(user.FriendsTable) + t4 := sql.Select(t3.C(user.FriendsPrimaryKey[1])). + From(t3). + Join(t2). + On(t3.C(user.FriendsPrimaryKey[0]), t2.C(user.FieldID)). + Where(sql.EQ(t2.C(user.FieldID), id)) + query.sql = sql.Select(). + From(t1). + Join(t4). + On(t1.C(user.FieldID), t4.C(user.FriendsPrimaryKey[1])) + + return query +} diff --git a/entc/integration/template/ent/config.go b/entc/integration/template/ent/config.go new file mode 100644 index 000000000..1a5d585d1 --- /dev/null +++ b/entc/integration/template/ent/config.go @@ -0,0 +1,55 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "github.com/facebookincubator/ent/dialect" +) + +// Option function to configure the client. +type Option func(*config) + +// Config is the configuration for the client and its builder. +type config struct { + // driver used for executing database requests. + driver dialect.Driver + // debug enable a debug logging. + debug bool + // log used for logging on debug mode. + log func(...interface{}) +} + +// Options applies the options on the config object. +func (c *config) options(opts ...Option) { + for _, opt := range opts { + opt(c) + } + if c.debug { + c.driver = dialect.Debug(c.driver, c.log) + } +} + +// Debug enables debug logging on the ent.Driver. +func Debug() Option { + return func(c *config) { + c.debug = true + } +} + +// Log sets the logging function for debug mode. +func Log(fn func(...interface{})) Option { + return func(c *config) { + c.log = fn + } +} + +// Driver configures the client driver. +func Driver(driver dialect.Driver) Option { + return func(c *config) { + c.driver = driver + } +} diff --git a/entc/integration/template/ent/context.go b/entc/integration/template/ent/context.go new file mode 100644 index 000000000..c6a16805e --- /dev/null +++ b/entc/integration/template/ent/context.go @@ -0,0 +1,24 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" +) + +type contextKey struct{} + +// FromContext returns the Client stored in a context, or nil if there isn't one. +func FromContext(ctx context.Context) *Client { + c, _ := ctx.Value(contextKey{}).(*Client) + return c +} + +// NewContext returns a new context with the given Client attached. +func NewContext(parent context.Context, c *Client) context.Context { + return context.WithValue(parent, contextKey{}, c) +} diff --git a/entc/integration/template/ent/ent.go b/entc/integration/template/ent/ent.go new file mode 100644 index 000000000..bb27d36a4 --- /dev/null +++ b/entc/integration/template/ent/ent.go @@ -0,0 +1,196 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + + "github.com/facebookincubator/ent/dialect" + "github.com/facebookincubator/ent/dialect/sql" +) + +// Order applies an ordering on either graph traversal or sql selector. +type Order func(*sql.Selector) + +// Asc applies the given fields in ASC order. +func Asc(fields ...string) Order { + return Order( + func(s *sql.Selector) { + for _, f := range fields { + s.OrderBy(sql.Asc(f)) + } + }, + ) +} + +// Desc applies the given fields in DESC order. +func Desc(fields ...string) Order { + return Order( + func(s *sql.Selector) { + for _, f := range fields { + s.OrderBy(sql.Desc(f)) + } + }, + ) +} + +// Aggregate applies an aggregation step on the group-by traversal/selector. +type Aggregate struct { + // SQL the column wrapped with the aggregation function. + SQL func(*sql.Selector) string +} + +// As is a pseudo aggregation function for renaming another other functions with custom names. For example: +// +// GroupBy(field1, field2). +// Aggregate(ent.As(ent.Sum(field1), "sum_field1"), (ent.As(ent.Sum(field2), "sum_field2")). +// Scan(ctx, &v) +// +func As(fn Aggregate, end string) Aggregate { + return Aggregate{ + SQL: func(s *sql.Selector) string { + return sql.As(fn.SQL(s), end) + }, + } +} + +// Count applies the "count" aggregation function on each group. +func Count() Aggregate { + return Aggregate{ + SQL: func(s *sql.Selector) string { + return sql.Count("*") + }, + } +} + +// Max applies the "max" aggregation function on the given field of each group. +func Max(field string) Aggregate { + return Aggregate{ + SQL: func(s *sql.Selector) string { + return sql.Max(s.C(field)) + }, + } +} + +// Mean applies the "mean" aggregation function on the given field of each group. +func Mean(field string) Aggregate { + return Aggregate{ + SQL: func(s *sql.Selector) string { + return sql.Avg(s.C(field)) + }, + } +} + +// Min applies the "min" aggregation function on the given field of each group. +func Min(field string) Aggregate { + return Aggregate{ + SQL: func(s *sql.Selector) string { + return sql.Min(s.C(field)) + }, + } +} + +// Sum applies the "sum" aggregation function on the given field of each group. +func Sum(field string) Aggregate { + return Aggregate{ + SQL: func(s *sql.Selector) string { + return sql.Sum(s.C(field)) + }, + } +} + +// ErrNotFound returns when trying to fetch a specific entity and it was not found in the database. +type ErrNotFound struct { + label string +} + +// Error implements the error interface. +func (e *ErrNotFound) Error() string { + return fmt.Sprintf("ent: %s not found", e.label) +} + +// IsNotFound returns a boolean indicating whether the error is a not found error. +func IsNotFound(err error) bool { + _, ok := err.(*ErrNotFound) + return ok +} + +// MaskNotFound masks nor found error. +func MaskNotFound(err error) error { + if IsNotFound(err) { + return nil + } + return err +} + +// ErrNotSingular returns when trying to fetch a singular entity and more then one was found in the database. +type ErrNotSingular struct { + label string +} + +// Error implements the error interface. +func (e *ErrNotSingular) Error() string { + return fmt.Sprintf("ent: %s not singular", e.label) +} + +// IsNotSingular returns a boolean indicating whether the error is a not singular error. +func IsNotSingular(err error) bool { + _, ok := err.(*ErrNotSingular) + return ok +} + +// ErrConstraintFailed returns when trying to create/update one or more entities and +// one or more of their constraints failed. For example, violation of edge or field uniqueness. +type ErrConstraintFailed struct { + msg string + wrap error +} + +// Error implements the error interface. +func (e ErrConstraintFailed) Error() string { + return fmt.Sprintf("ent: unique constraint failed: %s", e.msg) +} + +// Unwrap implements the errors.Wrapper interface. +func (e *ErrConstraintFailed) Unwrap() error { + return e.wrap +} + +// IsConstraintFailure returns a boolean indicating whether the error is a constraint failure. +func IsConstraintFailure(err error) bool { + _, ok := err.(*ErrConstraintFailed) + return ok +} + +func isSQLConstraintError(err error) (*ErrConstraintFailed, bool) { + // Error number 1062 is ER_DUP_ENTRY in mysql, and "UNIQUE constraint failed" is SQLite prefix. + if msg := err.Error(); strings.HasPrefix(msg, "Error 1062") || strings.HasPrefix(msg, "UNIQUE constraint failed") { + return &ErrConstraintFailed{msg, err}, true + } + return nil, false +} + +// rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred. +func rollback(tx dialect.Tx, err error) error { + if rerr := tx.Rollback(); rerr != nil { + err = fmt.Errorf("%s: %v", err.Error(), rerr) + } + if err, ok := isSQLConstraintError(err); ok { + return err + } + return err +} + +// keys returns the keys/ids from the edge map. +func keys(m map[int]struct{}) []int { + s := make([]int, 0, len(m)) + for id, _ := range m { + s = append(s, id) + } + return s +} diff --git a/entc/integration/template/ent/example_test.go b/entc/integration/template/ent/example_test.go new file mode 100644 index 000000000..84c2c26ab --- /dev/null +++ b/entc/integration/template/ent/example_test.go @@ -0,0 +1,116 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "log" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// dsn for the database. In order to run the tests locally, run the following command: +// +// ENT_INTEGRATION_ENDPOINT="root:pass@tcp(localhost:3306)/test?parseTime=True" go test -v +// +var dsn string + +func ExampleGroup() { + if dsn == "" { + return + } + ctx := context.Background() + drv, err := sql.Open("mysql", dsn) + if err != nil { + log.Fatalf("failed creating database client: %v", err) + } + defer drv.Close() + client := NewClient(Driver(drv)) + // creating vertices for the group's edges. + + // create group vertex with its edges. + gr := client.Group. + Create(). + SetMaxUsers(1). + SaveX(ctx) + log.Println("group created:", gr) + + // query edges. + + // Output: +} +func ExamplePet() { + if dsn == "" { + return + } + ctx := context.Background() + drv, err := sql.Open("mysql", dsn) + if err != nil { + log.Fatalf("failed creating database client: %v", err) + } + defer drv.Close() + client := NewClient(Driver(drv)) + // creating vertices for the pet's edges. + + // create pet vertex with its edges. + pe := client.Pet. + Create(). + SetAge(1). + SaveX(ctx) + log.Println("pet created:", pe) + + // query edges. + + // Output: +} +func ExampleUser() { + if dsn == "" { + return + } + ctx := context.Background() + drv, err := sql.Open("mysql", dsn) + if err != nil { + log.Fatalf("failed creating database client: %v", err) + } + defer drv.Close() + client := NewClient(Driver(drv)) + // creating vertices for the user's edges. + pe0 := client.Pet. + Create(). + SetAge(1). + SaveX(ctx) + log.Println("pet created:", pe0) + u1 := client.User. + Create(). + SetName("string"). + SaveX(ctx) + log.Println("user created:", u1) + + // create user vertex with its edges. + u := client.User. + Create(). + SetName("string"). + AddPets(pe0). + AddFriends(u1). + SaveX(ctx) + log.Println("user created:", u) + + // query edges. + pe0, err = u.QueryPets().First(ctx) + if err != nil { + log.Fatalf("failed querying pets: %v", err) + } + log.Println("pets found:", pe0) + + u1, err = u.QueryFriends().First(ctx) + if err != nil { + log.Fatalf("failed querying friends: %v", err) + } + log.Println("friends found:", u1) + + // Output: +} diff --git a/entc/integration/template/ent/group.go b/entc/integration/template/ent/group.go new file mode 100644 index 000000000..198abf41c --- /dev/null +++ b/entc/integration/template/ent/group.go @@ -0,0 +1,90 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "bytes" + "fmt" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// Group is the model entity for the Group schema. +type Group struct { + config + // ID of the ent. + ID int `json:"id,omitempty"` + // MaxUsers holds the value of the "max_users" field. + MaxUsers int `json:"max_users,omitempty"` +} + +// FromRows scans the sql response data into Group. +func (gr *Group) FromRows(rows *sql.Rows) error { + var vgr struct { + ID int + MaxUsers sql.NullInt64 + } + // the order here should be the same as in the `group.Columns`. + if err := rows.Scan( + &vgr.ID, + &vgr.MaxUsers, + ); err != nil { + return err + } + gr.ID = vgr.ID + gr.MaxUsers = int(vgr.MaxUsers.Int64) + return nil +} + +// Update returns a builder for updating this Group. +// Note that, you need to call Group.Unwrap() before calling this method, if this Group +// was returned from a transaction, and the transaction was committed or rolled back. +func (gr *Group) Update() *GroupUpdateOne { + return (&GroupClient{gr.config}).UpdateOne(gr) +} + +// Unwrap unwraps the entity that was returned from a transaction after it was closed, +// so that all next queries will be executed through the driver which created the transaction. +func (gr *Group) Unwrap() *Group { + tx, ok := gr.config.driver.(*txDriver) + if !ok { + panic("ent: Group is not a transactional entity") + } + gr.config.driver = tx.drv + return gr +} + +// String implements the fmt.Stringer. +func (gr *Group) String() string { + buf := bytes.NewBuffer(nil) + buf.WriteString("Group(") + buf.WriteString(fmt.Sprintf("id=%v", gr.ID)) + buf.WriteString(fmt.Sprintf(", max_users=%v", gr.MaxUsers)) + buf.WriteString(")") + return buf.String() +} + +// Groups is a parsable slice of Group. +type Groups []*Group + +// FromRows scans the sql response data into Groups. +func (gr *Groups) FromRows(rows *sql.Rows) error { + for rows.Next() { + vgr := &Group{} + if err := vgr.FromRows(rows); err != nil { + return err + } + *gr = append(*gr, vgr) + } + return nil +} + +func (gr Groups) config(cfg config) { + for i := range gr { + gr[i].config = cfg + } +} diff --git a/entc/integration/template/ent/group/group.go b/entc/integration/template/ent/group/group.go new file mode 100644 index 000000000..0835cd132 --- /dev/null +++ b/entc/integration/template/ent/group/group.go @@ -0,0 +1,25 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package group + +const ( + // Label holds the string label denoting the group type in the database. + Label = "group" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldMaxUsers holds the string denoting the max_users vertex property in the database. + FieldMaxUsers = "max_users" + + // Table holds the table name of the group in the database. + Table = "groups" +) + +// Columns holds all SQL columns are group fields. +var Columns = []string{ + FieldID, + FieldMaxUsers, +} diff --git a/entc/integration/template/ent/group/where.go b/entc/integration/template/ent/group/where.go new file mode 100644 index 000000000..11dc15464 --- /dev/null +++ b/entc/integration/template/ent/group/where.go @@ -0,0 +1,249 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package group + +import ( + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// ID filters vertices based on their identifier. +func ID(id int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldID), id)) + }, + ) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldID), id)) + }, + ) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldID), id)) + }, + ) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldID), id)) + }, + ) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldID), id)) + }, + ) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldID), id)) + }, + ) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldID), id)) + }, + ) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(ids) == 0 { + s.Where(sql.False()) + return + } + v := make([]interface{}, len(ids)) + for i := range v { + v[i] = ids[i] + } + s.Where(sql.In(s.C(FieldID), v...)) + }, + ) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(ids) == 0 { + s.Where(sql.False()) + return + } + v := make([]interface{}, len(ids)) + for i := range v { + v[i] = ids[i] + } + s.Where(sql.NotIn(s.C(FieldID), v...)) + }, + ) +} + +// MaxUsers applies equality check predicate on the "max_users" field. It's identical to MaxUsersEQ. +func MaxUsers(v int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldMaxUsers), v)) + }, + ) +} + +// MaxUsersEQ applies the EQ predicate on the "max_users" field. +func MaxUsersEQ(v int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldMaxUsers), v)) + }, + ) +} + +// MaxUsersNEQ applies the NEQ predicate on the "max_users" field. +func MaxUsersNEQ(v int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldMaxUsers), v)) + }, + ) +} + +// MaxUsersGT applies the GT predicate on the "max_users" field. +func MaxUsersGT(v int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldMaxUsers), v)) + }, + ) +} + +// MaxUsersGTE applies the GTE predicate on the "max_users" field. +func MaxUsersGTE(v int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldMaxUsers), v)) + }, + ) +} + +// MaxUsersLT applies the LT predicate on the "max_users" field. +func MaxUsersLT(v int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldMaxUsers), v)) + }, + ) +} + +// MaxUsersLTE applies the LTE predicate on the "max_users" field. +func MaxUsersLTE(v int) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldMaxUsers), v)) + }, + ) +} + +// MaxUsersIn applies the In predicate on the "max_users" field. +func MaxUsersIn(vs ...int) predicate.Group { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Group( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(vs) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldMaxUsers), v...)) + }, + ) +} + +// MaxUsersNotIn applies the NotIn predicate on the "max_users" field. +func MaxUsersNotIn(vs ...int) predicate.Group { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Group( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(vs) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldMaxUsers), v...)) + }, + ) +} + +// And groups list of predicates with the AND operator between them. +func And(predicates ...predicate.Group) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + for _, p := range predicates { + p(s) + } + }, + ) +} + +// Or groups list of predicates with the OR operator between them. +func Or(predicates ...predicate.Group) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + for i, p := range predicates { + if i > 0 { + s.Or() + } + p(s) + } + }, + ) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Group) predicate.Group { + return predicate.Group( + func(s *sql.Selector) { + p(s.Not()) + }, + ) +} diff --git a/entc/integration/template/ent/group_create.go b/entc/integration/template/ent/group_create.go new file mode 100644 index 000000000..49539e925 --- /dev/null +++ b/entc/integration/template/ent/group_create.go @@ -0,0 +1,74 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + + "github.com/facebookincubator/ent/entc/integration/template/ent/group" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// GroupCreate is the builder for creating a Group entity. +type GroupCreate struct { + config + max_users *int +} + +// SetMaxUsers sets the max_users field. +func (gc *GroupCreate) SetMaxUsers(i int) *GroupCreate { + gc.max_users = &i + return gc +} + +// Save creates the Group in the database. +func (gc *GroupCreate) Save(ctx context.Context) (*Group, error) { + if gc.max_users == nil { + return nil, errors.New("ent: missing required field \"max_users\"") + } + return gc.sqlSave(ctx) +} + +// SaveX calls Save and panics if Save returns an error. +func (gc *GroupCreate) SaveX(ctx context.Context) *Group { + v, err := gc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +func (gc *GroupCreate) sqlSave(ctx context.Context) (*Group, error) { + var ( + res sql.Result + gr = &Group{config: gc.config} + ) + tx, err := gc.driver.Tx(ctx) + if err != nil { + return nil, err + } + builder := sql.Insert(group.Table).Default(gc.driver.Dialect()) + if gc.max_users != nil { + builder.Set(group.FieldMaxUsers, *gc.max_users) + gr.MaxUsers = *gc.max_users + } + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + id, err := res.LastInsertId() + if err != nil { + return nil, rollback(tx, err) + } + gr.ID = int(id) + if err := tx.Commit(); err != nil { + return nil, err + } + return gr, nil +} diff --git a/entc/integration/template/ent/group_delete.go b/entc/integration/template/ent/group_delete.go new file mode 100644 index 000000000..d11908ea6 --- /dev/null +++ b/entc/integration/template/ent/group_delete.go @@ -0,0 +1,65 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + + "github.com/facebookincubator/ent/entc/integration/template/ent/group" + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// GroupDelete is the builder for deleting a Group entity. +type GroupDelete struct { + config + predicates []predicate.Group +} + +// Where adds a new predicate for the builder. +func (gd *GroupDelete) Where(ps ...predicate.Group) *GroupDelete { + gd.predicates = append(gd.predicates, ps...) + return gd +} + +// Exec executes the deletion query. +func (gd *GroupDelete) Exec(ctx context.Context) error { + return gd.sqlExec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (gd *GroupDelete) ExecX(ctx context.Context) { + if err := gd.Exec(ctx); err != nil { + panic(err) + } +} + +func (gd *GroupDelete) sqlExec(ctx context.Context) error { + var res sql.Result + selector := sql.Select().From(sql.Table(group.Table)) + for _, p := range gd.predicates { + p(selector) + } + query, args := sql.Delete(group.Table).FromSelect(selector).Query() + return gd.driver.Exec(ctx, query, args, &res) +} + +// GroupDeleteOne is the builder for deleting a single Group entity. +type GroupDeleteOne struct { + gd *GroupDelete +} + +// Exec executes the deletion query. +func (gdo *GroupDeleteOne) Exec(ctx context.Context) error { + return gdo.gd.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (gdo *GroupDeleteOne) ExecX(ctx context.Context) { + gdo.gd.ExecX(ctx) +} diff --git a/entc/integration/template/ent/group_query.go b/entc/integration/template/ent/group_query.go new file mode 100644 index 000000000..967a4cc7a --- /dev/null +++ b/entc/integration/template/ent/group_query.go @@ -0,0 +1,611 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "math" + + "github.com/facebookincubator/ent/entc/integration/template/ent/group" + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// GroupQuery is the builder for querying Group entities. +type GroupQuery struct { + config + limit *int + offset *int + order []Order + unique []string + predicates []predicate.Group + // intermediate queries. + sql *sql.Selector +} + +// Where adds a new predicate for the builder. +func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { + gq.predicates = append(gq.predicates, ps...) + return gq +} + +// Limit adds a limit step to the query. +func (gq *GroupQuery) Limit(limit int) *GroupQuery { + gq.limit = &limit + return gq +} + +// Offset adds an offset step to the query. +func (gq *GroupQuery) Offset(offset int) *GroupQuery { + gq.offset = &offset + return gq +} + +// Order adds an order step to the query. +func (gq *GroupQuery) Order(o ...Order) *GroupQuery { + gq.order = append(gq.order, o...) + return gq +} + +// Get returns a Group entity by its id. +func (gq *GroupQuery) Get(ctx context.Context, id int) (*Group, error) { + return gq.Where(group.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (gq *GroupQuery) GetX(ctx context.Context, id int) *Group { + gr, err := gq.Get(ctx, id) + if err != nil { + panic(err) + } + return gr +} + +// First returns the first Group entity in the query. Returns *ErrNotFound when no group was found. +func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { + grs, err := gq.Limit(1).All(ctx) + if err != nil { + return nil, err + } + if len(grs) == 0 { + return nil, &ErrNotFound{group.Label} + } + return grs[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (gq *GroupQuery) FirstX(ctx context.Context) *Group { + gr, err := gq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return gr +} + +// FirstID returns the first Group id in the query. Returns *ErrNotFound when no id was found. +func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = gq.Limit(1).IDs(ctx); err != nil { + return + } + if len(ids) == 0 { + err = &ErrNotFound{group.Label} + return + } + return ids[0], nil +} + +// FirstXID is like FirstID, but panics if an error occurs. +func (gq *GroupQuery) FirstXID(ctx context.Context) int { + id, err := gq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns the only Group entity in the query, returns an error if not exactly one entity was returned. +func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { + grs, err := gq.Limit(2).All(ctx) + if err != nil { + return nil, err + } + switch len(grs) { + case 1: + return grs[0], nil + case 0: + return nil, &ErrNotFound{group.Label} + default: + return nil, &ErrNotSingular{group.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { + gr, err := gq.Only(ctx) + if err != nil { + panic(err) + } + return gr +} + +// OnlyID returns the only Group id in the query, returns an error if not exactly one id was returned. +func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = gq.Limit(2).IDs(ctx); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &ErrNotFound{group.Label} + default: + err = &ErrNotSingular{group.Label} + } + return +} + +// OnlyXID is like OnlyID, but panics if an error occurs. +func (gq *GroupQuery) OnlyXID(ctx context.Context) int { + id, err := gq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Groups. +func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { + return gq.sqlAll(ctx) +} + +// AllX is like All, but panics if an error occurs. +func (gq *GroupQuery) AllX(ctx context.Context) []*Group { + grs, err := gq.All(ctx) + if err != nil { + panic(err) + } + return grs +} + +// IDs executes the query and returns a list of Group ids. +func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) { + return gq.sqlIDs(ctx) +} + +// IDsX is like IDs, but panics if an error occurs. +func (gq *GroupQuery) IDsX(ctx context.Context) []int { + ids, err := gq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (gq *GroupQuery) Count(ctx context.Context) (int, error) { + return gq.sqlCount(ctx) +} + +// CountX is like Count, but panics if an error occurs. +func (gq *GroupQuery) CountX(ctx context.Context) int { + count, err := gq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { + return gq.sqlExist(ctx) +} + +// ExistX is like Exist, but panics if an error occurs. +func (gq *GroupQuery) ExistX(ctx context.Context) bool { + exist, err := gq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the query builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (gq *GroupQuery) Clone() *GroupQuery { + return &GroupQuery{ + config: gq.config, + limit: gq.limit, + offset: gq.offset, + order: append([]Order{}, gq.order...), + unique: append([]string{}, gq.unique...), + predicates: append([]predicate.Group{}, gq.predicates...), + // clone intermediate queries. + sql: gq.sql.Clone(), + } +} + +// GroupBy used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// MaxUsers int `json:"max_users,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Group.Query(). +// GroupBy(group.FieldMaxUsers). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +// +func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { + group := &GroupGroupBy{config: gq.config} + group.fields = append([]string{field}, fields...) + group.sql = gq.sqlQuery() + return group +} + +// Select one or more fields from the given query. +// +// Example: +// +// var v []struct { +// MaxUsers int `json:"max_users,omitempty"` +// } +// +// client.Group.Query(). +// Select(group.FieldMaxUsers). +// Scan(ctx, &v) +// +func (gq *GroupQuery) Select(field string, fields ...string) *GroupSelect { + selector := &GroupSelect{config: gq.config} + selector.fields = append([]string{field}, fields...) + selector.sql = gq.sqlQuery() + return selector +} + +func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { + rows := &sql.Rows{} + selector := gq.sqlQuery() + if unique := gq.unique; len(unique) == 0 { + selector.Distinct() + } + query, args := selector.Query() + if err := gq.driver.Query(ctx, query, args, rows); err != nil { + return nil, err + } + defer rows.Close() + var grs Groups + if err := grs.FromRows(rows); err != nil { + return nil, err + } + grs.config(gq.config) + return grs, nil +} + +func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { + rows := &sql.Rows{} + selector := gq.sqlQuery() + unique := []string{group.FieldID} + if len(gq.unique) > 0 { + unique = gq.unique + } + selector.Count(sql.Distinct(selector.Columns(unique...)...)) + query, args := selector.Query() + if err := gq.driver.Query(ctx, query, args, rows); err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errors.New("ent: no rows found") + } + var n int + if err := rows.Scan(&n); err != nil { + return 0, fmt.Errorf("ent: failed reading count: %v", err) + } + return n, nil +} + +func (gq *GroupQuery) sqlExist(ctx context.Context) (bool, error) { + n, err := gq.sqlCount(ctx) + if err != nil { + return false, fmt.Errorf("ent: check existence: %v", err) + } + return n > 0, nil +} + +func (gq *GroupQuery) sqlIDs(ctx context.Context) ([]int, error) { + vs, err := gq.sqlAll(ctx) + if err != nil { + return nil, err + } + var ids []int + for _, v := range vs { + ids = append(ids, v.ID) + } + return ids, nil +} + +func (gq *GroupQuery) sqlQuery() *sql.Selector { + t1 := sql.Table(group.Table) + selector := sql.Select(t1.Columns(group.Columns...)...).From(t1) + if gq.sql != nil { + selector = gq.sql + selector.Select(selector.Columns(group.Columns...)...) + } + for _, p := range gq.predicates { + p(selector) + } + for _, p := range gq.order { + p(selector) + } + if offset := gq.offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt64) + } + if limit := gq.limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// GroupGroupBy is the builder for group-by Group entities. +type GroupGroupBy struct { + config + fields []string + fns []Aggregate + // intermediate queries. + sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (ggb *GroupGroupBy) Aggregate(fns ...Aggregate) *GroupGroupBy { + ggb.fns = append(ggb.fns, fns...) + return ggb +} + +// Scan applies the group-by query and scan the result into the given value. +func (ggb *GroupGroupBy) Scan(ctx context.Context, v interface{}) error { + return ggb.sqlScan(ctx, v) +} + +// ScanX is like Scan, but panics if an error occurs. +func (ggb *GroupGroupBy) ScanX(ctx context.Context, v interface{}) { + if err := ggb.Scan(ctx, v); err != nil { + panic(err) + } +} + +// Strings returns list of strings from group-by. It is only allowed when querying group-by with one field. +func (ggb *GroupGroupBy) Strings(ctx context.Context) ([]string, error) { + if len(ggb.fields) > 1 { + return nil, errors.New("ent: GroupGroupBy.Strings is not achievable when grouping more than 1 field") + } + var v []string + if err := ggb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// StringsX is like Strings, but panics if an error occurs. +func (ggb *GroupGroupBy) StringsX(ctx context.Context) []string { + v, err := ggb.Strings(ctx) + if err != nil { + panic(err) + } + return v +} + +// Ints returns list of ints from group-by. It is only allowed when querying group-by with one field. +func (ggb *GroupGroupBy) Ints(ctx context.Context) ([]int, error) { + if len(ggb.fields) > 1 { + return nil, errors.New("ent: GroupGroupBy.Ints is not achievable when grouping more than 1 field") + } + var v []int + if err := ggb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// IntsX is like Ints, but panics if an error occurs. +func (ggb *GroupGroupBy) IntsX(ctx context.Context) []int { + v, err := ggb.Ints(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64s returns list of float64s from group-by. It is only allowed when querying group-by with one field. +func (ggb *GroupGroupBy) Float64s(ctx context.Context) ([]float64, error) { + if len(ggb.fields) > 1 { + return nil, errors.New("ent: GroupGroupBy.Float64s is not achievable when grouping more than 1 field") + } + var v []float64 + if err := ggb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// Float64sX is like Float64s, but panics if an error occurs. +func (ggb *GroupGroupBy) Float64sX(ctx context.Context) []float64 { + v, err := ggb.Float64s(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bools returns list of bools from group-by. It is only allowed when querying group-by with one field. +func (ggb *GroupGroupBy) Bools(ctx context.Context) ([]bool, error) { + if len(ggb.fields) > 1 { + return nil, errors.New("ent: GroupGroupBy.Bools is not achievable when grouping more than 1 field") + } + var v []bool + if err := ggb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// BoolsX is like Bools, but panics if an error occurs. +func (ggb *GroupGroupBy) BoolsX(ctx context.Context) []bool { + v, err := ggb.Bools(ctx) + if err != nil { + panic(err) + } + return v +} + +func (ggb *GroupGroupBy) sqlScan(ctx context.Context, v interface{}) error { + rows := &sql.Rows{} + query, args := ggb.sqlQuery().Query() + if err := ggb.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { + selector := ggb.sql + columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) + columns = append(columns, ggb.fields...) + for _, fn := range ggb.fns { + columns = append(columns, fn.SQL(selector)) + } + return selector.Select(columns...).GroupBy(ggb.fields...) +} + +// GroupSelect is the builder for select fields of Group entities. +type GroupSelect struct { + config + fields []string + // intermediate queries. + sql *sql.Selector +} + +// Scan applies the selector query and scan the result into the given value. +func (gs *GroupSelect) Scan(ctx context.Context, v interface{}) error { + return gs.sqlScan(ctx, v) +} + +// ScanX is like Scan, but panics if an error occurs. +func (gs *GroupSelect) ScanX(ctx context.Context, v interface{}) { + if err := gs.Scan(ctx, v); err != nil { + panic(err) + } +} + +// Strings returns list of strings from selector. It is only allowed when selecting one field. +func (gs *GroupSelect) Strings(ctx context.Context) ([]string, error) { + if len(gs.fields) > 1 { + return nil, errors.New("ent: GroupSelect.Strings is not achievable when selecting more than 1 field") + } + var v []string + if err := gs.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// StringsX is like Strings, but panics if an error occurs. +func (gs *GroupSelect) StringsX(ctx context.Context) []string { + v, err := gs.Strings(ctx) + if err != nil { + panic(err) + } + return v +} + +// Ints returns list of ints from selector. It is only allowed when selecting one field. +func (gs *GroupSelect) Ints(ctx context.Context) ([]int, error) { + if len(gs.fields) > 1 { + return nil, errors.New("ent: GroupSelect.Ints is not achievable when selecting more than 1 field") + } + var v []int + if err := gs.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// IntsX is like Ints, but panics if an error occurs. +func (gs *GroupSelect) IntsX(ctx context.Context) []int { + v, err := gs.Ints(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64s returns list of float64s from selector. It is only allowed when selecting one field. +func (gs *GroupSelect) Float64s(ctx context.Context) ([]float64, error) { + if len(gs.fields) > 1 { + return nil, errors.New("ent: GroupSelect.Float64s is not achievable when selecting more than 1 field") + } + var v []float64 + if err := gs.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// Float64sX is like Float64s, but panics if an error occurs. +func (gs *GroupSelect) Float64sX(ctx context.Context) []float64 { + v, err := gs.Float64s(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bools returns list of bools from selector. It is only allowed when selecting one field. +func (gs *GroupSelect) Bools(ctx context.Context) ([]bool, error) { + if len(gs.fields) > 1 { + return nil, errors.New("ent: GroupSelect.Bools is not achievable when selecting more than 1 field") + } + var v []bool + if err := gs.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// BoolsX is like Bools, but panics if an error occurs. +func (gs *GroupSelect) BoolsX(ctx context.Context) []bool { + v, err := gs.Bools(ctx) + if err != nil { + panic(err) + } + return v +} + +func (gs *GroupSelect) sqlScan(ctx context.Context, v interface{}) error { + rows := &sql.Rows{} + query, args := gs.sqlQuery().Query() + if err := gs.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +func (gs *GroupSelect) sqlQuery() sql.Querier { + view := "group_view" + return sql.Select(gs.fields...).From(gs.sql.As(view)) +} diff --git a/entc/integration/template/ent/group_update.go b/entc/integration/template/ent/group_update.go new file mode 100644 index 000000000..00cf11d53 --- /dev/null +++ b/entc/integration/template/ent/group_update.go @@ -0,0 +1,226 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + + "github.com/facebookincubator/ent/entc/integration/template/ent/group" + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// GroupUpdate is the builder for updating Group entities. +type GroupUpdate struct { + config + max_users *int + addmax_users *int + predicates []predicate.Group +} + +// Where adds a new predicate for the builder. +func (gu *GroupUpdate) Where(ps ...predicate.Group) *GroupUpdate { + gu.predicates = append(gu.predicates, ps...) + return gu +} + +// SetMaxUsers sets the max_users field. +func (gu *GroupUpdate) SetMaxUsers(i int) *GroupUpdate { + gu.max_users = &i + return gu +} + +// AddMaxUsers adds i to max_users. +func (gu *GroupUpdate) AddMaxUsers(i int) *GroupUpdate { + gu.addmax_users = &i + return gu +} + +// Save executes the query and returns the number of rows/vertices matched by this operation. +func (gu *GroupUpdate) Save(ctx context.Context) (int, error) { + return gu.sqlSave(ctx) +} + +// SaveX is like Save, but panics if an error occurs. +func (gu *GroupUpdate) SaveX(ctx context.Context) int { + affected, err := gu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (gu *GroupUpdate) Exec(ctx context.Context) error { + _, err := gu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (gu *GroupUpdate) ExecX(ctx context.Context) { + if err := gu.Exec(ctx); err != nil { + panic(err) + } +} + +func (gu *GroupUpdate) sqlSave(ctx context.Context) (n int, err error) { + selector := sql.Select(group.FieldID).From(sql.Table(group.Table)) + for _, p := range gu.predicates { + p(selector) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err = gu.driver.Query(ctx, query, args, rows); err != nil { + return 0, err + } + defer rows.Close() + var ids []int + for rows.Next() { + var id int + if err := rows.Scan(&id); err != nil { + return 0, fmt.Errorf("ent: failed reading id: %v", err) + } + ids = append(ids, id) + } + if len(ids) == 0 { + return 0, nil + } + + tx, err := gu.driver.Tx(ctx) + if err != nil { + return 0, err + } + var ( + update bool + res sql.Result + builder = sql.Update(group.Table).Where(sql.InInts(group.FieldID, ids...)) + ) + if value := gu.max_users; value != nil { + update = true + builder.Set(group.FieldMaxUsers, *value) + } + if value := gu.addmax_users; value != nil { + update = true + builder.Add(group.FieldMaxUsers, *value) + } + if update { + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + } + if err = tx.Commit(); err != nil { + return 0, err + } + return len(ids), nil +} + +// GroupUpdateOne is the builder for updating a single Group entity. +type GroupUpdateOne struct { + config + id int + max_users *int + addmax_users *int +} + +// SetMaxUsers sets the max_users field. +func (guo *GroupUpdateOne) SetMaxUsers(i int) *GroupUpdateOne { + guo.max_users = &i + return guo +} + +// AddMaxUsers adds i to max_users. +func (guo *GroupUpdateOne) AddMaxUsers(i int) *GroupUpdateOne { + guo.addmax_users = &i + return guo +} + +// Save executes the query and returns the updated entity. +func (guo *GroupUpdateOne) Save(ctx context.Context) (*Group, error) { + return guo.sqlSave(ctx) +} + +// SaveX is like Save, but panics if an error occurs. +func (guo *GroupUpdateOne) SaveX(ctx context.Context) *Group { + gr, err := guo.Save(ctx) + if err != nil { + panic(err) + } + return gr +} + +// Exec executes the query on the entity. +func (guo *GroupUpdateOne) Exec(ctx context.Context) error { + _, err := guo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (guo *GroupUpdateOne) ExecX(ctx context.Context) { + if err := guo.Exec(ctx); err != nil { + panic(err) + } +} + +func (guo *GroupUpdateOne) sqlSave(ctx context.Context) (gr *Group, err error) { + selector := sql.Select(group.Columns...).From(sql.Table(group.Table)) + group.ID(guo.id)(selector) + rows := &sql.Rows{} + query, args := selector.Query() + if err = guo.driver.Query(ctx, query, args, rows); err != nil { + return nil, err + } + defer rows.Close() + var ids []int + for rows.Next() { + var id int + gr = &Group{config: guo.config} + if err := gr.FromRows(rows); err != nil { + return nil, fmt.Errorf("ent: failed scanning row into Group: %v", err) + } + id = gr.ID + ids = append(ids, id) + } + switch n := len(ids); { + case n == 0: + return nil, fmt.Errorf("ent: Group not found with id: %v", guo.id) + case n > 1: + return nil, fmt.Errorf("ent: more than one Group with the same id: %v", guo.id) + } + + tx, err := guo.driver.Tx(ctx) + if err != nil { + return nil, err + } + var ( + update bool + res sql.Result + builder = sql.Update(group.Table).Where(sql.InInts(group.FieldID, ids...)) + ) + if value := guo.max_users; value != nil { + update = true + builder.Set(group.FieldMaxUsers, *value) + gr.MaxUsers = *value + } + if value := guo.addmax_users; value != nil { + update = true + builder.Add(group.FieldMaxUsers, *value) + gr.MaxUsers += *value + } + if update { + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + if err = tx.Commit(); err != nil { + return nil, err + } + return gr, nil +} diff --git a/entc/integration/template/ent/migrate/migrate.go b/entc/integration/template/ent/migrate/migrate.go new file mode 100644 index 000000000..7991000b1 --- /dev/null +++ b/entc/integration/template/ent/migrate/migrate.go @@ -0,0 +1,71 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package migrate + +import ( + "context" + "fmt" + "io" + + "github.com/facebookincubator/ent/dialect" + "github.com/facebookincubator/ent/dialect/sql/schema" +) + +var ( + // WithGlobalUniqueID sets the universal ids options to the migration. + // If this option is enabled, ent migration will allocate a 1<<32 range + // for the ids of each entity (table). + // Note that this option cannot be applied on tables that already exist. + WithGlobalUniqueID = schema.WithGlobalUniqueID + // WithDropColumn sets the drop column option to the migration. + // If this option is enabled, ent migration will drop old columns + // that were used for both fields and edges. This defaults to false. + WithDropColumn = schema.WithDropColumn + // WithDropIndex sets the drop index option to the migration. + // If this option is enabled, ent migration will drop old indexes + // that were defined in the schema. This defaults to false. + // Note that unique constraints are defined using `UNIQUE INDEX`, + // and therefore, it's recommended to enable this option to get more + // flexibility in the schema changes. + WithDropIndex = schema.WithDropIndex +) + +// Schema is the API for creating, migrating and dropping a schema. +type Schema struct { + drv dialect.Driver + universalID bool +} + +// NewSchema creates a new schema client. +func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} } + +// Create creates all schema resources. +func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error { + migrate, err := schema.NewMigrate(s.drv, opts...) + if err != nil { + return fmt.Errorf("ent/migrate: %v", err) + } + return migrate.Create(ctx, Tables...) +} + +// WriteTo writes the schema changes to w instead of running them against the database. +// +// if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { +// log.Fatal(err) +// } +// +func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error { + drv := &schema.WriteDriver{ + Writer: w, + Driver: s.drv, + } + migrate, err := schema.NewMigrate(drv, opts...) + if err != nil { + return fmt.Errorf("ent/migrate: %v", err) + } + return migrate.Create(ctx, Tables...) +} diff --git a/entc/integration/template/ent/migrate/schema.go b/entc/integration/template/ent/migrate/schema.go new file mode 100644 index 000000000..b0dd0bd97 --- /dev/null +++ b/entc/integration/template/ent/migrate/schema.go @@ -0,0 +1,100 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package migrate + +import ( + "github.com/facebookincubator/ent/dialect/sql/schema" + "github.com/facebookincubator/ent/schema/field" +) + +var ( + // GroupsColumns holds the columns for the "groups" table. + GroupsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "max_users", Type: field.TypeInt}, + } + // GroupsTable holds the schema information for the "groups" table. + GroupsTable = &schema.Table{ + Name: "groups", + Columns: GroupsColumns, + PrimaryKey: []*schema.Column{GroupsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{}, + } + // PetsColumns holds the columns for the "pets" table. + PetsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "age", Type: field.TypeInt}, + {Name: "owner_id", Type: field.TypeInt, Nullable: true}, + } + // PetsTable holds the schema information for the "pets" table. + PetsTable = &schema.Table{ + Name: "pets", + Columns: PetsColumns, + PrimaryKey: []*schema.Column{PetsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "pets_users_pets", + Columns: []*schema.Column{PetsColumns[2]}, + + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + } + // UsersColumns holds the columns for the "users" table. + UsersColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "name", Type: field.TypeString}, + } + // UsersTable holds the schema information for the "users" table. + UsersTable = &schema.Table{ + Name: "users", + Columns: UsersColumns, + PrimaryKey: []*schema.Column{UsersColumns[0]}, + ForeignKeys: []*schema.ForeignKey{}, + } + // UserFriendsColumns holds the columns for the "user_friends" table. + UserFriendsColumns = []*schema.Column{ + {Name: "user_id", Type: field.TypeInt}, + {Name: "friend_id", Type: field.TypeInt}, + } + // UserFriendsTable holds the schema information for the "user_friends" table. + UserFriendsTable = &schema.Table{ + Name: "user_friends", + Columns: UserFriendsColumns, + PrimaryKey: []*schema.Column{UserFriendsColumns[0], UserFriendsColumns[1]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "user_friends_user_id", + Columns: []*schema.Column{UserFriendsColumns[0]}, + + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "user_friends_friend_id", + Columns: []*schema.Column{UserFriendsColumns[1]}, + + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + } + // Tables holds all the tables in the schema. + Tables = []*schema.Table{ + GroupsTable, + PetsTable, + UsersTable, + UserFriendsTable, + } +) + +func init() { + PetsTable.ForeignKeys[0].RefTable = UsersTable + UserFriendsTable.ForeignKeys[0].RefTable = UsersTable + UserFriendsTable.ForeignKeys[1].RefTable = UsersTable +} diff --git a/entc/integration/template/ent/node.go b/entc/integration/template/ent/node.go new file mode 100644 index 000000000..bb8d8ce5a --- /dev/null +++ b/entc/integration/template/ent/node.go @@ -0,0 +1,193 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "encoding/json" + "sync" + + "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/dialect/sql/schema" + "github.com/facebookincubator/ent/entc/integration/template/ent/group" + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + "github.com/facebookincubator/ent/entc/integration/template/ent/user" +) + +// Noder wraps the basic Node method. +type Noder interface { + Node(context.Context) (*Node, error) +} + +// Node in the graph. +type Node struct { + ID int `json:"id,omitemty"` // node id. + Type string `json:"type,omitempty"` // node type. + Fields []*Field `json:"fields,omitempty"` // node fields. + Edges []*Edge `json:"edges,omitempty"` // node edges. +} + +// Field of a node. +type Field struct { + Type string `json:"type,omitempty"` // field type. + Name string `json:"name,omitempty"` // field name (as in struct). + Value string `json:"value,omitempty"` // stringified value. +} + +// Edges between two nodes. +type Edge struct { + Type string `json:"type,omitempty"` // edge type. + Name string `json:"name,omitempty"` // edge name. + IDs []int `json:"ids,omitempty"` // node ids (where this edge point to). +} + +func (gr *Group) Node(ctx context.Context) (node *Node, err error) { + node = &Node{ + ID: gr.ID, + Type: "Group", + Fields: make([]*Field, 1), + Edges: make([]*Edge, 0), + } + var buf []byte + if buf, err = json.Marshal(gr.MaxUsers); err != nil { + return nil, err + } + node.Fields[0] = &Field{ + Type: "int", + Name: "MaxUsers", + Value: string(buf), + } + return node, nil +} + +func (pe *Pet) Node(ctx context.Context) (node *Node, err error) { + node = &Node{ + ID: pe.ID, + Type: "Pet", + Fields: make([]*Field, 1), + Edges: make([]*Edge, 1), + } + var buf []byte + if buf, err = json.Marshal(pe.Age); err != nil { + return nil, err + } + node.Fields[0] = &Field{ + Type: "int", + Name: "Age", + Value: string(buf), + } + var ids []int + ids, err = pe.QueryOwner(). + Select(user.FieldID). + Ints(ctx) + if err != nil { + return nil, err + } + node.Edges[0] = &Edge{ + IDs: ids, + Type: "User", + Name: "Owner", + } + return node, nil +} + +func (u *User) Node(ctx context.Context) (node *Node, err error) { + node = &Node{ + ID: u.ID, + Type: "User", + Fields: make([]*Field, 1), + Edges: make([]*Edge, 2), + } + var buf []byte + if buf, err = json.Marshal(u.Name); err != nil { + return nil, err + } + node.Fields[0] = &Field{ + Type: "string", + Name: "Name", + Value: string(buf), + } + var ids []int + ids, err = u.QueryPets(). + Select(pet.FieldID). + Ints(ctx) + if err != nil { + return nil, err + } + node.Edges[0] = &Edge{ + IDs: ids, + Type: "Pet", + Name: "Pets", + } + ids, err = u.QueryFriends(). + Select(user.FieldID). + Ints(ctx) + if err != nil { + return nil, err + } + node.Edges[1] = &Edge{ + IDs: ids, + Type: "User", + Name: "Friends", + } + return node, nil +} + +var ( + once sync.Once + types []string + typeNodes = make(map[string]func(ctx context.Context, id int) (*Node, error)) +) + +func (c *Client) Node(ctx context.Context, id int) (*Node, error) { + var err error + once.Do(func() { + err = c.loadTypes(ctx) + }) + if err != nil { + return nil, err + } + idx := id / (1<<32 - 1) + return typeNodes[types[idx]](ctx, id) +} + +func (c *Client) loadTypes(ctx context.Context) error { + rows := &sql.Rows{} + query, args := sql.Select("type"). + From(sql.Table(schema.TypeTable)). + OrderBy(sql.Asc("id")). + Query() + if err := c.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + if err := sql.ScanSlice(rows, &types); err != nil { + return err + } + typeNodes[group.Table] = func(ctx context.Context, id int) (*Node, error) { + nv, err := c.Group.Query().Get(ctx, id) + if err != nil { + return nil, err + } + return nv.Node(ctx) + } + typeNodes[pet.Table] = func(ctx context.Context, id int) (*Node, error) { + nv, err := c.Pet.Query().Get(ctx, id) + if err != nil { + return nil, err + } + return nv.Node(ctx) + } + typeNodes[user.Table] = func(ctx context.Context, id int) (*Node, error) { + nv, err := c.User.Query().Get(ctx, id) + if err != nil { + return nil, err + } + return nv.Node(ctx) + } + return nil +} diff --git a/entc/integration/template/ent/pet.go b/entc/integration/template/ent/pet.go new file mode 100644 index 000000000..0fbb14725 --- /dev/null +++ b/entc/integration/template/ent/pet.go @@ -0,0 +1,95 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "bytes" + "fmt" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// Pet is the model entity for the Pet schema. +type Pet struct { + config + // ID of the ent. + ID int `json:"id,omitempty"` + // Age holds the value of the "age" field. + Age int `json:"age,omitempty"` +} + +// FromRows scans the sql response data into Pet. +func (pe *Pet) FromRows(rows *sql.Rows) error { + var vpe struct { + ID int + Age sql.NullInt64 + } + // the order here should be the same as in the `pet.Columns`. + if err := rows.Scan( + &vpe.ID, + &vpe.Age, + ); err != nil { + return err + } + pe.ID = vpe.ID + pe.Age = int(vpe.Age.Int64) + return nil +} + +// QueryOwner queries the owner edge of the Pet. +func (pe *Pet) QueryOwner() *UserQuery { + return (&PetClient{pe.config}).QueryOwner(pe) +} + +// Update returns a builder for updating this Pet. +// Note that, you need to call Pet.Unwrap() before calling this method, if this Pet +// was returned from a transaction, and the transaction was committed or rolled back. +func (pe *Pet) Update() *PetUpdateOne { + return (&PetClient{pe.config}).UpdateOne(pe) +} + +// Unwrap unwraps the entity that was returned from a transaction after it was closed, +// so that all next queries will be executed through the driver which created the transaction. +func (pe *Pet) Unwrap() *Pet { + tx, ok := pe.config.driver.(*txDriver) + if !ok { + panic("ent: Pet is not a transactional entity") + } + pe.config.driver = tx.drv + return pe +} + +// String implements the fmt.Stringer. +func (pe *Pet) String() string { + buf := bytes.NewBuffer(nil) + buf.WriteString("Pet(") + buf.WriteString(fmt.Sprintf("id=%v", pe.ID)) + buf.WriteString(fmt.Sprintf(", age=%v", pe.Age)) + buf.WriteString(")") + return buf.String() +} + +// Pets is a parsable slice of Pet. +type Pets []*Pet + +// FromRows scans the sql response data into Pets. +func (pe *Pets) FromRows(rows *sql.Rows) error { + for rows.Next() { + vpe := &Pet{} + if err := vpe.FromRows(rows); err != nil { + return err + } + *pe = append(*pe, vpe) + } + return nil +} + +func (pe Pets) config(cfg config) { + for i := range pe { + pe[i].config = cfg + } +} diff --git a/entc/integration/template/ent/pet/pet.go b/entc/integration/template/ent/pet/pet.go new file mode 100644 index 000000000..f67600701 --- /dev/null +++ b/entc/integration/template/ent/pet/pet.go @@ -0,0 +1,32 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package pet + +const ( + // Label holds the string label denoting the pet type in the database. + Label = "pet" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldAge holds the string denoting the age vertex property in the database. + FieldAge = "age" + + // Table holds the table name of the pet in the database. + Table = "pets" + // OwnerTable is the table the holds the owner relation/edge. + OwnerTable = "pets" + // OwnerInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + OwnerInverseTable = "users" + // OwnerColumn is the table column denoting the owner relation/edge. + OwnerColumn = "owner_id" +) + +// Columns holds all SQL columns are pet fields. +var Columns = []string{ + FieldID, + FieldAge, +} diff --git a/entc/integration/template/ent/pet/where.go b/entc/integration/template/ent/pet/where.go new file mode 100644 index 000000000..cc014e5c1 --- /dev/null +++ b/entc/integration/template/ent/pet/where.go @@ -0,0 +1,273 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package pet + +import ( + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// ID filters vertices based on their identifier. +func ID(id int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldID), id)) + }, + ) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldID), id)) + }, + ) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldID), id)) + }, + ) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldID), id)) + }, + ) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldID), id)) + }, + ) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldID), id)) + }, + ) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldID), id)) + }, + ) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(ids) == 0 { + s.Where(sql.False()) + return + } + v := make([]interface{}, len(ids)) + for i := range v { + v[i] = ids[i] + } + s.Where(sql.In(s.C(FieldID), v...)) + }, + ) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(ids) == 0 { + s.Where(sql.False()) + return + } + v := make([]interface{}, len(ids)) + for i := range v { + v[i] = ids[i] + } + s.Where(sql.NotIn(s.C(FieldID), v...)) + }, + ) +} + +// Age applies equality check predicate on the "age" field. It's identical to AgeEQ. +func Age(v int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldAge), v)) + }, + ) +} + +// AgeEQ applies the EQ predicate on the "age" field. +func AgeEQ(v int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldAge), v)) + }, + ) +} + +// AgeNEQ applies the NEQ predicate on the "age" field. +func AgeNEQ(v int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldAge), v)) + }, + ) +} + +// AgeGT applies the GT predicate on the "age" field. +func AgeGT(v int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldAge), v)) + }, + ) +} + +// AgeGTE applies the GTE predicate on the "age" field. +func AgeGTE(v int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldAge), v)) + }, + ) +} + +// AgeLT applies the LT predicate on the "age" field. +func AgeLT(v int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldAge), v)) + }, + ) +} + +// AgeLTE applies the LTE predicate on the "age" field. +func AgeLTE(v int) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldAge), v)) + }, + ) +} + +// AgeIn applies the In predicate on the "age" field. +func AgeIn(vs ...int) predicate.Pet { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Pet( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(vs) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldAge), v...)) + }, + ) +} + +// AgeNotIn applies the NotIn predicate on the "age" field. +func AgeNotIn(vs ...int) predicate.Pet { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Pet( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(vs) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldAge), v...)) + }, + ) +} + +// HasOwner applies the HasEdge predicate on the "owner" edge. +func HasOwner() predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + t1 := s.Table() + s.Where(sql.NotNull(t1.C(OwnerColumn))) + }, + ) +} + +// HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). +func HasOwnerWith(preds ...predicate.User) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + t1 := s.Table() + t2 := sql.Select(FieldID).From(sql.Table(OwnerInverseTable)) + for _, p := range preds { + p(t2) + } + s.Where(sql.In(t1.C(OwnerColumn), t2)) + }, + ) +} + +// And groups list of predicates with the AND operator between them. +func And(predicates ...predicate.Pet) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + for _, p := range predicates { + p(s) + } + }, + ) +} + +// Or groups list of predicates with the OR operator between them. +func Or(predicates ...predicate.Pet) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + for i, p := range predicates { + if i > 0 { + s.Or() + } + p(s) + } + }, + ) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Pet) predicate.Pet { + return predicate.Pet( + func(s *sql.Selector) { + p(s.Not()) + }, + ) +} diff --git a/entc/integration/template/ent/pet_create.go b/entc/integration/template/ent/pet_create.go new file mode 100644 index 000000000..c76b29171 --- /dev/null +++ b/entc/integration/template/ent/pet_create.go @@ -0,0 +1,111 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// PetCreate is the builder for creating a Pet entity. +type PetCreate struct { + config + age *int + owner map[int]struct{} +} + +// SetAge sets the age field. +func (pc *PetCreate) SetAge(i int) *PetCreate { + pc.age = &i + return pc +} + +// SetOwnerID sets the owner edge to User by id. +func (pc *PetCreate) SetOwnerID(id int) *PetCreate { + if pc.owner == nil { + pc.owner = make(map[int]struct{}) + } + pc.owner[id] = struct{}{} + return pc +} + +// SetNillableOwnerID sets the owner edge to User by id if the given value is not nil. +func (pc *PetCreate) SetNillableOwnerID(id *int) *PetCreate { + if id != nil { + pc = pc.SetOwnerID(*id) + } + return pc +} + +// SetOwner sets the owner edge to User. +func (pc *PetCreate) SetOwner(u *User) *PetCreate { + return pc.SetOwnerID(u.ID) +} + +// Save creates the Pet in the database. +func (pc *PetCreate) Save(ctx context.Context) (*Pet, error) { + if pc.age == nil { + return nil, errors.New("ent: missing required field \"age\"") + } + if len(pc.owner) > 1 { + return nil, errors.New("ent: multiple assignments on a unique edge \"owner\"") + } + return pc.sqlSave(ctx) +} + +// SaveX calls Save and panics if Save returns an error. +func (pc *PetCreate) SaveX(ctx context.Context) *Pet { + v, err := pc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +func (pc *PetCreate) sqlSave(ctx context.Context) (*Pet, error) { + var ( + res sql.Result + pe = &Pet{config: pc.config} + ) + tx, err := pc.driver.Tx(ctx) + if err != nil { + return nil, err + } + builder := sql.Insert(pet.Table).Default(pc.driver.Dialect()) + if pc.age != nil { + builder.Set(pet.FieldAge, *pc.age) + pe.Age = *pc.age + } + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + id, err := res.LastInsertId() + if err != nil { + return nil, rollback(tx, err) + } + pe.ID = int(id) + if len(pc.owner) > 0 { + for eid := range pc.owner { + query, args := sql.Update(pet.OwnerTable). + Set(pet.OwnerColumn, eid). + Where(sql.EQ(pet.FieldID, id)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + } + if err := tx.Commit(); err != nil { + return nil, err + } + return pe, nil +} diff --git a/entc/integration/template/ent/pet_delete.go b/entc/integration/template/ent/pet_delete.go new file mode 100644 index 000000000..94629260a --- /dev/null +++ b/entc/integration/template/ent/pet_delete.go @@ -0,0 +1,65 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// PetDelete is the builder for deleting a Pet entity. +type PetDelete struct { + config + predicates []predicate.Pet +} + +// Where adds a new predicate for the builder. +func (pd *PetDelete) Where(ps ...predicate.Pet) *PetDelete { + pd.predicates = append(pd.predicates, ps...) + return pd +} + +// Exec executes the deletion query. +func (pd *PetDelete) Exec(ctx context.Context) error { + return pd.sqlExec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (pd *PetDelete) ExecX(ctx context.Context) { + if err := pd.Exec(ctx); err != nil { + panic(err) + } +} + +func (pd *PetDelete) sqlExec(ctx context.Context) error { + var res sql.Result + selector := sql.Select().From(sql.Table(pet.Table)) + for _, p := range pd.predicates { + p(selector) + } + query, args := sql.Delete(pet.Table).FromSelect(selector).Query() + return pd.driver.Exec(ctx, query, args, &res) +} + +// PetDeleteOne is the builder for deleting a single Pet entity. +type PetDeleteOne struct { + pd *PetDelete +} + +// Exec executes the deletion query. +func (pdo *PetDeleteOne) Exec(ctx context.Context) error { + return pdo.pd.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (pdo *PetDeleteOne) ExecX(ctx context.Context) { + pdo.pd.ExecX(ctx) +} diff --git a/entc/integration/template/ent/pet_query.go b/entc/integration/template/ent/pet_query.go new file mode 100644 index 000000000..b1e19d029 --- /dev/null +++ b/entc/integration/template/ent/pet_query.go @@ -0,0 +1,625 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "math" + + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + "github.com/facebookincubator/ent/entc/integration/template/ent/user" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// PetQuery is the builder for querying Pet entities. +type PetQuery struct { + config + limit *int + offset *int + order []Order + unique []string + predicates []predicate.Pet + // intermediate queries. + sql *sql.Selector +} + +// Where adds a new predicate for the builder. +func (pq *PetQuery) Where(ps ...predicate.Pet) *PetQuery { + pq.predicates = append(pq.predicates, ps...) + return pq +} + +// Limit adds a limit step to the query. +func (pq *PetQuery) Limit(limit int) *PetQuery { + pq.limit = &limit + return pq +} + +// Offset adds an offset step to the query. +func (pq *PetQuery) Offset(offset int) *PetQuery { + pq.offset = &offset + return pq +} + +// Order adds an order step to the query. +func (pq *PetQuery) Order(o ...Order) *PetQuery { + pq.order = append(pq.order, o...) + return pq +} + +// QueryOwner chains the current query on the owner edge. +func (pq *PetQuery) QueryOwner() *UserQuery { + query := &UserQuery{config: pq.config} + t1 := sql.Table(user.Table) + t2 := pq.sqlQuery() + t2.Select(t2.C(pet.OwnerColumn)) + query.sql = sql.Select(t1.Columns(user.Columns...)...). + From(t1). + Join(t2). + On(t1.C(user.FieldID), t2.C(pet.OwnerColumn)) + return query +} + +// Get returns a Pet entity by its id. +func (pq *PetQuery) Get(ctx context.Context, id int) (*Pet, error) { + return pq.Where(pet.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (pq *PetQuery) GetX(ctx context.Context, id int) *Pet { + pe, err := pq.Get(ctx, id) + if err != nil { + panic(err) + } + return pe +} + +// First returns the first Pet entity in the query. Returns *ErrNotFound when no pet was found. +func (pq *PetQuery) First(ctx context.Context) (*Pet, error) { + pes, err := pq.Limit(1).All(ctx) + if err != nil { + return nil, err + } + if len(pes) == 0 { + return nil, &ErrNotFound{pet.Label} + } + return pes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (pq *PetQuery) FirstX(ctx context.Context) *Pet { + pe, err := pq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return pe +} + +// FirstID returns the first Pet id in the query. Returns *ErrNotFound when no id was found. +func (pq *PetQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = pq.Limit(1).IDs(ctx); err != nil { + return + } + if len(ids) == 0 { + err = &ErrNotFound{pet.Label} + return + } + return ids[0], nil +} + +// FirstXID is like FirstID, but panics if an error occurs. +func (pq *PetQuery) FirstXID(ctx context.Context) int { + id, err := pq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns the only Pet entity in the query, returns an error if not exactly one entity was returned. +func (pq *PetQuery) Only(ctx context.Context) (*Pet, error) { + pes, err := pq.Limit(2).All(ctx) + if err != nil { + return nil, err + } + switch len(pes) { + case 1: + return pes[0], nil + case 0: + return nil, &ErrNotFound{pet.Label} + default: + return nil, &ErrNotSingular{pet.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (pq *PetQuery) OnlyX(ctx context.Context) *Pet { + pe, err := pq.Only(ctx) + if err != nil { + panic(err) + } + return pe +} + +// OnlyID returns the only Pet id in the query, returns an error if not exactly one id was returned. +func (pq *PetQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = pq.Limit(2).IDs(ctx); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &ErrNotFound{pet.Label} + default: + err = &ErrNotSingular{pet.Label} + } + return +} + +// OnlyXID is like OnlyID, but panics if an error occurs. +func (pq *PetQuery) OnlyXID(ctx context.Context) int { + id, err := pq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Pets. +func (pq *PetQuery) All(ctx context.Context) ([]*Pet, error) { + return pq.sqlAll(ctx) +} + +// AllX is like All, but panics if an error occurs. +func (pq *PetQuery) AllX(ctx context.Context) []*Pet { + pes, err := pq.All(ctx) + if err != nil { + panic(err) + } + return pes +} + +// IDs executes the query and returns a list of Pet ids. +func (pq *PetQuery) IDs(ctx context.Context) ([]int, error) { + return pq.sqlIDs(ctx) +} + +// IDsX is like IDs, but panics if an error occurs. +func (pq *PetQuery) IDsX(ctx context.Context) []int { + ids, err := pq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (pq *PetQuery) Count(ctx context.Context) (int, error) { + return pq.sqlCount(ctx) +} + +// CountX is like Count, but panics if an error occurs. +func (pq *PetQuery) CountX(ctx context.Context) int { + count, err := pq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (pq *PetQuery) Exist(ctx context.Context) (bool, error) { + return pq.sqlExist(ctx) +} + +// ExistX is like Exist, but panics if an error occurs. +func (pq *PetQuery) ExistX(ctx context.Context) bool { + exist, err := pq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the query builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (pq *PetQuery) Clone() *PetQuery { + return &PetQuery{ + config: pq.config, + limit: pq.limit, + offset: pq.offset, + order: append([]Order{}, pq.order...), + unique: append([]string{}, pq.unique...), + predicates: append([]predicate.Pet{}, pq.predicates...), + // clone intermediate queries. + sql: pq.sql.Clone(), + } +} + +// GroupBy used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// Age int `json:"age,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Pet.Query(). +// GroupBy(pet.FieldAge). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +// +func (pq *PetQuery) GroupBy(field string, fields ...string) *PetGroupBy { + group := &PetGroupBy{config: pq.config} + group.fields = append([]string{field}, fields...) + group.sql = pq.sqlQuery() + return group +} + +// Select one or more fields from the given query. +// +// Example: +// +// var v []struct { +// Age int `json:"age,omitempty"` +// } +// +// client.Pet.Query(). +// Select(pet.FieldAge). +// Scan(ctx, &v) +// +func (pq *PetQuery) Select(field string, fields ...string) *PetSelect { + selector := &PetSelect{config: pq.config} + selector.fields = append([]string{field}, fields...) + selector.sql = pq.sqlQuery() + return selector +} + +func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { + rows := &sql.Rows{} + selector := pq.sqlQuery() + if unique := pq.unique; len(unique) == 0 { + selector.Distinct() + } + query, args := selector.Query() + if err := pq.driver.Query(ctx, query, args, rows); err != nil { + return nil, err + } + defer rows.Close() + var pes Pets + if err := pes.FromRows(rows); err != nil { + return nil, err + } + pes.config(pq.config) + return pes, nil +} + +func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { + rows := &sql.Rows{} + selector := pq.sqlQuery() + unique := []string{pet.FieldID} + if len(pq.unique) > 0 { + unique = pq.unique + } + selector.Count(sql.Distinct(selector.Columns(unique...)...)) + query, args := selector.Query() + if err := pq.driver.Query(ctx, query, args, rows); err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errors.New("ent: no rows found") + } + var n int + if err := rows.Scan(&n); err != nil { + return 0, fmt.Errorf("ent: failed reading count: %v", err) + } + return n, nil +} + +func (pq *PetQuery) sqlExist(ctx context.Context) (bool, error) { + n, err := pq.sqlCount(ctx) + if err != nil { + return false, fmt.Errorf("ent: check existence: %v", err) + } + return n > 0, nil +} + +func (pq *PetQuery) sqlIDs(ctx context.Context) ([]int, error) { + vs, err := pq.sqlAll(ctx) + if err != nil { + return nil, err + } + var ids []int + for _, v := range vs { + ids = append(ids, v.ID) + } + return ids, nil +} + +func (pq *PetQuery) sqlQuery() *sql.Selector { + t1 := sql.Table(pet.Table) + selector := sql.Select(t1.Columns(pet.Columns...)...).From(t1) + if pq.sql != nil { + selector = pq.sql + selector.Select(selector.Columns(pet.Columns...)...) + } + for _, p := range pq.predicates { + p(selector) + } + for _, p := range pq.order { + p(selector) + } + if offset := pq.offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt64) + } + if limit := pq.limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// PetGroupBy is the builder for group-by Pet entities. +type PetGroupBy struct { + config + fields []string + fns []Aggregate + // intermediate queries. + sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (pgb *PetGroupBy) Aggregate(fns ...Aggregate) *PetGroupBy { + pgb.fns = append(pgb.fns, fns...) + return pgb +} + +// Scan applies the group-by query and scan the result into the given value. +func (pgb *PetGroupBy) Scan(ctx context.Context, v interface{}) error { + return pgb.sqlScan(ctx, v) +} + +// ScanX is like Scan, but panics if an error occurs. +func (pgb *PetGroupBy) ScanX(ctx context.Context, v interface{}) { + if err := pgb.Scan(ctx, v); err != nil { + panic(err) + } +} + +// Strings returns list of strings from group-by. It is only allowed when querying group-by with one field. +func (pgb *PetGroupBy) Strings(ctx context.Context) ([]string, error) { + if len(pgb.fields) > 1 { + return nil, errors.New("ent: PetGroupBy.Strings is not achievable when grouping more than 1 field") + } + var v []string + if err := pgb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// StringsX is like Strings, but panics if an error occurs. +func (pgb *PetGroupBy) StringsX(ctx context.Context) []string { + v, err := pgb.Strings(ctx) + if err != nil { + panic(err) + } + return v +} + +// Ints returns list of ints from group-by. It is only allowed when querying group-by with one field. +func (pgb *PetGroupBy) Ints(ctx context.Context) ([]int, error) { + if len(pgb.fields) > 1 { + return nil, errors.New("ent: PetGroupBy.Ints is not achievable when grouping more than 1 field") + } + var v []int + if err := pgb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// IntsX is like Ints, but panics if an error occurs. +func (pgb *PetGroupBy) IntsX(ctx context.Context) []int { + v, err := pgb.Ints(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64s returns list of float64s from group-by. It is only allowed when querying group-by with one field. +func (pgb *PetGroupBy) Float64s(ctx context.Context) ([]float64, error) { + if len(pgb.fields) > 1 { + return nil, errors.New("ent: PetGroupBy.Float64s is not achievable when grouping more than 1 field") + } + var v []float64 + if err := pgb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// Float64sX is like Float64s, but panics if an error occurs. +func (pgb *PetGroupBy) Float64sX(ctx context.Context) []float64 { + v, err := pgb.Float64s(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bools returns list of bools from group-by. It is only allowed when querying group-by with one field. +func (pgb *PetGroupBy) Bools(ctx context.Context) ([]bool, error) { + if len(pgb.fields) > 1 { + return nil, errors.New("ent: PetGroupBy.Bools is not achievable when grouping more than 1 field") + } + var v []bool + if err := pgb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// BoolsX is like Bools, but panics if an error occurs. +func (pgb *PetGroupBy) BoolsX(ctx context.Context) []bool { + v, err := pgb.Bools(ctx) + if err != nil { + panic(err) + } + return v +} + +func (pgb *PetGroupBy) sqlScan(ctx context.Context, v interface{}) error { + rows := &sql.Rows{} + query, args := pgb.sqlQuery().Query() + if err := pgb.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +func (pgb *PetGroupBy) sqlQuery() *sql.Selector { + selector := pgb.sql + columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) + columns = append(columns, pgb.fields...) + for _, fn := range pgb.fns { + columns = append(columns, fn.SQL(selector)) + } + return selector.Select(columns...).GroupBy(pgb.fields...) +} + +// PetSelect is the builder for select fields of Pet entities. +type PetSelect struct { + config + fields []string + // intermediate queries. + sql *sql.Selector +} + +// Scan applies the selector query and scan the result into the given value. +func (ps *PetSelect) Scan(ctx context.Context, v interface{}) error { + return ps.sqlScan(ctx, v) +} + +// ScanX is like Scan, but panics if an error occurs. +func (ps *PetSelect) ScanX(ctx context.Context, v interface{}) { + if err := ps.Scan(ctx, v); err != nil { + panic(err) + } +} + +// Strings returns list of strings from selector. It is only allowed when selecting one field. +func (ps *PetSelect) Strings(ctx context.Context) ([]string, error) { + if len(ps.fields) > 1 { + return nil, errors.New("ent: PetSelect.Strings is not achievable when selecting more than 1 field") + } + var v []string + if err := ps.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// StringsX is like Strings, but panics if an error occurs. +func (ps *PetSelect) StringsX(ctx context.Context) []string { + v, err := ps.Strings(ctx) + if err != nil { + panic(err) + } + return v +} + +// Ints returns list of ints from selector. It is only allowed when selecting one field. +func (ps *PetSelect) Ints(ctx context.Context) ([]int, error) { + if len(ps.fields) > 1 { + return nil, errors.New("ent: PetSelect.Ints is not achievable when selecting more than 1 field") + } + var v []int + if err := ps.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// IntsX is like Ints, but panics if an error occurs. +func (ps *PetSelect) IntsX(ctx context.Context) []int { + v, err := ps.Ints(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64s returns list of float64s from selector. It is only allowed when selecting one field. +func (ps *PetSelect) Float64s(ctx context.Context) ([]float64, error) { + if len(ps.fields) > 1 { + return nil, errors.New("ent: PetSelect.Float64s is not achievable when selecting more than 1 field") + } + var v []float64 + if err := ps.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// Float64sX is like Float64s, but panics if an error occurs. +func (ps *PetSelect) Float64sX(ctx context.Context) []float64 { + v, err := ps.Float64s(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bools returns list of bools from selector. It is only allowed when selecting one field. +func (ps *PetSelect) Bools(ctx context.Context) ([]bool, error) { + if len(ps.fields) > 1 { + return nil, errors.New("ent: PetSelect.Bools is not achievable when selecting more than 1 field") + } + var v []bool + if err := ps.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// BoolsX is like Bools, but panics if an error occurs. +func (ps *PetSelect) BoolsX(ctx context.Context) []bool { + v, err := ps.Bools(ctx) + if err != nil { + panic(err) + } + return v +} + +func (ps *PetSelect) sqlScan(ctx context.Context, v interface{}) error { + rows := &sql.Rows{} + query, args := ps.sqlQuery().Query() + if err := ps.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +func (ps *PetSelect) sqlQuery() sql.Querier { + view := "pet_view" + return sql.Select(ps.fields...).From(ps.sql.As(view)) +} diff --git a/entc/integration/template/ent/pet_update.go b/entc/integration/template/ent/pet_update.go new file mode 100644 index 000000000..337e7a0e6 --- /dev/null +++ b/entc/integration/template/ent/pet_update.go @@ -0,0 +1,334 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + "github.com/facebookincubator/ent/entc/integration/template/ent/user" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// PetUpdate is the builder for updating Pet entities. +type PetUpdate struct { + config + age *int + addage *int + owner map[int]struct{} + clearedOwner bool + predicates []predicate.Pet +} + +// Where adds a new predicate for the builder. +func (pu *PetUpdate) Where(ps ...predicate.Pet) *PetUpdate { + pu.predicates = append(pu.predicates, ps...) + return pu +} + +// SetAge sets the age field. +func (pu *PetUpdate) SetAge(i int) *PetUpdate { + pu.age = &i + return pu +} + +// AddAge adds i to age. +func (pu *PetUpdate) AddAge(i int) *PetUpdate { + pu.addage = &i + return pu +} + +// SetOwnerID sets the owner edge to User by id. +func (pu *PetUpdate) SetOwnerID(id int) *PetUpdate { + if pu.owner == nil { + pu.owner = make(map[int]struct{}) + } + pu.owner[id] = struct{}{} + return pu +} + +// SetNillableOwnerID sets the owner edge to User by id if the given value is not nil. +func (pu *PetUpdate) SetNillableOwnerID(id *int) *PetUpdate { + if id != nil { + pu = pu.SetOwnerID(*id) + } + return pu +} + +// SetOwner sets the owner edge to User. +func (pu *PetUpdate) SetOwner(u *User) *PetUpdate { + return pu.SetOwnerID(u.ID) +} + +// ClearOwner clears the owner edge to User. +func (pu *PetUpdate) ClearOwner() *PetUpdate { + pu.clearedOwner = true + return pu +} + +// Save executes the query and returns the number of rows/vertices matched by this operation. +func (pu *PetUpdate) Save(ctx context.Context) (int, error) { + if len(pu.owner) > 1 { + return 0, errors.New("ent: multiple assignments on a unique edge \"owner\"") + } + return pu.sqlSave(ctx) +} + +// SaveX is like Save, but panics if an error occurs. +func (pu *PetUpdate) SaveX(ctx context.Context) int { + affected, err := pu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (pu *PetUpdate) Exec(ctx context.Context) error { + _, err := pu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (pu *PetUpdate) ExecX(ctx context.Context) { + if err := pu.Exec(ctx); err != nil { + panic(err) + } +} + +func (pu *PetUpdate) sqlSave(ctx context.Context) (n int, err error) { + selector := sql.Select(pet.FieldID).From(sql.Table(pet.Table)) + for _, p := range pu.predicates { + p(selector) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err = pu.driver.Query(ctx, query, args, rows); err != nil { + return 0, err + } + defer rows.Close() + var ids []int + for rows.Next() { + var id int + if err := rows.Scan(&id); err != nil { + return 0, fmt.Errorf("ent: failed reading id: %v", err) + } + ids = append(ids, id) + } + if len(ids) == 0 { + return 0, nil + } + + tx, err := pu.driver.Tx(ctx) + if err != nil { + return 0, err + } + var ( + update bool + res sql.Result + builder = sql.Update(pet.Table).Where(sql.InInts(pet.FieldID, ids...)) + ) + if value := pu.age; value != nil { + update = true + builder.Set(pet.FieldAge, *value) + } + if value := pu.addage; value != nil { + update = true + builder.Add(pet.FieldAge, *value) + } + if update { + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + } + if pu.clearedOwner { + query, args := sql.Update(pet.OwnerTable). + SetNull(pet.OwnerColumn). + Where(sql.InInts(user.FieldID, ids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + } + if len(pu.owner) > 0 { + for eid := range pu.owner { + query, args := sql.Update(pet.OwnerTable). + Set(pet.OwnerColumn, eid). + Where(sql.InInts(pet.FieldID, ids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + } + } + if err = tx.Commit(); err != nil { + return 0, err + } + return len(ids), nil +} + +// PetUpdateOne is the builder for updating a single Pet entity. +type PetUpdateOne struct { + config + id int + age *int + addage *int + owner map[int]struct{} + clearedOwner bool +} + +// SetAge sets the age field. +func (puo *PetUpdateOne) SetAge(i int) *PetUpdateOne { + puo.age = &i + return puo +} + +// AddAge adds i to age. +func (puo *PetUpdateOne) AddAge(i int) *PetUpdateOne { + puo.addage = &i + return puo +} + +// SetOwnerID sets the owner edge to User by id. +func (puo *PetUpdateOne) SetOwnerID(id int) *PetUpdateOne { + if puo.owner == nil { + puo.owner = make(map[int]struct{}) + } + puo.owner[id] = struct{}{} + return puo +} + +// SetNillableOwnerID sets the owner edge to User by id if the given value is not nil. +func (puo *PetUpdateOne) SetNillableOwnerID(id *int) *PetUpdateOne { + if id != nil { + puo = puo.SetOwnerID(*id) + } + return puo +} + +// SetOwner sets the owner edge to User. +func (puo *PetUpdateOne) SetOwner(u *User) *PetUpdateOne { + return puo.SetOwnerID(u.ID) +} + +// ClearOwner clears the owner edge to User. +func (puo *PetUpdateOne) ClearOwner() *PetUpdateOne { + puo.clearedOwner = true + return puo +} + +// Save executes the query and returns the updated entity. +func (puo *PetUpdateOne) Save(ctx context.Context) (*Pet, error) { + if len(puo.owner) > 1 { + return nil, errors.New("ent: multiple assignments on a unique edge \"owner\"") + } + return puo.sqlSave(ctx) +} + +// SaveX is like Save, but panics if an error occurs. +func (puo *PetUpdateOne) SaveX(ctx context.Context) *Pet { + pe, err := puo.Save(ctx) + if err != nil { + panic(err) + } + return pe +} + +// Exec executes the query on the entity. +func (puo *PetUpdateOne) Exec(ctx context.Context) error { + _, err := puo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (puo *PetUpdateOne) ExecX(ctx context.Context) { + if err := puo.Exec(ctx); err != nil { + panic(err) + } +} + +func (puo *PetUpdateOne) sqlSave(ctx context.Context) (pe *Pet, err error) { + selector := sql.Select(pet.Columns...).From(sql.Table(pet.Table)) + pet.ID(puo.id)(selector) + rows := &sql.Rows{} + query, args := selector.Query() + if err = puo.driver.Query(ctx, query, args, rows); err != nil { + return nil, err + } + defer rows.Close() + var ids []int + for rows.Next() { + var id int + pe = &Pet{config: puo.config} + if err := pe.FromRows(rows); err != nil { + return nil, fmt.Errorf("ent: failed scanning row into Pet: %v", err) + } + id = pe.ID + ids = append(ids, id) + } + switch n := len(ids); { + case n == 0: + return nil, fmt.Errorf("ent: Pet not found with id: %v", puo.id) + case n > 1: + return nil, fmt.Errorf("ent: more than one Pet with the same id: %v", puo.id) + } + + tx, err := puo.driver.Tx(ctx) + if err != nil { + return nil, err + } + var ( + update bool + res sql.Result + builder = sql.Update(pet.Table).Where(sql.InInts(pet.FieldID, ids...)) + ) + if value := puo.age; value != nil { + update = true + builder.Set(pet.FieldAge, *value) + pe.Age = *value + } + if value := puo.addage; value != nil { + update = true + builder.Add(pet.FieldAge, *value) + pe.Age += *value + } + if update { + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + if puo.clearedOwner { + query, args := sql.Update(pet.OwnerTable). + SetNull(pet.OwnerColumn). + Where(sql.InInts(user.FieldID, ids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + if len(puo.owner) > 0 { + for eid := range puo.owner { + query, args := sql.Update(pet.OwnerTable). + Set(pet.OwnerColumn, eid). + Where(sql.InInts(pet.FieldID, ids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + } + if err = tx.Commit(); err != nil { + return nil, err + } + return pe, nil +} diff --git a/entc/integration/template/ent/predicate/predicate.go b/entc/integration/template/ent/predicate/predicate.go new file mode 100644 index 000000000..a904498f5 --- /dev/null +++ b/entc/integration/template/ent/predicate/predicate.go @@ -0,0 +1,20 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package predicate + +import ( + "github.com/facebookincubator/ent/dialect/sql" +) + +// Group is the predicate function for group builders. +type Group func(*sql.Selector) + +// Pet is the predicate function for pet builders. +type Pet func(*sql.Selector) + +// User is the predicate function for user builders. +type User func(*sql.Selector) diff --git a/entc/integration/template/ent/schema/group.go b/entc/integration/template/ent/schema/group.go new file mode 100644 index 000000000..871bd8c04 --- /dev/null +++ b/entc/integration/template/ent/schema/group.go @@ -0,0 +1,27 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package schema + +import ( + "github.com/facebookincubator/ent" + "github.com/facebookincubator/ent/schema/field" +) + +// Group holds the schema definition for the Group entity. +type Group struct { + ent.Schema +} + +// Fields of the Group. +func (Group) Fields() []ent.Field { + return []ent.Field{ + field.Int("max_users"), + } +} + +// Edges of the Group. +func (Group) Edges() []ent.Edge { + return nil +} diff --git a/entc/integration/template/ent/schema/pet.go b/entc/integration/template/ent/schema/pet.go new file mode 100644 index 000000000..e29ca4ed5 --- /dev/null +++ b/entc/integration/template/ent/schema/pet.go @@ -0,0 +1,32 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package schema + +import ( + "github.com/facebookincubator/ent" + "github.com/facebookincubator/ent/schema/edge" + "github.com/facebookincubator/ent/schema/field" +) + +// Pet holds the schema definition for the Pet entity. +type Pet struct { + ent.Schema +} + +// Fields of the Pet. +func (Pet) Fields() []ent.Field { + return []ent.Field{ + field.Int("age"), + } +} + +// Edges of the Pet. +func (Pet) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("owner", User.Type). + Ref("pets"). + Unique(), + } +} diff --git a/entc/integration/template/ent/schema/user.go b/entc/integration/template/ent/schema/user.go new file mode 100644 index 000000000..c4c527dca --- /dev/null +++ b/entc/integration/template/ent/schema/user.go @@ -0,0 +1,31 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package schema + +import ( + "github.com/facebookincubator/ent" + "github.com/facebookincubator/ent/schema/edge" + "github.com/facebookincubator/ent/schema/field" +) + +// User holds the schema definition for the User entity. +type User struct { + ent.Schema +} + +// Fields of the User. +func (User) Fields() []ent.Field { + return []ent.Field{ + field.String("name"), + } +} + +// Edges of the User. +func (User) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("pets", Pet.Type), + edge.To("friends", User.Type), + } +} diff --git a/entc/integration/template/ent/template/node.tmpl b/entc/integration/template/ent/template/node.tmpl new file mode 100644 index 000000000..c080e7e38 --- /dev/null +++ b/entc/integration/template/ent/template/node.tmpl @@ -0,0 +1,133 @@ +{{/* +Copyright 2019-present Facebook Inc. All rights reserved. +This source code is licensed under the Apache 2.0 license found +in the LICENSE file in the root directory of this source tree. +*/}} + +{{ define "node" }} +{{ $pkg := base $.Config.Package }} +{{ template "header" $ }} + +// Noder wraps the basic Node method. +type Noder interface { + Node(context.Context) (*Node, error) +} + +// Node in the graph. +type Node struct { + ID {{ $.IDType }} `json:"id,omitemty"` // node id. + Type string `json:"type,omitempty"` // node type. + Fields []*Field `json:"fields,omitempty"` // node fields. + Edges []*Edge `json:"edges,omitempty"` // node edges. +} + +// Field of a node. +type Field struct { + Type string `json:"type,omitempty"` // field type. + Name string `json:"name,omitempty"` // field name (as in struct). + Value string `json:"value,omitempty"` // stringified value. +} + +// Edges between two nodes. +type Edge struct { + Type string `json:"type,omitempty"` // edge type. + Name string `json:"name,omitempty"` // edge name. + IDs []{{ $.IDType }} `json:"ids,omitempty"` // node ids (where this edge point to). +} + +{{/* loop over all types and add implement the Node interface. */}} +{{ range $_, $n := $.Nodes -}} + {{ $receiver := $n.Receiver }} + func ({{ $receiver }} *{{ $n.Name }}) Node(ctx context.Context) (node *Node, err error) { + node = &Node{ + ID: {{ $receiver }}.ID, + Type: "{{ $n.Name }}", + Fields: make([]*Field, {{ len $n.Fields }}), + Edges: make([]*Edge, {{ len $n.Edges }}), + } + {{- with $n.Fields }} + var buf []byte + {{- range $i, $f := $n.Fields }} + if buf, err = json.Marshal({{ $receiver }}.{{ pascal $f.Name }}); err != nil { + return nil, err + } + node.Fields[{{ $i }}] = &Field{ + Type: "{{ $f.Type }}", + Name: "{{ pascal $f.Name }}", + Value: string(buf), + } + {{- end }} + {{- end }} + {{- with $n.Edges }} + var ids []{{ $.IDType }} + {{- range $i, $e := $n.Edges }} + ids, err = {{ $receiver }}.{{ print "Query" (pascal $e.Name) }}(). + Select({{ $e.Type.Package }}.FieldID). + {{ pascal $.IDType.String }}s(ctx) + if err != nil { + return nil, err + } + node.Edges[{{ $i }}] = &Edge{ + IDs: ids, + Type: "{{ $e.Type.Name }}", + Name: "{{ pascal $e.Name }}", + } + {{- end }} + {{- end }} + return node, nil + } +{{ end }} + +{{/* add the node api to the client */}} + +var ( + once sync.Once + types []string + typeNodes = make(map[string]func(ctx context.Context, id {{ $.IDType }})(*Node, error)) +) + +func (c *Client) Node(ctx context.Context, id {{ $.IDType }}) (*Node, error) { + var err error + once.Do(func() { + err = c.loadTypes(ctx) + }) + if err != nil { + return nil, err + } + {{- if not $.IDType.Numeric }} + idv, err := strconv.Atoi(id) + if err != nil { + return nil, err + } + idx := idv/(1<<32 - 1) + {{- else }} + idx := id/(1<<32 - 1) + {{- end }} + return typeNodes[types[idx]](ctx, id) +} + +func (c *Client) loadTypes(ctx context.Context) error { + rows := &sql.Rows{} + query, args := sql.Select("type"). + From(sql.Table(schema.TypeTable)). + OrderBy(sql.Asc("id")). + Query() + if err := c.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + if err := sql.ScanSlice(rows, &types); err != nil { + return err + } + {{- range $_, $n := $.Nodes }} + typeNodes[{{ $n.Package }}.Table] = func(ctx context.Context, id {{ $.IDType }})(*Node, error) { + nv, err := c.{{ $n.Name }}.Query().Get(ctx, id) + if err != nil { + return nil, err + } + return nv.Node(ctx) + } + {{- end }} + return nil +} +{{ end }} \ No newline at end of file diff --git a/entc/integration/template/ent/tx.go b/entc/integration/template/ent/tx.go new file mode 100644 index 000000000..0ee09a46a --- /dev/null +++ b/entc/integration/template/ent/tx.go @@ -0,0 +1,103 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + + "github.com/facebookincubator/ent/dialect" + "github.com/facebookincubator/ent/entc/integration/template/ent/migrate" +) + +// Tx is a transactional client that is created by calling Client.Tx(). +type Tx struct { + config + // Group is the client for interacting with the Group builders. + Group *GroupClient + // Pet is the client for interacting with the Pet builders. + Pet *PetClient + // User is the client for interacting with the User builders. + User *UserClient +} + +// Commit commits the transaction. +func (tx *Tx) Commit() error { + return tx.config.driver.(*txDriver).tx.Commit() +} + +// Rollback rollbacks the transaction. +func (tx *Tx) Rollback() error { + return tx.config.driver.(*txDriver).tx.Rollback() +} + +// Client returns a Client that binds to current transaction. +func (tx *Tx) Client() *Client { + return &Client{ + config: tx.config, + Schema: migrate.NewSchema(tx.driver), + Group: NewGroupClient(tx.config), + Pet: NewPetClient(tx.config), + User: NewUserClient(tx.config), + } +} + +// txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. +// The idea is to support transactions without adding any extra code to the builders. +// When a builder calls to driver.Tx(), it gets the same dialect.Tx instance. +// Commit and Rollback are nop for the internal builders and the user must call one +// of them in order to commit or rollback the transaction. +// +// If a closed transaction is embedded in one of the generated entities, and the entity +// applies a query, for example: Group.QueryXXX(), the query will be executed +// through the driver which created this transaction. +// +// Note that txDriver is not goroutine safe. +type txDriver struct { + // the driver we started the transaction from. + drv dialect.Driver + // tx is the underlying transaction. + tx dialect.Tx +} + +// newTx creates a new transactional driver. +func newTx(ctx context.Context, drv dialect.Driver) (*txDriver, error) { + tx, err := drv.Tx(ctx) + if err != nil { + return nil, err + } + return &txDriver{tx: tx, drv: drv}, nil +} + +// Tx returns the transaction wrapper (txDriver) to avoid Commit or Rollback calls +// from the internal builders. Should be called only by the internal builders. +func (tx *txDriver) Tx(context.Context) (dialect.Tx, error) { return tx, nil } + +// Dialect returns the dialect of the driver we started the transaction from. +func (tx *txDriver) Dialect() string { return tx.drv.Dialect() } + +// Close is a nop close. +func (*txDriver) Close() error { return nil } + +// Commit is a nop commit for the internal builders. +// User must call `Tx.Commit` in order to commit the transaction. +func (*txDriver) Commit() error { return nil } + +// Rollback is a nop rollback for the internal builders. +// User must call `Tx.Rollback` in order to rollback the transaction. +func (*txDriver) Rollback() error { return nil } + +// Exec calls tx.Exec. +func (tx *txDriver) Exec(ctx context.Context, query string, args, v interface{}) error { + return tx.tx.Exec(ctx, query, args, v) +} + +// Query calls tx.Query. +func (tx *txDriver) Query(ctx context.Context, query string, args, v interface{}) error { + return tx.tx.Query(ctx, query, args, v) +} + +var _ dialect.Driver = (*txDriver)(nil) diff --git a/entc/integration/template/ent/user.go b/entc/integration/template/ent/user.go new file mode 100644 index 000000000..28cbdda5a --- /dev/null +++ b/entc/integration/template/ent/user.go @@ -0,0 +1,100 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "bytes" + "fmt" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// User is the model entity for the User schema. +type User struct { + config + // ID of the ent. + ID int `json:"id,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` +} + +// FromRows scans the sql response data into User. +func (u *User) FromRows(rows *sql.Rows) error { + var vu struct { + ID int + Name sql.NullString + } + // the order here should be the same as in the `user.Columns`. + if err := rows.Scan( + &vu.ID, + &vu.Name, + ); err != nil { + return err + } + u.ID = vu.ID + u.Name = vu.Name.String + return nil +} + +// QueryPets queries the pets edge of the User. +func (u *User) QueryPets() *PetQuery { + return (&UserClient{u.config}).QueryPets(u) +} + +// QueryFriends queries the friends edge of the User. +func (u *User) QueryFriends() *UserQuery { + return (&UserClient{u.config}).QueryFriends(u) +} + +// Update returns a builder for updating this User. +// Note that, you need to call User.Unwrap() before calling this method, if this User +// was returned from a transaction, and the transaction was committed or rolled back. +func (u *User) Update() *UserUpdateOne { + return (&UserClient{u.config}).UpdateOne(u) +} + +// Unwrap unwraps the entity that was returned from a transaction after it was closed, +// so that all next queries will be executed through the driver which created the transaction. +func (u *User) Unwrap() *User { + tx, ok := u.config.driver.(*txDriver) + if !ok { + panic("ent: User is not a transactional entity") + } + u.config.driver = tx.drv + return u +} + +// String implements the fmt.Stringer. +func (u *User) String() string { + buf := bytes.NewBuffer(nil) + buf.WriteString("User(") + buf.WriteString(fmt.Sprintf("id=%v", u.ID)) + buf.WriteString(fmt.Sprintf(", name=%v", u.Name)) + buf.WriteString(")") + return buf.String() +} + +// Users is a parsable slice of User. +type Users []*User + +// FromRows scans the sql response data into Users. +func (u *Users) FromRows(rows *sql.Rows) error { + for rows.Next() { + vu := &User{} + if err := vu.FromRows(rows); err != nil { + return err + } + *u = append(*u, vu) + } + return nil +} + +func (u Users) config(cfg config) { + for i := range u { + u[i].config = cfg + } +} diff --git a/entc/integration/template/ent/user/user.go b/entc/integration/template/ent/user/user.go new file mode 100644 index 000000000..a89160295 --- /dev/null +++ b/entc/integration/template/ent/user/user.go @@ -0,0 +1,40 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package user + +const ( + // Label holds the string label denoting the user type in the database. + Label = "user" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldName holds the string denoting the name vertex property in the database. + FieldName = "name" + + // Table holds the table name of the user in the database. + Table = "users" + // PetsTable is the table the holds the pets relation/edge. + PetsTable = "pets" + // PetsInverseTable is the table name for the Pet entity. + // It exists in this package in order to avoid circular dependency with the "pet" package. + PetsInverseTable = "pets" + // PetsColumn is the table column denoting the pets relation/edge. + PetsColumn = "owner_id" + // FriendsTable is the table the holds the friends relation/edge. The primary key declared below. + FriendsTable = "user_friends" +) + +// Columns holds all SQL columns are user fields. +var Columns = []string{ + FieldID, + FieldName, +} + +var ( + // FriendsPrimaryKey and FriendsColumn2 are the table columns denoting the + // primary key for the friends relation (M2M). + FriendsPrimaryKey = []string{"user_id", "friend_id"} +) diff --git a/entc/integration/template/ent/user/where.go b/entc/integration/template/ent/user/where.go new file mode 100644 index 000000000..0b51179c8 --- /dev/null +++ b/entc/integration/template/ent/user/where.go @@ -0,0 +1,361 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package user + +import ( + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// ID filters vertices based on their identifier. +func ID(id int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldID), id)) + }, + ) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldID), id)) + }, + ) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldID), id)) + }, + ) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldID), id)) + }, + ) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldID), id)) + }, + ) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldID), id)) + }, + ) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldID), id)) + }, + ) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(ids) == 0 { + s.Where(sql.False()) + return + } + v := make([]interface{}, len(ids)) + for i := range v { + v[i] = ids[i] + } + s.Where(sql.In(s.C(FieldID), v...)) + }, + ) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.User { + return predicate.User( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(ids) == 0 { + s.Where(sql.False()) + return + } + v := make([]interface{}, len(ids)) + for i := range v { + v[i] = ids[i] + } + s.Where(sql.NotIn(s.C(FieldID), v...)) + }, + ) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldName), v)) + }, + ) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldName), v)) + }, + ) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldName), v)) + }, + ) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldName), v)) + }, + ) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldName), v)) + }, + ) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldName), v)) + }, + ) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldName), v)) + }, + ) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.User { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(vs) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldName), v...)) + }, + ) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.User { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User( + func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(vs) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldName), v...)) + }, + ) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldName), v)) + }, + ) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldName), v)) + }, + ) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldName), v)) + }, + ) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldName), v)) + }, + ) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.User { + return predicate.User( + func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldName), v)) + }, + ) +} + +// HasPets applies the HasEdge predicate on the "pets" edge. +func HasPets() predicate.User { + return predicate.User( + func(s *sql.Selector) { + t1 := s.Table() + s.Where( + sql.In( + t1.C(FieldID), + sql.Select(PetsColumn). + From(sql.Table(PetsTable)). + Where(sql.NotNull(PetsColumn)), + ), + ) + }, + ) +} + +// HasPetsWith applies the HasEdge predicate on the "pets" edge with a given conditions (other predicates). +func HasPetsWith(preds ...predicate.Pet) predicate.User { + return predicate.User( + func(s *sql.Selector) { + t1 := s.Table() + t2 := sql.Select(PetsColumn).From(sql.Table(PetsTable)) + for _, p := range preds { + p(t2) + } + s.Where(sql.In(t1.C(FieldID), t2)) + }, + ) +} + +// HasFriends applies the HasEdge predicate on the "friends" edge. +func HasFriends() predicate.User { + return predicate.User( + func(s *sql.Selector) { + t1 := s.Table() + s.Where( + sql.In( + t1.C(FieldID), + sql.Select(FriendsPrimaryKey[0]).From(sql.Table(FriendsTable)), + ), + ) + }, + ) +} + +// HasFriendsWith applies the HasEdge predicate on the "friends" edge with a given conditions (other predicates). +func HasFriendsWith(preds ...predicate.User) predicate.User { + return predicate.User( + func(s *sql.Selector) { + t1 := s.Table() + t2 := sql.Table(Table) + t3 := sql.Table(FriendsTable) + t4 := sql.Select(t3.C(FriendsPrimaryKey[0])). + From(t3). + Join(t2). + On(t3.C(FriendsPrimaryKey[1]), t2.C(FieldID)) + t5 := sql.Select().From(t2) + for _, p := range preds { + p(t5) + } + t4.FromSelect(t5) + s.Where(sql.In(t1.C(FieldID), t4)) + }, + ) +} + +// And groups list of predicates with the AND operator between them. +func And(predicates ...predicate.User) predicate.User { + return predicate.User( + func(s *sql.Selector) { + for _, p := range predicates { + p(s) + } + }, + ) +} + +// Or groups list of predicates with the OR operator between them. +func Or(predicates ...predicate.User) predicate.User { + return predicate.User( + func(s *sql.Selector) { + for i, p := range predicates { + if i > 0 { + s.Or() + } + p(s) + } + }, + ) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.User) predicate.User { + return predicate.User( + func(s *sql.Selector) { + p(s.Not()) + }, + ) +} diff --git a/entc/integration/template/ent/user_create.go b/entc/integration/template/ent/user_create.go new file mode 100644 index 000000000..fb288f94c --- /dev/null +++ b/entc/integration/template/ent/user_create.go @@ -0,0 +1,151 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + "github.com/facebookincubator/ent/entc/integration/template/ent/user" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// UserCreate is the builder for creating a User entity. +type UserCreate struct { + config + name *string + pets map[int]struct{} + friends map[int]struct{} +} + +// SetName sets the name field. +func (uc *UserCreate) SetName(s string) *UserCreate { + uc.name = &s + return uc +} + +// AddPetIDs adds the pets edge to Pet by ids. +func (uc *UserCreate) AddPetIDs(ids ...int) *UserCreate { + if uc.pets == nil { + uc.pets = make(map[int]struct{}) + } + for i := range ids { + uc.pets[ids[i]] = struct{}{} + } + return uc +} + +// AddPets adds the pets edges to Pet. +func (uc *UserCreate) AddPets(p ...*Pet) *UserCreate { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uc.AddPetIDs(ids...) +} + +// AddFriendIDs adds the friends edge to User by ids. +func (uc *UserCreate) AddFriendIDs(ids ...int) *UserCreate { + if uc.friends == nil { + uc.friends = make(map[int]struct{}) + } + for i := range ids { + uc.friends[ids[i]] = struct{}{} + } + return uc +} + +// AddFriends adds the friends edges to User. +func (uc *UserCreate) AddFriends(u ...*User) *UserCreate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return uc.AddFriendIDs(ids...) +} + +// Save creates the User in the database. +func (uc *UserCreate) Save(ctx context.Context) (*User, error) { + if uc.name == nil { + return nil, errors.New("ent: missing required field \"name\"") + } + return uc.sqlSave(ctx) +} + +// SaveX calls Save and panics if Save returns an error. +func (uc *UserCreate) SaveX(ctx context.Context) *User { + v, err := uc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +func (uc *UserCreate) sqlSave(ctx context.Context) (*User, error) { + var ( + res sql.Result + u = &User{config: uc.config} + ) + tx, err := uc.driver.Tx(ctx) + if err != nil { + return nil, err + } + builder := sql.Insert(user.Table).Default(uc.driver.Dialect()) + if uc.name != nil { + builder.Set(user.FieldName, *uc.name) + u.Name = *uc.name + } + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + id, err := res.LastInsertId() + if err != nil { + return nil, rollback(tx, err) + } + u.ID = int(id) + if len(uc.pets) > 0 { + p := sql.P() + for eid := range uc.pets { + p.Or().EQ(pet.FieldID, eid) + } + query, args := sql.Update(user.PetsTable). + Set(user.PetsColumn, id). + Where(sql.And(p, sql.IsNull(user.PetsColumn))). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + affected, err := res.RowsAffected() + if err != nil { + return nil, rollback(tx, err) + } + if int(affected) < len(uc.pets) { + return nil, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("one of \"pets\" %v already connected to a different \"User\"", keys(uc.pets))}) + } + } + if len(uc.friends) > 0 { + for eid := range uc.friends { + + query, args := sql.Insert(user.FriendsTable). + Columns(user.FriendsPrimaryKey[0], user.FriendsPrimaryKey[1]). + Values(id, eid). + Values(eid, id). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + } + if err := tx.Commit(); err != nil { + return nil, err + } + return u, nil +} diff --git a/entc/integration/template/ent/user_delete.go b/entc/integration/template/ent/user_delete.go new file mode 100644 index 000000000..d98f18c92 --- /dev/null +++ b/entc/integration/template/ent/user_delete.go @@ -0,0 +1,65 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + "github.com/facebookincubator/ent/entc/integration/template/ent/user" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// UserDelete is the builder for deleting a User entity. +type UserDelete struct { + config + predicates []predicate.User +} + +// Where adds a new predicate for the builder. +func (ud *UserDelete) Where(ps ...predicate.User) *UserDelete { + ud.predicates = append(ud.predicates, ps...) + return ud +} + +// Exec executes the deletion query. +func (ud *UserDelete) Exec(ctx context.Context) error { + return ud.sqlExec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (ud *UserDelete) ExecX(ctx context.Context) { + if err := ud.Exec(ctx); err != nil { + panic(err) + } +} + +func (ud *UserDelete) sqlExec(ctx context.Context) error { + var res sql.Result + selector := sql.Select().From(sql.Table(user.Table)) + for _, p := range ud.predicates { + p(selector) + } + query, args := sql.Delete(user.Table).FromSelect(selector).Query() + return ud.driver.Exec(ctx, query, args, &res) +} + +// UserDeleteOne is the builder for deleting a single User entity. +type UserDeleteOne struct { + ud *UserDelete +} + +// Exec executes the deletion query. +func (udo *UserDeleteOne) Exec(ctx context.Context) error { + return udo.ud.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (udo *UserDeleteOne) ExecX(ctx context.Context) { + udo.ud.ExecX(ctx) +} diff --git a/entc/integration/template/ent/user_query.go b/entc/integration/template/ent/user_query.go new file mode 100644 index 000000000..7e7f56c2a --- /dev/null +++ b/entc/integration/template/ent/user_query.go @@ -0,0 +1,643 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "math" + + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + "github.com/facebookincubator/ent/entc/integration/template/ent/user" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// UserQuery is the builder for querying User entities. +type UserQuery struct { + config + limit *int + offset *int + order []Order + unique []string + predicates []predicate.User + // intermediate queries. + sql *sql.Selector +} + +// Where adds a new predicate for the builder. +func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { + uq.predicates = append(uq.predicates, ps...) + return uq +} + +// Limit adds a limit step to the query. +func (uq *UserQuery) Limit(limit int) *UserQuery { + uq.limit = &limit + return uq +} + +// Offset adds an offset step to the query. +func (uq *UserQuery) Offset(offset int) *UserQuery { + uq.offset = &offset + return uq +} + +// Order adds an order step to the query. +func (uq *UserQuery) Order(o ...Order) *UserQuery { + uq.order = append(uq.order, o...) + return uq +} + +// QueryPets chains the current query on the pets edge. +func (uq *UserQuery) QueryPets() *PetQuery { + query := &PetQuery{config: uq.config} + t1 := sql.Table(pet.Table) + t2 := uq.sqlQuery() + t2.Select(t2.C(user.FieldID)) + query.sql = sql.Select(). + From(t1). + Join(t2). + On(t1.C(user.PetsColumn), t2.C(user.FieldID)) + return query +} + +// QueryFriends chains the current query on the friends edge. +func (uq *UserQuery) QueryFriends() *UserQuery { + query := &UserQuery{config: uq.config} + t1 := sql.Table(user.Table) + t2 := uq.sqlQuery() + t2.Select(t2.C(user.FieldID)) + t3 := sql.Table(user.FriendsTable) + t4 := sql.Select(t3.C(user.FriendsPrimaryKey[1])). + From(t3). + Join(t2). + On(t3.C(user.FriendsPrimaryKey[0]), t2.C(user.FieldID)) + query.sql = sql.Select(). + From(t1). + Join(t4). + On(t1.C(user.FieldID), t4.C(user.FriendsPrimaryKey[1])) + return query +} + +// Get returns a User entity by its id. +func (uq *UserQuery) Get(ctx context.Context, id int) (*User, error) { + return uq.Where(user.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (uq *UserQuery) GetX(ctx context.Context, id int) *User { + u, err := uq.Get(ctx, id) + if err != nil { + panic(err) + } + return u +} + +// First returns the first User entity in the query. Returns *ErrNotFound when no user was found. +func (uq *UserQuery) First(ctx context.Context) (*User, error) { + us, err := uq.Limit(1).All(ctx) + if err != nil { + return nil, err + } + if len(us) == 0 { + return nil, &ErrNotFound{user.Label} + } + return us[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (uq *UserQuery) FirstX(ctx context.Context) *User { + u, err := uq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return u +} + +// FirstID returns the first User id in the query. Returns *ErrNotFound when no id was found. +func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = uq.Limit(1).IDs(ctx); err != nil { + return + } + if len(ids) == 0 { + err = &ErrNotFound{user.Label} + return + } + return ids[0], nil +} + +// FirstXID is like FirstID, but panics if an error occurs. +func (uq *UserQuery) FirstXID(ctx context.Context) int { + id, err := uq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns the only User entity in the query, returns an error if not exactly one entity was returned. +func (uq *UserQuery) Only(ctx context.Context) (*User, error) { + us, err := uq.Limit(2).All(ctx) + if err != nil { + return nil, err + } + switch len(us) { + case 1: + return us[0], nil + case 0: + return nil, &ErrNotFound{user.Label} + default: + return nil, &ErrNotSingular{user.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (uq *UserQuery) OnlyX(ctx context.Context) *User { + u, err := uq.Only(ctx) + if err != nil { + panic(err) + } + return u +} + +// OnlyID returns the only User id in the query, returns an error if not exactly one id was returned. +func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = uq.Limit(2).IDs(ctx); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &ErrNotFound{user.Label} + default: + err = &ErrNotSingular{user.Label} + } + return +} + +// OnlyXID is like OnlyID, but panics if an error occurs. +func (uq *UserQuery) OnlyXID(ctx context.Context) int { + id, err := uq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Users. +func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { + return uq.sqlAll(ctx) +} + +// AllX is like All, but panics if an error occurs. +func (uq *UserQuery) AllX(ctx context.Context) []*User { + us, err := uq.All(ctx) + if err != nil { + panic(err) + } + return us +} + +// IDs executes the query and returns a list of User ids. +func (uq *UserQuery) IDs(ctx context.Context) ([]int, error) { + return uq.sqlIDs(ctx) +} + +// IDsX is like IDs, but panics if an error occurs. +func (uq *UserQuery) IDsX(ctx context.Context) []int { + ids, err := uq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (uq *UserQuery) Count(ctx context.Context) (int, error) { + return uq.sqlCount(ctx) +} + +// CountX is like Count, but panics if an error occurs. +func (uq *UserQuery) CountX(ctx context.Context) int { + count, err := uq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { + return uq.sqlExist(ctx) +} + +// ExistX is like Exist, but panics if an error occurs. +func (uq *UserQuery) ExistX(ctx context.Context) bool { + exist, err := uq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the query builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (uq *UserQuery) Clone() *UserQuery { + return &UserQuery{ + config: uq.config, + limit: uq.limit, + offset: uq.offset, + order: append([]Order{}, uq.order...), + unique: append([]string{}, uq.unique...), + predicates: append([]predicate.User{}, uq.predicates...), + // clone intermediate queries. + sql: uq.sql.Clone(), + } +} + +// GroupBy used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// Name string `json:"name,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.User.Query(). +// GroupBy(user.FieldName). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +// +func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { + group := &UserGroupBy{config: uq.config} + group.fields = append([]string{field}, fields...) + group.sql = uq.sqlQuery() + return group +} + +// Select one or more fields from the given query. +// +// Example: +// +// var v []struct { +// Name string `json:"name,omitempty"` +// } +// +// client.User.Query(). +// Select(user.FieldName). +// Scan(ctx, &v) +// +func (uq *UserQuery) Select(field string, fields ...string) *UserSelect { + selector := &UserSelect{config: uq.config} + selector.fields = append([]string{field}, fields...) + selector.sql = uq.sqlQuery() + return selector +} + +func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { + rows := &sql.Rows{} + selector := uq.sqlQuery() + if unique := uq.unique; len(unique) == 0 { + selector.Distinct() + } + query, args := selector.Query() + if err := uq.driver.Query(ctx, query, args, rows); err != nil { + return nil, err + } + defer rows.Close() + var us Users + if err := us.FromRows(rows); err != nil { + return nil, err + } + us.config(uq.config) + return us, nil +} + +func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { + rows := &sql.Rows{} + selector := uq.sqlQuery() + unique := []string{user.FieldID} + if len(uq.unique) > 0 { + unique = uq.unique + } + selector.Count(sql.Distinct(selector.Columns(unique...)...)) + query, args := selector.Query() + if err := uq.driver.Query(ctx, query, args, rows); err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errors.New("ent: no rows found") + } + var n int + if err := rows.Scan(&n); err != nil { + return 0, fmt.Errorf("ent: failed reading count: %v", err) + } + return n, nil +} + +func (uq *UserQuery) sqlExist(ctx context.Context) (bool, error) { + n, err := uq.sqlCount(ctx) + if err != nil { + return false, fmt.Errorf("ent: check existence: %v", err) + } + return n > 0, nil +} + +func (uq *UserQuery) sqlIDs(ctx context.Context) ([]int, error) { + vs, err := uq.sqlAll(ctx) + if err != nil { + return nil, err + } + var ids []int + for _, v := range vs { + ids = append(ids, v.ID) + } + return ids, nil +} + +func (uq *UserQuery) sqlQuery() *sql.Selector { + t1 := sql.Table(user.Table) + selector := sql.Select(t1.Columns(user.Columns...)...).From(t1) + if uq.sql != nil { + selector = uq.sql + selector.Select(selector.Columns(user.Columns...)...) + } + for _, p := range uq.predicates { + p(selector) + } + for _, p := range uq.order { + p(selector) + } + if offset := uq.offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt64) + } + if limit := uq.limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// UserGroupBy is the builder for group-by User entities. +type UserGroupBy struct { + config + fields []string + fns []Aggregate + // intermediate queries. + sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (ugb *UserGroupBy) Aggregate(fns ...Aggregate) *UserGroupBy { + ugb.fns = append(ugb.fns, fns...) + return ugb +} + +// Scan applies the group-by query and scan the result into the given value. +func (ugb *UserGroupBy) Scan(ctx context.Context, v interface{}) error { + return ugb.sqlScan(ctx, v) +} + +// ScanX is like Scan, but panics if an error occurs. +func (ugb *UserGroupBy) ScanX(ctx context.Context, v interface{}) { + if err := ugb.Scan(ctx, v); err != nil { + panic(err) + } +} + +// Strings returns list of strings from group-by. It is only allowed when querying group-by with one field. +func (ugb *UserGroupBy) Strings(ctx context.Context) ([]string, error) { + if len(ugb.fields) > 1 { + return nil, errors.New("ent: UserGroupBy.Strings is not achievable when grouping more than 1 field") + } + var v []string + if err := ugb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// StringsX is like Strings, but panics if an error occurs. +func (ugb *UserGroupBy) StringsX(ctx context.Context) []string { + v, err := ugb.Strings(ctx) + if err != nil { + panic(err) + } + return v +} + +// Ints returns list of ints from group-by. It is only allowed when querying group-by with one field. +func (ugb *UserGroupBy) Ints(ctx context.Context) ([]int, error) { + if len(ugb.fields) > 1 { + return nil, errors.New("ent: UserGroupBy.Ints is not achievable when grouping more than 1 field") + } + var v []int + if err := ugb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// IntsX is like Ints, but panics if an error occurs. +func (ugb *UserGroupBy) IntsX(ctx context.Context) []int { + v, err := ugb.Ints(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64s returns list of float64s from group-by. It is only allowed when querying group-by with one field. +func (ugb *UserGroupBy) Float64s(ctx context.Context) ([]float64, error) { + if len(ugb.fields) > 1 { + return nil, errors.New("ent: UserGroupBy.Float64s is not achievable when grouping more than 1 field") + } + var v []float64 + if err := ugb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// Float64sX is like Float64s, but panics if an error occurs. +func (ugb *UserGroupBy) Float64sX(ctx context.Context) []float64 { + v, err := ugb.Float64s(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bools returns list of bools from group-by. It is only allowed when querying group-by with one field. +func (ugb *UserGroupBy) Bools(ctx context.Context) ([]bool, error) { + if len(ugb.fields) > 1 { + return nil, errors.New("ent: UserGroupBy.Bools is not achievable when grouping more than 1 field") + } + var v []bool + if err := ugb.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// BoolsX is like Bools, but panics if an error occurs. +func (ugb *UserGroupBy) BoolsX(ctx context.Context) []bool { + v, err := ugb.Bools(ctx) + if err != nil { + panic(err) + } + return v +} + +func (ugb *UserGroupBy) sqlScan(ctx context.Context, v interface{}) error { + rows := &sql.Rows{} + query, args := ugb.sqlQuery().Query() + if err := ugb.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +func (ugb *UserGroupBy) sqlQuery() *sql.Selector { + selector := ugb.sql + columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) + columns = append(columns, ugb.fields...) + for _, fn := range ugb.fns { + columns = append(columns, fn.SQL(selector)) + } + return selector.Select(columns...).GroupBy(ugb.fields...) +} + +// UserSelect is the builder for select fields of User entities. +type UserSelect struct { + config + fields []string + // intermediate queries. + sql *sql.Selector +} + +// Scan applies the selector query and scan the result into the given value. +func (us *UserSelect) Scan(ctx context.Context, v interface{}) error { + return us.sqlScan(ctx, v) +} + +// ScanX is like Scan, but panics if an error occurs. +func (us *UserSelect) ScanX(ctx context.Context, v interface{}) { + if err := us.Scan(ctx, v); err != nil { + panic(err) + } +} + +// Strings returns list of strings from selector. It is only allowed when selecting one field. +func (us *UserSelect) Strings(ctx context.Context) ([]string, error) { + if len(us.fields) > 1 { + return nil, errors.New("ent: UserSelect.Strings is not achievable when selecting more than 1 field") + } + var v []string + if err := us.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// StringsX is like Strings, but panics if an error occurs. +func (us *UserSelect) StringsX(ctx context.Context) []string { + v, err := us.Strings(ctx) + if err != nil { + panic(err) + } + return v +} + +// Ints returns list of ints from selector. It is only allowed when selecting one field. +func (us *UserSelect) Ints(ctx context.Context) ([]int, error) { + if len(us.fields) > 1 { + return nil, errors.New("ent: UserSelect.Ints is not achievable when selecting more than 1 field") + } + var v []int + if err := us.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// IntsX is like Ints, but panics if an error occurs. +func (us *UserSelect) IntsX(ctx context.Context) []int { + v, err := us.Ints(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64s returns list of float64s from selector. It is only allowed when selecting one field. +func (us *UserSelect) Float64s(ctx context.Context) ([]float64, error) { + if len(us.fields) > 1 { + return nil, errors.New("ent: UserSelect.Float64s is not achievable when selecting more than 1 field") + } + var v []float64 + if err := us.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// Float64sX is like Float64s, but panics if an error occurs. +func (us *UserSelect) Float64sX(ctx context.Context) []float64 { + v, err := us.Float64s(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bools returns list of bools from selector. It is only allowed when selecting one field. +func (us *UserSelect) Bools(ctx context.Context) ([]bool, error) { + if len(us.fields) > 1 { + return nil, errors.New("ent: UserSelect.Bools is not achievable when selecting more than 1 field") + } + var v []bool + if err := us.Scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// BoolsX is like Bools, but panics if an error occurs. +func (us *UserSelect) BoolsX(ctx context.Context) []bool { + v, err := us.Bools(ctx) + if err != nil { + panic(err) + } + return v +} + +func (us *UserSelect) sqlScan(ctx context.Context, v interface{}) error { + rows := &sql.Rows{} + query, args := us.sqlQuery().Query() + if err := us.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +func (us *UserSelect) sqlQuery() sql.Querier { + view := "user_view" + return sql.Select(us.fields...).From(us.sql.As(view)) +} diff --git a/entc/integration/template/ent/user_update.go b/entc/integration/template/ent/user_update.go new file mode 100644 index 000000000..06aa2ea97 --- /dev/null +++ b/entc/integration/template/ent/user_update.go @@ -0,0 +1,518 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated (@generated) by entc, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + + "github.com/facebookincubator/ent/entc/integration/template/ent/pet" + "github.com/facebookincubator/ent/entc/integration/template/ent/predicate" + "github.com/facebookincubator/ent/entc/integration/template/ent/user" + + "github.com/facebookincubator/ent/dialect/sql" +) + +// UserUpdate is the builder for updating User entities. +type UserUpdate struct { + config + name *string + pets map[int]struct{} + friends map[int]struct{} + removedPets map[int]struct{} + removedFriends map[int]struct{} + predicates []predicate.User +} + +// Where adds a new predicate for the builder. +func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate { + uu.predicates = append(uu.predicates, ps...) + return uu +} + +// SetName sets the name field. +func (uu *UserUpdate) SetName(s string) *UserUpdate { + uu.name = &s + return uu +} + +// AddPetIDs adds the pets edge to Pet by ids. +func (uu *UserUpdate) AddPetIDs(ids ...int) *UserUpdate { + if uu.pets == nil { + uu.pets = make(map[int]struct{}) + } + for i := range ids { + uu.pets[ids[i]] = struct{}{} + } + return uu +} + +// AddPets adds the pets edges to Pet. +func (uu *UserUpdate) AddPets(p ...*Pet) *UserUpdate { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uu.AddPetIDs(ids...) +} + +// AddFriendIDs adds the friends edge to User by ids. +func (uu *UserUpdate) AddFriendIDs(ids ...int) *UserUpdate { + if uu.friends == nil { + uu.friends = make(map[int]struct{}) + } + for i := range ids { + uu.friends[ids[i]] = struct{}{} + } + return uu +} + +// AddFriends adds the friends edges to User. +func (uu *UserUpdate) AddFriends(u ...*User) *UserUpdate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return uu.AddFriendIDs(ids...) +} + +// RemovePetIDs removes the pets edge to Pet by ids. +func (uu *UserUpdate) RemovePetIDs(ids ...int) *UserUpdate { + if uu.removedPets == nil { + uu.removedPets = make(map[int]struct{}) + } + for i := range ids { + uu.removedPets[ids[i]] = struct{}{} + } + return uu +} + +// RemovePets removes pets edges to Pet. +func (uu *UserUpdate) RemovePets(p ...*Pet) *UserUpdate { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uu.RemovePetIDs(ids...) +} + +// RemoveFriendIDs removes the friends edge to User by ids. +func (uu *UserUpdate) RemoveFriendIDs(ids ...int) *UserUpdate { + if uu.removedFriends == nil { + uu.removedFriends = make(map[int]struct{}) + } + for i := range ids { + uu.removedFriends[ids[i]] = struct{}{} + } + return uu +} + +// RemoveFriends removes friends edges to User. +func (uu *UserUpdate) RemoveFriends(u ...*User) *UserUpdate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return uu.RemoveFriendIDs(ids...) +} + +// Save executes the query and returns the number of rows/vertices matched by this operation. +func (uu *UserUpdate) Save(ctx context.Context) (int, error) { + return uu.sqlSave(ctx) +} + +// SaveX is like Save, but panics if an error occurs. +func (uu *UserUpdate) SaveX(ctx context.Context) int { + affected, err := uu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (uu *UserUpdate) Exec(ctx context.Context) error { + _, err := uu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (uu *UserUpdate) ExecX(ctx context.Context) { + if err := uu.Exec(ctx); err != nil { + panic(err) + } +} + +func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { + selector := sql.Select(user.FieldID).From(sql.Table(user.Table)) + for _, p := range uu.predicates { + p(selector) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err = uu.driver.Query(ctx, query, args, rows); err != nil { + return 0, err + } + defer rows.Close() + var ids []int + for rows.Next() { + var id int + if err := rows.Scan(&id); err != nil { + return 0, fmt.Errorf("ent: failed reading id: %v", err) + } + ids = append(ids, id) + } + if len(ids) == 0 { + return 0, nil + } + + tx, err := uu.driver.Tx(ctx) + if err != nil { + return 0, err + } + var ( + update bool + res sql.Result + builder = sql.Update(user.Table).Where(sql.InInts(user.FieldID, ids...)) + ) + if value := uu.name; value != nil { + update = true + builder.Set(user.FieldName, *value) + } + if update { + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + } + if len(uu.removedPets) > 0 { + eids := make([]int, len(uu.removedPets)) + for eid := range uu.removedPets { + eids = append(eids, eid) + } + query, args := sql.Update(user.PetsTable). + SetNull(user.PetsColumn). + Where(sql.InInts(user.PetsColumn, ids...)). + Where(sql.InInts(pet.FieldID, eids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + } + if len(uu.pets) > 0 { + for _, id := range ids { + p := sql.P() + for eid := range uu.pets { + p.Or().EQ(pet.FieldID, eid) + } + query, args := sql.Update(user.PetsTable). + Set(user.PetsColumn, id). + Where(sql.And(p, sql.IsNull(user.PetsColumn))). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + affected, err := res.RowsAffected() + if err != nil { + return 0, rollback(tx, err) + } + if int(affected) < len(uu.pets) { + return 0, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("one of \"pets\" %v already connected to a different \"User\"", keys(uu.pets))}) + } + } + } + if len(uu.removedFriends) > 0 { + eids := make([]int, len(uu.removedFriends)) + for eid := range uu.removedFriends { + eids = append(eids, eid) + } + query, args := sql.Delete(user.FriendsTable). + Where(sql.InInts(user.FriendsPrimaryKey[0], ids...)). + Where(sql.InInts(user.FriendsPrimaryKey[1], eids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + query, args = sql.Delete(user.FriendsTable). + Where(sql.InInts(user.FriendsPrimaryKey[1], ids...)). + Where(sql.InInts(user.FriendsPrimaryKey[0], eids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + } + if len(uu.friends) > 0 { + values := make([][]int, 0, len(ids)) + for _, id := range ids { + for eid := range uu.friends { + values = append(values, []int{id, eid}, []int{eid, id}) + } + } + builder := sql.Insert(user.FriendsTable). + Columns(user.FriendsPrimaryKey[0], user.FriendsPrimaryKey[1]) + for _, v := range values { + builder.Values(v[0], v[1]) + } + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, rollback(tx, err) + } + } + if err = tx.Commit(); err != nil { + return 0, err + } + return len(ids), nil +} + +// UserUpdateOne is the builder for updating a single User entity. +type UserUpdateOne struct { + config + id int + name *string + pets map[int]struct{} + friends map[int]struct{} + removedPets map[int]struct{} + removedFriends map[int]struct{} +} + +// SetName sets the name field. +func (uuo *UserUpdateOne) SetName(s string) *UserUpdateOne { + uuo.name = &s + return uuo +} + +// AddPetIDs adds the pets edge to Pet by ids. +func (uuo *UserUpdateOne) AddPetIDs(ids ...int) *UserUpdateOne { + if uuo.pets == nil { + uuo.pets = make(map[int]struct{}) + } + for i := range ids { + uuo.pets[ids[i]] = struct{}{} + } + return uuo +} + +// AddPets adds the pets edges to Pet. +func (uuo *UserUpdateOne) AddPets(p ...*Pet) *UserUpdateOne { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uuo.AddPetIDs(ids...) +} + +// AddFriendIDs adds the friends edge to User by ids. +func (uuo *UserUpdateOne) AddFriendIDs(ids ...int) *UserUpdateOne { + if uuo.friends == nil { + uuo.friends = make(map[int]struct{}) + } + for i := range ids { + uuo.friends[ids[i]] = struct{}{} + } + return uuo +} + +// AddFriends adds the friends edges to User. +func (uuo *UserUpdateOne) AddFriends(u ...*User) *UserUpdateOne { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return uuo.AddFriendIDs(ids...) +} + +// RemovePetIDs removes the pets edge to Pet by ids. +func (uuo *UserUpdateOne) RemovePetIDs(ids ...int) *UserUpdateOne { + if uuo.removedPets == nil { + uuo.removedPets = make(map[int]struct{}) + } + for i := range ids { + uuo.removedPets[ids[i]] = struct{}{} + } + return uuo +} + +// RemovePets removes pets edges to Pet. +func (uuo *UserUpdateOne) RemovePets(p ...*Pet) *UserUpdateOne { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uuo.RemovePetIDs(ids...) +} + +// RemoveFriendIDs removes the friends edge to User by ids. +func (uuo *UserUpdateOne) RemoveFriendIDs(ids ...int) *UserUpdateOne { + if uuo.removedFriends == nil { + uuo.removedFriends = make(map[int]struct{}) + } + for i := range ids { + uuo.removedFriends[ids[i]] = struct{}{} + } + return uuo +} + +// RemoveFriends removes friends edges to User. +func (uuo *UserUpdateOne) RemoveFriends(u ...*User) *UserUpdateOne { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return uuo.RemoveFriendIDs(ids...) +} + +// Save executes the query and returns the updated entity. +func (uuo *UserUpdateOne) Save(ctx context.Context) (*User, error) { + return uuo.sqlSave(ctx) +} + +// SaveX is like Save, but panics if an error occurs. +func (uuo *UserUpdateOne) SaveX(ctx context.Context) *User { + u, err := uuo.Save(ctx) + if err != nil { + panic(err) + } + return u +} + +// Exec executes the query on the entity. +func (uuo *UserUpdateOne) Exec(ctx context.Context) error { + _, err := uuo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (uuo *UserUpdateOne) ExecX(ctx context.Context) { + if err := uuo.Exec(ctx); err != nil { + panic(err) + } +} + +func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) { + selector := sql.Select(user.Columns...).From(sql.Table(user.Table)) + user.ID(uuo.id)(selector) + rows := &sql.Rows{} + query, args := selector.Query() + if err = uuo.driver.Query(ctx, query, args, rows); err != nil { + return nil, err + } + defer rows.Close() + var ids []int + for rows.Next() { + var id int + u = &User{config: uuo.config} + if err := u.FromRows(rows); err != nil { + return nil, fmt.Errorf("ent: failed scanning row into User: %v", err) + } + id = u.ID + ids = append(ids, id) + } + switch n := len(ids); { + case n == 0: + return nil, fmt.Errorf("ent: User not found with id: %v", uuo.id) + case n > 1: + return nil, fmt.Errorf("ent: more than one User with the same id: %v", uuo.id) + } + + tx, err := uuo.driver.Tx(ctx) + if err != nil { + return nil, err + } + var ( + update bool + res sql.Result + builder = sql.Update(user.Table).Where(sql.InInts(user.FieldID, ids...)) + ) + if value := uuo.name; value != nil { + update = true + builder.Set(user.FieldName, *value) + u.Name = *value + } + if update { + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + if len(uuo.removedPets) > 0 { + eids := make([]int, len(uuo.removedPets)) + for eid := range uuo.removedPets { + eids = append(eids, eid) + } + query, args := sql.Update(user.PetsTable). + SetNull(user.PetsColumn). + Where(sql.InInts(user.PetsColumn, ids...)). + Where(sql.InInts(pet.FieldID, eids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + if len(uuo.pets) > 0 { + for _, id := range ids { + p := sql.P() + for eid := range uuo.pets { + p.Or().EQ(pet.FieldID, eid) + } + query, args := sql.Update(user.PetsTable). + Set(user.PetsColumn, id). + Where(sql.And(p, sql.IsNull(user.PetsColumn))). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + affected, err := res.RowsAffected() + if err != nil { + return nil, rollback(tx, err) + } + if int(affected) < len(uuo.pets) { + return nil, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("one of \"pets\" %v already connected to a different \"User\"", keys(uuo.pets))}) + } + } + } + if len(uuo.removedFriends) > 0 { + eids := make([]int, len(uuo.removedFriends)) + for eid := range uuo.removedFriends { + eids = append(eids, eid) + } + query, args := sql.Delete(user.FriendsTable). + Where(sql.InInts(user.FriendsPrimaryKey[0], ids...)). + Where(sql.InInts(user.FriendsPrimaryKey[1], eids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + query, args = sql.Delete(user.FriendsTable). + Where(sql.InInts(user.FriendsPrimaryKey[1], ids...)). + Where(sql.InInts(user.FriendsPrimaryKey[0], eids...)). + Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + if len(uuo.friends) > 0 { + values := make([][]int, 0, len(ids)) + for _, id := range ids { + for eid := range uuo.friends { + values = append(values, []int{id, eid}, []int{eid, id}) + } + } + builder := sql.Insert(user.FriendsTable). + Columns(user.FriendsPrimaryKey[0], user.FriendsPrimaryKey[1]) + for _, v := range values { + builder.Values(v[0], v[1]) + } + query, args := builder.Query() + if err := tx.Exec(ctx, query, args, &res); err != nil { + return nil, rollback(tx, err) + } + } + if err = tx.Commit(); err != nil { + return nil, err + } + return u, nil +} diff --git a/entc/integration/template/template_test.go b/entc/integration/template/template_test.go new file mode 100644 index 000000000..679a8f6dc --- /dev/null +++ b/entc/integration/template/template_test.go @@ -0,0 +1,44 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package template + +import ( + "context" + "testing" + + "github.com/facebookincubator/ent/dialect/sql" + "github.com/facebookincubator/ent/entc/integration/template/ent" + "github.com/facebookincubator/ent/entc/integration/template/ent/migrate" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" +) + +func TestCustomTemplate(t *testing.T) { + drv, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + defer drv.Close() + ctx := context.Background() + client := ent.NewClient(ent.Driver(drv)) + require.NoError(t, client.Schema.Create(ctx, migrate.WithGlobalUniqueID(true))) + + p := client.Pet.Create().SetAge(1).SaveX(ctx) + u := client.User.Create().SetName("a8m").AddPets(p).SaveX(ctx) + g := client.Group.Create().SetMaxUsers(10).SaveX(ctx) + + node, err := client.Node(ctx, p.ID) + require.Equal(t, p.ID, node.ID) + require.Equal(t, &ent.Field{Type: "int", Name: "Age", Value: "1"}, node.Fields[0]) + require.Equal(t, &ent.Edge{Type: "User", Name: "Owner", IDs: []int{u.ID}}, node.Edges[0]) + + node, err = client.Node(ctx, u.ID) + require.Equal(t, u.ID, node.ID) + require.Equal(t, &ent.Field{Type: "string", Name: "Name", Value: "\"a8m\""}, node.Fields[0]) + require.Equal(t, &ent.Edge{Type: "Pet", Name: "Pets", IDs: []int{p.ID}}, node.Edges[0]) + + node, err = client.Node(ctx, g.ID) + require.Equal(t, g.ID, node.ID) + require.Equal(t, &ent.Field{Type: "int", Name: "MaxUsers", Value: "10"}, node.Fields[0]) +}