Skip to content

Commit

Permalink
Optimize memory usage by avoiding intermediate buffer in message seri…
Browse files Browse the repository at this point in the history
…alization (#928)

* Optimize memory usage by avoiding intermediate buffer in message serialization

This commit replaces the use of an intermediate buffer in the message serialization process with a direct write-to-buffer approach. The original implementation used MustMarshalBinary() which involved an extra memory copy to an intermediate buffer before writing to the final writeBuffer, leading to high memory consumption for large messages. The new WriteTo function writes message data directly to the writeBuffer, significantly reducing memory overhead and CPU time spent on garbage collection.

* add benchmark for write

* benchmark for 1M/4M/8M

* Tidy up new benchmarks

* Maintain older payload write implementation

---------

Co-authored-by: luozhengjie.lzj <[email protected]>
Co-authored-by: Matt Joiner <[email protected]>
  • Loading branch information
3 people committed Apr 25, 2024
1 parent 78e00f9 commit 3f5ef0b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 17 deletions.
13 changes: 12 additions & 1 deletion peer-conn-msg-writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,21 @@ func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) {
}
}

func (cn *peerConnMsgWriter) writeToBuffer(msg pp.Message) (err error) {
originalLen := cn.writeBuffer.Len()
defer func() {
if err != nil {
// Since an error occurred during buffer write, revert buffer to its original state before the write.
cn.writeBuffer.Truncate(originalLen)
}
}()
return msg.WriteTo(cn.writeBuffer)
}

func (cn *peerConnMsgWriter) write(msg pp.Message) bool {
cn.mu.Lock()
defer cn.mu.Unlock()
cn.writeBuffer.Write(msg.MustMarshalBinary())
cn.writeToBuffer(msg)
cn.writeCond.Broadcast()
return !cn.writeBufferFull()
}
Expand Down
68 changes: 68 additions & 0 deletions peer-conn-msg-writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package torrent

import (
"bytes"
"testing"

"github.com/dustin/go-humanize"

pp "github.com/anacrolix/torrent/peer_protocol"
)

func PieceMsg(length int64) pp.Message {
return pp.Message{
Type: pp.Piece,
Index: pp.Integer(0),
Begin: pp.Integer(0),
Piece: make([]byte, length),
}
}

var benchmarkPieceLengths = []int{defaultChunkSize, 1 << 20, 4 << 20, 8 << 20}

func runBenchmarkWriteToBuffer(b *testing.B, length int64) {
writer := &peerConnMsgWriter{
writeBuffer: &bytes.Buffer{},
}
msg := PieceMsg(length)

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
//b.StopTimer()
writer.writeBuffer.Reset()
//b.StartTimer()
writer.writeToBuffer(msg)
}
}

func BenchmarkWritePieceMsg(b *testing.B) {
for _, length := range benchmarkPieceLengths {
b.Run(humanize.IBytes(uint64(length)), func(b *testing.B) {
b.Run("ToBuffer", func(b *testing.B) {
b.SetBytes(int64(length))
runBenchmarkWriteToBuffer(b, int64(length))
})
b.Run("MarshalBinary", func(b *testing.B) {
b.SetBytes(int64(length))
runBenchmarkMarshalBinaryWrite(b, int64(length))
})
})
}
}

func runBenchmarkMarshalBinaryWrite(b *testing.B, length int64) {
writer := &peerConnMsgWriter{
writeBuffer: &bytes.Buffer{},
}
msg := PieceMsg(length)

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
//b.StopTimer()
writer.writeBuffer.Reset()
//b.StartTimer()
writer.writeBuffer.Write(msg.MustMarshalBinary())
}
}
69 changes: 55 additions & 14 deletions peer_protocol/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding"
"encoding/binary"
"fmt"
"io"
)

// This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
Expand Down Expand Up @@ -61,13 +62,14 @@ func (msg Message) MustMarshalBinary() []byte {
return b
}

func (msg Message) MarshalBinary() (data []byte, err error) {
// It might look like you could have a pool of buffers and preallocate the message length
// prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
// will need a benchmark.
var buf bytes.Buffer
type MessageWriter interface {
io.ByteWriter
io.Writer
}

func (msg *Message) writePayloadTo(buf MessageWriter) (err error) {
mustWrite := func(data any) {
err := binary.Write(&buf, binary.BigEndian, data)
err := binary.Write(buf, binary.BigEndian, data)
if err != nil {
panic(err)
}
Expand All @@ -85,10 +87,10 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
case Have, AllowedFast, Suggest:
err = binary.Write(&buf, binary.BigEndian, msg.Index)
err = binary.Write(buf, binary.BigEndian, msg.Index)
case Request, Cancel, Reject:
for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
err = binary.Write(&buf, binary.BigEndian, i)
err = binary.Write(buf, binary.BigEndian, i)
if err != nil {
break
}
Expand All @@ -97,7 +99,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
_, err = buf.Write(marshalBitfield(msg.Bitfield))
case Piece:
for _, i := range []Integer{msg.Index, msg.Begin} {
err = binary.Write(&buf, binary.BigEndian, i)
err = binary.Write(buf, binary.BigEndian, i)
if err != nil {
return
}
Expand All @@ -116,19 +118,43 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
}
_, err = buf.Write(msg.ExtendedPayload)
case Port:
err = binary.Write(&buf, binary.BigEndian, msg.Port)
err = binary.Write(buf, binary.BigEndian, msg.Port)
case HashRequest:
buf.Write(msg.PiecesRoot[:])
writeConsecutive(msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers)
default:
err = fmt.Errorf("unknown message type: %v", msg.Type)
}
}
data = make([]byte, 4+buf.Len())
binary.BigEndian.PutUint32(data, uint32(buf.Len()))
if buf.Len() != copy(data[4:], buf.Bytes()) {
panic("bad copy")
return
}

func (msg *Message) WriteTo(w MessageWriter) (err error) {
length, err := msg.getPayloadLength()
if err != nil {
return
}
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return
}
return msg.writePayloadTo(w)
}

func (msg *Message) getPayloadLength() (length Integer, err error) {
var lw lengthWriter
err = msg.writePayloadTo(&lw)
length = lw.n
return
}

func (msg Message) MarshalBinary() (data []byte, err error) {
// It might look like you could have a pool of buffers and preallocate the message length
// prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
// will need a benchmark.
var buf bytes.Buffer
err = msg.WriteTo(&buf)
data = buf.Bytes()
return
}

Expand Down Expand Up @@ -158,3 +184,18 @@ func (me *Message) UnmarshalBinary(b []byte) error {
}
return nil
}

type lengthWriter struct {
n Integer
}

func (l *lengthWriter) WriteByte(c byte) error {
l.n++
return nil
}

func (l *lengthWriter) Write(p []byte) (n int, err error) {
n = len(p)
l.n += Integer(n)
return
}
3 changes: 1 addition & 2 deletions requesting.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ import (
"time"
"unsafe"

g "github.com/anacrolix/generics"

"github.com/RoaringBitmap/roaring"
g "github.com/anacrolix/generics"
"github.com/anacrolix/generics/heap"
"github.com/anacrolix/log"
"github.com/anacrolix/multiless"
Expand Down

0 comments on commit 3f5ef0b

Please sign in to comment.