entc/gen/sql: support custom-id on create

Summary: Pull Request resolved: https://github.com/facebookincubator/ent/pull/166

Reviewed By: alexsn

Differential Revision: D18514295

fbshipit-source-id: e5988552c5611cbad18476ab2d9c2155df1e6e0c
This commit is contained in:
Ariel Mashraki
2019-11-14 14:37:36 -08:00
committed by Facebook Github Bot
parent 38002e6d2e
commit 2b2e056f05
16 changed files with 116 additions and 105 deletions

File diff suppressed because one or more lines are too long

View File

@@ -7,6 +7,10 @@ in the LICENSE file in the root directory of this source tree.
{{ define "dialect/sql/create" }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
{{ $fields := $.Fields }}
{{ if $.ID.UserDefined }}
{{ $fields = append $fields $.ID }}
{{ end }}
func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name }}, error) {
var (
@@ -19,7 +23,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name
return nil, err
}
insert := builder.Insert({{ $.Package }}.Table).Default()
{{- range $_, $f := $.Fields }}
{{- range $_, $f := $fields }}
if value := {{ $receiver }}.{{- $f.BuilderField }}; value != nil {
{{- if $f.IsJSON }}
buf, err := json.Marshal(*value)
@@ -37,7 +41,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name
if err != nil {
return nil, rollback(tx, err)
}
{{ $.Receiver }}.ID = {{ if $.ID.IsString }}strconv.FormatInt(id, 10){{ else }}{{ $.ID.Type }}(id){{ end }}
{{ $.Receiver }}.ID = {{ if and $.ID.IsString (not $.ID.UserDefined) }}strconv.FormatInt(id, 10){{ else }}{{ $.ID.Type }}(id){{ end }}
{{- range $_, $e := $.Edges }}
if len({{ $receiver }}.{{ $e.BuilderField }}) > 0 {
{{- if and $e.Unique $e.SelfRef }}{{/* O2O with self reference */}}

View File

@@ -0,0 +1,34 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package customid
import (
"context"
"testing"
"github.com/facebookincubator/ent/entc/integration/customid/ent"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
func TestCustomID(t *testing.T) {
client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
require.NoError(t, client.Schema.Create(ctx))
nat := client.User.Create().SaveX(ctx)
require.Equal(t, 1, nat.ID)
_, err = client.User.Create().SetID(1).Save(ctx)
require.True(t, ent.IsConstraintFailure(err), "duplicate id")
a8m := client.User.Create().SetID(5).SaveX(ctx)
require.Equal(t, 5, a8m.ID)
hub := client.Group.Create().SetID(3).AddUsers(a8m, nat).SaveX(ctx)
require.Equal(t, 3, hub.ID)
require.Equal(t, []int{1, 5}, hub.QueryUsers().IDsX(ctx))
}

View File

@@ -124,7 +124,7 @@ func (c *GroupClient) UpdateOne(gr *Group) *GroupUpdateOne {
}
// UpdateOneID returns an update builder for the given id.
func (c *GroupClient) UpdateOneID(id string) *GroupUpdateOne {
func (c *GroupClient) UpdateOneID(id int) *GroupUpdateOne {
return &GroupUpdateOne{config: c.config, id: id}
}
@@ -139,7 +139,7 @@ func (c *GroupClient) DeleteOne(gr *Group) *GroupDeleteOne {
}
// DeleteOneID returns a delete builder for the given id.
func (c *GroupClient) DeleteOneID(id string) *GroupDeleteOne {
func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne {
return &GroupDeleteOne{c.Delete().Where(group.ID(id))}
}
@@ -149,12 +149,12 @@ func (c *GroupClient) Query() *GroupQuery {
}
// Get returns a Group entity by its id.
func (c *GroupClient) Get(ctx context.Context, id string) (*Group, error) {
func (c *GroupClient) Get(ctx context.Context, id int) (*Group, error) {
return c.Query().Where(group.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *GroupClient) GetX(ctx context.Context, id string) *Group {
func (c *GroupClient) GetX(ctx context.Context, id int) *Group {
gr, err := c.Get(ctx, id)
if err != nil {
panic(err)
@@ -165,7 +165,7 @@ func (c *GroupClient) GetX(ctx context.Context, id string) *Group {
// QueryUsers queries the users edge of a Group.
func (c *GroupClient) QueryUsers(gr *Group) *UserQuery {
query := &UserQuery{config: c.config}
id := gr.id()
id := gr.ID
step := &sql.Step{}
step.From.V = id
step.From.Table = group.Table

View File

@@ -226,8 +226,8 @@ func insertLastID(ctx context.Context, tx dialect.Tx, insert *sql.InsertBuilder)
}
// keys returns the keys/ids from the edge map.
func keys(m map[string]struct{}) []string {
s := make([]string, 0, len(m))
func keys(m map[int]struct{}) []int {
s := make([]int, 0, len(m))
for id := range m {
s = append(s, id)
}

View File

@@ -4,7 +4,6 @@ package ent
import (
"fmt"
"strconv"
"strings"
"github.com/facebookincubator/ent/dialect/sql"
@@ -14,7 +13,7 @@ import (
type Group struct {
config
// ID of the ent.
ID string `json:"id,omitempty"`
ID int `json:"id,omitempty"`
}
// FromRows scans the sql response data into Group.
@@ -28,7 +27,7 @@ func (gr *Group) FromRows(rows *sql.Rows) error {
); err != nil {
return err
}
gr.ID = strconv.Itoa(scangr.ID)
gr.ID = scangr.ID
return nil
}
@@ -64,12 +63,6 @@ func (gr *Group) String() string {
return builder.String()
}
// id returns the int representation of the ID field.
func (gr *Group) id() int {
id, _ := strconv.Atoi(gr.ID)
return id
}
// Groups is a parsable slice of Group.
type Groups []*Group

View File

@@ -3,44 +3,39 @@
package group
import (
"strconv"
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/entc/integration/customid/ent/predicate"
)
// ID filters vertices based on their identifier.
func ID(id string) predicate.Group {
func ID(id int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
id, _ := strconv.Atoi(id)
s.Where(sql.EQ(s.C(FieldID), id))
},
)
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id string) predicate.Group {
func IDEQ(id int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
id, _ := strconv.Atoi(id)
s.Where(sql.EQ(s.C(FieldID), id))
},
)
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id string) predicate.Group {
func IDNEQ(id int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
id, _ := strconv.Atoi(id)
s.Where(sql.NEQ(s.C(FieldID), id))
},
)
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...string) predicate.Group {
func IDIn(ids ...int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
@@ -51,7 +46,7 @@ func IDIn(ids ...string) predicate.Group {
}
v := make([]interface{}, len(ids))
for i := range v {
v[i], _ = strconv.Atoi(ids[i])
v[i] = ids[i]
}
s.Where(sql.In(s.C(FieldID), v...))
},
@@ -59,7 +54,7 @@ func IDIn(ids ...string) predicate.Group {
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...string) predicate.Group {
func IDNotIn(ids ...int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
@@ -70,7 +65,7 @@ func IDNotIn(ids ...string) predicate.Group {
}
v := make([]interface{}, len(ids))
for i := range v {
v[i], _ = strconv.Atoi(ids[i])
v[i] = ids[i]
}
s.Where(sql.NotIn(s.C(FieldID), v...))
},
@@ -78,40 +73,36 @@ func IDNotIn(ids ...string) predicate.Group {
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id string) predicate.Group {
func IDGT(id int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
id, _ := strconv.Atoi(id)
s.Where(sql.GT(s.C(FieldID), id))
},
)
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id string) predicate.Group {
func IDGTE(id int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
id, _ := strconv.Atoi(id)
s.Where(sql.GTE(s.C(FieldID), id))
},
)
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id string) predicate.Group {
func IDLT(id int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
id, _ := strconv.Atoi(id)
s.Where(sql.LT(s.C(FieldID), id))
},
)
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id string) predicate.Group {
func IDLTE(id int) predicate.Group {
return predicate.Group(
func(s *sql.Selector) {
id, _ := strconv.Atoi(id)
s.Where(sql.LTE(s.C(FieldID), id))
},
)

View File

@@ -4,7 +4,6 @@ package ent
import (
"context"
"strconv"
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/entc/integration/customid/ent/group"
@@ -13,13 +12,13 @@ import (
// GroupCreate is the builder for creating a Group entity.
type GroupCreate struct {
config
id *string
id *int
users map[int]struct{}
}
// SetID sets the id field.
func (gc *GroupCreate) SetID(s string) *GroupCreate {
gc.id = &s
func (gc *GroupCreate) SetID(i int) *GroupCreate {
gc.id = &i
return gc
}
@@ -68,11 +67,15 @@ func (gc *GroupCreate) sqlSave(ctx context.Context) (*Group, error) {
return nil, err
}
insert := builder.Insert(group.Table).Default()
if value := gc.id; value != nil {
insert.Set(group.FieldID, *value)
gr.ID = *value
}
id, err := insertLastID(ctx, tx, insert.Returning(group.FieldID))
if err != nil {
return nil, rollback(tx, err)
}
gr.ID = strconv.FormatInt(id, 10)
gr.ID = int(id)
if len(gc.users) > 0 {
for eid := range gc.users {

View File

@@ -92,8 +92,8 @@ func (gq *GroupQuery) FirstX(ctx context.Context) *Group {
}
// FirstID returns the first Group id in the query. Returns *ErrNotFound when no id was found.
func (gq *GroupQuery) FirstID(ctx context.Context) (id string, err error) {
var ids []string
func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = gq.Limit(1).IDs(ctx); err != nil {
return
}
@@ -105,7 +105,7 @@ func (gq *GroupQuery) FirstID(ctx context.Context) (id string, err error) {
}
// FirstXID is like FirstID, but panics if an error occurs.
func (gq *GroupQuery) FirstXID(ctx context.Context) string {
func (gq *GroupQuery) FirstXID(ctx context.Context) int {
id, err := gq.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
@@ -139,8 +139,8 @@ func (gq *GroupQuery) OnlyX(ctx context.Context) *Group {
}
// OnlyID returns the only Group id in the query, returns an error if not exactly one id was returned.
func (gq *GroupQuery) OnlyID(ctx context.Context) (id string, err error) {
var ids []string
func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = gq.Limit(2).IDs(ctx); err != nil {
return
}
@@ -156,7 +156,7 @@ func (gq *GroupQuery) OnlyID(ctx context.Context) (id string, err error) {
}
// OnlyXID is like OnlyID, but panics if an error occurs.
func (gq *GroupQuery) OnlyXID(ctx context.Context) string {
func (gq *GroupQuery) OnlyXID(ctx context.Context) int {
id, err := gq.OnlyID(ctx)
if err != nil {
panic(err)
@@ -179,8 +179,8 @@ func (gq *GroupQuery) AllX(ctx context.Context) []*Group {
}
// IDs executes the query and returns a list of Group ids.
func (gq *GroupQuery) IDs(ctx context.Context) ([]string, error) {
var ids []string
func (gq *GroupQuery) IDs(ctx context.Context) ([]int, error) {
var ids []int
if err := gq.Select(group.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
@@ -188,7 +188,7 @@ func (gq *GroupQuery) IDs(ctx context.Context) ([]string, error) {
}
// IDsX is like IDs, but panics if an error occurs.
func (gq *GroupQuery) IDsX(ctx context.Context) []string {
func (gq *GroupQuery) IDsX(ctx context.Context) []int {
ids, err := gq.IDs(ctx)
if err != nil {
panic(err)

View File

@@ -162,7 +162,7 @@ func (gu *GroupUpdate) sqlSave(ctx context.Context) (n int, err error) {
// GroupUpdateOne is the builder for updating a single Group entity.
type GroupUpdateOne struct {
config
id string
id int
users map[int]struct{}
removedUsers map[int]struct{}
}
@@ -253,7 +253,7 @@ func (guo *GroupUpdateOne) sqlSave(ctx context.Context) (gr *Group, err error) {
if err := gr.FromRows(rows); err != nil {
return nil, fmt.Errorf("ent: failed scanning row into Group: %v", err)
}
id = gr.id()
id = gr.ID
ids = append(ids, id)
}
switch n := len(ids); {

View File

@@ -32,7 +32,7 @@ var (
}
// GroupUsersColumns holds the columns for the "group_users" table.
GroupUsersColumns = []*schema.Column{
{Name: "group_id", Type: field.TypeString},
{Name: "group_id", Type: field.TypeInt},
{Name: "user_id", Type: field.TypeInt},
}
// GroupUsersTable holds the schema information for the "group_users" table.

View File

@@ -1,3 +1,7 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
@@ -14,7 +18,7 @@ type Group struct {
// Fields of the Group.
func (Group) Fields() []ent.Field {
return []ent.Field{
field.String("id"),
field.Int("id"),
}
}

View File

@@ -1,3 +1,7 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (

View File

@@ -4,7 +4,6 @@ package ent
import (
"context"
"strconv"
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/entc/integration/customid/ent/user"
@@ -14,7 +13,7 @@ import (
type UserCreate struct {
config
id *int
groups map[string]struct{}
groups map[int]struct{}
}
// SetID sets the id field.
@@ -24,9 +23,9 @@ func (uc *UserCreate) SetID(i int) *UserCreate {
}
// AddGroupIDs adds the groups edge to Group by ids.
func (uc *UserCreate) AddGroupIDs(ids ...string) *UserCreate {
func (uc *UserCreate) AddGroupIDs(ids ...int) *UserCreate {
if uc.groups == nil {
uc.groups = make(map[string]struct{})
uc.groups = make(map[int]struct{})
}
for i := range ids {
uc.groups[ids[i]] = struct{}{}
@@ -36,7 +35,7 @@ func (uc *UserCreate) AddGroupIDs(ids ...string) *UserCreate {
// AddGroups adds the groups edges to Group.
func (uc *UserCreate) AddGroups(g ...*Group) *UserCreate {
ids := make([]string, len(g))
ids := make([]int, len(g))
for i := range g {
ids[i] = g[i].ID
}
@@ -68,6 +67,10 @@ func (uc *UserCreate) sqlSave(ctx context.Context) (*User, error) {
return nil, err
}
insert := builder.Insert(user.Table).Default()
if value := uc.id; value != nil {
insert.Set(user.FieldID, *value)
u.ID = *value
}
id, err := insertLastID(ctx, tx, insert.Returning(user.FieldID))
if err != nil {
return nil, rollback(tx, err)
@@ -75,10 +78,6 @@ func (uc *UserCreate) sqlSave(ctx context.Context) (*User, error) {
u.ID = int(id)
if len(uc.groups) > 0 {
for eid := range uc.groups {
eid, err := strconv.Atoi(eid)
if err != nil {
return nil, rollback(tx, err)
}
query, args := builder.Insert(user.GroupsTable).
Columns(user.GroupsPrimaryKey[1], user.GroupsPrimaryKey[0]).

View File

@@ -5,7 +5,6 @@ package ent
import (
"context"
"fmt"
"strconv"
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/entc/integration/customid/ent/predicate"
@@ -15,8 +14,8 @@ import (
// UserUpdate is the builder for updating User entities.
type UserUpdate struct {
config
groups map[string]struct{}
removedGroups map[string]struct{}
groups map[int]struct{}
removedGroups map[int]struct{}
predicates []predicate.User
}
@@ -27,9 +26,9 @@ func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate {
}
// AddGroupIDs adds the groups edge to Group by ids.
func (uu *UserUpdate) AddGroupIDs(ids ...string) *UserUpdate {
func (uu *UserUpdate) AddGroupIDs(ids ...int) *UserUpdate {
if uu.groups == nil {
uu.groups = make(map[string]struct{})
uu.groups = make(map[int]struct{})
}
for i := range ids {
uu.groups[ids[i]] = struct{}{}
@@ -39,7 +38,7 @@ func (uu *UserUpdate) AddGroupIDs(ids ...string) *UserUpdate {
// AddGroups adds the groups edges to Group.
func (uu *UserUpdate) AddGroups(g ...*Group) *UserUpdate {
ids := make([]string, len(g))
ids := make([]int, len(g))
for i := range g {
ids[i] = g[i].ID
}
@@ -47,9 +46,9 @@ func (uu *UserUpdate) AddGroups(g ...*Group) *UserUpdate {
}
// RemoveGroupIDs removes the groups edge to Group by ids.
func (uu *UserUpdate) RemoveGroupIDs(ids ...string) *UserUpdate {
func (uu *UserUpdate) RemoveGroupIDs(ids ...int) *UserUpdate {
if uu.removedGroups == nil {
uu.removedGroups = make(map[string]struct{})
uu.removedGroups = make(map[int]struct{})
}
for i := range ids {
uu.removedGroups[ids[i]] = struct{}{}
@@ -59,7 +58,7 @@ func (uu *UserUpdate) RemoveGroupIDs(ids ...string) *UserUpdate {
// RemoveGroups removes groups edges to Group.
func (uu *UserUpdate) RemoveGroups(g ...*Group) *UserUpdate {
ids := make([]string, len(g))
ids := make([]int, len(g))
for i := range g {
ids[i] = g[i].ID
}
@@ -127,11 +126,6 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
if len(uu.removedGroups) > 0 {
eids := make([]int, len(uu.removedGroups))
for eid := range uu.removedGroups {
eid, serr := strconv.Atoi(eid)
if serr != nil {
err = rollback(tx, serr)
return
}
eids = append(eids, eid)
}
query, args := builder.Delete(user.GroupsTable).
@@ -146,11 +140,6 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
values := make([][]int, 0, len(ids))
for _, id := range ids {
for eid := range uu.groups {
eid, serr := strconv.Atoi(eid)
if serr != nil {
err = rollback(tx, serr)
return
}
values = append(values, []int{id, eid})
}
}
@@ -174,14 +163,14 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
type UserUpdateOne struct {
config
id int
groups map[string]struct{}
removedGroups map[string]struct{}
groups map[int]struct{}
removedGroups map[int]struct{}
}
// AddGroupIDs adds the groups edge to Group by ids.
func (uuo *UserUpdateOne) AddGroupIDs(ids ...string) *UserUpdateOne {
func (uuo *UserUpdateOne) AddGroupIDs(ids ...int) *UserUpdateOne {
if uuo.groups == nil {
uuo.groups = make(map[string]struct{})
uuo.groups = make(map[int]struct{})
}
for i := range ids {
uuo.groups[ids[i]] = struct{}{}
@@ -191,7 +180,7 @@ func (uuo *UserUpdateOne) AddGroupIDs(ids ...string) *UserUpdateOne {
// AddGroups adds the groups edges to Group.
func (uuo *UserUpdateOne) AddGroups(g ...*Group) *UserUpdateOne {
ids := make([]string, len(g))
ids := make([]int, len(g))
for i := range g {
ids[i] = g[i].ID
}
@@ -199,9 +188,9 @@ func (uuo *UserUpdateOne) AddGroups(g ...*Group) *UserUpdateOne {
}
// RemoveGroupIDs removes the groups edge to Group by ids.
func (uuo *UserUpdateOne) RemoveGroupIDs(ids ...string) *UserUpdateOne {
func (uuo *UserUpdateOne) RemoveGroupIDs(ids ...int) *UserUpdateOne {
if uuo.removedGroups == nil {
uuo.removedGroups = make(map[string]struct{})
uuo.removedGroups = make(map[int]struct{})
}
for i := range ids {
uuo.removedGroups[ids[i]] = struct{}{}
@@ -211,7 +200,7 @@ func (uuo *UserUpdateOne) RemoveGroupIDs(ids ...string) *UserUpdateOne {
// RemoveGroups removes groups edges to Group.
func (uuo *UserUpdateOne) RemoveGroups(g ...*Group) *UserUpdateOne {
ids := make([]string, len(g))
ids := make([]int, len(g))
for i := range g {
ids[i] = g[i].ID
}
@@ -282,11 +271,6 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) {
if len(uuo.removedGroups) > 0 {
eids := make([]int, len(uuo.removedGroups))
for eid := range uuo.removedGroups {
eid, serr := strconv.Atoi(eid)
if serr != nil {
err = rollback(tx, serr)
return
}
eids = append(eids, eid)
}
query, args := builder.Delete(user.GroupsTable).
@@ -301,11 +285,6 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (u *User, err error) {
values := make([][]int, 0, len(ids))
for _, id := range ids {
for eid := range uuo.groups {
eid, serr := strconv.Atoi(eid)
if serr != nil {
err = rollback(tx, serr)
return
}
values = append(values, []int{id, eid})
}
}

View File

@@ -2,7 +2,7 @@
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package template
package idtype
import (
"context"