mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
315 lines
7.6 KiB
Go
315 lines
7.6 KiB
Go
// Copyright 2019-present Facebook Inc. All rights reserved.
|
|
// This source code is licensed under the Apache 2.0 license found
|
|
// in the LICENSE file in the root directory of this source tree.
|
|
|
|
package ws
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"entgo.io/ent/dialect/gremlin"
|
|
"entgo.io/ent/dialect/gremlin/encoding"
|
|
"entgo.io/ent/dialect/gremlin/encoding/graphson"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/pkg/errors"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
const (
|
|
// Time allowed to write a message to the peer.
|
|
writeWait = 5 * time.Second
|
|
|
|
// Time allowed to read the next pong message from the peer.
|
|
pongWait = 10 * time.Second
|
|
|
|
// Send pings to peer with this period. Must be less than pongWait.
|
|
pingPeriod = (pongWait * 9) / 10
|
|
)
|
|
|
|
type (
|
|
// A Dialer contains options for connecting to Gremlin server.
|
|
Dialer struct {
|
|
// Underlying websocket dialer.
|
|
websocket.Dialer
|
|
|
|
// Gremlin server basic auth credentials.
|
|
user, pass string
|
|
}
|
|
|
|
// Conn performs operations on a gremlin server.
|
|
Conn struct {
|
|
// Underlying websocket connection.
|
|
conn *websocket.Conn
|
|
|
|
// Credentials for basic authentication.
|
|
user, pass string
|
|
|
|
// Goroutine tracking.
|
|
ctx context.Context
|
|
grp *errgroup.Group
|
|
|
|
// Channel of outbound requests.
|
|
send chan io.Reader
|
|
|
|
// Map of in flight requests.
|
|
inflight sync.Map
|
|
}
|
|
|
|
// inflight tracks request state.
|
|
inflight struct {
|
|
// partially received data
|
|
frags []graphson.RawMessage
|
|
|
|
// response channel
|
|
result chan<- result
|
|
}
|
|
|
|
// represents an execution result.
|
|
result struct {
|
|
rsp *gremlin.Response
|
|
err error
|
|
}
|
|
)
|
|
|
|
var (
|
|
// DefaultDialer is a dialer with all fields set to the default values.
|
|
DefaultDialer = &Dialer{
|
|
Dialer: websocket.Dialer{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
HandshakeTimeout: 5 * time.Second,
|
|
WriteBufferSize: 8192,
|
|
ReadBufferSize: 8192,
|
|
},
|
|
}
|
|
|
|
// ErrConnClosed is returned by the Conn's Execute method when
|
|
// the underlying gremlin server connection is closed.
|
|
ErrConnClosed = errors.New("gremlin: server connection closed")
|
|
|
|
// ErrDuplicateRequest is returned by the Conns Execute method on
|
|
// request identifier key collision.
|
|
ErrDuplicateRequest = errors.New("gremlin: duplicate request")
|
|
)
|
|
|
|
// Dial creates a new connection by calling DialContext with a background context.
|
|
func (d *Dialer) Dial(uri string) (*Conn, error) {
|
|
return d.DialContext(context.Background(), uri)
|
|
}
|
|
|
|
// DialContext creates a new Gremlin connection.
|
|
func (d *Dialer) DialContext(ctx context.Context, uri string) (*Conn, error) {
|
|
c, rsp, err := d.Dialer.DialContext(ctx, uri, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "gremlin: dialing uri %s", uri)
|
|
}
|
|
defer rsp.Body.Close()
|
|
|
|
conn := &Conn{
|
|
conn: c,
|
|
user: d.user,
|
|
pass: d.pass,
|
|
send: make(chan io.Reader),
|
|
}
|
|
conn.grp, conn.ctx = errgroup.WithContext(context.Background())
|
|
|
|
conn.grp.Go(conn.sender)
|
|
conn.grp.Go(conn.receiver)
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// Execute executes a request against a Gremlin server.
|
|
func (c *Conn) Execute(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) {
|
|
// buffered result channel prevents receiver block on context cancellation
|
|
result := make(chan result, 1)
|
|
|
|
// request id must be unique across inflight request
|
|
if _, loaded := c.inflight.LoadOrStore(req.RequestID, &inflight{result: result}); loaded {
|
|
return nil, ErrDuplicateRequest
|
|
}
|
|
|
|
pr, pw := io.Pipe()
|
|
defer pr.Close()
|
|
|
|
// stream graphson encoding into request
|
|
c.grp.Go(func() error {
|
|
err := graphson.NewEncoder(pw).Encode(req)
|
|
if err != nil {
|
|
err = errors.Wrap(err, "encoding request")
|
|
}
|
|
pw.CloseWithError(err)
|
|
return err
|
|
})
|
|
|
|
// local copy for single write
|
|
send := c.send
|
|
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
c.inflight.Delete(req.RequestID)
|
|
return nil, ErrConnClosed
|
|
case <-ctx.Done():
|
|
c.inflight.Delete(req.RequestID)
|
|
return nil, ctx.Err()
|
|
case send <- pr:
|
|
send = nil
|
|
case result := <-result:
|
|
return result.rsp, result.err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close connection with a Gremlin server.
|
|
func (c *Conn) Close() error {
|
|
c.grp.Go(func() error { return ErrConnClosed })
|
|
_ = c.grp.Wait()
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) sender() error {
|
|
pinger := time.NewTicker(pingPeriod)
|
|
defer pinger.Stop()
|
|
|
|
// closing connection terminates receiver
|
|
defer c.conn.Close()
|
|
|
|
for {
|
|
select {
|
|
case r := <-c.send:
|
|
// ensure write completes within a window
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
|
|
// fetch next message writer
|
|
w, err := c.conn.NextWriter(websocket.BinaryMessage)
|
|
if err != nil {
|
|
return errors.Wrap(err, "getting message writer")
|
|
}
|
|
|
|
// write mime header
|
|
if _, err := w.Write(encoding.GraphSON3Mime); err != nil {
|
|
return errors.Wrap(err, "writing mime header")
|
|
}
|
|
|
|
// write request body
|
|
if _, err := io.Copy(w, r); err != nil {
|
|
return errors.Wrap(err, "writing request")
|
|
}
|
|
|
|
// finish message write
|
|
if err := w.Close(); err != nil {
|
|
return errors.Wrap(err, "closing message writer")
|
|
}
|
|
case <-c.ctx.Done():
|
|
// connection closing
|
|
return c.conn.WriteControl(
|
|
websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
|
time.Time{},
|
|
)
|
|
case <-pinger.C:
|
|
// periodic connection keepalive
|
|
if err := c.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(writeWait)); err != nil {
|
|
return errors.Wrap(err, "writing ping message")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Conn) receiver() error {
|
|
// handle keepalive responses
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
c.conn.SetPongHandler(func(string) error {
|
|
return c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
})
|
|
|
|
// complete all in flight requests on termination
|
|
defer c.inflight.Range(func(id, ifr interface{}) bool {
|
|
ifr.(*inflight).result <- result{err: ErrConnClosed}
|
|
c.inflight.Delete(id)
|
|
return true
|
|
})
|
|
|
|
for {
|
|
// rely on sender connection close during termination
|
|
_, r, err := c.conn.NextReader()
|
|
if err != nil {
|
|
return errors.Wrap(err, "getting next reader")
|
|
}
|
|
|
|
// decode received response
|
|
var rsp gremlin.Response
|
|
if err := graphson.NewDecoder(r).Decode(&rsp); err != nil {
|
|
return errors.Wrap(err, "reading response")
|
|
}
|
|
|
|
ifr, ok := c.inflight.Load(rsp.RequestID)
|
|
if !ok {
|
|
// context cancellation aborts inflight requests
|
|
continue
|
|
}
|
|
|
|
// handle incoming response
|
|
if done := c.receive(ifr.(*inflight), &rsp); done {
|
|
// stop tracking finished requests
|
|
c.inflight.Delete(rsp.RequestID)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Conn) receive(ifr *inflight, rsp *gremlin.Response) bool {
|
|
result := result{rsp: rsp}
|
|
switch rsp.Status.Code {
|
|
case gremlin.StatusSuccess:
|
|
// quickly handle non fragmented responses
|
|
if ifr.frags == nil {
|
|
break
|
|
}
|
|
// handle fragment
|
|
fallthrough
|
|
case gremlin.StatusPartialContent:
|
|
// append received fragment
|
|
var frag []graphson.RawMessage
|
|
if err := graphson.Unmarshal(rsp.Result.Data, &frag); err != nil {
|
|
result.err = errors.Wrap(err, "decoding response fragment")
|
|
break
|
|
}
|
|
ifr.frags = append(ifr.frags, frag...)
|
|
|
|
// partial response requires additional fragments
|
|
if rsp.Status.Code == gremlin.StatusPartialContent {
|
|
return false
|
|
}
|
|
|
|
// reassemble fragmented response
|
|
if rsp.Result.Data, result.err = graphson.Marshal(ifr.frags); result.err != nil {
|
|
result.err = errors.Wrap(result.err, "assembling fragmented response")
|
|
}
|
|
case gremlin.StatusAuthenticate:
|
|
// receiver should never block
|
|
c.grp.Go(func() error {
|
|
var buf bytes.Buffer
|
|
if err := graphson.NewEncoder(&buf).Encode(
|
|
gremlin.NewAuthRequest(rsp.RequestID, c.user, c.pass),
|
|
); err != nil {
|
|
return errors.Wrap(err, "encoding auth request")
|
|
}
|
|
select {
|
|
case c.send <- &buf:
|
|
case <-c.ctx.Done():
|
|
}
|
|
return c.ctx.Err()
|
|
})
|
|
return false
|
|
}
|
|
|
|
ifr.result <- result
|
|
return true
|
|
}
|