Files
ent/entc/gen/template/dialect/sql/update.tmpl
Ariel Mashraki 4323141fe2 ent: add license and copyright to template files
Reviewed By: alexsn

Differential Revision: D17149292

fbshipit-source-id: 837de5fad988de1e54438b47584701f2fc35326d
2019-09-01 03:03:15 -07:00

285 lines
11 KiB
Cheetah

{{/*
Copyright 2019-present Facebook Inc. All rights reserved.
This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}
{{ define "dialect/sql/update" }}
{{ $pkg := $.Scope.Package }}
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}
{{ $one := hasSuffix $builder "One" }}
{{- $zero := 0 }}{{ if $one }}{{ $zero = "nil" }}{{ end }}
{{- $ret := "n" }}{{ if $one }}{{ $ret = $.Receiver }}{{ end }}
func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) ({{ $ret }} {{ if $one }}*{{ $.Name }}{{ else }}int{{ end }}, err error) {
selector := sql.Select({{ $.Package }}.{{ if $one }}Columns...{{ else }}{{ $.ID.Constant }}{{ end }}).From(sql.Table({{ $.Package }}.Table))
{{- if $one }}
{{ $.Package }}.ID({{ $receiver }}.id)(selector)
{{- else }}
for _, p := range {{ $receiver }}.predicates {
p(selector)
}
{{- end }}
rows := &sql.Rows{}
query, args := selector.Query()
if err = {{ $receiver }}.driver.Query(ctx, query, args, rows); err != nil {
return {{ $zero }}, err
}
defer rows.Close()
var ids []int
for rows.Next() {
var id int
{{- if $one }}
{{ $.Receiver }} = &{{ $.Name }}{config: {{ $receiver }}.config}
if err := {{ $.Receiver }}.FromRows(rows); err != nil {
return {{ $zero }}, fmt.Errorf("{{ $pkg }}: failed scanning row into {{ $.Name }}: %v", err)
}
id = {{ $.Receiver }}.{{- if $.ID.IsString }}id(){{ else }}ID{{ end }}
{{- else }}
if err := rows.Scan(&id); err != nil {
return {{ $zero }}, fmt.Errorf("{{ $pkg }}: failed reading id: %v", err)
}
{{- end }}
ids = append(ids, id)
}
{{- if $one }}
switch n := len(ids); {
case n == 0:
return {{ $zero }}, fmt.Errorf("{{ $pkg }}: {{ $.Name }} not found with id: %v", {{ $receiver }}.id)
case n > 1:
return {{ $zero }}, fmt.Errorf("{{ $pkg }}: more than one {{ $.Name }} with the same id: %v", {{ $receiver }}.id)
}
{{- else }}
if len(ids) == 0 {
return {{ $zero }}, nil
}
{{- end }}
{{/* if there's something to update, start a transaction. */}}
tx, err := {{ $receiver }}.driver.Tx(ctx)
if err != nil {
return {{ $zero }}, err
}
{{- if $.Fields }}
var (
update bool
res sql.Result
builder = sql.Update({{ $.Package }}.Table).Where(sql.InInts({{ $.Package }}.{{ $.ID.Constant }}, ids...))
)
{{- range $_, $f := $.Fields }}
{{- if or (not $f.Immutable) $f.UpdateDefault }}
if {{ $receiver }}.{{ $f.StructField }} != nil {
update = true
builder.Set({{ $.Package }}.{{ $f.Constant }}, *{{ $receiver }}.{{ $f.StructField }})
{{- if $one }}
{{- if $f.Nillable }}
{{ $.Receiver }}.{{ pascal $f.Name }} = {{ $receiver }}.{{ $f.StructField }}
{{- else }}
{{ $.Receiver }}.{{ pascal $f.Name }} = *{{ $receiver }}.{{ $f.StructField }}
{{- end }}
{{- end }}
}
{{- end }}
{{- end }}
if update {
query, args := builder.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
}
{{- else if $.Edges }}{{/* ent without fields, but with edges */}}
var res sql.Result
{{- end }}
{{- range $_, $e := $.Edges }}
{{- if $e.M2M }}
if len({{ $receiver }}.removed{{ pascal $e.Name }}) > 0 {
{{- $a := 0 }}{{ $b := 1 }}{{ if $e.IsInverse }}{{ $a = 1 }}{{ $b = 0 }}{{ end }}
eids := make([]int, len({{ $receiver }}.removed{{ pascal $e.Name }}))
for eid := range {{ $receiver }}.removed{{ pascal $e.Name }} {
{{- template "dialect/sql/update/convertid" $e -}}
eids = append(eids, eid)
}
query, args := sql.Delete({{ $.Package }}.{{ $e.TableConstant }}).
Where(sql.InInts({{ $.Package }}.{{ $e.PKConstant }}[{{ $a }}], ids...)).
Where(sql.InInts({{ $.Package }}.{{ $e.PKConstant }}[{{ $b }}], eids...)).
Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
{{- if $e.SelfRef }}{{/* M2M with self reference */}}{{/* TODO: use OR in the case above. */}}
query, args = sql.Delete({{ $.Package }}.{{ $e.TableConstant }}).
Where(sql.InInts({{ $.Package }}.{{ $e.PKConstant }}[{{ $b }}], ids...)).
Where(sql.InInts({{ $.Package }}.{{ $e.PKConstant }}[{{ $a }}], eids...)).
Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
{{- end }}
}
{{- else if $e.O2M }}
if len({{ $receiver }}.removed{{ pascal $e.Name }}) > 0 {
eids := make([]int, len({{ $receiver }}.removed{{ pascal $e.Name }}))
for eid := range {{ $receiver }}.removed{{ pascal $e.Name }} {
{{- template "dialect/sql/update/convertid" $e -}}
eids = append(eids, eid)
}
query, args := sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
SetNull({{ $.Package }}.{{ $e.ColumnConstant }}).
Where(sql.InInts({{ $.Package }}.{{ $e.ColumnConstant }}, ids...)).
Where(sql.InInts({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, eids...)).
Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
}
{{- else }}{{/* O2O or M2O */}}
if {{ $receiver }}.cleared{{ pascal $e.Name }} {
query, args := sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
SetNull({{ $.Package }}.{{ $e.ColumnConstant }}).
Where(sql.InInts({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, ids...)).
Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
{{- if $e.SelfRef }}{{/* O2O with self reference */}}
query, args = sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
SetNull({{ $.Package }}.{{ $e.ColumnConstant }}).
Where(sql.InInts({{ $.Package }}.{{ $e.ColumnConstant }}, ids...)).
Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
{{- end }}
}
{{- end }}
if len({{ $receiver }}.{{ $e.StructField }}) > 0 {
{{- if and $e.Unique $e.SelfRef }}{{/* O2O with self reference */}}
if n := len(ids); n > 1 {
return {{ $zero }}, rollback(tx, fmt.Errorf("{{ $pkg }}: can't link O2O edge \"{{ $e.Name }}\" to %d vertices (> 1)", n))
}
for eid := range {{ $receiver }}.{{ $e.StructField }} {
{{- template "dialect/sql/update/convertid" $e -}}
query, args := sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
Set({{ $.Package }}.{{ $e.ColumnConstant }}, eid).
Where(sql.EQ({{ $.Package }}.{{ $.ID.Constant }}, ids[0])).Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
query, args = sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
Set({{ $.Package }}.{{ $e.ColumnConstant }}, ids[0]).
Where(sql.EQ({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, eid).And().IsNull({{ $.Package }}.{{ $e.ColumnConstant }})).Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
affected, err := res.RowsAffected()
if err != nil {
return {{ $zero }}, rollback(tx, err)
}
if int(affected) < len({{ $receiver }}.{{ $e.StructField }}) {
return {{ $zero }}, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("\"{{ $e.Name }}\" (%v) already connected to a different \"{{ $.Name }}\"", eid)})
}
}
{{- else if $e.M2M }}
values := make([][]int, 0, len(ids))
for _, id := range ids {
for eid := range {{ $receiver }}.{{ $e.StructField }} {
{{- template "dialect/sql/update/convertid" $e -}}
values = append(values, []int{id, eid}, {{- if $e.SelfRef }}[]int{eid, id}{{ end }}){{/* self-ref creates the edges in both ways. */}}
}
}
{{- $a := 0 }}{{ $b := 1 }}{{ if $e.IsInverse }}{{ $a = 1 }}{{ $b = 0 }}{{ end }}
builder := sql.Insert({{ $.Package }}.{{ $e.TableConstant }}).
Columns({{ $.Package }}.{{ $e.PKConstant }}[{{ $a }}], {{ $.Package }}.{{ $e.PKConstant }}[{{ $b }}])
for _, v := range values {
builder.Values(v[0], v[1])
}
query, args := builder.Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
{{- else if $e.M2O }}
for eid := range {{ $receiver }}.{{ $e.StructField }} {
{{- template "dialect/sql/update/convertid" $e -}}
query, args := sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
Set({{ $.Package }}.{{ $e.ColumnConstant }}, eid).
Where(sql.InInts({{ $.Package }}.{{ $.ID.Constant }}, ids...)).
Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
}
{{- else if $e.O2M }}
for _, id := range ids {
p := sql.P()
for eid := range {{ $receiver }}.{{ $e.StructField }} {
{{- template "dialect/sql/update/convertid" $e -}}
p.Or().EQ({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, eid)
}
query, args := sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
Set({{ $.Package }}.{{ $e.ColumnConstant }}, id).
Where(sql.And(p, sql.IsNull({{ $.Package }}.{{ $e.ColumnConstant }}))).
Query()
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
affected, err := res.RowsAffected()
if err != nil {
return {{ $zero }}, rollback(tx, err)
}
if int(affected) < len({{ $receiver }}.{{ $e.StructField }}) {
return {{ $zero }}, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("one of \"{{ $e.Name }}\" %v already connected to a different \"{{ $.Name }}\"", keys({{ $receiver }}.{{ $e.StructField }}))})
}
}
{{- else }}{{/* O2O */}}
for _, id := range ids {
{{- if $.Type.ID.IsString }}
eid, serr := strconv.Atoi(keys({{ $receiver }}.{{ $e.StructField }})[0])
if serr != nil {
return {{ $zero }}, rollback(tx, err)
}
{{- else }}
eid := keys({{ $receiver }}.{{ $e.StructField }})[0]
{{- end }}
{{- if $e.IsInverse }}
query, args := sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
Set({{ $.Package }}.{{ $e.ColumnConstant }}, eid).
Where(sql.EQ({{ $.Package }}.{{ $.ID.Constant }}, id).And().IsNull({{ $.Package }}.{{ $e.ColumnConstant }})).
Query()
{{- else }}
query, args := sql.Update({{ $.Package }}.{{ $e.TableConstant }}).
Set({{ $.Package }}.{{ $e.ColumnConstant }}, id).
Where(sql.EQ({{ $e.Type.Package }}.{{ $e.Type.ID.Constant }}, eid).And().IsNull({{ $.Package }}.{{ $e.ColumnConstant }})).
Query()
{{- end }}
if err := tx.Exec(ctx, query, args, &res); err != nil {
return {{ $zero }}, rollback(tx, err)
}
affected, err := res.RowsAffected()
if err != nil {
return {{ $zero }}, rollback(tx, err)
}
if int(affected) < len({{ $receiver }}.{{ $e.StructField }}) {
return {{ $zero }}, rollback(tx, &ErrConstraintFailed{msg: fmt.Sprintf("one of \"{{ $e.Name }}\" %v already connected to a different \"{{ $.Name }}\"", keys({{ $receiver }}.{{ $e.StructField }}))})
}
}
{{- end }}
}
{{- end }}
if err = tx.Commit(); err != nil {
return {{ $zero }}, err
}
return {{ if $one }}{{ $.Receiver }}{{ else }}len(ids){{ end }}, nil
}
{{ end }}
{{ define "dialect/sql/update/convertid" }}
{{- if $.Type.ID.IsString }}
eid, serr := strconv.Atoi(eid)
if serr != nil {
err = rollback(tx, serr)
return {{/* return is not knwon at this point. */}}
}
{{- end }}
{{ end }}