Files
ent/entc/gen/template/dialect/gremlin/update.tmpl
2021-09-07 18:33:32 +03:00

191 lines
6.6 KiB
Cheetah

{{/*
Copyright 2019-present Facebook Inc. All rights reserved.
This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}
{{/* gotype: entgo.io/ent/entc/gen.typeScope */}}
{{ define "dialect/gremlin/update" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
{{ $mutation := print $receiver ".mutation" }}
{{ $one := hasSuffix $builder "One" }}
{{ $zero := 0 }}{{ if $one }}{{ $zero = "nil" }}{{ end }}
func ({{ $receiver }} *{{ $builder }}) gremlinSave(ctx context.Context) ({{- if $one }}*{{ $.Name }}{{ else }}int{{ end }}, error) {
res := &gremlin.Response{}
{{- if $one }}
id, ok := {{ $mutation }}.{{ $.ID.MutationGet }}()
if !ok {
return {{ $zero }}, &ValidationError{Name: "{{ $.ID.Name }}", err: errors.New(`{{ $pkg }}: missing "{{ $.Name }}.{{ $.ID.Name }}" for update`)}
}
query, bindings := {{ $receiver }}.gremlin(id).Query()
{{- else }}
query, bindings := {{ $receiver }}.gremlin().Query()
{{- end }}
if err := {{ $receiver }}.driver.Exec(ctx, query, bindings, res); err != nil {
return {{ $zero }}, err
}
if err, ok := isConstantError(res); ok {
return {{ $zero }}, err
}
{{- if $one }}
{{- $r := $.Receiver }}
{{ $r }} := &{{ $.Name }}{config: {{ $receiver }}.config}
if err := {{ $r }}.FromResponse(res); err != nil {
return nil, err
}
return {{ $r }}, nil
{{- else }}
return res.ReadInt()
{{- end }}
}
func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{ end }}) *dsl.Traversal {
{{- with .NumConstraint }}
type constraint struct {
pred *dsl.Traversal // constraint predicate.
test *dsl.Traversal // test matches and its constant.
}
constraints := make([]*constraint, 0, {{ . }})
{{- end }}
{{- /* case of update specific vertex */}}
{{- if $one }}
v := g.V(id)
{{- /* general update for N vertices */}}
{{- else }}
v := g.V().HasLabel({{ $.Package }}.Label)
for _, p := range {{ $mutation }}.predicates {
p(v)
}
{{- end }}
var (
{{ if or .NumConstraint (len $.Edges) }}
rv = v.Clone()
_ = rv
{{ end }}
trs []*dsl.Traversal
)
{{- range $f := $.MutationFields }}
{{- if or (not $f.Immutable) $f.UpdateDefault }}
if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok {
{{- if $f.Unique }}
constraints = append(constraints, &constraint{
pred: g.V().Has({{ $.Package }}.Label, {{ $.Package }}.{{ $f.Constant }}, value).Count(),
test: __.Is(p.NEQ(0)).Constant(NewErrUniqueField({{ $.Package }}.Label, {{ $.Package }}.{{ $f.Constant }}, value)),
})
{{- end }}
v.Property(dsl.Single, {{ $.Package }}.{{ $f.Constant }}, value)
}
{{- if $f.SupportsMutationAdd }}
if value, ok := {{ $mutation }}.Added{{ $f.StructField }}(); ok {
{{- if $f.Unique }}
addValue := rv.Clone().Union(__.Values({{ $.Package }}.{{ $f.Constant }}), __.Constant(value)).Sum().Next()
constraints = append(constraints, &constraint{
pred: g.V().Has({{ $.Package }}.Label, {{ $.Package }}.{{ $f.Constant }}, addValue).Count(),
test: __.Is(p.NEQ(0)).Constant(NewErrUniqueField({{ $.Package }}.Label, {{ $.Package }}.{{ $f.Constant }}, fmt.Sprintf("+= %v", value))),
})
{{- end }}
v.Property(dsl.Single, {{ $.Package }}.{{ $f.Constant }}, __.Union(__.Values({{ $.Package }}.{{ $f.Constant }}), __.Constant(value)).Sum())
}
{{- end }}
{{- end }}
{{- end }}
{{- /* clear optional fields. */}}
{{- with $.HasOptional }}
var properties []interface{}
{{- range $f := $.MutationFields }}
{{- if $f.Optional }}
if {{ $mutation }}.{{ $f.StructField }}Cleared() {
properties = append(properties, {{ $.Package }}.{{ $f.Constant }})
}
{{- end }}
{{- end }}
if len(properties) > 0 {
v.SideEffect(__.Properties(properties...).Drop())
}
{{- end }}
{{- range $e := $.Edges }}
{{- $direction := "In" }}
{{- $name := printf "%s.%s" $.Package $e.LabelConstant }}
{{- if $e.IsInverse }}
{{- $direction = "Out" }}
{{- $name = printf "%s.%s" $e.Type.Package $e.LabelConstant }}
{{- end }}
{{- /* remove edges */}}
{{- if $e.Unique }}
if {{ $mutation }}.{{ $e.StructField }}Cleared() {
{{- else }}
for _, id := range {{ $mutation }}.Removed{{ $e.StructField }}IDs() {
{{- end }}
{{- if $e.Bidi }}
tr := rv.Clone().BothE({{ $name }}){{ if not $e.Unique }}.Where(__.Or(__.InV().HasID(id), __.OutV().HasID(id))){{ end }}.Drop().Iterate()
{{- else if $e.IsInverse }}
tr := rv.Clone().InE({{ $name }}){{ if not $e.Unique }}.Where(__.OtherV().HasID(id)){{ end }}.Drop().Iterate()
{{- else }}
tr := rv.Clone().OutE({{ $name }}){{ if not $e.Unique }}.Where(__.OtherV().HasID(id)){{ end }}.Drop().Iterate()
{{- end }}
trs = append(trs, tr)
}
{{- /* update edges */}}
for _, id := range {{ $mutation }}.{{ $e.StructField }}IDs() {
{{- if $e.IsInverse }}
v.AddE({{ $name }}).From(g.V(id)).InV()
{{- else }}
v.AddE({{ $name }}).To(g.V(id)).OutV()
{{- end }}
{{- if $e.HasConstraint }}
{{- if $e.Bidi }}
constraints = append(constraints, &constraint{
pred: rv.Clone().Both({{ $name }}).Count(),
test: __.Is(p.NEQ(0)).Constant(NewErrUniqueEdge({{ $.Package }}.Label, {{ $name }}, id)),
})
constraints = append(constraints, &constraint{
pred: g.E().HasLabel({{ $name }}).Where(__.Or(__.InV().HasID(id), __.OutV().HasID(id))).Count(),
test: __.Is(p.NEQ(0)).Constant(NewErrUniqueEdge({{ $.Package }}.Label, {{ $name }}, id)),
})
{{- else }}
constraints = append(constraints, &constraint{
pred: g.E().HasLabel({{ $name }}).{{ $direction }}V().HasID(id).Count(),
test: __.Is(p.NEQ(0)).Constant(NewErrUniqueEdge({{ $.Package }}.Label, {{ $name }}, id)),
})
{{- end }}
{{- end }}
}
{{- end }}
{{- if $one }}
if len({{ $receiver }}.fields) > 0 {
fields := make([]interface{}, 0, len({{ $receiver }}.fields)+1)
fields = append(fields, true)
for _, f := range {{ $receiver }}.fields {
fields = append(fields, f)
}
v.ValueMap(fields...)
} else {
v.ValueMap(true)
}
{{- else }}
v.Count()
{{- end }}
{{- with .NumConstraint }}
if len(constraints) > 0 {
{{- /* make sure the traversal does not contain more than one vertex if we have constraint */}}
{{- if not $one }}
constraints = append(constraints, &constraint{
pred: rv.Count(),
test: __.Is(p.GT(1)).Constant(&ConstraintError{msg: "update traversal contains more than one vertex"}),
})
{{- end }}
v = constraints[0].pred.Coalesce(constraints[0].test, v)
for _, cr := range constraints[1:] {
v = cr.pred.Coalesce(cr.test, v)
}
}
{{- end }}
trs = append(trs, v)
return dsl.Join(trs...)
}
{{ end }}