Files
ent/entc/gen/globalid.go
2025-01-27 19:40:05 +01:00

209 lines
5.4 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package gen
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"entgo.io/ent/dialect/entsql"
)
const incrementIdent = "const IncrementStarts"
// IncrementStarts holds the autoincrement start value for each type.
type IncrementStarts map[string]int
// IncrementStartAnnotation assigns a unique range to each node in the graph.
func IncrementStartAnnotation(g *Graph) error {
// To ensure we keep the existing type ranges, load the current global id sequence, if there is any.
var (
r = make(IncrementStarts)
path = IncrementStartsFilePath(g.Target)
buf, err = os.ReadFile(path)
)
switch {
case os.IsNotExist(err): // first time generation
case err != nil:
return err
default:
if ok, _ := g.FeatureEnabled(FeatureSnapshot.Name); ok {
if err = ResolveIncrementStartsConflict(g.Target); err != nil {
return err
}
buf, err = os.ReadFile(path)
if err != nil {
return err
}
}
var (
matches = make([][]byte, 0, 2)
lines = bytes.Split(buf, []byte("\n"))
)
for i := 0; i < len(lines); i++ {
if l := lines[i]; bytes.HasPrefix(l, []byte(incrementIdent)) {
matches = append(matches, l)
}
}
if len(matches) != 1 {
return fmt.Errorf("expect to have exactly 1 ranges in %s", path)
}
var (
line = matches[0]
start = bytes.IndexByte(line, '"')
end = bytes.LastIndexByte(line, '"')
)
if start == -1 || start >= end {
return fmt.Errorf("unexpected line %s", line)
}
l, err := strconv.Unquote(string(line[start : end+1]))
if err != nil {
return err
}
if err := json.Unmarshal([]byte(l), &r); err != nil {
return fmt.Errorf("unmarshal ranges: %w", err)
}
}
// Range over all nodes and assign the increment starting value.
var (
need []*Type
lastIdx = -1
)
for _, n := range g.Nodes {
a := n.EntSQL()
if a == nil {
a = &entsql.Annotation{}
}
switch v, ok := r[n.Table()]; {
case a.IncrementStart != nil:
// In case the start value is defined by an annotation already, it has precedence.
r[n.Table()] = *a.IncrementStart
case ok:
// In case this node has no annotation but an existing entry in the increments file.
a.IncrementStart = &v
default:
// In case this is a new node, it gets the next free increment range (highest value << 32).
need = append(need, n)
}
if v, ok := r[n.Table()]; ok {
lastIdx = max(lastIdx, v/(1<<32-1))
}
if err := setAnnotation(n, a); err != nil {
return err
}
}
// Compute new ranges and write them back to the file.
for i, n := range need {
r[n.Table()] = (lastIdx + i + 1) << 32
a := n.EntSQL()
a.IncrementStart = func(i int) *int { return &i }(r[n.Table()]) // copy to not override previous values
if err := setAnnotation(n, a); err != nil {
return err
}
}
// Ensure increment ranges are exactly of size 1<<32 with no overlaps.
d := make(map[int]string)
for t, s := range r {
switch t1, ok := d[s]; {
case ok:
return fmt.Errorf("duplicated increment start value %d for types %s and %s", s, t1, t)
case s%(1<<32) != 0:
return fmt.Errorf(
"unexpected increment start value %d for type %s, expected multiple of %d (1<<32)", s, t, 1<<32,
)
}
d[s] = t
}
if g.Annotations == nil {
g.Annotations = make(Annotations)
}
g.Annotations[r.Name()] = r
return nil
}
// Name implements Annotation interface.
func (IncrementStarts) Name() string {
return "IncrementStarts"
}
// WriteToDisk writes the increment starts to the disk.
func (i IncrementStarts) WriteToDisk(target string) error {
initTemplates()
p := IncrementStartsFilePath(target)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
return err
}
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()
return templates.Lookup("internal/globalid").
Execute(f, &Config{
Target: target,
Annotations: Annotations{"IncrementStarts": i},
})
}
func IncrementStartsFilePath(dir string) string {
return filepath.Join(dir, "internal", "globalid.go")
}
// ResolveIncrementStartsConflict resolves git/mercurial conflicts by "accepting theirs".
func ResolveIncrementStartsConflict(dir string) error {
// Expect 2 ranges in the file, accept the second one, since this is the remote content.
p := IncrementStartsFilePath(dir)
fi, err := os.Stat(p)
if err != nil {
return err
}
c, err := os.ReadFile(p)
if err != nil {
return err
}
var (
fixed [][]byte
conflict, skipped bool
lines = bytes.Split(c, []byte("\n"))
)
for _, l := range lines {
switch {
case bytes.HasPrefix(l, []byte("<<<<<<<")):
conflict = true
case bytes.HasPrefix(l, []byte("=======")), bytes.HasPrefix(l, []byte(">>>>>>>")):
case bytes.HasPrefix(l, []byte(incrementIdent)) && conflict && !skipped:
skipped = true
default:
fixed = append(fixed, l)
}
}
return os.WriteFile(p, bytes.Join(fixed, []byte("\n")), fi.Mode())
}
func ToMap(a *entsql.Annotation) (map[string]any, error) {
buf, err := json.Marshal(a)
if err != nil {
return nil, err
}
m := make(map[string]any)
if err = json.Unmarshal(buf, &m); err != nil {
return nil, err
}
return m, nil
}
func setAnnotation(n *Type, a *entsql.Annotation) error {
m, err := ToMap(a)
if err != nil {
return err
}
n.Annotations.Set(a.Name(), m)
return nil
}