mirror of
https://github.com/ent/ent.git
synced 2026-04-28 05:30:56 +03:00
entc/gen: resolve merge conflicts on global id file (#4312)
This commit is contained in:
@@ -18,14 +18,14 @@ import (
|
||||
const incrementIdent = "const IncrementStarts"
|
||||
|
||||
// IncrementStarts holds the autoincrement start value for each type.
|
||||
type IncrementStarts map[string]int64
|
||||
type IncrementStarts map[string]int
|
||||
|
||||
// IncrementStartAnnotation assigns a unique range to each node in the graph.
|
||||
func IncrementStartAnnotation(g *Graph) error {
|
||||
// To ensure we keep the existing type ranges, load the current global id sequence, if there is any.
|
||||
var (
|
||||
r = make(IncrementStarts)
|
||||
path = rangesFilePath(g.Target)
|
||||
path = IncrementStartsFilePath(g.Target)
|
||||
buf, err = os.ReadFile(path)
|
||||
)
|
||||
switch {
|
||||
@@ -63,8 +63,8 @@ func IncrementStartAnnotation(g *Graph) error {
|
||||
}
|
||||
// Range over all nodes and assign the increment starting value.
|
||||
var (
|
||||
need []*Type
|
||||
last int64
|
||||
need []*Type
|
||||
lastIdx = -1
|
||||
)
|
||||
for _, n := range g.Nodes {
|
||||
if n.Annotations == nil {
|
||||
@@ -86,18 +86,19 @@ func IncrementStartAnnotation(g *Graph) error {
|
||||
// In case this is a new node, it gets the next free increment range (highest value << 32).
|
||||
need = append(need, n)
|
||||
}
|
||||
last = max(last, r[n.Table()])
|
||||
if v, ok := r[n.Table()]; ok {
|
||||
lastIdx = max(lastIdx, v/(1<<32-1))
|
||||
}
|
||||
}
|
||||
// Compute new ranges and write them back to the file.
|
||||
s := len(g.Nodes) - len(need) // number of nodes with existing increment values
|
||||
for i, n := range need {
|
||||
r[n.Table()] = last + int64(s+i)<<32
|
||||
r[n.Table()] = (lastIdx + i + 1) << 32
|
||||
a := n.EntSQL()
|
||||
a.IncrementStart = func(i int64) *int64 { return &i }(r[n.Table()]) // copy to not override previous values
|
||||
a.IncrementStart = func(i int) *int { return &i }(r[n.Table()]) // copy to not override previous values
|
||||
n.Annotations[a.Name()] = a
|
||||
}
|
||||
// Ensure increment ranges are exactly of size 1<<32 with no overlaps.
|
||||
d := make(map[int64]string)
|
||||
d := make(map[int]string)
|
||||
for t, s := range r {
|
||||
switch t1, ok := d[s]; {
|
||||
case ok:
|
||||
@@ -124,7 +125,7 @@ func (IncrementStarts) Name() string {
|
||||
// WriteToDisk writes the increment starts to the disk.
|
||||
func (i IncrementStarts) WriteToDisk(target string) error {
|
||||
initTemplates()
|
||||
p := rangesFilePath(target)
|
||||
p := IncrementStartsFilePath(target)
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -140,6 +141,37 @@ func (i IncrementStarts) WriteToDisk(target string) error {
|
||||
})
|
||||
}
|
||||
|
||||
func rangesFilePath(dir string) string {
|
||||
func IncrementStartsFilePath(dir string) string {
|
||||
return filepath.Join(dir, "internal", "globalid.go")
|
||||
}
|
||||
|
||||
// ResolveIncrementStartsConflict resolves git/mercurial conflicts by "accepting theirs".
|
||||
func ResolveIncrementStartsConflict(dir string) error {
|
||||
// Expect 2 ranges in the file, accept the second one, since this is the remote content.
|
||||
p := IncrementStartsFilePath(dir)
|
||||
fi, err := os.Stat(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
fixed [][]byte
|
||||
skipped bool
|
||||
)
|
||||
lines := bytes.Split(c, []byte("\n"))
|
||||
for _, l := range lines {
|
||||
switch {
|
||||
case bytes.HasPrefix(l, []byte("<<<<<<<")),
|
||||
bytes.HasPrefix(l, []byte("=======")),
|
||||
bytes.HasPrefix(l, []byte(">>>>>>>")):
|
||||
case bytes.HasPrefix(l, []byte(incrementIdent)) && !skipped:
|
||||
skipped = true
|
||||
default:
|
||||
fixed = append(fixed, l)
|
||||
}
|
||||
}
|
||||
return os.WriteFile(p, bytes.Join(fixed, []byte("\n")), fi.Mode())
|
||||
}
|
||||
|
||||
@@ -5,17 +5,23 @@
|
||||
package gen_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/entc/gen"
|
||||
"entgo.io/ent/entc/internal"
|
||||
"entgo.io/ent/entc/load"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIncrementStartAnnotation(t *testing.T) {
|
||||
var (
|
||||
p = func(i int64) *int64 { return &i }
|
||||
p = func(i int) *int { return &i }
|
||||
a = &entsql.Annotation{IncrementStart: p(100)}
|
||||
s = []*load.Schema{
|
||||
{
|
||||
@@ -53,10 +59,62 @@ func TestIncrementStartAnnotation(t *testing.T) {
|
||||
require.Nil(t, g)
|
||||
|
||||
// Respects existing increment starting values loaded from file.
|
||||
s = []*load.Schema{{Name: "A"}, {Name: "B"}, {Name: "C"}}
|
||||
c.Target = t.TempDir()
|
||||
is := gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 2 << 32}
|
||||
require.NoError(t, is.WriteToDisk(c.Target))
|
||||
g, err = gen.NewGraph(c, &load.Schema{Name: "A"}, &load.Schema{Name: "B"}, &load.Schema{Name: "C"})
|
||||
g, err = gen.NewGraph(c, s...)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, is, g.Annotations[is.Name()])
|
||||
}
|
||||
|
||||
func TestResolveConflicts(t *testing.T) {
|
||||
var (
|
||||
c = &gen.Config{
|
||||
Package: "entc/gen",
|
||||
Target: t.TempDir(),
|
||||
Features: []gen.Feature{gen.FeatureGlobalID},
|
||||
}
|
||||
s = []*load.Schema{{Name: "A"}, {Name: "B"}, {Name: "C"}, {Name: "D"}}
|
||||
cflct = fmt.Sprintf(`
|
||||
package internal
|
||||
|
||||
<<<<<<< HEAD:globalid.go
|
||||
const IncrementStarts = %s
|
||||
=======
|
||||
const IncrementStarts = %s
|
||||
>>>>>>> 1234567:globalid.go
|
||||
`,
|
||||
// We added the c table which got the range 2<<32
|
||||
marshal(t, gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 2 << 32}),
|
||||
// In the meantime someone else added the d table which got the
|
||||
// range 2<<32 as well, and they merged before we did.
|
||||
marshal(t, gen.IncrementStarts{"bs": 0, "as": 1 << 32, "ds": 2 << 32}),
|
||||
)
|
||||
p = filepath.Join(c.Target, "internal", "globalid.go")
|
||||
)
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0755))
|
||||
require.NoError(t, os.WriteFile(p, []byte(cflct), 0644))
|
||||
|
||||
// Expect an error when there is a file conflict.
|
||||
// g, err := gen.NewGraph(c, s...)
|
||||
// require.Error(t, err)
|
||||
// Conflict is resolved to "accept theirs".
|
||||
require.NoError(t, gen.ResolveIncrementStartsConflict(c.Target))
|
||||
require.NoError(t, internal.CheckDir(filepath.Dir(p)))
|
||||
g, err := gen.NewGraph(c, s...)
|
||||
require.NoError(t, err)
|
||||
// Expect the conflict to be resolved with the remote table d keeping
|
||||
// its range and our newly added table c gets the next one (3<<32).
|
||||
require.Equal(t,
|
||||
gen.IncrementStarts{"bs": 0, "as": 1 << 32, "cs": 3 << 32, "ds": 2 << 32},
|
||||
g.Annotations[(&gen.IncrementStarts{}).Name()],
|
||||
)
|
||||
}
|
||||
|
||||
func marshal(t *testing.T, v any) string {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
return strconv.Quote(string(b))
|
||||
}
|
||||
|
||||
@@ -408,9 +408,7 @@ func TestEnsureCorrectFK(t *testing.T) {
|
||||
|
||||
func TestGraph_Gen(t *testing.T) {
|
||||
require := require.New(t)
|
||||
target := filepath.Join(os.TempDir(), "ent")
|
||||
require.NoError(os.MkdirAll(target, os.ModePerm), "creating tmpdir")
|
||||
defer os.RemoveAll(target)
|
||||
target := filepath.Join(t.TempDir(), "ent")
|
||||
external := MustParse(NewTemplate("external").Parse("package external"))
|
||||
skipped := MustParse(NewTemplate("skipped").SkipIf(func(*Graph) bool { return true }).Parse("package external"))
|
||||
schemas := []*load.Schema{
|
||||
@@ -444,7 +442,7 @@ func TestGraph_Gen(t *testing.T) {
|
||||
require.Equal(a, graph.Annotations[a.Name()])
|
||||
ant := &entsql.Annotation{}
|
||||
for i, n := range graph.Nodes {
|
||||
require.Equal(int64(i)<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart)
|
||||
require.Equal(i<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart)
|
||||
}
|
||||
// Ensure graph files were generated.
|
||||
for _, name := range []string{"ent", "client"} {
|
||||
|
||||
@@ -222,7 +222,7 @@ func init() {
|
||||
Check: {{ quote . }},
|
||||
{{- end }}
|
||||
{{- with $ant.IncrementStart }}
|
||||
IncrementStart: func(i int64) *int64 { return &i }({{ . }}),
|
||||
IncrementStart: func(i int) *int { return &i }({{ . }}),
|
||||
{{- end }}
|
||||
}
|
||||
{{- with $ant.Incremental }}
|
||||
|
||||
Reference in New Issue
Block a user