From f12ef91829238c175ac4f12ed22f2366d4419cdb Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Thu, 25 Mar 2021 16:26:44 +0200 Subject: [PATCH] entc/gen: privatize table columns check --- dialect/sql/builder.go | 5 + entc/gen/internal/bindata.go | 12 +- entc/gen/template/dialect/sql/by.tmpl | 48 +++++--- entc/gen/template/dialect/sql/group.tmpl | 2 +- entc/gen/template/dialect/sql/query.tmpl | 4 +- .../cascadelete/ent/comment_query.go | 6 +- entc/integration/cascadelete/ent/ent.go | 82 +++++++++----- .../integration/cascadelete/ent/post_query.go | 6 +- .../integration/cascadelete/ent/user_query.go | 6 +- entc/integration/config/ent/ent.go | 78 ++++++++----- entc/integration/config/ent/user_query.go | 6 +- entc/integration/customid/ent/blob_query.go | 6 +- entc/integration/customid/ent/car_query.go | 6 +- entc/integration/customid/ent/ent.go | 88 ++++++++++----- entc/integration/customid/ent/group_query.go | 6 +- .../integration/customid/ent/mixinid_query.go | 6 +- entc/integration/customid/ent/pet_query.go | 6 +- entc/integration/customid/ent/user_query.go | 6 +- entc/integration/edgefield/ent/card_query.go | 6 +- entc/integration/edgefield/ent/ent.go | 88 ++++++++++----- entc/integration/edgefield/ent/info_query.go | 6 +- .../edgefield/ent/metadata_query.go | 6 +- entc/integration/edgefield/ent/pet_query.go | 6 +- entc/integration/edgefield/ent/post_query.go | 6 +- entc/integration/edgefield/ent/user_query.go | 6 +- entc/integration/ent/card_query.go | 6 +- entc/integration/ent/comment_query.go | 6 +- entc/integration/ent/ent.go | 104 +++++++++++++----- entc/integration/ent/fieldtype_query.go | 6 +- entc/integration/ent/file_query.go | 6 +- entc/integration/ent/filetype_query.go | 6 +- entc/integration/ent/goods_query.go | 6 +- entc/integration/ent/group_query.go | 6 +- entc/integration/ent/groupinfo_query.go | 6 +- entc/integration/ent/item_query.go | 6 +- entc/integration/ent/node_query.go | 6 +- entc/integration/ent/pet_query.go | 6 +- entc/integration/ent/spec_query.go | 6 +- entc/integration/ent/task_query.go | 6 +- entc/integration/ent/user_query.go | 6 +- entc/integration/hooks/ent/card_query.go | 6 +- entc/integration/hooks/ent/ent.go | 80 +++++++++----- entc/integration/hooks/ent/user_query.go | 6 +- entc/integration/idtype/ent/ent.go | 78 ++++++++----- entc/integration/idtype/ent/user_query.go | 6 +- entc/integration/integration_test.go | 31 +++++- entc/integration/json/ent/ent.go | 78 ++++++++----- entc/integration/json/ent/user_query.go | 6 +- entc/integration/migrate/entv1/car_query.go | 6 +- .../migrate/entv1/conversion_query.go | 6 +- .../migrate/entv1/customtype_query.go | 6 +- entc/integration/migrate/entv1/ent.go | 84 +++++++++----- entc/integration/migrate/entv1/user_query.go | 6 +- entc/integration/migrate/entv2/car_query.go | 6 +- .../migrate/entv2/conversion_query.go | 6 +- .../migrate/entv2/customtype_query.go | 6 +- entc/integration/migrate/entv2/ent.go | 90 ++++++++++----- entc/integration/migrate/entv2/group_query.go | 6 +- entc/integration/migrate/entv2/media_query.go | 6 +- entc/integration/migrate/entv2/pet_query.go | 6 +- entc/integration/migrate/entv2/user_query.go | 6 +- entc/integration/multischema/ent/ent.go | 82 +++++++++----- .../multischema/ent/group_query.go | 6 +- entc/integration/multischema/ent/pet_query.go | 6 +- .../integration/multischema/ent/user_query.go | 6 +- entc/integration/privacy/ent/ent.go | 82 +++++++++----- entc/integration/privacy/ent/task_query.go | 6 +- entc/integration/privacy/ent/team_query.go | 6 +- entc/integration/privacy/ent/user_query.go | 6 +- entc/integration/template/ent/ent.go | 82 +++++++++----- entc/integration/template/ent/group_query.go | 6 +- entc/integration/template/ent/pet_query.go | 6 +- entc/integration/template/ent/user_query.go | 6 +- examples/edgeindex/ent/city_query.go | 6 +- examples/edgeindex/ent/ent.go | 80 +++++++++----- examples/edgeindex/ent/street_query.go | 6 +- examples/entcpkg/ent/ent.go | 78 ++++++++----- examples/entcpkg/ent/user_query.go | 6 +- examples/m2m2types/ent/ent.go | 80 +++++++++----- examples/m2m2types/ent/group_query.go | 6 +- examples/m2m2types/ent/user_query.go | 6 +- examples/m2mbidi/ent/ent.go | 78 ++++++++----- examples/m2mbidi/ent/user_query.go | 6 +- examples/m2mrecur/ent/ent.go | 78 ++++++++----- examples/m2mrecur/ent/user_query.go | 6 +- examples/o2m2types/ent/ent.go | 80 +++++++++----- examples/o2m2types/ent/pet_query.go | 6 +- examples/o2m2types/ent/user_query.go | 6 +- examples/o2mrecur/ent/ent.go | 78 ++++++++----- examples/o2mrecur/ent/node_query.go | 6 +- examples/o2o2types/ent/card_query.go | 6 +- examples/o2o2types/ent/ent.go | 80 +++++++++----- examples/o2o2types/ent/user_query.go | 6 +- examples/o2obidi/ent/ent.go | 78 ++++++++----- examples/o2obidi/ent/user_query.go | 6 +- examples/o2orecur/ent/ent.go | 78 ++++++++----- examples/o2orecur/ent/node_query.go | 6 +- examples/privacyadmin/ent/ent.go | 78 ++++++++----- examples/privacyadmin/ent/user_query.go | 6 +- examples/privacytenant/ent/ent.go | 82 +++++++++----- examples/privacytenant/ent/group_query.go | 6 +- examples/privacytenant/ent/tenant_query.go | 6 +- examples/privacytenant/ent/user_query.go | 6 +- examples/start/ent/car_query.go | 6 +- examples/start/ent/ent.go | 82 +++++++++----- examples/start/ent/group_query.go | 6 +- examples/start/ent/user_query.go | 6 +- examples/traversal/ent/ent.go | 82 +++++++++----- examples/traversal/ent/group_query.go | 6 +- examples/traversal/ent/pet_query.go | 6 +- examples/traversal/ent/user_query.go | 6 +- 111 files changed, 1790 insertions(+), 988 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index cdaa49549..39918cb8b 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1846,6 +1846,11 @@ func (s *Selector) Table() *SelectTable { return s.from.(*SelectTable) } +// TableName returns the name of the selected table. +func (s *Selector) TableName() string { + return s.Table().name +} + // Join appends a `JOIN` clause to the statement. func (s *Selector) Join(t TableView) *Selector { return s.join("JOIN", t) diff --git a/entc/gen/internal/bindata.go b/entc/gen/internal/bindata.go index 620f96291..f9a9b72bc 100644 --- a/entc/gen/internal/bindata.go +++ b/entc/gen/internal/bindata.go @@ -612,7 +612,7 @@ func templateDialectGremlinUpdateTmpl() (*asset, error) { return a, nil } -var _templateDialectSqlByTmpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x9c\x53\x5f\x4f\xdb\x3e\x14\x7d\x4e\x3e\xc5\xfd\x45\xfc\xa6\xa4\x6a\x13\xc6\xdb\x90\xf6\xd0\x21\x90\x90\x36\xf6\xc0\xb4\x3d\x22\x63\x5f\xa7\x16\xc6\x4e\x6d\x87\xa9\xb2\xfc\xdd\x27\xdb\x69\xe9\x0a\x03\xb6\x87\x4a\xf1\xfd\x77\xce\x3d\xf7\xd4\xfb\x6e\x56\x9e\xe9\x61\x63\x44\xbf\x72\x70\x72\xfc\xfe\xc3\x62\x30\x68\x51\x39\xb8\x20\x14\x6f\xb5\xbe\x83\x4b\x45\x5b\x58\x4a\x09\xa9\xc8\x42\xcc\x9b\x07\x64\x6d\xf9\x6d\x25\x2c\x58\x3d\x1a\x8a\x40\x35\x43\x10\x16\xa4\xa0\xa8\x2c\x32\x18\x15\x43\x03\x6e\x85\xb0\x1c\x08\x5d\x21\x9c\xb4\xc7\xdb\x2c\x70\x3d\x2a\x56\x0a\x95\xf2\x9f\x2f\xcf\xce\xaf\xae\xcf\x81\x0b\x89\x30\xc5\x8c\xd6\x0e\x98\x30\x48\x9d\x36\x1b\xd0\x1c\xdc\x1e\x98\x33\x88\x6d\x39\xeb\x42\x28\x4b\xef\x81\x21\x17\x0a\xa1\x62\x82\x48\xa4\xae\xb3\x6b\xd9\x69\xc3\xd0\x74\x56\xf4\x8a\xb8\xd1\x60\x05\x8b\x10\xca\xa2\xeb\xe0\x6b\x4c\x5c\x8c\x8a\x02\x19\x06\x29\xd0\x02\x51\x90\xaa\x85\xea\x41\x67\x78\xbb\x96\x60\x51\x26\xf4\xb6\x2c\xdc\x66\xc0\xbd\x46\x3e\x2a\x5a\xcf\xec\x5a\xb6\xd7\x53\xcd\x3c\xc7\xac\x8b\x43\x1a\xb8\xd5\x5a\x36\xa5\xf7\x0b\x40\xc5\xe0\x55\x96\xb1\x77\x22\x18\x7b\x8e\x38\x9c\x7e\x84\xa3\xf6\x9a\xea\x01\xdb\x84\x98\x72\x19\x02\x0e\x80\xe9\x0a\xe9\xdd\x33\xf0\xe0\xcb\xa2\xe0\xda\xc0\xcd\x1c\xd2\x40\x43\x54\x8f\xc0\x05\x4a\x66\x53\xb2\x10\x3c\x77\xd7\x3c\x57\x17\x85\x6d\xd3\x96\x9f\x36\x75\x04\xf1\x3e\x72\x09\xa1\xe6\x4d\x13\xd3\x01\x50\x5a\xdc\x95\x2e\x19\x3b\x37\x46\x9b\xfa\xdd\x77\x22\x05\x23\x4e\x68\x95\x02\xfe\x8a\xdc\xe3\x29\xf0\x39\xa0\x31\xa7\xc0\xef\x5d\x9b\xe2\xbc\xae\x84\x7a\x88\xb5\x99\x06\xfc\xbf\x86\xc8\x70\xab\x7e\x35\x07\xde\x84\x0c\x55\xa6\x5f\x38\x10\xb1\x9b\x01\x1d\xad\xd3\xf7\xb0\xbb\x6c\x9a\xd0\x1b\x3d\x0e\x8b\xdb\x4d\xd2\x21\xf2\x80\x64\x8e\x3f\xa8\x9e\xaa\x9f\x78\x23\x1d\x79\xd9\xf7\x06\x7b\xe2\xf0\x2f\x0e\x0d\xf9\xf5\x96\x7b\x67\x64\x62\xab\x7f\xba\x68\x7e\x65\xfd\x0d\xba\xd1\xa8\x68\xd4\x76\x69\x6b\xae\x6a\x3b\x35\x36\xf3\x48\xa2\x79\xaa\xdd\x0b\x84\x9e\x18\x50\x3d\xef\xc0\x94\xfc\x29\xdc\xea\x22\x9d\x6f\xaf\xe6\xc7\x2e\xf8\xc2\x62\xde\x83\xe0\xfb\x03\x42\x48\x94\xbd\xcf\xc6\x0a\xe1\x26\x7e\x26\xc6\xaf\xec\x1f\x99\x1c\xce\x9a\x2c\xfd\xdf\xe4\xe9\x18\x6d\xde\x6e\xd6\x58\xfe\x56\xc3\x26\xd5\x26\xc3\x26\x98\x6c\xda\xed\x51\xaa\x6a\x6b\xe1\xbd\x0b\xfc\x76\xb2\xac\x04\xae\x93\xd4\xd5\x17\x24\xaa\x82\x10\x96\x0f\xfd\xa3\x14\xe9\xcf\xa7\xf2\x47\x1e\x51\x3f\xa3\x9f\x6d\xcf\xa6\x4d\x1f\x3b\xab\x59\xb5\xeb\x39\xf4\xc1\xaf\x00\x00\x00\xff\xff\x2c\x66\xf4\x6e\xf2\x05\x00\x00") +var _templateDialectSqlByTmpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xa4\x54\x4d\x6f\xe3\x36\x10\x3d\x4b\xbf\x62\x56\x48\x0a\xc9\x70\xe4\xed\xde\xea\x62\x0f\xae\x91\x00\x0b\xb4\x69\x81\x2c\xda\xc3\x62\x51\xd0\xd4\x50\x26\x4c\x93\x32\x49\x25\x35\x04\xfd\xf7\x82\x43\xd9\x91\xd6\x9b\x20\x68\x6f\x36\xe7\xe3\xbd\x79\xf3\x34\x5d\xb7\x98\xa5\x6b\xd3\x1c\xad\xac\xb7\x1e\x3e\xbc\xff\xf1\xa7\x9b\xc6\xa2\x43\xed\xe1\x8e\x71\xdc\x18\xb3\x83\x4f\x9a\x97\xb0\x52\x0a\x28\xc9\x41\x88\xdb\x47\xac\xca\xf4\xf3\x56\x3a\x70\xa6\xb5\x1c\x81\x9b\x0a\x41\x3a\x50\x92\xa3\x76\x58\x41\xab\x2b\xb4\xe0\xb7\x08\xab\x86\xf1\x2d\xc2\x87\xf2\xfd\x29\x0a\xc2\xb4\xba\x4a\xa5\xa6\xf8\xaf\x9f\xd6\xb7\xf7\x0f\xb7\x20\xa4\x42\x18\xde\xac\x31\x1e\x2a\x69\x91\x7b\x63\x8f\x60\x04\xf8\x11\x98\xb7\x88\x65\x3a\x5b\xf4\x7d\x9a\x76\x1d\x54\x28\xa4\x46\xc8\x2a\xc9\x14\x72\xbf\x70\x07\xb5\x30\xb6\x42\xbb\x70\xb2\xd6\xcc\xb7\x16\x33\xb8\xe9\xfb\x74\xb1\x80\xdf\xc3\xfb\x5d\xab\x39\xb0\xa6\x51\x12\x1d\x30\x0d\x94\x2c\x75\x0d\x26\xa2\xbb\x83\x02\x87\x8a\xc0\xcb\xd4\x1f\x1b\x1c\xd5\x89\x56\xf3\x7c\xe6\x0e\xaa\x7c\x18\x52\x8a\x34\x74\xe6\x46\xb5\x7b\xbd\xde\x22\xdf\xa1\x05\x8b\xbe\xb5\xda\x01\xa3\x7c\x2f\x8d\x06\xa9\x2b\xc9\x99\x47\x07\x52\x10\x4c\xac\x00\xfc\x47\x3a\xef\x4e\x93\xd7\xf2\x11\xf5\x10\x2a\xd3\x50\x3c\xed\x9c\x7b\xb6\x51\x08\xce\x07\xc2\x45\x64\x73\xfa\x83\xd6\x1a\x0b\x5d\x9a\xf0\x90\xeb\x60\xf9\x11\xf6\xac\xf9\x12\xc3\x5f\x27\xa9\x1b\x63\x54\x97\x26\x49\xd7\xdd\x80\x65\xba\x46\xb8\xd2\x21\xff\xaa\xbc\x37\x15\x3a\xe8\xfb\x34\x09\x51\xb8\xd2\xe5\x1f\x8c\xef\x58\x8d\xd0\xf7\xe5\xe7\x00\xbe\x84\x8b\xf7\x3f\x99\x92\xd5\x9a\x78\xce\x87\xae\xa8\x2b\xea\xd2\x0f\x74\xe6\x60\x76\x01\x21\x72\xfb\x42\x63\x7c\x4d\x13\x29\xe0\x9d\xd9\x05\xd2\x49\xd4\xec\x85\x91\x9e\xc3\x7b\x5f\xde\x86\x57\x91\x67\x5d\x07\x1b\xe6\x10\xae\xca\xb5\xd1\x42\xd6\x23\x4a\x4b\x68\xf5\x4e\x9b\x27\x0d\x51\xb0\xeb\x43\x36\x8f\x3f\x8b\x34\x49\x22\xaf\x31\xe0\xb0\x8c\x4b\xdc\x40\x90\x38\x0f\x29\xc5\xff\x63\x33\xe0\x5c\x1f\x40\x18\x3b\xe6\x16\x03\x53\x8e\x27\x14\x2d\x55\x20\xdc\xa7\x23\x61\x5f\x77\x7e\x98\x29\x9a\x9e\x96\x71\x25\xe2\x72\x1f\xb8\x69\xb0\x24\x1f\x53\x2c\x6a\x0d\x53\x3b\xd3\x80\x34\x32\xed\x6b\xe2\x3e\x17\x2d\x70\xcf\xf6\x98\x17\x81\x65\x98\xe2\xef\x39\x50\xff\x68\x24\x21\x51\x55\x2e\xaa\x24\x45\x50\xf2\xbc\xf6\x5c\x14\x3f\xd3\xc3\xbb\x8f\x61\xa6\x98\x93\xb8\x72\x55\x55\x24\x62\xfe\x03\x19\x89\x85\xef\x85\x1e\xba\x00\xb4\x04\x31\x0f\x55\xcb\x89\xda\xe7\x8f\x96\x56\xb5\x84\xeb\xa7\x8c\xb2\x8a\x3e\xd0\x22\xf5\x12\x57\xd2\x77\xfb\xcb\x31\x0f\x03\x06\xd7\x0a\xe8\xfb\xdc\x95\xeb\x5c\x14\xc5\xd9\x08\x53\x55\x17\x33\xe0\xad\xf3\x66\x0f\xe7\xf3\x41\xbb\xaa\xad\x69\x9b\x9b\xcd\xf1\xf9\x8b\xa6\x0b\xf4\xc2\x1a\x28\xfb\xdb\x03\x94\xd0\x2d\x59\xd5\xb5\xc5\x9a\x79\x7c\xe9\x9e\x0c\x2e\x7c\xcb\xb6\x23\x0c\x73\xd9\x6b\xfb\x8c\xed\x26\xbe\x0d\x09\x2b\x97\x0b\x9d\xbb\x62\x1e\x40\x8a\x4b\x21\x5e\x01\xbc\xb0\x97\xfe\xbe\xbf\x28\xf8\x24\xfd\xf6\x2e\x98\x62\x9c\xf3\xd7\xf9\xf1\x6d\xc4\x43\x27\x29\xc6\xcd\xe2\x89\x7a\xab\x4f\x2f\xad\x18\x9a\xfc\x47\x3b\x86\xd2\xef\x58\x92\xb4\x79\xd9\x92\x27\xed\xb3\xec\x64\xd0\xf1\x9d\x1c\x6f\xa6\xeb\xc2\xac\x78\x20\x61\xb3\xdf\x90\xe9\x0c\xfa\x7e\xf5\x58\x77\x1d\xa0\x72\xe1\xa2\x90\x99\x75\xfc\x11\x5b\xe4\xb1\x6a\xa2\x10\x39\x9d\x06\x7d\xae\xcc\x66\xd9\xb9\xe6\xdb\xad\xff\x1b\x00\x00\xff\xff\x15\x9b\xac\x20\x12\x08\x00\x00") func templateDialectSqlByTmplBytes() ([]byte, error) { return bindataRead( @@ -627,7 +627,7 @@ func templateDialectSqlByTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/dialect/sql/by.tmpl", size: 1522, mode: os.FileMode(420), modTime: time.Unix(1, 0)} + info := bindataFileInfo{name: "template/dialect/sql/by.tmpl", size: 2066, mode: os.FileMode(420), modTime: time.Unix(1, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -792,7 +792,7 @@ func templateDialectSqlGlobalsTmpl() (*asset, error) { return a, nil } -var _templateDialectSqlGroupTmpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x94\x54\x5f\x6b\x1b\x3f\x10\x7c\xf6\x7d\x8a\x8d\xc9\x2f\xdc\xf9\x77\x91\xd3\xbc\x35\x25\x0f\x89\x49\x4b\xa0\x84\xb6\x2e\x7d\x29\xa5\x28\xd2\xca\x16\x91\xa5\xb3\xa4\x73\x62\x0e\x7d\xf7\xb2\xf2\x39\x7f\xdc\xba\xa1\x60\xb0\xd8\x19\xed\x8c\x76\xc7\xee\xba\xf1\xa8\x98\xb8\x66\xed\xf5\x6c\x1e\xe1\xf4\xe4\xcd\xdb\xe3\xc6\x63\x40\x1b\xe1\x3d\x17\x78\xeb\xdc\x1d\x5c\x5b\xc1\xe0\xc2\x18\xc8\xa4\x00\x84\xfb\x15\x4a\x56\x7c\x9d\xeb\x00\xc1\xb5\x5e\x20\x08\x27\x11\x74\x00\xa3\x05\xda\x80\x12\x5a\x2b\xd1\x43\x9c\x23\x5c\x34\x5c\xcc\x11\x4e\xd9\xc9\x16\x05\xe5\x5a\x2b\x0b\x6d\x33\xfe\xf1\x7a\x72\x75\x33\xbd\x02\xa5\x0d\x42\x5f\xf3\xce\x45\x90\xda\xa3\x88\xce\xaf\xc1\x29\x88\xcf\xc4\xa2\x47\x64\xc5\x68\x9c\x52\x51\x74\x1d\x48\x54\xda\x22\x0c\xa5\xe6\x06\x45\x1c\x87\xa5\x19\xcf\xbc\x6b\x9b\x21\xa4\x44\x84\xc3\xdb\x56\x1b\xb2\x73\x76\x0e\x0d\x0f\x82\x1b\x38\x64\x53\xe1\x1a\x64\x97\x3d\xd2\x13\x3d\x0a\xd4\xab\x0d\xf3\xf1\xfc\x78\x9d\xf4\x54\x6b\x05\x94\x2f\xb8\x29\xc1\xe8\xb9\x4a\x4a\x15\x84\xa5\x99\x0a\x6e\x4b\x11\x1f\x40\x38\x1b\xf1\x21\xb2\xc9\xe6\xbb\x86\x15\x68\x1b\xd1\x2b\x2e\xb0\x4b\x15\xa0\xf7\xce\x43\x57\x0c\x94\xf3\xf0\xb3\x06\x95\xd5\xb9\x9d\x21\xec\xe8\x30\xa5\xd1\xc8\x40\xdc\x81\x56\x70\x40\x30\xfb\xc4\xc5\x1d\x9f\x21\xc1\xdf\xb8\xd1\x72\xe2\x4c\xbb\xb0\xa5\xaa\x32\x6d\xe0\x31\xb6\xde\xc2\x51\xc6\x78\xd4\xce\x5e\x91\x5e\x77\xc3\x17\x78\x06\xaa\x26\xf9\x33\x50\x8b\xc8\x72\x5d\x95\x43\x6d\x57\xc4\x85\x2c\x06\xff\x2d\x81\x7c\xe5\x81\x1e\xdf\xae\x87\x35\xa8\x2a\x15\x83\x41\x2a\xe8\x13\xd0\xe4\x1d\x91\xe5\x5d\xb3\x61\x69\x3e\xb7\xe8\xd7\x65\x55\x90\x5b\xf4\x99\xb5\xbd\x41\x72\x65\xf5\x2e\x97\x0f\xce\xc1\x6a\x93\xfd\xf6\x76\xd1\xfb\xdc\xdf\xbb\xfb\x40\xb7\x8e\xc2\xd2\xb0\x2f\xee\x3e\x74\xa9\x18\x2c\xa9\x6b\x0d\xdc\xcf\xc2\x8b\x8e\x7f\x50\xdb\xf5\x24\x3d\x9d\x7a\xa6\x88\x0f\x35\x3c\x6b\x56\x03\xc9\xbd\xea\x49\xa2\x42\x9f\xa9\x6c\x62\x5c\x40\x52\xec\x29\xe4\x92\xf6\x3e\xa5\xa4\x97\x44\xa9\x61\x55\x15\xa9\xf8\x97\xe0\xf4\xcf\x80\x51\xee\xb6\x1d\x70\xf7\xea\xb0\x8b\x81\xc8\xab\xcf\x43\x59\xf0\x3b\x2c\xbf\xff\x08\xd1\x6b\x3b\xab\xe1\xa4\x06\x83\x76\x57\xbe\xcf\x53\x05\xff\xff\x86\x12\x68\x43\x55\x3d\x35\x3d\x07\xde\x34\x68\x65\xd9\x17\xea\x3d\xe9\x64\x8c\x55\x4f\x59\xb6\x7f\x09\xb3\xdd\x24\x79\xbf\x80\xb2\xe5\xf6\xcd\x1b\xb5\x3d\x61\x27\x9b\xe9\x69\x09\xdb\x3c\x6c\x66\xb7\x6d\x47\xbe\xd8\x07\x8a\xf1\xe5\x7a\xcf\x20\xb2\xf5\xfc\x4f\x80\x56\xd2\xcf\xfd\x57\x00\x00\x00\xff\xff\x39\xd5\x28\xb4\x22\x05\x00\x00") +var _templateDialectSqlGroupTmpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x94\x53\x5d\x6b\x1b\x3b\x10\x7d\xde\xfd\x15\x13\x93\x1b\x76\x7d\x37\x72\x6e\xde\x6e\x2e\x79\x48\x4c\x6e\x09\x94\xd0\xd6\xa5\x2f\xa5\x14\x45\x1a\xd9\x22\xb2\xb4\x96\xb4\x4e\xcc\xa2\xff\x5e\x46\x5e\xe7\xc3\x6d\x1a\x0a\x06\x2f\x3a\x67\xe6\x1c\xcd\x1c\xf5\xfd\x64\x5c\x4e\x5d\xbb\xf1\x7a\xbe\x88\x70\x7a\xf2\xcf\xbf\xc7\xad\xc7\x80\x36\xc2\xff\x5c\xe0\xad\x73\x77\x70\x6d\x05\x83\x0b\x63\x20\x93\x02\x10\xee\xd7\x28\x59\xf9\x79\xa1\x03\x04\xd7\x79\x81\x20\x9c\x44\xd0\x01\x8c\x16\x68\x03\x4a\xe8\xac\x44\x0f\x71\x81\x70\xd1\x72\xb1\x40\x38\x65\x27\x3b\x14\x94\xeb\xac\x2c\xb5\xcd\xf8\xfb\xeb\xe9\xd5\xcd\xec\x0a\x94\x36\x08\xc3\x99\x77\x2e\x82\xd4\x1e\x45\x74\x7e\x03\x4e\x41\x7c\x26\x16\x3d\x22\x2b\xc7\x93\x94\xca\xb2\xef\x41\xa2\xd2\x16\x61\x24\x35\x37\x28\xe2\x24\xac\xcc\x64\xee\x5d\xd7\x8e\x20\x25\x22\x1c\xde\x76\xda\x90\x9d\xb3\x73\x68\x79\x10\xdc\xc0\x21\x9b\x09\xd7\x22\xbb\x1c\x90\x81\xe8\x51\xa0\x5e\x6f\x99\x8f\xdf\x8f\xe5\xa4\xa7\x3a\x2b\xa0\x7a\xc1\x4d\x09\xc6\xcf\x55\x52\xaa\x21\xac\xcc\x4c\x70\x5b\x89\xf8\x00\xc2\xd9\x88\x0f\x91\x4d\xb7\xff\x0d\xac\x41\xdb\x88\x5e\x71\x81\x7d\xaa\x01\xbd\x77\x1e\xfa\xb2\x50\xce\xc3\xf7\x06\x54\x56\xe7\x76\x8e\xb0\xa7\xc3\x94\x46\x23\x03\x71\x0b\xad\xe0\x80\x60\xf6\x81\x8b\x3b\x3e\x47\x82\xbf\x70\xa3\xe5\xd4\x99\x6e\x69\x2b\x55\x67\x5a\xe1\x31\x76\xde\xc2\x51\xc6\x78\xd4\xce\x5e\x91\x5e\x7f\xc3\x97\x78\x06\xaa\x21\xf9\x33\x50\xcb\xc8\xf2\xb9\xaa\x46\xda\xae\x89\x0b\x59\x0c\xfe\x5a\x01\xf9\xca\x03\x3d\xbe\xdd\x8c\x1a\x50\x75\x2a\x8b\x22\x95\xf4\x0b\x68\xf2\x8e\xc8\xf2\xbe\xd9\xb0\x32\x1f\x3b\xf4\x9b\xaa\x2e\xc9\x2d\xfa\xcc\xda\x55\x90\x5c\x55\xff\x97\x8f\x0f\xce\xc1\x6a\x93\xfd\x0e\x76\xd1\xfb\xdc\xdf\xbb\xfb\x40\x55\x47\x61\x65\xd8\x27\x77\x1f\xfa\x54\x16\x2b\xea\xda\x00\xf7\xf3\xf0\xa2\xe3\x2f\xd4\xf6\x3d\x49\x4f\x5f\x03\x53\xc4\x87\x06\x9e\x35\x6b\x80\xe4\xde\xf4\x24\x51\xa1\xcf\x54\x36\x35\x2e\x20\x29\x0e\x14\x72\x49\x7b\x9f\x51\xd2\x2b\xa2\x34\xb0\xae\xcb\x54\xfe\x49\x70\x86\x6b\xc0\x38\x77\xdb\x0d\xb8\x7f\x73\xd8\x65\x21\xf2\xea\xf3\x50\x96\xfc\x0e\xab\xaf\xdf\x42\xf4\xda\xce\x1b\x38\x69\xc0\xa0\xdd\x97\x1f\xf2\x54\xc3\xdf\x3f\xa1\x04\xda\x50\xd7\x4f\x4d\xcf\x81\xb7\x2d\x5a\x59\x0d\x07\xcd\x2b\xe9\x64\x8c\xd5\x4f\x59\xb6\xbf\x09\xb3\xdd\x26\xf9\x75\x01\x65\xab\xdd\x9d\xc9\x49\x7a\x9a\xf3\x6e\xe5\xdb\xf1\xec\x2a\x48\x9a\xbd\xa3\xa4\x5e\x6e\x5e\xb9\x6b\x76\x97\x1f\x3b\x5a\x49\x2f\xfa\x47\x00\x00\x00\xff\xff\x5b\xd3\x68\x36\x05\x05\x00\x00") func templateDialectSqlGroupTmplBytes() ([]byte, error) { return bindataRead( @@ -807,7 +807,7 @@ func templateDialectSqlGroupTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/dialect/sql/group.tmpl", size: 1314, mode: os.FileMode(420), modTime: time.Unix(1, 0)} + info := bindataFileInfo{name: "template/dialect/sql/group.tmpl", size: 1285, mode: os.FileMode(420), modTime: time.Unix(1, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -872,7 +872,7 @@ func templateDialectSqlPredicateTmpl() (*asset, error) { return a, nil } -var _templateDialectSqlQueryTmpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xec\x3b\x6b\x6f\xdb\xc6\xb2\x9f\xa5\x5f\x31\x25\x9c\x40\x32\x64\xda\xc9\xbd\xb8\xc0\x95\xe1\x0b\xf8\xc6\x36\x20\xa4\x75\x72\xea\xb4\xfd\x60\x18\x2d\x4d\x2e\xe5\x85\xa8\x25\x4d\x2e\xfd\x80\xa2\xff\x7e\x30\xb3\x0f\x2e\x5f\x92\xec\x36\xed\xc1\xe9\xf9\x90\xd8\xdc\x9d\xd9\x99\x9d\xf7\xce\xae\x57\xab\xc3\xfd\xe1\x87\x34\x7b\xce\xf9\xfc\x4e\xc2\xfb\xa3\x77\xff\x7b\x90\xe5\xac\x60\x42\xc2\x45\x10\xb2\xdb\x34\x5d\xc0\x4c\x84\x3e\x9c\x26\x09\x10\x50\x01\x38\x9f\x3f\xb0\xc8\x1f\x7e\xb9\xe3\x05\x14\x69\x99\x87\x0c\xc2\x34\x62\xc0\x0b\x48\x78\xc8\x44\xc1\x22\x28\x45\xc4\x72\x90\x77\x0c\x4e\xb3\x20\xbc\x63\xf0\xde\x3f\x32\xb3\x10\xa7\xa5\x88\x86\x5c\xd0\xfc\xf7\xb3\x0f\xe7\x97\x57\xe7\x10\xf3\x84\x81\x1e\xcb\xd3\x54\x42\xc4\x73\x16\xca\x34\x7f\x86\x34\x06\xe9\x10\x93\x39\x63\xfe\x70\xff\x70\xbd\x1e\x0e\x71\x0f\x70\x1a\x45\x5c\xf2\x54\x04\x09\xc4\x9c\x25\x51\x01\x71\xaa\x88\xdf\x96\x3c\x89\x58\xee\x03\x41\xaf\x56\x10\xb1\x98\x0b\x06\x5e\xc4\x83\x84\x85\xf2\xb0\xb8\x4f\x0e\xef\x4b\x96\x3f\x1f\x2a\x4c\x0f\xd6\xeb\xe1\x60\xb5\x3a\x80\x47\x2e\xef\x60\xcf\xff\x49\xb0\xa7\x2c\xcd\x25\x8b\x2e\xd2\x9c\xf1\xb9\xf8\xc8\x9e\x0b\x02\x1a\x20\xc4\xc5\xc7\x02\x6e\xd3\x34\x51\x38\x4c\x44\x40\x74\xec\xaf\x1b\x69\x7a\x0a\x18\xf6\xb2\xc5\x1c\xa6\x27\xb0\xe7\x5f\x85\x69\xc6\xfc\xcf\x41\xb8\x08\xe6\xcc\xcc\xea\x4d\x20\x44\x16\x14\x61\x90\x58\xc0\xff\xd7\x33\x1a\x30\x67\x21\xe3\x0f\x0a\xd2\xfe\x6e\xd1\x91\x9b\xb8\x14\x21\x8c\x6a\xb0\xeb\x35\xec\xbb\x54\xd6\xeb\x31\x14\xf7\xc9\x69\x92\x8c\x42\xf9\x04\x61\x2a\x24\x7b\x92\xfe\x07\xf5\x73\x0c\xa3\xeb\x1b\x82\xf7\x2f\x83\x25\xb2\x38\x01\x96\xe7\x69\x3e\x86\xd5\x70\xf0\x10\xe4\x30\x1a\x0e\x06\x22\x8d\x58\x01\x27\xd0\x00\x5d\xa1\xd4\x76\x93\xad\x15\xee\x09\x34\xb8\xf5\xf5\x8c\x5e\x4a\xcb\x79\x30\xf8\xb5\xc8\x58\xd8\x01\x4e\x92\xbe\xca\x58\x38\x1a\xd7\xa9\x9f\x47\x73\x66\xa8\x25\x69\x10\xb1\xe8\xcb\x73\xa6\xd8\x5e\xad\x20\x61\x02\x7c\x58\xaf\x6f\x50\xbb\x2b\x84\x21\xdc\x3c\x10\x73\x06\x7b\x0c\x45\xec\x6b\x64\x9c\xa9\xd3\xc4\x6f\xe6\x9f\x07\x73\x96\x7f\x9f\x06\xd1\x05\x9a\x16\x0a\xfa\xbb\x13\x10\x3c\x99\xd8\xd5\x2c\xf3\x83\x75\x63\x3b\xe3\x5d\x8d\xd0\x05\xbb\xf8\xe8\xee\x69\xc0\x63\x14\x86\xe6\x98\x4f\x1c\xae\x57\x2b\xe0\x31\xcc\x25\xec\x71\x38\x42\xc6\xbe\x7e\x45\x50\x45\xfc\x65\x9b\xb1\x68\xa0\x84\xe4\x28\x4e\xe6\x25\xa3\x31\xcb\x67\xb5\x5f\x1e\x83\x01\x54\x78\xa4\x3e\xff\x32\x8d\x98\xff\x21\x4d\xca\xa5\xc0\x15\x82\x2c\x63\x22\x1a\xb5\xe7\x26\xa4\x66\xc7\x51\x7c\x47\x30\xbe\xef\x8f\xb5\x4c\x5d\xa2\x6a\x95\xab\x30\x10\x3f\x07\x49\x49\x8a\x46\x77\x18\x85\x9a\xdc\xf5\x4d\x21\x73\x2e\xe6\x64\xe2\x5c\x48\x96\xc7\x41\xc8\x56\x35\x03\x27\xcb\x46\x29\xbe\xad\xd9\x75\x98\x8a\x98\xcf\xa7\x2d\xdb\x53\xe3\x6b\xc7\x23\xf4\x8e\xe8\x73\x02\xf8\x03\x59\xcd\x99\x2c\x73\x41\x9f\x7e\x61\x19\x34\x9c\x8d\x87\x03\xcb\xfe\x69\x51\xf0\xb9\xe8\x63\x7d\x02\x0f\x6a\x6b\xb5\x0d\x8c\xd5\x06\x88\x7f\x1e\xa3\x65\x2b\xfa\x63\x38\x39\x81\x23\x25\x7f\xcd\x41\xbc\x94\xfe\x39\x02\xc7\x23\xcf\x04\xa6\xf5\x7a\x0a\x9a\x6c\x18\x24\x09\x8b\x48\x73\x69\x29\xe9\x93\x8b\x39\x54\x32\xf5\x70\x37\x6b\x47\x4e\x44\xe8\xba\x22\x79\xf0\xee\xa6\xdf\x0b\x69\xff\x34\xe0\xd7\x1d\xd2\xf9\x6a\xba\xbd\x2b\xba\x80\xb8\xac\x0b\xcf\x88\x44\x09\x11\x51\x31\x61\x24\x49\xfa\x08\xcb\x52\x06\x12\xf9\xc7\x4c\x51\xdc\x27\xf3\x3c\xc8\xee\xfc\x7f\x98\x78\x01\xb7\xcf\x80\xa9\x90\x3d\x49\x26\x0a\x9e\x8a\x02\xd2\x1c\xca\x02\xf3\x1a\x5b\x66\x49\x20\x59\xe1\x53\x5e\x71\xf6\x23\x97\x59\x52\xe0\xc6\x97\x81\x0c\xef\xbe\x68\xb8\xae\x7c\x83\xea\x3c\xdc\x57\xf9\xc6\x0d\x2d\xb8\x02\x25\x03\xb5\x54\xe5\xe4\x4f\x86\xaa\x86\xd9\xab\x50\x8d\x34\xdc\xdf\x79\x8c\x6a\xc7\x95\xea\x5b\x43\x37\x2a\x30\xb4\x4f\x5a\xe6\x1a\xe5\xf8\xdb\x04\xc8\xd4\xc6\xc7\x84\xaf\xbc\x9c\x8c\xc4\x88\x9a\x27\xe4\x12\x24\xd0\x1e\x7b\x72\xb4\x82\x66\xce\x13\x2b\x7d\x37\x80\xd6\xb4\x6f\x65\x48\xf2\x8e\x60\x0f\xbc\x1f\x59\xe8\x39\x1c\x7a\x08\xed\x21\xae\x11\x8a\x55\x44\x97\x80\x19\x46\x2c\xb4\x1c\x2e\xe6\x9e\x09\xd6\x7d\xd2\x6a\x33\xfc\xa2\x6c\xf9\x21\x2d\x85\xec\xc9\x97\x5c\x48\x37\x84\xa8\x3c\x35\xdd\x92\xa8\xfe\x46\x86\xaa\x45\x6f\xf7\x45\xb2\xdc\xd9\x4a\x5f\xa6\xa7\xf3\x27\x5e\xf4\xe9\x09\x13\xbd\xab\x28\x31\x31\x0e\xd4\xe4\xc0\x55\xf8\xd8\x7a\x5a\xdb\x53\xe2\x20\x29\xd8\xa4\x37\xa8\x86\x77\x2c\x5c\x00\x43\x96\x98\x08\xd9\x14\xde\x3c\x7a\x44\x53\xc5\x2a\x63\x92\xf0\x7f\x70\xf4\x52\x93\x74\x6c\x09\xf6\x3b\x2c\xc6\xb5\xc3\xb7\xed\x79\xdc\x03\x6a\x60\xea\x4c\xe2\xb7\x99\x1b\x7c\x09\x6e\x13\x36\x6d\x25\x61\x1a\xa6\xf2\x46\xe7\xe9\x36\x88\x49\xe0\x08\x34\x3b\x73\x09\x50\x61\x61\x29\x0c\x30\xda\x4f\x55\x3d\x4f\x25\x88\x3f\x3b\xf3\x71\x0c\x35\x56\x48\x53\x7c\x12\xa8\x5a\xb3\x4d\xcb\xa0\x11\x46\x20\xa4\x41\xa0\xff\xe9\xbf\x8b\x3c\x5d\xb6\xd3\x76\x71\x4f\x35\xda\x4f\x82\xdf\x97\x6c\x4a\x75\xcc\xc4\x44\xbb\x92\x06\xbb\xac\x42\xcd\x1c\x1b\x08\xc7\x1c\x54\xea\x56\xcb\xc1\x09\xec\x2b\x08\xb3\xa2\x3e\xb2\x74\xac\xa8\x66\x8e\x29\xc2\xaa\xdf\xc7\x68\x0c\xce\x9a\x8d\x7a\x69\x19\x2c\xd8\xa8\x2a\x06\x8e\x26\x2e\xea\xb8\x0f\xeb\x05\x55\x56\x87\x44\x71\x59\x3c\x6e\x71\x3a\x75\x50\x68\xd0\x3b\x5a\xe9\x42\x4f\x7d\x5e\xf3\x1b\x94\xc9\x0e\x2b\xbe\xb2\x22\xb4\x64\x4c\xf1\x87\xff\x94\x88\xb3\x4e\xf1\x66\x39\x8b\x78\x88\xd1\x51\x89\x38\x6b\x89\xf7\xb3\x81\x30\x05\x57\xc1\x12\x3a\x93\x92\x53\xf9\x57\xfa\x4b\x95\x86\x4d\x29\x64\xa6\xb4\xcd\x90\x29\x8b\xda\xe6\x2e\xe1\x4b\x2e\xbb\x18\xa4\x89\x63\x3d\xdf\xb2\xa7\xef\x69\xf8\x04\xf6\x69\xde\x2c\x96\xc6\x71\xc1\x3a\x57\x53\x33\xc7\x06\xa2\xb5\xde\x27\x35\x7e\x02\xfb\x0a\x62\xb3\xf0\xd2\x3c\x62\x79\x9f\xdc\x3e\xe1\xe4\x1f\x28\xb3\xb6\x21\xfe\x1c\x24\x3c\x52\xaa\x6f\x08\x54\x87\x4d\x62\x64\xa8\x0e\xe0\x9b\xaa\x04\x43\xc2\x53\x69\x6a\x38\xac\x65\xde\x20\xc2\xea\x01\x96\x4c\xde\xa5\x51\x01\x32\xa5\x14\x4c\x98\x07\x26\xde\xee\x9c\x7d\x5f\x93\x7c\x03\xdb\xda\x30\x29\x78\x4b\x06\xee\x4f\xc0\xfd\x0d\x8a\x5d\x7a\x15\x8e\xa0\xfe\xca\xb6\x04\xa5\xa8\xee\x04\x5e\x33\x2f\x34\x23\x87\x41\x9c\x3a\x53\x5b\x6a\x92\xd1\xf5\x84\x9d\xc6\x38\x29\xdf\x21\x92\x69\x1e\x51\x56\x1b\x75\xe6\xba\xf1\x70\x60\xcd\xdb\xc1\x50\x5c\x8c\xe4\x3b\x13\x9d\x5a\xd8\x7a\x1c\x8f\xaa\xf4\x0f\x13\xd1\x48\xbe\x53\xd5\x44\x47\x3e\x72\xdd\xd5\x52\xec\xac\x4c\x1c\x00\xc3\x87\xfd\xde\x91\x9b\xed\x47\xa5\x4a\xce\xdf\xb8\xf6\xd4\x64\xbe\x65\xfd\x89\x21\xe8\xd7\x09\x64\x55\x14\xea\x4f\x14\x24\xff\xcc\x8d\xe5\x3b\x2d\x40\xc1\xb2\x81\xbb\x2d\xa6\xbd\x3e\x9c\x1f\x1e\xea\x94\xc1\x0b\x58\x06\x22\x0a\xa8\x89\x8a\x5c\x6a\xd8\x30\x09\xca\x82\xf9\xf0\x0b\x83\x42\x06\xb9\x54\x38\xa4\x9b\x88\xc5\x41\x99\x48\x75\x68\x9e\x40\x20\x22\x48\x1f\x58\x9e\xf3\x88\x01\x97\x70\xcb\xd0\x1a\x78\x0c\x82\xb1\x88\x45\xbe\x6b\x6c\x2a\x7f\x8c\x74\xf6\x18\xab\xfc\x34\x5a\x06\xf2\xce\xff\x21\x78\x9a\x09\xf9\x5f\xef\xc7\xaf\x4e\x79\x96\x8a\x5a\x55\xe5\xbc\x5a\x9d\x6c\x20\x86\xeb\x7a\x48\x3b\xdc\x57\xf1\xfa\x30\x0b\xd4\xfe\xb8\x60\x45\x15\xc6\x61\xce\x04\xcb\x03\x8c\xb1\x24\x22\x82\x4a\x63\x08\x60\xce\x1f\x98\x00\x16\xcd\xd9\x2e\xed\x63\xc4\xab\x02\xf4\x9e\x20\xcb\xa4\x62\x06\x39\x40\x72\xd4\x12\x79\xd4\x22\x77\x18\x88\xf3\x74\xa9\x29\x28\x5c\xe6\xf6\x82\xf1\xcc\x5b\x5b\x06\x19\xc2\x65\x50\x03\x98\x90\x90\xff\x79\x8e\xd6\x8e\xb3\xc4\xbe\x4c\x6b\xeb\xf1\x08\x3d\xd4\x59\x73\x46\x03\x07\x16\xc0\x0d\xd0\x06\xe6\xc7\x4a\x29\xf5\x18\xd7\x11\x72\x6c\x50\x1e\xd7\xda\x0e\x46\x65\xe7\x79\x3e\xda\xad\x9b\x50\x48\x96\xd5\x3a\x16\x97\xec\xf1\x4a\xb2\x6c\x84\x16\x60\xcf\x09\x18\x2a\x91\x0b\xd1\x3e\x7a\x40\x6b\x5c\x0d\x34\x0e\x01\x96\xb7\xf1\xc4\x5d\xf9\x4b\x3a\x52\xcd\x4e\x3a\x67\x74\x2f\xde\x9e\x74\x46\x1b\x95\x71\x6d\x71\x54\xe4\xc8\x7e\x29\xa4\x1f\x59\x42\x88\xc4\x93\x1a\x9a\x15\x33\xf1\xc0\xf2\xa2\x1a\x6b\x6d\x87\x29\x7e\x9a\xa7\x1a\x54\x25\x8f\x71\xfa\x87\xf7\x3f\x28\xed\xea\xd6\x74\xc7\x0a\x9f\x3f\x3a\xe8\xbe\xef\xdb\x0e\x6d\x52\xb0\x6d\xb8\x2a\x44\x39\xf8\x6e\x7b\x57\xe1\xe2\xd6\x77\xeb\x62\xa0\x7a\xbf\x71\x12\x41\x9f\xf8\x96\x09\x04\xa5\xa2\x7c\x6c\xbd\x06\xc7\x78\xaf\x98\xbc\x64\x7c\x7e\x77\x9b\xe6\xc5\xd6\xaa\x63\x02\x68\xfc\xe3\x9e\xd8\x85\x31\x62\x7b\xec\x0a\x54\xb8\x72\xe2\x8a\x0d\x63\xd4\x1f\xdd\xe5\x16\x2c\x4f\x97\xff\x96\x61\x8c\xc0\x78\xd4\x15\xc1\x66\x67\x7f\x62\xe4\xe1\xd1\x7f\x62\xce\xdf\x20\xe6\xa0\x37\xfc\x15\x31\xe7\x77\x06\x9c\x0d\x91\xa1\xde\x4e\xdf\xe8\xe5\x9b\xfd\xd1\xdc\x4e\xa8\xb0\xd1\xe1\x8f\x3d\x17\x8e\xc7\x1a\xc3\xa9\x21\xea\xe6\xa7\x84\x19\x2f\xb4\x86\xa8\x15\xa6\x77\xfd\xb3\x2a\x67\x75\x43\x4c\xdd\x56\x50\xdf\x80\x47\x15\xf4\x32\xc8\xae\xdd\x4e\x23\xac\xd7\xcd\xab\xee\x06\xb6\xae\xfc\xcd\x65\x97\xd2\xb1\xba\xea\x53\x6d\x0c\x1e\x15\xd7\x14\x7a\x67\x67\x37\xa0\x6e\xc3\x68\x1c\x99\xb4\xcd\xac\x78\xa1\xaf\x02\xfd\xd9\x19\x2d\xeb\x5e\x84\x29\x61\x5c\xc9\xbc\x0c\xa5\xbd\x7a\xb5\x37\xeb\x3a\x06\xb8\xf7\xeb\xea\xb6\xd5\xdc\xc4\x0f\x06\x18\x68\x71\x97\xd7\x37\xf5\xa0\xa1\x77\x68\x61\x6c\xff\xd0\x88\xa1\x05\x7a\xd3\xb8\xce\x27\x5e\xe9\xbf\x8e\x5e\x32\x72\x5f\xeb\x27\x0f\x06\x38\x34\x6d\x80\x54\xb3\x03\x1d\x83\xa6\x5d\x41\x49\x41\xf4\x74\x9d\x37\xc4\xa7\x0d\x8d\xe8\x8e\x98\xa4\x50\xf4\x0f\xdb\xfb\x9b\xea\x36\x56\x67\xff\x6a\x30\x28\xfc\x5f\xee\x58\x4e\x61\xd6\x9f\x99\xcb\xc7\x1d\x88\x5d\xab\x4b\xf8\xc6\x4e\xdf\xa1\x3b\x26\xf4\xeb\x91\xf5\xcc\x9b\x09\xc4\x0b\x3a\x9d\x8f\x5d\x0e\xc9\xc9\xd2\x92\x72\xa2\x87\xe4\x2f\xcb\x24\x99\x09\xf9\x3f\xff\xed\xd9\x3b\x7e\x32\xe6\x9f\x0a\x96\x9f\x91\x63\x9b\xfb\x7d\xc4\x3a\x51\x93\x88\xa4\x15\x5c\x85\x02\xbb\x3c\x17\x1b\x57\xaf\x6c\xa4\x4d\x83\x0b\x24\x51\x41\xf4\x12\xaa\xae\x8f\xb5\xa8\xc7\x70\xfd\xde\xbd\xc0\xd6\x92\xd6\x27\x86\xc6\xdc\x5b\xb3\x1f\xb4\xff\x89\xba\x98\xe7\x82\xbe\xd6\xae\xb4\xd4\x15\xb6\xa6\x90\x96\x72\x02\x5c\x40\xcf\x2d\x39\xba\x04\x81\xa4\x0b\xdc\x7e\x5a\x4a\x7f\xb4\x5f\xd1\x51\x5a\xc0\x18\xf6\x5d\xba\x80\xaf\x5f\x81\x91\x3c\xab\xb8\x34\xe8\xbe\x51\x2f\x05\x7b\xca\x58\x28\x59\x04\x3c\x52\x47\x6c\xaa\xdb\xd0\xfd\x0e\xd2\x52\x7a\x7a\x61\xfd\xaa\x84\x71\x61\x38\xe0\x42\x33\x40\x3b\x6b\xd3\x47\x59\xff\x3e\xf2\x5c\x34\xa8\xa7\xa5\x24\xa5\xe8\x08\xdd\xb8\x9a\x3d\xcd\xe7\x1e\x78\xb8\x6f\x0f\x3c\x8a\x4b\x1e\x99\x13\x78\x46\xcd\x9e\xd5\xca\xee\xd7\xb4\x87\xcb\xf7\x4b\x75\x89\xef\x99\x77\x2a\x8e\x9d\x0c\xb8\xd8\xce\x11\x17\x0e\x43\xd6\xf8\x6a\x6c\x29\xeb\xf8\xc3\xb8\xc2\x60\x6d\xf5\x14\x15\xd7\x46\x70\x37\x35\x2d\xed\xa6\x17\xca\x24\x3c\x42\xd3\xa4\x98\x3c\x85\x37\x0f\xde\x04\xcc\x92\x75\x0d\xf1\x18\x93\x8f\x22\x4c\xd0\xd7\x5a\x40\x37\xc7\x35\x92\x26\x03\xd8\x84\xa3\x07\xd0\x03\x3a\x96\xad\x2f\x55\xc7\xaa\xc6\xab\xc7\x2b\x03\xf7\x2c\xef\x78\x9c\x2d\x6d\xb6\x54\x63\xb5\x4c\xf1\x82\xb2\x6c\xf0\xaa\xc2\xac\xa6\xd8\xda\xd5\xf6\xf6\x3a\x6d\x63\xa5\xd6\x7c\xe1\xd4\x7a\xef\xd4\xf9\x1a\x83\x92\xfc\xab\x5e\x63\xd4\x3b\x28\x8e\x3d\xfd\xa6\xca\x24\x95\xd3\x3d\x95\x79\x74\xc6\xf6\xa6\xf0\xe6\xf1\x37\x73\xc7\xac\x55\x44\xe0\x3a\x89\x75\x9f\x36\x66\x67\x33\x61\x8c\xc6\x66\x21\x61\x2a\x4d\x7b\x4d\xae\x16\xd2\x0f\x07\xc7\xce\xae\x7b\xb9\xa6\xbe\x8f\x66\xc3\xd4\x52\x4e\x21\x65\x28\x68\x4c\xfd\x38\xc3\x35\x78\x3c\x5f\xdd\x0c\xdb\x6e\xd6\x27\x1a\xc7\xd5\x1a\x92\x51\xae\xa7\xf0\x58\x04\x6f\x1e\x7e\x9b\x80\xb0\x05\x99\x52\x70\xe3\xb2\xca\x2d\xf4\x14\x73\xd7\xfc\x66\x73\xd9\xe6\x3e\xff\xda\x0c\x3c\x01\xe1\x90\xb6\x47\x01\x2c\x0d\x54\xe2\xfd\xf4\x28\x2e\x3e\x1a\xeb\x8a\xdc\x9a\xb7\xb3\x78\xeb\x2a\x7e\xf1\xd7\xae\x02\x78\xb7\xca\x6f\x83\x34\xa8\x74\x88\x49\x4f\x7b\xcc\x79\x9a\xe7\x7a\x0a\xee\x22\x5e\xa8\x5b\x7f\xff\x92\x27\x09\xd6\x6f\x55\xa8\x8f\xc1\x0a\x09\x29\xc7\x8b\x96\x30\xeb\xf9\x2e\x4c\x85\xe4\x42\x3f\x38\x1c\x74\x39\xe4\x40\x31\xa4\x2b\x97\x2e\xd2\xfb\x36\xb2\x6f\xa6\x3d\x6c\x46\x5f\x2d\xc8\xeb\x78\x51\x0f\xbd\xb5\xa8\x4b\x11\x37\x5e\xb8\x06\xe5\xe0\xd5\x8d\x43\x0f\x4e\xac\x10\x5e\xec\xae\xff\x42\xae\x6a\x36\xf4\x3b\x9c\x35\x56\x26\x74\xb0\x60\xcf\xca\x71\x2b\xf5\x19\xff\xfd\xd6\xae\x2b\x7a\xbc\xf1\x35\xa7\xce\x3e\xc7\xeb\x3d\x79\x6e\x73\xb8\xee\xf3\x24\x6d\xca\x3d\x53\x9a\xe3\xa8\x9e\x30\x47\x52\xfc\xac\xf9\x25\xf3\x3f\xd9\xf3\xf4\xae\xf2\xd9\x70\x22\xed\x78\xeb\xec\xe4\x6e\x8d\xb1\xe9\x8d\xb3\x36\xfa\xf6\xd3\xe2\xfa\xc2\xae\x6b\xd8\x3b\x3b\xbf\xcd\xd5\x68\xd3\x81\xee\x05\xe7\xb9\x56\x53\xaa\x7e\x4e\x5b\xff\x55\xde\xb7\x2d\xfc\xaa\x29\xb1\x21\xb8\xed\x14\xa0\xe3\x45\xcf\xa1\xa3\xe5\xc9\x3b\xb8\x2f\x2f\x68\x25\xdc\x1c\x65\xe2\x86\x17\x77\x47\x74\xb7\xf4\x36\xf6\xbd\x53\x7c\x8f\x17\x7f\x4a\x20\x6a\x6c\x67\x47\xd6\xda\xd1\x6b\x7b\x3f\xc8\x9e\xc7\xf5\x2b\xb3\xf5\x5a\x54\x9d\x04\x27\xb1\x6c\xa9\x38\x6a\x47\x9b\x66\xff\x71\xfd\xaa\x5e\xa1\x7b\x7a\xb2\xad\xc1\x20\xaf\xfd\xc1\xcb\x69\x3e\xaf\xe6\xe8\x1d\x97\x3b\x5b\x59\xa6\xba\x93\x28\x93\x44\x62\x90\x74\x40\x9c\xde\xc2\xd0\x18\xef\x5d\x50\x7c\xce\x59\xcc\x9f\x1c\x14\xaf\xb8\x4f\x3c\xdd\x2f\x26\xed\x29\x3d\x68\x6c\x45\x88\x98\xb3\xb7\x0a\x4e\x73\x5a\xc9\x58\xa4\xd2\xe2\x75\xe9\x0f\x97\x0d\x9c\xfd\xbc\xf0\x6f\x82\x0e\xb3\x9c\x65\x41\xce\xe8\x71\xa8\x23\xb1\xfe\x3f\x11\xda\xf1\x0e\x57\x87\x8d\xb8\xff\x95\x82\xf3\x6c\x0f\xfd\x62\xc3\xf3\x84\x51\x3c\xae\xfd\xd9\xc0\x5b\x9a\xa3\x8b\x28\xf2\x99\x15\x3a\xc3\x14\x62\x0a\x5c\xd3\xde\x07\xb0\x5c\x3c\x20\x9e\x7a\xb9\x07\x6f\xee\xc9\x5d\xd4\x5f\x46\x4d\x20\x1e\xdb\x37\x5d\x8e\xec\xfe\x19\x00\x00\xff\xff\x80\x4c\x37\xd5\xb4\x36\x00\x00") +var _templateDialectSqlQueryTmpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xec\x3b\x6d\x6f\xdb\x46\xd2\x9f\xa5\x5f\x31\x25\x9c\x40\x32\x14\x3a\xc9\xf3\xe0\x01\x1e\x19\x3e\xc0\x17\xdb\x80\x90\xd6\xc9\xd5\x69\xfb\xc1\x30\x5a\x9a\x1c\xca\x0b\x51\x4b\x9a\x5c\xf9\x05\x8a\xfe\xfb\x61\x66\x5f\xb4\x7c\x91\x2c\xbb\x97\xf6\x70\xbd\x0f\x49\xc4\xdd\x99\x9d\xd9\x79\xdf\xd9\xcd\x72\x79\xb0\xdf\xff\x90\x17\x8f\xa5\x98\xde\x28\x78\xff\xf6\xdd\xff\xbf\x29\x4a\xac\x50\x2a\x38\x8b\x62\xbc\xce\xf3\x19\x4c\x64\x1c\xc2\x71\x96\x01\x03\x55\x40\xf3\xe5\x1d\x26\x61\xff\xcb\x8d\xa8\xa0\xca\x17\x65\x8c\x10\xe7\x09\x82\xa8\x20\x13\x31\xca\x0a\x13\x58\xc8\x04\x4b\x50\x37\x08\xc7\x45\x14\xdf\x20\xbc\x0f\xdf\xda\x59\x48\xf3\x85\x4c\xfa\x42\xf2\xfc\xf7\x93\x0f\xa7\xe7\x17\xa7\x90\x8a\x0c\xc1\x8c\x95\x79\xae\x20\x11\x25\xc6\x2a\x2f\x1f\x21\x4f\x41\x79\xc4\x54\x89\x18\xf6\xf7\x0f\x56\xab\x7e\x9f\xf6\x00\xc7\x49\x22\x94\xc8\x65\x94\x41\x2a\x30\x4b\x2a\x48\x73\x4d\xfc\x7a\x21\xb2\x04\xcb\x10\x18\x7a\xb9\x84\x04\x53\x21\x11\x82\x44\x44\x19\xc6\xea\xa0\xba\xcd\x0e\x6e\x17\x58\x3e\x1e\x68\xcc\x00\x56\xab\x7e\x6f\xb9\x7c\x03\xf7\x42\xdd\xc0\x5e\xf8\x93\xc4\x87\x22\x2f\x15\x26\x67\x79\x89\x62\x2a\x3f\xe2\x63\xc5\x40\x3d\x82\x38\xfb\x58\xc1\x75\x9e\x67\x1a\x07\x65\x02\x4c\xc7\xfd\xdc\x4a\x33\xd0\xc0\xb0\x57\xcc\xa6\x30\x3e\x82\xbd\xf0\x22\xce\x0b\x0c\x3f\x47\xf1\x2c\x9a\xa2\x9d\x35\x9b\x20\x88\x22\xaa\xe2\x28\x73\x80\x7f\x37\x33\x06\xb0\xc4\x18\xc5\x9d\x86\x74\xbf\x1d\x3a\x71\x93\x2e\x64\x0c\x83\x1a\xec\x6a\x05\xfb\x3e\x95\xd5\x6a\x08\xd5\x6d\x76\x9c\x65\x83\x58\x3d\x40\x9c\x4b\x85\x0f\x2a\xfc\xa0\xff\x1d\xc2\xe0\xf2\x8a\xe1\xc3\xf3\x68\x4e\x2c\x8e\x00\xcb\x32\x2f\x87\xb0\xec\xf7\xee\xa2\x12\x06\xfd\x5e\x4f\xe6\x09\x56\x70\x04\x0d\xd0\x25\x49\x6d\x37\xd9\x3a\xe1\x1e\x41\x83\xdb\xd0\xcc\x98\xa5\x8c\x9c\x7b\xbd\x5f\xab\x02\xe3\x0e\x70\x96\xf4\x45\x81\xf1\x60\x58\xa7\x7e\x9a\x4c\xd1\x52\xcb\xf2\x28\xc1\xe4\xcb\x63\xa1\xd9\x5e\x2e\x21\x43\x09\x21\xac\x56\x57\xa4\xdd\x25\xc1\x30\x6e\x19\xc9\x29\xc2\x1e\x92\x88\x43\x83\x4c\x33\x75\x9a\xf4\x8d\xe1\x69\x34\xc5\xf2\xfb\x3c\x4a\xce\xc8\xb4\x48\xd0\xdf\x1d\x81\x14\xd9\xc8\xad\xe6\x98\xef\xad\x1a\xdb\x19\xee\x6a\x84\x3e\xd8\xd9\x47\x7f\x4f\x3d\x91\x92\x30\x0c\xc7\x62\xe4\x71\xbd\x5c\x82\x48\x61\xaa\x60\x4f\xc0\x5b\x62\xec\xeb\x57\x02\xd5\xc4\x9f\xb7\x19\x87\x06\x5a\x48\x9e\xe2\x54\xb9\x40\x1e\x73\x7c\xae\xf7\x2b\x52\xb0\x80\x1a\x8f\xd5\x17\x9e\xe7\x09\x86\x1f\xf2\x6c\x31\x97\xb4\x42\x54\x14\x28\x93\x41\x7b\x6e\xc4\x6a\xf6\x1c\x25\xf4\x04\x13\x86\xe1\xd0\xc8\xd4\x27\xaa\x57\xb9\x88\x23\xf9\x73\x94\x2d\x58\xd1\xe4\x0e\x83\xd8\x90\xbb\xbc\xaa\x54\x29\xe4\x94\x4d\x5c\x48\x85\x65\x1a\xc5\xb8\xac\x19\x38\x5b\x36\x49\xf1\x75\xcd\xae\xe3\x5c\xa6\x62\x3a\x6e\xd9\x9e\x1e\x5f\x79\x1e\x61\x76\xc4\x9f\x23\xa0\x7f\x88\xd5\x12\xd5\xa2\x94\xfc\x19\x56\x8e\x41\xcb\xd9\xb0\xdf\x73\xec\x1f\x57\x95\x98\xca\x4d\xac\x8f\xe0\x4e\x6f\xad\xb6\x81\xa1\xde\x00\xf3\x2f\x52\xb2\x6c\x4d\x7f\x08\x47\x47\xf0\x56\xcb\xdf\x70\x90\xce\x55\x78\x4a\xc0\xe9\x20\xb0\x81\x69\xb5\x1a\x83\x21\x1b\x47\x59\x86\x09\x6b\x2e\x5f\x28\xfe\x14\x72\x0a\x6b\x99\x06\xb4\x9b\x95\x27\x27\x26\x74\xb9\x26\xf9\xe6\xdd\xd5\x66\x2f\xe4\xfd\xf3\x40\x58\x77\x48\xef\xab\xe9\xf6\xbe\xe8\x22\xe6\xb2\x2e\x3c\x2b\x12\x2d\x44\x42\xa5\x84\x91\x65\xf9\x3d\xcc\x17\x2a\x52\xc4\x3f\x65\x8a\xea\x36\x9b\x96\x51\x71\x13\xfe\xc3\xc6\x0b\xb8\x7e\x04\x4a\x85\xf8\xa0\x50\x56\x22\x97\x15\xe4\x25\x2c\x2a\xca\x6b\x38\x2f\xb2\x48\x61\x15\x72\x5e\xf1\xf6\xa3\xe6\x45\x56\xd1\xc6\xe7\x91\x8a\x6f\xbe\x18\xb8\xae\x7c\x43\xea\x3c\xd8\xd7\xf9\xc6\x0f\x2d\xb4\x02\x27\x03\xbd\xd4\xda\xc9\x1f\x2c\x55\x03\xb3\xb7\x46\xb5\xd2\xf0\x7f\x8b\x94\xd4\x4e\x2b\xd5\xb7\x46\x6e\x54\x51\x68\x1f\xb5\xcc\x35\x29\xe9\xd7\x08\xd8\xd4\x86\x87\x8c\xaf\xbd\x9c\x8d\xc4\x8a\x5a\x64\xec\x12\x2c\xd0\x0d\xf6\xe4\x69\x85\xcc\x5c\x64\x4e\xfa\x7e\x00\xad\x69\xdf\xc9\x90\xe5\x9d\xc0\x1e\x04\x3f\x62\x1c\x78\x1c\x06\x04\x1d\x10\xae\x15\x8a\x53\x44\x97\x80\x91\x22\x16\x59\x8e\x90\xd3\xc0\x06\xeb\x4d\xd2\x6a\x33\xfc\xac\x6c\xf9\x21\x5f\x48\xb5\x21\x5f\x0a\xa9\xfc\x10\xa2\xf3\xd4\xf8\x89\x44\xf5\x17\x32\x54\x23\x7a\xb7\x2f\x96\xe5\xce\x56\xfa\x3c\x3d\x9d\x3e\x88\x6a\x93\x9e\x28\xd1\xfb\x8a\x92\x23\xeb\x40\x4d\x0e\x7c\x85\x0f\x9d\xa7\xb5\x3d\x25\x8d\xb2\x0a\x47\x1b\x83\x6a\x7c\x83\xf1\x0c\x90\x58\x42\x19\xe3\x18\x5e\xdd\x07\x4c\x53\xc7\x2a\x6b\x92\xf0\x37\x78\xfb\x5c\x93\xf4\x6c\x09\xf6\x3b\x2c\xc6\xb7\xc3\xd7\xed\x79\xda\x03\x69\x60\xec\x4d\xd2\xb7\x9d\xeb\x7d\x89\xae\x33\x1c\xb7\x92\x30\x0f\x73\x79\x63\xf2\x74\x1b\xc4\x26\x70\x02\x9a\x9c\xf8\x04\xb8\xb0\x70\x14\x7a\x14\xed\xc7\xba\x9e\xe7\x12\x24\x9c\x9c\x84\x34\x46\x1a\xab\x94\x2d\x3e\x19\x54\xaf\xd9\xa6\x65\xd1\x18\x23\x92\xca\x22\xf0\xdf\xfc\xd7\x59\x99\xcf\xdb\x69\xbb\xba\xe5\x1a\xed\x27\x29\x6e\x17\x38\xe6\x3a\x66\x64\xa3\xdd\x82\x07\xbb\xac\x42\xcf\x1c\x5a\x08\xcf\x1c\x74\xea\xd6\xcb\xc1\x11\xec\x6b\x08\xbb\xa2\x39\xb2\x74\xac\xa8\x67\x0e\x39\xc2\xea\xdf\x43\x32\x06\x6f\xcd\x46\xbd\x34\x8f\x66\x38\x58\x17\x03\x6f\x47\x3e\xea\x70\x13\xd6\x33\xaa\xac\x0e\x89\xd2\xb2\x74\xdc\x12\x7c\xea\xe0\xd0\x60\x76\xb4\x34\x85\x9e\xfe\xbc\x14\x57\x24\x93\x1d\x56\x7c\x61\x45\xe8\xc8\xd8\xe2\x8f\xfe\x68\x11\x17\x9d\xe2\x2d\x4a\x4c\x44\x4c\xd1\x51\x8b\xb8\x68\x89\xf7\xb3\x85\xb0\x05\x57\x85\x19\x9f\x49\xd9\xa9\xc2\x0b\xf3\xa5\x4b\xc3\xa6\x14\x0a\x5b\xda\x16\xc4\x94\x43\x6d\x73\x97\x89\xb9\x50\x5d\x0c\xf2\xc4\xa1\x99\x6f\xd9\xd3\xf7\x3c\x7c\x04\xfb\x3c\x6f\x17\xcb\xd3\xb4\xc2\xce\xd5\xf4\xcc\xa1\x85\x68\xad\xf7\x49\x8f\x1f\xc1\xbe\x86\xd8\x2e\xbc\xbc\x4c\xb0\xdc\x24\xb7\x4f\x34\xf9\xed\x64\x66\x22\x23\xd3\xea\xeb\x33\xf6\xb6\x42\xc0\xae\x12\xe8\x4c\xd4\xef\xd7\x92\x6b\x94\x50\x81\x00\x73\x54\x37\x79\x52\x81\xca\x39\xcb\x32\xe6\x1b\x1b\x52\x77\x4e\xb0\x2f\xc9\xaf\x91\xeb\x5e\xd8\x2c\xfb\x44\x92\xdd\x9c\x63\x37\xf7\x20\x76\x69\x47\x78\x82\xfa\x33\x3b\x0f\x9c\x85\xba\x73\x74\xcd\x82\xc8\x52\x3c\x06\x69\xea\x44\x6f\xa9\x49\xc6\x94\x0c\x6e\x9a\x42\xa1\x7a\x47\x48\xb6\x3f\xc4\x89\x6b\xd0\x99\xce\x86\xfd\x9e\xb3\x60\x0f\x43\x73\x31\x50\xef\x6c\x00\x6a\x61\x9b\x71\x3a\x8d\xf2\x1f\xca\x35\x03\xf5\x4e\x17\x0c\x1d\x29\xc7\xf7\x48\x47\xb1\xb3\xf8\xf0\x00\x2c\x1f\xee\x7b\x47\x6e\x9e\x3e\x0d\xad\xe5\xfc\x8d\xcb\x4b\x43\xe6\x5b\x96\x98\x14\x65\x7e\x1d\x41\xb1\x0e\x34\x9b\x73\x01\xcb\xbf\xf0\x43\xcf\x4e\x0b\x70\x3c\xec\xc4\x7d\x61\x50\x3e\x38\x30\x81\x5f\x54\x30\x8f\x64\x12\x71\x2b\x94\x18\x31\xb0\x71\x16\x2d\x2a\x0c\xe1\x17\x84\x4a\x45\xa5\xd2\x38\x2c\xfe\x04\xd3\x68\x91\x29\x7d\xf4\x1d\x41\x24\x13\xc8\xef\xb0\x2c\x45\x82\x20\x14\x5c\x23\x29\x5c\xa4\x20\x11\x13\x4c\x42\xdf\x9e\x74\x16\x18\x98\x1c\x30\xd4\x59\x66\x30\x8f\xd4\x4d\xf8\x43\xf4\x30\x91\xea\x7f\xde\x0f\x5f\x9c\xb8\x1c\x15\xbd\xaa\xce\x5c\xb5\x6a\xd7\x42\xf4\x57\xf5\xa8\x75\xb0\xaf\x43\xf2\x41\x11\xe9\xfd\x09\x89\xd5\x3a\x52\xc3\x14\x25\x96\x11\x85\x51\x16\x11\x43\xe5\x29\x44\x30\x15\x77\x28\x01\x93\x29\xee\xd2\x04\x26\xbc\x75\x0c\xde\x93\x6c\x7c\x5c\x92\x10\x07\x44\x8e\x1b\x1b\xf7\x46\xe4\x1e\x03\x69\x99\xcf\x0d\x05\x8d\x8b\x7e\x47\x97\x4e\xae\xb5\x65\x88\x21\x5a\x86\x34\x40\x39\x87\xf8\x9f\x96\x64\xd0\x34\xcb\xec\xab\xbc\xb6\x9e\x48\xc8\x09\xbd\x35\x27\x3c\xf0\xc6\x01\xf8\x31\xd8\xc2\xfc\xb8\x56\x4a\x3d\x8c\x75\x44\x15\x17\x77\x87\xb5\xe6\x81\x55\xd9\x69\x59\x0e\x76\xeb\x09\x54\x0a\x8b\x5a\xdf\xe1\x1c\xef\x2f\x14\x16\x03\xb2\x00\x57\xed\x53\x34\x24\x2e\x64\xfb\x00\x01\xad\x71\x3d\xd0\x28\xe5\x1d\x6f\xc3\x91\xbf\xf2\x97\x7c\xa0\x5b\x96\x7c\x5a\xe8\x5e\xbc\x3d\xe9\x8d\x36\xea\xdb\xda\xe2\xa4\xc8\x81\xfb\xd2\x48\x3f\x62\xc6\x88\xcc\x93\x1e\x9a\x54\x13\x79\x87\x65\xb5\x1e\x6b\x6d\x07\x35\x3f\xcd\xb3\x09\xa9\x52\xa4\x34\xfd\xc3\xfb\x1f\xb4\x76\x4d\x83\xb9\x63\x85\xcf\x1f\x3d\xf4\x30\x0c\x5d\x9f\x35\xab\xf0\x29\x5c\x9d\x10\x3c\x7c\xbf\x49\xab\x71\x69\xeb\xbb\xf5\x22\x48\xbd\xdf\x38\x4f\x90\x4f\x7c\xcb\x1c\x41\x52\xd1\x3e\xb6\x5a\x81\x67\xbc\x17\xa8\xce\x51\x4c\x6f\xae\xf3\xb2\x7a\xb2\xb0\x18\x01\x19\xff\x70\x43\xec\xa2\x18\xf1\x74\xec\x8a\x74\xb8\xf2\xe2\x8a\x0b\x63\xdc\xe5\xdc\xe5\x2e\xab\xcc\xe7\xff\x91\x61\x8c\xc1\x44\xd2\x15\xc1\x26\x27\x7f\x60\xe4\x11\xc9\x7f\x63\xce\x5f\x20\xe6\x90\x37\xfc\x19\x31\xe7\x77\x06\x9c\x2d\x91\xa1\xde\x14\xdf\xea\xe5\xdb\xfd\xd1\xde\x31\xe8\xb0\xd1\xe1\x8f\x1b\xae\x0d\x0f\x0d\x86\x57\x43\xd4\xcd\x4f\x0b\x33\x9d\x19\x0d\x71\x43\xcb\xec\xfa\x67\x5d\xce\x9a\xb6\x96\xbe\x73\xe0\xd6\x80\x48\xd6\xd0\xf3\xa8\xb8\xf4\xfb\x85\xb0\x5a\x35\x2f\xac\x1b\xd8\xa6\xb8\xb7\x57\x56\x5a\xc7\xfa\xc2\x4e\x37\x23\x44\x52\x5d\x72\xe8\x9d\x9c\x5c\x81\xbe\xd3\xe2\x71\x62\xd2\xb5\xa4\xd2\x99\xb9\xd0\x0b\x27\x27\xbc\xac\x7f\x9d\xa5\x85\x71\xa1\xca\x45\xac\xdc\x05\xaa\xbb\x1f\x37\x31\xc0\xbf\x25\xd7\x77\xa6\xf6\x3e\xbd\xd7\xa3\x40\x4b\xbb\xbc\xbc\xaa\x07\x0d\xb3\x43\x07\xe3\xba\x80\x56\x0c\x2d\xd0\xab\xc6\xa5\x3c\xf3\xca\x7f\x75\x74\x84\x89\xfb\x5a\x57\xb8\xd7\xa3\xa1\x71\x03\x64\x3d\xdb\x33\x31\x68\xdc\x15\x94\x34\xc4\x86\xde\xf1\x96\xf8\xb4\xa5\x9d\xdc\x11\x93\x34\x8a\xf9\xc7\x75\xf0\xc6\xa6\x19\xd5\xd9\x85\xea\xf5\xaa\xf0\x97\x1b\x2c\x39\xcc\x86\x13\x7b\x85\xb8\x03\xb1\x4b\x7d\x95\xde\xd8\xe9\x3b\x72\xc7\x8c\x7f\xbe\x75\x9e\x79\x35\x82\x74\xc6\x07\xf0\xa1\xcf\x21\x3b\x59\xbe\xe0\x9c\x18\x10\xf9\xf3\x45\x96\x4d\xa4\xfa\xbf\xff\x0d\xdc\x4d\x3d\x1b\xf3\x4f\x15\x96\x27\xec\xd8\xf6\x96\x9e\xb0\x8e\xf4\x24\x21\x19\x05\xaf\x43\x81\x5b\x5e\xc8\xad\xab\xaf\x6d\xa4\x4d\x43\x48\x22\xb1\x86\xd8\x48\x68\x7d\x09\x6c\x44\x3d\x84\xcb\xf7\xfe\x35\xb4\x91\xb4\x39\x31\x34\xe6\x5e\xdb\xfd\x90\xfd\x8f\xf4\xf5\xba\x90\xfc\xb5\xf2\xa5\xa5\x2f\xa2\x0d\x85\x7c\xa1\x46\x20\x24\x6c\xb8\xeb\x26\x97\x60\x90\x7c\x46\xdb\xcf\x17\x2a\x1c\xec\xaf\xe9\x68\x2d\x50\x0c\xfb\x2e\x9f\xc1\xd7\xaf\x80\x2c\xcf\x75\x5c\xea\x75\xdf\x8b\x2f\x24\x3e\x14\x18\x2b\x4c\x40\x24\xfa\x88\xcd\x75\x1b\xb9\xdf\x9b\x7c\xa1\x02\xb3\xb0\x79\x1b\x82\x42\x5a\x0e\x84\x34\x0c\xf0\xce\xda\xf4\x49\xd6\xbf\x8f\xbc\x90\x0d\xea\xf9\x42\xb1\x52\x4c\x84\x6e\x5c\xb0\x1e\x97\xd3\x00\x02\xda\x77\x00\x01\xc7\xa5\x80\xcd\x09\x02\xab\xe6\xc0\x69\x65\xf7\xcb\xd6\x83\xf9\xfb\xb9\xbe\x8a\x0f\xec\x6b\x13\xcf\x4e\x7a\x42\x3e\xcd\x91\x90\x1e\x43\xce\xf8\x6a\x6c\x69\xeb\xf8\x97\x71\x45\xc1\xda\xe9\x29\xa9\x2e\xad\xe0\xae\x6a\x5a\xda\x4d\x2f\x9c\x49\x44\x42\xa6\xc9\x31\x79\x0c\xaf\xee\x82\x11\xd8\x25\xeb\x1a\x12\x29\x25\x1f\x4d\x98\xa1\x2f\x8d\x80\xae\x0e\x6b\x24\x6d\x06\x70\x09\xc7\x0c\x90\x07\x74\x2c\x5b\x5f\xaa\x8e\xb5\x1e\x5f\x3f\x41\xe9\xf9\x67\x79\xcf\xe3\x5c\x69\xf3\x44\x35\x56\xcb\x14\xcf\x28\xcb\x7a\x2f\x2a\xcc\x6a\x8a\xad\x5d\x50\x3f\x5d\xa7\x6d\xad\xd4\x9a\xef\x94\x5a\xaf\x96\x3a\xdf\x54\x70\x92\x7f\xd1\x9b\x8a\x7a\x07\xc5\xb3\xa7\xdf\x74\x99\xa4\x73\x7a\xa0\x33\x8f\xc9\xd8\xc1\x18\x5e\xdd\xff\x66\x6f\x8a\x8d\x8a\x18\xdc\x24\xb1\xee\xd3\xc6\xe4\x64\x22\xad\xd1\xb8\x2c\x24\x6d\xa5\xe9\x2e\xbb\xf5\x42\xe6\xf9\xdf\xd0\xdb\xf5\x46\xae\xb9\xef\x63\xd8\xb0\xb5\x94\x57\x48\x59\x0a\x06\xd3\x3c\xb1\xf0\x0d\x9e\xce\x57\x57\xfd\xb6\x9b\x6d\x12\x8d\xe7\x6a\x0d\xc9\x68\xd7\xd3\x78\x98\xc0\xab\xbb\xdf\x46\x20\x5d\x41\xa6\x15\xdc\xb8\x72\xf2\x0b\x3d\xcd\xdc\xa5\xb8\xda\x5e\xb6\xf9\x8f\xb8\xb6\x03\x8f\x40\x7a\xa4\xdd\x51\x80\x4a\x03\x9d\x78\x3f\xdd\xcb\xb3\x8f\xd6\xba\x12\xbf\xe6\xed\x2c\xde\xba\x8a\x5f\xfa\xd9\x55\x00\xef\x56\xf9\x6d\x91\x06\x97\x0e\x29\xeb\x69\x0f\xbd\x07\x76\xbe\xa7\xd0\x2e\xd2\x99\xbe\xbb\x0f\xcf\x45\x96\x51\xfd\xb6\x0e\xf5\x29\x38\x21\x11\xe5\x74\xd6\x12\x66\x3d\xdf\xc5\xb9\x54\x42\x9a\x67\x83\xbd\x2e\x87\xec\x69\x86\x4c\xe5\xd2\x45\x7a\xdf\x45\xf6\xed\xb4\xfb\xcd\xe8\x6b\x04\x79\x99\xce\xea\xa1\xb7\x16\x75\x39\xe2\xa6\x33\xdf\xa0\x3c\xbc\xba\x71\x98\xc1\x91\x13\xc2\xb3\xdd\xf5\xdf\xc8\x55\xed\x86\x7e\x87\xb3\xa6\xda\x84\xde\xcc\xf0\x51\x3b\xee\x5a\x7d\xd6\x7f\xbf\xb5\xeb\xca\x0d\xde\xf8\x92\x53\xe7\x26\xc7\xdb\x78\xf2\x7c\xca\xe1\xba\xcf\x93\xbc\x29\xff\x4c\x69\x8f\xa3\x66\xc2\x1e\x49\xe9\xb3\xe6\x97\x18\x7e\x72\xe7\xe9\x5d\xe5\xb3\xe5\x44\xda\xf1\x62\xd9\xcb\xdd\x06\x63\xdb\x4b\x65\x63\xf4\xed\x07\xc2\xf5\x85\x7d\xd7\x70\xd7\x72\x61\x9b\xab\xc1\xb6\x03\xdd\x33\xce\x73\xad\xa6\x54\xfd\x9c\xb6\xfa\xb3\xbc\xef\xa9\xf0\xab\xa7\xe4\x96\xe0\xb6\x53\x80\x4e\x67\x1b\x0e\x1d\x2d\x4f\xde\xc1\x7d\x45\xc5\x2b\xd1\xe6\x38\x13\x37\xbc\xb8\x3b\xa2\xfb\xa5\xb7\xb5\xef\x9d\xe2\x7b\x3a\xfb\x43\x02\x51\x63\x3b\x3b\xb2\xd6\x8e\x5e\x4f\xf7\x83\xdc\x79\xdc\xbc\x15\x5b\xad\xe4\xba\x93\xe0\x25\x96\x27\x2a\x8e\xda\xd1\xa6\xd9\x7f\x5c\xbd\xa8\x57\xe8\x9f\x9e\x5c\x6b\x30\x2a\x6b\xff\x6d\xe5\xb8\x9c\xae\xe7\xf8\x35\x96\x3f\xbb\xb6\x4c\x7d\x27\xb1\xc8\x32\x45\x41\xd2\x03\xf1\x7a\x0b\x7d\x6b\xbc\x37\x51\xf5\xb9\xc4\x54\x3c\x78\x28\x41\x75\x9b\x05\xa6\x5f\xcc\xda\xd3\x7a\x30\xd8\x9a\x10\x33\xe7\x6e\x15\xbc\xe6\xb4\x96\xb1\xcc\x95\xc3\xeb\xd2\x1f\x2d\x1b\x79\xfb\x79\xe6\xff\xec\x39\x28\x4a\x2c\xa2\x12\xf9\x89\xa7\x27\xb1\xcd\xff\xd1\x67\xc7\x3b\x5c\x13\x36\xd2\xcd\x0f\x11\xbc\xc7\x77\xe4\x17\xcd\x98\xf7\x73\x94\x89\x44\x87\xbc\x41\x3a\xac\x3d\xfe\x7f\xcd\x73\x7c\x11\xc5\x3e\xb3\x24\x67\x18\x43\xca\x81\x6b\xbc\xf1\x19\xab\x90\x77\x84\xa7\xdf\xdf\xc1\xab\x5b\x76\x17\xfd\xff\x9b\x46\x90\x0e\xdd\xb3\x2d\x4f\x76\xff\x0c\x00\x00\xff\xff\x34\x7a\x8a\x3d\x7a\x36\x00\x00") func templateDialectSqlQueryTmplBytes() ([]byte, error) { return bindataRead( @@ -887,7 +887,7 @@ func templateDialectSqlQueryTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "template/dialect/sql/query.tmpl", size: 14004, mode: os.FileMode(420), modTime: time.Unix(1, 0)} + info := bindataFileInfo{name: "template/dialect/sql/query.tmpl", size: 13946, mode: os.FileMode(420), modTime: time.Unix(1, 0)} a := &asset{bytes: bytes, info: info} return a, nil } diff --git a/entc/gen/template/dialect/sql/by.tmpl b/entc/gen/template/dialect/sql/by.tmpl index e3f70fdd7..9c440d8fe 100644 --- a/entc/gen/template/dialect/sql/by.tmpl +++ b/entc/gen/template/dialect/sql/by.tmpl @@ -5,41 +5,63 @@ in the LICENSE file in the root directory of this source tree. */}} {{ define "dialect/sql/order/signature" -}} - // OrderFunc applies an ordering on the sql selector. - type OrderFunc func(*sql.Selector, func(string) bool) +// OrderFunc applies an ordering on the sql selector. +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + {{- range $n := $.Nodes }} + {{ $n.Package }}.Table: {{ $n.Package }}.ValidColumn, + {{- end }} + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("{{ base $.Config.Package }}: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("{{ base $.Config.Package }}: unknown column %q for table %q", column, table) + } + return nil + } +} {{- end }} {{ define "dialect/sql/order/func" -}} {{- $f := $.Scope.Func -}} - func(s *sql.Selector, check func(string) bool) { + func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.{{ $f }}(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.{{ $f }}(s.C(f))) } } {{- end }} {{/* custom signature for group-by function */}} {{ define "dialect/sql/group/signature" -}} - type AggregateFunc func(*sql.Selector, func(string) bool) string + type AggregateFunc func(*sql.Selector) string {{- end }} {{ define "dialect/sql/group/as" -}} - func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + func(s *sql.Selector) string { + return sql.As(fn(s), end) } {{- end }} {{ define "dialect/sql/group/func" -}} {{- $fn := $.Scope.Func -}} {{- $withField := $.Scope.WithField -}} - func(s *sql.Selector, {{ if $withField }}check{{ else }}_{{ end }} func(string) bool) string { + func(s *sql.Selector) string { {{- if $withField }} - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } {{- end }} diff --git a/entc/gen/template/dialect/sql/group.tmpl b/entc/gen/template/dialect/sql/group.tmpl index 2b6bfcccd..06889d6bd 100644 --- a/entc/gen/template/dialect/sql/group.tmpl +++ b/entc/gen/template/dialect/sql/group.tmpl @@ -33,7 +33,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery() *sql.Selector { columns := make([]string, 0, len({{ $receiver }}.fields) + len({{ $receiver}}.fns)) columns = append(columns, {{ $receiver }}.fields...) for _, fn := range {{ $receiver }}.fns { - columns = append(columns, fn(selector, {{ $.Package }}.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy({{ $receiver }}.fields...) } diff --git a/entc/gen/template/dialect/sql/query.tmpl b/entc/gen/template/dialect/sql/query.tmpl index 8b4f8e8f9..bf1721a0d 100644 --- a/entc/gen/template/dialect/sql/query.tmpl +++ b/entc/gen/template/dialect/sql/query.tmpl @@ -136,7 +136,7 @@ func ({{ $receiver }} *{{ $builder }}) querySpec() *sqlgraph.QuerySpec { if ps := {{ $receiver }}.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, {{ $.Package }}.ValidColumn) + ps[i](selector) } } } @@ -177,7 +177,7 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select p(selector) } for _, p := range {{ $receiver }}.order { - p(selector, {{ $.Package }}.ValidColumn) + p(selector) } if offset := {{ $receiver }}.offset; offset != nil { // limit is mandatory for offset clause. We start diff --git a/entc/integration/cascadelete/ent/comment_query.go b/entc/integration/cascadelete/ent/comment_query.go index d6fb4d937..f838b2821 100644 --- a/entc/integration/cascadelete/ent/comment_query.go +++ b/entc/integration/cascadelete/ent/comment_query.go @@ -460,7 +460,7 @@ func (cq *CommentQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, comment.ValidColumn) + ps[i](selector) } } } @@ -479,7 +479,7 @@ func (cq *CommentQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, comment.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -745,7 +745,7 @@ func (cgb *CommentGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, comment.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/cascadelete/ent/ent.go b/entc/integration/cascadelete/ent/ent.go index 28dee448a..35c3f55a4 100644 --- a/entc/integration/cascadelete/ent/ent.go +++ b/entc/integration/cascadelete/ent/ent.go @@ -14,6 +14,9 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/cascadelete/ent/comment" + "entgo.io/ent/entc/integration/cascadelete/ent/post" + "entgo.io/ent/entc/integration/cascadelete/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +32,57 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + comment.Table: comment.ValidColumn, + post.Table: post.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +91,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +117,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +129,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +141,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/cascadelete/ent/post_query.go b/entc/integration/cascadelete/ent/post_query.go index 742c5a952..dabc62434 100644 --- a/entc/integration/cascadelete/ent/post_query.go +++ b/entc/integration/cascadelete/ent/post_query.go @@ -523,7 +523,7 @@ func (pq *PostQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, post.ValidColumn) + ps[i](selector) } } } @@ -542,7 +542,7 @@ func (pq *PostQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, post.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -808,7 +808,7 @@ func (pgb *PostGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, post.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/entc/integration/cascadelete/ent/user_query.go b/entc/integration/cascadelete/ent/user_query.go index 966dfffb5..631539c2b 100644 --- a/entc/integration/cascadelete/ent/user_query.go +++ b/entc/integration/cascadelete/ent/user_query.go @@ -460,7 +460,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -479,7 +479,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -745,7 +745,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/config/ent/ent.go b/entc/integration/config/ent/ent.go index 28dee448a..740254841 100644 --- a/entc/integration/config/ent/ent.go +++ b/entc/integration/config/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/config/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/config/ent/user_query.go b/entc/integration/config/ent/user_query.go index 339c5088e..84f79a6e3 100644 --- a/entc/integration/config/ent/user_query.go +++ b/entc/integration/config/ent/user_query.go @@ -392,7 +392,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/customid/ent/blob_query.go b/entc/integration/customid/ent/blob_query.go index 19315fc3a..d30818c9e 100644 --- a/entc/integration/customid/ent/blob_query.go +++ b/entc/integration/customid/ent/blob_query.go @@ -573,7 +573,7 @@ func (bq *BlobQuery) querySpec() *sqlgraph.QuerySpec { if ps := bq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, blob.ValidColumn) + ps[i](selector) } } } @@ -592,7 +592,7 @@ func (bq *BlobQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range bq.order { - p(selector, blob.ValidColumn) + p(selector) } if offset := bq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -858,7 +858,7 @@ func (bgb *BlobGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(bgb.fields)+len(bgb.fns)) columns = append(columns, bgb.fields...) for _, fn := range bgb.fns { - columns = append(columns, fn(selector, blob.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(bgb.fields...) } diff --git a/entc/integration/customid/ent/car_query.go b/entc/integration/customid/ent/car_query.go index b9903804d..b26ae3125 100644 --- a/entc/integration/customid/ent/car_query.go +++ b/entc/integration/customid/ent/car_query.go @@ -471,7 +471,7 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, car.ValidColumn) + ps[i](selector) } } } @@ -490,7 +490,7 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, car.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -756,7 +756,7 @@ func (cgb *CarGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, car.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/customid/ent/ent.go b/entc/integration/customid/ent/ent.go index 28dee448a..b880028fa 100644 --- a/entc/integration/customid/ent/ent.go +++ b/entc/integration/customid/ent/ent.go @@ -14,6 +14,12 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/customid/ent/blob" + "entgo.io/ent/entc/integration/customid/ent/car" + "entgo.io/ent/entc/integration/customid/ent/group" + "entgo.io/ent/entc/integration/customid/ent/mixinid" + "entgo.io/ent/entc/integration/customid/ent/pet" + "entgo.io/ent/entc/integration/customid/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +35,60 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + blob.Table: blob.ValidColumn, + car.Table: car.ValidColumn, + group.Table: group.ValidColumn, + mixinid.Table: mixinid.ValidColumn, + pet.Table: pet.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +97,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +123,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +135,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +147,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/customid/ent/group_query.go b/entc/integration/customid/ent/group_query.go index d228b5098..e1ea8a285 100644 --- a/entc/integration/customid/ent/group_query.go +++ b/entc/integration/customid/ent/group_query.go @@ -476,7 +476,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -495,7 +495,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -761,7 +761,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/entc/integration/customid/ent/mixinid_query.go b/entc/integration/customid/ent/mixinid_query.go index ae105c056..254ef72fc 100644 --- a/entc/integration/customid/ent/mixinid_query.go +++ b/entc/integration/customid/ent/mixinid_query.go @@ -393,7 +393,7 @@ func (miq *MixinIDQuery) querySpec() *sqlgraph.QuerySpec { if ps := miq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, mixinid.ValidColumn) + ps[i](selector) } } } @@ -412,7 +412,7 @@ func (miq *MixinIDQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range miq.order { - p(selector, mixinid.ValidColumn) + p(selector) } if offset := miq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -678,7 +678,7 @@ func (migb *MixinIDGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(migb.fields)+len(migb.fns)) columns = append(columns, migb.fields...) for _, fn := range migb.fns { - columns = append(columns, fn(selector, mixinid.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(migb.fields...) } diff --git a/entc/integration/customid/ent/pet_query.go b/entc/integration/customid/ent/pet_query.go index 547a03947..03a5fea18 100644 --- a/entc/integration/customid/ent/pet_query.go +++ b/entc/integration/customid/ent/pet_query.go @@ -680,7 +680,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, pet.ValidColumn) + ps[i](selector) } } } @@ -699,7 +699,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, pet.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -965,7 +965,7 @@ func (pgb *PetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, pet.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/entc/integration/customid/ent/user_query.go b/entc/integration/customid/ent/user_query.go index 0e3cd34d7..d7ac9398d 100644 --- a/entc/integration/customid/ent/user_query.go +++ b/entc/integration/customid/ent/user_query.go @@ -680,7 +680,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -699,7 +699,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -965,7 +965,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/edgefield/ent/card_query.go b/entc/integration/edgefield/ent/card_query.go index 8dbae4fcf..a61fddbf9 100644 --- a/entc/integration/edgefield/ent/card_query.go +++ b/entc/integration/edgefield/ent/card_query.go @@ -460,7 +460,7 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, card.ValidColumn) + ps[i](selector) } } } @@ -479,7 +479,7 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, card.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -745,7 +745,7 @@ func (cgb *CardGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, card.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/edgefield/ent/ent.go b/entc/integration/edgefield/ent/ent.go index 28dee448a..495778ebd 100644 --- a/entc/integration/edgefield/ent/ent.go +++ b/entc/integration/edgefield/ent/ent.go @@ -14,6 +14,12 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/edgefield/ent/card" + "entgo.io/ent/entc/integration/edgefield/ent/info" + "entgo.io/ent/entc/integration/edgefield/ent/metadata" + "entgo.io/ent/entc/integration/edgefield/ent/pet" + "entgo.io/ent/entc/integration/edgefield/ent/post" + "entgo.io/ent/entc/integration/edgefield/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +35,60 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + card.Table: card.ValidColumn, + info.Table: info.ValidColumn, + metadata.Table: metadata.ValidColumn, + pet.Table: pet.ValidColumn, + post.Table: post.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +97,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +123,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +135,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +147,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/edgefield/ent/info_query.go b/entc/integration/edgefield/ent/info_query.go index 28c28fb19..bdc9ebc21 100644 --- a/entc/integration/edgefield/ent/info_query.go +++ b/entc/integration/edgefield/ent/info_query.go @@ -460,7 +460,7 @@ func (iq *InfoQuery) querySpec() *sqlgraph.QuerySpec { if ps := iq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, info.ValidColumn) + ps[i](selector) } } } @@ -479,7 +479,7 @@ func (iq *InfoQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range iq.order { - p(selector, info.ValidColumn) + p(selector) } if offset := iq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -745,7 +745,7 @@ func (igb *InfoGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(igb.fields)+len(igb.fns)) columns = append(columns, igb.fields...) for _, fn := range igb.fns { - columns = append(columns, fn(selector, info.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(igb.fields...) } diff --git a/entc/integration/edgefield/ent/metadata_query.go b/entc/integration/edgefield/ent/metadata_query.go index 0b35f7472..ea20cdbbf 100644 --- a/entc/integration/edgefield/ent/metadata_query.go +++ b/entc/integration/edgefield/ent/metadata_query.go @@ -460,7 +460,7 @@ func (mq *MetadataQuery) querySpec() *sqlgraph.QuerySpec { if ps := mq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, metadata.ValidColumn) + ps[i](selector) } } } @@ -479,7 +479,7 @@ func (mq *MetadataQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range mq.order { - p(selector, metadata.ValidColumn) + p(selector) } if offset := mq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -745,7 +745,7 @@ func (mgb *MetadataGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) columns = append(columns, mgb.fields...) for _, fn := range mgb.fns { - columns = append(columns, fn(selector, metadata.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(mgb.fields...) } diff --git a/entc/integration/edgefield/ent/pet_query.go b/entc/integration/edgefield/ent/pet_query.go index 3c77b8479..6d5e7f07c 100644 --- a/entc/integration/edgefield/ent/pet_query.go +++ b/entc/integration/edgefield/ent/pet_query.go @@ -460,7 +460,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, pet.ValidColumn) + ps[i](selector) } } } @@ -479,7 +479,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, pet.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -745,7 +745,7 @@ func (pgb *PetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, pet.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/entc/integration/edgefield/ent/post_query.go b/entc/integration/edgefield/ent/post_query.go index 4b5c2932c..d91e163b2 100644 --- a/entc/integration/edgefield/ent/post_query.go +++ b/entc/integration/edgefield/ent/post_query.go @@ -463,7 +463,7 @@ func (pq *PostQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, post.ValidColumn) + ps[i](selector) } } } @@ -482,7 +482,7 @@ func (pq *PostQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, post.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -748,7 +748,7 @@ func (pgb *PostGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, post.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/entc/integration/edgefield/ent/user_query.go b/entc/integration/edgefield/ent/user_query.go index b3d554733..479fb4062 100644 --- a/entc/integration/edgefield/ent/user_query.go +++ b/entc/integration/edgefield/ent/user_query.go @@ -829,7 +829,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -848,7 +848,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -1114,7 +1114,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/ent/card_query.go b/entc/integration/ent/card_query.go index 4417d0dde..6e1f39a35 100644 --- a/entc/integration/ent/card_query.go +++ b/entc/integration/ent/card_query.go @@ -574,7 +574,7 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, card.ValidColumn) + ps[i](selector) } } } @@ -593,7 +593,7 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, card.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -859,7 +859,7 @@ func (cgb *CardGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, card.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/ent/comment_query.go b/entc/integration/ent/comment_query.go index 96ac5b58d..7713ef878 100644 --- a/entc/integration/ent/comment_query.go +++ b/entc/integration/ent/comment_query.go @@ -392,7 +392,7 @@ func (cq *CommentQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, comment.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (cq *CommentQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, comment.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (cgb *CommentGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, comment.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/ent/ent.go b/entc/integration/ent/ent.go index 28dee448a..3c2f8d86c 100644 --- a/entc/integration/ent/ent.go +++ b/entc/integration/ent/ent.go @@ -14,6 +14,20 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/ent/card" + "entgo.io/ent/entc/integration/ent/comment" + "entgo.io/ent/entc/integration/ent/fieldtype" + "entgo.io/ent/entc/integration/ent/file" + "entgo.io/ent/entc/integration/ent/filetype" + "entgo.io/ent/entc/integration/ent/goods" + "entgo.io/ent/entc/integration/ent/group" + "entgo.io/ent/entc/integration/ent/groupinfo" + "entgo.io/ent/entc/integration/ent/item" + "entgo.io/ent/entc/integration/ent/node" + "entgo.io/ent/entc/integration/ent/pet" + "entgo.io/ent/entc/integration/ent/spec" + "entgo.io/ent/entc/integration/ent/task" + "entgo.io/ent/entc/integration/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +43,68 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + card.Table: card.ValidColumn, + comment.Table: comment.ValidColumn, + fieldtype.Table: fieldtype.ValidColumn, + file.Table: file.ValidColumn, + filetype.Table: filetype.ValidColumn, + goods.Table: goods.ValidColumn, + group.Table: group.ValidColumn, + groupinfo.Table: groupinfo.ValidColumn, + item.Table: item.ValidColumn, + node.Table: node.ValidColumn, + pet.Table: pet.ValidColumn, + spec.Table: spec.ValidColumn, + task.Table: task.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +113,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +139,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +151,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +163,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/ent/fieldtype_query.go b/entc/integration/ent/fieldtype_query.go index 225d3af1c..281dc58a1 100644 --- a/entc/integration/ent/fieldtype_query.go +++ b/entc/integration/ent/fieldtype_query.go @@ -397,7 +397,7 @@ func (ftq *FieldTypeQuery) querySpec() *sqlgraph.QuerySpec { if ps := ftq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, fieldtype.ValidColumn) + ps[i](selector) } } } @@ -416,7 +416,7 @@ func (ftq *FieldTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range ftq.order { - p(selector, fieldtype.ValidColumn) + p(selector) } if offset := ftq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -682,7 +682,7 @@ func (ftgb *FieldTypeGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ftgb.fields)+len(ftgb.fns)) columns = append(columns, ftgb.fields...) for _, fn := range ftgb.fns { - columns = append(columns, fn(selector, fieldtype.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ftgb.fields...) } diff --git a/entc/integration/ent/file_query.go b/entc/integration/ent/file_query.go index 3ed6c3073..da112facf 100644 --- a/entc/integration/ent/file_query.go +++ b/entc/integration/ent/file_query.go @@ -604,7 +604,7 @@ func (fq *FileQuery) querySpec() *sqlgraph.QuerySpec { if ps := fq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, file.ValidColumn) + ps[i](selector) } } } @@ -623,7 +623,7 @@ func (fq *FileQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range fq.order { - p(selector, file.ValidColumn) + p(selector) } if offset := fq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -889,7 +889,7 @@ func (fgb *FileGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(fgb.fields)+len(fgb.fns)) columns = append(columns, fgb.fields...) for _, fn := range fgb.fns { - columns = append(columns, fn(selector, file.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(fgb.fields...) } diff --git a/entc/integration/ent/filetype_query.go b/entc/integration/ent/filetype_query.go index d944f82b0..cde069e55 100644 --- a/entc/integration/ent/filetype_query.go +++ b/entc/integration/ent/filetype_query.go @@ -464,7 +464,7 @@ func (ftq *FileTypeQuery) querySpec() *sqlgraph.QuerySpec { if ps := ftq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, filetype.ValidColumn) + ps[i](selector) } } } @@ -483,7 +483,7 @@ func (ftq *FileTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range ftq.order { - p(selector, filetype.ValidColumn) + p(selector) } if offset := ftq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -749,7 +749,7 @@ func (ftgb *FileTypeGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ftgb.fields)+len(ftgb.fns)) columns = append(columns, ftgb.fields...) for _, fn := range ftgb.fns { - columns = append(columns, fn(selector, filetype.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ftgb.fields...) } diff --git a/entc/integration/ent/goods_query.go b/entc/integration/ent/goods_query.go index d2ae98f71..8319c7736 100644 --- a/entc/integration/ent/goods_query.go +++ b/entc/integration/ent/goods_query.go @@ -368,7 +368,7 @@ func (gq *GoodsQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, goods.ValidColumn) + ps[i](selector) } } } @@ -387,7 +387,7 @@ func (gq *GoodsQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, goods.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -653,7 +653,7 @@ func (ggb *GoodsGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, goods.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/entc/integration/ent/group_query.go b/entc/integration/ent/group_query.go index 3c1be3ef3..8e9c0e9d6 100644 --- a/entc/integration/ent/group_query.go +++ b/entc/integration/ent/group_query.go @@ -705,7 +705,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -724,7 +724,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -990,7 +990,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/entc/integration/ent/groupinfo_query.go b/entc/integration/ent/groupinfo_query.go index 56eef9159..555c4f731 100644 --- a/entc/integration/ent/groupinfo_query.go +++ b/entc/integration/ent/groupinfo_query.go @@ -464,7 +464,7 @@ func (giq *GroupInfoQuery) querySpec() *sqlgraph.QuerySpec { if ps := giq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, groupinfo.ValidColumn) + ps[i](selector) } } } @@ -483,7 +483,7 @@ func (giq *GroupInfoQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range giq.order { - p(selector, groupinfo.ValidColumn) + p(selector) } if offset := giq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -749,7 +749,7 @@ func (gigb *GroupInfoGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(gigb.fields)+len(gigb.fns)) columns = append(columns, gigb.fields...) for _, fn := range gigb.fns { - columns = append(columns, fn(selector, groupinfo.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(gigb.fields...) } diff --git a/entc/integration/ent/item_query.go b/entc/integration/ent/item_query.go index 9dc273ee1..d6008f322 100644 --- a/entc/integration/ent/item_query.go +++ b/entc/integration/ent/item_query.go @@ -368,7 +368,7 @@ func (iq *ItemQuery) querySpec() *sqlgraph.QuerySpec { if ps := iq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, item.ValidColumn) + ps[i](selector) } } } @@ -387,7 +387,7 @@ func (iq *ItemQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range iq.order { - p(selector, item.ValidColumn) + p(selector) } if offset := iq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -653,7 +653,7 @@ func (igb *ItemGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(igb.fields)+len(igb.fns)) columns = append(columns, igb.fields...) for _, fn := range igb.fns { - columns = append(columns, fn(selector, item.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(igb.fields...) } diff --git a/entc/integration/ent/node_query.go b/entc/integration/ent/node_query.go index d759a8819..d9c895704 100644 --- a/entc/integration/ent/node_query.go +++ b/entc/integration/ent/node_query.go @@ -535,7 +535,7 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { if ps := nq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, node.ValidColumn) + ps[i](selector) } } } @@ -554,7 +554,7 @@ func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range nq.order { - p(selector, node.ValidColumn) + p(selector) } if offset := nq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -820,7 +820,7 @@ func (ngb *NodeGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ngb.fields)+len(ngb.fns)) columns = append(columns, ngb.fields...) for _, fn := range ngb.fns { - columns = append(columns, fn(selector, node.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ngb.fields...) } diff --git a/entc/integration/ent/pet_query.go b/entc/integration/ent/pet_query.go index 0896ff895..0b0bd5d14 100644 --- a/entc/integration/ent/pet_query.go +++ b/entc/integration/ent/pet_query.go @@ -536,7 +536,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, pet.ValidColumn) + ps[i](selector) } } } @@ -555,7 +555,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, pet.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -821,7 +821,7 @@ func (pgb *PetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, pet.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/entc/integration/ent/spec_query.go b/entc/integration/ent/spec_query.go index 5498f8ac3..e74d7dbd5 100644 --- a/entc/integration/ent/spec_query.go +++ b/entc/integration/ent/spec_query.go @@ -476,7 +476,7 @@ func (sq *SpecQuery) querySpec() *sqlgraph.QuerySpec { if ps := sq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, spec.ValidColumn) + ps[i](selector) } } } @@ -495,7 +495,7 @@ func (sq *SpecQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range sq.order { - p(selector, spec.ValidColumn) + p(selector) } if offset := sq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -761,7 +761,7 @@ func (sgb *SpecGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(sgb.fields)+len(sgb.fns)) columns = append(columns, sgb.fields...) for _, fn := range sgb.fns { - columns = append(columns, fn(selector, spec.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(sgb.fields...) } diff --git a/entc/integration/ent/task_query.go b/entc/integration/ent/task_query.go index 183fa7bda..cef476d3b 100644 --- a/entc/integration/ent/task_query.go +++ b/entc/integration/ent/task_query.go @@ -392,7 +392,7 @@ func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { if ps := tq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, task.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range tq.order { - p(selector, task.ValidColumn) + p(selector) } if offset := tq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (tgb *TaskGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(tgb.fields)+len(tgb.fns)) columns = append(columns, tgb.fields...) for _, fn := range tgb.fns { - columns = append(columns, fn(selector, task.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(tgb.fields...) } diff --git a/entc/integration/ent/user_query.go b/entc/integration/ent/user_query.go index bd453ff0f..66a94767f 100644 --- a/entc/integration/ent/user_query.go +++ b/entc/integration/ent/user_query.go @@ -1267,7 +1267,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -1286,7 +1286,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -1552,7 +1552,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/hooks/ent/card_query.go b/entc/integration/hooks/ent/card_query.go index 85186e2a7..5fc73647e 100644 --- a/entc/integration/hooks/ent/card_query.go +++ b/entc/integration/hooks/ent/card_query.go @@ -471,7 +471,7 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, card.ValidColumn) + ps[i](selector) } } } @@ -490,7 +490,7 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, card.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -756,7 +756,7 @@ func (cgb *CardGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, card.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/hooks/ent/ent.go b/entc/integration/hooks/ent/ent.go index 28dee448a..5797f17c9 100644 --- a/entc/integration/hooks/ent/ent.go +++ b/entc/integration/hooks/ent/ent.go @@ -14,6 +14,8 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/hooks/ent/card" + "entgo.io/ent/entc/integration/hooks/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +31,56 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + card.Table: card.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +89,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +115,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +127,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +139,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/hooks/ent/user_query.go b/entc/integration/hooks/ent/user_query.go index 6553bbb4a..8338bb770 100644 --- a/entc/integration/hooks/ent/user_query.go +++ b/entc/integration/hooks/ent/user_query.go @@ -638,7 +638,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -657,7 +657,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -923,7 +923,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/idtype/ent/ent.go b/entc/integration/idtype/ent/ent.go index 28dee448a..8db527434 100644 --- a/entc/integration/idtype/ent/ent.go +++ b/entc/integration/idtype/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/idtype/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/idtype/ent/user_query.go b/entc/integration/idtype/ent/user_query.go index 588bb7cf7..1b24cf743 100644 --- a/entc/integration/idtype/ent/user_query.go +++ b/entc/integration/idtype/ent/user_query.go @@ -673,7 +673,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -692,7 +692,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -958,7 +958,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 35635171f..585affe6b 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -21,6 +21,7 @@ import ( "time" "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/ent" "entgo.io/ent/entc/integration/ent/enttest" @@ -360,6 +361,30 @@ func Select(t *testing.T, client *ent.Client) { require.NotEmpty(a8m.Age) require.Empty(a8m.Name) require.Empty(a8m.Nickname) + + pets := client.Pet.CreateBulk( + client.Pet.Create().SetName("a"), + client.Pet.Create().SetName("b"), + client.Pet.Create().SetName("c"), + client.Pet.Create().SetName("b"), + ).SaveX(ctx) + client.User.Create().SetName("foo").SetAge(20).AddPets(pets[0], pets[1]).SaveX(ctx) + client.User.Create().SetName("bar").SetAge(20).AddPets(pets[2], pets[3]).SaveX(ctx) + names = client.Pet.Query().Order(ent.Asc(pet.FieldID)).Select(pet.FieldName).StringsX(ctx) + require.Equal([]string{"a", "b", "c", "b"}, names) + names = client.Pet.Query().Order(ent.Asc(pet.FieldName)).Select(pet.FieldName).StringsX(ctx) + require.Equal([]string{"a", "b", "b", "c"}, names) + names = client.Pet.Query(). + Order(func(s *entsql.Selector) { + // Join with user table for ordering by owner-name + // and pet-name (edge + field ordering). + t := entsql.Table(user.Table) + s.Join(t).On(s.C(pet.OwnerColumn), t.C(user.FieldID)) + s.OrderBy(t.C(user.FieldName), s.C(pet.FieldName)) + }). + Select(pet.FieldName). + StringsX(ctx) + require.Equal([]string{"b", "c", "a", "b"}, names) } func Predicate(t *testing.T, client *ent.Client) { @@ -654,11 +679,11 @@ func Relation(t *testing.T, client *ent.Client) { _, err = client.Group.Query().GroupBy("unknown_field").String(ctx) require.EqualError(err, "invalid field \"unknown_field\" for group-by") _, err = client.User.Query().Order(ent.Asc("invalid")).Only(ctx) - require.EqualError(err, "invalid field \"invalid\" for ordering") + require.EqualError(err, "ordering error: ent: unknown column \"invalid\" for table \"users\"") _, err = client.User.Query().Order(ent.Asc("invalid")).QueryFollowing().Only(ctx) - require.EqualError(err, "invalid field \"invalid\" for ordering") + require.EqualError(err, "ordering error: ent: unknown column \"invalid\" for table \"users\"") _, err = client.User.Query().GroupBy("name").Aggregate(ent.Sum("invalid")).String(ctx) - require.EqualError(err, "invalid field \"invalid\" for grouping") + require.EqualError(err, "grouping error: ent: unknown column \"invalid\" for table \"users\"") t.Log("query using edge-with predicate") require.Len(usr.QueryGroups().Where(group.HasInfoWith(groupinfo.Desc("group info"))).AllX(ctx), 1) diff --git a/entc/integration/json/ent/ent.go b/entc/integration/json/ent/ent.go index 28dee448a..92b75da67 100644 --- a/entc/integration/json/ent/ent.go +++ b/entc/integration/json/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/json/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/json/ent/user_query.go b/entc/integration/json/ent/user_query.go index 0e6a0c115..b64339662 100644 --- a/entc/integration/json/ent/user_query.go +++ b/entc/integration/json/ent/user_query.go @@ -392,7 +392,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/migrate/entv1/car_query.go b/entc/integration/migrate/entv1/car_query.go index d44180a7a..94e4143ff 100644 --- a/entc/integration/migrate/entv1/car_query.go +++ b/entc/integration/migrate/entv1/car_query.go @@ -447,7 +447,7 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, car.ValidColumn) + ps[i](selector) } } } @@ -466,7 +466,7 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, car.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -732,7 +732,7 @@ func (cgb *CarGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, car.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/migrate/entv1/conversion_query.go b/entc/integration/migrate/entv1/conversion_query.go index 28d70b87e..cf5a296d2 100644 --- a/entc/integration/migrate/entv1/conversion_query.go +++ b/entc/integration/migrate/entv1/conversion_query.go @@ -392,7 +392,7 @@ func (cq *ConversionQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, conversion.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (cq *ConversionQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, conversion.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (cgb *ConversionGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, conversion.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/migrate/entv1/customtype_query.go b/entc/integration/migrate/entv1/customtype_query.go index 7f0f90294..6928952a6 100644 --- a/entc/integration/migrate/entv1/customtype_query.go +++ b/entc/integration/migrate/entv1/customtype_query.go @@ -392,7 +392,7 @@ func (ctq *CustomTypeQuery) querySpec() *sqlgraph.QuerySpec { if ps := ctq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, customtype.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (ctq *CustomTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range ctq.order { - p(selector, customtype.ValidColumn) + p(selector) } if offset := ctq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (ctgb *CustomTypeGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ctgb.fields)+len(ctgb.fns)) columns = append(columns, ctgb.fields...) for _, fn := range ctgb.fns { - columns = append(columns, fn(selector, customtype.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ctgb.fields...) } diff --git a/entc/integration/migrate/entv1/ent.go b/entc/integration/migrate/entv1/ent.go index 114f78381..48deaaa13 100644 --- a/entc/integration/migrate/entv1/ent.go +++ b/entc/integration/migrate/entv1/ent.go @@ -14,6 +14,10 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/migrate/entv1/car" + "entgo.io/ent/entc/integration/migrate/entv1/conversion" + "entgo.io/ent/entc/integration/migrate/entv1/customtype" + "entgo.io/ent/entc/integration/migrate/entv1/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +33,58 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + car.Table: car.ValidColumn, + conversion.Table: conversion.ValidColumn, + customtype.Table: customtype.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("entv1: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("entv1: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +93,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +119,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +131,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +143,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/migrate/entv1/user_query.go b/entc/integration/migrate/entv1/user_query.go index a71408a4e..5120f95a6 100644 --- a/entc/integration/migrate/entv1/user_query.go +++ b/entc/integration/migrate/entv1/user_query.go @@ -666,7 +666,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -685,7 +685,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -951,7 +951,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/migrate/entv2/car_query.go b/entc/integration/migrate/entv2/car_query.go index 3578edbb4..35e559530 100644 --- a/entc/integration/migrate/entv2/car_query.go +++ b/entc/integration/migrate/entv2/car_query.go @@ -447,7 +447,7 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, car.ValidColumn) + ps[i](selector) } } } @@ -466,7 +466,7 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, car.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -732,7 +732,7 @@ func (cgb *CarGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, car.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/migrate/entv2/conversion_query.go b/entc/integration/migrate/entv2/conversion_query.go index c1963fcfa..2e30471db 100644 --- a/entc/integration/migrate/entv2/conversion_query.go +++ b/entc/integration/migrate/entv2/conversion_query.go @@ -392,7 +392,7 @@ func (cq *ConversionQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, conversion.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (cq *ConversionQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, conversion.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (cgb *ConversionGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, conversion.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/entc/integration/migrate/entv2/customtype_query.go b/entc/integration/migrate/entv2/customtype_query.go index ab44176a8..5dd002aa3 100644 --- a/entc/integration/migrate/entv2/customtype_query.go +++ b/entc/integration/migrate/entv2/customtype_query.go @@ -392,7 +392,7 @@ func (ctq *CustomTypeQuery) querySpec() *sqlgraph.QuerySpec { if ps := ctq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, customtype.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (ctq *CustomTypeQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range ctq.order { - p(selector, customtype.ValidColumn) + p(selector) } if offset := ctq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (ctgb *CustomTypeGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ctgb.fields)+len(ctgb.fns)) columns = append(columns, ctgb.fields...) for _, fn := range ctgb.fns { - columns = append(columns, fn(selector, customtype.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ctgb.fields...) } diff --git a/entc/integration/migrate/entv2/ent.go b/entc/integration/migrate/entv2/ent.go index b802e2d89..e5553c120 100644 --- a/entc/integration/migrate/entv2/ent.go +++ b/entc/integration/migrate/entv2/ent.go @@ -14,6 +14,13 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/migrate/entv2/car" + "entgo.io/ent/entc/integration/migrate/entv2/conversion" + "entgo.io/ent/entc/integration/migrate/entv2/customtype" + "entgo.io/ent/entc/integration/migrate/entv2/group" + "entgo.io/ent/entc/integration/migrate/entv2/media" + "entgo.io/ent/entc/integration/migrate/entv2/pet" + "entgo.io/ent/entc/integration/migrate/entv2/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +36,61 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + car.Table: car.ValidColumn, + conversion.Table: conversion.ValidColumn, + customtype.Table: customtype.ValidColumn, + group.Table: group.ValidColumn, + media.Table: media.ValidColumn, + pet.Table: pet.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("entv2: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("entv2: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +99,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +125,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +137,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +149,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/migrate/entv2/group_query.go b/entc/integration/migrate/entv2/group_query.go index d070399aa..cec48523e 100644 --- a/entc/integration/migrate/entv2/group_query.go +++ b/entc/integration/migrate/entv2/group_query.go @@ -368,7 +368,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -387,7 +387,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -653,7 +653,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/entc/integration/migrate/entv2/media_query.go b/entc/integration/migrate/entv2/media_query.go index 0b7f671ab..d87586c7e 100644 --- a/entc/integration/migrate/entv2/media_query.go +++ b/entc/integration/migrate/entv2/media_query.go @@ -392,7 +392,7 @@ func (mq *MediaQuery) querySpec() *sqlgraph.QuerySpec { if ps := mq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, media.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (mq *MediaQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range mq.order { - p(selector, media.ValidColumn) + p(selector) } if offset := mq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (mgb *MediaGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) columns = append(columns, mgb.fields...) for _, fn := range mgb.fns { - columns = append(columns, fn(selector, media.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(mgb.fields...) } diff --git a/entc/integration/migrate/entv2/pet_query.go b/entc/integration/migrate/entv2/pet_query.go index be5c0ab93..ebced8f07 100644 --- a/entc/integration/migrate/entv2/pet_query.go +++ b/entc/integration/migrate/entv2/pet_query.go @@ -447,7 +447,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, pet.ValidColumn) + ps[i](selector) } } } @@ -466,7 +466,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, pet.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -732,7 +732,7 @@ func (pgb *PetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, pet.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/entc/integration/migrate/entv2/user_query.go b/entc/integration/migrate/entv2/user_query.go index c4ce22c23..9daef8ff6 100644 --- a/entc/integration/migrate/entv2/user_query.go +++ b/entc/integration/migrate/entv2/user_query.go @@ -630,7 +630,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -649,7 +649,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -915,7 +915,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/multischema/ent/ent.go b/entc/integration/multischema/ent/ent.go index 28dee448a..7417c927f 100644 --- a/entc/integration/multischema/ent/ent.go +++ b/entc/integration/multischema/ent/ent.go @@ -14,6 +14,9 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/multischema/ent/group" + "entgo.io/ent/entc/integration/multischema/ent/pet" + "entgo.io/ent/entc/integration/multischema/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +32,57 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + group.Table: group.ValidColumn, + pet.Table: pet.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +91,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +117,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +129,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +141,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/multischema/ent/group_query.go b/entc/integration/multischema/ent/group_query.go index 765bb200d..e98417e5b 100644 --- a/entc/integration/multischema/ent/group_query.go +++ b/entc/integration/multischema/ent/group_query.go @@ -509,7 +509,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -531,7 +531,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -797,7 +797,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/entc/integration/multischema/ent/pet_query.go b/entc/integration/multischema/ent/pet_query.go index d625ac906..c4df20141 100644 --- a/entc/integration/multischema/ent/pet_query.go +++ b/entc/integration/multischema/ent/pet_query.go @@ -479,7 +479,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, pet.ValidColumn) + ps[i](selector) } } } @@ -501,7 +501,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, pet.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -767,7 +767,7 @@ func (pgb *PetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, pet.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/entc/integration/multischema/ent/user_query.go b/entc/integration/multischema/ent/user_query.go index 5468e13b5..4c3cdcfec 100644 --- a/entc/integration/multischema/ent/user_query.go +++ b/entc/integration/multischema/ent/user_query.go @@ -578,7 +578,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -600,7 +600,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -866,7 +866,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/privacy/ent/ent.go b/entc/integration/privacy/ent/ent.go index 28dee448a..74fb50c2c 100644 --- a/entc/integration/privacy/ent/ent.go +++ b/entc/integration/privacy/ent/ent.go @@ -14,6 +14,9 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/privacy/ent/task" + "entgo.io/ent/entc/integration/privacy/ent/team" + "entgo.io/ent/entc/integration/privacy/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +32,57 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + task.Table: task.ValidColumn, + team.Table: team.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +91,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +117,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +129,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +141,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/privacy/ent/task_query.go b/entc/integration/privacy/ent/task_query.go index baac33b6b..cd53e90bf 100644 --- a/entc/integration/privacy/ent/task_query.go +++ b/entc/integration/privacy/ent/task_query.go @@ -580,7 +580,7 @@ func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { if ps := tq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, task.ValidColumn) + ps[i](selector) } } } @@ -599,7 +599,7 @@ func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range tq.order { - p(selector, task.ValidColumn) + p(selector) } if offset := tq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -865,7 +865,7 @@ func (tgb *TaskGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(tgb.fields)+len(tgb.fns)) columns = append(columns, tgb.fields...) for _, fn := range tgb.fns { - columns = append(columns, fn(selector, task.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(tgb.fields...) } diff --git a/entc/integration/privacy/ent/team_query.go b/entc/integration/privacy/ent/team_query.go index 44a103a22..17ad12171 100644 --- a/entc/integration/privacy/ent/team_query.go +++ b/entc/integration/privacy/ent/team_query.go @@ -608,7 +608,7 @@ func (tq *TeamQuery) querySpec() *sqlgraph.QuerySpec { if ps := tq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, team.ValidColumn) + ps[i](selector) } } } @@ -627,7 +627,7 @@ func (tq *TeamQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range tq.order { - p(selector, team.ValidColumn) + p(selector) } if offset := tq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -893,7 +893,7 @@ func (tgb *TeamGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(tgb.fields)+len(tgb.fns)) columns = append(columns, tgb.fields...) for _, fn := range tgb.fns { - columns = append(columns, fn(selector, team.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(tgb.fields...) } diff --git a/entc/integration/privacy/ent/user_query.go b/entc/integration/privacy/ent/user_query.go index aa19c50d6..cf7b8e996 100644 --- a/entc/integration/privacy/ent/user_query.go +++ b/entc/integration/privacy/ent/user_query.go @@ -572,7 +572,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -591,7 +591,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -857,7 +857,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/entc/integration/template/ent/ent.go b/entc/integration/template/ent/ent.go index 28dee448a..9de96adfb 100644 --- a/entc/integration/template/ent/ent.go +++ b/entc/integration/template/ent/ent.go @@ -14,6 +14,9 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/entc/integration/template/ent/group" + "entgo.io/ent/entc/integration/template/ent/pet" + "entgo.io/ent/entc/integration/template/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +32,57 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + group.Table: group.ValidColumn, + pet.Table: pet.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +91,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +117,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +129,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +141,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/entc/integration/template/ent/group_query.go b/entc/integration/template/ent/group_query.go index 2fca380d3..6a2316f1f 100644 --- a/entc/integration/template/ent/group_query.go +++ b/entc/integration/template/ent/group_query.go @@ -392,7 +392,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/entc/integration/template/ent/pet_query.go b/entc/integration/template/ent/pet_query.go index b95d49456..0fb0bb8b1 100644 --- a/entc/integration/template/ent/pet_query.go +++ b/entc/integration/template/ent/pet_query.go @@ -471,7 +471,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, pet.ValidColumn) + ps[i](selector) } } } @@ -490,7 +490,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, pet.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -756,7 +756,7 @@ func (pgb *PetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, pet.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/entc/integration/template/ent/user_query.go b/entc/integration/template/ent/user_query.go index 4282094bb..f102e034d 100644 --- a/entc/integration/template/ent/user_query.go +++ b/entc/integration/template/ent/user_query.go @@ -565,7 +565,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -584,7 +584,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -850,7 +850,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/edgeindex/ent/city_query.go b/examples/edgeindex/ent/city_query.go index 1b2ceba60..3948930f6 100644 --- a/examples/edgeindex/ent/city_query.go +++ b/examples/edgeindex/ent/city_query.go @@ -464,7 +464,7 @@ func (cq *CityQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, city.ValidColumn) + ps[i](selector) } } } @@ -483,7 +483,7 @@ func (cq *CityQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, city.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -749,7 +749,7 @@ func (cgb *CityGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, city.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/examples/edgeindex/ent/ent.go b/examples/edgeindex/ent/ent.go index 28dee448a..1ae66c8f3 100644 --- a/examples/edgeindex/ent/ent.go +++ b/examples/edgeindex/ent/ent.go @@ -14,6 +14,8 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/edgeindex/ent/city" + "entgo.io/ent/examples/edgeindex/ent/street" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +31,56 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + city.Table: city.ValidColumn, + street.Table: street.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +89,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +115,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +127,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +139,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/edgeindex/ent/street_query.go b/examples/edgeindex/ent/street_query.go index 2175c2a2c..8bc993e76 100644 --- a/examples/edgeindex/ent/street_query.go +++ b/examples/edgeindex/ent/street_query.go @@ -471,7 +471,7 @@ func (sq *StreetQuery) querySpec() *sqlgraph.QuerySpec { if ps := sq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, street.ValidColumn) + ps[i](selector) } } } @@ -490,7 +490,7 @@ func (sq *StreetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range sq.order { - p(selector, street.ValidColumn) + p(selector) } if offset := sq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -756,7 +756,7 @@ func (sgb *StreetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(sgb.fields)+len(sgb.fns)) columns = append(columns, sgb.fields...) for _, fn := range sgb.fns { - columns = append(columns, fn(selector, street.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(sgb.fields...) } diff --git a/examples/entcpkg/ent/ent.go b/examples/entcpkg/ent/ent.go index 28dee448a..e766f3ee1 100644 --- a/examples/entcpkg/ent/ent.go +++ b/examples/entcpkg/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/entcpkg/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/entcpkg/ent/user_query.go b/examples/entcpkg/ent/user_query.go index b9a9137eb..9c489d753 100644 --- a/examples/entcpkg/ent/user_query.go +++ b/examples/entcpkg/ent/user_query.go @@ -392,7 +392,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -411,7 +411,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -677,7 +677,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/m2m2types/ent/ent.go b/examples/m2m2types/ent/ent.go index 28dee448a..58b1f30be 100644 --- a/examples/m2m2types/ent/ent.go +++ b/examples/m2m2types/ent/ent.go @@ -14,6 +14,8 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/m2m2types/ent/group" + "entgo.io/ent/examples/m2m2types/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +31,56 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + group.Table: group.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +89,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +115,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +127,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +139,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/m2m2types/ent/group_query.go b/examples/m2m2types/ent/group_query.go index 120a56963..f75041e76 100644 --- a/examples/m2m2types/ent/group_query.go +++ b/examples/m2m2types/ent/group_query.go @@ -500,7 +500,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -519,7 +519,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -785,7 +785,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/examples/m2m2types/ent/user_query.go b/examples/m2m2types/ent/user_query.go index 9c45f3d8b..5a233fcc5 100644 --- a/examples/m2m2types/ent/user_query.go +++ b/examples/m2m2types/ent/user_query.go @@ -500,7 +500,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -519,7 +519,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -785,7 +785,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/m2mbidi/ent/ent.go b/examples/m2mbidi/ent/ent.go index 28dee448a..82e82e8f1 100644 --- a/examples/m2mbidi/ent/ent.go +++ b/examples/m2mbidi/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/m2mbidi/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/m2mbidi/ent/user_query.go b/examples/m2mbidi/ent/user_query.go index 3a82e26c0..a6e361a3d 100644 --- a/examples/m2mbidi/ent/user_query.go +++ b/examples/m2mbidi/ent/user_query.go @@ -499,7 +499,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -518,7 +518,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -784,7 +784,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/m2mrecur/ent/ent.go b/examples/m2mrecur/ent/ent.go index 28dee448a..ae38ef67f 100644 --- a/examples/m2mrecur/ent/ent.go +++ b/examples/m2mrecur/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/m2mrecur/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/m2mrecur/ent/user_query.go b/examples/m2mrecur/ent/user_query.go index fcf991ef5..6666fa25e 100644 --- a/examples/m2mrecur/ent/user_query.go +++ b/examples/m2mrecur/ent/user_query.go @@ -600,7 +600,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -619,7 +619,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -885,7 +885,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/o2m2types/ent/ent.go b/examples/o2m2types/ent/ent.go index 28dee448a..9356cddbe 100644 --- a/examples/o2m2types/ent/ent.go +++ b/examples/o2m2types/ent/ent.go @@ -14,6 +14,8 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/o2m2types/ent/pet" + "entgo.io/ent/examples/o2m2types/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +31,56 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + pet.Table: pet.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +89,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +115,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +127,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +139,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/o2m2types/ent/pet_query.go b/examples/o2m2types/ent/pet_query.go index a0cd5d96d..285de5362 100644 --- a/examples/o2m2types/ent/pet_query.go +++ b/examples/o2m2types/ent/pet_query.go @@ -471,7 +471,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, pet.ValidColumn) + ps[i](selector) } } } @@ -490,7 +490,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, pet.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -756,7 +756,7 @@ func (pgb *PetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, pet.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/examples/o2m2types/ent/user_query.go b/examples/o2m2types/ent/user_query.go index 952059651..016653c20 100644 --- a/examples/o2m2types/ent/user_query.go +++ b/examples/o2m2types/ent/user_query.go @@ -464,7 +464,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -483,7 +483,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -749,7 +749,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/o2mrecur/ent/ent.go b/examples/o2mrecur/ent/ent.go index 28dee448a..672b0125c 100644 --- a/examples/o2mrecur/ent/ent.go +++ b/examples/o2mrecur/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/o2mrecur/ent/node" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + node.Table: node.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/o2mrecur/ent/node_query.go b/examples/o2mrecur/ent/node_query.go index 1c0741b46..1d7050c9c 100644 --- a/examples/o2mrecur/ent/node_query.go +++ b/examples/o2mrecur/ent/node_query.go @@ -536,7 +536,7 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { if ps := nq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, node.ValidColumn) + ps[i](selector) } } } @@ -555,7 +555,7 @@ func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range nq.order { - p(selector, node.ValidColumn) + p(selector) } if offset := nq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -821,7 +821,7 @@ func (ngb *NodeGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ngb.fields)+len(ngb.fns)) columns = append(columns, ngb.fields...) for _, fn := range ngb.fns { - columns = append(columns, fn(selector, node.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ngb.fields...) } diff --git a/examples/o2o2types/ent/card_query.go b/examples/o2o2types/ent/card_query.go index 4e8ee05cb..79f105730 100644 --- a/examples/o2o2types/ent/card_query.go +++ b/examples/o2o2types/ent/card_query.go @@ -471,7 +471,7 @@ func (cq *CardQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, card.ValidColumn) + ps[i](selector) } } } @@ -490,7 +490,7 @@ func (cq *CardQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, card.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -756,7 +756,7 @@ func (cgb *CardGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, card.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/examples/o2o2types/ent/ent.go b/examples/o2o2types/ent/ent.go index 28dee448a..ad931673d 100644 --- a/examples/o2o2types/ent/ent.go +++ b/examples/o2o2types/ent/ent.go @@ -14,6 +14,8 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/o2o2types/ent/card" + "entgo.io/ent/examples/o2o2types/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +31,56 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + card.Table: card.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +89,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +115,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +127,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +139,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/o2o2types/ent/user_query.go b/examples/o2o2types/ent/user_query.go index 4b0baeb43..a5e412938 100644 --- a/examples/o2o2types/ent/user_query.go +++ b/examples/o2o2types/ent/user_query.go @@ -463,7 +463,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -482,7 +482,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -748,7 +748,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/o2obidi/ent/ent.go b/examples/o2obidi/ent/ent.go index 28dee448a..48b61f07b 100644 --- a/examples/o2obidi/ent/ent.go +++ b/examples/o2obidi/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/o2obidi/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/o2obidi/ent/user_query.go b/examples/o2obidi/ent/user_query.go index def3fa2d6..341580779 100644 --- a/examples/o2obidi/ent/user_query.go +++ b/examples/o2obidi/ent/user_query.go @@ -470,7 +470,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -489,7 +489,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -755,7 +755,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/o2orecur/ent/ent.go b/examples/o2orecur/ent/ent.go index 28dee448a..bc81621f1 100644 --- a/examples/o2orecur/ent/ent.go +++ b/examples/o2orecur/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/o2orecur/ent/node" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + node.Table: node.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/o2orecur/ent/node_query.go b/examples/o2orecur/ent/node_query.go index 76da350a1..1373c99fe 100644 --- a/examples/o2orecur/ent/node_query.go +++ b/examples/o2orecur/ent/node_query.go @@ -535,7 +535,7 @@ func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { if ps := nq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, node.ValidColumn) + ps[i](selector) } } } @@ -554,7 +554,7 @@ func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range nq.order { - p(selector, node.ValidColumn) + p(selector) } if offset := nq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -820,7 +820,7 @@ func (ngb *NodeGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ngb.fields)+len(ngb.fns)) columns = append(columns, ngb.fields...) for _, fn := range ngb.fns { - columns = append(columns, fn(selector, node.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ngb.fields...) } diff --git a/examples/privacyadmin/ent/ent.go b/examples/privacyadmin/ent/ent.go index 28dee448a..033bf1227 100644 --- a/examples/privacyadmin/ent/ent.go +++ b/examples/privacyadmin/ent/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/privacyadmin/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +30,55 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +87,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +113,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +125,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +137,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/privacyadmin/ent/user_query.go b/examples/privacyadmin/ent/user_query.go index e3a464915..63f6f130d 100644 --- a/examples/privacyadmin/ent/user_query.go +++ b/examples/privacyadmin/ent/user_query.go @@ -398,7 +398,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -417,7 +417,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -683,7 +683,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/privacytenant/ent/ent.go b/examples/privacytenant/ent/ent.go index 28dee448a..8536ef284 100644 --- a/examples/privacytenant/ent/ent.go +++ b/examples/privacytenant/ent/ent.go @@ -14,6 +14,9 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/privacytenant/ent/group" + "entgo.io/ent/examples/privacytenant/ent/tenant" + "entgo.io/ent/examples/privacytenant/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +32,57 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + group.Table: group.ValidColumn, + tenant.Table: tenant.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +91,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +117,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +129,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +141,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/privacytenant/ent/group_query.go b/examples/privacytenant/ent/group_query.go index c3941d9ce..ff753ecce 100644 --- a/examples/privacytenant/ent/group_query.go +++ b/examples/privacytenant/ent/group_query.go @@ -580,7 +580,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -599,7 +599,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -865,7 +865,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/examples/privacytenant/ent/tenant_query.go b/examples/privacytenant/ent/tenant_query.go index 291392586..5a555173c 100644 --- a/examples/privacytenant/ent/tenant_query.go +++ b/examples/privacytenant/ent/tenant_query.go @@ -398,7 +398,7 @@ func (tq *TenantQuery) querySpec() *sqlgraph.QuerySpec { if ps := tq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, tenant.ValidColumn) + ps[i](selector) } } } @@ -417,7 +417,7 @@ func (tq *TenantQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range tq.order { - p(selector, tenant.ValidColumn) + p(selector) } if offset := tq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -683,7 +683,7 @@ func (tgb *TenantGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(tgb.fields)+len(tgb.fns)) columns = append(columns, tgb.fields...) for _, fn := range tgb.fns { - columns = append(columns, fn(selector, tenant.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(tgb.fields...) } diff --git a/examples/privacytenant/ent/user_query.go b/examples/privacytenant/ent/user_query.go index 9a6546896..cdf5e72b7 100644 --- a/examples/privacytenant/ent/user_query.go +++ b/examples/privacytenant/ent/user_query.go @@ -580,7 +580,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -599,7 +599,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -865,7 +865,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/start/ent/car_query.go b/examples/start/ent/car_query.go index a5c349dd4..f49fe29c6 100644 --- a/examples/start/ent/car_query.go +++ b/examples/start/ent/car_query.go @@ -471,7 +471,7 @@ func (cq *CarQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, car.ValidColumn) + ps[i](selector) } } } @@ -490,7 +490,7 @@ func (cq *CarQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, car.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -756,7 +756,7 @@ func (cgb *CarGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, car.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/examples/start/ent/ent.go b/examples/start/ent/ent.go index 28dee448a..ec139ffe8 100644 --- a/examples/start/ent/ent.go +++ b/examples/start/ent/ent.go @@ -14,6 +14,9 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/start/ent/car" + "entgo.io/ent/examples/start/ent/group" + "entgo.io/ent/examples/start/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +32,57 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + car.Table: car.ValidColumn, + group.Table: group.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +91,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +117,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +129,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +141,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/start/ent/group_query.go b/examples/start/ent/group_query.go index 7ec0ee099..54c099252 100644 --- a/examples/start/ent/group_query.go +++ b/examples/start/ent/group_query.go @@ -500,7 +500,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -519,7 +519,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -785,7 +785,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/examples/start/ent/user_query.go b/examples/start/ent/user_query.go index b774cfeb6..e3449adb3 100644 --- a/examples/start/ent/user_query.go +++ b/examples/start/ent/user_query.go @@ -566,7 +566,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -585,7 +585,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -851,7 +851,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) } diff --git a/examples/traversal/ent/ent.go b/examples/traversal/ent/ent.go index 28dee448a..6cd31fd24 100644 --- a/examples/traversal/ent/ent.go +++ b/examples/traversal/ent/ent.go @@ -14,6 +14,9 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/examples/traversal/ent/group" + "entgo.io/ent/examples/traversal/ent/pet" + "entgo.io/ent/examples/traversal/ent/user" ) // ent aliases to avoid import conflicts in user's code. @@ -29,36 +32,57 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + group.Table: group.ValidColumn, + pet.Table: pet.ValidColumn, + user.Table: user.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("ent: unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("ent: unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ordering error: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -67,23 +91,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -92,9 +117,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -103,9 +129,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -114,9 +141,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("grouping error: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/examples/traversal/ent/group_query.go b/examples/traversal/ent/group_query.go index 55e4ed81a..27f78e7df 100644 --- a/examples/traversal/ent/group_query.go +++ b/examples/traversal/ent/group_query.go @@ -573,7 +573,7 @@ func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { if ps := gq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, group.ValidColumn) + ps[i](selector) } } } @@ -592,7 +592,7 @@ func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range gq.order { - p(selector, group.ValidColumn) + p(selector) } if offset := gq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -858,7 +858,7 @@ func (ggb *GroupGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ggb.fields)+len(ggb.fns)) columns = append(columns, ggb.fields...) for _, fn := range ggb.fns { - columns = append(columns, fn(selector, group.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ggb.fields...) } diff --git a/examples/traversal/ent/pet_query.go b/examples/traversal/ent/pet_query.go index 7b0354ccb..5582e0cfd 100644 --- a/examples/traversal/ent/pet_query.go +++ b/examples/traversal/ent/pet_query.go @@ -573,7 +573,7 @@ func (pq *PetQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, pet.ValidColumn) + ps[i](selector) } } } @@ -592,7 +592,7 @@ func (pq *PetQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, pet.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -858,7 +858,7 @@ func (pgb *PetGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, pet.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/examples/traversal/ent/user_query.go b/examples/traversal/ent/user_query.go index a07834b23..84bf1b1e6 100644 --- a/examples/traversal/ent/user_query.go +++ b/examples/traversal/ent/user_query.go @@ -732,7 +732,7 @@ func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { if ps := uq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, user.ValidColumn) + ps[i](selector) } } } @@ -751,7 +751,7 @@ func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range uq.order { - p(selector, user.ValidColumn) + p(selector) } if offset := uq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -1017,7 +1017,7 @@ func (ugb *UserGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) columns = append(columns, ugb.fields...) for _, fn := range ugb.fns { - columns = append(columns, fn(selector, user.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ugb.fields...) }