{{/* Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. */}} {{/* gotype: entgo.io/ent/entc/gen.Type */}} {{ define "update" }} {{ $pkg := base $.Config.Package }} {{ template "header" $ }} {{ template "import" $ }} import ( {{- range $import := $.SiblingImports }} {{ $import.Alias }} "{{ $import.Path }}" {{- end }} ) {{ $builder := $.UpdateName }} {{ $receiver := receiver $builder }} {{ $mutation := print $receiver ".mutation" }} {{ $runtimeRequired := or $.NumHooks $.NumPolicy }} // {{ $builder }} is the builder for updating {{ $.Name }} entities. type {{ $builder }} struct { config {{- template "update/fields" $ -}} } // Where appends a list predicates to the {{ $builder }} builder. func ({{ $receiver}} *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $builder }} { {{ $mutation }}.Where(ps...) return {{ $receiver }} } {{ with extend $ "Builder" $builder }} {{ template "setter" . }} {{ end }} {{ with extend $ "Builder" $builder }} {{ template "update/edges" . }} {{ end }} // Save executes the query and returns the number of nodes affected by the update operation. func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) (int, error) { var ( err error affected int ) {{- if $.HasUpdateDefault }} {{- if $runtimeRequired }} if err := {{ $receiver }}.defaults(); err != nil { return 0, err } {{- else }} {{ $receiver }}.defaults() {{- end }} {{- end }} if len({{ $receiver }}.hooks) == 0 { {{- if $.HasUpdateCheckers }} if err = {{ $receiver }}.check(); err != nil { return 0, err } {{- end }} affected, err = {{ $receiver }}.{{ $.Storage }}Save(ctx) } else { var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { mutation, ok := m.(*{{ $.MutationName }}) if !ok { return nil, fmt.Errorf("unexpected mutation type %T", m) } {{- if $.HasUpdateCheckers }} if err = {{ $receiver }}.check(); err != nil { return 0, err } {{- end }} {{ $mutation }} = mutation affected, err = {{ $receiver }}.{{ $.Storage }}Save(ctx) mutation.done = true return affected, err }) for i := len({{ $receiver }}.hooks) - 1; i >= 0; i-- { if {{ $receiver }}.hooks[i] == nil { return 0, fmt.Errorf("{{ $pkg }}: uninitialized hook (forgotten import {{ $pkg }}/runtime?)") } mut = {{ $receiver }}.hooks[i](mut) } if _, err := mut.Mutate(ctx, {{ $mutation }}); err != nil { return 0, err } } return affected, err } // SaveX is like Save, but panics if an error occurs. func ({{ $receiver }} *{{ $builder }}) SaveX(ctx context.Context) int { affected, err := {{ $receiver }}.Save(ctx) if err != nil { panic(err) } return affected } // Exec executes the query. func ({{ $receiver }} *{{ $builder }}) Exec(ctx context.Context) error { _, err := {{ $receiver }}.Save(ctx) return err } // ExecX is like Exec, but panics if an error occurs. func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) { if err := {{ $receiver }}.Exec(ctx); err != nil { panic(err) } } {{ with extend $ "Receiver" $receiver "Package" $pkg "Builder" $builder }} {{ template "update/checks" . }} {{ end }} {{ with extend $ "Builder" $builder "Package" $pkg }} {{ $tmpl := printf "dialect/%s/update" $.Storage }} {{ xtemplate $tmpl . }} {{ end }} {{ $onebuilder := $.UpdateOneName }} {{ $receiver = receiver $onebuilder }} {{ $mutation = print $receiver ".mutation" }} // {{ $onebuilder }} is the builder for updating a single {{ $.Name }} entity. type {{ $onebuilder }} struct { config fields []string {{- template "update/fields" $ }} } {{ with extend $ "Builder" $onebuilder }} {{ template "setter" . }} {{ end }} {{ with extend $ "Builder" $onebuilder }} {{ template "update/edges" . }} {{ end }} // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func ({{ $receiver }} *{{ $onebuilder }}) Select(field string, fields ...string) *{{ $onebuilder }} { {{ $receiver }}.fields = append([]string{field}, fields...) return {{ $receiver }} } // Save executes the query and returns the updated {{ $.Name }} entity. func ({{ $receiver }} *{{ $onebuilder }} ) Save(ctx context.Context) (*{{ $.Name }}, error) { var ( err error node *{{ $.Name }} ) {{- if $.HasUpdateDefault }} {{- if $runtimeRequired }} if err := {{ $receiver }}.defaults(); err != nil { return nil, err } {{- else }} {{ $receiver }}.defaults() {{- end }} {{- end }} if len({{ $receiver }}.hooks) == 0 { {{- if $.HasUpdateCheckers }} if err = {{ $receiver }}.check(); err != nil { return nil, err } {{- end }} node, err = {{ $receiver }}.{{ $.Storage }}Save(ctx) } else { var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { mutation, ok := m.(*{{ $.MutationName }}) if !ok { return nil, fmt.Errorf("unexpected mutation type %T", m) } {{- if $.HasUpdateCheckers }} if err = {{ $receiver }}.check(); err != nil { return nil, err } {{- end }} {{ $mutation }} = mutation node, err = {{ $receiver }}.{{ $.Storage }}Save(ctx) mutation.done = true return node, err }) for i := len({{ $receiver }}.hooks) - 1; i >= 0; i-- { if {{ $receiver }}.hooks[i] == nil { return nil, fmt.Errorf("{{ $pkg }}: uninitialized hook (forgotten import {{ $pkg }}/runtime?)") } mut = {{ $receiver }}.hooks[i](mut) } v, err := mut.Mutate(ctx, {{ $mutation }}) if err != nil { return nil, err } nv, ok := v.(*{{ $.Name }}) if !ok { return nil, fmt.Errorf("unexpected node type %T returned from {{ $.MutationName }}", v) } node = nv } return node, err } // SaveX is like Save, but panics if an error occurs. func ({{ $receiver }} *{{ $onebuilder }}) SaveX(ctx context.Context) *{{ $.Name }} { node, err := {{ $receiver }}.Save(ctx) if err != nil { panic(err) } return node } // Exec executes the query on the entity. func ({{ $receiver }} *{{ $onebuilder }}) Exec(ctx context.Context) error { _, err := {{ $receiver }}.Save(ctx) return err } // ExecX is like Exec, but panics if an error occurs. func ({{ $receiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) { if err := {{ $receiver }}.Exec(ctx); err != nil { panic(err) } } {{ with extend $ "Receiver" $receiver "Package" $pkg "Builder" $onebuilder }} {{ template "update/checks" . }} {{ end }} {{ with extend $ "Builder" $onebuilder "Package" $pkg }} {{ $tmpl := printf "dialect/%s/update" $.Storage }} {{ xtemplate $tmpl . }} {{ end }} {{- /* Support adding update methods by global templates. */}} {{- with $tmpls := matchTemplate "update/additional/*" }} {{- range $tmpl := $tmpls }} {{ xtemplate $tmpl $ }} {{- end }} {{- end }} {{ end }} {{/* shared struct fields between the two updaters */}} {{ define "update/fields"}} hooks []Hook mutation *{{ $.MutationName }} {{- /* Additional fields to add to the builder. */}} {{- $tmpl := printf "dialect/%s/update/fields" $.Storage }} {{- if hasTemplate $tmpl }} {{- xtemplate $tmpl . }} {{- end }} {{ end }} {{/* shared edges removal between the two updaters */}} {{ define "update/edges" }} {{ $builder := pascal .Scope.Builder }} {{ $receiver := receiver $builder }} {{ $mutation := print $receiver ".mutation" }} {{ range $e := $.EdgesWithID }} {{ if $e.Immutable }} {{/* Skip to the next one as immutable edges cannot be updated. */}} {{continue}} {{ end }} {{ $func := $e.MutationClear }} // {{ $func }} clears {{ if $e.Unique }}the "{{ $e.Name }}" edge{{ else }}all "{{ $e.Name }}" edges{{ end }} to the {{ $e.Type.Name }} entity. func ({{ $receiver }} *{{ $builder }}) {{ $func }}() *{{ $builder }} { {{ $mutation }}.{{ $func }}() return {{ $receiver }} } {{ if not $e.Unique }} {{ $p := lower (printf "%.1s" $e.Type.Name) }} {{/* if the name of the parameter conflicts with the receiver name */}} {{ if eq $p $receiver }} {{ $p = "v" }} {{ end }} {{ $idsFunc := print "Remove" (singular $e.Name | pascal) "IDs" }} // {{ $idsFunc }} removes the "{{ $e.Name }}" edge to {{ $e.Type.Name }} entities by IDs. func ({{ $receiver }} *{{ $builder }}) {{ $idsFunc }}(ids ...{{ $e.Type.ID.Type }}) *{{ $builder }} { {{ $mutation }}.{{ $idsFunc }}(ids...) return {{ $receiver }} } {{ $func := print "Remove" $e.StructField }} // {{ $func }} removes "{{ $e.Name }}" edges to {{ $e.Type.Name }} entities. func ({{ $receiver }} *{{ $builder }}) {{ $func }}({{ $p }} ...*{{ $e.Type.Name }}) *{{ $builder }} { ids := make([]{{ $e.Type.ID.Type }}, len({{ $p }})) {{ $i := "i" }}{{ if eq $i $p }}{{ $i = "j" }}{{ end -}} for {{ $i }} := range {{ $p }} { ids[{{ $i }}] = {{ $p }}[{{ $i }}].ID } return {{ $receiver }}.{{ $idsFunc }}(ids...) } {{ end }} {{ end }} {{ end }} {{/* shared template for the 2 update builders */}} {{ define "update/checks" }} {{ $pkg := .Scope.Package }} {{ $receiver := .Scope.Receiver }} {{ $builder := pascal .Scope.Builder }} {{ $mutation := print $receiver ".mutation" }} {{ $runtimeRequired := or $.NumHooks $.NumPolicy }} {{ if $.HasUpdateDefault }} // defaults sets the default values of the builder before save. func ({{ $receiver }} *{{ $builder }}) defaults() {{ if $runtimeRequired }}error{{ end }}{ {{- range $f := $.Fields }} {{- if $f.UpdateDefault }} if _, ok := {{ $mutation }}.{{ $f.MutationGet }}(); !ok {{ if $f.Optional }} && !{{ $mutation }}.{{ $f.StructField }}Cleared() {{ end }} { {{- if $runtimeRequired }} if {{ $.Package }}.{{ $f.UpdateDefaultName }} == nil { return fmt.Errorf("{{ $pkg }}: uninitialized {{ $.Package }}.{{ $f.UpdateDefaultName }} (forgotten import {{ $pkg }}/runtime?)") } {{- end }} v := {{ $.Package }}.{{ $f.UpdateDefaultName }}() {{ $mutation }}.Set{{ $f.StructField }}(v) } {{- end }} {{- end }} {{- if $runtimeRequired }} return nil {{- end }} } {{ end }} {{ if $.HasUpdateCheckers }} // check runs all checks and user-defined validators on the builder. func ({{ $receiver }} *{{ $builder }}) check() error { {{- range $f := $.Fields }} {{- with and (or $f.Validators $f.IsEnum) (not $f.Immutable) }} if v, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok { {{- $basic := $f.BasicType "v" }} if err := {{ $.Package }}.{{ $f.Validator }}({{ $basic }}); err != nil { return &ValidationError{Name: "{{ $f.Name }}", err: fmt.Errorf(`{{ $pkg }}: validator failed for field "{{ $.Name }}.{{ $f.Name }}": %w`, err)} } } {{- end }} {{- end }} {{- range $e := $.Edges }} {{- if and $e.Unique (not $e.Optional) }} if _, ok := {{ $mutation }}.{{ $e.StructField }}ID(); {{ $mutation }}.{{ $e.StructField }}Cleared() && !ok { return errors.New(`{{ $pkg }}: clearing a required unique edge "{{ $.Name }}.{{ $e.Name }}"`) } {{- end }} {{- end }} return nil } {{ end }} {{ end }}