{{/* Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. */}} {{/* gotype: entgo.io/ent/entc/gen.Graph */}} {{ define "enttest" }} {{ $pkg := base $.Config.Package }} {{ with extend $ "Package" "enttest" -}} {{ template "header" . }} {{ end }} import ( "fmt" "{{ $.Config.Package }}" // required by schema hooks. _ "{{ $.Config.Package }}/runtime" {{ if $.SupportMigrate }} "{{ $.Config.Package }}/migrate" "entgo.io/ent/dialect/sql/schema" {{ end }} ) type ( // TestingT is the interface that is shared between // testing.T and testing.B and used by enttest. TestingT interface { FailNow() Error(...any) } // Option configures client creation. Option func(*options) options struct { opts []{{ $pkg }}.Option {{- if $.SupportMigrate }} migrateOpts []schema.MigrateOption {{- end }} } ) // WithOptions forwards options to client creation. func WithOptions(opts ...{{ $pkg }}.Option) Option { return func(o *options) { o.opts = append(o.opts, opts...) } } {{- if $.SupportMigrate }} // WithMigrateOptions forwards options to auto migration. func WithMigrateOptions(opts ...schema.MigrateOption) Option { return func(o *options) { o.migrateOpts = append(o.migrateOpts, opts...) } } {{- end }} func newOptions(opts []Option) *options { o := &options{} for _, opt := range opts { opt(o) } return o } // Open calls {{ $pkg }}.Open and auto-run migration. func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *{{ $pkg }}.Client { o := newOptions(opts) c, err := {{ $pkg }}.Open(driverName, dataSourceName, o.opts...) if err != nil { t.Error(err) t.FailNow() } {{- if $.SupportMigrate }} migrateSchema(t, c, o) {{- end }} return c } // NewClient calls {{ $pkg }}.NewClient and auto-run migration. func NewClient(t TestingT, opts ...Option) *{{ $pkg }}.Client { o := newOptions(opts) c := {{ $pkg }}.NewClient(o.opts...) {{- if $.SupportMigrate }} migrateSchema(t, c, o) {{- end }} return c } {{- if $.SupportMigrate }} func migrateSchema(t TestingT, c *{{ $pkg }}.Client, o *options) { tables, err := schema.CopyTables(migrate.Tables) if err != nil { t.Error(err) t.FailNow() } if err := migrate.Create(context.Background(), c.Schema, tables, o.migrateOpts...); err != nil { t.Error(err) t.FailNow() } } {{- end }} {{ end }}