entc/gen: support external ValueScanner for id field (#4487)

This commit is contained in:
Jannik Clausen
2026-02-18 07:41:35 +01:00
committed by GitHub
parent d056659140
commit ab0540611e
30 changed files with 2650 additions and 133 deletions

View File

@@ -28,24 +28,39 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name
return nil, err
}
{{- if $.HasCompositeID }}
{{- else if or $.ID.Type.ValueScanner (not $.ID.Type.Numeric) }}
{{- else if or $.ID.HasValueScanner $.ID.Type.ValueScanner (not $.ID.Type.Numeric) }}
if _spec.ID.Value != nil {
{{- /* If the ID type is not a pointer, but implements the ValueScanner interface (e.g. UUID fields). */}}
{{- if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}
if id, ok := _spec.ID.Value.(*{{ $.ID.Type }}); ok {
_node.ID = *id
{{- else }}
if id, ok := _spec.ID.Value.({{ $.ID.Type }}); ok {
_node.ID = id
{{- end }}
{{- if $.ID.Type.ValueScanner }}
} else if err := _node.ID.Scan(_spec.ID.Value); err != nil {
return nil, err
{{- else }}
} else {
return nil, fmt.Errorf("unexpected {{ $.Name }}.ID type: %T", _spec.ID.Value)
{{- end }}
{{- if $.ID.HasValueScanner }}
sv, ok := _spec.ID.Value.(field.ValueScanner)
if !ok {
sv = {{ $.ID.ScanValueFunc }}()
if err := sv.Scan(_spec.ID.Value); err != nil {
return nil, err
}
}
if value, err := {{ $.ID.FromValueFunc }}(sv); err != nil {
return nil, err
} else {
_node.ID = value
}
{{- else }}
{{- /* If the ID type is not a pointer, but implements the ValueScanner interface (e.g. UUID fields). */}}
{{- if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}
if id, ok := _spec.ID.Value.(*{{ $.ID.Type }}); ok {
_node.ID = *id
{{- else }}
if id, ok := _spec.ID.Value.({{ $.ID.Type }}); ok {
_node.ID = id
{{- end }}
{{- if $.ID.Type.ValueScanner }}
} else if err := _node.ID.Scan(_spec.ID.Value); err != nil {
return nil, err
{{- else }}
} else {
return nil, fmt.Errorf("unexpected {{ $.Name }}.ID type: %T", _spec.ID.Value)
{{- end }}
}
{{- end }}
}
{{- else }}
{{- if $.ID.UserDefined }}
@@ -78,7 +93,15 @@ func ({{ $receiver }} *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.Cr
{{- if and (not $.HasCompositeID) $.ID.UserDefined }}
if id, ok := {{ $mutation }}.{{ $.ID.MutationGet }}(); ok {
_node.ID = id
_spec.ID.Value = {{ if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}&{{ end }}id
{{- if $.ID.HasValueScanner }}
vv, err := {{ $.ID.ValueFunc }}(id)
if err != nil {
return nil, nil, err
}
_spec.ID.Value = vv
{{- else }}
_spec.ID.Value = {{ if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}&{{ end }}id
{{- end }}
}
{{- end }}
{{- range $f := $.MutationFields }}
@@ -194,8 +217,23 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }
}
{{- if $.HasOneFieldID }}
mutation.{{ $.ID.BuilderField }} = &nodes[i].{{ $.ID.StructField }}
{{- if or $.ID.IsString $.ID.IsUUID $.ID.IsBytes $.ID.IsOther }}
{{- if and (or $.ID.IsString $.ID.IsUUID $.ID.IsBytes $.ID.IsOther) (not $.ID.HasValueScanner) }}
{{- /* Do nothing, because these 4 types must be supplied by the user. */ -}}
{{- else if $.ID.HasValueScanner }}
if specs[i].ID.Value != nil {
sv, ok := specs[i].ID.Value.(field.ValueScanner)
if !ok {
sv = {{ $.ID.ScanValueFunc }}()
if err := sv.Scan(specs[i].ID.Value); err != nil {
return nil, err
}
}
if id, err := {{ $.ID.FromValueFunc }}(sv); err != nil {
return nil, err
} else {
nodes[i].ID = id
}
}
{{- else if or $.ID.Type.ValueScanner }}
if specs[i].ID.Value != nil {
if err := nodes[i].ID.Scan(specs[i].ID.Value); err != nil {

View File

@@ -11,8 +11,10 @@ in the LICENSE file in the root directory of this source tree.
{{ $ctypes := dict }}
{{ if $.HasOneFieldID }}
{{ $idscantype := $.ID.NewScanType }}
{{ $ctypes = set $ctypes $idscantype (list $.ID.Constant) }}
{{- if not $.ID.HasValueScanner }}
{{ $idscantype := $.ID.NewScanType }}
{{ $ctypes = set $ctypes $idscantype (list $.ID.Constant) }}
{{- end }}
{{ end }}
{{ range $f := $.Fields }}
{{- if $f.HasValueScanner }}
@@ -31,6 +33,10 @@ func (*{{ $.Name }}) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
{{- if and $.HasOneFieldID $.ID.HasValueScanner }}
case {{ $.Package }}.{{ $.ID.Constant }}:
values[i] = {{ $.ID.ScanValueFunc }}()
{{- end }}
{{- range $type, $columns := $ctypes }}
case {{ range $i, $c := $columns }}{{ if ne $i 0 }},{{ end }}{{ $.Package }}.{{ $c }}{{ end }}:
values[i] = {{ $type }}
@@ -64,7 +70,7 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(columns []string, values []any
switch columns[i] {
{{- if $.HasOneFieldID }}
case {{ $.Package }}.{{ $.ID.Constant }}:
{{- if or $.ID.IsString $.ID.IsBytes $.ID.HasGoType }}
{{- if or $.ID.IsString $.ID.IsBytes $.ID.HasGoType $.ID.HasValueScanner }}
{{- with extend $ "Idx" "i" "Field" $.ID "Rec" $receiver }}
{{ template "dialect/sql/decode/field" . }}
{{- end }}

View File

@@ -7,13 +7,22 @@ in the LICENSE file in the root directory of this source tree.
{{/* gotype: entgo.io/ent/entc/gen.typeScope */}}
{{ define "dialect/sql/predicate/id" -}}
sql.FieldEQ({{ $.ID.Constant }}, id)
{{- $arg := "id" }}
{{- with $.Scope.Arg }}
{{- $arg = . }}
{{- end -}}
sql.FieldEQ({{ $.ID.Constant }}, {{ $arg }})
{{- end }}
{{ define "dialect/sql/predicate/id/ops" -}}
{{- $op := $.Scope.Op -}}
{{- $storage := $.Scope.Storage -}}
sql.Field{{ call $storage.OpCode $op }}({{ $.ID.Constant }}{{ if not $op.Niladic }},{{ if $op.Variadic }}ids...{{ else }}id{{ end }}{{ end }})
{{- $arg := "id" -}}
{{- if $op.Variadic }}{{ $arg = "ids" }}{{ end -}}
{{- with $.Scope.Arg -}}
{{- $arg = . -}}
{{- end -}}
sql.Field{{ call $storage.OpCode $op }}({{ $.ID.Constant }}{{ if not $op.Niladic }}, {{ $arg }}{{ if $op.Variadic }}...{{ end }}{{ end }})
{{- end }}
{{ define "dialect/sql/predicate/field" -}}

View File

@@ -116,7 +116,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
for i, node := range nodes {
edgeIDs[i] = node.ID
{{- if $.ID.HasValueScanner }}
vv, err := {{ $.ID.ValueFunc }}(node.ID)
if err != nil {
return err
}
edgeIDs[i] = vv
{{- else }}
edgeIDs[i] = node.ID
{{- end }}
byID[node.ID] = node
if init != nil {
init(node)
@@ -154,11 +162,29 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
if err != nil {
return nil, err
}
return append([]any{new({{ $out }})}, values...), nil
{{- if $.ID.HasValueScanner }}
return append([]any{ {{ $.ID.ScanValueFunc }}() }, values...), nil
{{- else }}
return append([]any{new({{ $out }})}, values...), nil
{{- end }}
}
spec.Assign = func(columns []string, values []any) error {
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
{{- if $.ID.HasValueScanner }}
var outValue {{ $.ID.Type }}
{{- with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out "Ident" "outValue" }}
{{ template "dialect/sql/query/eagerloading/m2massign" . }}
{{- end }}
{{- else }}
outValue := {{- with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massignexpr" . }}{{- end }}
{{- end }}
{{- if $e.Type.ID.HasValueScanner }}
var inValue {{ $e.Type.ID.Type }}
{{- with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in "Ident" "inValue" }}
{{ template "dialect/sql/query/eagerloading/m2massign" . }}
{{- end }}
{{- else }}
inValue := {{- with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massignexpr" . }}{{- end }}
{{- end }}
if nids[inValue] == nil {
nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: {}}
return assign(columns[1:], values[1:])
@@ -218,7 +244,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
for i := range nodes {
fks = append(fks, nodes[i].ID)
{{- if $.ID.HasValueScanner }}
vv, err := {{ $.ID.ValueFunc }}(nodes[i].ID)
if err != nil {
return err
}
fks = append(fks, vv)
{{- else }}
fks = append(fks, nodes[i].ID)
{{- end }}
nodeids[nodes[i].ID] = nodes[i]
{{- if $e.O2M }}
if init != nil {
@@ -431,7 +465,16 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- $e := $.Scope.Edge }} {{/* the edge we need to genegrate the path to. */}}
{{- $ident := $.Scope.Ident -}}
{{- $receiver := $.Scope.Receiver -}}
id := {{ $receiver }}.ID
{{- if $n.ID.HasValueScanner -}}
id := any({{ $receiver }}.ID)
vv, err := {{ $n.ID.ValueFunc }}({{ $receiver }}.ID)
if err != nil {
return nil, err
}
id = vv
{{- else -}}
id := {{ $receiver }}.ID
{{- end }}
step := sqlgraph.NewStep(
sqlgraph.From({{ $n.Package }}.Table, {{ $n.Package }}.{{ $n.ID.Constant }}, id),
sqlgraph.To({{ $e.Type.Package }}.Table, {{ $e.Type.Package }}.{{ if $e.Type.HasCompositeID }}{{ $e.Ref.ColumnConstant }}{{ else }}{{ $e.Type.ID.Constant }}{{ end }}),
@@ -452,10 +495,28 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
{{ end }}
{{ define "dialect/sql/query/eagerloading/m2massign" }}
{{- $arg := $.Scope.Arg }}
{{- $field := $.Scope.Field }}
{{- $scantype := $.Scope.ScanType }}
{{ define "dialect/sql/query/eagerloading/m2massign" -}}
{{- $arg := $.Scope.Arg -}}
{{- $field := $.Scope.Field -}}
{{- $scantype := $.Scope.ScanType -}}
{{- $ident := $.Scope.Ident -}}
{{- if $field.HasValueScanner }}
if value, err := {{ $field.FromValueFunc }}({{ $arg }}); err != nil {
return err
} else {
{{ $ident }} = value
}
{{- else if hasPrefix $scantype "sql" -}}
{{ $ident }} = {{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField }}
{{- else -}}
{{ $ident }} = {{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }}
{{- end }}
{{- end }}
{{ define "dialect/sql/query/eagerloading/m2massignexpr" -}}
{{- $arg := $.Scope.Arg -}}
{{- $field := $.Scope.Field -}}
{{- $scantype := $.Scope.ScanType -}}
{{- if hasPrefix $scantype "sql" -}}
{{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}}
{{- else -}}

View File

@@ -50,7 +50,15 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (_node {{ if
if !ok {
return {{ $zero }}, &ValidationError{Name: "{{ $.ID.Name }}", err: errors.New(`{{ $pkg }}: missing "{{ $.Name }}.{{ $.ID.Name }}" for update`)}
}
_spec.Node.ID.Value = id
{{- if $.ID.HasValueScanner }}
vv, err := {{ $.ID.ValueFunc }}(id)
if err != nil {
return {{ $zero }}, err
}
_spec.Node.ID.Value = vv
{{- else }}
_spec.Node.ID.Value = id
{{- end }}
if fields := {{ $receiver }}.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, {{ $.Package }}.{{ $.ID.Constant }})