dialect/sql/schema: add multi schema and view support for schema dump (#4335)

This commit is contained in:
Jannik Clausen
2025-02-21 15:16:17 +01:00
committed by GitHub
parent 727a677465
commit 8b85c83e00
27 changed files with 1254 additions and 134 deletions

View File

@@ -10,8 +10,10 @@ import (
"database/sql"
"errors"
"fmt"
"maps"
"net/url"
"reflect"
"slices"
"sort"
"strings"
@@ -527,15 +529,7 @@ func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc {
a.sqlDialect = drv
}
a.setupTables(tables)
ts, err := a.tables(tables)
if err != nil {
return nil, err
}
vs, err := a.views(tables)
if err != nil {
return nil, err
}
return &schema.Realm{Schemas: []*schema.Schema{{Tables: ts, Views: vs}}}, nil
return a.realm(tables)
}
}
@@ -660,6 +654,14 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
return tx.Commit()
}
// For BC reason, we omit the schema qualifier from the migration plan.
// This is currently limiting migrations to a single schema.
// If multi-schema migrations are required, one should use Atlas' schema loader for Ent.
var noQualifierOpt = func(opts *migrate.PlanOptions) {
var noQualifier string
opts.SchemaQualifier = &noQualifier
}
// planInspect creates the current state by inspecting the connected database, computing the current state of the Ent schema
// and proceeds to diff the changes to create a migration plan.
func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) {
@@ -688,9 +690,15 @@ func (a *Atlas) planInspect(ctx context.Context, conn dialect.ExecQuerier, name
if err != nil {
return nil, err
}
desired := realm.Schemas[0]
var desired *schema.Schema
switch {
case realm != nil && len(realm.Schemas) > 0:
desired = realm.Schemas[0]
default:
desired = &schema.Schema{}
}
desired.Name, desired.Attrs = current.Name, current.Attrs
return a.diff(ctx, name, current, desired, a.types[len(types):])
return a.diff(ctx, name, current, desired, a.types[len(types):], noQualifierOpt)
}
func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*migrate.Plan, error) {
@@ -749,12 +757,7 @@ func (a *Atlas) planReplay(ctx context.Context, name string, tables []*Table) (*
}
return a.diff(ctx, name, current,
&schema.Schema{Name: current.Name, Attrs: current.Attrs, Tables: desired}, a.types[len(types):],
// For BC reason, we omit the schema qualifier from the migration scripts,
// but that is currently limiting versioned migration to a single schema.
func(opts *migrate.PlanOptions) {
var noQualifier string
opts.SchemaQualifier = &noQualifier
},
noQualifierOpt,
)
}
@@ -836,14 +839,33 @@ func (d *db) ExecContext(ctx context.Context, query string, args ...any) (sql.Re
return r, nil
}
// tables converts an Ent table slice to an atlas table slice
func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) {
// tables converts an Ent table slice to an atlas tables.
func (a *Atlas) realm(tables []*Table) (*schema.Realm, error) {
var (
sm = make(map[string]*schema.Schema)
byT = make(map[*Table]*schema.Table)
ts = make([]*schema.Table, 0, len(tables))
)
for _, et := range tables {
if _, ok := sm[et.Schema]; !ok {
sm[et.Schema] = schema.New(et.Schema)
}
s := sm[et.Schema]
if et.View {
if et.Annotation == nil || et.Annotation.ViewAs == "" && et.Annotation.ViewFor[a.dialect] == "" {
continue // defined externally
}
def := et.Annotation.ViewFor[a.dialect]
if def == "" {
def = et.Annotation.ViewAs
}
av := schema.NewView(et.Name, def)
if et.Comment != "" {
av.SetComment(et.Comment)
}
if err := a.aVColumns(et, av); err != nil {
return nil, err
}
s.AddViews(av)
continue
}
at := schema.NewTable(et.Name)
@@ -871,7 +893,7 @@ func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) {
if err := a.aIndexes(et, at); err != nil {
return nil, err
}
ts = append(ts, at)
s.AddTables(at)
byT[et] = at
}
for _, t1 := range tables {
@@ -892,7 +914,7 @@ func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) {
fk2.AddColumns(c2)
}
var refT *schema.Table
for _, t2 := range ts {
for _, t2 := range sm[fk1.RefTable.Schema].Tables {
if t2.Name == fk1.RefTable.Name {
refT = t2
break
@@ -912,31 +934,27 @@ func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) {
t2.AddForeignKeys(fk2)
}
}
return ts, nil
ss := slices.SortedFunc(maps.Values(sm), func(a, b *schema.Schema) int {
return strings.Compare(a.Name, b.Name)
})
// In case there only is one schema, do not qualify the schema name.
if len(ss) == 1 {
ss[0].Name = ""
}
return &schema.Realm{Schemas: ss}, nil
}
// tables converts an Ent table slice to an atlas table slice
func (a *Atlas) views(tables []*Table) ([]*schema.View, error) {
vs := make([]*schema.View, 0, len(tables))
for _, et := range tables {
// Not a view, or the view defined externally.
if !et.View || et.Annotation == nil || (et.Annotation.ViewAs == "" && et.Annotation.ViewFor[a.dialect] == "") {
continue
}
def := et.Annotation.ViewFor[a.dialect]
if def == "" {
def = et.Annotation.ViewAs
}
av := schema.NewView(et.Name, def)
if et.Comment != "" {
av.SetComment(et.Comment)
}
if err := a.aVColumns(et, av); err != nil {
return nil, err
}
vs = append(vs, av)
// tables converts an Ent table slice to an atlas table slice.
func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) {
r, err := a.realm(tables)
if err != nil {
return nil, err
}
return vs, nil
var ts []*schema.Table
for _, s := range r.Schemas {
ts = append(ts, s.Tables...)
}
return ts, nil
}
func (a *Atlas) aColumns(et *Table, at *schema.Table) error {

View File

@@ -625,6 +625,13 @@ func Dump(ctx context.Context, dialect, version string, tables []*Table, opts ..
if err != nil {
return "", err
}
// Since the Atlas version bundled with Ent does not support view management,
// simply spit out the definition instead of letting Atlas plan them.
var vs []*schema.View
for _, s := range r.Schemas {
vs = append(vs, s.Views...)
s.Views = nil
}
var c schema.Changes
if slices.ContainsFunc(tables, func(t *Table) bool { return t.Schema != "" }) {
c, err = d.RealmDiff(&schema.Realm{}, r)
@@ -638,6 +645,23 @@ func Dump(ctx context.Context, dialect, version string, tables []*Table, opts ..
if err != nil {
return "", err
}
for _, v := range vs {
q, _ := sql.Dialect(dialect).
CreateView(v.Name).
Schema(v.Schema.Name).
Columns(func(cols []*schema.Column) (bs []*sql.ColumnBuilder) {
for _, c := range cols {
bs = append(bs, sql.Dialect(dialect).Column(c.Name).Type(c.Type.Raw))
}
return
}(v.Columns)...).
As(sql.Raw(v.Def)).
Query()
p.Changes = append(p.Changes, &migrate.Change{
Cmd: q,
Comment: fmt.Sprintf("Add %q view", v.Name),
})
}
f, err := migrate.DefaultFormatter.FormatFile(p)
if err != nil {
return "", err

View File

@@ -7,9 +7,9 @@ package schema
import (
"context"
"fmt"
"strings"
"testing"
"ariga.io/atlas/sql/migrate"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema/field"
@@ -196,18 +196,100 @@ func TestDump(t *testing.T) {
RefColumns: users.Columns[:1],
OnDelete: SetDefault,
})
tables = []*Table{users, pets}
petsWithoutFur := &Table{
Name: "pets_without_fur",
View: true,
Columns: append(pets.Columns[:2], pets.Columns[3]),
Annotation: entsql.View("SELECT id, name, owner_id FROM pets"),
}
tables = []*Table{users, pets, petsWithoutFur}
my := func(length int) string {
return fmt.Sprintf("-- Create \"users\" table\nCREATE TABLE `users` (`id` bigint NOT NULL, `name` varchar(%d) NOT NULL, `spouse_id` bigint NOT NULL, PRIMARY KEY (`id`), INDEX `name` (`name`), FOREIGN KEY (`spouse_id`) REFERENCES `users` (`id`) ON UPDATE SET DEFAULT) CHARSET utf8mb4 COLLATE utf8mb4_bin;\n-- Create \"pets\" table\nCREATE TABLE `pets` (`id` bigint NOT NULL, `name` varchar(%d) NOT NULL, `fur_color` enum('black','white') NOT NULL, `owner_id` bigint NOT NULL, UNIQUE INDEX `name` (`name` DESC), FOREIGN KEY (`owner_id`) REFERENCES `users` (`id`) ON DELETE SET DEFAULT) CHARSET utf8mb4 COLLATE utf8mb4_bin;\n", length, length)
return fmt.Sprintf(strings.ReplaceAll(`-- Add new schema named "s1"
CREATE DATABASE $s1$;
-- Add new schema named "s2"
CREATE DATABASE $s2$;
-- Add new schema named "s3"
CREATE DATABASE $s3$;
-- Create "users" table
CREATE TABLE $s1$.$users$ (
$id$ bigint NOT NULL,
$name$ varchar(%d) NOT NULL,
$spouse_id$ bigint NOT NULL,
PRIMARY KEY ($id$),
INDEX $name$ ($name$),
FOREIGN KEY ($spouse_id$) REFERENCES $s1$.$users$ ($id$) ON UPDATE SET DEFAULT
) CHARSET utf8mb4 COLLATE utf8mb4_bin;
-- Create "pets" table
CREATE TABLE $s2$.$pets$ (
$id$ bigint NOT NULL,
$name$ varchar(%d) NOT NULL,
$owner_id$ bigint NOT NULL,
$owner_id$ bigint NOT NULL,
UNIQUE INDEX $name$ ($name$ DESC),
FOREIGN KEY ($owner_id$) REFERENCES $s1$.$users$ ($id$) ON DELETE SET DEFAULT
) CHARSET utf8mb4 COLLATE utf8mb4_bin;
-- Add "pets_without_fur" view
CREATE VIEW $s3$.$pets_without_fur$ ($id$, $name$, $owner_id$) AS SELECT id, name, owner_id FROM pets;
`, "$", "`"), length, length)
}
pg := "-- Create \"users\" table\nCREATE TABLE \"users\" (\"id\" bigint NOT NULL, \"name\" character varying NOT NULL, \"spouse_id\" bigint NOT NULL, PRIMARY KEY (\"id\"), FOREIGN KEY (\"spouse_id\") REFERENCES \"users\" (\"id\") ON UPDATE SET DEFAULT);\n-- Create index \"name\" to table: \"users\"\nCREATE INDEX \"name\" ON \"users\" (\"name\");\n-- Create \"pets\" table\nCREATE TABLE \"pets\" (\"id\" bigint NOT NULL, \"name\" character varying NOT NULL, \"fur_color\" character varying NOT NULL, \"owner_id\" bigint NOT NULL, FOREIGN KEY (\"owner_id\") REFERENCES \"users\" (\"id\") ON DELETE SET DEFAULT);\n-- Create index \"name\" to table: \"pets\"\nCREATE UNIQUE INDEX \"name\" ON \"pets\" (\"name\" DESC);\n"
pg := `-- Add new schema named "s1"
CREATE SCHEMA "s1";
-- Add new schema named "s2"
CREATE SCHEMA "s2";
-- Add new schema named "s3"
CREATE SCHEMA "s3";
-- Create "users" table
CREATE TABLE "s1"."users" (
"id" bigint NOT NULL,
"name" character varying NOT NULL,
"spouse_id" bigint NOT NULL,
PRIMARY KEY ("id"),
FOREIGN KEY ("spouse_id") REFERENCES "s1"."users" ("id") ON UPDATE SET DEFAULT
);
-- Create index "name" to table: "users"
CREATE INDEX "name" ON "s1"."users" ("name");
-- Create "pets" table
CREATE TABLE "s2"."pets" (
"id" bigint NOT NULL,
"name" character varying NOT NULL,
"owner_id" bigint NOT NULL,
"owner_id" bigint NOT NULL,
FOREIGN KEY ("owner_id") REFERENCES "s1"."users" ("id") ON DELETE SET DEFAULT
);
-- Create index "name" to table: "pets"
CREATE UNIQUE INDEX "name" ON "s2"."pets" ("name" DESC);
-- Add "pets_without_fur" view
CREATE VIEW "s3"."pets_without_fur" ("id", "name", "owner_id") AS SELECT id, name, owner_id FROM pets;
`
for _, tt := range []struct{ dialect, version, expected string }{
{
dialect.SQLite, "",
"-- Create \"users\" table\nCREATE TABLE `users` (`id` integer NOT NULL, `name` text NOT NULL, `spouse_id` integer NOT NULL, PRIMARY KEY (`id`), FOREIGN KEY (`spouse_id`) REFERENCES `users` (`id`) ON UPDATE SET DEFAULT);\n-- Create index \"name\" to table: \"users\"\nCREATE INDEX `name` ON `users` (`name`);\n-- Create \"pets\" table\nCREATE TABLE `pets` (`id` integer NOT NULL, `name` text NOT NULL, `fur_color` text NOT NULL, `owner_id` integer NOT NULL, FOREIGN KEY (`owner_id`) REFERENCES `users` (`id`) ON DELETE SET DEFAULT);\n-- Create index \"name\" to table: \"pets\"\nCREATE UNIQUE INDEX `name` ON `pets` (`name` DESC);\n",
strings.ReplaceAll(`-- Create "users" table
CREATE TABLE $users$ (
$id$ integer NOT NULL,
$name$ text NOT NULL,
$spouse_id$ integer NOT NULL,
PRIMARY KEY ($id$),
FOREIGN KEY ($spouse_id$) REFERENCES $users$ ($id$) ON UPDATE SET DEFAULT
);
-- Create index "name" to table: "users"
CREATE INDEX $name$ ON $users$ ($name$);
-- Create "pets" table
CREATE TABLE $pets$ (
$id$ integer NOT NULL,
$name$ text NOT NULL,
$owner_id$ integer NOT NULL,
$owner_id$ integer NOT NULL,
FOREIGN KEY ($owner_id$) REFERENCES $users$ ($id$) ON DELETE SET DEFAULT
);
-- Create index "name" to table: "pets"
CREATE UNIQUE INDEX $name$ ON $pets$ ($name$ DESC);
-- Add "pets_without_fur" view
CREATE VIEW $pets_without_fur$ ($id$, $name$, $owner_id$) AS SELECT id, name, owner_id FROM pets;
`, "$", "`"),
},
{dialect.MySQL, "5.6", my(191)},
{dialect.MySQL, "5.7", my(255)},
@@ -217,10 +299,17 @@ func TestDump(t *testing.T) {
{dialect.Postgres, "14", pg},
{dialect.Postgres, "15", pg},
} {
t.Run(fmt.Sprintf("%s:%s", tt.dialect, tt.version), func(t *testing.T) {
ac, err := Dump(context.Background(), tt.dialect, tt.version, tables, func(o *migrate.PlanOptions) {
o.Indent = ""
})
n := tt.dialect
if tt.version != "" {
n += ":" + tt.version
}
if tt.dialect != dialect.SQLite {
tables[0].Schema = "s1"
tables[1].Schema = "s2"
tables[2].Schema = "s3"
}
t.Run(n, func(t *testing.T) {
ac, err := Dump(context.Background(), tt.dialect, tt.version, tables)
require.NoError(t, err)
require.Equal(t, tt.expected, ac)
})