From 8b85c83e00b9152687d25b1f3590b0f3489e3b37 Mon Sep 17 00:00:00 2001 From: Jannik Clausen <12862103+masseelch@users.noreply.github.com> Date: Fri, 21 Feb 2025 15:16:17 +0100 Subject: [PATCH] dialect/sql/schema: add multi schema and view support for schema dump (#4335) --- cmd/internal/base/base.go | 6 +- dialect/sql/builder.go | 102 +++- dialect/sql/builder_test.go | 25 + dialect/sql/schema/atlas.go | 106 ++-- dialect/sql/schema/schema.go | 24 + dialect/sql/schema/schema_test.go | 107 +++- entc/gen/func.go | 11 + entc/gen/graph_test.go | 41 ++ entc/gen/storage.go | 10 +- entc/gen/template/migrate/schema.tmpl | 58 ++- entc/integration/multischema/ent/cleanuser.go | 89 ++++ .../multischema/ent/cleanuser/cleanuser.go | 43 ++ .../multischema/ent/cleanuser/where.go | 97 ++++ .../multischema/ent/cleanuser_query.go | 486 ++++++++++++++++++ entc/integration/multischema/ent/client.go | 54 +- entc/integration/multischema/ent/ent.go | 2 + .../multischema/ent/internal/schemaconfig.go | 1 + entc/integration/multischema/ent/mutation.go | 1 + .../multischema/ent/predicate/predicate.go | 3 + .../multischema/ent/schema/base.go | 19 + .../multischema/ent/schema/friendship.go | 9 + .../multischema/ent/schema/group.go | 9 + .../integration/multischema/ent/schema/pet.go | 2 +- .../multischema/ent/schema/user.go | 24 +- entc/integration/multischema/ent/tx.go | 5 +- .../multischema/multischema_test.go | 48 +- .../multischema/versioned/migrate/schema.go | 6 - 27 files changed, 1254 insertions(+), 134 deletions(-) create mode 100644 entc/integration/multischema/ent/cleanuser.go create mode 100644 entc/integration/multischema/ent/cleanuser/cleanuser.go create mode 100644 entc/integration/multischema/ent/cleanuser/where.go create mode 100644 entc/integration/multischema/ent/cleanuser_query.go create mode 100644 entc/integration/multischema/ent/schema/base.go diff --git a/cmd/internal/base/base.go b/cmd/internal/base/base.go index 9609cbd2f..19ce7e8f5 100644 --- a/cmd/internal/base/base.go +++ b/cmd/internal/base/base.go @@ -246,7 +246,11 @@ func SchemaCmd() *cobra.Command { if err != nil { log.Fatalln(err) } - ddl, err := schema.Dump(cmd.Context(), dlct, version, t) + v, err := g.Views() + if err != nil { + log.Fatalln(err) + } + ddl, err := schema.Dump(cmd.Context(), dlct, version, append(t, v...)) if err != nil { log.Fatalln(err) } diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 60241fb90..30503dbbe 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -240,6 +240,80 @@ func (t *TableBuilder) Query() (string, []any) { return t.String(), t.args } +// ViewBuilder is a query builder for `CREATE VIEW` statement. +type ViewBuilder struct { + Builder + schema string // view schema. + name string // view name. + exists bool // check existence. + columns []Querier // table columns. + as Querier // view query. +} + +// CreateView returns a query builder for the `CREATE VIEW` statement. +// +// t := Table("users") +// CreateView("clean_users"). +// Columns( +// Column("id").Type("int").Attr("auto_increment"), +// Column("name").Type("varchar(255)"), +// ). +// As(Select(t.C("id"), t.C("name")).From(t)) +func CreateView(name string) *ViewBuilder { return &ViewBuilder{name: name} } + +// Schema sets the database name for the view. +func (v *ViewBuilder) Schema(name string) *ViewBuilder { + v.schema = name + return v +} + +// IfNotExists appends the `IF NOT EXISTS` clause to the `CREATE VIEW` statement. +func (v *ViewBuilder) IfNotExists() *ViewBuilder { + v.exists = true + return v +} + +// Column appends the given column to the `CREATE VIEW` statement. +func (v *ViewBuilder) Column(c *ColumnBuilder) *ViewBuilder { + v.columns = append(v.columns, c) + return v +} + +// Columns appends a list of columns to the builder. +func (v *ViewBuilder) Columns(columns ...*ColumnBuilder) *ViewBuilder { + v.columns = make([]Querier, 0, len(columns)) + for i := range columns { + v.columns = append(v.columns, columns[i]) + } + return v +} + +// As sets the view definition to the builder. +func (v *ViewBuilder) As(as Querier) *ViewBuilder { + v.as = as + return v +} + +// Query returns query representation of a `CREATE VIEW` statement. +// +// CREATE VIEW [IF NOT EXISTS] name AS +// +// (view definition) +func (v *ViewBuilder) Query() (string, []any) { + v.WriteString("CREATE VIEW ") + if v.exists { + v.WriteString("IF NOT EXISTS ") + } + v.writeSchema(v.schema) + v.Ident(v.name) + if len(v.columns) > 0 { + v.Pad().Wrap(func(b *Builder) { b.JoinComma(v.columns...) }) + } + v.WriteString(" AS ") + v.Join(v.as) + return v.String(), v.args +} + // DescribeBuilder is a query builder for `DESCRIBE` statement. type DescribeBuilder struct { Builder @@ -3884,7 +3958,7 @@ func (b *Builder) isQualified(s string) bool { ident && !pg && strings.Contains(s, "`.`") // `qualifier`.`column` } -// state wraps the all methods for setting and getting +// state wraps all methods for setting and getting // update state between all queries in the query tree. type state interface { Dialect() string @@ -3930,17 +4004,33 @@ func (d *DialectBuilder) Describe(name string) *DescribeBuilder { // // Dialect(dialect.Postgres). // CreateTable("users"). -// Columns( -// Column("id").Type("int").Attr("auto_increment"), -// Column("name").Type("varchar(255)"), -// ). -// PrimaryKey("id") +// Columns( +// Column("id").Type("int").Attr("auto_increment"), +// Column("name").Type("varchar(255)"), +// ). +// PrimaryKey("id") func (d *DialectBuilder) CreateTable(name string) *TableBuilder { b := CreateTable(name) b.SetDialect(d.dialect) return b } +// CreateView creates a ViewBuilder for the configured dialect. +// +// t := Table("users") +// Dialect(dialect.Postgres). +// CreateView("users"). +// Columns( +// Column("id").Type("int"), +// Column("name").Type("varchar(255)"), +// ). +// As(Select(t.C("id"), t.C("name")).From(t)) +func (d *DialectBuilder) CreateView(name string) *ViewBuilder { + b := CreateView(name) + b.SetDialect(d.dialect) + return b +} + // AlterTable creates a TableAlter for the configured dialect. // // Dialect(dialect.Postgres). diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 4da6d8faa..c3e79d5c9 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -104,6 +104,31 @@ func TestBuilder(t *testing.T) { Reference(Reference().Table("cards").Columns("id")).OnDelete("SET NULL")), wantQuery: `CREATE TABLE IF NOT EXISTS "users"("id" serial, "card_id" int, PRIMARY KEY("id", "name"), FOREIGN KEY("card_id") REFERENCES "cards"("id") ON DELETE SET NULL)`, }, + { + input: CreateView("clean_users"). + Columns( + Column("id").Type("int"), + Column("name").Type("varchar(255)"), + ). + As(Select("id", "name").From(Table("users"))), + wantQuery: "CREATE VIEW `clean_users` (`id` int, `name` varchar(255)) AS SELECT `id`, `name` FROM `users`", + }, + { + input: Dialect(dialect.Postgres). + CreateView("clean_users"). + Columns( + Column("id").Type("int"), + Column("name").Type("varchar(255)"), + ). + As(Select("id", "name").From(Table("users"))), + wantQuery: `CREATE VIEW "clean_users" ("id" int, "name" varchar(255)) AS SELECT "id", "name" FROM "users"`, + }, + { + input: CreateView("clean_users"). + Schema("schema"). + As(Select("id", "name").From(Table("users"))), + wantQuery: "CREATE VIEW `schema`.`clean_users` AS SELECT `id`, `name` FROM `users`", + }, { input: AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). diff --git a/dialect/sql/schema/atlas.go b/dialect/sql/schema/atlas.go index cfad1399e..01dbf3e07 100644 --- a/dialect/sql/schema/atlas.go +++ b/dialect/sql/schema/atlas.go @@ -10,8 +10,10 @@ import ( "database/sql" "errors" "fmt" + "maps" "net/url" "reflect" + "slices" "sort" "strings" @@ -527,15 +529,7 @@ func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc { a.sqlDialect = drv } a.setupTables(tables) - ts, err := a.tables(tables) - if err != nil { - return nil, err - } - vs, err := a.views(tables) - if err != nil { - return nil, err - } - return &schema.Realm{Schemas: []*schema.Schema{{Tables: ts, Views: vs}}}, nil + return a.realm(tables) } } @@ -660,6 +654,14 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) { return tx.Commit() } +// For BC reason, we omit the schema qualifier from the migration plan. +// This is currently limiting migrations to a single schema. +// If multi-schema migrations are required, one should use Atlas' schema loader for Ent. +var noQualifierOpt = func(opts *migrate.PlanOptions) { + var noQualifier string + opts.SchemaQualifier = &noQualifier +} + // planInspect creates the current state by inspecting the connected database, computing the current state of the Ent schema // and proceeds to diff the changes to create a migration plan. func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) { @@ -688,9 +690,15 @@ func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name if err != nil { return nil, err } - desired := realm.Schemas[0] + var desired *schema.Schema + switch { + case realm != nil && len(realm.Schemas) > 0: + desired = realm.Schemas[0] + default: + desired = &schema.Schema{} + } desired.Name, desired.Attrs = current.Name, current.Attrs - return a.diff(ctx, name, current, desired, a.types[len(types):]) + return a.diff(ctx, name, current, desired, a.types[len(types):], noQualifierOpt) } func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*migrate.Plan, error) { @@ -749,12 +757,7 @@ func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (* } return a.diff(ctx, name, current, &schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: desired}, a.types[len(types):], - // For BC reason, we omit the schema qualifier from the migration scripts, - // but that is currently limiting versioned migration to a single schema. - func(opts *migrate.PlanOptions) { - var noQualifier string - opts.SchemaQualifier = &noQualifier - }, + noQualifierOpt, ) } @@ -836,14 +839,33 @@ func (d *db) ExecContext(ctx context.Context, query string, args ...any) (sql.Re return r, nil } -// tables converts an Ent table slice to an atlas table slice -func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) { +// tables converts an Ent table slice to an atlas tables. +func (a *Atlas) realm(tables []*Table) (*schema.Realm, error) { var ( + sm = make(map[string]*schema.Schema) byT = make(map[*Table]*schema.Table) - ts = make([]*schema.Table, 0, len(tables)) ) for _, et := range tables { + if _, ok := sm[et.Schema]; !ok { + sm[et.Schema] = schema.New(et.Schema) + } + s := sm[et.Schema] if et.View { + if et.Annotation == nil || et.Annotation.ViewAs == "" && et.Annotation.ViewFor[a.dialect] == "" { + continue // defined externally + } + def := et.Annotation.ViewFor[a.dialect] + if def == "" { + def = et.Annotation.ViewAs + } + av := schema.NewView(et.Name, def) + if et.Comment != "" { + av.SetComment(et.Comment) + } + if err := a.aVColumns(et, av); err != nil { + return nil, err + } + s.AddViews(av) continue } at := schema.NewTable(et.Name) @@ -871,7 +893,7 @@ func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) { if err := a.aIndexes(et, at); err != nil { return nil, err } - ts = append(ts, at) + s.AddTables(at) byT[et] = at } for _, t1 := range tables { @@ -892,7 +914,7 @@ func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) { fk2.AddColumns(c2) } var refT *schema.Table - for _, t2 := range ts { + for _, t2 := range sm[fk1.RefTable.Schema].Tables { if t2.Name == fk1.RefTable.Name { refT = t2 break @@ -912,31 +934,27 @@ func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) { t2.AddForeignKeys(fk2) } } - return ts, nil + ss := slices.SortedFunc(maps.Values(sm), func(a, b *schema.Schema) int { + return strings.Compare(a.Name, b.Name) + }) + // In case there only is one schema, do not qualify the schema name. + if len(ss) == 1 { + ss[0].Name = "" + } + return &schema.Realm{Schemas: ss}, nil } -// tables converts an Ent table slice to an atlas table slice -func (a *Atlas) views(tables []*Table) ([]*schema.View, error) { - vs := make([]*schema.View, 0, len(tables)) - for _, et := range tables { - // Not a view, or the view defined externally. - if !et.View || et.Annotation == nil || (et.Annotation.ViewAs == "" && et.Annotation.ViewFor[a.dialect] == "") { - continue - } - def := et.Annotation.ViewFor[a.dialect] - if def == "" { - def = et.Annotation.ViewAs - } - av := schema.NewView(et.Name, def) - if et.Comment != "" { - av.SetComment(et.Comment) - } - if err := a.aVColumns(et, av); err != nil { - return nil, err - } - vs = append(vs, av) +// tables converts an Ent table slice to an atlas table slice. +func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) { + r, err := a.realm(tables) + if err != nil { + return nil, err } - return vs, nil + var ts []*schema.Table + for _, s := range r.Schemas { + ts = append(ts, s.Tables...) + } + return ts, nil } func (a *Atlas) aColumns(et *Table, at *schema.Table) error { diff --git a/dialect/sql/schema/schema.go b/dialect/sql/schema/schema.go index 15a2d6f76..03e221281 100644 --- a/dialect/sql/schema/schema.go +++ b/dialect/sql/schema/schema.go @@ -625,6 +625,13 @@ func Dump(ctx context.Context, dialect, version string, tables []*Table, opts .. if err != nil { return "", err } + // Since the Atlas version bundled with Ent does not support view management, + // simply spit out the definition instead of letting Atlas plan them. + var vs []*schema.View + for _, s := range r.Schemas { + vs = append(vs, s.Views...) + s.Views = nil + } var c schema.Changes if slices.ContainsFunc(tables, func(t *Table) bool { return t.Schema != "" }) { c, err = d.RealmDiff(&schema.Realm{}, r) @@ -638,6 +645,23 @@ func Dump(ctx context.Context, dialect, version string, tables []*Table, opts .. if err != nil { return "", err } + for _, v := range vs { + q, _ := sql.Dialect(dialect). + CreateView(v.Name). + Schema(v.Schema.Name). + Columns(func(cols []*schema.Column) (bs []*sql.ColumnBuilder) { + for _, c := range cols { + bs = append(bs, sql.Dialect(dialect).Column(c.Name).Type(c.Type.Raw)) + } + return + }(v.Columns)...). + As(sql.Raw(v.Def)). + Query() + p.Changes = append(p.Changes, &migrate.Change{ + Cmd: q, + Comment: fmt.Sprintf("Add %q view", v.Name), + }) + } f, err := migrate.DefaultFormatter.FormatFile(p) if err != nil { return "", err diff --git a/dialect/sql/schema/schema_test.go b/dialect/sql/schema/schema_test.go index efd8c2316..b44ecef2c 100644 --- a/dialect/sql/schema/schema_test.go +++ b/dialect/sql/schema/schema_test.go @@ -7,9 +7,9 @@ package schema import ( "context" "fmt" + "strings" "testing" - "ariga.io/atlas/sql/migrate" "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema/field" @@ -196,18 +196,100 @@ func TestDump(t *testing.T) { RefColumns: users.Columns[:1], OnDelete: SetDefault, }) - tables = []*Table{users, pets} + petsWithoutFur := &Table{ + Name: "pets_without_fur", + View: true, + Columns: append(pets.Columns[:2], pets.Columns[3]), + Annotation: entsql.View("SELECT id, name, owner_id FROM pets"), + } + tables = []*Table{users, pets, petsWithoutFur} my := func(length int) string { - return fmt.Sprintf("-- Create \"users\" table\nCREATE TABLE `users` (`id` bigint NOT NULL, `name` varchar(%d) NOT NULL, `spouse_id` bigint NOT NULL, PRIMARY KEY (`id`), INDEX `name` (`name`), FOREIGN KEY (`spouse_id`) REFERENCES `users` (`id`) ON UPDATE SET DEFAULT) CHARSET utf8mb4 COLLATE utf8mb4_bin;\n-- Create \"pets\" table\nCREATE TABLE `pets` (`id` bigint NOT NULL, `name` varchar(%d) NOT NULL, `fur_color` enum('black','white') NOT NULL, `owner_id` bigint NOT NULL, UNIQUE INDEX `name` (`name` DESC), FOREIGN KEY (`owner_id`) REFERENCES `users` (`id`) ON DELETE SET DEFAULT) CHARSET utf8mb4 COLLATE utf8mb4_bin;\n", length, length) + return fmt.Sprintf(strings.ReplaceAll(`-- Add new schema named "s1" +CREATE DATABASE $s1$; +-- Add new schema named "s2" +CREATE DATABASE $s2$; +-- Add new schema named "s3" +CREATE DATABASE $s3$; +-- Create "users" table +CREATE TABLE $s1$.$users$ ( + $id$ bigint NOT NULL, + $name$ varchar(%d) NOT NULL, + $spouse_id$ bigint NOT NULL, + PRIMARY KEY ($id$), + INDEX $name$ ($name$), + FOREIGN KEY ($spouse_id$) REFERENCES $s1$.$users$ ($id$) ON UPDATE SET DEFAULT +) CHARSET utf8mb4 COLLATE utf8mb4_bin; +-- Create "pets" table +CREATE TABLE $s2$.$pets$ ( + $id$ bigint NOT NULL, + $name$ varchar(%d) NOT NULL, + $owner_id$ bigint NOT NULL, + $owner_id$ bigint NOT NULL, + UNIQUE INDEX $name$ ($name$ DESC), + FOREIGN KEY ($owner_id$) REFERENCES $s1$.$users$ ($id$) ON DELETE SET DEFAULT +) CHARSET utf8mb4 COLLATE utf8mb4_bin; +-- Add "pets_without_fur" view +CREATE VIEW $s3$.$pets_without_fur$ ($id$, $name$, $owner_id$) AS SELECT id, name, owner_id FROM pets; +`, "$", "`"), length, length) } - pg := "-- Create \"users\" table\nCREATE TABLE \"users\" (\"id\" bigint NOT NULL, \"name\" character varying NOT NULL, \"spouse_id\" bigint NOT NULL, PRIMARY KEY (\"id\"), FOREIGN KEY (\"spouse_id\") REFERENCES \"users\" (\"id\") ON UPDATE SET DEFAULT);\n-- Create index \"name\" to table: \"users\"\nCREATE INDEX \"name\" ON \"users\" (\"name\");\n-- Create \"pets\" table\nCREATE TABLE \"pets\" (\"id\" bigint NOT NULL, \"name\" character varying NOT NULL, \"fur_color\" character varying NOT NULL, \"owner_id\" bigint NOT NULL, FOREIGN KEY (\"owner_id\") REFERENCES \"users\" (\"id\") ON DELETE SET DEFAULT);\n-- Create index \"name\" to table: \"pets\"\nCREATE UNIQUE INDEX \"name\" ON \"pets\" (\"name\" DESC);\n" + pg := `-- Add new schema named "s1" +CREATE SCHEMA "s1"; +-- Add new schema named "s2" +CREATE SCHEMA "s2"; +-- Add new schema named "s3" +CREATE SCHEMA "s3"; +-- Create "users" table +CREATE TABLE "s1"."users" ( + "id" bigint NOT NULL, + "name" character varying NOT NULL, + "spouse_id" bigint NOT NULL, + PRIMARY KEY ("id"), + FOREIGN KEY ("spouse_id") REFERENCES "s1"."users" ("id") ON UPDATE SET DEFAULT +); +-- Create index "name" to table: "users" +CREATE INDEX "name" ON "s1"."users" ("name"); +-- Create "pets" table +CREATE TABLE "s2"."pets" ( + "id" bigint NOT NULL, + "name" character varying NOT NULL, + "owner_id" bigint NOT NULL, + "owner_id" bigint NOT NULL, + FOREIGN KEY ("owner_id") REFERENCES "s1"."users" ("id") ON DELETE SET DEFAULT +); +-- Create index "name" to table: "pets" +CREATE UNIQUE INDEX "name" ON "s2"."pets" ("name" DESC); +-- Add "pets_without_fur" view +CREATE VIEW "s3"."pets_without_fur" ("id", "name", "owner_id") AS SELECT id, name, owner_id FROM pets; +` for _, tt := range []struct{ dialect, version, expected string }{ { dialect.SQLite, "", - "-- Create \"users\" table\nCREATE TABLE `users` (`id` integer NOT NULL, `name` text NOT NULL, `spouse_id` integer NOT NULL, PRIMARY KEY (`id`), FOREIGN KEY (`spouse_id`) REFERENCES `users` (`id`) ON UPDATE SET DEFAULT);\n-- Create index \"name\" to table: \"users\"\nCREATE INDEX `name` ON `users` (`name`);\n-- Create \"pets\" table\nCREATE TABLE `pets` (`id` integer NOT NULL, `name` text NOT NULL, `fur_color` text NOT NULL, `owner_id` integer NOT NULL, FOREIGN KEY (`owner_id`) REFERENCES `users` (`id`) ON DELETE SET DEFAULT);\n-- Create index \"name\" to table: \"pets\"\nCREATE UNIQUE INDEX `name` ON `pets` (`name` DESC);\n", + strings.ReplaceAll(`-- Create "users" table +CREATE TABLE $users$ ( + $id$ integer NOT NULL, + $name$ text NOT NULL, + $spouse_id$ integer NOT NULL, + PRIMARY KEY ($id$), + FOREIGN KEY ($spouse_id$) REFERENCES $users$ ($id$) ON UPDATE SET DEFAULT +); +-- Create index "name" to table: "users" +CREATE INDEX $name$ ON $users$ ($name$); +-- Create "pets" table +CREATE TABLE $pets$ ( + $id$ integer NOT NULL, + $name$ text NOT NULL, + $owner_id$ integer NOT NULL, + $owner_id$ integer NOT NULL, + FOREIGN KEY ($owner_id$) REFERENCES $users$ ($id$) ON DELETE SET DEFAULT +); +-- Create index "name" to table: "pets" +CREATE UNIQUE INDEX $name$ ON $pets$ ($name$ DESC); +-- Add "pets_without_fur" view +CREATE VIEW $pets_without_fur$ ($id$, $name$, $owner_id$) AS SELECT id, name, owner_id FROM pets; +`, "$", "`"), }, {dialect.MySQL, "5.6", my(191)}, {dialect.MySQL, "5.7", my(255)}, @@ -217,10 +299,17 @@ func TestDump(t *testing.T) { {dialect.Postgres, "14", pg}, {dialect.Postgres, "15", pg}, } { - t.Run(fmt.Sprintf("%s:%s", tt.dialect, tt.version), func(t *testing.T) { - ac, err := Dump(context.Background(), tt.dialect, tt.version, tables, func(o *migrate.PlanOptions) { - o.Indent = "" - }) + n := tt.dialect + if tt.version != "" { + n += ":" + tt.version + } + if tt.dialect != dialect.SQLite { + tables[0].Schema = "s1" + tables[1].Schema = "s2" + tables[2].Schema = "s3" + } + t.Run(n, func(t *testing.T) { + ac, err := Dump(context.Background(), tt.dialect, tt.version, tables) require.NoError(t, err) require.Equal(t, tt.expected, ac) }) diff --git a/entc/gen/func.go b/entc/gen/func.go index 7a31087fb..345b62c97 100644 --- a/entc/gen/func.go +++ b/entc/gen/func.go @@ -74,6 +74,7 @@ var ( "slist": list[string], "fail": fail, "replace": strings.ReplaceAll, + "allZero": allZero, } rules = ruleset() acronyms = make(map[string]struct{}) @@ -540,3 +541,13 @@ func jsonString(v any) (string, error) { } return string(b), nil } + +// allZero reports whether all given values are the zero value of their type. +func allZero(v ...any) bool { + for _, x := range v { + if !reflect.ValueOf(x).IsZero() { + return false + } + } + return true +} diff --git a/entc/gen/graph_test.go b/entc/gen/graph_test.go index 57f3b7e62..d07c505de 100644 --- a/entc/gen/graph_test.go +++ b/entc/gen/graph_test.go @@ -11,6 +11,7 @@ import ( "reflect" "testing" + "entgo.io/ent/dialect/entsql" "entgo.io/ent/entc/load" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" @@ -378,6 +379,46 @@ func TestAbortDuplicateFK(t *testing.T) { require.EqualError(t, err, `duplicate foreign-key symbol "owner_id" found in tables "cars" and "pets"`) } +func TestMultiSchemaAnnotation(t *testing.T) { + antFn := func(s string) map[string]any { + return map[string]any{entsql.Annotation{}.Name(): map[string]string{"schema": s}} + } + var ( + user = &load.Schema{ + Name: "User", + Edges: []*load.Edge{ + {Name: "pets", Type: "Pet"}, + {Name: "cars", Type: "Car", Annotations: antFn("two")}, + }, + Annotations: antFn("one"), + } + pet = &load.Schema{ + Name: "Pet", + Edges: []*load.Edge{ + {Name: "owner", Type: "User", RefName: "pets", Inverse: true}, + }, + Annotations: antFn("two"), + } + car = &load.Schema{ + Name: "Car", + Edges: []*load.Edge{ + {Name: "owner", Type: "User", RefName: "cars", Inverse: true}, + }, + Annotations: antFn("two"), + } + ) + g, err := NewGraph(&Config{Package: "entc/gen", Storage: drivers[0]}, user, pet, car) + require.NoError(t, err) + ts, err := g.Tables() + require.NoError(t, err) + require.Len(t, ts, 5) + require.Equal(t, "one", ts[0].Schema) // user + require.Equal(t, "two", ts[1].Schema) // pet + require.Equal(t, "two", ts[2].Schema) // car + require.Equal(t, "one", ts[3].Schema) // user<>pets join table user lives in owner schema + require.Equal(t, "two", ts[4].Schema) // user<>cars edge has annotation and lives in specified schema +} + func TestEnsureCorrectFK(t *testing.T) { var ( user = &load.Schema{ diff --git a/entc/gen/storage.go b/entc/gen/storage.go index e68ebc463..9a01d683a 100644 --- a/entc/gen/storage.go +++ b/entc/gen/storage.go @@ -6,8 +6,9 @@ package gen import ( "fmt" + "maps" "reflect" - "sort" + "slices" "strings" "entgo.io/ent/dialect/gremlin/graph/dsl" @@ -175,12 +176,7 @@ func (g *Graph) TableSchemas() ([]string, error) { } } } - names := make([]string, 0, len(all)) - for s := range all { - names = append(names, s) - } - sort.Strings(names) - return names, nil + return slices.Sorted(maps.Keys(all)), nil } // TableSchema returns the schema name of where the type table resides (intentionally exported). diff --git a/entc/gen/template/migrate/schema.tmpl b/entc/gen/template/migrate/schema.tmpl index 002230b37..0834f7afe 100644 --- a/entc/gen/template/migrate/schema.tmpl +++ b/entc/gen/template/migrate/schema.tmpl @@ -205,36 +205,38 @@ func init() { {{ $table }}.ForeignKeys[{{ $i }}].RefTable = {{ pascal $fk.RefTable.Name | printf "%sTable" }} {{- end }} {{- with $ant := $t.Annotation }} - {{ $table }}.Annotation = &entsql.Annotation{ - {{- with $ant.Table }} - Table: "{{ . }}", - {{- end }} - {{- with $ant.Charset }} - Charset: "{{ . }}", - {{- end }} - {{- with $ant.Collation }} - Collation: "{{ . }}", - {{- end }} - {{- with $ant.Options }} - Options: {{ quote . }}, - {{- end }} - {{- with $ant.Check }} - Check: {{ quote . }}, - {{- end }} - {{- with $ant.IncrementStart }} - IncrementStart: func(i int) *int { return &i }({{ . }}), - {{- end }} - } - {{- with $ant.Incremental }} - {{ $table }}.Annotation.Incremental = new(bool) - *{{ $table }}.Annotation.Incremental = {{ with indirect . }}true{{ else }}false{{ end }} - {{- end }} - {{- with $keys := keys $ant.Checks }} - {{ $table }}.Annotation.Checks = map[string]string{ - {{- range $k := $keys }} - "{{ $k }}": {{ index $ant.Checks $k | quote }}, + {{- if not (allZero $ant.Table $ant.Charset $ant.Collation $ant.Options $ant.Check $ant.IncrementStart $ant.Incremental $ant.Checks) }} + {{ $table }}.Annotation = &entsql.Annotation{ + {{- with $ant.Table }} + Table: "{{ . }}", + {{- end }} + {{- with $ant.Charset }} + Charset: "{{ . }}", + {{- end }} + {{- with $ant.Collation }} + Collation: "{{ . }}", + {{- end }} + {{- with $ant.Options }} + Options: {{ quote . }}, + {{- end }} + {{- with $ant.Check }} + Check: {{ quote . }}, + {{- end }} + {{- with $ant.IncrementStart }} + IncrementStart: func(i int) *int { return &i }({{ . }}), {{- end }} } + {{- with $ant.Incremental }} + {{ $table }}.Annotation.Incremental = new(bool) + *{{ $table }}.Annotation.Incremental = {{ with indirect . }}true{{ else }}false{{ end }} + {{- end }} + {{- with $keys := keys $ant.Checks }} + {{ $table }}.Annotation.Checks = map[string]string{ + {{- range $k := $keys }} + "{{ $k }}": {{ index $ant.Checks $k | quote }}, + {{- end }} + } + {{- end }} {{- end }} {{- end }} {{- end }} diff --git a/entc/integration/multischema/ent/cleanuser.go b/entc/integration/multischema/ent/cleanuser.go new file mode 100644 index 000000000..285051322 --- /dev/null +++ b/entc/integration/multischema/ent/cleanuser.go @@ -0,0 +1,89 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/entc/integration/multischema/ent/cleanuser" +) + +// CleanUser is the model entity for the CleanUser schema. +type CleanUser struct { + config `json:"-"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*CleanUser) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case cleanuser.FieldName: + values[i] = new(sql.NullString) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the CleanUser fields. +func (cu *CleanUser) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case cleanuser.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + cu.Name = value.String + } + default: + cu.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the CleanUser. +// This includes values selected through modifiers, order, etc. +func (cu *CleanUser) Value(name string) (ent.Value, error) { + return cu.selectValues.Get(name) +} + +// Unwrap unwraps the CleanUser entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (cu *CleanUser) Unwrap() *CleanUser { + _tx, ok := cu.config.driver.(*txDriver) + if !ok { + panic("ent: CleanUser is not a transactional entity") + } + cu.config.driver = _tx.drv + return cu +} + +// String implements the fmt.Stringer. +func (cu *CleanUser) String() string { + var builder strings.Builder + builder.WriteString("CleanUser(") + builder.WriteString("name=") + builder.WriteString(cu.Name) + builder.WriteByte(')') + return builder.String() +} + +// CleanUsers is a parsable slice of CleanUser. +type CleanUsers []*CleanUser diff --git a/entc/integration/multischema/ent/cleanuser/cleanuser.go b/entc/integration/multischema/ent/cleanuser/cleanuser.go new file mode 100644 index 000000000..ac8d6aeb7 --- /dev/null +++ b/entc/integration/multischema/ent/cleanuser/cleanuser.go @@ -0,0 +1,43 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated by ent, DO NOT EDIT. + +package cleanuser + +import ( + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the cleanuser type in the database. + Label = "clean_user" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // Table holds the table name of the cleanuser in the database. + Table = "clean_users" +) + +// Columns holds all SQL columns for cleanuser fields. +var Columns = []string{ + FieldName, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// OrderOption defines the ordering options for the CleanUser queries. +type OrderOption func(*sql.Selector) + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} diff --git a/entc/integration/multischema/ent/cleanuser/where.go b/entc/integration/multischema/ent/cleanuser/where.go new file mode 100644 index 000000000..aa9ac144d --- /dev/null +++ b/entc/integration/multischema/ent/cleanuser/where.go @@ -0,0 +1,97 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated by ent, DO NOT EDIT. + +package cleanuser + +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/entc/integration/multischema/ent/predicate" +) + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldEQ(FieldName, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.CleanUser { + return predicate.CleanUser(sql.FieldContainsFold(FieldName, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.CleanUser) predicate.CleanUser { + return predicate.CleanUser(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.CleanUser) predicate.CleanUser { + return predicate.CleanUser(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.CleanUser) predicate.CleanUser { + return predicate.CleanUser(sql.NotPredicates(p)) +} diff --git a/entc/integration/multischema/ent/cleanuser_query.go b/entc/integration/multischema/ent/cleanuser_query.go new file mode 100644 index 000000000..d2f3e3468 --- /dev/null +++ b/entc/integration/multischema/ent/cleanuser_query.go @@ -0,0 +1,486 @@ +// Copyright 2019-present Facebook Inc. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/multischema/ent/cleanuser" + "entgo.io/ent/entc/integration/multischema/ent/internal" + "entgo.io/ent/entc/integration/multischema/ent/predicate" +) + +// CleanUserQuery is the builder for querying CleanUser entities. +type CleanUserQuery struct { + config + ctx *QueryContext + order []cleanuser.OrderOption + inters []Interceptor + predicates []predicate.CleanUser + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the CleanUserQuery builder. +func (cuq *CleanUserQuery) Where(ps ...predicate.CleanUser) *CleanUserQuery { + cuq.predicates = append(cuq.predicates, ps...) + return cuq +} + +// Limit the number of records to be returned by this query. +func (cuq *CleanUserQuery) Limit(limit int) *CleanUserQuery { + cuq.ctx.Limit = &limit + return cuq +} + +// Offset to start from. +func (cuq *CleanUserQuery) Offset(offset int) *CleanUserQuery { + cuq.ctx.Offset = &offset + return cuq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (cuq *CleanUserQuery) Unique(unique bool) *CleanUserQuery { + cuq.ctx.Unique = &unique + return cuq +} + +// Order specifies how the records should be ordered. +func (cuq *CleanUserQuery) Order(o ...cleanuser.OrderOption) *CleanUserQuery { + cuq.order = append(cuq.order, o...) + return cuq +} + +// First returns the first CleanUser entity from the query. +// Returns a *NotFoundError when no CleanUser was found. +func (cuq *CleanUserQuery) First(ctx context.Context) (*CleanUser, error) { + nodes, err := cuq.Limit(1).All(setContextOp(ctx, cuq.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{cleanuser.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (cuq *CleanUserQuery) FirstX(ctx context.Context) *CleanUser { + node, err := cuq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// Only returns a single CleanUser entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one CleanUser entity is found. +// Returns a *NotFoundError when no CleanUser entities are found. +func (cuq *CleanUserQuery) Only(ctx context.Context) (*CleanUser, error) { + nodes, err := cuq.Limit(2).All(setContextOp(ctx, cuq.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{cleanuser.Label} + default: + return nil, &NotSingularError{cleanuser.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (cuq *CleanUserQuery) OnlyX(ctx context.Context) *CleanUser { + node, err := cuq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// All executes the query and returns a list of CleanUsers. +func (cuq *CleanUserQuery) All(ctx context.Context) ([]*CleanUser, error) { + ctx = setContextOp(ctx, cuq.ctx, ent.OpQueryAll) + if err := cuq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*CleanUser, *CleanUserQuery]() + return withInterceptors[[]*CleanUser](ctx, cuq, qr, cuq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (cuq *CleanUserQuery) AllX(ctx context.Context) []*CleanUser { + nodes, err := cuq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// Count returns the count of the given query. +func (cuq *CleanUserQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, cuq.ctx, ent.OpQueryCount) + if err := cuq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, cuq, querierCount[*CleanUserQuery](), cuq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (cuq *CleanUserQuery) CountX(ctx context.Context) int { + count, err := cuq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (cuq *CleanUserQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, cuq.ctx, ent.OpQueryExist) + switch _, err := cuq.First(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (cuq *CleanUserQuery) ExistX(ctx context.Context) bool { + exist, err := cuq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the CleanUserQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (cuq *CleanUserQuery) Clone() *CleanUserQuery { + if cuq == nil { + return nil + } + return &CleanUserQuery{ + config: cuq.config, + ctx: cuq.ctx.Clone(), + order: append([]cleanuser.OrderOption{}, cuq.order...), + inters: append([]Interceptor{}, cuq.inters...), + predicates: append([]predicate.CleanUser{}, cuq.predicates...), + // clone intermediate query. + sql: cuq.sql.Clone(), + path: cuq.path, + modifiers: append([]func(*sql.Selector){}, cuq.modifiers...), + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// Name string `json:"name,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.CleanUser.Query(). +// GroupBy(cleanuser.FieldName). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (cuq *CleanUserQuery) GroupBy(field string, fields ...string) *CleanUserGroupBy { + cuq.ctx.Fields = append([]string{field}, fields...) + grbuild := &CleanUserGroupBy{build: cuq} + grbuild.flds = &cuq.ctx.Fields + grbuild.label = cleanuser.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// Name string `json:"name,omitempty"` +// } +// +// client.CleanUser.Query(). +// Select(cleanuser.FieldName). +// Scan(ctx, &v) +func (cuq *CleanUserQuery) Select(fields ...string) *CleanUserSelect { + cuq.ctx.Fields = append(cuq.ctx.Fields, fields...) + sbuild := &CleanUserSelect{CleanUserQuery: cuq} + sbuild.label = cleanuser.Label + sbuild.flds, sbuild.scan = &cuq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a CleanUserSelect configured with the given aggregations. +func (cuq *CleanUserQuery) Aggregate(fns ...AggregateFunc) *CleanUserSelect { + return cuq.Select().Aggregate(fns...) +} + +func (cuq *CleanUserQuery) prepareQuery(ctx context.Context) error { + for _, inter := range cuq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, cuq); err != nil { + return err + } + } + } + for _, f := range cuq.ctx.Fields { + if !cleanuser.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if cuq.path != nil { + prev, err := cuq.path(ctx) + if err != nil { + return err + } + cuq.sql = prev + } + return nil +} + +func (cuq *CleanUserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*CleanUser, error) { + var ( + nodes = []*CleanUser{} + _spec = cuq.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*CleanUser).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &CleanUser{config: cuq.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + _spec.Node.Schema = cuq.schemaConfig.CleanUser + ctx = internal.NewSchemaConfigContext(ctx, cuq.schemaConfig) + if len(cuq.modifiers) > 0 { + _spec.Modifiers = cuq.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, cuq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (cuq *CleanUserQuery) sqlCount(ctx context.Context) (int, error) { + _spec := cuq.querySpec() + _spec.Node.Schema = cuq.schemaConfig.CleanUser + ctx = internal.NewSchemaConfigContext(ctx, cuq.schemaConfig) + if len(cuq.modifiers) > 0 { + _spec.Modifiers = cuq.modifiers + } + _spec.Node.Columns = cuq.ctx.Fields + if len(cuq.ctx.Fields) > 0 { + _spec.Unique = cuq.ctx.Unique != nil && *cuq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, cuq.driver, _spec) +} + +func (cuq *CleanUserQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(cleanuser.Table, cleanuser.Columns, nil) + _spec.From = cuq.sql + if unique := cuq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if cuq.path != nil { + _spec.Unique = true + } + if fields := cuq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + for i := range fields { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if ps := cuq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := cuq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := cuq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := cuq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (cuq *CleanUserQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(cuq.driver.Dialect()) + t1 := builder.Table(cleanuser.Table) + columns := cuq.ctx.Fields + if len(columns) == 0 { + columns = cleanuser.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if cuq.sql != nil { + selector = cuq.sql + selector.Select(selector.Columns(columns...)...) + } + if cuq.ctx.Unique != nil && *cuq.ctx.Unique { + selector.Distinct() + } + t1.Schema(cuq.schemaConfig.CleanUser) + ctx = internal.NewSchemaConfigContext(ctx, cuq.schemaConfig) + selector.WithContext(ctx) + for _, m := range cuq.modifiers { + m(selector) + } + for _, p := range cuq.predicates { + p(selector) + } + for _, p := range cuq.order { + p(selector) + } + if offset := cuq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := cuq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// Modify adds a query modifier for attaching custom logic to queries. +func (cuq *CleanUserQuery) Modify(modifiers ...func(s *sql.Selector)) *CleanUserSelect { + cuq.modifiers = append(cuq.modifiers, modifiers...) + return cuq.Select() +} + +// CleanUserGroupBy is the group-by builder for CleanUser entities. +type CleanUserGroupBy struct { + selector + build *CleanUserQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (cugb *CleanUserGroupBy) Aggregate(fns ...AggregateFunc) *CleanUserGroupBy { + cugb.fns = append(cugb.fns, fns...) + return cugb +} + +// Scan applies the selector query and scans the result into the given value. +func (cugb *CleanUserGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, cugb.build.ctx, ent.OpQueryGroupBy) + if err := cugb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*CleanUserQuery, *CleanUserGroupBy](ctx, cugb.build, cugb, cugb.build.inters, v) +} + +func (cugb *CleanUserGroupBy) sqlScan(ctx context.Context, root *CleanUserQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(cugb.fns)) + for _, fn := range cugb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*cugb.flds)+len(cugb.fns)) + for _, f := range *cugb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*cugb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := cugb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// CleanUserSelect is the builder for selecting fields of CleanUser entities. +type CleanUserSelect struct { + *CleanUserQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (cus *CleanUserSelect) Aggregate(fns ...AggregateFunc) *CleanUserSelect { + cus.fns = append(cus.fns, fns...) + return cus +} + +// Scan applies the selector query and scans the result into the given value. +func (cus *CleanUserSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, cus.ctx, ent.OpQuerySelect) + if err := cus.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*CleanUserQuery, *CleanUserSelect](ctx, cus.CleanUserQuery, cus, cus.inters, v) +} + +func (cus *CleanUserSelect) sqlScan(ctx context.Context, root *CleanUserQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(cus.fns)) + for _, fn := range cus.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*cus.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := cus.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// Modify adds a query modifier for attaching custom logic to queries. +func (cus *CleanUserSelect) Modify(modifiers ...func(s *sql.Selector)) *CleanUserSelect { + cus.modifiers = append(cus.modifiers, modifiers...) + return cus +} diff --git a/entc/integration/multischema/ent/client.go b/entc/integration/multischema/ent/client.go index 694a1973b..4d08e0bb3 100644 --- a/entc/integration/multischema/ent/client.go +++ b/entc/integration/multischema/ent/client.go @@ -32,6 +32,8 @@ type Client struct { config // Schema is the client for creating, migrating and dropping schema. Schema *migrate.Schema + // CleanUser is the client for interacting with the CleanUser builders. + CleanUser *CleanUserClient // Friendship is the client for interacting with the Friendship builders. Friendship *FriendshipClient // Group is the client for interacting with the Group builders. @@ -51,6 +53,7 @@ func NewClient(opts ...Option) *Client { func (c *Client) init() { c.Schema = migrate.NewSchema(c.driver) + c.CleanUser = NewCleanUserClient(c.config) c.Friendship = NewFriendshipClient(c.config) c.Group = NewGroupClient(c.config) c.Pet = NewPetClient(c.config) @@ -80,6 +83,7 @@ type ( // newConfig creates a new config for the client. func newConfig(opts ...Option) config { cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} + cfg.schemaConfig = DefaultSchemaConfig cfg.options(opts...) return cfg } @@ -149,6 +153,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { return &Tx{ ctx: ctx, config: cfg, + CleanUser: NewCleanUserClient(cfg), Friendship: NewFriendshipClient(cfg), Group: NewGroupClient(cfg), Pet: NewPetClient(cfg), @@ -172,6 +177,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) return &Tx{ ctx: ctx, config: cfg, + CleanUser: NewCleanUserClient(cfg), Friendship: NewFriendshipClient(cfg), Group: NewGroupClient(cfg), Pet: NewPetClient(cfg), @@ -182,7 +188,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) // Debug returns a new debug-client. It's used to get verbose logging on specific operations. // // client.Debug(). -// Friendship. +// CleanUser. // Query(). // Count(ctx) func (c *Client) Debug() *Client { @@ -213,6 +219,7 @@ func (c *Client) Use(hooks ...Hook) { // Intercept adds the query interceptors to all the entity clients. // In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. func (c *Client) Intercept(interceptors ...Interceptor) { + c.CleanUser.Intercept(interceptors...) c.Friendship.Intercept(interceptors...) c.Group.Intercept(interceptors...) c.Pet.Intercept(interceptors...) @@ -235,6 +242,36 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { } } +// CleanUserClient is a client for the CleanUser schema. +type CleanUserClient struct { + config +} + +// NewCleanUserClient returns a client for the CleanUser from the given config. +func NewCleanUserClient(c config) *CleanUserClient { + return &CleanUserClient{config: c} +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `cleanuser.Intercept(f(g(h())))`. +func (c *CleanUserClient) Intercept(interceptors ...Interceptor) { + c.inters.CleanUser = append(c.inters.CleanUser, interceptors...) +} + +// Query returns a query builder for CleanUser. +func (c *CleanUserClient) Query() *CleanUserQuery { + return &CleanUserQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeCleanUser}, + inters: c.Interceptors(), + } +} + +// Interceptors returns the client interceptors. +func (c *CleanUserClient) Interceptors() []Interceptor { + return c.inters.CleanUser +} + // FriendshipClient is a client for the Friendship schema. type FriendshipClient struct { config @@ -925,7 +962,7 @@ type ( Friendship, Group, Pet, User []ent.Hook } inters struct { - Friendship, Group, Pet, User []ent.Interceptor + CleanUser, Friendship, Group, Pet, User []ent.Interceptor } ) @@ -935,6 +972,19 @@ func SchemaConfigFromContext(ctx context.Context) SchemaConfig { return internal.SchemaConfigFromContext(ctx) } +var ( + // DefaultSchemaConfig represents the default schema names for all tables as defined in ent/schema. + DefaultSchemaConfig = SchemaConfig{ + CleanUser: tableSchemas[0], + Friendship: tableSchemas[1], + Group: tableSchemas[1], + GroupUsers: tableSchemas[1], + Pet: tableSchemas[0], + User: tableSchemas[0], + } + tableSchemas = [...]string{"db1", "db2"} +) + // SchemaConfig represents alternative schema names for all tables // that can be passed at runtime. type SchemaConfig = internal.SchemaConfig diff --git a/entc/integration/multischema/ent/ent.go b/entc/integration/multischema/ent/ent.go index 171c20129..ce5b4c4b9 100644 --- a/entc/integration/multischema/ent/ent.go +++ b/entc/integration/multischema/ent/ent.go @@ -16,6 +16,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/multischema/ent/cleanuser" "entgo.io/ent/entc/integration/multischema/ent/friendship" "entgo.io/ent/entc/integration/multischema/ent/group" "entgo.io/ent/entc/integration/multischema/ent/pet" @@ -80,6 +81,7 @@ var ( func checkColumn(table, column string) error { initCheck.Do(func() { columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ + cleanuser.Table: cleanuser.ValidColumn, friendship.Table: friendship.ValidColumn, group.Table: group.ValidColumn, pet.Table: pet.ValidColumn, diff --git a/entc/integration/multischema/ent/internal/schemaconfig.go b/entc/integration/multischema/ent/internal/schemaconfig.go index 711db9535..6dc2cea8a 100644 --- a/entc/integration/multischema/ent/internal/schemaconfig.go +++ b/entc/integration/multischema/ent/internal/schemaconfig.go @@ -11,6 +11,7 @@ import "context" // SchemaConfig represents alternative schema names for all tables // that can be passed at runtime. type SchemaConfig struct { + CleanUser string // CleanUser table. Friendship string // Friendship table. Group string // Group table. GroupUsers string // Group-users->User table. diff --git a/entc/integration/multischema/ent/mutation.go b/entc/integration/multischema/ent/mutation.go index 71940ae08..d928519c6 100644 --- a/entc/integration/multischema/ent/mutation.go +++ b/entc/integration/multischema/ent/mutation.go @@ -31,6 +31,7 @@ const ( OpUpdateOne = ent.OpUpdateOne // Node types. + TypeCleanUser = "CleanUser" TypeFriendship = "Friendship" TypeGroup = "Group" TypePet = "Pet" diff --git a/entc/integration/multischema/ent/predicate/predicate.go b/entc/integration/multischema/ent/predicate/predicate.go index 62ec0bbad..de817ea25 100644 --- a/entc/integration/multischema/ent/predicate/predicate.go +++ b/entc/integration/multischema/ent/predicate/predicate.go @@ -10,6 +10,9 @@ import ( "entgo.io/ent/dialect/sql" ) +// CleanUser is the predicate function for cleanuser builders. +type CleanUser func(*sql.Selector) + // Friendship is the predicate function for friendship builders. type Friendship func(*sql.Selector) diff --git a/entc/integration/multischema/ent/schema/base.go b/entc/integration/multischema/ent/schema/base.go new file mode 100644 index 000000000..70bc2c8a9 --- /dev/null +++ b/entc/integration/multischema/ent/schema/base.go @@ -0,0 +1,19 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" +) + +// base holds the default configuration for most schemas in this package. +type base struct { + ent.Schema +} + +// Annotations of the base schema. +func (base) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Schema("db1"), + } +} diff --git a/entc/integration/multischema/ent/schema/friendship.go b/entc/integration/multischema/ent/schema/friendship.go index 0a788983c..48fdd4712 100644 --- a/entc/integration/multischema/ent/schema/friendship.go +++ b/entc/integration/multischema/ent/schema/friendship.go @@ -8,6 +8,8 @@ import ( "time" "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" @@ -54,3 +56,10 @@ func (Friendship) Indexes() []ent.Index { index.Fields("created_at"), } } + +// Annotations of the Friendship. +func (Friendship) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Schema("db2"), + } +} diff --git a/entc/integration/multischema/ent/schema/group.go b/entc/integration/multischema/ent/schema/group.go index 0232d5d65..7d309cb73 100644 --- a/entc/integration/multischema/ent/schema/group.go +++ b/entc/integration/multischema/ent/schema/group.go @@ -6,6 +6,8 @@ package schema import ( "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" ) @@ -29,3 +31,10 @@ func (Group) Edges() []ent.Edge { edge.To("users", User.Type), } } + +// Annotations of the Group. +func (Group) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Schema("db2"), + } +} diff --git a/entc/integration/multischema/ent/schema/pet.go b/entc/integration/multischema/ent/schema/pet.go index 3e5514278..81647183c 100644 --- a/entc/integration/multischema/ent/schema/pet.go +++ b/entc/integration/multischema/ent/schema/pet.go @@ -12,7 +12,7 @@ import ( // Pet holds the schema definition for the Pet entity. type Pet struct { - ent.Schema + base } // Fields of the Pet. diff --git a/entc/integration/multischema/ent/schema/user.go b/entc/integration/multischema/ent/schema/user.go index 842b7ecff..70be1781b 100644 --- a/entc/integration/multischema/ent/schema/user.go +++ b/entc/integration/multischema/ent/schema/user.go @@ -6,13 +6,15 @@ package schema import ( "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" ) // User holds the schema definition for the User entity. type User struct { - ent.Schema + base } // Fields of the User. @@ -33,3 +35,23 @@ func (User) Edges() []ent.Edge { Through("friendships", Friendship.Type), } } + +// CleanUser represents a user without its PII field. +type CleanUser struct { + ent.View +} + +// Annotations of the CleanUser. +func (CleanUser) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.View("SELECT id, name FROM users"), + entsql.Schema("db1"), + } +} + +// Fields of the CleanUser. +func (CleanUser) Fields() []ent.Field { + return []ent.Field{ + field.String("name"), + } +} diff --git a/entc/integration/multischema/ent/tx.go b/entc/integration/multischema/ent/tx.go index 094af0255..64d4eb4b1 100644 --- a/entc/integration/multischema/ent/tx.go +++ b/entc/integration/multischema/ent/tx.go @@ -16,6 +16,8 @@ import ( // Tx is a transactional client that is created by calling Client.Tx(). type Tx struct { config + // CleanUser is the client for interacting with the CleanUser builders. + CleanUser *CleanUserClient // Friendship is the client for interacting with the Friendship builders. Friendship *FriendshipClient // Group is the client for interacting with the Group builders. @@ -155,6 +157,7 @@ func (tx *Tx) Client() *Client { } func (tx *Tx) init() { + tx.CleanUser = NewCleanUserClient(tx.config) tx.Friendship = NewFriendshipClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.Pet = NewPetClient(tx.config) @@ -168,7 +171,7 @@ func (tx *Tx) init() { // of them in order to commit or rollback the transaction. // // If a closed transaction is embedded in one of the generated entities, and the entity -// applies a query, for example: Friendship.QueryXXX(), the query will be executed +// applies a query, for example: CleanUser.QueryXXX(), the query will be executed // through the driver which created this transaction. // // Note that txDriver is not goroutine safe. diff --git a/entc/integration/multischema/multischema_test.go b/entc/integration/multischema/multischema_test.go index 1eaf56ac5..c6fbdb3a7 100644 --- a/entc/integration/multischema/multischema_test.go +++ b/entc/integration/multischema/multischema_test.go @@ -12,11 +12,10 @@ import ( "testing" "ariga.io/atlas-go-sdk/atlasexec" - atlas "ariga.io/atlas/sql/schema" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/schema" "entgo.io/ent/entc/integration/multischema/ent" - "entgo.io/ent/entc/integration/multischema/ent/friendship" "entgo.io/ent/entc/integration/multischema/ent/group" "entgo.io/ent/entc/integration/multischema/ent/migrate" "entgo.io/ent/entc/integration/multischema/ent/pet" @@ -31,16 +30,27 @@ import ( ) func TestMySQL(t *testing.T) { - db, err := sql.Open("mysql", "root:pass@tcp(localhost:3308)/?parseTime=true") + db, err := sql.Open("mysql", "root:pass@tcp(localhost:3308)/?parseTime=true&multiStatements=true") require.NoError(t, err) - defer db.Close() ctx := context.Background() - _, err = db.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS db1") - require.NoError(t, err, "creating database") - _, err = db.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS db2") - require.NoError(t, err, "creating database") - defer db.ExecContext(ctx, "DROP DATABASE IF EXISTS db1") - defer db.ExecContext(ctx, "DROP DATABASE IF EXISTS db2") + t.Cleanup(func() { + db.ExecContext(ctx, "SET foreign_key_checks = 0") + db.ExecContext(ctx, "DROP DATABASE IF EXISTS db1") + db.ExecContext(ctx, "DROP DATABASE IF EXISTS db2") + db.ExecContext(ctx, "SET foreign_key_checks = 1") + db.Close() + }) + + migrate.PetsTable.Schema = "db1" + migrate.UsersTable.Schema = "db1" + migrate.GroupsTable.Schema = "db2" + migrate.GroupUsersTable.Schema = "db2" + migrate.FriendshipsTable.Schema = "db2" + + pl, err := schema.Dump(ctx, dialect.MySQL, "8.0.19", migrate.Tables) + require.NoError(t, err) + _, err = db.ExecContext(ctx, pl) + require.NoError(t, err) // Default schema for the connection is db1. db1, err := sql.Open("mysql", "root:pass@tcp(localhost:3308)/db1?parseTime=true") @@ -63,7 +73,6 @@ func TestMySQL(t *testing.T) { Friendship: "db2", } client := ent.NewClient(ent.Driver(db1), ent.AlternateSchema(cfg)) - setupSchema(t, client, cfg) pedro := client.Pet.Create().SetName("Pedro").SaveX(ctx) groups := client.Group.CreateBulk( client.Group.Create().SetName("GitHub"), @@ -255,20 +264,3 @@ func TestVersionedMigration(t *testing.T) { require.Len(t, users[0].Edges.Friendships, 1) require.Len(t, users[1].Edges.Friendships, 1) } - -func setupSchema(t *testing.T, client *ent.Client, cfg ent.SchemaConfig) { - err := client.Schema.Create( - context.Background(), - migrate.WithForeignKeys(false), - schema.WithDiffHook(func(next schema.Differ) schema.Differ { - return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) { - for tt, s := range map[string]string{group.Table: cfg.Group, group.UsersTable: cfg.GroupUsers, friendship.Table: cfg.Friendship} { - t1, ok := desired.Table(tt) - require.True(t, ok) - t1.SetSchema(atlas.New(s)) - } - return next.Diff(current, desired) - }) - })) - require.NoError(t, err) -} diff --git a/entc/integration/multischema/versioned/migrate/schema.go b/entc/integration/multischema/versioned/migrate/schema.go index caf905034..aac06ffa3 100644 --- a/entc/integration/multischema/versioned/migrate/schema.go +++ b/entc/integration/multischema/versioned/migrate/schema.go @@ -7,7 +7,6 @@ package migrate import ( - "entgo.io/ent/dialect/entsql" "entgo.io/ent/dialect/sql/schema" "entgo.io/ent/schema/field" ) @@ -159,14 +158,9 @@ var ( func init() { FriendshipsTable.ForeignKeys[0].RefTable = UsersTable FriendshipsTable.ForeignKeys[1].RefTable = UsersTable - FriendshipsTable.Annotation = &entsql.Annotation{} - GroupsTable.Annotation = &entsql.Annotation{} PetsTable.ForeignKeys[0].RefTable = UsersTable - PetsTable.Annotation = &entsql.Annotation{} - UsersTable.Annotation = &entsql.Annotation{} GroupUsersTable.ForeignKeys[0].RefTable = GroupsTable GroupUsersTable.ForeignKeys[1].RefTable = UsersTable UserFollowingTable.ForeignKeys[0].RefTable = UsersTable UserFollowingTable.ForeignKeys[1].RefTable = UsersTable - UserFollowingTable.Annotation = &entsql.Annotation{} }