From 598ae18ac1ed4aa1fc009bd80b25c651da75169d Mon Sep 17 00:00:00 2001 From: Jannik Clausen <12862103+masseelch@users.noreply.github.com> Date: Mon, 27 Jan 2025 19:40:05 +0100 Subject: [PATCH] entc/gen: fix globalid annotation encoding (#4314) --- entc/gen/globalid.go | 32 ++++++++++++++++++++++---- entc/gen/globalid_test.go | 14 ++++++++--- entc/gen/graph_test.go | 4 +--- entc/integration/ent/migrate/schema.go | 3 ++- 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/entc/gen/globalid.go b/entc/gen/globalid.go index 1aae7c8f3..9d86218c4 100644 --- a/entc/gen/globalid.go +++ b/entc/gen/globalid.go @@ -76,13 +76,9 @@ func IncrementStartAnnotation(g *Graph) error { lastIdx = -1 ) for _, n := range g.Nodes { - if n.Annotations == nil { - n.Annotations = make(Annotations) - } a := n.EntSQL() if a == nil { a = &entsql.Annotation{} - n.Annotations[a.Name()] = a } switch v, ok := r[n.Table()]; { case a.IncrementStart != nil: @@ -98,13 +94,18 @@ func IncrementStartAnnotation(g *Graph) error { if v, ok := r[n.Table()]; ok { lastIdx = max(lastIdx, v/(1<<32-1)) } + if err := setAnnotation(n, a); err != nil { + return err + } } // Compute new ranges and write them back to the file. for i, n := range need { r[n.Table()] = (lastIdx + i + 1) << 32 a := n.EntSQL() a.IncrementStart = func(i int) *int { return &i }(r[n.Table()]) // copy to not override previous values - n.Annotations[a.Name()] = a + if err := setAnnotation(n, a); err != nil { + return err + } } // Ensure increment ranges are exactly of size 1<<32 with no overlaps. d := make(map[int]string) @@ -184,3 +185,24 @@ func ResolveIncrementStartsConflict(dir string) error { } return os.WriteFile(p, bytes.Join(fixed, []byte("\n")), fi.Mode()) } + +func ToMap(a *entsql.Annotation) (map[string]any, error) { + buf, err := json.Marshal(a) + if err != nil { + return nil, err + } + m := make(map[string]any) + if err = json.Unmarshal(buf, &m); err != nil { + return nil, err + } + return m, nil +} + +func setAnnotation(n *Type, a *entsql.Annotation) error { + m, err := ToMap(a) + if err != nil { + return err + } + n.Annotations.Set(a.Name(), m) + return nil +} diff --git a/entc/gen/globalid_test.go b/entc/gen/globalid_test.go index f2f413696..b5c946426 100644 --- a/entc/gen/globalid_test.go +++ b/entc/gen/globalid_test.go @@ -26,7 +26,7 @@ func TestIncrementStartAnnotation(t *testing.T) { s = []*load.Schema{ { Name: "T1", - Annotations: gen.Annotations{a.Name(): a}, + Annotations: map[string]any{a.Name(): must(gen.ToMap(a))}, }, } c = &gen.Config{ @@ -44,7 +44,8 @@ func TestIncrementStartAnnotation(t *testing.T) { g, err = gen.NewGraph(c, s...) require.EqualError(t, err, "unexpected increment start value 100 for type t1s, expected multiple of 4294967296 (1<<32)") require.Nil(t, g) - a.IncrementStart = p(1 << 32) + a = &entsql.Annotation{IncrementStart: p(1 << 32)} + s[0].Annotations[a.Name()] = must(gen.ToMap(a)) g, err = gen.NewGraph(c, s...) require.NoError(t, err) require.NotNil(t, g) @@ -52,7 +53,7 @@ func TestIncrementStartAnnotation(t *testing.T) { // Duplicated increment starting values are not allowed. s = append(s, &load.Schema{Name: "T2"}, &load.Schema{ Name: "T3", - Annotations: gen.Annotations{a.Name(): &entsql.Annotation{IncrementStart: p(1 << 32)}}, + Annotations: map[string]any{a.Name(): must(gen.ToMap(a))}, }) g, err = gen.NewGraph(c, s...) require.ErrorContains(t, err, "duplicated increment start value 4294967296 for types") @@ -112,6 +113,13 @@ const IncrementStarts = %s ) } +func must[T any](t T, err error) T { + if err != nil { + panic(err) + } + return t +} + func marshal(t *testing.T, v any) string { t.Helper() b, err := json.Marshal(v) diff --git a/entc/gen/graph_test.go b/entc/gen/graph_test.go index 17848f66a..57f3b7e62 100644 --- a/entc/gen/graph_test.go +++ b/entc/gen/graph_test.go @@ -11,7 +11,6 @@ import ( "reflect" "testing" - "entgo.io/ent/dialect/entsql" "entgo.io/ent/entc/load" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" @@ -440,9 +439,8 @@ func TestGraph_Gen(t *testing.T) { // Ensure globalid feature added annotations. a := IncrementStarts{"t1s": 0, "t2s": 1 << 32, "t3s": 2 << 32} require.Equal(a, graph.Annotations[a.Name()]) - ant := &entsql.Annotation{} for i, n := range graph.Nodes { - require.Equal(i<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart) + require.Equal(i<<32, *n.EntSQL().IncrementStart) } // Ensure graph files were generated. for _, name := range []string{"ent", "client"} { diff --git a/entc/integration/ent/migrate/schema.go b/entc/integration/ent/migrate/schema.go index b27f472ce..3704811a3 100644 --- a/entc/integration/ent/migrate/schema.go +++ b/entc/integration/ent/migrate/schema.go @@ -681,7 +681,8 @@ func init() { PetTable.ForeignKeys[0].RefTable = UsersTable PetTable.ForeignKeys[1].RefTable = UsersTable PetTable.Annotation = &entsql.Annotation{ - Table: "pet", + Table: "pet", + IncrementStart: func(i int) *int { return &i }(77309411328), } SpecsTable.Annotation = &entsql.Annotation{ IncrementStart: func(i int) *int { return &i }(81604378624),