mirror of
https://github.com/ent/ent.git
synced 2026-05-24 09:31:56 +03:00
391 lines
9.4 KiB
Go
391 lines
9.4 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package ws
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/facebook/ent/dialect/gremlin"
|
|
"github.com/facebook/ent/dialect/gremlin/encoding/graphson"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type conn struct{ *websocket.Conn }
|
|
|
|
func (c conn) ReadRequest() (*gremlin.Request, error) {
|
|
_, data, err := c.ReadMessage()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var req gremlin.Request
|
|
if err := graphson.Unmarshal(data[data[0]+1:], &req); err != nil {
|
|
return nil, err
|
|
}
|
|
return &req, nil
|
|
}
|
|
|
|
func (c conn) WriteResponse(rsp *gremlin.Response) error {
|
|
data, err := graphson.Marshal(rsp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.WriteMessage(websocket.BinaryMessage, data)
|
|
}
|
|
|
|
func serve(handler func(conn)) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024}
|
|
c, _ := upgrader.Upgrade(w, r, nil)
|
|
defer c.Close()
|
|
handler(conn{c})
|
|
for {
|
|
_, _, err := c.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
}))
|
|
}
|
|
|
|
func TestConnectClosure(t *testing.T) {
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
defer wg.Wait()
|
|
|
|
srv := serve(func(conn conn) {
|
|
defer wg.Done()
|
|
_, _, err := conn.ReadMessage()
|
|
assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure))
|
|
})
|
|
defer srv.Close()
|
|
|
|
conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
|
|
err = conn.Close()
|
|
assert.NoError(t, err)
|
|
|
|
_, err = conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()"))
|
|
assert.EqualError(t, err, ErrConnClosed.Error())
|
|
}
|
|
|
|
func TestSimpleQuery(t *testing.T) {
|
|
srv := serve(func(conn conn) {
|
|
typ, data, err := conn.ReadMessage()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, websocket.BinaryMessage, typ)
|
|
|
|
var req gremlin.Request
|
|
err = graphson.Unmarshal(data[data[0]+1:], &req)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "g.V()", req.Arguments["gremlin"])
|
|
|
|
rsp := gremlin.Response{RequestID: req.RequestID}
|
|
rsp.Status.Code = gremlin.StatusNoContent
|
|
err = conn.WriteResponse(&rsp)
|
|
require.NoError(t, err)
|
|
})
|
|
defer srv.Close()
|
|
|
|
conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer assert.Condition(t, func() bool { return assert.NoError(t, conn.Close()) })
|
|
|
|
rsp, err := conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()"))
|
|
assert.NoError(t, err)
|
|
require.NotNil(t, rsp)
|
|
assert.Equal(t, gremlin.StatusNoContent, rsp.Status.Code)
|
|
}
|
|
|
|
func TestDuplicateRequest(t *testing.T) {
|
|
// skip until flakiness will be fixed.
|
|
t.SkipNow()
|
|
srv := serve(func(conn conn) {
|
|
req, err := conn.ReadRequest()
|
|
require.NoError(t, err)
|
|
|
|
rsp := gremlin.Response{RequestID: req.RequestID}
|
|
rsp.Status.Code = gremlin.StatusNoContent
|
|
err = conn.WriteResponse(&rsp)
|
|
require.NoError(t, err)
|
|
})
|
|
defer srv.Close()
|
|
|
|
conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
var errors [2]error
|
|
req := gremlin.NewEvalRequest("g.V()")
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(errors))
|
|
|
|
for i := range errors {
|
|
go func(i int) {
|
|
_, errors[i] = conn.Execute(context.Background(), req)
|
|
wg.Done()
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
err = errors[0]
|
|
if err == nil {
|
|
err = errors[1]
|
|
}
|
|
assert.EqualError(t, err, ErrDuplicateRequest.Error())
|
|
}
|
|
|
|
func TestConnectCancellation(t *testing.T) {
|
|
srv := serve(func(conn) {})
|
|
defer srv.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
conn, err := DefaultDialer.DialContext(ctx, "ws://"+srv.Listener.Addr().String())
|
|
assert.Error(t, err)
|
|
assert.Nil(t, conn)
|
|
}
|
|
|
|
func TestQueryCancellation(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
srv := serve(func(conn conn) {
|
|
if _, _, err := conn.ReadMessage(); err == nil {
|
|
cancel()
|
|
}
|
|
})
|
|
defer srv.Close()
|
|
|
|
conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
_, err = conn.Execute(ctx, gremlin.NewEvalRequest("g.E()"))
|
|
assert.EqualError(t, err, context.Canceled.Error())
|
|
}
|
|
|
|
func TestBadResponse(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
mangle func(*gremlin.Response) *gremlin.Response
|
|
}{
|
|
{
|
|
name: "NoStatus",
|
|
mangle: func(rsp *gremlin.Response) *gremlin.Response {
|
|
return rsp
|
|
},
|
|
},
|
|
{
|
|
name: "Malformed",
|
|
mangle: func(rsp *gremlin.Response) *gremlin.Response {
|
|
rsp.Status.Code = gremlin.StatusMalformedRequest
|
|
rsp.Status.Message = "bad request"
|
|
return rsp
|
|
},
|
|
},
|
|
{
|
|
name: "Unknown",
|
|
mangle: func(rsp *gremlin.Response) *gremlin.Response {
|
|
rsp.Status.Code = 424242
|
|
return rsp
|
|
},
|
|
},
|
|
}
|
|
|
|
srv := serve(func(conn conn) {
|
|
for {
|
|
req, err := conn.ReadRequest()
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
idx, err := strconv.ParseInt(req.Arguments["gremlin"].(string), 10, 0)
|
|
require.NoError(t, err)
|
|
|
|
err = conn.WriteResponse(tests[idx].mangle(&gremlin.Response{RequestID: req.RequestID}))
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
defer srv.Close()
|
|
|
|
conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(tests))
|
|
|
|
ctx := context.Background()
|
|
for i, tc := range tests {
|
|
i, tc := i, tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
defer wg.Done()
|
|
rsp, err := conn.Execute(ctx, gremlin.NewEvalRequest(strconv.FormatInt(int64(i), 10)))
|
|
assert.NoError(t, err)
|
|
assert.True(t, rsp.IsErr())
|
|
})
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestServerHangup(t *testing.T) {
|
|
// skip until flakiness will be fixed.
|
|
t.SkipNow()
|
|
srv := serve(func(conn conn) { _ = conn.Close() })
|
|
defer srv.Close()
|
|
|
|
conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
_, err = conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()"))
|
|
assert.EqualError(t, err, ErrConnClosed.Error())
|
|
assert.Error(t, conn.ctx.Err())
|
|
}
|
|
|
|
func TestCanceledLongRequest(t *testing.T) {
|
|
// skip until flakiness will be fixed.
|
|
t.SkipNow()
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
srv := serve(func(conn conn) {
|
|
var responses [3]*gremlin.Response
|
|
for i := 0; i < len(responses); i++ {
|
|
req, err := conn.ReadRequest()
|
|
require.NoError(t, err)
|
|
|
|
rsp := gremlin.Response{RequestID: req.RequestID}
|
|
rsp.Status.Code = gremlin.StatusSuccess
|
|
rsp.Result.Data = graphson.RawMessage(`"ok"`)
|
|
responses[i] = &rsp
|
|
}
|
|
|
|
cancel()
|
|
|
|
responses[0], responses[2] = responses[2], responses[0]
|
|
for i := 0; i < len(responses); i++ {
|
|
err := conn.WriteResponse(responses[i])
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
defer srv.Close()
|
|
|
|
conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(3)
|
|
defer wg.Wait()
|
|
|
|
for i := 0; i < 3; i++ {
|
|
go func(ctx context.Context, idx int) {
|
|
defer wg.Done()
|
|
rsp, err := conn.Execute(ctx, gremlin.NewEvalRequest("g.V()"))
|
|
if idx > 0 {
|
|
assert.NoError(t, err)
|
|
assert.EqualValues(t, []byte(`"ok"`), rsp.Result.Data)
|
|
} else {
|
|
assert.EqualError(t, err, context.Canceled.Error())
|
|
}
|
|
}(ctx, i)
|
|
ctx = context.Background()
|
|
}
|
|
}
|
|
|
|
func TestPartialResponse(t *testing.T) {
|
|
type kv struct {
|
|
Key string
|
|
Value int
|
|
}
|
|
kvs := []kv{
|
|
{"one", 1},
|
|
{"two", 2},
|
|
{"three", 3},
|
|
}
|
|
srv := serve(func(conn conn) {
|
|
req, err := conn.ReadRequest()
|
|
require.NoError(t, err)
|
|
|
|
for i := range kvs {
|
|
data, err := graphson.Marshal([]kv{kvs[i]})
|
|
require.NoError(t, err)
|
|
|
|
rsp := gremlin.Response{RequestID: req.RequestID}
|
|
rsp.Result.Data = graphson.RawMessage(data)
|
|
|
|
if i != len(kvs)-1 {
|
|
rsp.Status.Code = gremlin.StatusPartialContent
|
|
} else {
|
|
rsp.Status.Code = gremlin.StatusSuccess
|
|
}
|
|
|
|
err = conn.WriteResponse(&rsp)
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
defer srv.Close()
|
|
|
|
conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
rsp, err := conn.Execute(context.Background(), gremlin.NewEvalRequest("g.E()"))
|
|
assert.NoError(t, err)
|
|
|
|
var result []kv
|
|
err = graphson.Unmarshal(rsp.Result.Data, &result)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, kvs, result)
|
|
}
|
|
|
|
func TestAuthentication(t *testing.T) {
|
|
user, pass := "username", "password"
|
|
srv := serve(func(conn conn) {
|
|
req, err := conn.ReadRequest()
|
|
require.NoError(t, err)
|
|
|
|
rsp := gremlin.Response{RequestID: req.RequestID}
|
|
rsp.Status.Code = gremlin.StatusAuthenticate
|
|
err = conn.WriteResponse(&rsp)
|
|
require.NoError(t, err)
|
|
|
|
areq, err := conn.ReadRequest()
|
|
require.NoError(t, err)
|
|
|
|
var acreds gremlin.Credentials
|
|
err = acreds.UnmarshalText([]byte(areq.Arguments["sasl"].(string)))
|
|
assert.NoError(t, err)
|
|
areq.Arguments["sasl"] = acreds
|
|
assert.Equal(t, gremlin.NewAuthRequest(req.RequestID, user, pass), areq)
|
|
|
|
rsp = gremlin.Response{RequestID: req.RequestID}
|
|
rsp.Status.Code = gremlin.StatusNoContent
|
|
err = conn.WriteResponse(&rsp)
|
|
require.NoError(t, err)
|
|
})
|
|
defer srv.Close()
|
|
|
|
dialer := *DefaultDialer
|
|
dialer.user = user
|
|
dialer.pass = pass
|
|
|
|
client, err := dialer.Dial("ws://" + srv.Listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer client.Close()
|
|
|
|
_, err = client.Execute(context.Background(), gremlin.NewEvalRequest("g.E().drop()"))
|
|
assert.NoError(t, err)
|
|
}
|