Skip to content

Commit

Permalink
Add contexted ReadWriter to handshakes
Browse files Browse the repository at this point in the history
  • Loading branch information
anacrolix committed Aug 10, 2024
1 parent b7b97a6 commit 2502dd2
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 11 deletions.
8 changes: 4 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ func (cl *Client) initiateHandshakes(ctx context.Context, c *PeerConn, t *Torren
// If we're sending the v1 infohash, and we know the v2 infohash, set the v2 upgrade bit. This
// means the peer can send the v2 infohash in the handshake to upgrade the connection.
localReservedBits.SetBit(pp.ExtensionBitV2Upgrade, g.Some(handshakeIh) == t.infoHash && t.infoHashV2.Ok)
ih, err := cl.connBtHandshake(c, &handshakeIh, localReservedBits)
ih, err := cl.connBtHandshake(context.TODO(), c, &handshakeIh, localReservedBits)
if err != nil {
return fmt.Errorf("bittorrent protocol handshake: %w", err)
}
Expand Down Expand Up @@ -1015,7 +1015,7 @@ func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
err = errors.New("connection does not have required header obfuscation")
return
}
ih, err := cl.connBtHandshake(c, nil, cl.config.Extensions)
ih, err := cl.connBtHandshake(context.TODO(), c, nil, cl.config.Extensions)
if err != nil {
return nil, fmt.Errorf("during bt handshake: %w", err)
}
Expand All @@ -1039,8 +1039,8 @@ func init() {
&successfulPeerWireProtocolHandshakePeerReservedBytes)
}

