From b8f4614bfd12abfd09bfdaa1cc0b7339e64eb508 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Thu, 17 Jun 2021 22:13:19 +0300 Subject: [PATCH] dialect/sql: allow adding check clauses in create table --- dialect/sql/builder.go | 26 ++++++++++++++++++-------- dialect/sql/builder_test.go | 7 +++++-- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 29c508498..ee964f42c 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -107,14 +107,15 @@ func (c *ColumnBuilder) Query() (string, []interface{}) { // TableBuilder is a query builder for `CREATE TABLE` statement. type TableBuilder struct { Builder - name string // table name. - exists bool // check existence. - charset string // table charset. - collation string // table collation. - options string // table options. - columns []Querier // table columns. - primary []string // primary key. - constraints []Querier // foreign keys and indices. + name string // table name. + exists bool // check existence. + charset string // table charset. + collation string // table collation. + options string // table options. + columns []Querier // table columns. + primary []string // primary key. + constraints []Querier // foreign keys and indices. + checks []func(*Builder) // check constraints. } // CreateTable returns a query builder for the `CREATE TABLE` statement. @@ -177,6 +178,12 @@ func (t *TableBuilder) Constraints(fks ...*ForeignKeyBuilder) *TableBuilder { return t } +// Checks adds CHECK clauses to the CREATE TABLE statement. +func (t *TableBuilder) Checks(checks ...func(*Builder)) *TableBuilder { + t.checks = append(t.checks, checks...) + return t +} + // Charset appends the `CHARACTER SET` clause to the statement. MySQL only. func (t *TableBuilder) Charset(s string) *TableBuilder { t.charset = s @@ -218,6 +225,9 @@ func (t *TableBuilder) Query() (string, []interface{}) { if len(t.constraints) > 0 { b.Comma().JoinComma(t.constraints...) } + for _, check := range t.checks { + check(b.Comma()) + } }) if t.charset != "" { t.WriteString(" CHARACTER SET " + t.charset) diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 0963aa44b..7dc824084 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -86,8 +86,11 @@ func TestBuilder(t *testing.T) { ). PrimaryKey("id", "name"). ForeignKeys(ForeignKey().Columns("card_id"). - Reference(Reference().Table("cards").Columns("id")).OnDelete("SET NULL")), - wantQuery: "CREATE TABLE IF NOT EXISTS `users`(`id` int auto_increment, `card_id` int, `doc` longtext CHECK (JSON_VALID(`doc`)), PRIMARY KEY(`id`, `name`), FOREIGN KEY(`card_id`) REFERENCES `cards`(`id`) ON DELETE SET NULL)", + Reference(Reference().Table("cards").Columns("id")).OnDelete("SET NULL")). + Checks(func(b *Builder) { + b.WriteString("CONSTRAINT ").Ident("valid_card").WriteString(" CHECK (").Ident("card_id").WriteString(" > 0)") + }), + wantQuery: "CREATE TABLE IF NOT EXISTS `users`(`id` int auto_increment, `card_id` int, `doc` longtext CHECK (JSON_VALID(`doc`)), PRIMARY KEY(`id`, `name`), FOREIGN KEY(`card_id`) REFERENCES `cards`(`id`) ON DELETE SET NULL, CONSTRAINT `valid_card` CHECK (`card_id` > 0))", }, { input: Dialect(dialect.Postgres).CreateTable("users").