{{/* Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. */}} {{ define "dialect/sql/update" }} {{ $pkg := $.Scope.Package }} {{ $builder := pascal $.Scope.Builder }} {{ $receiver := receiver $builder }} {{ $one := hasSuffix $builder "One" }} {{- $zero := 0 }}{{ if $one }}{{ $zero = "nil" }}{{ end }} {{- $ret := "n" }}{{ if $one }}{{ $ret = $.Receiver }}{{ end }} func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }} {{ if $one }}*{{ $.Name }}{{ else }}int{{ end }}, err error) { var ( builder = sql.Dialect({{ $receiver }}.driver.Dialect()) selector = builder.Select({{ $.Package }}.{{ if $one }}Columns...{{ else }}{{ $.ID.Constant }}{{ end }}).From(builder.Table({{ $.Package }}.Table)) ) {{- if $one }} {{ $.Package }}.ID({{ $receiver }}.id)(selector) {{- else }} for _, p := range {{ $receiver }}.predicates { p(selector) } {{- end }} rows := &sql.Rows{} query, args := selector.Query() if err = {{ $receiver }}.driver.Query(ctx, query, args, rows); err != nil { return {{ $zero }}, err } defer rows.Close() var ids []int for rows.Next() { var id int {{- if $one }} {{ $.Receiver }} = &{{ $.Name }}{config: {{ $receiver }}.config} if err := {{ $.Receiver }}.FromRows(rows); err != nil { return {{ $zero }}, fmt.Errorf("{{ $pkg }}: failed scanning row into {{ $.Name }}: %v", err) } id = {{ if $.ID.IsString }}{{ $.Receiver }}.id(){{ else if not $.ID.IsInt }}int({{ $.Receiver }}.ID){{ else }}{{ $.Receiver }}.ID{{ end }} {{- else }} if err := rows.Scan(&id); err != nil { return {{ $zero }}, fmt.Errorf("{{ $pkg }}: failed reading id: %v", err) } {{- end }} ids = append(ids, id) } {{- if $one }} switch n := len(ids); { case n == 0: return {{ $zero }}, &ErrNotFound{fmt.Sprintf("{{ $.Name }} with id: %v", {{ $receiver }}.id)} case n > 1: return {{ $zero }}, fmt.Errorf("{{ $pkg }}: more than one {{ $.Name }} with the same id: %v", {{ $receiver }}.id) } {{- else }} if len(ids) == 0 { return {{ $zero }}, nil } {{- end }} {{/* if there's something to update, start a transaction. */}} tx, err := {{ $receiver }}.driver.Tx(ctx) if err != nil { return {{ $zero }}, err } {{- if $.Fields }} var ( res sql.Result updater = builder.Update({{ $.Package }}.Table).Where(sql.InInts({{ $.Package }}.{{ $.ID.Constant }}, ids...)) ) {{- range $_, $f := $.Fields }} {{- if or (not $f.Immutable) $f.UpdateDefault }} if value := {{ $receiver }}.{{ $f.BuilderField }}; value != nil { {{- if $f.IsJSON }} buf, err := json.Marshal(*value) if err != nil { return {{ $zero }}, err } updater.Set({{ $.Package }}.{{ $f.Constant }}, buf) {{- else }} updater.Set({{ $.Package }}.{{ $f.Constant }}, *value) {{- end }} {{- if $one }} {{ $.Receiver }}.{{ $f.StructField }} = {{ if not $f.Nillable }}*{{ end }}value {{- end }} } {{- if $f.Type.Numeric }} if value := {{ $receiver }}.add{{ $f.BuilderField }}; value != nil { updater.Add({{ $.Package }}.{{ $f.Constant }}, *value) {{- if $one }} {{- if $f.Nillable }} if {{ $.Receiver }}.{{ $f.StructField }} != nil { *{{ $.Receiver }}.{{ $f.StructField }} += *value } else { {{ $.Receiver }}.{{ $f.StructField }} = value } {{- else }} {{ $.Receiver }}.{{ $f.StructField }} += *value {{- end }} {{- end }} } {{- end }} {{- end }} {{- if $f.Optional }} if {{ $receiver }}.clear{{ $f.BuilderField }} { {{- if $one }} {{- if $f.Nillable }} {{ $.Receiver }}.{{ $f.StructField }} = nil {{- else }} var value {{ $f.Type }} {{ $.Receiver }}.{{ $f.StructField }} = value {{- end }} {{- end }} updater.SetNull({{ $.Package }}.{{ $f.Constant }}) } {{- end }} {{- end }} if !updater.Empty() { query, args := updater.Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } } {{- else if $.Edges }}{{/* ent without fields, but with edges */}} var res sql.Result {{- end }} {{- range $_, $e := $.Edges }} {{- if $e.M2M }} if len({{ $receiver }}.removed{{ $e.StructField }}) > 0 { {{- $a := 0 }}{{ $b := 1 }}{{ if $e.IsInverse }}{{ $a = 1 }}{{ $b = 0 }}{{ end }} eids := make([]int, len({{ $receiver }}.removed{{ $e.StructField }})) for eid := range {{ $receiver }}.removed{{ $e.StructField }} { {{- template "dialect/sql/update/convertid" $e -}} eids = append(eids, eid) } query, args := builder.Delete({{ $.Package }}.{{ $e.TableConstant }}). Where(sql.InInts({{ $.Package }}.{{ $e.PKConstant }}[{{ $a }}], ids...)). Where(sql.InInts({{ $.Package }}.{{ $e.PKConstant }}[{{ $b }}], eids...)). Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } {{- if $e.SelfRef }}{{/* M2M with self reference */}}{{/* TODO: use OR in the case above. */}} query, args = builder.Delete({{ $.Package }}.{{ $e.TableConstant }}). Where(sql.InInts({{ $.Package }}.{{ $e.PKConstant }}[{{ $b }}], ids...)). Where(sql.InInts({{ $.Package }}.{{ $e.PKConstant }}[{{ $a }}], eids...)). Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } {{- end }} } {{- else if $e.O2M }} if len({{ $receiver }}.removed{{ $e.StructField }}) > 0 { eids := make([]int, len({{ $receiver }}.removed{{ $e.StructField }})) for eid := range {{ $receiver }}.removed{{ $e.StructField }} { {{- template "dialect/sql/update/convertid" $e -}} eids = append(eids, eid) } query, args := builder.Update({{ $.Package }}.{{ $e.TableConstant }}). SetNull({{ $.Package }}.{{ $e.ColumnConstant }}). Where(sql.InInts({{ $.Package }}.{{ $e.ColumnConstant }}, ids...)). Where(sql.InInts({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, eids...)). Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } } {{- else }}{{/* O2O or M2O */}} if {{ $receiver }}.cleared{{ $e.StructField }} { query, args := builder.Update({{ $.Package }}.{{ $e.TableConstant }}). SetNull({{ $.Package }}.{{ $e.ColumnConstant }}). Where(sql.InInts({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, ids...)). Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } {{- if $e.SelfRef }}{{/* O2O with self reference */}} query, args = builder.Update({{ $.Package }}.{{ $e.TableConstant }}). SetNull({{ $.Package }}.{{ $e.ColumnConstant }}). Where(sql.InInts({{ $.Package }}.{{ $e.ColumnConstant }}, ids...)). Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } {{- end }} } {{- end }} if len({{ $receiver }}.{{ $e.BuilderField }}) > 0 { {{- if and $e.Unique $e.SelfRef }}{{/* O2O with self reference */}} if n := len(ids); n > 1 { return {{ $zero }}, rollback(tx, fmt.Errorf("{{ $pkg }}: can't link O2O edge \"{{ $e.Name }}\" to %d vertices (> 1)", n)) } for eid := range {{ $receiver }}.{{ $e.BuilderField }} { {{- template "dialect/sql/update/convertid" $e -}} query, args := builder.Update({{ $.Package }}.{{ $e.TableConstant }}). Set({{ $.Package }}.{{ $e.ColumnConstant }}, eid). Where(sql.EQ({{ $.Package }}.{{ $.ID.Constant }}, ids[0])).Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } query, args = builder.Update({{ $.Package }}.{{ $e.TableConstant }}). Set({{ $.Package }}.{{ $e.ColumnConstant }}, ids[0]). Where(sql.EQ({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, eid).And().IsNull({{ $.Package }}.{{ $e.ColumnConstant }})).Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } affected, err := res.RowsAffected() if err != nil { return {{ $zero }}, rollback(tx, err) } if int(affected) < len({{ $receiver }}.{{ $e.BuilderField }}) { return {{ $zero }}, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("\"{{ $e.Name }}\" (%v) already connected to a different \"{{ $.Name }}\"", eid)}) } } {{- else if $e.M2M }} values := make([][]int, 0, len(ids)) for _, id := range ids { for eid := range {{ $receiver }}.{{ $e.BuilderField }} { {{- template "dialect/sql/update/convertid" $e -}} values = append(values, []int{id, eid}, {{- if $e.SelfRef }}[]int{eid, id}{{ end }}){{/* self-ref creates the edges in both ways. */}} } } {{- $a := 0 }}{{ $b := 1 }}{{ if $e.IsInverse }}{{ $a = 1 }}{{ $b = 0 }}{{ end }} builder := builder.Insert({{ $.Package }}.{{ $e.TableConstant }}). Columns({{ $.Package }}.{{ $e.PKConstant }}[{{ $a }}], {{ $.Package }}.{{ $e.PKConstant }}[{{ $b }}]) for _, v := range values { builder.Values(v[0], v[1]) } query, args := builder.Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } {{- else if $e.M2O }} for eid := range {{ $receiver }}.{{ $e.BuilderField }} { {{- template "dialect/sql/update/convertid" $e -}} query, args := builder.Update({{ $.Package }}.{{ $e.TableConstant }}). Set({{ $.Package }}.{{ $e.ColumnConstant }}, eid). Where(sql.InInts({{ $.Package }}.{{ $.ID.Constant }}, ids...)). Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } } {{- else if $e.O2M }} for _, id := range ids { p := sql.P() for eid := range {{ $receiver }}.{{ $e.BuilderField }} { {{- template "dialect/sql/update/convertid" $e -}} p.Or().EQ({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, eid) } query, args := builder.Update({{ $.Package }}.{{ $e.TableConstant }}). Set({{ $.Package }}.{{ $e.ColumnConstant }}, id). Where(sql.And(p, sql.IsNull({{ $.Package }}.{{ $e.ColumnConstant }}))). Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } affected, err := res.RowsAffected() if err != nil { return {{ $zero }}, rollback(tx, err) } if int(affected) < len({{ $receiver }}.{{ $e.BuilderField }}) { return {{ $zero }}, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("one of \"{{ $e.Name }}\" %v already connected to a different \"{{ $.Name }}\"", keys({{ $receiver }}.{{ $e.BuilderField }}))}) } } {{- else }}{{/* O2O */}} for _, id := range ids { {{- if $.Type.ID.IsString }} eid, serr := strconv.Atoi(keys({{ $receiver }}.{{ $e.BuilderField }})[0]) if serr != nil { return {{ $zero }}, rollback(tx, err) } {{- else }} eid := keys({{ $receiver }}.{{ $e.BuilderField }})[0] {{- end }} {{- if $e.IsInverse }} query, args := builder.Update({{ $.Package }}.{{ $e.TableConstant }}). Set({{ $.Package }}.{{ $e.ColumnConstant }}, eid). Where(sql.EQ({{ $.Package }}.{{ $.ID.Constant }}, id).And().IsNull({{ $.Package }}.{{ $e.ColumnConstant }})). Query() {{- else }} query, args := builder.Update({{ $.Package }}.{{ $e.TableConstant }}). Set({{ $.Package }}.{{ $e.ColumnConstant }}, id). Where(sql.EQ({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, eid).And().IsNull({{ $.Package }}.{{ $e.ColumnConstant }})). Query() {{- end }} if err := tx.Exec(ctx, query, args, &res); err != nil { return {{ $zero }}, rollback(tx, err) } affected, err := res.RowsAffected() if err != nil { return {{ $zero }}, rollback(tx, err) } if int(affected) < len({{ $receiver }}.{{ $e.BuilderField }}) { return {{ $zero }}, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("one of \"{{ $e.Name }}\" %v already connected to a different \"{{ $.Name }}\"", keys({{ $receiver }}.{{ $e.BuilderField }}))}) } } {{- end }} } {{- end }} if err = tx.Commit(); err != nil { return {{ $zero }}, err } return {{ if $one }}{{ $.Receiver }}{{ else }}len(ids){{ end }}, nil } {{ end }} {{ define "dialect/sql/update/convertid" }} {{- if $.Type.ID.IsString }} eid, serr := strconv.Atoi(eid) if serr != nil { err = rollback(tx, serr) return {{/* return is not knwon at this point. */}} } {{- else if not $.Type.ID.IsInt }} eid := int(eid) {{- end }} {{ end }}