entc: global id feature (#4293)

Feature to add annotations to ent schema to ensure sql tables have sequential auto-increment id columns. Meant to be a better alternative for entgql / gqlgen globally unique id feature.
This commit is contained in:
Jannik Clausen
2025-01-13 20:49:28 +01:00
committed by GitHub
parent adfd86c303
commit 6cfa2288bb
15 changed files with 351 additions and 19 deletions

View File

@@ -144,6 +144,21 @@ var (
Description: "Allows users to work with versioned migrations / migration files",
}
FeatureGlobalID = Feature{
Name: "sql/globalid",
Stage: Experimental,
Default: false,
Description: "Ensures all nodes have a unique global identifier", GraphTemplates: []GraphTemplate{
{
Name: "internal/globalid",
Format: "internal/globalid.go",
},
},
cleanup: func(c *Config) error {
return remove(filepath.Join(c.Target, "internal"), "globalid.go")
},
}
// AllFeatures holds a list of all feature-flags.
AllFeatures = []Feature{
FeaturePrivacy,
@@ -158,6 +173,7 @@ var (
FeatureExecQuery,
FeatureUpsert,
FeatureVersionedMigration,
FeatureGlobalID,
}
// allFeatures includes all public and private features.
allFeatures = append(AllFeatures, featureMultiSchema)

View File

@@ -6,6 +6,7 @@ package gen
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"go/token"
@@ -47,6 +48,7 @@ var (
"indexOf": indexOf,
"join": join,
"joinWords": joinWords,
"json": jsonString,
"isNil": isNil,
"lower": strings.ToLower,
"upper": strings.ToUpper,
@@ -529,3 +531,12 @@ func list[T any](v ...T) []T {
func fail(msg string) (string, error) {
return "", errors.New(msg)
}
// jsonString returns a json encoded value as string.
func jsonString(v any) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", nil
}
return string(b), nil
}

126
entc/gen/globalid.go Normal file
View File

@@ -0,0 +1,126 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package gen
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"entgo.io/ent/dialect/entsql"
)
const incrementIdent = "const IncrementStarts"
// IncrementStarts holds the autoincrement start value for each type.
type IncrementStarts map[string]int64
// IncrementStartAnnotation assigns a unique range to each node in the graph.
func IncrementStartAnnotation(g *Graph) error {
// To ensure we keep the existing type ranges, load the current global id sequence, if there is any.
var (
r = make(IncrementStarts)
path = rangesFilePath(g.Target)
buf, err = os.ReadFile(path)
)
switch {
case os.IsNotExist(err): // first time generation
case err != nil:
return err
default:
var (
matches = make([][]byte, 0, 2)
lines = bytes.Split(buf, []byte("\n"))
)
for i := 0; i < len(lines); i++ {
if l := lines[i]; bytes.HasPrefix(l, []byte(incrementIdent)) {
matches = append(matches, l)
}
}
if len(matches) != 1 {
return fmt.Errorf("expect to have exactly 1 ranges in %s", path)
}
var (
line = matches[0]
start = bytes.IndexByte(line, '"')
end = bytes.LastIndexByte(line, '"')
)
if start == -1 || start >= end {
return fmt.Errorf("unexpected line %s", line)
}
l, err := strconv.Unquote(string(line[start : end+1]))
if err != nil {
return err
}
if err := json.Unmarshal([]byte(l), &r); err != nil {
return fmt.Errorf("unmarshal ranges: %w", err)
}
}
// Range over all nodes and assign the increment starting value.
var (
need []*Type
last int64
)
for _, n := range g.Nodes {
if n.Annotations == nil {
n.Annotations = make(Annotations)
}
a := n.EntSQL()
if a == nil {
a = &entsql.Annotation{}
n.Annotations[a.Name()] = a
}
switch v, ok := r[n.Table()]; {
case a.IncrementStart != nil:
// In case the start value is defined by an annotation already, it has precedence.
r[n.Table()] = *a.IncrementStart
case ok:
// In case this node has no annotation but an existing entry in the increments file.
a.IncrementStart = &v
default:
// In case this is a new node, it gets the next free increment range (highest value << 32).
need = append(need, n)
}
last = max(last, r[n.Table()])
}
// Compute new ranges and write them back to the file.
s := len(g.Nodes) - len(need) // number of nodes with existing increment values
for i, n := range need {
r[n.Table()] = last + int64(s+i)<<32
a := n.EntSQL()
a.IncrementStart = func(i int64) *int64 { return &i }(r[n.Table()]) // copy to not override previous values
n.Annotations[a.Name()] = a
}
// Ensure increment ranges are exactly of size 1<<32 with no overlaps.
d := make(map[int64]string)
for t, s := range r {
switch t1, ok := d[s]; {
case ok:
return fmt.Errorf("duplicated increment start value %d for types %s and %s", s, t1, t)
case s%(1<<32) != 0:
return fmt.Errorf(
"unexpected increment start value %d for type %s, expected multiple of %d (1<<32)", s, t, 1<<32,
)
}
d[s] = t
}
if g.Annotations == nil {
g.Annotations = make(Annotations)
}
g.Annotations[r.Name()] = r
return nil
}
// Name implements Annotation interface.
func (IncrementStarts) Name() string {
return "IncrementStarts"
}
func rangesFilePath(dir string) string {
return filepath.Join(dir, "internal", "globalid.go")
}

55
entc/gen/globalid_test.go Normal file
View File

@@ -0,0 +1,55 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package gen_test
import (
"testing"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/entc/gen"
"entgo.io/ent/entc/load"
"github.com/stretchr/testify/require"
)
func TestIncrementStartAnnotation(t *testing.T) {
var (
p = func(i int64) *int64 { return &i }
a = &entsql.Annotation{IncrementStart: p(100)}
s = []*load.Schema{
{
Name: "T1",
Annotations: gen.Annotations{a.Name(): a},
},
{Name: "T2"},
}
c = &gen.Config{
Package: "entc/gen",
Target: t.TempDir(),
}
)
// Arbitrary increment starting values allowed if feature is not enabled.
g, err := gen.NewGraph(c, s...)
require.NoError(t, err)
require.NotNil(t, g)
// Increments must be a multiple of 1<<32.
c.Features = []gen.Feature{gen.FeatureGlobalID}
g, err = gen.NewGraph(c, s...)
require.EqualError(t, err, "unexpected increment start value 100 for type t1s, expected multiple of 4294967296 (1<<32)")
require.Nil(t, g)
a.IncrementStart = p(1 << 32)
g, err = gen.NewGraph(c, s...)
require.NoError(t, err)
require.NotNil(t, g)
// Duplicated increment starting values are not allowed.
s = append(s, &load.Schema{
Name: "T3",
Annotations: gen.Annotations{a.Name(): &entsql.Annotation{IncrementStart: p(1 << 32)}},
})
g, err = gen.NewGraph(c, s...)
require.ErrorContains(t, err, "duplicated increment start value 4294967296 for types")
require.Nil(t, g)
}

View File

@@ -181,6 +181,11 @@ func NewGraph(c *Config, schemas ...*load.Schema) (g *Graph, err error) {
if c.Storage != nil && c.Storage.Init != nil {
check(c.Storage.Init(g), "storage driver init")
}
if enabled, _ := g.Config.FeatureEnabled(FeatureGlobalID.Name); enabled {
if err := IncrementStartAnnotation(g); err != nil {
return nil, err
}
}
return
}

View File

@@ -11,6 +11,7 @@ import (
"reflect"
"testing"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/entc/load"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
@@ -424,9 +425,8 @@ func TestGraph_Gen(t *testing.T) {
{Name: "t1", Type: "T1", Unique: true},
},
},
{
Name: "T2",
},
{Name: "T2"},
{Name: "T3"},
}
graph, err := NewGraph(&Config{
Package: "entc/gen",
@@ -439,6 +439,13 @@ func TestGraph_Gen(t *testing.T) {
require.NoError(err)
require.NotNil(graph)
require.NoError(graph.Gen())
// Ensure globalid feature added annotations.
a := IncrementStarts{"t1s": 0, "t2s": 1 << 32, "t3s": 2 << 32}
require.Equal(a, graph.Annotations[a.Name()])
ant := &entsql.Annotation{}
for i, n := range graph.Nodes {
require.Equal(int64(i)<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart)
}
// Ensure graph files were generated.
for _, name := range []string{"ent", "client"} {
_, err := os.Stat(fmt.Sprintf("%s/%s.go", target, name))
@@ -461,6 +468,9 @@ func TestGraph_Gen(t *testing.T) {
require.NoError(err)
_, err = os.Stat(filepath.Join(target, "internal", "schemaconfig.go"))
require.NoError(err)
c, err := os.ReadFile(filepath.Join(target, "internal", "globalid.go"))
require.NoError(err)
require.Contains(string(c), fmt.Sprintf(`"{\"t1s\":0,\"t2s\":%d,\"t3s\":%d}"`, 1<<32, 2<<32))
// Rerun codegen with only one feature-flag.
graph.Features = []Feature{FeatureSnapshot}
require.NoError(graph.Gen())
@@ -469,6 +479,8 @@ func TestGraph_Gen(t *testing.T) {
require.NoError(err)
_, err = os.Stat(filepath.Join(target, "internal", "schemaconfig.go"))
require.True(os.IsNotExist(err))
_, err = os.Stat(filepath.Join(target, "internal", "globalid.go"))
require.True(os.IsNotExist(err))
// Rerun codegen without any feature-flags.
graph.Features = nil
require.NoError(graph.Gen())

View File

@@ -17,3 +17,14 @@ package internal
const Schema = {{ .SchemaSnapshot }}
{{ end }}
{{ define "internal/globalid" }}
{{ with $.Header }}{{ . }}{{ else }}// Code generated by ent, DO NOT EDIT.{{ end }}
// +build tools
package internal
const IncrementStarts = {{ .Annotations.IncrementStarts | json | quote }}
{{ end }}

View File

@@ -221,6 +221,9 @@ func init() {
{{- with $ant.Check }}
Check: {{ quote . }},
{{- end }}
{{- with $ant.IncrementStart }}
IncrementStart: func(i int64) *int64 { return &i }({{ . }}),
{{- end }}
}
{{- with $ant.Incremental }}
{{ $table }}.Annotation.Incremental = new(bool)