Files
ent/dialect/gremlin/ocgremlin/trace_test.go
Ariel Mashraki fd0a7f9f02 all: facebookincubator/ent => facebook/ent (#660)
ent repository is going to be migrated to facebook organization
2020-08-18 11:05:08 +03:00

255 lines
7.4 KiB
Go

// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package ocgremlin
import (
"bytes"
"context"
"errors"
"fmt"
"testing"
"github.com/facebook/ent/dialect/gremlin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.opencensus.io/trace"
)
type mockTransport struct {
mock.Mock
}
func (t *mockTransport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) {
args := t.Called(ctx, req)
rsp, _ := args.Get(0).(*gremlin.Response)
return rsp, args.Error(1)
}
func TestTraceTransportRoundTrip(t *testing.T) {
_, parent := trace.StartSpan(context.Background(), "parent")
tests := []struct {
name string
parent *trace.Span
}{
{
name: "no parent",
parent: nil,
},
{
name: "parent",
parent: parent,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
transport := &mockTransport{}
transport.On("RoundTrip", mock.Anything, mock.Anything).
Run(func(args mock.Arguments) {
span := trace.FromContext(args.Get(0).(context.Context))
require.NotNil(t, span)
if tt.parent != nil {
assert.Equal(t, tt.parent.SpanContext().TraceID, span.SpanContext().TraceID)
}
}).
Return(nil, errors.New("noop")).
Once()
defer transport.AssertExpectations(t)
ctx, req := context.Background(), gremlin.NewEvalRequest("g.V()")
if tt.parent != nil {
ctx = trace.NewContext(ctx, tt.parent)
}
rt := &Transport{Base: transport}
_, _ = rt.RoundTrip(ctx, req)
})
}
}
type testExporter struct {
spans []*trace.SpanData
}
func (t *testExporter) ExportSpan(s *trace.SpanData) {
t.spans = append(t.spans, s)
}
func TestEndToEnd(t *testing.T) {
trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()})
var exporter testExporter
trace.RegisterExporter(&exporter)
defer trace.UnregisterExporter(&exporter)
req := gremlin.NewEvalRequest("g.V()")
rsp := &gremlin.Response{
RequestID: req.RequestID,
}
rsp.Status.Code = 200
rsp.Status.Message = "OK"
var transport mockTransport
transport.On("RoundTrip", mock.Anything, mock.Anything).
Return(rsp, nil).
Once()
defer transport.AssertExpectations(t)
rt := &Transport{Base: &transport, WithQuery: true}
_, err := rt.RoundTrip(context.Background(), req)
require.NoError(t, err)
require.Len(t, exporter.spans, 1)
attrs := exporter.spans[0].Attributes
assert.Len(t, attrs, 5)
assert.Equal(t, req.RequestID, attrs["gremlin.request_id"])
assert.Equal(t, req.Operation, attrs["gremlin.operation"])
assert.Equal(t, req.Arguments[gremlin.ArgsGremlin], attrs["gremlin.query"])
assert.Equal(t, int64(200), attrs["gremlin.code"])
assert.Equal(t, "OK", attrs["gremlin.message"])
}
func TestRequestAttributes(t *testing.T) {
tests := []struct {
name string
makeReq func() *gremlin.Request
wantAttrs []trace.Attribute
}{
{
name: "Query without bindings",
makeReq: func() *gremlin.Request {
req := gremlin.NewEvalRequest("g.E().count()")
req.RequestID = "a8b5c664-03ca-4175-a9e7-569b46f3551c"
return req
},
wantAttrs: []trace.Attribute{
trace.StringAttribute("gremlin.request_id", "a8b5c664-03ca-4175-a9e7-569b46f3551c"),
trace.StringAttribute("gremlin.operation", "eval"),
trace.StringAttribute("gremlin.query", "g.E().count()"),
},
},
{
name: "Query with bindings",
makeReq: func() *gremlin.Request {
bindings := map[string]interface{}{
"$1": "user", "$2": int64(42),
"$3": 3.14, "$4": bytes.Repeat([]byte{0xff}, 257),
"$5": true, "$6": nil,
}
req := gremlin.NewEvalRequest(
`g.V().hasLabel($1).has("age",$2).has("v",$3).limit($4).valueMap($5)`,
gremlin.WithBindings(bindings),
)
req.RequestID = "d3d986fa-bd22-41bd-b2f7-ef2f1f639260"
return req
},
wantAttrs: []trace.Attribute{
trace.StringAttribute("gremlin.request_id", "d3d986fa-bd22-41bd-b2f7-ef2f1f639260"),
trace.StringAttribute("gremlin.operation", "eval"),
trace.StringAttribute("gremlin.query", `g.V().hasLabel($1).has("age",$2).has("v",$3).limit($4).valueMap($5)`),
trace.StringAttribute("gremlin.binding.$1", "user"),
trace.Int64Attribute("gremlin.binding.$2", 42),
trace.Float64Attribute("gremlin.binding.$3", 3.14),
trace.StringAttribute("gremlin.binding.$4", func() string {
str := fmt.Sprintf("%v", bytes.Repeat([]byte{0xff}, 256))
return str[:256]
}()),
trace.BoolAttribute("gremlin.binding.$5", true),
trace.StringAttribute("gremlin.binding.$6", ""),
},
},
{
name: "Authentication",
makeReq: func() *gremlin.Request {
return gremlin.NewAuthRequest(
"d239d950-59a1-41a7-a103-908f976ebd89",
"user", "pass",
)
},
wantAttrs: []trace.Attribute{
trace.StringAttribute("gremlin.request_id", "d239d950-59a1-41a7-a103-908f976ebd89"),
trace.StringAttribute("gremlin.operation", "authentication"),
trace.StringAttribute("gremlin.query", ""),
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
req := tt.makeReq()
attrs := requestAttrs(req, true)
for _, attr := range attrs {
assert.Contains(t, tt.wantAttrs, attr)
}
assert.Len(t, attrs, len(tt.wantAttrs))
})
}
}
func TestResponseAttributes(t *testing.T) {
tests := []struct {
name string
makeRsp func() *gremlin.Response
wantAttrs []trace.Attribute
}{
{
name: "Success no message",
makeRsp: func() *gremlin.Response {
var rsp gremlin.Response
rsp.Status.Code = 204
return &rsp
},
wantAttrs: []trace.Attribute{
trace.Int64Attribute("gremlin.code", 204),
},
},
{
name: "Authenticate with message",
makeRsp: func() *gremlin.Response {
var rsp gremlin.Response
rsp.Status.Code = 407
rsp.Status.Message = "login required"
return &rsp
},
wantAttrs: []trace.Attribute{
trace.Int64Attribute("gremlin.code", 407),
trace.StringAttribute("gremlin.message", "login required"),
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
rsp := tt.makeRsp()
attrs := responseAttrs(rsp)
assert.Equal(t, tt.wantAttrs, attrs)
})
}
}
func TestTraceStatus(t *testing.T) {
tests := []struct {
in int
want trace.Status
}{
{200, trace.Status{Code: trace.StatusCodeOK, Message: "Success"}},
{204, trace.Status{Code: trace.StatusCodeOK, Message: "No Content"}},
{206, trace.Status{Code: trace.StatusCodeOK, Message: "Partial Content"}},
{401, trace.Status{Code: trace.StatusCodePermissionDenied, Message: "Unauthorized"}},
{407, trace.Status{Code: trace.StatusCodeUnauthenticated, Message: "Authenticate"}},
{498, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Malformed Request"}},
{499, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Invalid Request Arguments"}},
{500, trace.Status{Code: trace.StatusCodeInternal, Message: "Server Error"}},
{597, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Script Evaluation Error"}},
{598, trace.Status{Code: trace.StatusCodeDeadlineExceeded, Message: "Server Timeout"}},
{599, trace.Status{Code: trace.StatusCodeInternal, Message: "Server Serialization Error"}},
{600, trace.Status{Code: trace.StatusCodeUnknown, Message: ""}},
}
for _, tt := range tests {
assert.Equal(t, tt.want, TraceStatus(tt.in))
}
}