diff --git a/dialect/sql/schema/schema.go b/dialect/sql/schema/schema.go index 3a5ea05b4..4412ca30c 100644 --- a/dialect/sql/schema/schema.go +++ b/dialect/sql/schema/schema.go @@ -45,7 +45,8 @@ type Table struct { ForeignKeys []*ForeignKey Annotation *entsql.Annotation Comment string - View bool // Indicate the table is a view. + View bool // Indicate the table is a view. + Pos string // filename:line of the ent schema definition. } // NewTable returns a new table with the given name. @@ -75,6 +76,12 @@ func (t *Table) SetSchema(s string) *Table { return t } +// SetPos sets the table position. +func (t *Table) SetPos(p string) *Table { + t.Pos = p + return t +} + // AddPrimary adds a new primary key to the table. func (t *Table) AddPrimary(c *Column) *Table { c.Key = PrimaryKey @@ -631,6 +638,27 @@ func Dump(ctx context.Context, dialect, version string, tables []*Table, opts .. Comment: fmt.Sprintf("Add %q view", v.Name), }) } + for _, t := range tables { + p.Directives = append(p.Directives, fmt.Sprintf( + "-- atlas:pos %s%s[type=%s] %s", + func() string { + if t.Schema != "" { + return t.Schema + "[type=schema]." + } + return "" + }(), + func() string { + return t.Name + }(), + func() string { + if t.View { + return "view" + } + return "table" + }(), + t.Pos, + )) + } f, err := migrate.DefaultFormatter.FormatFile(p) if err != nil { return "", err diff --git a/dialect/sql/schema/schema_test.go b/dialect/sql/schema/schema_test.go index 531355d51..139fdcbfa 100644 --- a/dialect/sql/schema/schema_test.go +++ b/dialect/sql/schema/schema_test.go @@ -158,6 +158,7 @@ func TestCopyTables(t *testing.T) { func TestDump(t *testing.T) { users := &Table{ Name: "users", + Pos: "users.go:15", Columns: []*Column{ {Name: "id", Type: field.TypeInt}, {Name: "name", Type: field.TypeString}, @@ -178,6 +179,7 @@ func TestDump(t *testing.T) { users.SetAnnotation(&entsql.Annotation{Table: "Users"}) pets := &Table{ Name: "pets", + Pos: "pets.go:15", Columns: []*Column{ {Name: "id", Type: field.TypeInt}, {Name: "name", Type: field.TypeString}, @@ -199,12 +201,14 @@ func TestDump(t *testing.T) { }) petsWithoutFur := &Table{ Name: "pets_without_fur", + Pos: "pets.go:30", View: true, Columns: append(pets.Columns[:2], pets.Columns[3]), Annotation: entsql.View("SELECT id, name, owner_id FROM pets"), } petNames := &Table{ Name: "pet_names", + Pos: "pets.go:45", View: true, Columns: pets.Columns[1:1], Annotation: entsql.ViewFor(dialect.Postgres, func(s *sql.Selector) { @@ -213,8 +217,7 @@ func TestDump(t *testing.T) { } tables = []*Table{users, pets, petsWithoutFur, petNames} - my := func(length int) string { - return fmt.Sprintf(strings.ReplaceAll(`-- Add new schema named "s1" + my := fmt.Sprintf(strings.ReplaceAll(`-- Add new schema named "s1" CREATE DATABASE $s1$; -- Add new schema named "s2" CREATE DATABASE $s2$; @@ -223,7 +226,7 @@ CREATE DATABASE $s3$; -- Create "users" table CREATE TABLE $s1$.$users$ ( $id$ bigint NOT NULL, - $name$ varchar(%d) NOT NULL, + $name$ varchar(255) NOT NULL, $spouse_id$ bigint NOT NULL, PRIMARY KEY ($id$), INDEX $name$ ($name$), @@ -232,7 +235,7 @@ CREATE TABLE $s1$.$users$ ( -- Create "pets" table CREATE TABLE $s2$.$pets$ ( $id$ bigint NOT NULL, - $name$ varchar(%d) NOT NULL, + $name$ varchar(255) NOT NULL, $owner_id$ bigint NOT NULL, $owner_id$ bigint NOT NULL, UNIQUE INDEX $name$ ($name$ DESC), @@ -240,8 +243,7 @@ CREATE TABLE $s2$.$pets$ ( ) CHARSET utf8mb4 COLLATE utf8mb4_bin; -- Add "pets_without_fur" view CREATE VIEW $s3$.$pets_without_fur$ ($id$, $name$, $owner_id$) AS SELECT id, name, owner_id FROM pets; -`, "$", "`"), length, length) - } +`, "$", "`")) pg := `-- Add new schema named "s1" CREATE SCHEMA "s1"; @@ -302,9 +304,9 @@ CREATE UNIQUE INDEX $name$ ON $pets$ ($name$ DESC); CREATE VIEW $pets_without_fur$ ($id$, $name$, $owner_id$) AS SELECT id, name, owner_id FROM pets; `, "$", "`"), }, - {dialect.MySQL, "5.6", my(255)}, - {dialect.MySQL, "5.7", my(255)}, - {dialect.MySQL, "8", my(255)}, + {dialect.MySQL, "5.6", my}, + {dialect.MySQL, "5.7", my}, + {dialect.MySQL, "8", my}, {dialect.Postgres, "12", pg}, {dialect.Postgres, "13", pg}, {dialect.Postgres, "14", pg}, @@ -314,23 +316,44 @@ CREATE VIEW $pets_without_fur$ ($id$, $name$, $owner_id$) AS SELECT id, name, ow if tt.version != "" { n += ":" + tt.version } + pos := `-- atlas:pos users[type=table] users.go:15 +-- atlas:pos pets[type=table] pets.go:15 +-- atlas:pos pets_without_fur[type=view] pets.go:30 +-- atlas:pos pet_names[type=view] pets.go:45 + +` if tt.dialect != dialect.SQLite { tables[0].Schema = "s1" tables[1].Schema = "s2" tables[2].Schema = "s3" tables[3].Schema = "s3" + pos = `-- atlas:pos s1[type=schema].users[type=table] users.go:15 +-- atlas:pos s2[type=schema].pets[type=table] pets.go:15 +-- atlas:pos s3[type=schema].pets_without_fur[type=view] pets.go:30 +-- atlas:pos s3[type=schema].pet_names[type=view] pets.go:45 + +` } t.Run(n, func(t *testing.T) { ac, err := Dump(context.Background(), tt.dialect, tt.version, tables) require.NoError(t, err) - require.Equal(t, tt.expected, ac) + require.Equal(t, pos+tt.expected, ac) }) t.Run(n+" single schema", func(t *testing.T) { ac, err := Dump(context.Background(), tt.dialect, tt.version, tables[0:1]) require.NoError(t, err) if tt.dialect != dialect.SQLite { - require.True(t, strings.HasPrefix(ac, "-- Add new schema named \"s1\""), strings.Split(ac, "\n")[0]) + require.Contains(t, ac, "s1[type=schema].") + require.NotContains(t, ac, "s2[type=schema].") + require.Contains(t, ac, "-- Add new schema named \"s1\"") } }) + t.Run(n+" no schema", func(t *testing.T) { + tables[0].Schema = "" + ac, err := Dump(context.Background(), tt.dialect, tt.version, tables[0:1]) + require.NoError(t, err) + require.NotContains(t, ac, "[type=schema].") + require.Contains(t, ac, "[type=table]") + }) } } diff --git a/entc/gen/graph.go b/entc/gen/graph.go index 4d6ca9d8d..5e33490ad 100644 --- a/entc/gen/graph.go +++ b/entc/gen/graph.go @@ -638,7 +638,8 @@ func (g *Graph) Tables() (all []*schema.Table, err error) { tables := make(map[string]*schema.Table) for _, n := range g.MutableNodes() { table := schema.NewTable(n.Table()). - SetComment(n.sqlComment()) + SetComment(n.sqlComment()). + SetPos(n.Pos()) if n.HasOneFieldID() { table.AddPrimary(n.ID.PK()) } @@ -728,6 +729,8 @@ func (g *Graph) Tables() (all []*schema.Table, err error) { s1, s2 := fkSymbols(e, c1, c2) all = append(all, &schema.Table{ Name: e.Rel.Table, + // Join tables get the position of the edge owner. + Pos: n.Pos(), // Search for edge annotation, or // default to edge owner annotation. Schema: func() string { diff --git a/entc/gen/graph_test.go b/entc/gen/graph_test.go index d07c505de..d161fcb84 100644 --- a/entc/gen/graph_test.go +++ b/entc/gen/graph_test.go @@ -379,6 +379,62 @@ func TestAbortDuplicateFK(t *testing.T) { require.EqualError(t, err, `duplicate foreign-key symbol "owner_id" found in tables "cars" and "pets"`) } +func TestPosition(t *testing.T) { + antFn := func(s string) map[string]any { + return map[string]any{entsql.Annotation{}.Name(): map[string]string{"schema": s}} + } + var ( + user = &load.Schema{ + Name: "User", + Pos: "user.go:1", + Edges: []*load.Edge{ + {Name: "pets", Type: "Pet"}, + {Name: "cars", Type: "Car", Through: &struct{ N, T string }{N: "car_edge", T: "CarOwner"}}, + }, + Annotations: antFn("one"), + } + pet = &load.Schema{ + Name: "Pet", + Pos: "pet.go:10", + Edges: []*load.Edge{ + {Name: "owner", Type: "User", RefName: "pets", Inverse: true}, + }, + Annotations: antFn("two"), + } + car = &load.Schema{ + Name: "Car", + Pos: "car.go:100", + Edges: []*load.Edge{ + {Name: "owners", Type: "User", RefName: "cars", Inverse: true}, + }, + Annotations: antFn("two"), + } + carOwner = &load.Schema{ + Name: "CarOwner", + Pos: "car_owner.go:1000", + Fields: []*load.Field{ + {Name: "user_id", Info: &field.TypeInfo{Type: field.TypeInt}}, + {Name: "car_id", Info: &field.TypeInfo{Type: field.TypeInt}}, + }, + Edges: []*load.Edge{ + {Name: "owner", Type: "User", Field: "user_id", Unique: true, Required: true}, + {Name: "car", Type: "User", Field: "car_id", Unique: true, Required: true}, + }, + Annotations: antFn("two"), + } + ) + g, err := NewGraph(&Config{Package: "entc/gen", Storage: drivers[0]}, user, pet, car, carOwner) + require.NoError(t, err) + ts, err := g.Tables() + require.NoError(t, err) + require.Len(t, ts, 5) + require.Equal(t, ts[0].Pos, "user.go:1") + require.Equal(t, ts[1].Pos, "pet.go:10") + require.Equal(t, ts[2].Pos, "car.go:100") + require.Equal(t, ts[3].Pos, "car_owner.go:1000") // edge schema has its own position + require.Equal(t, ts[4].Pos, "user.go:1") // user owns the pet edge -> user position +} + func TestMultiSchemaAnnotation(t *testing.T) { antFn := func(s string) map[string]any { return map[string]any{entsql.Annotation{}.Name(): map[string]string{"schema": s}} diff --git a/entc/gen/type.go b/entc/gen/type.go index 381cb0857..96ee633db 100644 --- a/entc/gen/type.go +++ b/entc/gen/type.go @@ -339,6 +339,11 @@ func (t Type) Receiver() string { return "_m" } +// Pos returns the filename:line position information of this type in the schema. +func (t Type) Pos() string { + return t.schema.Pos +} + // hasEdge returns true if this type as an edge (reverse or assoc) // with the given name. func (t Type) hasEdge(name string) bool { diff --git a/entc/load/load.go b/entc/load/load.go index f17bad2c9..0306e09ae 100644 --- a/entc/load/load.go +++ b/entc/load/load.go @@ -15,10 +15,12 @@ import ( "go/parser" "go/token" "go/types" + "maps" "os" "os/exec" "path/filepath" "reflect" + "slices" "sort" "strconv" "strings" @@ -61,7 +63,7 @@ type ( // Load loads the schemas package and build the Go plugin with this info. func (c *Config) Load() (*SchemaSpec, error) { - spec, err := c.load() + spec, pos, err := c.load() if err != nil { return nil, fmt.Errorf("entc/load: parse schema dir: %w", err) } @@ -102,6 +104,9 @@ func (c *Config) Load() (*SchemaSpec, error) { } spec.Schemas = append(spec.Schemas, schema) } + for _, s := range spec.Schemas { + s.Pos = pos[s.Name] + } return spec, nil } @@ -109,33 +114,33 @@ func (c *Config) Load() (*SchemaSpec, error) { var entInterface = reflect.TypeOf(struct{ ent.Interface }{}).Field(0).Type // load the ent/schema info. -func (c *Config) load() (*SchemaSpec, error) { +func (c *Config) load() (*SchemaSpec, map[string]string, error) { pkgs, err := packages.Load(&packages.Config{ BuildFlags: c.BuildFlags, Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedModule, }, c.Path, entInterface.PkgPath()) if err != nil { - return nil, fmt.Errorf("loading package: %w", err) + return nil, nil, fmt.Errorf("loading package: %w", err) } if len(pkgs) < 2 { // Check if the package loading failed due to Go-related // errors, such as 'missing go.sum entry'. if err := golist(c.Path, c.BuildFlags); err != nil { - return nil, err + return nil, nil, err } - return nil, fmt.Errorf("missing package information for: %s", c.Path) + return nil, nil, fmt.Errorf("missing package information for: %s", c.Path) } entPkg, pkg := pkgs[0], pkgs[1] if len(pkg.Errors) != 0 { - return nil, c.loadError(pkg.Errors[0]) + return nil, nil, c.loadError(pkg.Errors[0]) } if len(entPkg.Errors) != 0 { - return nil, entPkg.Errors[0] + return nil, nil, entPkg.Errors[0] } if pkgs[0].PkgPath != entInterface.PkgPath() { entPkg, pkg = pkgs[1], pkgs[0] } - var names []string + names := make(map[string]string) iface := entPkg.Types.Scope().Lookup(entInterface.Name()).Type().Underlying().(*types.Interface) for k, v := range pkg.TypesInfo.Defs { typ, ok := v.(*types.TypeName) @@ -144,18 +149,20 @@ func (c *Config) load() (*SchemaSpec, error) { } spec, ok := k.Obj.Decl.(*ast.TypeSpec) if !ok { - return nil, fmt.Errorf("invalid declaration %T for %s", k.Obj.Decl, k.Name) + return nil, nil, fmt.Errorf("invalid declaration %T for %s", k.Obj.Decl, k.Name) } if _, ok := spec.Type.(*ast.StructType); !ok { - return nil, fmt.Errorf("invalid spec type %T for %s", spec.Type, k.Name) + return nil, nil, fmt.Errorf("invalid spec type %T for %s", spec.Type, k.Name) } - names = append(names, k.Name) + p := pkg.Fset.Position(spec.Pos()) + names[k.Name] = fmt.Sprintf("%s:%d", p.Filename, p.Line) } if len(c.Names) == 0 { - c.Names = names + c.Names = slices.Sorted(maps.Keys(names)) + } else { + sort.Strings(c.Names) } - sort.Strings(c.Names) - return &SchemaSpec{PkgPath: pkg.PkgPath, Module: pkg.Module}, nil + return &SchemaSpec{PkgPath: pkg.PkgPath, Module: pkg.Module}, names, nil } func (c *Config) loadError(perr packages.Error) (err error) { diff --git a/entc/load/schema.go b/entc/load/schema.go index 6150c7845..e11a33a1a 100644 --- a/entc/load/schema.go +++ b/entc/load/schema.go @@ -19,6 +19,7 @@ import ( // Schema represents an ent.Schema that was loaded from a complied user package. type Schema struct { Name string `json:"name,omitempty"` + Pos string `json:"-"` View bool `json:"view,omitempty"` Config ent.Config `json:"config,omitempty"` Edges []*Edge `json:"edges,omitempty"`