dialect/sql/schema: make WriteDriver friendlier (#3119)

Also, add a guide for writing and executing data migrations files.
This commit is contained in:
Ariel Mashraki
2022-11-27 13:27:15 +02:00
committed by GitHub
parent 2840921231
commit f7109f0274
20 changed files with 936 additions and 45 deletions

View File

@@ -5,44 +5,252 @@
package schema
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"strings"
"time"
"unicode"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"ariga.io/atlas/sql/migrate"
)
// WriteDriver is a driver that writes all driver exec operations to its writer.
type WriteDriver struct {
dialect.Driver // underlying driver.
io.Writer // target for exec statements.
type (
// WriteDriver is a driver that writes all driver exec operations to its writer.
// Note that this driver is used only for printing or writing statements to SQL
// files, and may require manual changes to the generated SQL statements.
WriteDriver struct {
dialect.Driver // optional driver for query calls.
io.Writer // target for exec statements.
FormatFunc func(string) (string, error)
}
// DirWriter implements the io.Writer interface
// for writing to an Atlas managed directory.
DirWriter struct {
Dir migrate.Dir // target directory.
Formatter migrate.Formatter // optional formatter.
b bytes.Buffer // working buffer.
changes []*migrate.Change // changes to flush.
}
)
// Write implements the io.Writer interface.
func (d *DirWriter) Write(p []byte) (int, error) { return d.b.Write(p) }
// Change converts all written statement so far into a migration
// change with the given comment.
func (d *DirWriter) Change(comment string) {
// Trim semicolon and new line, because formatter adds it.
d.changes = append(d.changes, &migrate.Change{Comment: comment, Cmd: strings.TrimRight(d.b.String(), ";\n")})
d.b.Reset()
}
// Exec writes its query and calls the underlying driver Exec method.
func (w *WriteDriver) Exec(_ context.Context, query string, _, _ any) error {
// Flush flushes the written statements to the directory.
func (d *DirWriter) Flush(name string) error {
switch {
case d.b.Len() != 0:
return fmt.Errorf("writer has undocumented change. Use Change or FlushChange instead")
case len(d.changes) == 0:
return errors.New("writer has no changes to flush")
default:
return migrate.NewPlanner(nil, d.Dir, migrate.PlanFormat(d.Formatter)).
WritePlan(&migrate.Plan{
Name: name,
Changes: d.changes,
})
}
}
// FlushChange combines Change and Flush.
func (d *DirWriter) FlushChange(name, comment string) error {
d.Change(comment)
return d.Flush(name)
}
// NewWriteDriver creates a dialect.Driver that writes all driver exec statement to its writer.
func NewWriteDriver(dialect string, w io.Writer) *WriteDriver {
return &WriteDriver{
Writer: w,
Driver: nopDriver{dialect: dialect},
}
}
// Exec implements the dialect.Driver.Exec method.
func (w *WriteDriver) Exec(_ context.Context, query string, args, res any) error {
if rr, ok := res.(*sql.Result); ok {
*rr = noResult{}
}
if !strings.HasSuffix(query, ";") {
query += ";"
}
if args != nil {
args, ok := args.([]any)
if !ok {
return fmt.Errorf("unexpected args type: %T", args)
}
query = w.expandArgs(query, args)
}
_, err := io.WriteString(w, query+"\n")
return err
}
// Query implements the dialect.Driver.Query method.
func (w *WriteDriver) Query(ctx context.Context, query string, args, res any) error {
if strings.HasPrefix(query, "INSERT") || strings.HasPrefix(query, "UPDATE") {
if err := w.Exec(ctx, query, args, nil); err != nil {
return err
}
if rr, ok := res.(*sql.Rows); ok {
*rr = sql.Rows{ColumnScanner: noRows{}}
}
}
switch w.Driver.(type) {
case nil, nopDriver:
return errors.New("query is not supported by the WriteDriver")
default:
return w.Driver.Query(ctx, query, args, res)
}
}
// expandArgs combines to arguments and statement into a single statement to
// print or write into a file (before editing).
// Note, the output may be incorrect or unsafe SQL and require manual changes.
func (w *WriteDriver) expandArgs(query string, args []any) string {
var (
b strings.Builder
p = w.placeholder()
scan = w.scanPlaceholder()
)
for i := 0; i < len(query); i++ {
Top:
switch query[i] {
case p:
idx, size := scan(query[i+1:])
// Unrecognized placeholder.
if idx < 0 || idx >= len(args) {
return query
}
i += size
v, err := w.formatArg(args[idx])
if err != nil {
// Unexpected formatting error.
return query
}
b.WriteString(v)
// String or identifier.
case '\'', '"', '`':
for j := i + 1; j < len(query); j++ {
switch query[j] {
case '\\':
j++
case query[i]:
b.WriteString(query[i : j+1])
i = j
break Top
}
}
// Unexpected EOS.
return query
default:
b.WriteByte(query[i])
}
}
return b.String()
}
func (w *WriteDriver) scanPlaceholder() func(string) (int, int) {
switch w.Dialect() {
case dialect.Postgres:
return func(s string) (int, int) {
var i int
for i < len(s) && unicode.IsDigit(rune(s[i])) {
i++
}
idx, err := strconv.ParseInt(s[:i], 10, 64)
if err != nil {
return -1, 0
}
// Placeholders are 1-based.
return int(idx) - 1, i
}
default:
idx := -1
return func(string) (int, int) {
idx++
return idx, 0
}
}
}
func (w *WriteDriver) placeholder() byte {
if w.Dialect() == dialect.Postgres {
return '$'
}
return '?'
}
func (w *WriteDriver) formatArg(v any) (string, error) {
if w.FormatFunc != nil {
return w.FormatFunc(fmt.Sprint(v))
}
switch v := v.(type) {
case nil:
return "NULL", nil
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v), nil
case float32, float64:
return fmt.Sprintf("%g", v), nil
case bool:
if v {
return "1", nil
} else {
return "0", nil
}
case string:
return "'" + strings.ReplaceAll(v, "'", "''") + "'", nil
case json.RawMessage:
return "'" + strings.ReplaceAll(string(v), "'", "''") + "'", nil
case []byte:
return "{{ BINARY_VALUE }}", nil
case time.Time:
return "{{ TIME_VALUE }}", nil
default:
return "{{ VALUE }}", nil
}
}
// Tx writes the transaction start.
func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) {
if _, err := io.WriteString(w, "BEGIN;\n"); err != nil {
return nil, err
}
return w, nil
return dialect.NopTx(w), nil
}
// Commit writes the transaction commit.
func (w *WriteDriver) Commit() error {
_, err := io.WriteString(w, "COMMIT;\n")
return err
// noResult represents a zero result.
type noResult struct{}
func (noResult) LastInsertId() (int64, error) { return 0, nil }
func (noResult) RowsAffected() (int64, error) { return 0, nil }
// noRows represents no rows.
type noRows struct{ sql.ColumnScanner }
func (noRows) Close() error { return nil }
func (noRows) Err() error { return nil }
func (noRows) Next() bool { return false }
type nopDriver struct {
dialect.Driver
dialect string
}
// Rollback writes the transaction rollback.
func (w *WriteDriver) Rollback() error {
_, err := io.WriteString(w, "ROLLBACK;\n")
return err
func (d nopDriver) Dialect() string { return d.dialect }
func (nopDriver) Query(context.Context, string, any, any) error {
return nil
}

View File

@@ -7,47 +7,82 @@ package schema
import (
"bytes"
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
"ariga.io/atlas/sql/migrate"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqljson"
"github.com/stretchr/testify/require"
)
func TestWriteDriver(t *testing.T) {
b := &bytes.Buffer{}
w := WriteDriver{Driver: nopDriver{}, Writer: b}
w := NewWriteDriver(dialect.MySQL, b)
ctx := context.Background()
tx, err := w.Tx(ctx)
require.NoError(t, err)
err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil)
require.NoError(t, err)
err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil)
require.NoError(t, err)
require.EqualError(t, err, "query is not supported by the WriteDriver")
err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `age` int", nil, nil)
require.NoError(t, err)
err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", nil, nil)
require.NoError(t, err)
err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil)
require.NoError(t, err)
require.NoError(t, tx.Commit())
lines := strings.Split(b.String(), "\n")
require.Equal(t, "BEGIN;", lines[0])
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `age` int;", lines[1])
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", lines[2])
require.Equal(t, "COMMIT;", lines[3])
require.Empty(t, lines[4], "file ends with blank line")
require.Len(t, lines, 3)
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `age` int;", lines[0])
require.Equal(t, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", lines[1])
require.Empty(t, lines[2], "file ends with blank line")
b.Reset()
query, args := sql.Update("users").Schema("test").Set("a", 1).Set("b", "a").Set("c", "'c'").Set("d", true).Where(sql.EQ("p", 0.2)).Query()
err = w.Exec(ctx, query, args, nil)
require.NoError(t, err)
require.Equal(t, "UPDATE `test`.`users` SET `a` = 1, `b` = 'a', `c` = '''c''', `d` = 1 WHERE `p` = 0.2;\n", b.String())
b.Reset()
query, args = sql.Dialect(dialect.MySQL).Update("users").Schema("test").Set("a", "{}").Where(sqljson.ValueIsNull("a")).Query()
err = w.Exec(ctx, query, args, nil)
require.NoError(t, err)
require.Equal(t, "UPDATE `test`.`users` SET `a` = '{}' WHERE JSON_CONTAINS(`a`, 'null', '$');\n", b.String())
b.Reset()
w = NewWriteDriver(dialect.Postgres, b)
query, args = sql.Dialect(dialect.Postgres).Update("users").Set("a", 1).Set("b", time.Now()).Query()
err = w.Exec(ctx, query, args, nil)
require.NoError(t, err)
require.Equal(t, `UPDATE "users" SET "a" = 1, "b" = {{ TIME_VALUE }};`+"\n", b.String())
b.Reset()
err = w.Exec(ctx, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id`, nil, nil)
require.NoError(t, err)
require.Equal(t, `INSERT INTO "users" (name) VALUES("a8m") RETURNING id;`+"\n", b.String())
}
type nopDriver struct {
dialect.Driver
}
func (nopDriver) Exec(context.Context, string, any, any) error {
return nil
}
func (nopDriver) Query(context.Context, string, any, any) error {
return nil
func TestDirWriter(t *testing.T) {
p := t.TempDir()
dir, err := migrate.NewLocalDir(p)
require.NoError(t, err)
w := &DirWriter{Dir: dir}
drv := NewWriteDriver(dialect.MySQL, w)
require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `a` = ?", []any{1}, nil))
w.Change("Comment 1.")
require.NoError(t, drv.Exec(context.Background(), "UPDATE `test`.`users` SET `b` = ?", []any{2}, nil))
w.Change("Comment 2.")
require.NoError(t, w.Flush("migration_file"))
files, err := os.ReadDir(p)
require.NoError(t, err)
require.Len(t, files, 2)
require.Contains(t, files[0].Name(), "_migration_file.sql")
buf, err := os.ReadFile(filepath.Join(p, files[0].Name()))
require.NoError(t, err)
require.Equal(t, "-- Comment 1.\nUPDATE `test`.`users` SET `a` = 1;\n-- Comment 2.\nUPDATE `test`.`users` SET `b` = 2;\n", string(buf))
require.Equal(t, "atlas.sum", files[1].Name())
}

316
doc/md/data-migrations.mdx Normal file
View File

@@ -0,0 +1,316 @@
---
id: data-migrations
title: Data Migrations
---
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
Migrations are usually used for changing the database schema, but in some cases, there is a need to modify the data
stored in the database. For example, adding seed data, or back-filling empty columns with custom default values.
Migrations of this type are called data migrations. In this document, we will discuss how to use Ent to plan data
migrations and integrate them into your regular schema migrations workflow.
### Migration Types
Ent currently supports two types of migrations, [versioned migration](versioned-migrations.mdx) and [declarative migration](migrate)
(also known as automatic migration). Data migrations can be executed in both types of migrations.
## Versioned Migrations
When using versioned migrations, data migrations should be stored on the same `migrations` directory and executed the
same way as regular migrations. It is recommended, however, to store data migrations and schema migrations in separate
files so that they can be easily tested.
The format used for such migrations is SQL, as the file can be safely executed (and stored without changes) even if
the Ent schema was modified and the generated code is not compatible with the data migration file anymore.
There are two ways to create data migrations scripts, manually and generated. By manually editing, users write all the SQL
statements and can control exactly what will be executed. Alternatively, users can use Ent to generate the data migrations
for them. It is recommended to verify that the generated file was correctly generated, as in some cases it may need to
be manually fixed or edited.
### Manual Creation
1\. If you don't have Atlas installed, check out its [getting-started](https://atlasgo.io/getting-started/#installation)
guide.
2\. Create a new migration file using [Atlas](https://atlasgo.io/versioned/new):
```shell
atlas migrate new <migration_name> \
--dir "file://my/project/migrations"
```
3\. Edit the migration file and add the custom data migration there. For example:
```sql title="ent/migrate/migrations/20221126185750_backfill_data.sql"
-- Backfill NULL or null tags with a default value.
UPDATE `users` SET `tags` = '["foo","bar"]' WHERE `tags` IS NULL OR JSON_CONTAINS(`tags`, 'null', '$');
```
4\. Update the migration directory [integrity file](https://atlasgo.io/concepts/migration-directory-integrity):
```shell
atlas migrate hash \
--dir "file://my/project/migrations"
```
Check out the [Testing](#testing) section below if you're unsure how to test the data migration file.
### Generated Scripts
Currently, Ent provides initial support for generating data migration files. By using this option, users can simplify the
process of writing complex SQL statements manually in most cases. Still, it is recommended to verify that the generated
file was correctly generated, as in some edge cases it may need to be manually edited.
1\. Create your [versioned-migration setup](/docs/versioned/intro), in case it
is not set.
2\. Create your first data-migration function. Below, you will find some examples that demonstrate how to write such a
function:
<Tabs>
<TabItem value="single" label="Single Statement" default>
```go title="ent/migrate/migratedata/migratedata.go"
package migratedata
// BackfillUnknown back-fills all empty users' names with the default value 'Unknown'.
func BackfillUnknown(dir *migrate.LocalDir) error {
w := &schema.DirWriter{Dir: dir}
client := ent.NewClient(ent.Driver(schema.NewWriteDriver(dialect.MySQL, w)))
// Change all empty names to 'unknown'.
err := client.User.
Update().
Where(
user.NameEQ(""),
).
SetName("Unknown").
Exec(context.Background())
if err != nil {
return fmt.Errorf("failed generating statement: %w", err)
}
// Write the content to the migration directory.
return w.FlushChange(
"unknown_names",
"Backfill all empty user names with default value 'unknown'.",
)
}
```
Then, using this function in `ent/migrate/main.go` will generate the following migration file:
```sql title="migrations/20221126185750_unknown_names.sql"
-- Backfill all empty user names with default value 'unknown'.
UPDATE `users` SET `name` = 'Unknown' WHERE `users`.`name` = '';
```
</TabItem>
<TabItem value="multi" label="Multi Statement">
```go title="ent/migrate/migratedata/migratedata.go"
package migratedata
// BackfillUserTags is used to generate the migration file '20221126185750_backfill_user_tags.sql'.
func BackfillUserTags(dir *migrate.LocalDir) error {
w := &schema.DirWriter{Dir: dir}
client := ent.NewClient(ent.Driver(schema.NewWriteDriver(dialect.MySQL, w)))
// Add defaults "foo" and "bar" tags for users without any.
err := client.User.
Update().
Where(func(s *sql.Selector) {
s.Where(
sql.Or(
sql.IsNull(user.FieldTags),
sqljson.ValueIsNull(user.FieldTags),
),
)
}).
SetTags([]string{"foo", "bar"}).
Exec(context.Background())
if err != nil {
return fmt.Errorf("failed generating backfill statement: %w", err)
}
// Document all changes until now with a custom comment.
w.Change("Backfill NULL or null tags with a default value.")
// Append the "org" special tag for users with a specific prefix or suffix.
err = client.User.
Update().
Where(
user.Or(
user.NameHasPrefix("org-"),
user.NameHasSuffix("-org"),
),
// Append to only those without this tag.
func(s *sql.Selector) {
s.Where(
sql.Not(sqljson.ValueContains(user.FieldTags, "org")),
)
},
).
AppendTags([]string{"org"}).
Exec(context.Background())
if err != nil {
return fmt.Errorf("failed generating backfill statement: %w", err)
}
// Document all changes until now with a custom comment.
w.Change("Append the 'org' tag for organization accounts in case they don't have it.")
// Write the content to the migration directory.
return w.Flush("backfill_user_tags")
}
```
Then, using this function in `ent/migrate/main.go` will generate the following migration file:
```sql title="migrations/20221126185750_backfill_user_tags.sql"
-- Backfill NULL or null tags with a default value.
UPDATE `users` SET `tags` = '["foo","bar"]' WHERE `tags` IS NULL OR JSON_CONTAINS(`tags`, 'null', '$');
-- Append the 'org' tag for organization accounts in case they don't have it.
UPDATE `users` SET `tags` = CASE WHEN (JSON_TYPE(JSON_EXTRACT(`tags`, '$')) IS NULL OR JSON_TYPE(JSON_EXTRACT(`tags`, '$')) = 'NULL') THEN JSON_ARRAY('org') ELSE JSON_ARRAY_APPEND(`tags`, '$', 'org') END WHERE (`users`.`name` LIKE 'org-%' OR `users`.`name` LIKE '%-org') AND (NOT (JSON_CONTAINS(`tags`, '"org"', '$') = 1));
```
</TabItem>
<TabItem value="seed" label="Data Seeding">
```go title="ent/migrate/migratedata/migratedata.go"
package migratedata
// SeedUsers add the initial users to the database.
func SeedUsers(dir *migrate.LocalDir) error {
w := &schema.DirWriter{Dir: dir}
client := ent.NewClient(ent.Driver(schema.NewWriteDriver(dialect.MySQL, w)))
// The statement that generates the INSERT statement.
err := client.User.CreateBulk(
client.User.Create().SetName("a8m").SetAge(1).SetTags([]string{"foo"}),
client.User.Create().SetName("nati").SetAge(1).SetTags([]string{"bar"}),
).Exec(context.Background())
if err != nil {
return fmt.Errorf("failed generating statement: %w", err)
}
// Write the content to the migration directory.
return w.FlushChange(
"seed_users",
"Add the initial users to the database.",
)
}
```
Then, using this function in `ent/migrate/main.go` will generate the following migration file:
```sql title="migrations/20221126212120_seed_users.sql"
-- Add the initial users to the database.
INSERT INTO `users` (`age`, `name`, `tags`) VALUES (1, 'a8m', '["foo"]'), (1, 'nati', '["bar"]');
```
</TabItem>
</Tabs>
3\. In case the generated file was edited, the migration directory [integrity file](https://atlasgo.io/concepts/migration-directory-integrity)
needs to be updated with the following command:
```shell
atlas migrate hash \
--dir "file://my/project/migrations"
```
### Testing
After adding the migration files, it is highly recommended that you apply them on a local database to ensure they are
valid and achieve the intended results. The following process can be done manually or automated by a program.
1\. Execute all migration files until the last created one, the data migration file:
```shell
# Total number of files.
number_of_files=$(ls ent/migrate/migrations/*.sql | wc -l)
# Execute all files without the latest.
atlas migrate apply $[number_of_files-1] \
--dir "file://my/project/migrations" \
-u "mysql://root:pass@localhost:3306/test"
```
2\. Ensure the last migration file is pending execution:
```shell
atlas migrate status \
--dir "file://my/project/migrations" \
-u "mysql://root:pass@localhost:3306/test"
Migration Status: PENDING
-- Current Version: <VERSION_N-1>
-- Next Version: <VERSION_N>
-- Executed Files: <N-1>
-- Pending Files: 1
```
3\. Fill the local database with temporary data that represents the production database before running the data
migration file.
4\. Run `atlas migrate apply` and ensure it was executed successfully.
```shell
atlas migrate apply \
--dir "file://my/project/migrations" \
-u "mysql://root:pass@localhost:3306/test"
```
Note, by using `atlas schema clean` you can clean the database you use for local development and repeat this process
until the data migration file achieves the desired result.
## Automatic Migrations
In the declarative workflow, data migrations are implemented using Diff or Apply [Hooks](migrate#atlas-diff-and-apply-hooks).
This is because, unlike the versioned option, migrations of this type do not hold a name or a version when they are applied.
Therefore, when a data is are written using hooks, the type of the `schema.Change` must be checked before its
execution to ensure the data migration was not applied more than once.
```go
func FillNullValues(dbdialect string) schema.ApplyHook {
return func(next schema.Applier) schema.Applier {
return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
//highlight-next-line-info
// Search the schema.Change that triggers the data migration.
hasC := func() bool {
for _, c := range plan.Changes {
m, ok := c.Source.(*schema.ModifyTable)
if ok && m.T.Name == user.Table && schema.Changes(m.Changes).IndexModifyColumn(user.FieldName) != -1 {
return true
}
}
return false
}()
// Change was found, apply the data migration.
if hasC {
//highlight-info-start
// At this stage, there are three ways to UPDATE the NULL values to "Unknown".
// Append a custom migrate.Change to migrate.Plan, execute an SQL statement
// directly on the dialect.ExecQuerier, or use the generated ent.Client.
//highlight-info-end
// Create a temporary client from the migration connection.
client := ent.NewClient(
ent.Driver(sql.NewDriver(dbdialect, sql.Conn{ExecQuerier: conn.(*sql.Tx)})),
)
if err := client.User.
Update().
SetName("Unknown").
Where(user.NameIsNil()).
Exec(ctx); err != nil {
return err
}
}
return next.Apply(ctx, conn, plan)
})
}
}
```
For more examples, check out the [Apply Hook](migrate.md#apply-hook-example) examples section.

View File

@@ -704,7 +704,7 @@ You have a checksum error in your migration directory.
This happens if you manually create or edit a migration file.
Please check your migration files and run
'atlas migrate hash --force'
'atlas migrate hash'
to re-hash the contents and resolve the error.
@@ -716,7 +716,7 @@ file:
```shell
# Recompute the sum file.
atlas migrate hash --dir file://<path-to-your-migration-directory> --force
atlas migrate hash --dir file://<path-to-your-migration-directory>
```
Back to the problem above, if team A would land their changes on master first and team B would now attempt to land

View File

@@ -97,6 +97,7 @@ const config = {
{
className: 'code-block-info-line',
line: 'highlight-next-line-info',
block: {start: 'highlight-info-start', end: 'highlight-info-end'},
},
],
},

View File

@@ -44,6 +44,7 @@ module.exports = {
items: [
'migrate',
'versioned-migrations',
'data-migrations',
'dialects',
],
collapsed: false,

View File

@@ -0,0 +1,118 @@
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
// Package migratedata holds the functions for generating data migration files.
// It exists here for documentation and reference purpose only and has no runtime
// effect on the actual migration files.
package migratedata
import (
"context"
"fmt"
"ariga.io/atlas/sql/migrate"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/schema"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/examples/migration/ent"
"entgo.io/ent/examples/migration/ent/user"
)
// BackfillUserTags was used to generate the migration file '20221126185750_backfill_user_tags.sql'.
// It exists here for documentation purpose only, and can be used as a reference for future data migrations.
func BackfillUserTags(dir *migrate.LocalDir) error {
w := &schema.DirWriter{Dir: dir}
client := ent.NewClient(ent.Driver(schema.NewWriteDriver(dialect.MySQL, w)))
// Add defaults "foo" and "bar" tags for users without any.
err := client.User.
Update().
Where(func(s *sql.Selector) {
s.Where(
sql.Or(
sql.IsNull(user.FieldTags),
sqljson.ValueIsNull(user.FieldTags),
),
)
}).
SetTags([]string{"foo", "bar"}).
Exec(context.Background())
if err != nil {
return fmt.Errorf("failed generating backfill statement: %w", err)
}
// Document all changes until now with a custom comment.
w.Change("Backfill NULL or null tags with a default value.")
// Append the "org" special tag for users with a specific prefix or suffix.
err = client.User.
Update().
Where(
user.Or(
user.NameHasPrefix("org-"),
user.NameHasSuffix("-org"),
),
// Append to only those without this tag.
func(s *sql.Selector) {
s.Where(
sql.Not(sqljson.ValueContains(user.FieldTags, "org")),
)
},
).
AppendTags([]string{"org"}).
Exec(context.Background())
if err != nil {
return fmt.Errorf("failed generating backfill statement: %w", err)
}
// Document all changes until now with a custom comment.
w.Change("Append the 'org' tag for organization accounts in case they don't have it.")
// Write the content to the migration directory.
return w.Flush("backfill_user_tags")
}
// BackfillUnknown back-fills all empty users' names with the default value 'Unknown'.
func BackfillUnknown(dir *migrate.LocalDir) error {
w := &schema.DirWriter{Dir: dir}
client := ent.NewClient(ent.Driver(schema.NewWriteDriver(dialect.MySQL, w)))
// Change all empty names to 'unknown'.
err := client.User.
Update().
Where(
user.NameEQ(""),
).
SetName("Unknown").
Exec(context.Background())
if err != nil {
return fmt.Errorf("failed generating statement: %w", err)
}
// Write the content to the migration directory.
return w.FlushChange(
"unknown_names",
"Backfill all empty user names with default value 'unknown'.",
)
}
// SeedUsers add the initial users to the database.
func SeedUsers(dir *migrate.LocalDir) error {
w := &schema.DirWriter{Dir: dir}
client := ent.NewClient(ent.Driver(schema.NewWriteDriver(dialect.MySQL, w)))
// The statement that generates the INSERT statement.
err := client.User.CreateBulk(
client.User.Create().SetName("a8m").SetAge(1).SetTags([]string{"foo"}),
client.User.Create().SetName("nati").SetAge(1).SetTags([]string{"bar"}),
).Exec(context.Background())
if err != nil {
return fmt.Errorf("failed generating statement: %w", err)
}
// Write the content to the migration directory.
return w.FlushChange(
"seed_users",
"Add the initial users to the database.",
)
}

View File

@@ -0,0 +1,2 @@
-- modify "users" table
ALTER TABLE `users` ADD COLUMN `tags` json NULL;

View File

@@ -0,0 +1,4 @@
-- Backfill NULL or null tags with a default value.
UPDATE `users` SET `tags` = '["foo","bar"]' WHERE `tags` IS NULL OR JSON_CONTAINS(`tags`, 'null', '$');
-- Append the 'org' tag for organization accounts in case they don't have it.
UPDATE `users` SET `tags` = CASE WHEN (JSON_TYPE(JSON_EXTRACT(`tags`, '$')) IS NULL OR JSON_TYPE(JSON_EXTRACT(`tags`, '$')) = 'NULL') THEN JSON_ARRAY('org') ELSE JSON_ARRAY_APPEND(`tags`, '$', 'org') END WHERE (`users`.`name` LIKE 'org-%' OR `users`.`name` LIKE '%-org') AND (NOT (JSON_CONTAINS(`tags`, '"org"', '$') = 1));

View File

@@ -1,4 +1,6 @@
h1:uJ67zyvQoeKI01UWjiq/dHlKBBjzMcgwKdDKq4YCLy8=
h1:bAsy4+3Td11QEAWYiqYyHN46U6bAY7kyc16todRIwJg=
20221114082343_create_users.sql h1:EeOPKbOeYUHO830P1MCXKUsxT3wLk/hSgcpT9GzlGZQ=
20221114090322_add_age.sql h1:weRCTqS5cqKyrvIjGXk7Bn24YUXo0nhLNWD0EZBcaoQ=
20221114101516_add_name.sql h1:leQgI2JN3URZyujlNzX3gI53MSiTA02MKI/G9cnMu6I=
20221126173531_add_user_tags.sql h1:9yq4x41NnP8gqBg3IUbHKwE7gP/23Mx3ylf5IorLoFs=
20221126185750_backfill_user_tags.sql h1:qBPBJD4ry+uwj1eGPmTr5COJwv8Sjo2o+OR7TWFj6qs=

View File

@@ -18,6 +18,7 @@ var (
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "age", Type: field.TypeFloat64},
{Name: "name", Type: field.TypeString},
{Name: "tags", Type: field.TypeJSON, Nullable: true},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{

View File

@@ -39,6 +39,8 @@ type UserMutation struct {
age *float64
addage *float64
name *string
tags *[]string
appendtags []string
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*User, error)
@@ -235,6 +237,71 @@ func (m *UserMutation) ResetName() {
m.name = nil
}
// SetTags sets the "tags" field.
func (m *UserMutation) SetTags(s []string) {
m.tags = &s
m.appendtags = nil
}
// Tags returns the value of the "tags" field in the mutation.
func (m *UserMutation) Tags() (r []string, exists bool) {
v := m.tags
if v == nil {
return
}
return *v, true
}
// OldTags returns the old "tags" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTags(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTags is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTags requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTags: %w", err)
}
return oldValue.Tags, nil
}
// AppendTags adds s to the "tags" field.
func (m *UserMutation) AppendTags(s []string) {
m.appendtags = append(m.appendtags, s...)
}
// AppendedTags returns the list of values that were appended to the "tags" field in this mutation.
func (m *UserMutation) AppendedTags() ([]string, bool) {
if len(m.appendtags) == 0 {
return nil, false
}
return m.appendtags, true
}
// ClearTags clears the value of the "tags" field.
func (m *UserMutation) ClearTags() {
m.tags = nil
m.appendtags = nil
m.clearedFields[user.FieldTags] = struct{}{}
}
// TagsCleared returns if the "tags" field was cleared in this mutation.
func (m *UserMutation) TagsCleared() bool {
_, ok := m.clearedFields[user.FieldTags]
return ok
}
// ResetTags resets all changes to the "tags" field.
func (m *UserMutation) ResetTags() {
m.tags = nil
m.appendtags = nil
delete(m.clearedFields, user.FieldTags)
}
// Where appends a list predicates to the UserMutation builder.
func (m *UserMutation) Where(ps ...predicate.User) {
m.predicates = append(m.predicates, ps...)
@@ -254,13 +321,16 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 2)
fields := make([]string, 0, 3)
if m.age != nil {
fields = append(fields, user.FieldAge)
}
if m.name != nil {
fields = append(fields, user.FieldName)
}
if m.tags != nil {
fields = append(fields, user.FieldTags)
}
return fields
}
@@ -273,6 +343,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.Age()
case user.FieldName:
return m.Name()
case user.FieldTags:
return m.Tags()
}
return nil, false
}
@@ -286,6 +358,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldAge(ctx)
case user.FieldName:
return m.OldName(ctx)
case user.FieldTags:
return m.OldTags(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -309,6 +383,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetName(v)
return nil
case user.FieldTags:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTags(v)
return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -353,7 +434,11 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
func (m *UserMutation) ClearedFields() []string {
return nil
var fields []string
if m.FieldCleared(user.FieldTags) {
fields = append(fields, user.FieldTags)
}
return fields
}
// FieldCleared returns a boolean indicating if a field with the given name was
@@ -366,6 +451,11 @@ func (m *UserMutation) FieldCleared(name string) bool {
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
func (m *UserMutation) ClearField(name string) error {
switch name {
case user.FieldTags:
m.ClearTags()
return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -379,6 +469,9 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldName:
m.ResetName()
return nil
case user.FieldTags:
m.ResetTags()
return nil
}
return fmt.Errorf("unknown User field %s", name)
}

View File

@@ -17,6 +17,8 @@ func (User) Fields() []ent.Field {
return []ent.Field{
field.Float("age"),
field.String("name"),
field.Strings("tags").
Optional(),
}
}

View File

@@ -7,6 +7,7 @@
package ent
import (
"encoding/json"
"fmt"
"strings"
@@ -23,6 +24,8 @@ type User struct {
Age float64 `json:"age,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// Tags holds the value of the "tags" field.
Tags []string `json:"tags,omitempty"`
}
// scanValues returns the types for scanning values from sql.Rows.
@@ -30,6 +33,8 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case user.FieldTags:
values[i] = new([]byte)
case user.FieldAge:
values[i] = new(sql.NullFloat64)
case user.FieldID:
@@ -69,6 +74,14 @@ func (u *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
u.Name = value.String
}
case user.FieldTags:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field tags", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &u.Tags); err != nil {
return fmt.Errorf("unmarshal field tags: %w", err)
}
}
}
}
return nil
@@ -102,6 +115,9 @@ func (u *User) String() string {
builder.WriteString(", ")
builder.WriteString("name=")
builder.WriteString(u.Name)
builder.WriteString(", ")
builder.WriteString("tags=")
builder.WriteString(fmt.Sprintf("%v", u.Tags))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -15,6 +15,8 @@ const (
FieldAge = "age"
// FieldName holds the string denoting the name field in the database.
FieldName = "name"
// FieldTags holds the string denoting the tags field in the database.
FieldTags = "tags"
// Table holds the table name of the user in the database.
Table = "users"
)
@@ -24,6 +26,7 @@ var Columns = []string{
FieldID,
FieldAge,
FieldName,
FieldTags,
}
// ValidColumn reports if the column name is valid (part of the table columns).

View File

@@ -259,6 +259,20 @@ func NameContainsFold(v string) predicate.User {
})
}
// TagsIsNil applies the IsNil predicate on the "tags" field.
func TagsIsNil() predicate.User {
return predicate.User(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C(FieldTags)))
})
}
// TagsNotNil applies the NotNil predicate on the "tags" field.
func TagsNotNil() predicate.User {
return predicate.User(func(s *sql.Selector) {
s.Where(sql.NotNull(s.C(FieldTags)))
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.User) predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -35,6 +35,12 @@ func (uc *UserCreate) SetName(s string) *UserCreate {
return uc
}
// SetTags sets the "tags" field.
func (uc *UserCreate) SetTags(s []string) *UserCreate {
uc.mutation.SetTags(s)
return uc
}
// Mutation returns the UserMutation object of the builder.
func (uc *UserCreate) Mutation() *UserMutation {
return uc.mutation
@@ -152,6 +158,10 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldName, field.TypeString, value)
_node.Name = value
}
if value, ok := uc.mutation.Tags(); ok {
_spec.SetField(user.FieldTags, field.TypeJSON, value)
_node.Tags = value
}
return _node, _spec
}

View File

@@ -13,6 +13,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/examples/migration/ent/predicate"
"entgo.io/ent/examples/migration/ent/user"
"entgo.io/ent/schema/field"
@@ -50,6 +51,24 @@ func (uu *UserUpdate) SetName(s string) *UserUpdate {
return uu
}
// SetTags sets the "tags" field.
func (uu *UserUpdate) SetTags(s []string) *UserUpdate {
uu.mutation.SetTags(s)
return uu
}
// AppendTags appends s to the "tags" field.
func (uu *UserUpdate) AppendTags(s []string) *UserUpdate {
uu.mutation.AppendTags(s)
return uu
}
// ClearTags clears the value of the "tags" field.
func (uu *UserUpdate) ClearTags() *UserUpdate {
uu.mutation.ClearTags()
return uu
}
// Mutation returns the UserMutation object of the builder.
func (uu *UserUpdate) Mutation() *UserMutation {
return uu.mutation
@@ -136,6 +155,17 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
if value, ok := uu.mutation.Name(); ok {
_spec.SetField(user.FieldName, field.TypeString, value)
}
if value, ok := uu.mutation.Tags(); ok {
_spec.SetField(user.FieldTags, field.TypeJSON, value)
}
if value, ok := uu.mutation.AppendedTags(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, user.FieldTags, value)
})
}
if uu.mutation.TagsCleared() {
_spec.ClearField(user.FieldTags, field.TypeJSON)
}
if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -174,6 +204,24 @@ func (uuo *UserUpdateOne) SetName(s string) *UserUpdateOne {
return uuo
}
// SetTags sets the "tags" field.
func (uuo *UserUpdateOne) SetTags(s []string) *UserUpdateOne {
uuo.mutation.SetTags(s)
return uuo
}
// AppendTags appends s to the "tags" field.
func (uuo *UserUpdateOne) AppendTags(s []string) *UserUpdateOne {
uuo.mutation.AppendTags(s)
return uuo
}
// ClearTags clears the value of the "tags" field.
func (uuo *UserUpdateOne) ClearTags() *UserUpdateOne {
uuo.mutation.ClearTags()
return uuo
}
// Mutation returns the UserMutation object of the builder.
func (uuo *UserUpdateOne) Mutation() *UserMutation {
return uuo.mutation
@@ -290,6 +338,17 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error)
if value, ok := uuo.mutation.Name(); ok {
_spec.SetField(user.FieldName, field.TypeString, value)
}
if value, ok := uuo.mutation.Tags(); ok {
_spec.SetField(user.FieldTags, field.TypeJSON, value)
}
if value, ok := uuo.mutation.AppendedTags(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, user.FieldTags, value)
})
}
if uuo.mutation.TagsCleared() {
_spec.ClearField(user.FieldTags, field.TypeJSON)
}
_node = &User{config: uuo.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

2
go.mod
View File

@@ -3,7 +3,7 @@ module entgo.io/ent
go 1.19
require (
ariga.io/atlas v0.8.3-0.20221116151337-9e4e9cbf3baf
ariga.io/atlas v0.8.3
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/go-openapi/inflect v0.19.0
github.com/go-sql-driver/mysql v1.6.0

6
go.sum
View File

@@ -1,5 +1,11 @@
ariga.io/atlas v0.8.3-0.20221116151337-9e4e9cbf3baf h1:tq28xcfFAtxk75ej1IwK+yIbRYC0fqNZkHljcVbYrOs=
ariga.io/atlas v0.8.3-0.20221116151337-9e4e9cbf3baf/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE=
ariga.io/atlas v0.8.3-0.20221125075918-4131ef186252 h1:hnvvUZiyn51tGJevOE5jSpcUb+tEF1ISpAddbeRJIF8=
ariga.io/atlas v0.8.3-0.20221125075918-4131ef186252/go.mod h1:T230JFcENj4ZZzMkZrXFDSkv+2kXkUgpJ5FQQ5hMcKU=
ariga.io/atlas v0.8.3-0.20221125104408-af32ea84f3de h1:mCehveiRKPRzLZqvtG6RgNpU+CFAqOtcBI5lre8bYlE=
ariga.io/atlas v0.8.3-0.20221125104408-af32ea84f3de/go.mod h1:T230JFcENj4ZZzMkZrXFDSkv+2kXkUgpJ5FQQ5hMcKU=
ariga.io/atlas v0.8.3 h1:nddOywkhr/62Cwa+UsGgO35lAhUYh52XGVsbFwGzWZM=
ariga.io/atlas v0.8.3/go.mod h1:T230JFcENj4ZZzMkZrXFDSkv+2kXkUgpJ5FQQ5hMcKU=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=