{{/* Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. */}} {{ define "update" }} {{ $pkg := base $.Config.Package }} {{ template "header" $ }} {{ template "import" $ }} {{ $builder := print (pascal $.Name) "Update" }} {{ $receiver := receiver $builder }} // {{ $builder }} is the builder for updating {{ $.Name }} entities. type {{ $builder }} struct { config {{- template "update/fields" $ -}} predicates []predicate.{{ $.Name }} } // Where adds a new predicate for the builder. func ({{ $receiver}} *{{ $builder }}) Where(ps ...predicate.{{ $.Name }}) *{{ $builder }} { {{ $receiver}}.predicates = append({{ $receiver}}.predicates, ps...) return {{ $receiver }} } {{ with extend $ "Builder" $builder }} {{ template "setter" . }} {{ end }} {{ with extend $ "Builder" $builder }} {{ template "update/edges" . }} {{ end }} // Save executes the query and returns the number of rows/vertices matched by this operation. func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) (int, error) { {{ with extend $ "Receiver" $receiver "Package" $pkg "ZeroValue" 0 -}} {{ template "update/save" . }} {{- end -}} {{- if $.MultiStorage -}} switch {{ $receiver }}.driver.Dialect() { {{- range $_, $storage := $.Storage }} case {{ join $storage.Dialects ", " }}: return {{ $receiver }}.{{ $storage }}Save(ctx) {{- end }} default: return 0, errors.New("{{ $pkg }}: unsupported dialect") } {{- else -}} return {{ $receiver }}.{{ index $.Storage 0 }}Save(ctx) {{- end }} } // SaveX is like Save, but panics if an error occurs. func ({{ $receiver }} *{{ $builder }}) SaveX(ctx context.Context) int { affected, err := {{ $receiver }}.Save(ctx) if err != nil { panic(err) } return affected } // Exec executes the query. func ({{ $receiver }} *{{ $builder }}) Exec(ctx context.Context) error { _, err := {{ $receiver }}.Save(ctx) return err } // ExecX is like Exec, but panics if an error occurs. func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) { if err := {{ $receiver }}.Exec(ctx); err != nil { panic(err) } } {{- range $_, $storage := $.Storage }} {{ with extend $ "Builder" $builder "Package" $pkg }} {{ $tmpl := printf "dialect/%s/update" $storage }} {{ xtemplate $tmpl . }} {{ end }} {{ end }} {{ $onebuilder := printf "%sOne" $builder }} {{ $receiver = receiver $onebuilder }} // {{ $onebuilder }} is the builder for updating a single {{ $.Name }} entity. type {{ $onebuilder }} struct { config id {{ $.ID.Type }} {{- template "update/fields" $ }} } {{ with extend $ "Builder" $onebuilder }} {{ template "setter" . }} {{ end }} {{ with extend $ "Builder" $onebuilder }} {{ template "update/edges" . }} {{ end }} // Save executes the query and returns the updated entity. func ({{ $receiver }} *{{ $onebuilder }} ) Save(ctx context.Context) (*{{ $.Name }}, error) { {{ with extend $ "Receiver" $receiver "Package" $pkg "ZeroValue" "nil" -}} {{ template "update/save" . }} {{- end -}} {{- if $.MultiStorage -}} switch {{ $receiver }}.driver.Dialect() { {{- range $_, $storage := $.Storage }} case {{ join $storage.Dialects ", " }}: return {{ $receiver }}.{{ $storage }}Save(ctx) {{- end }} default: return nil, errors.New("{{ $pkg }}: unsupported dialect") } {{- else -}} return {{ $receiver }}.{{ index $.Storage 0 }}Save(ctx) {{- end }} } // SaveX is like Save, but panics if an error occurs. func ({{ $receiver }} *{{ $onebuilder }}) SaveX(ctx context.Context) *{{ $.Name }} { {{ $.Receiver }}, err := {{ $receiver }}.Save(ctx) if err != nil { panic(err) } return {{ $.Receiver }} } // Exec executes the query on the entity. func ({{ $receiver }} *{{ $onebuilder }}) Exec(ctx context.Context) error { _, err := {{ $receiver }}.Save(ctx) return err } // ExecX is like Exec, but panics if an error occurs. func ({{ $receiver }} *{{ $onebuilder }}) ExecX(ctx context.Context) { if err := {{ $receiver }}.Exec(ctx); err != nil { panic(err) } } {{- range $_, $storage := $.Storage }} {{ with extend $ "Builder" $onebuilder "Package" $pkg }} {{ $tmpl := printf "dialect/%s/update" $storage }} {{ xtemplate $tmpl . }} {{ end }} {{ end }} {{ end }} {{/* shared struct fields between the two updaters */}} {{ define "update/fields"}} {{ range $_, $f := $.Fields }} {{- if or (not $f.Immutable) $f.UpdateDefault }} {{- $f.BuilderField }} *{{ $f.Type }} {{- if $f.Type.Numeric }} add{{ $f.BuilderField }} *{{ $f.Type }} {{- end }} {{- end }} {{- if $f.Optional }} clear{{ $f.BuilderField }} bool {{- end }} {{ end }} {{- range $_, $e := $.Edges }} {{- $e.BuilderField }} map[{{ $.ID.Type }}]struct{} {{ end }} {{- range $_, $e := $.Edges }} {{- $p := "removed" }}{{ if $e.Unique }}{{ $p = "cleared" }}{{ end }} {{- print $p $e.StructField }} {{ if $e.Unique }}bool{{ else }}map[{{ $.ID.Type }}]struct{}{{ end }} {{ end -}} {{ end }} {{/* shared edges removal between the two updaters */}} {{ define "update/edges" }} {{ $builder := pascal .Scope.Builder }} {{ $receiver := receiver $builder }} {{ range $_, $e := $.Edges }} {{ if $e.Unique }} {{ $func := print "Clear" $e.StructField }} // {{ $func }} clears the {{ $e.Name }} edge to {{ $e.Type.Name }}. func ({{ $receiver }} *{{ $builder }}) {{ $func }}() *{{ $builder }} { {{ $receiver }}.cleared{{ $e.StructField }} = true return {{ $receiver }} } {{ else }} {{ $p := lower (printf "%.1s" $e.Type.Name) }} {{/* if the name of the parameter conflicts with the receiver name */}} {{ if eq $p $receiver }} {{ $p = "v" }} {{ end }} {{ $idsFunc := print "Remove" (singular $e.Name | pascal) "IDs" }} // {{ $idsFunc }} removes the {{ $e.Name }} edge to {{ $e.Type.Name }} by ids. func ({{ $receiver }} *{{ $builder }}) {{ $idsFunc }}(ids ...{{ $.ID.Type }}) *{{ $builder }} { if {{ $receiver }}.removed{{ $e.StructField }} == nil { {{ $receiver }}.removed{{ $e.StructField }} = make(map[{{ $.ID.Type }}]struct{}) } for i := range ids { {{ $receiver }}.removed{{ $e.StructField }}[ids[i]] = struct{}{} } return {{ $receiver }} } {{ $func := print "Remove" $e.StructField }} // {{ $func }} removes {{ $e.Name }} edges to {{ $e.Type.Name }}. func ({{ $receiver }} *{{ $builder }}) {{ $func }}({{ $p }} ...*{{ $e.Type.Name }}) *{{ $builder }} { ids := make([]{{ $.ID.Type }}, len({{ $p }})) {{ $i := "i" }}{{ if eq $i $p }}{{ $i = "j" }}{{ end -}} for {{ $i }} := range {{ $p }} { ids[{{ $i }}] = {{ $p }}[{{ $i }}].ID } return {{ $receiver }}.{{ $idsFunc }}(ids...) } {{ end }} {{ end }} {{ end }} {{/* shared template for the save method of the 2 builders */}} {{ define "update/save" }} {{- $pkg := .Scope.Package -}} {{- $zero := .Scope.ZeroValue }} {{- $receiver := .Scope.Receiver -}} {{- range $_, $f := $.Fields -}} {{- if $f.UpdateDefault -}} if {{ $receiver }}.{{ $f.BuilderField }} == nil {{ if $f.Optional }} && !{{ $receiver }}.clear{{ $f.BuilderField }} {{ end }} { v := {{ $.Package }}.{{ $f.UpdateDefaultName }}{{ if $f.IsTime }}(){{ end }} {{ $receiver }}.{{ $f.BuilderField }} = &v } {{ end -}} {{ with or $f.Validators $f.IsEnum -}} if {{ $receiver }}.{{ $f.BuilderField }} != nil { if err := {{ $.Package }}.{{ $f.Validator }}(*{{ $receiver }}.{{ $f.BuilderField }}); err != nil { return {{ $zero }}, fmt.Errorf("{{ $pkg }}: validator failed for field \"{{ $f.Name }}\": %v", err) } } {{ end -}} {{ end -}} {{- range $_, $e := $.Edges }} {{- if $e.Unique -}} if len({{ $receiver }}.{{ $e.BuilderField }}) > 1 { return {{ $zero }}, errors.New("{{ $pkg }}: multiple assignments on a unique edge \"{{ $e.Name }}\"") } {{ if not $e.Optional -}} if {{ $receiver }}.cleared{{ $e.StructField }} && {{ $receiver }}.{{ $e.BuilderField }} == nil { return {{ $zero }}, errors.New("{{ $pkg }}: clearing a unique edge \"{{ $e.Name }}\"") } {{ end -}} {{ end -}} {{ end -}} {{ end }}