entc/gen: resolve merge conflicts on global id file (#4312)

This commit is contained in:
Jannik Clausen
2025-01-23 13:59:47 +01:00
committed by GitHub
parent ec5bfadcab
commit 0edce5f3d6
9 changed files with 137 additions and 41 deletions

View File

@@ -18,14 +18,14 @@ import (
const incrementIdent = "const IncrementStarts"
// IncrementStarts holds the autoincrement start value for each type.
type IncrementStarts map[string]int64
type IncrementStarts map[string]int
// IncrementStartAnnotation assigns a unique range to each node in the graph.
func IncrementStartAnnotation(g *Graph) error {
// To ensure we keep the existing type ranges, load the current global id sequence, if there is any.
var (
r = make(IncrementStarts)
path = rangesFilePath(g.Target)
path = IncrementStartsFilePath(g.Target)
buf, err = os.ReadFile(path)
)
switch {
@@ -63,8 +63,8 @@ func IncrementStartAnnotation(g *Graph) error {
}
// Range over all nodes and assign the increment starting value.
var (
need []*Type
last int64
need []*Type
lastIdx = -1
)
for _, n := range g.Nodes {
if n.Annotations == nil {
@@ -86,18 +86,19 @@ func IncrementStartAnnotation(g *Graph) error {
// In case this is a new node, it gets the next free increment range (highest value << 32).
need = append(need, n)
}
last = max(last, r[n.Table()])
if v, ok := r[n.Table()]; ok {
lastIdx = max(lastIdx, v/(1<<32-1))
}
}
// Compute new ranges and write them back to the file.
s := len(g.Nodes) - len(need) // number of nodes with existing increment values
for i, n := range need {
r[n.Table()] = last + int64(s+i)<<32
r[n.Table()] = (lastIdx + i + 1) << 32
a := n.EntSQL()
a.IncrementStart = func(i int64) *int64 { return &i }(r[n.Table()]) // copy to not override previous values
a.IncrementStart = func(i int) *int { return &i }(r[n.Table()]) // copy to not override previous values
n.Annotations[a.Name()] = a
}
// Ensure increment ranges are exactly of size 1<<32 with no overlaps.
d := make(map[int64]string)
d := make(map[int]string)
for t, s := range r {
switch t1, ok := d[s]; {
case ok:
@@ -124,7 +125,7 @@ func (IncrementStarts) Name() string {
// WriteToDisk writes the increment starts to the disk.
func (i IncrementStarts) WriteToDisk(target string) error {
initTemplates()
p := rangesFilePath(target)
p := IncrementStartsFilePath(target)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
return err
}
@@ -140,6 +141,37 @@ func (i IncrementStarts) WriteToDisk(target string) error {
})
}
func rangesFilePath(dir string) string {
func IncrementStartsFilePath(dir string) string {
return filepath.Join(dir, "internal", "globalid.go")
}
// ResolveIncrementStartsConflict resolves git/mercurial conflicts by "accepting theirs".
func ResolveIncrementStartsConflict(dir string) error {
// Expect 2 ranges in the file, accept the second one, since this is the remote content.
p := IncrementStartsFilePath(dir)
fi, err := os.Stat(p)
if err != nil {
return err
}
c, err := os.ReadFile(p)
if err != nil {
return err
}
var (
fixed [][]byte
skipped bool
)
lines := bytes.Split(c, []byte("\n"))
for _, l := range lines {
switch {
case bytes.HasPrefix(l, []byte("<<<<<<<")),
bytes.HasPrefix(l, []byte("=======")),
bytes.HasPrefix(l, []byte(">>>>>>>")):
case bytes.HasPrefix(l, []byte(incrementIdent)) && !skipped:
skipped = true
default:
fixed = append(fixed, l)
}
}
return os.WriteFile(p, bytes.Join(fixed, []byte("\n")), fi.Mode())
}

View File

@@ -5,17 +5,23 @@
package gen_test
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"testing"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/entc/gen"
"entgo.io/ent/entc/internal"
"entgo.io/ent/entc/load"
"github.com/stretchr/testify/require"
)
func TestIncrementStartAnnotation(t *testing.T) {
var (
p = func(i int64) *int64 { return &i }
p = func(i int) *int { return &i }
a = &entsql.Annotation{IncrementStart: p(100)}
s = []*load.Schema{
{
@@ -53,10 +59,62 @@ func TestIncrementStartAnnotation(t *testing.T) {
require.Nil(t, g)
// Respects existing increment starting values loaded from file.
s = []*load.Schema{{Name: "A"}, {Name: "B"}, {Name: "C"}}
c.Target = t.TempDir()
is := gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 2 << 32}
require.NoError(t, is.WriteToDisk(c.Target))
g, err = gen.NewGraph(c, &load.Schema{Name: "A"}, &load.Schema{Name: "B"}, &load.Schema{Name: "C"})
g, err = gen.NewGraph(c, s...)
require.NoError(t, err)
require.Equal(t, is, g.Annotations[is.Name()])
}
func TestResolveConflicts(t *testing.T) {
var (
c = &gen.Config{
Package: "entc/gen",
Target: t.TempDir(),
Features: []gen.Feature{gen.FeatureGlobalID},
}
s = []*load.Schema{{Name: "A"}, {Name: "B"}, {Name: "C"}, {Name: "D"}}
cflct = fmt.Sprintf(`
package internal
<<<<<<< HEAD:globalid.go
const IncrementStarts = %s
=======
const IncrementStarts = %s
>>>>>>> 1234567:globalid.go
`,
// We added the c table which got the range 2<<32
marshal(t, gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 2 << 32}),
// In the meantime someone else added the d table which got the
// range 2<<32 as well, and they merged before we did.
marshal(t, gen.IncrementStarts{"bs": 0, "as": 1 << 32, "ds": 2 << 32}),
)
p = filepath.Join(c.Target, "internal", "globalid.go")
)
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0755))
require.NoError(t, os.WriteFile(p, []byte(cflct), 0644))
// Expect an error when there is a file conflict.
// g, err := gen.NewGraph(c, s...)
// require.Error(t, err)
// Conflict is resolved to "accept theirs".
require.NoError(t, gen.ResolveIncrementStartsConflict(c.Target))
require.NoError(t, internal.CheckDir(filepath.Dir(p)))
g, err := gen.NewGraph(c, s...)
require.NoError(t, err)
// Expect the conflict to be resolved with the remote table d keeping
// its range and our newly added table c gets the next one (3<<32).
require.Equal(t,
gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 3 << 32, "ds": 2 << 32},
g.Annotations[(&gen.IncrementStarts{}).Name()],
)
}
func marshal(t *testing.T, v any) string {
t.Helper()
b, err := json.Marshal(v)
require.NoError(t, err)
return strconv.Quote(string(b))
}

View File

@@ -408,9 +408,7 @@ func TestEnsureCorrectFK(t *testing.T) {
func TestGraph_Gen(t *testing.T) {
require := require.New(t)
target := filepath.Join(os.TempDir(), "ent")
require.NoError(os.MkdirAll(target, os.ModePerm), "creating tmpdir")
defer os.RemoveAll(target)
target := filepath.Join(t.TempDir(), "ent")
external := MustParse(NewTemplate("external").Parse("package external"))
skipped := MustParse(NewTemplate("skipped").SkipIf(func(*Graph) bool { return true }).Parse("package external"))
schemas := []*load.Schema{
@@ -444,7 +442,7 @@ func TestGraph_Gen(t *testing.T) {
require.Equal(a, graph.Annotations[a.Name()])
ant := &entsql.Annotation{}
for i, n := range graph.Nodes {
require.Equal(int64(i)<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart)
require.Equal(i<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart)
}
// Ensure graph files were generated.
for _, name := range []string{"ent", "client"} {

View File

@@ -222,7 +222,7 @@ func init() {
Check: {{ quote . }},
{{- end }}
{{- with $ant.IncrementStart }}
IncrementStart: func(i int64) *int64 { return &i }({{ . }}),
IncrementStart: func(i int) *int { return &i }({{ . }}),
{{- end }}
}
{{- with $ant.Incremental }}