entc/gen: resolve merge conflicts on global id file (#4312)

This commit is contained in:
Jannik Clausen
2025-01-23 13:59:47 +01:00
committed by GitHub
parent ec5bfadcab
commit 0edce5f3d6
9 changed files with 137 additions and 41 deletions

View File

@@ -135,7 +135,7 @@ type Annotation struct {
//
// By default, this value is nil defaulting to whatever the database settings are.
//
IncrementStart *int64 `json:"increment_start,omitempty"`
IncrementStart *int `json:"increment_start,omitempty"`
// OnDelete specifies a custom referential action for DELETE operations on parent
// table that has matching rows in the child table.
@@ -384,7 +384,7 @@ func OnDelete(opt ReferenceOption) *Annotation {
// entsql.IncrementStart(100),
// }
// }
func IncrementStart(i int64) *Annotation {
func IncrementStart(i int) *Annotation {
return &Annotation{
IncrementStart: &i,
}

View File

@@ -860,7 +860,7 @@ func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) {
}
a.sqlDialect.atIncrementT(at, r)
case et.Annotation != nil && et.Annotation.IncrementStart != nil:
a.sqlDialect.atIncrementT(at, *et.Annotation.IncrementStart)
a.sqlDialect.atIncrementT(at, int64(*et.Annotation.IncrementStart))
}
if err := a.aColumns(et, at); err != nil {
return nil, err

View File

@@ -376,7 +376,7 @@ func TestAtlas_StateReader(t *testing.T) {
{Name: "active", Type: field.TypeBool},
},
Annotation: &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(100),
IncrementStart: func(i int) *int { return &i }(100),
},
}).ReadState(context.Background())
require.NoError(t, err)

View File

@@ -383,7 +383,7 @@ func generate(schemaPath string, cfg *gen.Config) error {
}
func mayRecover(err error, schemaPath string, cfg *gen.Config) error {
if enabled, _ := cfg.FeatureEnabled(gen.FeatureSnapshot.Name); !enabled {
if ok, _ := cfg.FeatureEnabled(gen.FeatureSnapshot.Name); !ok {
return err
}
if !errors.As(err, &packages.Error{}) && !internal.IsBuildError(err) {
@@ -394,6 +394,14 @@ func mayRecover(err error, schemaPath string, cfg *gen.Config) error {
return fmt.Errorf("schema failure: %w", err)
}
target := filepath.Join(cfg.Target, "internal/schema.go")
if ok, _ := cfg.FeatureEnabled(gen.FeatureGlobalID.Name); ok {
if internal.CheckDir(gen.IncrementStartsFilePath(target)) != nil {
// Resolve the conflict by accepting the remote version of the file.
if err := gen.ResolveIncrementStartsConflict(cfg.Target); err != nil {
return err
}
}
}
return (&internal.Snapshot{Path: target, Config: cfg}).Restore()
}

View File

@@ -18,14 +18,14 @@ import (
const incrementIdent = "const IncrementStarts"
// IncrementStarts holds the autoincrement start value for each type.
type IncrementStarts map[string]int64
type IncrementStarts map[string]int
// IncrementStartAnnotation assigns a unique range to each node in the graph.
func IncrementStartAnnotation(g *Graph) error {
// To ensure we keep the existing type ranges, load the current global id sequence, if there is any.
var (
r = make(IncrementStarts)
path = rangesFilePath(g.Target)
path = IncrementStartsFilePath(g.Target)
buf, err = os.ReadFile(path)
)
switch {
@@ -63,8 +63,8 @@ func IncrementStartAnnotation(g *Graph) error {
}
// Range over all nodes and assign the increment starting value.
var (
need []*Type
last int64
need []*Type
lastIdx = -1
)
for _, n := range g.Nodes {
if n.Annotations == nil {
@@ -86,18 +86,19 @@ func IncrementStartAnnotation(g *Graph) error {
// In case this is a new node, it gets the next free increment range (highest value << 32).
need = append(need, n)
}
last = max(last, r[n.Table()])
if v, ok := r[n.Table()]; ok {
lastIdx = max(lastIdx, v/(1<<32-1))
}
}
// Compute new ranges and write them back to the file.
s := len(g.Nodes) - len(need) // number of nodes with existing increment values
for i, n := range need {
r[n.Table()] = last + int64(s+i)<<32
r[n.Table()] = (lastIdx + i + 1) << 32
a := n.EntSQL()
a.IncrementStart = func(i int64) *int64 { return &i }(r[n.Table()]) // copy to not override previous values
a.IncrementStart = func(i int) *int { return &i }(r[n.Table()]) // copy to not override previous values
n.Annotations[a.Name()] = a
}
// Ensure increment ranges are exactly of size 1<<32 with no overlaps.
d := make(map[int64]string)
d := make(map[int]string)
for t, s := range r {
switch t1, ok := d[s]; {
case ok:
@@ -124,7 +125,7 @@ func (IncrementStarts) Name() string {
// WriteToDisk writes the increment starts to the disk.
func (i IncrementStarts) WriteToDisk(target string) error {
initTemplates()
p := rangesFilePath(target)
p := IncrementStartsFilePath(target)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
return err
}
@@ -140,6 +141,37 @@ func (i IncrementStarts) WriteToDisk(target string) error {
})
}
func rangesFilePath(dir string) string {
func IncrementStartsFilePath(dir string) string {
return filepath.Join(dir, "internal", "globalid.go")
}
// ResolveIncrementStartsConflict resolves git/mercurial conflicts by "accepting theirs".
func ResolveIncrementStartsConflict(dir string) error {
// Expect 2 ranges in the file, accept the second one, since this is the remote content.
p := IncrementStartsFilePath(dir)
fi, err := os.Stat(p)
if err != nil {
return err
}
c, err := os.ReadFile(p)
if err != nil {
return err
}
var (
fixed [][]byte
skipped bool
)
lines := bytes.Split(c, []byte("\n"))
for _, l := range lines {
switch {
case bytes.HasPrefix(l, []byte("<<<<<<<")),
bytes.HasPrefix(l, []byte("=======")),
bytes.HasPrefix(l, []byte(">>>>>>>")):
case bytes.HasPrefix(l, []byte(incrementIdent)) && !skipped:
skipped = true
default:
fixed = append(fixed, l)
}
}
return os.WriteFile(p, bytes.Join(fixed, []byte("\n")), fi.Mode())
}

View File

@@ -5,17 +5,23 @@
package gen_test
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"testing"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/entc/gen"
"entgo.io/ent/entc/internal"
"entgo.io/ent/entc/load"
"github.com/stretchr/testify/require"
)
func TestIncrementStartAnnotation(t *testing.T) {
var (
p = func(i int64) *int64 { return &i }
p = func(i int) *int { return &i }
a = &entsql.Annotation{IncrementStart: p(100)}
s = []*load.Schema{
{
@@ -53,10 +59,62 @@ func TestIncrementStartAnnotation(t *testing.T) {
require.Nil(t, g)
// Respects existing increment starting values loaded from file.
s = []*load.Schema{{Name: "A"}, {Name: "B"}, {Name: "C"}}
c.Target = t.TempDir()
is := gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 2 << 32}
require.NoError(t, is.WriteToDisk(c.Target))
g, err = gen.NewGraph(c, &load.Schema{Name: "A"}, &load.Schema{Name: "B"}, &load.Schema{Name: "C"})
g, err = gen.NewGraph(c, s...)
require.NoError(t, err)
require.Equal(t, is, g.Annotations[is.Name()])
}
func TestResolveConflicts(t *testing.T) {
var (
c = &gen.Config{
Package: "entc/gen",
Target: t.TempDir(),
Features: []gen.Feature{gen.FeatureGlobalID},
}
s = []*load.Schema{{Name: "A"}, {Name: "B"}, {Name: "C"}, {Name: "D"}}
cflct = fmt.Sprintf(`
package internal
<<<<<<< HEAD:globalid.go
const IncrementStarts = %s
=======
const IncrementStarts = %s
>>>>>>> 1234567:globalid.go
`,
// We added the c table which got the range 2<<32
marshal(t, gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 2 << 32}),
// In the meantime someone else added the d table which got the
// range 2<<32 as well, and they merged before we did.
marshal(t, gen.IncrementStarts{"bs": 0, "as": 1 << 32, "ds": 2 << 32}),
)
p = filepath.Join(c.Target, "internal", "globalid.go")
)
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0755))
require.NoError(t, os.WriteFile(p, []byte(cflct), 0644))
// Expect an error when there is a file conflict.
// g, err := gen.NewGraph(c, s...)
// require.Error(t, err)
// Conflict is resolved to "accept theirs".
require.NoError(t, gen.ResolveIncrementStartsConflict(c.Target))
require.NoError(t, internal.CheckDir(filepath.Dir(p)))
g, err := gen.NewGraph(c, s...)
require.NoError(t, err)
// Expect the conflict to be resolved with the remote table d keeping
// its range and our newly added table c gets the next one (3<<32).
require.Equal(t,
gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 3 << 32, "ds": 2 << 32},
g.Annotations[(&gen.IncrementStarts{}).Name()],
)
}
func marshal(t *testing.T, v any) string {
t.Helper()
b, err := json.Marshal(v)
require.NoError(t, err)
return strconv.Quote(string(b))
}

View File

@@ -408,9 +408,7 @@ func TestEnsureCorrectFK(t *testing.T) {
func TestGraph_Gen(t *testing.T) {
require := require.New(t)
target := filepath.Join(os.TempDir(), "ent")
require.NoError(os.MkdirAll(target, os.ModePerm), "creating tmpdir")
defer os.RemoveAll(target)
target := filepath.Join(t.TempDir(), "ent")
external := MustParse(NewTemplate("external").Parse("package external"))
skipped := MustParse(NewTemplate("skipped").SkipIf(func(*Graph) bool { return true }).Parse("package external"))
schemas := []*load.Schema{
@@ -444,7 +442,7 @@ func TestGraph_Gen(t *testing.T) {
require.Equal(a, graph.Annotations[a.Name()])
ant := &entsql.Annotation{}
for i, n := range graph.Nodes {
require.Equal(int64(i)<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart)
require.Equal(i<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart)
}
// Ensure graph files were generated.
for _, name := range []string{"ent", "client"} {

View File

@@ -222,7 +222,7 @@ func init() {
Check: {{ quote . }},
{{- end }}
{{- with $ant.IncrementStart }}
IncrementStart: func(i int64) *int64 { return &i }({{ . }}),
IncrementStart: func(i int) *int { return &i }({{ . }}),
{{- end }}
}
{{- with $ant.Incremental }}

View File

@@ -627,56 +627,56 @@ var (
func init() {
ApisTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(12884901888),
IncrementStart: func(i int) *int { return &i }(12884901888),
}
BuildersTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(17179869184),
IncrementStart: func(i int) *int { return &i }(17179869184),
}
CardsTable.ForeignKeys[0].RefTable = UsersTable
CardsTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(21474836480),
IncrementStart: func(i int) *int { return &i }(21474836480),
}
CommentsTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(25769803776),
IncrementStart: func(i int) *int { return &i }(25769803776),
}
ExValueScansTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(30064771072),
IncrementStart: func(i int) *int { return &i }(30064771072),
}
FieldTypesTable.ForeignKeys[0].RefTable = FilesTable
FieldTypesTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(34359738368),
IncrementStart: func(i int) *int { return &i }(34359738368),
}
FilesTable.ForeignKeys[0].RefTable = FileTypesTable
FilesTable.ForeignKeys[1].RefTable = GroupsTable
FilesTable.ForeignKeys[2].RefTable = UsersTable
FilesTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(38654705664),
IncrementStart: func(i int) *int { return &i }(38654705664),
}
FileTypesTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(42949672960),
IncrementStart: func(i int) *int { return &i }(42949672960),
}
GoodsTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(47244640256),
IncrementStart: func(i int) *int { return &i }(47244640256),
}
GroupsTable.ForeignKeys[0].RefTable = GroupInfosTable
GroupsTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(51539607552),
IncrementStart: func(i int) *int { return &i }(51539607552),
}
GroupInfosTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(55834574848),
IncrementStart: func(i int) *int { return &i }(55834574848),
}
ItemsTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(60129542144),
IncrementStart: func(i int) *int { return &i }(60129542144),
}
LicensesTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(64424509440),
IncrementStart: func(i int) *int { return &i }(64424509440),
}
NodesTable.ForeignKeys[0].RefTable = NodesTable
NodesTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(68719476736),
IncrementStart: func(i int) *int { return &i }(68719476736),
}
PcsTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(73014444032),
IncrementStart: func(i int) *int { return &i }(73014444032),
}
PetTable.ForeignKeys[0].RefTable = UsersTable
PetTable.ForeignKeys[1].RefTable = UsersTable
@@ -684,16 +684,16 @@ func init() {
Table: "pet",
}
SpecsTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(81604378624),
IncrementStart: func(i int) *int { return &i }(81604378624),
}
TasksTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(85899345920),
IncrementStart: func(i int) *int { return &i }(85899345920),
}
UsersTable.ForeignKeys[0].RefTable = GroupsTable
UsersTable.ForeignKeys[1].RefTable = UsersTable
UsersTable.ForeignKeys[2].RefTable = UsersTable
UsersTable.Annotation = &entsql.Annotation{
IncrementStart: func(i int64) *int64 { return &i }(8589934592),
IncrementStart: func(i int) *int { return &i }(8589934592),
}
SpecCardTable.ForeignKeys[0].RefTable = SpecsTable
SpecCardTable.ForeignKeys[1].RefTable = CardsTable