func (cl *Client) connBtHandshake(c *PeerConn, ih *metainfo.Hash, reservedBits PeerExtensionBits) (ret metainfo.Hash, err error) {
res, err := pp.Handshake(c.rw(), ih, cl.peerID, reservedBits)
func (cl *Client) connBtHandshake(ctx context.Context, c *PeerConn, ih *metainfo.Hash, reservedBits PeerExtensionBits) (ret metainfo.Hash, err error) {
res, err := pp.Handshake(ctx, c.rw(), ih, cl.peerID, reservedBits)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/RoaringBitmap/roaring v1.2.3
github.com/ajwerner/btree v0.0.0-20211221152037-f427b3e689c0
github.com/alexflint/go-arg v1.4.3
github.com/anacrolix/bargle v0.0.0-20220630015206-d7a4d433886a
github.com/anacrolix/bargle v0.0.0-20221014000746-4f2739072e9d
github.com/anacrolix/chansync v0.4.1-0.20240627045151-1aa1ac392fe8
github.com/anacrolix/dht/v2 v2.19.2-0.20221121215055-066ad8494444
github.com/anacrolix/envpprof v1.3.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ github.com/anacrolix/backtrace v0.0.0-20221205112523-22a61db8f82e h1:A0Ty9UeyBDI
github.com/anacrolix/backtrace v0.0.0-20221205112523-22a61db8f82e/go.mod h1:4YFqy+788tLJWtin2jNliYVJi+8aDejG9zcu/2/pONw=
github.com/anacrolix/bargle v0.0.0-20220630015206-d7a4d433886a h1:KCP9QvHlLoUQBOaTf/YCuOzG91Ym1cPB6S68O4Q3puo=
github.com/anacrolix/bargle v0.0.0-20220630015206-d7a4d433886a/go.mod h1:9xUiZbkh+94FbiIAL1HXpAIBa832f3Mp07rRPl5c5RQ=
github.com/anacrolix/bargle v0.0.0-20221014000746-4f2739072e9d h1:ypNOsIwvdumNRlqWj/hsnLs5TyQWQOylwi+T9Qs454A=
github.com/anacrolix/bargle v0.0.0-20221014000746-4f2739072e9d/go.mod h1:9xUiZbkh+94FbiIAL1HXpAIBa832f3Mp07rRPl5c5RQ=
github.com/anacrolix/chansync v0.4.1-0.20240627045151-1aa1ac392fe8 h1:eyb0bBaQKMOh5Se/Qg54shijc8K4zpQiOjEhKFADkQM=
github.com/anacrolix/chansync v0.4.1-0.20240627045151-1aa1ac392fe8/go.mod h1:DZsatdsdXxD0WiwcGl0nJVwyjCKMDv+knl1q2iBjA2k=
github.com/anacrolix/dht/v2 v2.19.2-0.20221121215055-066ad8494444 h1:8V0K09lrGoeT2KRJNOtspA7q+OMxGwQqK/Ug0IiaaRE=
Expand Down
4 changes: 2 additions & 2 deletions mse/ctxrw.go → internal/ctxrw/ctxrw.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package mse
package ctxrw

import (
"context"
Expand Down Expand Up @@ -41,7 +41,7 @@ func (me contextedWriter) Write(p []byte) (n int, err error) {
return contextedReadOrWrite(me.ctx, me.w.Write, p)
}

func contextedReadWriter(ctx context.Context, rw io.ReadWriter) io.ReadWriter {
func WrapReadWriter(ctx context.Context, rw io.ReadWriter) io.ReadWriter {
return struct {
io.Reader
io.Writer
Expand Down
5 changes: 3 additions & 2 deletions mse/mse.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"errors"
"expvar"
"fmt"
"github.com/anacrolix/torrent/internal/ctxrw"
"io"
"math"
"math/big"
Expand Down Expand Up @@ -554,7 +555,7 @@ func InitiateHandshakeContext(
) {
h := handshake{
conn: rw,
ctxConn: contextedReadWriter(ctx, rw),
ctxConn: ctxrw.WrapReadWriter(ctx, rw),
initer: true,
skey: skey,
ia: initialPayload,
Expand Down Expand Up @@ -589,7 +590,7 @@ func ReceiveHandshakeEx(
) (ret HandshakeResult) {
h := handshake{
conn: rw,
ctxConn: contextedReadWriter(ctx, rw),
ctxConn: ctxrw.WrapReadWriter(ctx, rw),
initer: false,
skeys: skeys,
chooseMethod: selectCrypto,
Expand Down
9 changes: 8 additions & 1 deletion peer_protocol/handshake.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package peer_protocol

import (
"context"
"encoding/hex"
"errors"
"fmt"
"github.com/anacrolix/torrent/internal/ctxrw"
"io"
"math/bits"
"strconv"
Expand Down Expand Up @@ -122,10 +124,15 @@ type HandshakeResult struct {
// connection. Returns ok if the Handshake was successful, and err if there was an unexpected
// condition other than the peer simply abandoning the Handshake.
func Handshake(
sock io.ReadWriter, ih *metainfo.Hash, peerID [20]byte, extensions PeerExtensionBits,
ctx context.Context,
sock io.ReadWriter,
ih *metainfo.Hash,
peerID [20]byte,
extensions PeerExtensionBits,
) (
res HandshakeResult, err error,
) {
sock = ctxrw.WrapReadWriter(ctx, sock)
// Bytes to be sent to the peer. Should never block the sender.
postCh := make(chan []byte, 4)
// A single error value sent when the writer completes.
Expand Down
3 changes: 2 additions & 1 deletion torrent_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package torrent

import (
"context"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -187,7 +188,7 @@ func TestTorrentMetainfoIncompleteMetadata(t *testing.T) {

var pex PeerExtensionBits
pex.SetBit(pp.ExtensionBitLtep, true)
hr, err := pp.Handshake(nc, &ih, [20]byte{}, pex)
hr, err := pp.Handshake(context.Background(), nc, &ih, [20]byte{}, pex)
require.NoError(t, err)
assert.True(t, hr.PeerExtensionBits.GetBit(pp.ExtensionBitLtep))
assert.EqualValues(t, cl.PeerID(), hr.PeerID)
Expand Down

0 comments on commit 2502dd2

Please sign in to comment.