mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
entc/gen: support external ValueScanner for id field (#4487)
This commit is contained in:
@@ -116,7 +116,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeIDs[i] = node.ID
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
vv, err := {{ $.ID.ValueFunc }}(node.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
edgeIDs[i] = vv
|
||||
{{- else }}
|
||||
edgeIDs[i] = node.ID
|
||||
{{- end }}
|
||||
byID[node.ID] = node
|
||||
if init != nil {
|
||||
init(node)
|
||||
@@ -154,11 +162,29 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]any{new({{ $out }})}, values...), nil
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
return append([]any{ {{ $.ID.ScanValueFunc }}() }, values...), nil
|
||||
{{- else }}
|
||||
return append([]any{new({{ $out }})}, values...), nil
|
||||
{{- end }}
|
||||
}
|
||||
spec.Assign = func(columns []string, values []any) error {
|
||||
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
var outValue {{ $.ID.Type }}
|
||||
{{- with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out "Ident" "outValue" }}
|
||||
{{ template "dialect/sql/query/eagerloading/m2massign" . }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
outValue := {{- with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massignexpr" . }}{{- end }}
|
||||
{{- end }}
|
||||
{{- if $e.Type.ID.HasValueScanner }}
|
||||
var inValue {{ $e.Type.ID.Type }}
|
||||
{{- with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in "Ident" "inValue" }}
|
||||
{{ template "dialect/sql/query/eagerloading/m2massign" . }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
inValue := {{- with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massignexpr" . }}{{- end }}
|
||||
{{- end }}
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: {}}
|
||||
return assign(columns[1:], values[1:])
|
||||
@@ -218,7 +244,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
{{- if $.ID.HasValueScanner }}
|
||||
vv, err := {{ $.ID.ValueFunc }}(nodes[i].ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fks = append(fks, vv)
|
||||
{{- else }}
|
||||
fks = append(fks, nodes[i].ID)
|
||||
{{- end }}
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
{{- if $e.O2M }}
|
||||
if init != nil {
|
||||
@@ -431,7 +465,16 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
|
||||
{{- $ident := $.Scope.Ident -}}
|
||||
{{- $receiver := $.Scope.Receiver -}}
|
||||
id := {{ $receiver }}.ID
|
||||
{{- if $n.ID.HasValueScanner -}}
|
||||
id := any({{ $receiver }}.ID)
|
||||
vv, err := {{ $n.ID.ValueFunc }}({{ $receiver }}.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id = vv
|
||||
{{- else -}}
|
||||
id := {{ $receiver }}.ID
|
||||
{{- end }}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
|
||||
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
|
||||
@@ -452,10 +495,28 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
|
||||
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
|
||||
{{ end }}
|
||||
|
||||
{{ define "dialect/sql/query/eagerloading/m2massign" }}
|
||||
{{- $arg := $.Scope.Arg }}
|
||||
{{- $field := $.Scope.Field }}
|
||||
{{- $scantype := $.Scope.ScanType }}
|
||||
{{ define "dialect/sql/query/eagerloading/m2massign" -}}
|
||||
{{- $arg := $.Scope.Arg -}}
|
||||
{{- $field := $.Scope.Field -}}
|
||||
{{- $scantype := $.Scope.ScanType -}}
|
||||
{{- $ident := $.Scope.Ident -}}
|
||||
{{- if $field.HasValueScanner }}
|
||||
if value, err := {{ $field.FromValueFunc }}({{ $arg }}); err != nil {
|
||||
return err
|
||||
} else {
|
||||
{{ $ident }} = value
|
||||
}
|
||||
{{- else if hasPrefix $scantype "sql" -}}
|
||||
{{ $ident }} = {{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField }}
|
||||
{{- else -}}
|
||||
{{ $ident }} = {{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{ define "dialect/sql/query/eagerloading/m2massignexpr" -}}
|
||||
{{- $arg := $.Scope.Arg -}}
|
||||
{{- $field := $.Scope.Field -}}
|
||||
{{- $scantype := $.Scope.ScanType -}}
|
||||
{{- if hasPrefix $scantype "sql" -}}
|
||||
{{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}}
|
||||
{{- else -}}
|
||||
|
||||
Reference in New Issue
Block a user