// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package internal import ( "bytes" "encoding/json" "fmt" "io/ioutil" "strings" "github.com/facebook/ent/entc/gen" "github.com/facebook/ent/entc/load" ) // Snapshot describes the schema snapshot restore. type Snapshot struct { Path string // Path to snapshot. Config *gen.Config // Config of codegen. } // Restore restores the generated package from the latest schema snapshot. // If there is a conflict between upstream and local snapshots, it is merged // before running the code generation. func (s *Snapshot) Restore() error { buf, err := ioutil.ReadFile(s.Path) if err != nil { return fmt.Errorf("unable to read snapshot schema %w", err) } snap, err := s.parseSnapshot(buf) if err != nil { return err } s.Config.Schema = snap.Schema s.Config.Package = snap.Package s.addFeatures(snap) graph, err := gen.NewGraph(s.Config, snap.Schemas...) if err != nil { return err } return graph.Gen() } // schemaIdent holds the schema identifier in snapshot file. const schemaIdent = "const Schema" // parseSnapshot parses the given buffer and extract the generated snapshot. // If it encounters a merge-conflict, it will resolve it by merging the relevant // parts for the codegen. func (s *Snapshot) parseSnapshot(buf []byte) (*gen.Snapshot, error) { var ( conflict bool matches = make([][]byte, 0, 2) lines = bytes.Split(buf, []byte("\n")) ) for i := 0; i < len(lines); i++ { switch line := lines[i]; { case bytes.HasPrefix(line, []byte(schemaIdent)): matches = append(matches, line) case bytes.HasPrefix(line, []byte(conflictMarker)): conflict = true } } switch n := len(matches); { case n == 0: return nil, fmt.Errorf("schema snapshot was not found in %s", s.Path) case n > 1 && !conflict: return nil, fmt.Errorf("expect to have exactly 1 snapshot in %s", s.Path) } line, err := trim(matches[0]) if err != nil { return nil, err } local := &gen.Snapshot{} if err := json.Unmarshal(line, &local); err != nil { return nil, fmt.Errorf("unmarshal snapshot %v: %w", local, err) } if !conflict || len(matches) == 1 { return local, nil } // In case of merge-conflict, we merge the 2 schemas. line, err = trim(matches[1]) if err != nil { return nil, err } other := &gen.Snapshot{} if err := json.Unmarshal(line, &other); err != nil { return nil, fmt.Errorf("unmarshal snapshot %v: %w", local, err) } merge(local, other) return local, nil } // addFeatures adds the features in the snapshot to the codegen config. func (s *Snapshot) addFeatures(snap *gen.Snapshot) { add := make(map[string]gen.Feature) for _, name := range snap.Features { for _, feat := range gen.AllFeatures { if name == feat.Name { add[name] = feat } } } for _, feat := range s.Config.Features { delete(add, feat.Name) } for _, feat := range add { s.Config.Features = append(s.Config.Features, feat) } } // merge the "other"/"upstream" snapshot to the "local" version. func merge(local, other *gen.Snapshot) { if local.Schema == "" { local.Schema = other.Schema } if local.Package == "" { local.Package = other.Package } locals := make(map[string]*load.Schema, len(local.Schemas)) for _, schema := range local.Schemas { locals[schema.Name] = schema } // Merge "other" schemas. for _, schema := range other.Schemas { switch match, ok := locals[schema.Name]; { case !ok: local.Schemas = append(local.Schemas, schema) case ok: mergeSchema(match, schema) } } // Merge codegen features. features := make(map[string]struct{}, len(local.Features)) for _, feat := range local.Features { features[feat] = struct{}{} } for _, feat := range other.Features { if _, ok := features[feat]; !ok { local.Features = append(local.Features, feat) } } } // mergeSchema merges to "local" the additional information in // the "other" schema, that may be necessary for code-generation. func mergeSchema(local, other *load.Schema) { if local.Config.Table == "" { local.Config.Table = other.Config.Table } if local.Annotations == nil && other.Annotations != nil { local.Annotations = make(map[string]interface{}) } for ant := range other.Annotations { if _, ok := local.Annotations[ant]; !ok { local.Annotations[ant] = other.Annotations[ant] } } fields := make(map[string]*load.Field, len(local.Fields)) for _, f := range local.Fields { fields[f.Name] = f } for _, f := range other.Fields { switch match, ok := fields[f.Name]; { case !ok: local.Fields = append(local.Fields, f) case ok: mergeField(match, f) } } edges := make(map[string]*load.Edge, len(local.Edges)) for _, e := range local.Edges { edges[e.Name] = e } for _, e := range other.Edges { switch match, ok := edges[e.Name]; { case !ok: local.Edges = append(local.Edges, e) case ok: mergeEdge(match, e) } } } // mergeField merges to "local" the additional information in // the "other" field, that may be necessary for code-generation. func mergeField(local, other *load.Field) { if local.Annotations == nil && other.Annotations != nil { local.Annotations = make(map[string]interface{}) } for ant := range other.Annotations { if _, ok := local.Annotations[ant]; !ok { local.Annotations[ant] = other.Annotations[ant] } } if !local.Immutable && other.Immutable { local.Immutable = other.Immutable } } // mergeEdge merges to "local" the additional information in // the "other" edge, that may be necessary for code-generation. func mergeEdge(local, other *load.Edge) { if local.Annotations == nil && other.Annotations != nil { local.Annotations = make(map[string]interface{}) } for ant := range other.Annotations { if _, ok := local.Annotations[ant]; !ok { local.Annotations[ant] = other.Annotations[ant] } } } // IsBuildError reports if the given error is an error from the Go command (e.g. syntax error). func IsBuildError(err error) bool { if strings.HasPrefix(err.Error(), "entc/load: #") { return true } for _, s := range []string{"syntax error", "previous declaration", "invalid character"} { if strings.Contains(err.Error(), s) { return true } } return false } func trim(line []byte) ([]byte, error) { start := bytes.IndexByte(line, '`') end := bytes.LastIndexByte(line, '`') if start == -1 || start >= end { return nil, fmt.Errorf("unexpected snapshot line %s", line) } return line[start+1 : end], nil }