From 6a1829cc3339bd596f5732bba9af0b3c2ed0cb98 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Fri, 4 Sep 2020 14:13:29 +0300 Subject: [PATCH] dialect/sql: add DotPath option to json option (#725) --- .github/workflows/ci.yml | 4 +- dialect/sql/builder.go | 80 ++++++++++++++++++++++++++-- dialect/sql/builder_test.go | 102 +++++++++++++++++++++++++++++++++++- 3 files changed, 177 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 952b3421c..b206d105b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -276,9 +276,9 @@ jobs: run: git checkout origin/master - name: Run integration on origin/master working-directory: entc/integration - run: go test -race -count=2 ./... + run: go test -race -count=2 -tags='json1' ./... - name: Checkout previous HEAD run: git checkout - - name: Run integration on HEAD working-directory: entc/integration - run: go test -race -count=2 . + run: go test -race -count=2 -tags='json1' ./... diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 3c8852eef..6a9df6433 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -898,14 +898,17 @@ func (p *Predicate) EQ(col string, arg interface{}) *Predicate { } // JSONHasKey calls Predicate.JSONHasKey. -func JSONHasKey(col string, path ...string) *Predicate { - return P().JSONHasKey(col, path...) +func JSONHasKey(col string, path string) *Predicate { + return P().JSONHasKey(col, path) } -// JSONHasKey return a predicate for checking that a JSON key exists and not NULL; -func (p *Predicate) JSONHasKey(col string, path ...string) *Predicate { +// JSONHasKey return a predicate for checking that a JSON key exists and not NULL. +// +// P().JSONHasKey("column", "a.b[2].c") +// +func (p *Predicate) JSONHasKey(col string, path string) *Predicate { return p.Append(func(b *Builder) { - b.JSONPath(col, Path(path...)).WriteOp(OpNotNull) + b.JSONPath(col, DotPath(path)).WriteOp(OpNotNull) }) } @@ -1956,6 +1959,19 @@ func Path(path ...string) JSONOption { } } +// DotPath is similar to Path, but accepts string with dot format. +// +// b.JSONPath("column", DotPath("a.b[2].c")) +// b.JSONPath("column", DotPath("a.b.c")) +// +// Note that DotPath is ignored if the input is invalid. +func DotPath(dotpath string) JSONOption { + path, _ := ParsePath(dotpath) + return func(p *JSONPath) { + p.path = path + } +} + // Unquote indicates that the result value should be unquoted. // // b.JSONPath("column", Path("a", "b", "[1]", "c"), Unquote(true)) @@ -2365,6 +2381,60 @@ func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder { return b } +// ParsePath parses the "dotpath" for the DotPath option. +// +// "a.b" => ["a", "b"] +// "a[1][2]" => ["a", "[1]", "[2]"] +// "a.\"b.c\" => ["a", "\"b.c\""] +// +func ParsePath(dotpath string) ([]string, error) { + var ( + i, p int + path []string + ) + for i < len(dotpath) { + switch r := dotpath[i]; { + case r == '"': + if i == len(dotpath)-1 { + return nil, fmt.Errorf("unexpected quote") + } + idx := strings.IndexRune(dotpath[i+1:], '"') + if idx == -1 || idx == 0 { + return nil, fmt.Errorf("unbalanced quote") + } + i += idx + 2 + case r == '[': + if p != i { + path = append(path, dotpath[p:i]) + } + p = i + if i == len(dotpath)-1 { + return nil, fmt.Errorf("unexpected bracket") + } + idx := strings.IndexRune(dotpath[i:], ']') + if idx == -1 || idx == 1 { + return nil, fmt.Errorf("unbalanced bracket") + } + if !isNumber(dotpath[i+1 : i+idx]) { + return nil, fmt.Errorf("invalid index %q", dotpath[i:i+idx+1]) + } + i += idx + 1 + case r == '.' || r == ']': + if p != i { + path = append(path, dotpath[p:i]) + } + i++ + p = i + default: + i++ + } + } + if p != i { + path = append(path, dotpath[p:i]) + } + return path, nil +} + // isNumber reports whether the string is a number (category N). func isNumber(s string) bool { for _, r := range s { diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index f49c242cc..9f986c436 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1272,16 +1272,59 @@ WHERE { input: Select("*"). From(Table("test")). - Where(JSONHasKey("j", "a", "*", "c")), + Where(JSONHasKey("j", "a.*.c")), wantQuery: "SELECT * FROM `test` WHERE JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL", }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("test")). - Where(JSONHasKey("j", "a", "b", "c")), + Where(JSONHasKey("j", "a.b.c")), wantQuery: `SELECT * FROM "test" WHERE "j"->'a'->'b'->'c' IS NOT NULL`, }, + { + input: Dialect(dialect.Postgres). + Select("*"). + From(Table("test")). + Where(JSONHasKey("j", "a.b.c")), + wantQuery: `SELECT * FROM "test" WHERE "j"->'a'->'b'->'c' IS NOT NULL`, + }, + { + input: Dialect(dialect.Postgres). + Select("*"). + From(Table("users")). + Where(P(func(b *Builder) { + b.JSONPath("a", DotPath("b.c[1].d"), Cast("int")) + b.WriteOp(OpEQ) + b.Arg(1) + })), + wantQuery: `SELECT * FROM "users" WHERE CAST("a"->'b'->'c'->1->'d' AS int) = $1`, + wantArgs: []interface{}{1}, + }, + { + input: Dialect(dialect.MySQL). + Select("*"). + From(Table("users")). + Where(P(func(b *Builder) { + b.JSONPath("a", DotPath("b.c[1].d")) + b.WriteOp(OpEQ) + b.Arg("a") + })), + wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.c[1].d\") = ?", + wantArgs: []interface{}{"a"}, + }, + { + input: Dialect(dialect.MySQL). + Select("*"). + From(Table("users")). + Where(P(func(b *Builder) { + b.JSONPath("a", DotPath("b.\"c[1]\".d[1][2].e")) + b.WriteOp(OpEQ) + b.Arg("a") + })), + wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.\"c[1]\".d[1][2].e\") = ?", + wantArgs: []interface{}{"a"}, + }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { @@ -1291,3 +1334,58 @@ WHERE }) } } + +func TestParsePath(t *testing.T) { + tests := []struct { + input string + wantPath []string + wantErr bool + }{ + { + input: "a.b.c", + wantPath: []string{"a", "b", "c"}, + }, + { + input: "a[1][2]", + wantPath: []string{"a", "[1]", "[2]"}, + }, + { + input: "a[1][2].b", + wantPath: []string{"a", "[1]", "[2]", "b"}, + }, + { + input: `a."b.c[0]"`, + wantPath: []string{"a", `"b.c[0]"`}, + }, + { + input: `a."b.c[0]".d`, + wantPath: []string{"a", `"b.c[0]"`, "d"}, + }, + { + input: `...`, + }, + { + input: `.a.b.`, + wantPath: []string{"a", "b"}, + }, + { + input: `a."`, + wantErr: true, + }, + { + input: `a[`, + wantErr: true, + }, + { + input: `a[a]`, + wantErr: true, + }, + } + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + path, err := ParsePath(tt.input) + require.Equal(t, tt.wantPath, path) + require.Equal(t, tt.wantErr, err != nil) + }) + } +}