mirror of
https://github.com/ent/ent.git
synced 2026-04-28 05:30:56 +03:00
entc: global id feature (#4293)
Feature to add annotations to ent schema to ensure sql tables have sequential auto-increment id columns. Meant to be a better alternative for entgql / gqlgen globally unique id feature.
This commit is contained in:
@@ -144,6 +144,21 @@ var (
|
||||
Description: "Allows users to work with versioned migrations / migration files",
|
||||
}
|
||||
|
||||
FeatureGlobalID = Feature{
|
||||
Name: "sql/globalid",
|
||||
Stage: Experimental,
|
||||
Default: false,
|
||||
Description: "Ensures all nodes have a unique global identifier", GraphTemplates: []GraphTemplate{
|
||||
{
|
||||
Name: "internal/globalid",
|
||||
Format: "internal/globalid.go",
|
||||
},
|
||||
},
|
||||
cleanup: func(c *Config) error {
|
||||
return remove(filepath.Join(c.Target, "internal"), "globalid.go")
|
||||
},
|
||||
}
|
||||
|
||||
// AllFeatures holds a list of all feature-flags.
|
||||
AllFeatures = []Feature{
|
||||
FeaturePrivacy,
|
||||
@@ -158,6 +173,7 @@ var (
|
||||
FeatureExecQuery,
|
||||
FeatureUpsert,
|
||||
FeatureVersionedMigration,
|
||||
FeatureGlobalID,
|
||||
}
|
||||
// allFeatures includes all public and private features.
|
||||
allFeatures = append(AllFeatures, featureMultiSchema)
|
||||
|
||||
@@ -6,6 +6,7 @@ package gen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/token"
|
||||
@@ -47,6 +48,7 @@ var (
|
||||
"indexOf": indexOf,
|
||||
"join": join,
|
||||
"joinWords": joinWords,
|
||||
"json": jsonString,
|
||||
"isNil": isNil,
|
||||
"lower": strings.ToLower,
|
||||
"upper": strings.ToUpper,
|
||||
@@ -529,3 +531,12 @@ func list[T any](v ...T) []T {
|
||||
func fail(msg string) (string, error) {
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
|
||||
// jsonString returns a json encoded value as string.
|
||||
func jsonString(v any) (string, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
126
entc/gen/globalid.go
Normal file
126
entc/gen/globalid.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package gen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
)
|
||||
|
||||
const incrementIdent = "const IncrementStarts"
|
||||
|
||||
// IncrementStarts holds the autoincrement start value for each type.
|
||||
type IncrementStarts map[string]int64
|
||||
|
||||
// IncrementStartAnnotation assigns a unique range to each node in the graph.
|
||||
func IncrementStartAnnotation(g *Graph) error {
|
||||
// To ensure we keep the existing type ranges, load the current global id sequence, if there is any.
|
||||
var (
|
||||
r = make(IncrementStarts)
|
||||
path = rangesFilePath(g.Target)
|
||||
buf, err = os.ReadFile(path)
|
||||
)
|
||||
switch {
|
||||
case os.IsNotExist(err): // first time generation
|
||||
case err != nil:
|
||||
return err
|
||||
default:
|
||||
var (
|
||||
matches = make([][]byte, 0, 2)
|
||||
lines = bytes.Split(buf, []byte("\n"))
|
||||
)
|
||||
for i := 0; i < len(lines); i++ {
|
||||
if l := lines[i]; bytes.HasPrefix(l, []byte(incrementIdent)) {
|
||||
matches = append(matches, l)
|
||||
}
|
||||
}
|
||||
if len(matches) != 1 {
|
||||
return fmt.Errorf("expect to have exactly 1 ranges in %s", path)
|
||||
}
|
||||
var (
|
||||
line = matches[0]
|
||||
start = bytes.IndexByte(line, '"')
|
||||
end = bytes.LastIndexByte(line, '"')
|
||||
)
|
||||
if start == -1 || start >= end {
|
||||
return fmt.Errorf("unexpected line %s", line)
|
||||
}
|
||||
l, err := strconv.Unquote(string(line[start : end+1]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(l), &r); err != nil {
|
||||
return fmt.Errorf("unmarshal ranges: %w", err)
|
||||
}
|
||||
}
|
||||
// Range over all nodes and assign the increment starting value.
|
||||
var (
|
||||
need []*Type
|
||||
last int64
|
||||
)
|
||||
for _, n := range g.Nodes {
|
||||
if n.Annotations == nil {
|
||||
n.Annotations = make(Annotations)
|
||||
}
|
||||
a := n.EntSQL()
|
||||
if a == nil {
|
||||
a = &entsql.Annotation{}
|
||||
n.Annotations[a.Name()] = a
|
||||
}
|
||||
switch v, ok := r[n.Table()]; {
|
||||
case a.IncrementStart != nil:
|
||||
// In case the start value is defined by an annotation already, it has precedence.
|
||||
r[n.Table()] = *a.IncrementStart
|
||||
case ok:
|
||||
// In case this node has no annotation but an existing entry in the increments file.
|
||||
a.IncrementStart = &v
|
||||
default:
|
||||
// In case this is a new node, it gets the next free increment range (highest value << 32).
|
||||
need = append(need, n)
|
||||
}
|
||||
last = max(last, r[n.Table()])
|
||||
}
|
||||
// Compute new ranges and write them back to the file.
|
||||
s := len(g.Nodes) - len(need) // number of nodes with existing increment values
|
||||
for i, n := range need {
|
||||
r[n.Table()] = last + int64(s+i)<<32
|
||||
a := n.EntSQL()
|
||||
a.IncrementStart = func(i int64) *int64 { return &i }(r[n.Table()]) // copy to not override previous values
|
||||
n.Annotations[a.Name()] = a
|
||||
}
|
||||
// Ensure increment ranges are exactly of size 1<<32 with no overlaps.
|
||||
d := make(map[int64]string)
|
||||
for t, s := range r {
|
||||
switch t1, ok := d[s]; {
|
||||
case ok:
|
||||
return fmt.Errorf("duplicated increment start value %d for types %s and %s", s, t1, t)
|
||||
case s%(1<<32) != 0:
|
||||
return fmt.Errorf(
|
||||
"unexpected increment start value %d for type %s, expected multiple of %d (1<<32)", s, t, 1<<32,
|
||||
)
|
||||
}
|
||||
d[s] = t
|
||||
}
|
||||
if g.Annotations == nil {
|
||||
g.Annotations = make(Annotations)
|
||||
}
|
||||
g.Annotations[r.Name()] = r
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name implements Annotation interface.
|
||||
func (IncrementStarts) Name() string {
|
||||
return "IncrementStarts"
|
||||
}
|
||||
|
||||
func rangesFilePath(dir string) string {
|
||||
return filepath.Join(dir, "internal", "globalid.go")
|
||||
}
|
||||
55
entc/gen/globalid_test.go
Normal file
55
entc/gen/globalid_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2019-present Facebook Inc. All rights reserved.
|
||||
// This source code is licensed under the Apache 2.0 license found
|
||||
// in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
package gen_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/entc/gen"
|
||||
"entgo.io/ent/entc/load"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIncrementStartAnnotation(t *testing.T) {
|
||||
var (
|
||||
p = func(i int64) *int64 { return &i }
|
||||
a = &entsql.Annotation{IncrementStart: p(100)}
|
||||
s = []*load.Schema{
|
||||
{
|
||||
Name: "T1",
|
||||
Annotations: gen.Annotations{a.Name(): a},
|
||||
},
|
||||
{Name: "T2"},
|
||||
}
|
||||
c = &gen.Config{
|
||||
Package: "entc/gen",
|
||||
Target: t.TempDir(),
|
||||
}
|
||||
)
|
||||
// Arbitrary increment starting values allowed if feature is not enabled.
|
||||
g, err := gen.NewGraph(c, s...)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, g)
|
||||
|
||||
// Increments must be a multiple of 1<<32.
|
||||
c.Features = []gen.Feature{gen.FeatureGlobalID}
|
||||
g, err = gen.NewGraph(c, s...)
|
||||
require.EqualError(t, err, "unexpected increment start value 100 for type t1s, expected multiple of 4294967296 (1<<32)")
|
||||
require.Nil(t, g)
|
||||
a.IncrementStart = p(1 << 32)
|
||||
g, err = gen.NewGraph(c, s...)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, g)
|
||||
|
||||
// Duplicated increment starting values are not allowed.
|
||||
s = append(s, &load.Schema{
|
||||
Name: "T3",
|
||||
Annotations: gen.Annotations{a.Name(): &entsql.Annotation{IncrementStart: p(1 << 32)}},
|
||||
})
|
||||
g, err = gen.NewGraph(c, s...)
|
||||
require.ErrorContains(t, err, "duplicated increment start value 4294967296 for types")
|
||||
require.Nil(t, g)
|
||||
}
|
||||
@@ -181,6 +181,11 @@ func NewGraph(c *Config, schemas ...*load.Schema) (g *Graph, err error) {
|
||||
if c.Storage != nil && c.Storage.Init != nil {
|
||||
check(c.Storage.Init(g), "storage driver init")
|
||||
}
|
||||
if enabled, _ := g.Config.FeatureEnabled(FeatureGlobalID.Name); enabled {
|
||||
if err := IncrementStartAnnotation(g); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/entc/load"
|
||||
"entgo.io/ent/schema/edge"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -424,9 +425,8 @@ func TestGraph_Gen(t *testing.T) {
|
||||
{Name: "t1", Type: "T1", Unique: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "T2",
|
||||
},
|
||||
{Name: "T2"},
|
||||
{Name: "T3"},
|
||||
}
|
||||
graph, err := NewGraph(&Config{
|
||||
Package: "entc/gen",
|
||||
@@ -439,6 +439,13 @@ func TestGraph_Gen(t *testing.T) {
|
||||
require.NoError(err)
|
||||
require.NotNil(graph)
|
||||
require.NoError(graph.Gen())
|
||||
// Ensure globalid feature added annotations.
|
||||
a := IncrementStarts{"t1s": 0, "t2s": 1 << 32, "t3s": 2 << 32}
|
||||
require.Equal(a, graph.Annotations[a.Name()])
|
||||
ant := &entsql.Annotation{}
|
||||
for i, n := range graph.Nodes {
|
||||
require.Equal(int64(i)<<32, *n.Annotations[ant.Name()].(*entsql.Annotation).IncrementStart)
|
||||
}
|
||||
// Ensure graph files were generated.
|
||||
for _, name := range []string{"ent", "client"} {
|
||||
_, err := os.Stat(fmt.Sprintf("%s/%s.go", target, name))
|
||||
@@ -461,6 +468,9 @@ func TestGraph_Gen(t *testing.T) {
|
||||
require.NoError(err)
|
||||
_, err = os.Stat(filepath.Join(target, "internal", "schemaconfig.go"))
|
||||
require.NoError(err)
|
||||
c, err := os.ReadFile(filepath.Join(target, "internal", "globalid.go"))
|
||||
require.NoError(err)
|
||||
require.Contains(string(c), fmt.Sprintf(`"{\"t1s\":0,\"t2s\":%d,\"t3s\":%d}"`, 1<<32, 2<<32))
|
||||
// Rerun codegen with only one feature-flag.
|
||||
graph.Features = []Feature{FeatureSnapshot}
|
||||
require.NoError(graph.Gen())
|
||||
@@ -469,6 +479,8 @@ func TestGraph_Gen(t *testing.T) {
|
||||
require.NoError(err)
|
||||
_, err = os.Stat(filepath.Join(target, "internal", "schemaconfig.go"))
|
||||
require.True(os.IsNotExist(err))
|
||||
_, err = os.Stat(filepath.Join(target, "internal", "globalid.go"))
|
||||
require.True(os.IsNotExist(err))
|
||||
// Rerun codegen without any feature-flags.
|
||||
graph.Features = nil
|
||||
require.NoError(graph.Gen())
|
||||
|
||||
@@ -17,3 +17,14 @@ package internal
|
||||
|
||||
const Schema = {{ .SchemaSnapshot }}
|
||||
{{ end }}
|
||||
|
||||
{{ define "internal/globalid" }}
|
||||
|
||||
{{ with $.Header }}{{ . }}{{ else }}// Code generated by ent, DO NOT EDIT.{{ end }}
|
||||
|
||||
// +build tools
|
||||
|
||||
package internal
|
||||
|
||||
const IncrementStarts = {{ .Annotations.IncrementStarts | json | quote }}
|
||||
{{ end }}
|
||||
|
||||
@@ -221,6 +221,9 @@ func init() {
|
||||
{{- with $ant.Check }}
|
||||
Check: {{ quote . }},
|
||||
{{- end }}
|
||||
{{- with $ant.IncrementStart }}
|
||||
IncrementStart: func(i int64) *int64 { return &i }({{ . }}),
|
||||
{{- end }}
|
||||
}
|
||||
{{- with $ant.Incremental }}
|
||||
{{ $table }}.Annotation.Incremental = new(bool)
|
||||
|
||||
Reference in New Issue
Block a user