mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: support external ValueScanner for id field (#4487)
This commit is contained in:
@@ -28,24 +28,39 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name
|
||||
return nil, err
|
||||
}
|
||||
{{- if $.HasCompositeID }}
|
||||
{{- else if or $.ID.Type.ValueScanner (not $.ID.Type.Numeric) }}
|
||||
{{- else if or $.ID.HasValueScanner $.ID.Type.ValueScanner (not $.ID.Type.Numeric) }}
|
||||
if _spec.ID.Value != nil {
|
||||
{{- /* If the ID type is not a pointer, but implements the ValueScanner interface (e.g. UUID fields). */}}
|
||||
{{- if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}
|
||||
if id, ok := _spec.ID.Value.(*{{ $.ID.Type }}); ok {
|
||||
_node.ID = *id
|
||||
{{- else }}
|
||||
if id, ok := _spec.ID.Value.({{ $.ID.Type }}); ok {
|
||||
_node.ID = id
|
||||
{{- end }}
|
||||
{{- if $.ID.Type.ValueScanner }}
|
||||
} else if err := _node.ID.Scan(_spec.ID.Value); err != nil {
|
||||
return nil, err
|
||||
{{- else }}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected {{ $.Name }}.ID type: %T", _spec.ID.Value)
|
||||
{{- end }}
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
sv, ok := _spec.ID.Value.(field.ValueScanner)
|
||||
if !ok {
|
||||
sv = {{ $.ID.ScanValueFunc }}()
|
||||
if err := sv.Scan(_spec.ID.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if value, err := {{ $.ID.FromValueFunc }}(sv); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
_node.ID = value
|
||||
}
|
||||
{{- else }}
|
||||
{{- /* If the ID type is not a pointer, but implements the ValueScanner interface (e.g. UUID fields). */}}
|
||||
{{- if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}
|
||||
if id, ok := _spec.ID.Value.(*{{ $.ID.Type }}); ok {
|
||||
_node.ID = *id
|
||||
{{- else }}
|
||||
if id, ok := _spec.ID.Value.({{ $.ID.Type }}); ok {
|
||||
_node.ID = id
|
||||
{{- end }}
|
||||
{{- if $.ID.Type.ValueScanner }}
|
||||
} else if err := _node.ID.Scan(_spec.ID.Value); err != nil {
|
||||
return nil, err
|
||||
{{- else }}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected {{ $.Name }}.ID type: %T", _spec.ID.Value)
|
||||
{{- end }}
|
||||
}
|
||||
{{- end }}
|
||||
}
|
||||
{{- else }}
|
||||
{{- if $.ID.UserDefined }}
|
||||
@@ -78,7 +93,15 @@ func ({{ $receiver }} *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.Cr
|
||||
{{- if and (not $.HasCompositeID) $.ID.UserDefined }}
|
||||
if id, ok := {{ $mutation }}.{{ $.ID.MutationGet }}(); ok {
|
||||
_node.ID = id
|
||||
_spec.ID.Value = {{ if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}&{{ end }}id
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
vv, err := {{ $.ID.ValueFunc }}(id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
_spec.ID.Value = vv
|
||||
{{- else }}
|
||||
_spec.ID.Value = {{ if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}&{{ end }}id
|
||||
{{- end }}
|
||||
}
|
||||
{{- end }}
|
||||
{{- range $f := $.MutationFields }}
|
||||
@@ -194,8 +217,23 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }
|
||||
}
|
||||
{{- if $.HasOneFieldID }}
|
||||
mutation.{{ $.ID.BuilderField }} = &nodes[i].{{ $.ID.StructField }}
|
||||
{{- if or $.ID.IsString $.ID.IsUUID $.ID.IsBytes $.ID.IsOther }}
|
||||
{{- if and (or $.ID.IsString $.ID.IsUUID $.ID.IsBytes $.ID.IsOther) (not $.ID.HasValueScanner) }}
|
||||
{{- /* Do nothing, because these 4 types must be supplied by the user. */ -}}
|
||||
{{- else if $.ID.HasValueScanner }}
|
||||
if specs[i].ID.Value != nil {
|
||||
sv, ok := specs[i].ID.Value.(field.ValueScanner)
|
||||
if !ok {
|
||||
sv = {{ $.ID.ScanValueFunc }}()
|
||||
if err := sv.Scan(specs[i].ID.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if id, err := {{ $.ID.FromValueFunc }}(sv); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
nodes[i].ID = id
|
||||
}
|
||||
}
|
||||
{{- else if or $.ID.Type.ValueScanner }}
|
||||
if specs[i].ID.Value != nil {
|
||||
if err := nodes[i].ID.Scan(specs[i].ID.Value); err != nil {
|
||||
|
||||
@@ -11,8 +11,10 @@ in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
{{ $ctypes := dict }}
|
||||
{{ if $.HasOneFieldID }}
|
||||
{{ $idscantype := $.ID.NewScanType }}
|
||||
{{ $ctypes = set $ctypes $idscantype (list $.ID.Constant) }}
|
||||
{{- if not $.ID.HasValueScanner }}
|
||||
{{ $idscantype := $.ID.NewScanType }}
|
||||
{{ $ctypes = set $ctypes $idscantype (list $.ID.Constant) }}
|
||||
{{- end }}
|
||||
{{ end }}
|
||||
{{ range $f := $.Fields }}
|
||||
{{- if $f.HasValueScanner }}
|
||||
@@ -31,6 +33,10 @@ func (*{{ $.Name }}) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
{{- if and $.HasOneFieldID $.ID.HasValueScanner }}
|
||||
case {{ $.Package }}.{{ $.ID.Constant }}:
|
||||
values[i] = {{ $.ID.ScanValueFunc }}()
|
||||
{{- end }}
|
||||
{{- range $type, $columns := $ctypes }}
|
||||
case {{ range $i, $c := $columns }}{{ if ne $i 0 }},{{ end }}{{ $.Package }}.{{ $c }}{{ end }}:
|
||||
values[i] = {{ $type }}
|
||||
@@ -64,7 +70,7 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(columns []string, values []any
|
||||
switch columns[i] {
|
||||
{{- if $.HasOneFieldID }}
|
||||
case {{ $.Package }}.{{ $.ID.Constant }}:
|
||||
{{- if or $.ID.IsString $.ID.IsBytes $.ID.HasGoType }}
|
||||
{{- if or $.ID.IsString $.ID.IsBytes $.ID.HasGoType $.ID.HasValueScanner }}
|
||||
{{- with extend $ "Idx" "i" "Field" $.ID "Rec" $receiver }}
|
||||
{{ template "dialect/sql/decode/field" . }}
|
||||
{{- end }}
|
||||
|
||||
@@ -7,13 +7,22 @@ in the LICENSE file in the root directory of this source tree.
|
||||
{{/* gotype: entgo.io/ent/entc/gen.typeScope */}}
|
||||
|
||||
{{ define "dialect/sql/predicate/id" -}}
|
||||
sql.FieldEQ({{ $.ID.Constant }}, id)
|
||||
{{- $arg := "id" }}
|
||||
{{- with $.Scope.Arg }}
|
||||
{{- $arg = . }}
|
||||
{{- end -}}
|
||||
sql.FieldEQ({{ $.ID.Constant }}, {{ $arg }})
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/predicate/id/ops" -}}
|
||||
{{- $op := $.Scope.Op -}}
|
||||
{{- $storage := $.Scope.Storage -}}
|
||||
sql.Field{{ call $storage.OpCode $op }}({{ $.ID.Constant }}{{ if not $op.Niladic }},{{ if $op.Variadic }}ids...{{ else }}id{{ end }}{{ end }})
|
||||
{{- $arg := "id" -}}
|
||||
{{- if $op.Variadic }}{{ $arg = "ids" }}{{ end -}}
|
||||
{{- with $.Scope.Arg -}}
|
||||
{{- $arg = . -}}
|
||||
{{- end -}}
|
||||
sql.Field{{ call $storage.OpCode $op }}({{ $.ID.Constant }}{{ if not $op.Niladic }}, {{ $arg }}{{ if $op.Variadic }}...{{ end }}{{ end }})
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/predicate/field" -}}
|
||||
|
||||
@@ -116,7 +116,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeIDs[i] = node.ID
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
vv, err := {{ $.ID.ValueFunc }}(node.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
edgeIDs[i] = vv
|
||||
{{- else }}
|
||||
edgeIDs[i] = node.ID
|
||||
{{- end }}
|
||||
byID[node.ID] = node
|
||||
if init != nil {
|
||||
init(node)
|
||||
@@ -154,11 +162,29 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]any{new({{ $out }})}, values...), nil
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
return append([]any{ {{ $.ID.ScanValueFunc }}() }, values...), nil
|
||||
{{- else }}
|
||||
return append([]any{new({{ $out }})}, values...), nil
|
||||
{{- end }}
|
||||
}
|
||||
spec.Assign = func(columns []string, values []any) error {
|
||||
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
var outValue {{ $.ID.Type }}
|
||||
{{- with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out "Ident" "outValue" }}
|
||||
{{ template "dialect/sql/query/eagerloading/m2massign" . }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
outValue := {{- with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massignexpr" . }}{{- end }}
|
||||
{{- end }}
|
||||
{{- if $e.Type.ID.HasValueScanner }}
|
||||
var inValue {{ $e.Type.ID.Type }}
|
||||
{{- with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in "Ident" "inValue" }}
|
||||
{{ template "dialect/sql/query/eagerloading/m2massign" . }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
inValue := {{- with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massignexpr" . }}{{- end }}
|
||||
{{- end }}
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: {}}
|
||||
return assign(columns[1:], values[1:])
|
||||
@@ -218,7 +244,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
vv, err := {{ $.ID.ValueFunc }}(nodes[i].ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fks = append(fks, vv)
|
||||
{{- else }}
|
||||
fks = append(fks, nodes[i].ID)
|
||||
{{- end }}
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
{{- if $e.O2M }}
|
||||
if init != nil {
|
||||
@@ -431,7 +465,16 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
|
||||
{{- $ident := $.Scope.Ident -}}
|
||||
{{- $receiver := $.Scope.Receiver -}}
|
||||
id := {{ $receiver }}.ID
|
||||
{{- if $n.ID.HasValueScanner -}}
|
||||
id := any({{ $receiver }}.ID)
|
||||
vv, err := {{ $n.ID.ValueFunc }}({{ $receiver }}.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id = vv
|
||||
{{- else -}}
|
||||
id := {{ $receiver }}.ID
|
||||
{{- end }}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
|
||||
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
|
||||
@@ -452,10 +495,28 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
|
||||
{{ end }}
|
||||
|
||||
{{ define "dialect/sql/query/eagerloading/m2massign" }}
|
||||
{{- $arg := $.Scope.Arg }}
|
||||
{{- $field := $.Scope.Field }}
|
||||
{{- $scantype := $.Scope.ScanType }}
|
||||
{{ define "dialect/sql/query/eagerloading/m2massign" -}}
|
||||
{{- $arg := $.Scope.Arg -}}
|
||||
{{- $field := $.Scope.Field -}}
|
||||
{{- $scantype := $.Scope.ScanType -}}
|
||||
{{- $ident := $.Scope.Ident -}}
|
||||
{{- if $field.HasValueScanner }}
|
||||
if value, err := {{ $field.FromValueFunc }}({{ $arg }}); err != nil {
|
||||
return err
|
||||
} else {
|
||||
{{ $ident }} = value
|
||||
}
|
||||
{{- else if hasPrefix $scantype "sql" -}}
|
||||
{{ $ident }} = {{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField }}
|
||||
{{- else -}}
|
||||
{{ $ident }} = {{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/query/eagerloading/m2massignexpr" -}}
|
||||
{{- $arg := $.Scope.Arg -}}
|
||||
{{- $field := $.Scope.Field -}}
|
||||
{{- $scantype := $.Scope.ScanType -}}
|
||||
{{- if hasPrefix $scantype "sql" -}}
|
||||
{{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}}
|
||||
{{- else -}}
|
||||
|
||||
@@ -50,7 +50,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (_node {{ if
|
||||
if !ok {
|
||||
return {{ $zero }}, &ValidationError{Name: "{{ $.ID.Name }}", err: errors.New(`{{ $pkg }}: missing "{{ $.Name }}.{{ $.ID.Name }}" for update`)}
|
||||
}
|
||||
_spec.Node.ID.Value = id
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
vv, err := {{ $.ID.ValueFunc }}(id)
|
||||
if err != nil {
|
||||
return {{ $zero }}, err
|
||||
}
|
||||
_spec.Node.ID.Value = vv
|
||||
{{- else }}
|
||||
_spec.Node.ID.Value = id
|
||||
{{- end }}
|
||||
if fields := {{ $receiver }}.fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }})
|
||||
|
||||
Reference in New Issue
Block a user