From 2d33420c0c7a6b7d1b28e0cc0cd14e425b080fbe Mon Sep 17 00:00:00 2001 From: "Giau. Tran Minh" Date: Mon, 18 May 2026 17:07:16 +0000 Subject: [PATCH] entc: blob storage support --- .gitignore | 1 + blob.go | 354 ++++++++++++++++++ dialect/sql/schema/mysql.go | 2 + dialect/sql/schema/postgres.go | 2 + dialect/sql/schema/sqlite.go | 2 + dialect/sql/sqlgraph/blob.go | 74 ++++ entc/gen/graph.go | 32 ++ entc/gen/graph_test.go | 39 ++ entc/gen/template/builder/create.tmpl | 1 + entc/gen/template/builder/mutation_graph.tmpl | 2 + entc/gen/template/builder/setter.tmpl | 28 ++ entc/gen/template/client.tmpl | 41 ++ entc/gen/template/dialect/gremlin/create.tmpl | 1 + entc/gen/template/dialect/gremlin/decode.tmpl | 4 + entc/gen/template/dialect/gremlin/update.tmpl | 2 + entc/gen/template/dialect/sql/create.tmpl | 171 +++++++++ entc/gen/template/dialect/sql/decode.tmpl | 19 + entc/gen/template/dialect/sql/delete.tmpl | 34 ++ entc/gen/template/dialect/sql/ent.tmpl | 3 + entc/gen/template/dialect/sql/entql.tmpl | 3 + .../template/dialect/sql/feature/upsert.tmpl | 46 ++- entc/gen/template/dialect/sql/meta.tmpl | 21 +- entc/gen/template/dialect/sql/query.tmpl | 66 ++++ entc/gen/template/dialect/sql/update.tmpl | 152 ++++++++ entc/gen/template/ent.tmpl | 23 ++ entc/gen/template/import.tmpl | 2 + entc/gen/template/meta.tmpl | 11 +- entc/gen/template/mutation.tmpl | 74 +++- entc/gen/template/runtime.tmpl | 19 +- entc/gen/template/where.tmpl | 2 + entc/gen/type.go | 144 ++++++- entc/gen/type_test.go | 97 +++++ entc/load/schema.go | 8 + entc/load/schema_test.go | 46 +++ schema/field/field.go | 216 ++++++++++- schema/field/field_test.go | 6 +- schema/field/type.go | 5 +- 37 files changed, 1711 insertions(+), 42 deletions(-) create mode 100644 .gitignore create mode 100644 blob.go create mode 100644 dialect/sql/sqlgraph/blob.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..e43b0f988 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.DS_Store diff --git a/blob.go b/blob.go new file mode 100644 index 000000000..c95e857bc --- /dev/null +++ b/blob.go @@ -0,0 +1,354 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package ent + +import ( + "context" + "errors" + "fmt" + "io" + "io/fs" +) + +// Blob defines the interface for blob storage operations. +// Implementations should return [io/fs.ErrNotExist] (or an error wrapping it) +// from NewReader when the requested key does not exist. +// +// Single-row SQL create builders write blob data to external storage before +// inserting the database row. If the row insertion fails (for example, due to +// a constraint violation), generated code attempts to delete the just-written blobs. +type Blob interface { + // NewReader opens a reader for the given key. + NewReader(ctx context.Context, key string) (io.ReadCloser, error) + // NewWriter opens a writer for the given key. + NewWriter(ctx context.Context, key string) (io.WriteCloser, error) + // Delete removes the blob at the given key. + // Implementations should return nil (not an error) if the key does not exist. + Delete(ctx context.Context, key string) error + // Close releases any resources held by the bucket. + Close() error +} + +// BlobOpener is a function that opens a [Blob] bucket for the given field name. +type BlobOpener func(context.Context, string) (Blob, error) + +// BlobKey identifies a blob in storage by field name and key. +type BlobKey struct { + Field string + Key string +} + +// BlobQuerier queries existing blob keys from the database. +// [Blobs.Update] passes the mutated field names; [Blobs.Delete] passes nil +// to indicate all fields should be queried. +type BlobQuerier interface { + QueryBlobKeys(ctx context.Context, fields []string) ([]BlobKey, error) +} + +// BlobUpdateResult holds post-update blob operations. +type BlobUpdateResult struct { + Rollback BlobOp // Deletes newly-written blobs. Call on SQL failure. + Commit BlobOp // Deletes old replaced blobs. Call after successful SQL commit. +} + +// BlobOp is a deferred blob storage operation (e.g. rollback or commit). +type BlobOp func(context.Context) error + +// BlobKeyFunc generates a storage key for a blob from its content. +type BlobKeyFunc func(context.Context, []byte) (string, error) + +// Blobs orchestrates blob storage operations for a single mutation. +// Use [NewBlobs] to create, then call [Blobs.Set] or [Blobs.SetCleared] +// for each blob field, then [Blobs.Create] or [Blobs.Update]. +type Blobs struct { + opener BlobOpener + inputs []blobInput +} + +type blobInput struct { + field string + data []byte + newKey BlobKeyFunc + apply func(string) + cleared bool + clear func() +} + +// NewBlobs creates a blob orchestrator for the given opener. +func NewBlobs(opener BlobOpener) *Blobs { + return &Blobs{opener: opener} +} + +// Set adds a blob field to be written. The apply callback is called with +// the generated key to set it on the SQL spec and node. +func (b *Blobs) Set(f string, data []byte, key BlobKeyFunc, apply func(string)) { + b.inputs = append(b.inputs, blobInput{field: f, data: data, newKey: key, apply: apply}) +} + +// SetCleared marks a blob field as cleared. The clear callback should +// remove the key column from the SQL spec. +func (b *Blobs) SetCleared(f string, clear func()) { + b.inputs = append(b.inputs, blobInput{field: f, cleared: true, clear: clear}) +} + +// Create prepares inputs, writes blobs, and returns a rollback [BlobOp]. +func (b *Blobs) Create(ctx context.Context) (BlobOp, error) { + writes, err := b.prepare(ctx) + if err != nil { + return nil, err + } + return b.write(ctx, writes) +} + +// Update prepares inputs, queries old keys, writes new blobs, and returns +// a [BlobUpdateResult] for post-SQL handling. +func (b *Blobs) Update(ctx context.Context, q BlobQuerier) (*BlobUpdateResult, error) { + if len(b.inputs) == 0 { + return noopBlobResult, nil + } + writes, err := b.prepare(ctx) + if err != nil { + return nil, err + } + mutated := make([]string, len(b.inputs)) + cleared := make(map[string]bool) + for i := range b.inputs { + mutated[i] = b.inputs[i].field + if b.inputs[i].cleared { + cleared[b.inputs[i].field] = true + } + } + keys, err := q.QueryBlobKeys(ctx, mutated) + if err != nil { + return nil, fmt.Errorf("querying old blob keys: %w", err) + } + // Build a set of old keys per field to detect unchanged blobs. + oldKeys := make(map[string]string, len(keys)) + for _, k := range keys { + oldKeys[k.Field] = k.Key + } + // Filter out writes where the key is unchanged (same content). + filtered := writes[:0] + for _, wr := range writes { + if oldKeys[wr.Field] == wr.Key { + continue + } + filtered = append(filtered, wr) + } + rollback, err := b.write(ctx, filtered) + if err != nil { + return nil, err + } + // Collect orphaned blobs: old keys for fields that changed or were cleared. + var orphaned []BlobKey + for _, k := range keys { + if cleared[k.Field] { + orphaned = append(orphaned, k) + continue + } + for _, wr := range writes { + if wr.Field == k.Field && wr.Key != k.Key { + orphaned = append(orphaned, k) + break + } + } + } + return &BlobUpdateResult{ + Rollback: rollback, + Commit: b.deleteOp(orphaned), + }, nil +} + +// Delete queries existing blob keys and returns a [BlobOp] that removes +// them from storage. Use for delete mutations. +func (b *Blobs) Delete(ctx context.Context, q BlobQuerier) (BlobOp, error) { + keys, err := q.QueryBlobKeys(ctx, nil) + if err != nil { + return nil, err + } + return b.deleteOp(keys), nil +} + +type blobWrite struct { + BlobKey + data []byte +} + +func (b *Blobs) prepare(ctx context.Context) ([]blobWrite, error) { + var writes []blobWrite + for _, inp := range b.inputs { + if inp.cleared { + inp.clear() + continue + } + k, err := inp.newKey(ctx, inp.data) + if err != nil { + return nil, fmt.Errorf("generating blob key for %s: %w", inp.field, err) + } + if inp.apply != nil { + inp.apply(k) + } + writes = append(writes, blobWrite{ + BlobKey: BlobKey{Field: inp.field, Key: k}, + data: inp.data, + }) + } + return writes, nil +} + +func (b *Blobs) write(ctx context.Context, writes []blobWrite) (BlobOp, error) { + if len(writes) == 0 { + return noOp, nil + } + w := NewBlobStore(b.opener) + var written []BlobKey + for _, wr := range writes { + if err := w.write(ctx, wr.Field, wr.Key, wr.data); err != nil { + var errs []error + errs = append(errs, fmt.Errorf("writing blob for %s: %w", wr.Field, err)) + for _, k := range written { + if derr := w.delete(ctx, k.Field, k.Key); derr != nil { + errs = append(errs, derr) + } + } + errs = append(errs, w.Close()) + return nil, errors.Join(errs...) + } + written = append(written, wr.BlobKey) + } + if err := w.Close(); err != nil { + return nil, err + } + return b.deleteOp(written), nil +} + +func (b *Blobs) deleteOp(keys []BlobKey) BlobOp { + if len(keys) == 0 { + return noOp + } + return func(ctx context.Context) error { + s := NewBlobStore(b.opener) + var errs []error + for _, k := range keys { + if err := s.delete(ctx, k.Field, k.Key); err != nil { + errs = append(errs, err) + } + } + if err := s.Close(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) + } +} + +var ( + noOp = func(context.Context) error { return nil } + noopBlobResult = &BlobUpdateResult{Rollback: noOp, Commit: noOp} +) + +// BlobReader returns a reader for the given key from the blob bucket. +// The returned reader closes both the underlying reader and the bucket. +// Returns nil, nil if the blob does not exist (fs.ErrNotExist). +func BlobReader(ctx context.Context, b Blob, key string) (io.ReadCloser, error) { + switch r, err := b.NewReader(ctx, key); { + case errors.Is(err, fs.ErrNotExist): + return nil, b.Close() + case err != nil: + return nil, errors.Join(err, b.Close()) + default: + return &blobReadCloser{ReadCloser: r, bucket: b}, nil + } +} + +type blobReadCloser struct { + io.ReadCloser + bucket Blob +} + +func (r *blobReadCloser) Close() error { + return errors.Join(r.ReadCloser.Close(), r.bucket.Close()) +} + +// BlobStore manages blob bucket lifecycles for read, write, and delete operations. +// It lazily opens buckets per field and reuses them for subsequent operations. +type BlobStore struct { + opener BlobOpener + buckets map[string]Blob +} + +// NewBlobStore creates a store that uses opener to lazily open buckets. +func NewBlobStore(opener BlobOpener) *BlobStore { + return &BlobStore{buckets: make(map[string]Blob), opener: opener} +} + +// Close closes all open buckets. +func (s *BlobStore) Close() error { + var errs []error + for _, b := range s.buckets { + errs = append(errs, b.Close()) + } + return errors.Join(errs...) +} + +// write writes data to the blob at key for the given field. +func (s *BlobStore) write(ctx context.Context, field, key string, data []byte) error { + b, err := s.bucket(ctx, field) + if err != nil { + return err + } + wr, err := b.NewWriter(ctx, key) + if err != nil { + return err + } + if _, err := wr.Write(data); err != nil { + return errors.Join(err, wr.Close(), b.Delete(ctx, key)) + } + return wr.Close() +} + +// delete removes the blob at key for the given field. +func (s *BlobStore) delete(ctx context.Context, field, key string) error { + b, err := s.bucket(ctx, field) + if err != nil { + return err + } + return b.Delete(ctx, key) +} + +// Read reads the blob at key for the given field. +// Returns nil, nil if the blob does not exist (fs.ErrNotExist). +func (s *BlobStore) Read(ctx context.Context, field, key string) ([]byte, error) { + b, err := s.bucket(ctx, field) + if err != nil { + return nil, err + } + rc, err := b.NewReader(ctx, key) + if errors.Is(err, fs.ErrNotExist) { + return nil, nil + } + if err != nil { + return nil, err + } + data, err := io.ReadAll(rc) + if closeErr := rc.Close(); closeErr != nil && err == nil { + err = closeErr + } + return data, err +} + +func (s *BlobStore) bucket(ctx context.Context, field string) (Blob, error) { + if b, ok := s.buckets[field]; ok { + return b, nil + } + if s.opener == nil { + return nil, errors.New("ent: blob storage not configured (missing WithBlobOpeners)") + } + b, err := s.opener(ctx, field) + if err != nil { + return nil, err + } + s.buckets[field] = b + return b, nil +} diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index 7f30af792..47aff25d5 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -135,6 +135,8 @@ func (d *MySQL) atTypeC(c1 *Column, c2 *schema.Column) error { switch c1.Type { case field.TypeBool: t = &schema.BoolType{T: "boolean"} + case field.TypeBlob: + return fmt.Errorf("blob fields are not stored in the database") case field.TypeInt8: t = &schema.IntegerType{T: mysql.TypeTinyInt} case field.TypeUint8: diff --git a/dialect/sql/schema/postgres.go b/dialect/sql/schema/postgres.go index 364b081c3..77ea1c57a 100644 --- a/dialect/sql/schema/postgres.go +++ b/dialect/sql/schema/postgres.go @@ -113,6 +113,8 @@ func (d *Postgres) atTypeC(c1 *Column, c2 *schema.Column) error { } var t schema.Type switch c1.Type { + case field.TypeBlob: + return fmt.Errorf("blob fields are not stored in the database") case field.TypeBool: t = &schema.BoolType{T: postgres.TypeBoolean} case field.TypeUint8, field.TypeInt8, field.TypeInt16: diff --git a/dialect/sql/schema/sqlite.go b/dialect/sql/schema/sqlite.go index 33708e881..944c0f80e 100644 --- a/dialect/sql/schema/sqlite.go +++ b/dialect/sql/schema/sqlite.go @@ -114,6 +114,8 @@ func (d *SQLite) atTypeC(c1 *Column, c2 *schema.Column) error { } var t schema.Type switch c1.Type { + case field.TypeBlob: + return fmt.Errorf("blob fields are not stored in the database") case field.TypeBool: t = &schema.BoolType{T: "bool"} case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32, diff --git a/dialect/sql/sqlgraph/blob.go b/dialect/sql/sqlgraph/blob.go new file mode 100644 index 000000000..c974ed70a --- /dev/null +++ b/dialect/sql/sqlgraph/blob.go @@ -0,0 +1,74 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package sqlgraph + +import ( + "context" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" +) + +// BlobSpec configures SQL-level blob key queries and implements [ent.BlobQuerier]. +type BlobSpec struct { + Driver dialect.Driver + Table string + Columns map[string]string // field name -> key column name + Predicate func(*sql.Selector) +} + +// QueryBlobKeys implements [ent.BlobQuerier]. +// If fields is nil, all columns are queried (for deletes); +// otherwise only the named fields are queried. +func (s *BlobSpec) QueryBlobKeys(ctx context.Context, fields []string) ([]ent.BlobKey, error) { + cols := s.Columns + if len(fields) > 0 { + cols = make(map[string]string, len(fields)) + for _, f := range fields { + if c, ok := s.Columns[f]; ok { + cols[f] = c + } + } + } + if len(cols) == 0 { + return nil, nil + } + names := make([]string, 0, len(cols)) + colNames := make([]string, 0, len(cols)) + for field, col := range cols { + names = append(names, field) + colNames = append(colNames, col) + } + selector := sql.Dialect(s.Driver.Dialect()). + Select(colNames...). + From(sql.Table(s.Table)) + if s.Predicate != nil { + s.Predicate(selector) + } + query, args := selector.Query() + rows := &sql.Rows{} + if err := s.Driver.Query(ctx, query, args, rows); err != nil { + return nil, err + } + defer rows.Close() + var keys []ent.BlobKey + for rows.Next() { + vals := make([]*string, len(colNames)) + ptrs := make([]any, len(colNames)) + for i := range vals { + ptrs[i] = &vals[i] + } + if err := rows.Scan(ptrs...); err != nil { + return nil, err + } + for i, v := range vals { + if v != nil && *v != "" { + keys = append(keys, ent.BlobKey{Field: names[i], Key: *v}) + } + } + } + return keys, rows.Err() +} diff --git a/entc/gen/graph.go b/entc/gen/graph.go index 7228516b1..786f712df 100644 --- a/entc/gen/graph.go +++ b/entc/gen/graph.go @@ -172,6 +172,32 @@ func NewGraph(c *Config, schemas ...*load.Schema) (g *Graph, err error) { for _, t := range g.Nodes { check(t.setupFKs(), "set %q foreign-keys", t.Name) } + for _, t := range g.Nodes { + t.setupBlobKeys() + } + // Non-lazy blob fields act as regular TypeBytes in the mutation and entity layers. + // Their Type is changed here so they flow through all normal field paths (struct, mutation, hooks). + for _, t := range g.Nodes { + for _, f := range t.Fields { + if f.IsBlob() && !f.IsBlobLazy() { + ti := &field.TypeInfo{Type: field.TypeBytes} + // Preserve custom GoType information if set. + if f.Type != nil && f.Type.RType != nil { + ti.RType = f.Type.RType + ti.Ident = f.Type.Ident + ti.PkgPath = f.Type.PkgPath + ti.PkgName = f.Type.PkgName + ti.Nillable = f.Type.Nillable + } + f.Type = ti + // Copy BlobDWSchemaType to SchemaType so it flows + // into the column definition for migration. + if f.def != nil && len(f.def.BlobDWSchemaType) > 0 { + f.def.SchemaType = f.def.BlobDWSchemaType + } + } + } + } for i := range schemas { g.addIndexes(schemas[i]) } @@ -654,10 +680,16 @@ func (g *Graph) Tables() (all []*schema.Table, err error) { if a := f.EntSQL(); a != nil && a.Skip { continue } + if f.IsBlobNoColumn() || f.IsBlobLazy() { + continue + } if !f.IsEdgeField() { table.AddColumn(f.Column()) } } + for _, bk := range n.BlobKeys { + table.AddColumn(bk.Field.Column()) + } switch { case tables[table.Name] == nil: tables[table.Name] = table diff --git a/entc/gen/graph_test.go b/entc/gen/graph_test.go index 018f5766c..029890304 100644 --- a/entc/gen/graph_test.go +++ b/entc/gen/graph_test.go @@ -347,6 +347,45 @@ func TestFKColumns(t *testing.T) { } } +func TestBlobDualWriteGoType(t *testing.T) { + require := require.New(t) + doc := &load.Schema{ + Name: "Doc", + Fields: []*load.Field{ + { + Name: "config", + Info: &field.TypeInfo{ + Type: field.TypeBlob, + Ident: "*Config", + PkgPath: "example.com/app", + PkgName: "app", + RType: &field.RType{ + Name: "Config", + Ident: "Config", + Kind: reflect.Struct, + PkgPath: "example.com/app", + }, + }, + ValueScanner: true, + BlobDualWrite: true, + }, + }, + } + g, err := NewGraph(&Config{Package: "entc/gen", Storage: drivers[0]}, doc) + require.NoError(err) + require.Len(g.Nodes, 1) + f := g.Nodes[0].Fields[0] + require.True(f.IsBlob()) + require.False(f.IsBlobNoColumn()) + // After graph initialization, the blob field should be TypeBytes but preserve GoType info. + require.Equal(field.TypeBytes, f.Type.Type) + require.True(f.HasGoType(), "GoType should be preserved for DualWrite blob fields") + require.Equal("*Config", f.Type.String(), "Type.String() should return the custom GoType") + require.Equal("example.com/app", f.Type.PkgPath) + require.NotNil(f.Type.RType) + require.Equal("Config", f.Type.RType.Name) +} + func TestAbortDuplicateFK(t *testing.T) { var ( user = &load.Schema{ diff --git a/entc/gen/template/builder/create.tmpl b/entc/gen/template/builder/create.tmpl index df1be57b5..4986dc686 100644 --- a/entc/gen/template/builder/create.tmpl +++ b/entc/gen/template/builder/create.tmpl @@ -98,6 +98,7 @@ func ({{ $receiver }} *{{ $builder }}) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func ({{ $receiver }} *{{ $builder }}) check() error { {{- range $f := $fields }} + {{- if and $f.IsBlobLazy (not (hasTemplate (printf "dialect/%s/model/fields" $.Storage))) }}{{ continue }}{{ end }} {{- $skip := false }}{{ if $.HasOneFieldID }}{{ if eq $f.Name $.ID.Name }}{{ $skip = true }}{{ end }}{{ end }} {{- if and (not $f.Optional) (not $skip) }} {{- $dialects := $f.RequiredFor }} diff --git a/entc/gen/template/builder/mutation_graph.tmpl b/entc/gen/template/builder/mutation_graph.tmpl index d688ea1ec..9c8f20e1e 100644 --- a/entc/gen/template/builder/mutation_graph.tmpl +++ b/entc/gen/template/builder/mutation_graph.tmpl @@ -166,6 +166,7 @@ func (m {{ $mutation }}) Tx() (*Tx, error) { {{ end }} {{ range $f := $n.Fields }} + {{- if $f.IsBlobLazy }}{{ continue }}{{ end }} {{ if $n.HasOneFieldID }} {{ $const := print $n.Package "." $f.Constant }} // {{ $f.MutationGetOld }} returns the old "{{ $f.Name }}" field's value of the {{ $n.Name }} entity. @@ -197,6 +198,7 @@ func (m *{{ $mutation }}) OldField(ctx context.Context, name string) (ent.Value, {{- with $n.Fields }} switch name { {{- range $f := . }} + {{- if $f.IsBlobLazy }}{{ continue }}{{ end }} {{- $const := print $n.Package "." $f.Constant }} case {{ $const }}: return m.{{ $f.MutationGetOld }}(ctx) diff --git a/entc/gen/template/builder/setter.tmpl b/entc/gen/template/builder/setter.tmpl index d5a03a69f..60e5ec4fe 100644 --- a/entc/gen/template/builder/setter.tmpl +++ b/entc/gen/template/builder/setter.tmpl @@ -22,6 +22,9 @@ in the LICENSE file in the root directory of this source tree. {{- end }} {{ range $f := $fields }} + {{- if $f.IsBlobLazy }}{{ continue }}{{ end }} + {{- /* Skip all blob fields on multi-row Update (blob key updates require UpdateOne). */}} + {{- if and $f.IsBlob $updater (not (hasSuffix $builder "UpdateOne")) }}{{ continue }}{{ end }} {{ $func := print "Set" $f.StructField }} // {{ $func }} sets the "{{ $f.Name }}" field. func ({{ $receiver }} *{{ $builder }}) {{ $func }}(v {{ $f.Type }}) *{{ $builder }} { @@ -74,6 +77,31 @@ in the LICENSE file in the root directory of this source tree. {{ end }} {{ end }} +{{- /* Lazy blob setters only on Create and UpdateOne (blob writes require a single-row context). */}} +{{- $updateOne := hasSuffix $builder "UpdateOne" }} +{{- if or $creator $updateOne }} +{{- if hasTemplate (printf "dialect/%s/model/fields" $.Storage) }} +{{- range $f := $.BlobFields }} + {{- if not $f.IsBlobLazy }}{{ continue }}{{ end }} + {{ $func := print "Set" $f.StructField }} + // {{ $func }} sets the "{{ $f.Name }}" field. + func ({{ $receiver }} *{{ $builder }}) {{ $func }}(v io.Reader) *{{ $builder }} { + {{ $receiver }}.mutation.Set{{ $f.StructField }}(v) + return {{ $receiver }} + } + + {{- if and $f.Optional $updateOne }} + {{ $clearFunc := print "Clear" $f.StructField }} + // {{ $clearFunc }} clears the value of the "{{ $f.Name }}" field. + func ({{ $receiver }} *{{ $builder }}) {{ $clearFunc }}() *{{ $builder }} { + {{ $receiver }}.mutation.{{ $clearFunc }}() + return {{ $receiver }} + } + {{- end }} +{{- end }} +{{- end }} +{{- end }} + {{ range $e := $.EdgesWithID }} {{ if and $updater $e.Immutable }} {{/* Skip to the next one as immutable edges cannot be updated. */}} diff --git a/entc/gen/template/client.tmpl b/entc/gen/template/client.tmpl index 31023d444..bf8e54525 100644 --- a/entc/gen/template/client.tmpl +++ b/entc/gen/template/client.tmpl @@ -63,6 +63,7 @@ import ( {{ $dep.Type.PkgName }} "{{ $dep.Type.PkgPath }}" {{- end }} "entgo.io/ent/dialect" + {{- $hasBlobNodes := false }}{{ $blobSupported := hasTemplate (printf "dialect/%s/model/fields" $.Storage) }}{{ if $blobSupported }}{{ range $n := $.Nodes }}{{ if $n.HasBlobFields }}{{ $hasBlobNodes = true }}{{ end }}{{ end }}{{ end }} {{ range $import := $.Storage.Imports -}} "{{ $import }}" {{ end -}} @@ -84,6 +85,10 @@ type ( hooks *hooks // interceptors to execute on queries. inters *inters + {{- if $hasBlobNodes }} + // blobOpeners configures how blob buckets are opened for each entity type. + blobOpeners BlobOpeners + {{- end }} {{- /* Additional dependency fields. */}} {{- range $dep := $deps }} {{ $dep.Field }} {{ $dep.Type }} @@ -154,6 +159,42 @@ func Driver(driver dialect.Driver) Option { } {{- end }} +{{- if $hasBlobNodes }} +// Blob is an alias for the [ent.Blob] interface defined in the entgo.io/ent package. +type Blob = ent.Blob + +// BlobOpeners configures how blob buckets are opened for each entity type. +// Each field is a function that opens a blob bucket for the given field name. +type BlobOpeners struct { + {{- range $n := $.Nodes }} + {{- if $n.HasBlobFields }} + {{ $n.Name }} ent.BlobOpener + {{- end }} + {{- end }} +} + +// WithBlobOpeners configures the blob bucket openers. +func WithBlobOpeners(openers BlobOpeners) Option { + return func(c *config) { + c.blobOpeners = openers + } +} + +{{- $needsDefaultBlobKey := false }} +{{- range $n := $.Nodes }} + {{- range $f := $n.BlobFields }} + {{- if not $f.HasBlobKey }}{{ $needsDefaultBlobKey = true }}{{ end }} + {{- end }} +{{- end }} +{{- if $needsDefaultBlobKey }} + +func defaultBlobKey(_ context.Context, data []byte) (string, error) { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]), nil +} +{{- end }} +{{- end }} + // Open opens a database/sql.DB specified by the driver name and // the data source name, and returns a new client attached to it. // Optional parameters can be added for configuring the client. diff --git a/entc/gen/template/dialect/gremlin/create.tmpl b/entc/gen/template/dialect/gremlin/create.tmpl index fb829f0ac..8fdef85a8 100644 --- a/entc/gen/template/dialect/gremlin/create.tmpl +++ b/entc/gen/template/dialect/gremlin/create.tmpl @@ -49,6 +49,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin() *dsl.Traversal { } {{- end }} {{- range $f := $.MutationFields }} + {{- if $f.IsBlob }}{{ continue }}{{ end }} if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok { {{- if $f.Unique }} constraints = append(constraints, &constraint{ diff --git a/entc/gen/template/dialect/gremlin/decode.tmpl b/entc/gen/template/dialect/gremlin/decode.tmpl index f5cc43e5d..40ddc3035 100644 --- a/entc/gen/template/dialect/gremlin/decode.tmpl +++ b/entc/gen/template/dialect/gremlin/decode.tmpl @@ -19,6 +19,7 @@ func ({{ $receiver }} *{{ $.Name }}) FromResponse(res *gremlin.Response) error { var {{ $scan }} struct { ID {{ $.ID.Type }} `json:"id,omitempty"` {{ range $f := $.Fields }} + {{- if $f.IsBlob }}{{ continue }}{{ end }} {{- $f.StructField }} {{ if and $f.IsTime (not $f.HasGoType) }}int64{{ else }}{{ if $f.NillableValue }}*{{ end }}{{ $f.Type }}{{ end }} `json:"{{ $f.StorageKey }},omitempty"` {{ end }} } @@ -27,6 +28,7 @@ func ({{ $receiver }} *{{ $.Name }}) FromResponse(res *gremlin.Response) error { } {{ $receiver }}.ID = {{ $scan }}.ID {{- range $i, $f := $.Fields }} + {{- if $f.IsBlob }}{{ continue }}{{ end }} {{- if and $f.IsTime (not $f.HasGoType) }} {{- if $f.Nillable }} v{{ $i }} := time.Unix(0, {{ $scan }}.{{ $f.StructField }}) @@ -56,6 +58,7 @@ func ({{ $receiver }} *{{ $slice }}) FromResponse(res *gremlin.Response) error { var {{ $scan }} []struct { ID {{ $.ID.Type }} `json:"id,omitempty"` {{ range $f := $.Fields }} + {{- if $f.IsBlob }}{{ continue }}{{ end }} {{- $f.StructField }} {{ if and $f.IsTime (not $f.HasGoType) }}int64{{ else }}{{ if $f.NillableValue }}*{{ end }}{{ $f.Type }}{{ end }} `json:"{{ $f.StorageKey }},omitempty"` {{ end }} } @@ -65,6 +68,7 @@ func ({{ $receiver }} *{{ $slice }}) FromResponse(res *gremlin.Response) error { for _, v := range {{ $scan }} { node := &{{ $.Name }}{ID: v.ID} {{- range $i, $f := $.Fields }} + {{- if $f.IsBlob }}{{ continue }}{{ end }} {{- if and $f.IsTime (not $f.HasGoType) }} {{- if $f.Nillable }} v{{ $i }} := time.Unix(0, v.{{ $f.StructField }}) diff --git a/entc/gen/template/dialect/gremlin/update.tmpl b/entc/gen/template/dialect/gremlin/update.tmpl index 6bb1c0e4c..6438fd585 100644 --- a/entc/gen/template/dialect/gremlin/update.tmpl +++ b/entc/gen/template/dialect/gremlin/update.tmpl @@ -75,6 +75,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{ trs []*dsl.Traversal ) {{- range $f := $.MutationFields }} + {{- if $f.IsBlob }}{{ continue }}{{ end }} {{- if or (not $f.Immutable) $f.UpdateDefault }} if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok { {{- if $f.Unique }} @@ -103,6 +104,7 @@ func ({{ $receiver }} *{{ $builder }}) gremlin({{ if $one }}id {{ $.ID.Type }}{{ {{- with $.HasOptional }} var properties []any {{- range $f := $.MutationFields }} + {{- if $f.IsBlob }}{{ continue }}{{ end }} {{- if $f.Optional }} if {{ $mutation }}.{{ $f.StructField }}Cleared() { properties = append(properties, {{ $.Package }}.{{ $f.Constant }}) diff --git a/entc/gen/template/dialect/sql/create.tmpl b/entc/gen/template/dialect/sql/create.tmpl index 32b037994..1bb4fae5f 100644 --- a/entc/gen/template/dialect/sql/create.tmpl +++ b/entc/gen/template/dialect/sql/create.tmpl @@ -10,6 +10,7 @@ in the LICENSE file in the root directory of this source tree. {{ $builder := pascal $.Scope.Builder }} {{ $receiver := $.Scope.Receiver }} {{ $mutation := print $receiver ".mutation" }} +{{ $pkg := base $.Config.Package }} func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name }}, error) { if err := {{ $receiver }}.check(); err != nil { @@ -21,12 +22,94 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (*{{ $.Name return nil, err } {{- end }} + {{- if $.HasBlobFields }} + _blobs := ent.NewBlobs({{ $mutation }}.blobOpeners.{{ $.Name }}) + {{- range $f := $.BlobFields }} + {{- if not $f.IsBlobLazy }} + if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok { + {{- if $f.HasValueScanner }} + _blobDV, err := {{ $f.ValueFunc }}(value) + if err != nil { + return nil, fmt.Errorf("{{ $pkg }}: encoding {{ $f.Name }}: %w", err) + } + var _blobData []byte + switch v := _blobDV.(type) { + case []byte: + _blobData = v + case string: + _blobData = []byte(v) + default: + return nil, fmt.Errorf("{{ $pkg }}: encoding {{ $f.Name }}: expected []byte or string, got %T", _blobDV) + } + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, _blobData, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + _node.{{ $f.BlobKeyColumn }} = &k + _spec.SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- else if $f.IsBlobGoString }} + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, []byte(value), + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + _node.{{ $f.BlobKeyColumn }} = &k + _spec.SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- else }} + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, value, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + _node.{{ $f.BlobKeyColumn }} = &k + _spec.SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- end }} + _node.{{ $f.StructField }} = value + } + {{- else }} + if r, ok := {{ $mutation }}.{{ $f.StructField }}(); ok { + _blobData, err := io.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("{{ $pkg }}: reading {{ $f.Name }}: %w", err) + } + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, _blobData, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + _node.{{ $f.BlobKeyColumn }} = &k + _spec.SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + } + {{- end }} + {{- end }} + _blobCleanup, err := _blobs.Create(ctx) + if err != nil { + return nil, err + } + {{- end }} if err := sqlgraph.CreateNode(ctx, {{ $receiver }}.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + {{- if $.HasBlobFields }} + return nil, errors.Join(err, _blobCleanup(ctx)) + {{- else }} return nil, err + {{- end }} } + {{- if $.HasBlobFields }} + if txd, ok := {{ $receiver }}.driver.(*txDriver); ok { + txd.mu.Lock() + txd.onRollback = append(txd.onRollback, func(next Rollbacker) Rollbacker { + return RollbackFunc(func(ctx context.Context, tx *Tx) error { + err := next.Rollback(ctx, tx) + return errors.Join(err, _blobCleanup(ctx)) + }) + }) + txd.mu.Unlock() + } + {{- end }} {{- if $.HasCompositeID }} {{- else if or $.ID.HasValueScanner $.ID.Type.ValueScanner (not $.ID.Type.Numeric) }} if _spec.ID.Value != nil { @@ -105,6 +188,7 @@ func ({{ $receiver }} *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.Cr } {{- end }} {{- range $f := $.MutationFields }} + {{- if $f.IsBlobNoColumn }}{{ continue }}{{ end }} if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok { {{- if $f.HasValueScanner }} vv, err := {{ $f.ValueFunc }}(value) @@ -163,6 +247,7 @@ func ({{ $receiver }} *{{ $builder }}) createSpec() (*{{ $.Name }}, *sqlgraph.Cr {{ define "dialect/sql/create_bulk" }} {{ $builder := pascal $.Scope.Builder }} {{ $receiver := $.Scope.Receiver }} +{{ $pkg := base $.Config.Package }} // Save creates the {{ $.Name }} entities in the database. func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name }}, error) { @@ -173,6 +258,9 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name } specs := make([]*sqlgraph.CreateSpec, len({{ $receiver }}.builders)) nodes := make([]*{{ $.Name }}, len({{ $receiver }}.builders)) mutators := make([]Mutator, len({{ $receiver }}.builders)) + {{- if $.HasBlobFields }} + _blobs := ent.NewBlobs({{ $receiver }}.blobOpeners.{{ $.Name }}) + {{- end }} for i := range {{ $receiver }}.builders { func(i int, root context.Context) { builder := {{ $receiver }}.builders[i] @@ -195,9 +283,77 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name } return nil, err } {{- end }} + {{- if $.HasBlobFields }} + {{- range $f := $.BlobFields }} + {{- if not $f.IsBlobLazy }} + if value, ok := mutation.{{ $f.MutationGet }}(); ok { + {{- if $f.HasValueScanner }} + _blobDV, err := {{ $f.ValueFunc }}(value) + if err != nil { + return nil, fmt.Errorf("{{ $pkg }}: encoding {{ $f.Name }}: %w", err) + } + var _blobData []byte + switch v := _blobDV.(type) { + case []byte: + _blobData = v + case string: + _blobData = []byte(v) + default: + return nil, fmt.Errorf("{{ $pkg }}: encoding {{ $f.Name }}: expected []byte or string, got %T", _blobDV) + } + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, _blobData, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + nodes[i].{{ $f.BlobKeyColumn }} = &k + specs[i].SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- else if $f.IsBlobGoString }} + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, []byte(value), + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + nodes[i].{{ $f.BlobKeyColumn }} = &k + specs[i].SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- else }} + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, value, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + nodes[i].{{ $f.BlobKeyColumn }} = &k + specs[i].SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- end }} + nodes[i].{{ $f.StructField }} = value + } + {{- else }} + if r, ok := mutation.{{ $f.StructField }}(); ok { + _blobData, err := io.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("{{ $pkg }}: reading {{ $f.Name }}: %w", err) + } + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, _blobData, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + nodes[i].{{ $f.BlobKeyColumn }} = &k + specs[i].SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + } + {{- end }} + {{- end }} + {{- end }} if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, {{ $receiver }}.builders[i+1].mutation) } else { + {{- if $.HasBlobFields }} + // Write blobs before creating SQL rows so insert failures can clean up written objects. + _blobCleanup, err := _blobs.Create(ctx) + if err != nil { + return nil, err + } + {{- end }} spec := &sqlgraph.BatchCreateSpec{Nodes: specs} {{- /* Allow mutating the sqlgraph.BatchCreateSpec by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/create_bulk/spec/*" }} @@ -210,7 +366,22 @@ func ({{ $receiver }} *{{ $builder }}) Save(ctx context.Context) ([]*{{ $.Name } if sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + {{- if $.HasBlobFields }} + return nil, errors.Join(err, _blobCleanup(ctx)) + {{- end }} } + {{- if $.HasBlobFields }} + if txd, ok := {{ $receiver }}.driver.(*txDriver); ok { + txd.mu.Lock() + txd.onRollback = append(txd.onRollback, func(next Rollbacker) Rollbacker { + return RollbackFunc(func(ctx context.Context, tx *Tx) error { + err := next.Rollback(ctx, tx) + return errors.Join(err, _blobCleanup(ctx)) + }) + }) + txd.mu.Unlock() + } + {{- end }} } if err != nil { return nil, err diff --git a/entc/gen/template/dialect/sql/decode.tmpl b/entc/gen/template/dialect/sql/decode.tmpl index f64e9054f..f94a67093 100644 --- a/entc/gen/template/dialect/sql/decode.tmpl +++ b/entc/gen/template/dialect/sql/decode.tmpl @@ -20,6 +20,9 @@ in the LICENSE file in the root directory of this source tree. {{- if $f.HasValueScanner }} {{- continue }} {{- end }} + {{- if or $f.IsBlobNoColumn $f.IsBlobLazy }} + {{- continue }} + {{- end }} {{ $names := list }} {{ if hasKey $ctypes $f.NewScanType }} {{ $names = get $ctypes $f.NewScanType }} @@ -42,6 +45,7 @@ func (*{{ $.Name }}) scanValues(columns []string) ([]any, error) { values[i] = {{ $type }} {{- end }} {{- range $f := $.Fields }} + {{- if $f.IsBlobNoColumn }}{{ continue }}{{ end }} {{- if $f.HasValueScanner }} case {{ $.Package }}.{{ $f.Constant }}: values[i] = {{ $f.ScanValueFunc }}() @@ -52,6 +56,10 @@ func (*{{ $.Name }}) scanValues(columns []string) ([]any, error) { case {{ $.Package }}.ForeignKeys[{{ $i }}]: // {{ $f.Name }} values[i] = {{ if not $f.UserDefined }}new(sql.NullInt64){{ else }}{{ $f.NewScanType }}{{ end }} {{- end }} + {{- range $i, $bk := $.BlobKeys }} + case {{ $.Package }}.BlobKeys[{{ $i }}]: // {{ $bk.Field.Name }} + values[i] = new(sql.NullString) + {{- end }} default: {{- /* In case of unknown column that was added by a modifier, predicate, etc., fallback to any. */}} values[i] = new(sql.UnknownType) @@ -83,6 +91,9 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(columns []string, values []any {{- end }} {{- end }} {{- range $f := $.Fields }} + {{- if or $f.IsBlobNoColumn $f.IsBlobLazy }} + {{- continue }} + {{- end }} case {{ $.Package }}.{{ $f.Constant }}: {{- with extend $ "Idx" "i" "Field" $f "Rec" $receiver }} {{ template "dialect/sql/decode/field" . }} @@ -104,6 +115,14 @@ func ({{ $receiver }} *{{ $.Name }}) assignValues(columns []string, values []any } {{- end }} {{- end }} + {{- range $i, $bk := $.BlobKeys }} + case {{ $.Package }}.BlobKeys[{{ $i }}]: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field {{ $bk.Field.Name }}", values[i]) + } else if value.Valid { + {{ $receiver }}.{{ $bk.StructField }} = &value.String + } + {{- end }} default: {{- /* In case of no match, allow getting this value by its name. */}} {{ $receiver }}.selectValues.Set(columns[i], values[i]) diff --git a/entc/gen/template/dialect/sql/delete.tmpl b/entc/gen/template/dialect/sql/delete.tmpl index 05cef7415..1758c5a96 100644 --- a/entc/gen/template/dialect/sql/delete.tmpl +++ b/entc/gen/template/dialect/sql/delete.tmpl @@ -26,10 +26,44 @@ func ({{ $receiver}} *{{ $builder }}) sqlExec(ctx context.Context) (int, error) } } } + {{- if $.HasBlobFields }} + // Collect blob keys before deleting rows so we can remove blobs from storage afterward. + _blobCleanup, _blobErr := ent.NewBlobs({{ $mutation }}.blobOpeners.{{ $.Name }}).Delete(ctx, &sqlgraph.BlobSpec{ + Driver: {{ $receiver }}.driver, + Predicate: _spec.Predicate, + Table: {{ $.Package }}.Table, + Columns: map[string]string{ + {{- range $f := $.BlobFields }} + {{ $.Package }}.{{ $f.Constant }}: "{{ $f.BlobKeyColumn }}", + {{- end }} + }, + }) + if _blobErr != nil { + return 0, _blobErr + } + {{- end }} affected, err := sqlgraph.DeleteNodes(ctx, {{ $receiver}}.driver, _spec) if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + {{- if $.HasBlobFields }} + if err == nil { + if txd, ok := {{ $receiver }}.driver.(*txDriver); ok { + txd.mu.Lock() + txd.onCommit = append(txd.onCommit, func(next Committer) Committer { + return CommitFunc(func(ctx context.Context, tx *Tx) error { + if err := next.Commit(ctx, tx); err != nil { + return err + } + return _blobCleanup(ctx) + }) + }) + txd.mu.Unlock() + } else { + err = _blobCleanup(ctx) + } + } + {{- end }} {{ $mutation }}.done = true return affected, err } diff --git a/entc/gen/template/dialect/sql/ent.tmpl b/entc/gen/template/dialect/sql/ent.tmpl index 030551c0d..e38ff7c19 100644 --- a/entc/gen/template/dialect/sql/ent.tmpl +++ b/entc/gen/template/dialect/sql/ent.tmpl @@ -12,6 +12,9 @@ in the LICENSE file in the root directory of this source tree. {{- $f := $fk.Field }} {{ $fk.StructField }} {{ if $f.Nillable }}*{{ end }}{{ $f.Type }} {{- end }} + {{- range $bk := $.BlobKeys }} + {{ $bk.StructField }} *string + {{- end }} selectValues sql.SelectValues {{- /* Allow adding struct fields by ent extensions or user templates.*/}} {{- with $tmpls := matchTemplate "dialect/sql/model/fields/*" }} diff --git a/entc/gen/template/dialect/sql/entql.tmpl b/entc/gen/template/dialect/sql/entql.tmpl index ed28bad05..1e424fe18 100644 --- a/entc/gen/template/dialect/sql/entql.tmpl +++ b/entc/gen/template/dialect/sql/entql.tmpl @@ -53,7 +53,9 @@ var schemaGraph = func() *sqlgraph.Schema { Type: "{{ $n.Name }}", Fields: map[string]*sqlgraph.FieldSpec{ {{- range $f := $n.Fields }} + {{- if not (or $f.IsBlobNoColumn $f.IsBlobLazy) }} {{ $n.Package }}.{{ $f.Constant }}: {Type: field.{{ $f.Type.ConstName }}, Column: {{ $n.Package }}.{{ $f.Constant }}}, + {{- end }} {{- end }} }, } @@ -140,6 +142,7 @@ type predicateAdder interface { {{- end }} {{ range $f := $n.Fields }} + {{- if or $f.IsBlobNoColumn $f.IsBlobLazy }}{{ continue }}{{ end }} {{ $type := $f.Type.Type.String }} {{ $iface := print (pascal $type) "P" }} {{- if $f.IsTime }}{{ $iface = "TimeP" }} diff --git a/entc/gen/template/dialect/sql/feature/upsert.tmpl b/entc/gen/template/dialect/sql/feature/upsert.tmpl index 896080738..93d2d233f 100644 --- a/entc/gen/template/dialect/sql/feature/upsert.tmpl +++ b/entc/gen/template/dialect/sql/feature/upsert.tmpl @@ -108,7 +108,30 @@ type ( } ) -{{ range $f := $.MutableFields }} +{{ range $f := $.UpsertFields }} + {{- if $f.IsBlob }} + {{ $func := print "Update" $f.StructField }} + // {{ $func }} sets the "{{ $f.Name }}" field to the value that was provided on create. + func (u *{{ $upsertSet }}) {{ $func }}() *{{ $upsertSet }} { + {{- if not $f.IsBlobNoColumn }} + u.SetExcluded({{ $.Package }}.{{ $f.Constant }}) + {{- end }} + u.SetExcluded("{{ $f.BlobKeyColumn }}") + return u + } + + {{ if $f.Optional }} + {{ $func := print "Clear" $f.StructField }} + // {{ $func }} clears the value of the "{{ $f.Name }}" field. + func (u *{{ $upsertSet }}) {{ $func }}() *{{ $upsertSet }} { + {{- if not $f.IsBlobNoColumn }} + u.SetNull({{ $.Package }}.{{ $f.Constant }}) + {{- end }} + u.SetNull("{{ $f.BlobKeyColumn }}") + return u + } + {{ end }} + {{- continue }}{{ end }} {{ $func := print "Set" $f.StructField }} // {{ $func }} sets the "{{ $f.Name }}" field. func (u *{{ $upsertSet }}) {{ $func }}(v {{ $f.Type }}) *{{ $upsertSet }} { @@ -411,7 +434,26 @@ func (u *{{ $upsertBulk }}) ExecX(ctx context.Context) { {{ $upsert := $.Scope.Upsert }} {{ $upsertSet := $.Scope.UpsertSet }} -{{ range $f := $.MutableFields }} +{{ range $f := $.UpsertFields }} + {{- if $f.IsBlob }} + {{ $func := print "Update" $f.StructField }} + // {{ $func }} sets the "{{ $f.Name }}" field to the value that was provided on create. + func (u *{{ $upsert }}) {{ $func }}() *{{ $upsert }} { + return u.Update(func(s *{{ $upsertSet }}) { + s.{{ $func }}() + }) + } + + {{ if $f.Optional }} + {{ $func := print "Clear" $f.StructField }} + // {{ $func }} clears the value of the "{{ $f.Name }}" field. + func (u *{{ $upsert }}) {{ $func }}() *{{ $upsert }} { + return u.Update(func(s *{{ $upsertSet }}) { + s.{{ $func }}() + }) + } + {{ end }} + {{- continue }}{{ end }} {{ $func := print "Set" $f.StructField }} // {{ $func }} sets the "{{ $f.Name }}" field. func (u *{{ $upsert }}) {{ $func }}(v {{ $f.Type }}) *{{ $upsert }} { diff --git a/entc/gen/template/dialect/sql/meta.tmpl b/entc/gen/template/dialect/sql/meta.tmpl index 06cfe2919..9a240fbca 100644 --- a/entc/gen/template/dialect/sql/meta.tmpl +++ b/entc/gen/template/dialect/sql/meta.tmpl @@ -49,10 +49,13 @@ in the LICENSE file in the root directory of this source tree. {{ $.ID.Constant }}, {{- end }} {{- range $f := $.Fields }} - {{- if not $f.IsDeprecated }} + {{- if and (not $f.IsDeprecated) (not $f.IsBlobNoColumn) (not $f.IsBlobLazy) }} {{ $f.Constant }}, {{- end }} {{- end }} + {{- range $k := $.BlobKeys }} + "{{ $k.Field.Name }}", + {{- end }} } {{/* If any of the edges owns a foreign-key */}} {{ with $.UnexportedForeignKeys }} @@ -65,6 +68,15 @@ in the LICENSE file in the root directory of this source tree. } {{ end }} + {{ with $.BlobKeys }} + // BlobKeys holds the SQL columns for blob storage keys. + var BlobKeys = []string{ + {{- range $bk := . }} + "{{ $bk.Field.Name }}", + {{- end }} + } + {{ end }} + {{ with $.NumM2M }} var ( {{- range $e := $.Edges }} @@ -94,6 +106,13 @@ func ValidColumn(column string) bool { } } {{- end }} + {{- with $.BlobKeys }} + for i := range BlobKeys { + if column == BlobKeys[i] { + return true + } + } + {{- end }} {{- with $.DeprecatedFields }} for _, f := range [...]string{ {{- range . }}{{ .Constant }},{{ end }} } { if column == f { diff --git a/entc/gen/template/dialect/sql/query.tmpl b/entc/gen/template/dialect/sql/query.tmpl index b5dc722e0..34b5b63a9 100644 --- a/entc/gen/template/dialect/sql/query.tmpl +++ b/entc/gen/template/dialect/sql/query.tmpl @@ -105,9 +105,75 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer {{- xtemplate $tmpl $ }} {{- end }} {{- end }} + {{- if $.HasLoadOnScanFields }} + if err := {{ $receiver }}.loadBlobFields(ctx, nodes); err != nil { + return nil, err + } + {{- end }} return nodes, nil } +{{- if $.HasLoadOnScanFields }} + +func ({{ $receiver }} *{{ $builder }}) loadBlobFields(ctx context.Context, nodes []*{{ $.Name }}) error { + _blobs := ent.NewBlobStore({{ $receiver }}.blobOpeners.{{ $.Name }}) + for _, n := range nodes { + {{- range $f := $.LoadOnScanFields }} + if n.{{ $f.BlobKeyColumn }} != nil && *n.{{ $f.BlobKeyColumn }} != "" { + {{- if $f.IsBlobNoColumn }} + {{- /* Non-DualWrite: blob is the only source of data. */}} + data, err := _blobs.Read(ctx, {{ $.Package }}.{{ $f.Constant }}, *n.{{ $f.BlobKeyColumn }}) + if err != nil { + return errors.Join(fmt.Errorf("loading {{ $f.Name }}: %w", err), _blobs.Close()) + } + if data == nil { + return errors.Join(fmt.Errorf("loading {{ $f.Name }}: object %q not found in blob storage", *n.{{ $f.BlobKeyColumn }}), _blobs.Close()) + } + {{- if $f.HasValueScanner }} + sv := {{ $f.ScanValueFunc }}() + if err := sv.Scan(data); err != nil { + return errors.Join(fmt.Errorf("scanning {{ $f.Name }}: %w", err), _blobs.Close()) + } + v, err := {{ $f.FromValueFunc }}(sv) + if err != nil { + return errors.Join(fmt.Errorf("scanning {{ $f.Name }}: %w", err), _blobs.Close()) + } + n.{{ $f.StructField }} = v + {{- else if $f.IsBlobGoString }} + n.{{ $f.StructField }} = {{ $f.Type }}(data) + {{- else }} + n.{{ $f.StructField }} = data + {{- end }} + {{- else }} + {{- /* DualWrite: prefer blob value, but preserve the column value when blob is missing. */}} + switch data, err := _blobs.Read(ctx, {{ $.Package }}.{{ $f.Constant }}, *n.{{ $f.BlobKeyColumn }}); { + case err != nil: + return errors.Join(fmt.Errorf("loading {{ $f.Name }}: %w", err), _blobs.Close()) + case data != nil: + {{- if $f.HasValueScanner }} + sv := {{ $f.ScanValueFunc }}() + if err := sv.Scan(data); err != nil { + return errors.Join(fmt.Errorf("scanning {{ $f.Name }}: %w", err), _blobs.Close()) + } + v, err := {{ $f.FromValueFunc }}(sv) + if err != nil { + return errors.Join(fmt.Errorf("scanning {{ $f.Name }}: %w", err), _blobs.Close()) + } + n.{{ $f.StructField }} = v + {{- else if $f.IsBlobGoString }} + n.{{ $f.StructField }} = {{ $f.Type }}(data) + {{- else }} + n.{{ $f.StructField }} = data + {{- end }} + } + {{- end }} + } + {{- end }} + } + return _blobs.Close() +} +{{- end }} + {{/* Generate a method to eager-load each edge. */}} {{- range $e := $.Edges }} func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error { diff --git a/entc/gen/template/dialect/sql/update.tmpl b/entc/gen/template/dialect/sql/update.tmpl index f98f19a33..b6b5032ae 100644 --- a/entc/gen/template/dialect/sql/update.tmpl +++ b/entc/gen/template/dialect/sql/update.tmpl @@ -98,6 +98,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (_node {{ if } } {{- range $f := $.MutationFields }} + {{- if $f.IsBlobNoColumn }}{{ continue }}{{ end }} {{- if or (not $f.Immutable) $f.UpdateDefault }} if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok { {{- if $f.HasValueScanner }} @@ -169,6 +170,92 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (_node {{ if {{- xtemplate $tmpl $ }} {{- end }} {{- end }} + {{- if and $one $.HasBlobFields }} + _blobs := ent.NewBlobs({{ $mutation }}.blobOpeners.{{ $.Name }}) + {{- range $f := $.BlobFields }} + {{- if not $f.IsBlobLazy }} + if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok { + {{- if $f.HasValueScanner }} + _blobDV, err := {{ $f.ValueFunc }}(value) + if err != nil { + return {{ $zero }}, fmt.Errorf("{{ $pkg }}: encoding {{ $f.Name }}: %w", err) + } + var _blobData []byte + switch v := _blobDV.(type) { + case []byte: + _blobData = v + case string: + _blobData = []byte(v) + default: + return {{ $zero }}, fmt.Errorf("{{ $pkg }}: encoding {{ $f.Name }}: expected []byte or string, got %T", _blobDV) + } + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, _blobData, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + _spec.SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- else if $f.IsBlobGoString }} + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, []byte(value), + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + _spec.SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- else }} + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, value, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + _spec.SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + {{- end }} + } + {{- if $f.Optional }} + if {{ $mutation }}.{{ $f.StructField }}Cleared() { + _blobs.SetCleared({{ $.Package }}.{{ $f.Constant }}, func() { + _spec.ClearField("{{ $f.BlobKeyColumn }}", field.TypeString) + }) + } + {{- end }} + {{- else }} + if r, ok := {{ $mutation }}.{{ $f.StructField }}(); ok { + _blobData, err := io.ReadAll(r) + if err != nil { + return {{ $zero }}, fmt.Errorf("{{ $pkg }}: reading {{ $f.Name }}: %w", err) + } + _blobs.Set({{ $.Package }}.{{ $f.Constant }}, _blobData, + {{ if $f.HasBlobKey }}{{ $.Package }}.{{ $f.BlobKeyName }}{{ else }}defaultBlobKey{{ end }}, + func(k string) { + _spec.SetField("{{ $f.BlobKeyColumn }}", field.TypeString, k) + }, + ) + } + {{- if $f.Optional }} + if {{ $mutation }}.{{ $f.MutationCleared }}() { + _blobs.SetCleared({{ $.Package }}.{{ $f.Constant }}, func() { + _spec.ClearField("{{ $f.BlobKeyColumn }}", field.TypeString) + }) + } + {{- end }} + {{- end }} + {{- end }} + _blobResult, err := _blobs.Update(ctx, &sqlgraph.BlobSpec{ + Driver: {{ $receiver }}.driver, + Table: {{ $.Package }}.Table, + Columns: map[string]string{ + {{- range $f := $.BlobFields }} + {{ $.Package }}.{{ $f.Constant }}: "{{ $f.BlobKeyColumn }}", + {{- end }} + }, + Predicate: func(s *sql.Selector) { + s.Where(sql.EQ({{ $.Package }}.{{ $.ID.Constant }}, _spec.Node.ID.Value)) + }, + }) + if err != nil { + return {{ $zero }}, err + } + {{- end }} {{- if $one }} _node = &{{ $.Name }}{config: {{ $receiver }}.config} _spec.Assign = _node.assignValues @@ -184,9 +271,74 @@ func ({{ $receiver }} *{{ $builder }}) sqlSave(ctx context.Context) (_node {{ if } else if sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + {{- if and $one $.HasBlobFields }} + return {{ $zero }}, errors.Join(err, _blobResult.Rollback(ctx)) + {{- else }} return {{ $zero }}, err + {{- end }} } {{ $mutation }}.done = true + {{- if and $one $.HasBlobFields }} + if txd, ok := {{ $receiver }}.driver.(*txDriver); ok { + txd.mu.Lock() + txd.onCommit = append(txd.onCommit, func(next Committer) Committer { + return CommitFunc(func(ctx context.Context, tx *Tx) error { + if err := next.Commit(ctx, tx); err != nil { + return err + } + return _blobResult.Commit(ctx) + }) + }) + txd.onRollback = append(txd.onRollback, func(next Rollbacker) Rollbacker { + return RollbackFunc(func(ctx context.Context, tx *Tx) error { + err := next.Rollback(ctx, tx) + return errors.Join(err, _blobResult.Rollback(ctx)) + }) + }) + txd.mu.Unlock() + } else { + if err := _blobResult.Commit(ctx); err != nil { + return {{ $zero }}, err + } + } + {{- end }} + {{- if and $one $.HasLoadOnScanFields }} + {{- /* Blob values are consistent after write — use mutation value for mutated fields + instead of re-reading from storage. For DualWrite fields, assignValues already + populated the value from the SQL column. For non-DualWrite fields not mutated, + read from blob storage. */}} + _blobReader := ent.NewBlobStore({{ $mutation }}.blobOpeners.{{ $.Name }}) + {{- range $f := $.LoadOnScanFields }} + if value, ok := {{ $mutation }}.{{ $f.MutationGet }}(); ok { + _node.{{ $f.StructField }} = value + } + {{- if $f.IsBlobNoColumn }} else if _node.{{ $f.BlobKeyColumn }} != nil && *_node.{{ $f.BlobKeyColumn }} != "" { + data, err := _blobReader.Read(ctx, {{ $.Package }}.{{ $f.Constant }}, *_node.{{ $f.BlobKeyColumn }}) + if err != nil { + return nil, errors.Join(fmt.Errorf("loading {{ $f.Name }} after update: %w", err), _blobReader.Close()) + } + {{- if $f.HasValueScanner }} + sv := {{ $f.ScanValueFunc }}() + if err := sv.Scan(data); err != nil { + return nil, errors.Join(fmt.Errorf("scanning {{ $f.Name }} after update: %w", err), _blobReader.Close()) + } + v, err := {{ $f.FromValueFunc }}(sv) + if err != nil { + return nil, errors.Join(fmt.Errorf("scanning {{ $f.Name }} after update: %w", err), _blobReader.Close()) + } + _node.{{ $f.StructField }} = v + {{- else if $f.IsBlobGoString }} + _node.{{ $f.StructField }} = {{ $f.Type }}(data) + {{- else }} + _node.{{ $f.StructField }} = data + {{- end }} + } + {{- end }} + {{- end }} + if err := _blobReader.Close(); err != nil { + return nil, err + } + {{- end }} return _node, nil } {{ end }} diff --git a/entc/gen/template/ent.tmpl b/entc/gen/template/ent.tmpl index 30559b64c..c0c220e1d 100644 --- a/entc/gen/template/ent.tmpl +++ b/entc/gen/template/ent.tmpl @@ -33,6 +33,7 @@ type {{ $.Name }} struct { ID {{ $.ID.Type }} {{ with $.Annotations.Fields.StructTag.id }}`{{ . }}`{{ else }}`{{ $.ID.StructTag }}`{{ end }} {{- end }} {{- range $f := $.Fields }} + {{- if $f.IsBlobLazy }}{{ continue }}{{ end }} {{- $tag := $f.StructTag }}{{ with $tags := $.Annotations.Fields.StructTag }}{{ with index $tags $f.Name }}{{ $tag = . }}{{ end }}{{ end }} {{- template "model/fieldcomment" $f }} {{ $f.StructField }} {{ if $f.NillableValue }}*{{ end }}{{ $f.Type }} {{ if not $f.Sensitive }}`{{ $tag }}`{{ else }}{{ template "model/omittags" $ }}{{ end }} @@ -100,6 +101,27 @@ type {{ $edgesType }} struct { } {{ end }} +{{- $blobModelTmpl := printf "dialect/%s/model/fields" $.Storage }} +{{- if hasTemplate $blobModelTmpl }} +{{- range $f := $.BlobFields }} + // {{ $f.StructField }}Reader opens a reader for the "{{ $f.Name }}" field from blob storage. + // The caller must close the returned reader when done. + func ({{ $receiver }} *{{ $.Name }}) {{ $f.StructField }}Reader(ctx context.Context) (io.ReadCloser, error) { + if {{ $receiver }}.{{ $f.BlobKeyColumn }} == nil || *{{ $receiver }}.{{ $f.BlobKeyColumn }} == "" { + return nil, fmt.Errorf("{{ $pkg }}: {{ $.Name }}.{{ $f.BlobKeyColumn }} is nil or empty") + } + if {{ $receiver }}.blobOpeners.{{ $.Name }} == nil { + return nil, fmt.Errorf("{{ $pkg }}: blob storage not configured (missing WithBlobOpeners)") + } + b, err := {{ $receiver }}.blobOpeners.{{ $.Name }}(ctx, {{ $.Package }}.{{ $f.Constant }}) + if err != nil { + return nil, err + } + return ent.BlobReader(ctx, b, *{{ $receiver }}.{{ $f.BlobKeyColumn }}) + } +{{- end }} +{{- end }} + {{- if not $.IsView }} // Update returns a builder for updating this {{ $.Name }}. // Note that you need to call {{ $.Name }}.Unwrap() before calling this method if this {{ $.Name }} @@ -146,6 +168,7 @@ type {{ $slice }} []*{{ $.Name }} builder.WriteString(fmt.Sprintf("id=%v{{ if $.Fields }}, {{ end }}", {{ $receiver }}.ID)) {{- end }} {{- range $i, $f := $.Fields }} + {{- if $f.IsBlobLazy }}{{ continue }}{{ end }} {{- if ne $i 0 }} builder.WriteString(", ") {{- end }} diff --git a/entc/gen/template/import.tmpl b/entc/gen/template/import.tmpl index 77e359d18..788462833 100644 --- a/entc/gen/template/import.tmpl +++ b/entc/gen/template/import.tmpl @@ -8,9 +8,11 @@ in the LICENSE file in the root directory of this source tree. {{ define "import" }} import ( + "bytes" "context" "errors" "fmt" + "io" "math" "strings" "sync" diff --git a/entc/gen/template/meta.tmpl b/entc/gen/template/meta.tmpl index aa8fef887..6ded4bfed 100644 --- a/entc/gen/template/meta.tmpl +++ b/entc/gen/template/meta.tmpl @@ -50,7 +50,7 @@ const ( {{ $hasDefault := false }}{{ range $f := $fields }}{{ if and $f.Default (not $f.IsEnum) }}{{ $hasDefault = true }}{{ end }}{{ end }} {{/* Generate global variables for hooks, validators and policy checkers */}} -{{ if or $hasDefault $.HasUpdateDefault $.HasValidators $.NumHooks $.NumPolicy $.NumInterceptors $.HasValueScanner }} +{{ if or $hasDefault $.HasUpdateDefault $.HasValidators $.NumHooks $.NumPolicy $.NumInterceptors $.HasValueScanner $.HasBlobFields }} {{- $numHooks := $.NumHooks }} {{- if $.NumPolicy }} {{- $numHooks = add $numHooks 1 }} @@ -75,6 +75,7 @@ const ( {{- end }} {{- $fields := $.Fields }}{{ if $.HasOneFieldID }}{{ if $.ID.UserDefined }}{{ $fields = append $fields $.ID }}{{ end }}{{ end }} {{- range $f := $fields }} + {{- if $f.IsBlobLazy }}{{ continue }}{{ end }} {{- if and $f.Default (not $f.IsEnum) }} {{- $default := $f.DefaultName }} // {{ $default }} holds the default value on creation for the "{{ $f.Name }}" field. @@ -95,6 +96,14 @@ const ( {{ $name }} {{ printf "func (%s) error" $type }} {{- end }} {{- end }} + {{- if hasTemplate (printf "dialect/%s/model/fields" $.Storage) }} + {{- range $f := $.BlobFields }} + {{- if $f.HasBlobKey }} + // {{ $f.BlobKeyName }} generates the blob storage key for the "{{ $f.Name }}" field. + {{ $f.BlobKeyName }} ent.BlobKeyFunc + {{- end }} + {{- end }} + {{- end }} {{- if $.HasValueScanner }} // ValueScanner of all {{ $.Name }} fields. ValueScanner struct { diff --git a/entc/gen/template/mutation.tmpl b/entc/gen/template/mutation.tmpl index 9e311c1a5..5e2330538 100644 --- a/entc/gen/template/mutation.tmpl +++ b/entc/gen/template/mutation.tmpl @@ -18,6 +18,7 @@ in the LICENSE file in the root directory of this source tree. types defined in this package (e.g. enums) are referenced without the package qualifier. */}} {{ $pkgPrefix := print $.Package "." }} +{{ $blobSupported := hasTemplate (printf "dialect/%s/model/fields" $.Storage) }} // Mutation represents an operation that mutates the {{ $.Name }} nodes in the graph. type Mutation struct { @@ -32,6 +33,12 @@ type Mutation struct { append{{ $f.BuilderField }} {{ replace $f.Type.String $pkgPrefix "" }} {{- end }} {{- end }} + {{- if hasTemplate (printf "dialect/%s/model/fields" $.Storage) }} + {{- range $f := $.BlobFields }} + {{- if not $f.IsBlobLazy }}{{ continue }}{{ end }} + {{ $f.BuilderField }} io.Reader + {{- end }} + {{- end }} clearedFields map[string]struct{} {{- range $e := $.EdgesWithID }} {{- if $e.Unique }} @@ -66,9 +73,11 @@ func (m *Mutation) Predicates() []predicate.{{ $.Name }} { } {{ range $f := $.Fields }} + {{- if $f.IsBlobLazy }}{{ continue }}{{ end }} {{ $const := $f.Constant }} {{ $type := replace $f.Type.String $pkgPrefix "" }} {{ $p := receiver $f.Type.String }}{{ if eq $p "m" }} {{ $p = "value" }} {{ end }} + {{ $func := $f.MutationSet }} // {{ $func }} sets the "{{ $f.Name }}" field. func (m *Mutation) {{ $func }}({{ $p }} {{ $type }}) { @@ -131,9 +140,8 @@ func (m *Mutation) Predicates() []predicate.{{ $.Name }} { {{ end }} {{ if $f.Optional }} - {{ $func := $f.MutationClear }} - // {{ $func }} clears the value of the "{{ $f.Name }}" field. - func (m *Mutation) {{ $func }}() { + // {{ $f.MutationClear }} clears the value of the "{{ $f.Name }}" field. + func (m *Mutation) {{ $f.MutationClear }}() { m.{{ $f.BuilderField }} = nil {{- if $f.SupportsMutationAdd }} m.add{{ $f.BuilderField }} = nil @@ -144,17 +152,15 @@ func (m *Mutation) Predicates() []predicate.{{ $.Name }} { m.clearedFields[{{ $const }}] = struct{}{} } - {{ $func = $f.MutationCleared }} - // {{ $func }} returns if the "{{ $f.Name }}" field was cleared in this mutation. - func (m *Mutation) {{ $func }}() bool { + // {{ $f.MutationCleared }} returns if the "{{ $f.Name }}" field was cleared in this mutation. + func (m *Mutation) {{ $f.MutationCleared }}() bool { _, ok := m.clearedFields[{{ $const }}] return ok } {{ end }} - {{ $func = $f.MutationReset }} - // {{ $func }} resets all changes to the "{{ $f.Name }}" field. - func (m *Mutation) {{ $func }}() { + // {{ $f.MutationReset }} resets all changes to the "{{ $f.Name }}" field. + func (m *Mutation) {{ $f.MutationReset }}() { m.{{ $f.BuilderField }} = nil {{- if $f.SupportsMutationAdd }} m.add{{ $f.BuilderField }} = nil @@ -168,6 +174,48 @@ func (m *Mutation) Predicates() []predicate.{{ $.Name }} { } {{ end }} +{{- if hasTemplate (printf "dialect/%s/model/fields" $.Storage) }} +{{- range $f := $.BlobFields }} + {{- if not $f.IsBlobLazy }}{{ continue }}{{ end }} + {{ $const := $f.Constant }} + {{ $func := print "Set" $f.StructField }} + // {{ $func }} sets the "{{ $f.Name }}" field. + func (m *Mutation) {{ $func }}(r io.Reader) { + m.{{ $f.BuilderField }} = r + } + + // {{ $f.StructField }} returns the value of the "{{ $f.Name }}" field in the mutation. + func (m *Mutation) {{ $f.StructField }}() (r io.Reader, exists bool) { + v := m.{{ $f.BuilderField }} + if v == nil { + return + } + return v, true + } + + {{- if $f.Optional }} + // {{ $f.MutationClear }} clears the value of the "{{ $f.Name }}" field. + func (m *Mutation) {{ $f.MutationClear }}() { + m.{{ $f.BuilderField }} = nil + m.clearedFields[{{ $const }}] = struct{}{} + } + + // {{ $f.MutationCleared }} returns if the "{{ $f.Name }}" field was cleared in this mutation. + func (m *Mutation) {{ $f.MutationCleared }}() bool { + _, ok := m.clearedFields[{{ $const }}] + return ok + } + {{- end }} + + // {{ $f.MutationReset }} resets all changes to the "{{ $f.Name }}" field. + func (m *Mutation) {{ $f.MutationReset }}() { + m.{{ $f.BuilderField }} = nil + {{- if $f.Optional }} + delete(m.clearedFields, {{ $const }}) + {{- end }} + } +{{- end }} +{{- end }} {{ range $e := $.EdgesWithID }} {{ $op := "add" }}{{ $idsFunc := $e.MutationAdd }}{{ if $e.Unique }}{{ $op = "set" }}{{ $idsFunc = $e.MutationSet }}{{ end }} @@ -303,9 +351,11 @@ func (m *Mutation) Type() string { func (m *Mutation) Fields() []string { fields := make([]string, 0, {{ len $.Fields }}) {{- range $f := $.Fields }} + {{- if not $f.IsBlobLazy }} if m.{{ $f.BuilderField }} != nil { fields = append(fields, {{ $f.Constant }}) } + {{- end }} {{- end }} return fields } @@ -317,8 +367,10 @@ func (m *Mutation) Field(name string) (ent.Value, bool) { {{- with $.Fields }} switch name { {{- range $f := $.Fields }} + {{- if not $f.IsBlobLazy }} case {{ $f.Constant }}: return m.{{ $f.MutationGet }}() + {{- end }} {{- end }} } {{- end }} @@ -342,6 +394,7 @@ func (m *Mutation) OldField(ctx context.Context, name string) (ent.Value, error) func (m *Mutation) SetField(name string, value ent.Value) error { switch name { {{- range $f := $.Fields }} + {{- if not $f.IsBlobLazy }} {{- $type := replace $f.Type.String $pkgPrefix "" }} case {{ $f.Constant }}: v, ok := value.({{ $type }}) @@ -350,6 +403,7 @@ func (m *Mutation) SetField(name string, value ent.Value) error { } m.{{ $f.MutationSet }}(v) return nil + {{- end }} {{- end }} } return fmt.Errorf("unknown {{ $.Name }} field %s", name) @@ -444,6 +498,7 @@ func (m *Mutation) ClearField(name string) error { {{- if $.HasOptional }} switch name { {{- range $f := $.Fields }} + {{- if and $f.IsBlobLazy (not $blobSupported) }}{{ continue }}{{ end }} {{- if $f.Optional }} case {{ $f.Constant }}: m.Clear{{ $f.StructField }}() @@ -462,6 +517,7 @@ func (m *Mutation) ResetField(name string) error { {{- with $.Fields }} switch name { {{- range $f := $.Fields }} + {{- if and $f.IsBlobLazy (not $blobSupported) }}{{ continue }}{{ end }} case {{ $f.Constant }}: m.{{ $f.MutationReset }}() return nil diff --git a/entc/gen/template/runtime.tmpl b/entc/gen/template/runtime.tmpl index da84b3ac7..2fa3d4dbb 100644 --- a/entc/gen/template/runtime.tmpl +++ b/entc/gen/template/runtime.tmpl @@ -151,7 +151,7 @@ func init() { {{- end }} {{- end }} {{- end }} - {{- if or $n.HasDefault $n.HasUpdateDefault $n.HasValidators $n.HasValueScanner }} + {{- if or $n.HasDefault $n.HasUpdateDefault $n.HasValidators $n.HasValueScanner $n.HasBlobFields }} {{- with $idx := $n.MixedInFields }} {{- range $i := $idx }} {{ print $pkg "MixinFields" $i }} := {{ $pkg }}Mixin[{{ $i }}].Fields() @@ -164,6 +164,7 @@ func init() { _ = {{ $pkg }}Fields {{- end }} {{- range $i, $f := $fields }} + {{- if $f.IsBlobLazy }}{{ continue }}{{ end }} {{- $desc := print $pkg "Desc" $f.StructField }} {{- $idscan := and $n.HasOneFieldID (eq $f.Name $n.ID.Name) }} {{- /* enum default values handled near their declarations (in type package). */}} @@ -227,6 +228,22 @@ func init() { {{- end }} {{- end }} {{- end }} + {{- if hasTemplate (printf "dialect/%s/model/fields" $.Storage) }} + {{- range $i, $f := $n.BlobFields }} + {{- if $f.HasBlobKey }} + {{- $bkName := print $pkg "." $f.BlobKeyName }} + {{- $desc := print $pkg "BlobDesc" $f.StructField }} + // {{ $desc }} is the schema descriptor for {{ $f.Name }} blob field. + {{- if $f.Position.MixedIn }} + {{ $desc }} := {{ print $pkg "MixinFields" $f.Position.MixinIndex }}[{{ $f.Position.Index }}].Descriptor() + {{- else }} + {{ $desc }} := {{ $pkg }}Fields[{{ $f.Position.Index }}].Descriptor() + {{- end }} + // {{ $bkName }} generates the blob storage key for the {{ $f.Name }} field. + {{ $bkName }} = ent.BlobKeyFunc({{ $desc }}.BlobKey) + {{- end }} + {{- end }} + {{- end }} {{- end }} {{- end }} } diff --git a/entc/gen/template/where.tmpl b/entc/gen/template/where.tmpl index 866b60303..02e1ab58d 100644 --- a/entc/gen/template/where.tmpl +++ b/entc/gen/template/where.tmpl @@ -81,6 +81,7 @@ in the LICENSE file in the root directory of this source tree. {{ end }} {{ range $f := $.Fields }} + {{ if or $f.IsBlobNoColumn $f.IsBlobLazy }}{{ continue }}{{ end }} {{ $func := $f.StructField }} {{/* JSON cannot be compared using "=" and Enum has a type defined with the field name */}} {{ $hasP := not (or $f.IsJSON $f.IsEnum) }} @@ -115,6 +116,7 @@ in the LICENSE file in the root directory of this source tree. {{ end }} {{ range $f := $.Fields }} + {{ if or $f.IsBlobNoColumn $f.IsBlobLazy }}{{ continue }}{{ end }} {{ range $op := $f.Ops }} {{ $arg := "v" }}{{ if $op.Variadic }}{{ $arg = "vs" }}{{ end }} {{ $stringOp := eq $op.Name "EqualFold" "Contains" "ContainsFold" "HasPrefix" "HasSuffix" }} diff --git a/entc/gen/type.go b/entc/gen/type.go index 62d8c877c..04233d6ca 100644 --- a/entc/gen/type.go +++ b/entc/gen/type.go @@ -53,6 +53,8 @@ type ( // ForeignKeys are the foreign-keys that resides in the type table. ForeignKeys []*ForeignKey foreignKeys map[string]struct{} + // BlobKeys are the implicit key columns for blob-stored fields. + BlobKeys []*BlobKey // Annotations that were defined for the field in the schema. // The mapping is from the Annotation.Name() to a JSON decoded object. Annotations Annotations @@ -197,6 +199,14 @@ type ( // UserDefined bool } + // BlobKey holds the information for blob key columns. Similar to a foreign-key, + // it is an implicit string column that stores the storage key for a blob field. + BlobKey struct { + // Field is the implicit string column that stores the blob key. + Field *Field + // BlobField is the blob field this key belongs to. + BlobField *Field + } // Enum holds the enum information for schema enums in codegen. Enum struct { // Name is the Go name of the enum. @@ -551,7 +561,23 @@ func (t Type) NumConstraint() int { func (t Type) MutableFields() []*Field { fields := make([]*Field, 0, len(t.Fields)) for _, f := range t.Fields { - if f.Immutable { + if f.Immutable || f.IsBlobLazy() { + continue + } + if e, err := f.Edge(); err == nil && e.Immutable { + continue + } + fields = append(fields, f) + } + return fields +} + +// UpsertFields returns all mutable fields that have a corresponding SQL column. +// This excludes lazy blob fields that have no struct field or settable column. +func (t Type) UpsertFields() []*Field { + fields := make([]*Field, 0, len(t.Fields)) + for _, f := range t.Fields { + if f.Immutable || f.IsBlobLazy() { continue } if e, err := f.Edge(); err == nil && e.Immutable { @@ -566,7 +592,7 @@ func (t Type) MutableFields() []*Field { func (t Type) ImmutableFields() []*Field { fields := make([]*Field, 0, len(t.Fields)) for _, f := range t.Fields { - if f.Immutable { + if f.Immutable && !f.IsBlob() { fields = append(fields, f) } } @@ -574,10 +600,11 @@ func (t Type) ImmutableFields() []*Field { } // MutationFields returns all the fields that are available on the typed-mutation. +// Lazy blob fields are excluded since they use streaming only. func (t Type) MutationFields() []*Field { fields := make([]*Field, 0, len(t.Fields)) for _, f := range t.Fields { - if !f.IsEdgeField() { + if !f.IsEdgeField() && !f.IsBlobLazy() { fields = append(fields, f) } } @@ -665,7 +692,14 @@ func (t *Type) AddIndex(idx *load.Index) error { } else if f = t.fields[name]; f == nil { return fmt.Errorf("unknown index field %q", name) } - index.Columns = append(index.Columns, f.StorageKey()) + switch { + case f.IsBlobNoColumn() || f.IsBlobLazy(): + // Non-dual-write blob fields have no data column; translate + // the index to the generated blob key column instead. + index.Columns = append(index.Columns, f.BlobKeyColumn()) + default: + index.Columns = append(index.Columns, f.StorageKey()) + } } for _, name := range idx.Edges { var ed *Edge @@ -759,6 +793,25 @@ func (t *Type) setupFKs() error { return nil } +// setupBlobKeys creates implicit key columns for all blob-stored fields. +func (t *Type) setupBlobKeys() { + for _, f := range t.Fields { + if !f.IsBlob() { + continue + } + t.BlobKeys = append(t.BlobKeys, &BlobKey{ + Field: &Field{ + typ: t, + Name: f.BlobKeyColumn(), + Type: &field.TypeInfo{Type: field.TypeString}, + Nillable: true, + Optional: f.Optional, + }, + BlobField: f, + }) + } +} + // setupFieldEdge check the field-edge validity and configures it and its foreign-key. func (t *Type) setupFieldEdge(fk *ForeignKey, fkOwner *Edge, fkName string) error { tf, ok := t.fields[fkName] @@ -1353,6 +1406,41 @@ func (f Field) IsInt64() bool { return f.Type != nil && f.Type.Type == field.Typ // IsEnum returns true if the field is an enum field. func (f Field) IsEnum() bool { return f.Type != nil && f.Type.Type == field.TypeEnum } +// IsBlob reports whether this field is stored in external blob storage. +func (f Field) IsBlob() bool { + return f.def != nil && f.def.Info != nil && f.def.Info.Type == field.TypeBlob +} + +// IsBlobLazy reports whether this is a lazy blob field (streaming-only, no struct field). +func (f Field) IsBlobLazy() bool { + return f.IsBlob() && f.def.BlobLazy +} + +// IsBlobNoColumn reports whether this blob field has no SQL data column. +// True for all blob fields except DualWrite. +func (f Field) IsBlobNoColumn() bool { + return f.IsBlob() && !f.def.BlobDualWrite +} + +// HasBlobKey reports whether this blob field has a user-defined key function. +func (f Field) HasBlobKey() bool { + return f.def != nil && f.def.BlobKey +} + +// BlobKeyName returns the variable name of the key generator for this blob field. +func (f Field) BlobKeyName() string { return "New" + pascal(f.Name) + "Key" } + +// BlobKeyColumn returns the storage column name for the implicit blob key field. +// It is derived from the field's StorageKey (honoring any custom StorageKey setting). +func (f Field) BlobKeyColumn() string { return f.StorageKey() + "_key" } + +// IsBlobGoString reports whether this blob field uses string as its Go type. +// When true, the codegen converts between string and []byte without requiring +// a custom ValueScanner. +func (f Field) IsBlobGoString() bool { + return f.IsBlob() && f.HasGoType() && f.Type.RType.Kind == reflect.String +} + // IsEdgeField reports if the given field is an edge-field (i.e. a foreign-key) // that was referenced by one of the edges. func (f Field) IsEdgeField() bool { return f.fk != nil } @@ -1447,6 +1535,49 @@ func (t Type) DeprecatedFields() []*Field { return fs } +// BlobFields returns all blob-stored fields of the type. +func (t Type) BlobFields() []*Field { + var fs []*Field + for _, f := range t.Fields { + if f.IsBlob() { + fs = append(fs, f) + } + } + return fs +} + +// HasBlobFields reports whether the type has any blob-stored fields. +func (t Type) HasBlobFields() bool { + for _, f := range t.Fields { + if f.IsBlob() { + return true + } + } + return false +} + +// LoadOnScanFields returns non-lazy blob fields that auto-load from blob storage on scan. +// For DualWrite fields, the blob value is preferred when a key exists; otherwise the column value is used. +func (t Type) LoadOnScanFields() []*Field { + var fs []*Field + for _, f := range t.Fields { + if f.IsBlob() && !f.IsBlobLazy() { + fs = append(fs, f) + } + } + return fs +} + +// HasLoadOnScanFields reports whether the type has any blob fields that auto-load on scan. +func (t Type) HasLoadOnScanFields() bool { + for _, f := range t.Fields { + if f.IsBlob() && !f.IsBlobLazy() { + return true + } + } + return false +} + // HasValueScanner indicates if the field has (an external) ValueScanner. func (f Field) HasValueScanner() bool { return f.def != nil && f.def.ValueScanner @@ -2183,6 +2314,11 @@ func (f ForeignKey) StructField() string { return f.Field.Name } +// StructField returns the struct field name of the blob key. +func (bk BlobKey) StructField() string { + return bk.Field.Name +} + // Rel is a relation type of an edge. type Rel int diff --git a/entc/gen/type_test.go b/entc/gen/type_test.go index 86aad6b70..19225b72b 100644 --- a/entc/gen/type_test.go +++ b/entc/gen/type_test.go @@ -337,3 +337,100 @@ func TestValidSchemaName(t *testing.T) { err = ValidSchemaName("Order") require.NoError(t, err) } + +func TestField_Blob(t *testing.T) { + require := require.New(t) + + // Test creating a type with blob-stored fields. + typ, err := NewType(&Config{Package: "entc/gen"}, &load.Schema{ + Name: "Document", + Fields: []*load.Field{ + { + Name: "content", + Info: &field.TypeInfo{Type: field.TypeBlob}, + Optional: true, + Comment: "blob content", + }, + { + Name: "thumbnail", + Info: &field.TypeInfo{Type: field.TypeBlob}, + }, + { + Name: "title", + Info: &field.TypeInfo{Type: field.TypeString}, + }, + }, + }) + require.NoError(err) + require.NotNil(typ) + require.Equal("Document", typ.Name) + + // Find blob fields. + require.True(typ.HasBlobFields()) + blobFields := typ.BlobFields() + require.Len(blobFields, 2) + + // First blob field: content. + f0 := blobFields[0] + require.Equal("content", f0.Name) + require.True(f0.IsBlob()) + require.True(f0.Optional) + + // Second blob field: thumbnail. + f1 := blobFields[1] + require.Equal("thumbnail", f1.Name) + require.True(f1.IsBlob()) + + // Non-blob field should not be blob-stored. + titleField := typ.Fields[2] + require.Equal("title", titleField.Name) + require.False(titleField.IsBlob()) + + // MutationFields should exclude only lazy blob fields. + mutFields := typ.MutationFields() + for _, mf := range mutFields { + require.False(mf.IsBlobLazy(), "lazy blob field %s should not be in MutationFields", mf.Name) + } + // Non-lazy blob fields (content, thumbnail) should be in MutationFields + // because their mutation struct fields are used by the blob hook. + var blobInMut int + for _, mf := range mutFields { + if mf.IsBlob() { + blobInMut++ + } + } + require.Equal(2, blobInMut) + + // Type without blob fields. + typ2, err := NewType(&Config{Package: "entc/gen"}, &load.Schema{ + Name: "Simple", + Fields: []*load.Field{ + {Name: "name", Info: &field.TypeInfo{Type: field.TypeString}}, + }, + }) + require.NoError(err) + require.False(typ2.HasBlobFields()) + require.Empty(typ2.BlobFields()) +} + +func TestField_BlobScanType(t *testing.T) { + require := require.New(t) + + typ, err := NewType(&Config{Package: "entc/gen"}, &load.Schema{ + Name: "Doc", + Fields: []*load.Field{ + { + Name: "data", + Info: &field.TypeInfo{Type: field.TypeBlob}, + }, + }, + }) + require.NoError(err) + f := typ.Fields[0] + require.True(f.IsBlob()) + require.True(f.IsBlobNoColumn()) + require.False(f.IsBlobLazy()) + // Non-lazy blob fields appear in MutationFields (for blob hook usage), + // but are excluded from SQL columns by IsBlobNoColumn. + require.NotEmpty(typ.MutationFields()) +} diff --git a/entc/load/schema.go b/entc/load/schema.go index e11a33a1a..68b366a6a 100644 --- a/entc/load/schema.go +++ b/entc/load/schema.go @@ -63,6 +63,10 @@ type Field struct { Comment string `json:"comment,omitempty"` Deprecated bool `json:"deprecated,omitempty"` DeprecatedReason string `json:"deprecated_reason,omitempty"` + BlobKey bool `json:"blob_key,omitempty"` + BlobDualWrite bool `json:"blob_dual_write,omitempty"` + BlobDWSchemaType map[string]string `json:"blob_dw_schema_type,omitempty"` + BlobLazy bool `json:"blob_lazy,omitempty"` } // Edge represents an ent.Edge that was loaded from a complied user package. @@ -144,6 +148,10 @@ func NewField(fd *field.Descriptor) (*Field, error) { Comment: fd.Comment, Deprecated: fd.Deprecated, DeprecatedReason: fd.DeprecatedReason, + BlobKey: fd.BlobKey != nil, + BlobDualWrite: fd.BlobDualWrite, + BlobLazy: fd.BlobLazy, + BlobDWSchemaType: fd.BlobDWSchemaType, } for _, at := range fd.Annotations { sf.addAnnotation(at) diff --git a/entc/load/schema_test.go b/entc/load/schema_test.go index c297e100c..97b19c882 100644 --- a/entc/load/schema_test.go +++ b/entc/load/schema_test.go @@ -344,6 +344,52 @@ func TestMarshalDefaults(t *testing.T) { require.Equal(t, schema.Fields[8].DefaultKind, reflect.Func) } +// BlobDoc is a test schema with blob-stored fields. +type BlobDoc struct { + ent.Schema +} + +func (BlobDoc) Fields() []ent.Field { + return []ent.Field{ + field.Blob("content"). + Comment("blob content"), + field.Blob("thumbnail"), + } +} + +func (BlobDoc) Edges() []ent.Edge { return nil } + +func TestMarshalBlobSchema(t *testing.T) { + d := BlobDoc{} + buf, err := MarshalSchema(d) + require.NoError(t, err) + + s := &Schema{} + err = json.Unmarshal(buf, s) + require.NoError(t, err) + + require.Equal(t, "BlobDoc", s.Name) + require.Len(t, s.Fields, 2) + + // First blob field: content. + f0 := s.Fields[0] + require.Equal(t, "content", f0.Name) + require.Equal(t, field.TypeBlob, f0.Info.Type) + require.Equal(t, "blob content", f0.Comment) + + // Second blob field: thumbnail. + f1 := s.Fields[1] + require.Equal(t, "thumbnail", f1.Name) + require.Equal(t, field.TypeBlob, f1.Info.Type) + + // Verify JSON roundtrip preserves type. + buf2, err := json.Marshal(s) + require.NoError(t, err) + s2 := &Schema{} + require.NoError(t, json.Unmarshal(buf2, s2)) + require.Equal(t, field.TypeBlob, s2.Fields[0].Info.Type) +} + type TimeMixin struct { mixin.Schema } diff --git a/schema/field/field.go b/schema/field/field.go index 541faf141..642e4fe4b 100644 --- a/schema/field/field.go +++ b/schema/field/field.go @@ -5,11 +5,14 @@ package field import ( + "context" "database/sql" "database/sql/driver" "encoding" + "encoding/hex" "errors" "fmt" + "hash" "math" "reflect" "regexp" @@ -18,6 +21,8 @@ import ( "unicode/utf8" "entgo.io/ent/schema" + + "github.com/google/uuid" ) // String returns a new Field with type string. @@ -183,6 +188,24 @@ func Other(name string, typ driver.Valuer) *otherBuilder { return ob } +// Blob returns a new Field with type blob. Blob fields store their data in +// external blob storage rather than in the database. By default, the mutation +// accepts []byte (or a custom GoType) and the entity struct holds the loaded +// value. Use Lazy() to accept an io.Reader in the mutation and omit the struct +// field from the entity; reading requires explicit use of the Reader method. +// +// field.Blob("content"). +// Lazy() +// +// field.Blob("avatar"). +// Optional() +func Blob(name string) *blobBuilder { + return &blobBuilder{&Descriptor{ + Name: name, + Info: &TypeInfo{Type: TypeBlob}, + }} +} + // stringBuilder is the builder for string fields. type stringBuilder struct { desc *Descriptor @@ -1417,28 +1440,181 @@ func (b *otherBuilder) Descriptor() *Descriptor { return b.desc } +// blobBuilder is the builder for blob fields. +type blobBuilder struct { + desc *Descriptor +} + +// Optional indicates that this field is optional on create. +// Unlike edges, fields are required by default. +func (b *blobBuilder) Optional() *blobBuilder { + b.desc.Optional = true + return b +} + +// Immutable indicates that this field cannot be updated. +func (b *blobBuilder) Immutable() *blobBuilder { + b.desc.Immutable = true + return b +} + +// Comment sets the comment of the field. +func (b *blobBuilder) Comment(c string) *blobBuilder { + b.desc.Comment = c + return b +} + +// StructTag sets the struct tag of the field. +func (b *blobBuilder) StructTag(s string) *blobBuilder { + b.desc.Tag = s + return b +} + +// StorageKey sets the storage key of the field. +func (b *blobBuilder) StorageKey(key string) *blobBuilder { + b.desc.StorageKey = key + return b +} + +// Annotations adds a list of annotations to the field object to be used by +// codegen extensions. +func (b *blobBuilder) Annotations(annotations ...schema.Annotation) *blobBuilder { + b.desc.Annotations = append(b.desc.Annotations, annotations...) + return b +} + +// Deprecated marks the field as deprecated. +func (b *blobBuilder) Deprecated(reason ...string) *blobBuilder { + b.desc.Deprecated = true + if len(reason) > 0 { + b.desc.DeprecatedReason = strings.Join(reason, " ") + } + return b +} + +// UUIDKey configures the blob field to use random UUID keys. +// Each write generates a new random UUID as the storage key. +// +// field.Blob("content").UUIDKey() +func (b *blobBuilder) UUIDKey() *blobBuilder { + b.desc.BlobKey = func(context.Context, []byte) (string, error) { + i, err := uuid.NewV7() + if err != nil { + return "", fmt.Errorf("generating uuid key: %w", err) + } + return i.String(), nil + } + return b +} + +// HashKey configures the blob field to use content-addressable keys. +// The blob data is hashed with the given hash function to produce the storage key. +// This enables deduplication: identical content always maps to the same key. +// The default key strategy (when neither UUIDKey nor HashKey is called) is HashKey(crypto.SHA256). +// +// field.Blob("content").HashKey(crypto.SHA256) +func (b *blobBuilder) HashKey(c interface{ New() hash.Hash }) *blobBuilder { + b.desc.BlobKey = func(_ context.Context, data []byte) (string, error) { + h := c.New() + h.Write(data) + return hex.EncodeToString(h.Sum(nil)), nil + } + return b +} + +// Lazy disables automatic loading of blob data into the entity struct field +// after scanning from the database. By default, blob fields auto-load their data +// from storage on scan. When Lazy is set, the field accepts an io.Reader in the +// mutation builder (which is fully read before writing to storage), and the field +// does not appear as a struct field on the entity. Reading requires explicit use +// of the generated Reader method (e.g., ContentReader). +func (b *blobBuilder) Lazy() *blobBuilder { + b.desc.BlobLazy = true + return b +} + +// DualWrite enables migration mode for the blob field. In this mode, the +// original bytes column is preserved alongside the blob key column. Writes +// go to both blob storage and the bytes column, while reads prefer blob +// storage (if a key exists) and fall back to the bytes column. +// +// The optional columnType argument overrides the default database column type +// (per dialect) to avoid schema drift when migrating from an existing column. +// For example, when migrating a JSON column to blob storage: +// +// field.Blob("payload"). +// DualWrite(map[string]string{ +// dialect.MySQL: "json", +// dialect.Postgres: "jsonb", +// dialect.SQLite: "json", +// }) +func (b *blobBuilder) DualWrite(columnType ...map[string]string) *blobBuilder { + b.desc.BlobDualWrite = true + if len(columnType) > 0 { + b.desc.BlobDWSchemaType = columnType[0] + } + return b +} + +// GoType overrides the default Go type ([]byte) with a custom one. +// For string, the conversion is handled automatically: +// +// field.Blob("description").GoType("") +// +// For other types, a ValueScanner is required: +// +// field.Blob("config").GoType(&MyConfig{}).ValueScanner(configScanner{}) +func (b *blobBuilder) GoType(typ any) *blobBuilder { + b.desc.goType(typ) + return b +} + +// ValueScanner provides a custom codec for the blob field data. +// The scanner converts between the Go type and the raw bytes stored in blob storage. +// This is required when GoType is set to a type other than []byte or string. +func (b *blobBuilder) ValueScanner(vs any) *blobBuilder { + b.desc.ValueScanner = vs + return b +} + +// Nillable indicates that this field is a nillable. +// Unlike "Optional" only fields, "Nillable" fields are pointers in the generated struct. +func (b *blobBuilder) Nillable() *blobBuilder { + b.desc.Nillable = true + return b +} + +// Descriptor implements the ent.Field interface by returning its descriptor. +func (b *blobBuilder) Descriptor() *Descriptor { + return b.desc +} + // A Descriptor for field configuration. type Descriptor struct { - Tag string // struct tag. - Size int // varchar size. - Name string // field name. - Info *TypeInfo // field type info. - ValueScanner any // custom field codec. - Unique bool // unique index of field. - Nillable bool // nillable struct field. - Optional bool // nullable field in database. - Immutable bool // create only field. - Default any // default value on create. - UpdateDefault any // default value on update. - Validators []any // validator functions. - StorageKey string // sql column or gremlin property. - Enums []struct{ N, V string } // enum values. - Sensitive bool // sensitive info string field. - SchemaType map[string]string // override the schema type. - Annotations []schema.Annotation // field annotations. - Comment string // field comment. - Deprecated bool // mark the field as deprecated. - DeprecatedReason string // deprecation reason. + Tag string // struct tag. + Size int // varchar size. + Name string // field name. + Info *TypeInfo // field type info. + ValueScanner any // custom field codec. + Unique bool // unique index of field. + Nillable bool // nillable struct field. + Optional bool // nullable field in database. + Immutable bool // create only field. + Default any // default value on create. + UpdateDefault any // default value on update. + Validators []any // validator functions. + StorageKey string // sql column or gremlin property. + Enums []struct{ N, V string } // enum values. + Sensitive bool // sensitive info string field. + SchemaType map[string]string // override the schema type. + Annotations []schema.Annotation // field annotations. + Comment string // field comment. + Deprecated bool // mark the field as deprecated. + DeprecatedReason string // deprecation reason. + BlobKey func(context.Context, []byte) (string, error) // blob key generation function. + BlobDualWrite bool // dual-write mode: write to both blob storage and bytes column. + BlobLazy bool // lazy loading: don't auto-load blob data on scan. + BlobDWSchemaType map[string]string // override the schema type for the dual-write column. Err error } diff --git a/schema/field/field_test.go b/schema/field/field_test.go index 472a894f7..12ae1f1c0 100644 --- a/schema/field/field_test.go +++ b/schema/field/field_test.go @@ -943,7 +943,7 @@ func TestTypeString(t *testing.T) { assert.Equal(t, "bool", typ.String()) typ = field.TypeInvalid assert.Equal(t, "invalid", typ.String()) - typ = 21 + typ = 22 assert.Equal(t, "invalid", typ.String()) } @@ -959,7 +959,7 @@ func TestTypeValid(t *testing.T) { assert.True(t, typ.Valid()) typ = 0 assert.False(t, typ.Valid()) - typ = 21 + typ = 22 assert.False(t, typ.Valid()) } @@ -972,7 +972,7 @@ func TestTypeConstName(t *testing.T) { assert.Equal(t, "TypeInt64", typ.ConstName()) typ = field.TypeOther assert.Equal(t, "TypeOther", typ.ConstName()) - typ = 21 + typ = 22 assert.Equal(t, "invalid", typ.ConstName()) } diff --git a/schema/field/type.go b/schema/field/type.go index 92ab6975d..e9740aa23 100644 --- a/schema/field/type.go +++ b/schema/field/type.go @@ -36,6 +36,7 @@ const ( TypeUint64 TypeFloat32 TypeFloat64 + TypeBlob endTypes ) @@ -49,7 +50,7 @@ func (t Type) String() string { // Numeric reports if the given type is a numeric type. func (t Type) Numeric() bool { - return t >= TypeInt8 && t < endTypes + return t >= TypeInt8 && t < TypeBlob } // Float reports if the given type is a float type. @@ -166,6 +167,7 @@ var ( TypeEnum: "string", TypeString: "string", TypeOther: "other", + TypeBlob: "blob", TypeInt: "int", TypeInt8: "int8", TypeInt16: "int16", @@ -186,6 +188,7 @@ var ( TypeEnum: "TypeEnum", TypeBytes: "TypeBytes", TypeOther: "TypeOther", + TypeBlob: "TypeBlob", } )