Skip to content

Commit

Permalink
Merge pull request #463 from cloudflare/mitali/add-request-context
Browse files Browse the repository at this point in the history
add a new field to operation struct that can carry request context
  • Loading branch information
mitalirawat committed May 24, 2024
2 parents 482aa4e + ef6c915 commit 8b7f832
Show file tree
Hide file tree
Showing 29 changed files with 1,565 additions and 21 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ require (
github.com/google/certificate-transparency-go v1.1.4 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.5 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA=
Expand Down
11 changes: 11 additions & 0 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ const (
TagExtra Tag = 0x14
// TagJaegerSpan contains a binary encoded jaeger span context. See https://www.jaegertracing.io/docs/1.19/client-libraries/#value
TagJaegerSpan Tag = 0x15
// TagReqContext contains request metadata
TagReqContext Tag = 0x16
// TagPadding implies an item with a meaningless payload added for padding.
TagPadding Tag = 0x20
)
Expand Down Expand Up @@ -414,6 +416,7 @@ type Operation struct {
CertID string
CustomFuncName string
JaegerSpan []byte
ReqContext []byte
}

func (o *Operation) String() string {
Expand Down Expand Up @@ -504,6 +507,9 @@ func (o *Operation) Bytes() uint16 {
if o.JaegerSpan != nil {
add(tlvLen(len(o.JaegerSpan)))
}
if o.ReqContext != nil {
add(tlvLen(len(o.ReqContext)))
}
if int(length)+headerSize < paddedLength {
// TODO: Are we sure that's the right behavior?

Expand Down Expand Up @@ -574,6 +580,9 @@ func (o *Operation) MarshalBinary() ([]byte, error) {
if o.JaegerSpan != nil {
b = append(b, tlvBytes(TagJaegerSpan, o.JaegerSpan)...)
}
if o.ReqContext != nil {
b = append(b, tlvBytes(TagReqContext, o.ReqContext)...)
}

if len(b)+headerSize < paddedLength {
// TODO: Are we sure that's the right behavior?
Expand Down Expand Up @@ -660,6 +669,8 @@ func (o *Operation) UnmarshalBinary(body []byte) error {
o.CustomFuncName = string(data)
case TagJaegerSpan:
o.JaegerSpan = data
case TagReqContext:
o.ReqContext = data
default:
// Silently ignore any unknown tags (to allow for new tags to be gradually added to the protocol).
continue
Expand Down
7 changes: 4 additions & 3 deletions protocol/protocol_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions protocol/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ func TestMarshalBinary(t *testing.T) {
// to ensure that the size is calculated correctly.
extra := make([]byte, 100)
payload := make([]byte, 1000)
reqCtx := make([]byte, 100)
rand.Read(extra)
rand.Read(payload)
rand.Read(reqCtx)
op := Operation{
Opcode: OpECDSASignSHA256,
Payload: payload,
Expand All @@ -32,6 +34,7 @@ func TestMarshalBinary(t *testing.T) {
CertID: "SNI",
CustomFuncName: "CustomFuncName",
JaegerSpan: []byte("615f730ad5fe896f:615f730ad5fe896f:1"),
ReqContext: reqCtx,
}
pkt := NewPacket(42, op)
b, err := pkt.MarshalBinary()
Expand Down
72 changes: 54 additions & 18 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -30,6 +31,7 @@ import (
textbook_rsa "github.com/cloudflare/gokeyless/server/internal/rsa"

"github.com/cloudflare/cfssl/log"
"github.com/google/uuid"
)

// Server is a Keyless Server capable of performing opaque key operations.
Expand Down Expand Up @@ -284,6 +286,37 @@ func makeErrResponse(pkt *protocol.Packet, err protocol.Error) response {
return response{id: pkt.ID, op: protocol.MakeErrorOp(err)}
}

func addOperationRequestID(op *protocol.Operation) string {
reqContext := make(map[string]interface{})
var reqID string
var gen bool

if len(op.ReqContext) > 0 {
if err := json.Unmarshal(op.ReqContext, &reqContext); err == nil {
if v, ok := reqContext["request_id"]; ok {
return v.(string)
} else {
gen = true
}
} else {
log.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext)
}
}

if len(op.ReqContext) == 0 || gen {
reqID = uuid.New().String()
reqContext["request_id"] = reqID
b, err := json.Marshal(reqContext)
if err == nil {
op.ReqContext = b
} else {
log.Errorf("error marshaling operation.ReqContext %v, ignoring error", reqContext)
reqID = ""
}
}
return reqID
}

func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
spanCtx, err := tracing.SpanContextFromBinary(pkt.Operation.JaegerSpan)
if err != nil {
Expand All @@ -292,14 +325,17 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
span, ctx := opentracing.StartSpanFromContext(context.Background(), "operation execution", ext.RPCServerOption(spanCtx))
defer span.Finish()
tracing.SetOperationSpanTags(span, &pkt.Operation)
reqID := addOperationRequestID(&pkt.Operation)
span.SetTag("request_id", reqID)

log.Debugf("connection %s: limited=false opcode=%s id=%d sni=%s ip=%s ski=%v",
log.Debugf("connection %s: limited=false opcode=%s id=%d sni=%s ip=%s ski=%v request-id=%s",
connName,
pkt.Operation.Opcode,
pkt.Header.ID,
pkt.Operation.SNI,
pkt.Operation.ServerIP,
pkt.Operation.SKI)
pkt.Operation.SKI,
reqID)

var opts crypto.SignerOpts
switch pkt.Operation.Opcode {
Expand Down Expand Up @@ -362,10 +398,10 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
key, err := s.keys.Get(ctx, &pkt.Operation)
logKeyLoadDuration(loadStart)
if err != nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, err)
log.Errorf("failed to load key with sni=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrInternal)
} else if key == nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, protocol.ErrKeyNotFound)
log.Errorf("failed to load key with sni=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
}

Expand All @@ -376,14 +412,14 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {

sig, err := key.Sign(rand.Reader, pkt.Operation.Payload, crypto.Hash(0))
if err != nil {
log.Errorf("Connection: %s: Signing error: %v", connName, protocol.ErrCrypto, err)
log.Errorf("Connection: %s: sni=%s ski=%v request-id=%s: Signing error: %v: request-id:%s:", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err, reqID)
// This indicates that a remote keyserver is being used
var remoteConfigurationErr RemoteConfigurationErr
if errors.As(err, &remoteConfigurationErr) {
log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrRemoteConfiguration, err)
log.Errorf("Connection %v: sni=%s ski=%v request-id=%s: %s: Signing error: %v request-id:%s\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err, reqID)
return makeErrResponse(pkt, protocol.ErrRemoteConfiguration)
} else {
log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrCrypto, err)
log.Errorf("Connection %v: sni=%s ski=%v request-id=%s: %s: Signing error: %v request-id:%s\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err, reqID)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
}
Expand All @@ -394,37 +430,37 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
key, err := s.keys.Get(ctx, &pkt.Operation)
logKeyLoadDuration(loadStart)
if err != nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, err)
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrInternal)
} else if key == nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, protocol.ErrKeyNotFound)
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
}

if _, ok := key.Public().(*rsa.PublicKey); !ok {
log.Errorf("Connection %v: %s: Key is not RSA", connName, protocol.ErrCrypto)
log.Errorf("Connection %v: sni=%s request-id=%s: %s: Key is not RSA", connName, pkt.Operation.SNI, reqID, protocol.ErrCrypto)
return makeErrResponse(pkt, protocol.ErrCrypto)
}

if rsaKey, ok := key.(*rsa.PrivateKey); ok {
// Decrypt without removing padding; that's the client's responsibility.
ptxt, err := textbook_rsa.Decrypt(rsaKey, pkt.Operation.Payload)
if err != nil {
log.Errorf("connection %v: %v", connName, err)
log.Errorf("connection %v: sni=%s ip=%s ski=%v request-id=%s: %v", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
return makeRespondResponse(pkt, ptxt)
}

rsaKey, ok := key.(crypto.Decrypter)
if !ok {
log.Errorf("Connection %v: %s: Key is not Decrypter", connName, protocol.ErrCrypto)
log.Errorf("Connection %v: sni=%s request-id=%s: %s: Key is not Decrypter", connName, pkt.Operation.SNI, reqID, protocol.ErrCrypto)
return makeErrResponse(pkt, protocol.ErrCrypto)
}

ptxt, err := rsaKey.Decrypt(nil, pkt.Operation.Payload, nil)
if err != nil {
log.Errorf("Connection %v: %s: Decryption error: %v", connName, protocol.ErrCrypto, err)
log.Errorf("Connection %v: sni=%s ip=%s ski=%v request-id=%s: %s: Decryption error: %v", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}

Expand Down Expand Up @@ -457,10 +493,10 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
key, err := s.keys.Get(ctx, &pkt.Operation)
logKeyLoadDuration(loadStart)
if err != nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, err)
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrInternal)
} else if key == nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, protocol.ErrKeyNotFound)
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
}

Expand Down Expand Up @@ -490,17 +526,17 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
}
if err != nil {
if attempts > 1 {
log.Debugf("Connection %v: failed sign attempt: %s, %d attempt(s) left", connName, err, attempts-1)
log.Debugf("Connection %v sni=%s ip=%s ski=%v request-id=%s : failed sign attempt: %s, %d attempt(s) left", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err, attempts-1)
continue
} else {
tracing.LogError(span, err)
// This indicates that a remote keyserver is being used
var remoteConfigurationErr RemoteConfigurationErr
if errors.As(err, &remoteConfigurationErr) {
log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrRemoteConfiguration, err)
log.Errorf("Connection %v sni=%s ip=%s ski=%v request-id=%s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err)
return makeErrResponse(pkt, protocol.ErrRemoteConfiguration)
} else {
log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrCrypto, err)
log.Errorf("Connection %v sni=%s ip=%s ski=%v request-id=%s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
}
Expand Down
40 changes: 40 additions & 0 deletions server/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package server

import (
"crypto/rand"
"encoding/json"
"testing"

"github.com/cloudflare/gokeyless/protocol"
"github.com/stretchr/testify/require"
)

func TestRequestID(t *testing.T) {
require := require.New(t)

r := make([]byte, 20)
_, err := rand.Read(r)
require.NoError(err)

// empty byte array in ReqContext
op := protocol.Operation{
Opcode: protocol.OpECDSASignSHA224,
Payload: r,
ReqContext: []byte{},
}
reqID := addOperationRequestID(&op)
require.NotEqual(reqID, "")

// nil byte array
op.ReqContext = nil
reqID = addOperationRequestID(&op)
require.NotEqual(reqID, "")

// Operation.ReqContext contains a map and request id
rc := map[string]interface{}{"request_id": "b76dfaf1-a852-4dc2-98ff-0ba1947a82b6"}
b, err := json.Marshal(rc)
require.NoError(err)
op.ReqContext = b
reqID = addOperationRequestID(&op)
require.Equal(reqID, "b76dfaf1-a852-4dc2-98ff-0ba1947a82b6")
}
41 changes: 41 additions & 0 deletions vendor/github.com/google/uuid/CHANGELOG.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8b7f832

Please sign in to comment.