Skip to content

Commit

Permalink
Protect typingMap with a mutex, sort userIDs before hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
S7evinK committed Jul 20, 2023
1 parent ca263b2 commit e566158
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
17 changes: 15 additions & 2 deletions sync2/handler2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"hash/fnv"
"os"
"sort"
"sync"
"time"

Expand Down Expand Up @@ -43,6 +44,7 @@ type Handler struct {
}
// room_id => fnv_hash([typing user ids])
typingMap map[string]uint64
typingMu *sync.Mutex

deviceDataTicker *sync2.DeviceDataTicker
e2eeWorkerPool *internal.WorkerPool
Expand All @@ -65,6 +67,7 @@ func NewHandler(
Notif int
}),
typingMap: make(map[string]uint64),
typingMu: &sync.Mutex{},
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded
}
Expand Down Expand Up @@ -337,6 +340,10 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra

func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
next := typingHash(ephEvent)
// protect typingMap with a lock, so concurrent calls to SetTyping see the correct map
h.typingMu.Lock()
defer h.typingMu.Unlock()

existing := h.typingMap[roomID]
if existing == next {
return
Expand Down Expand Up @@ -493,8 +500,14 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {

func typingHash(ephEvent json.RawMessage) uint64 {
h := fnv.New64a()
for _, userID := range gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() {
h.Write([]byte(userID.Str))
parsedUserIDs := gjson.ParseBytes(ephEvent).Get("content.user_ids").Array()
userIDs := make([]string, len(parsedUserIDs))
for i := range parsedUserIDs {
userIDs[i] = parsedUserIDs[i].Str
}
sort.Strings(userIDs)
for _, userID := range userIDs {
_, _ = h.Write([]byte(userID))
}
return h.Sum64()
}
42 changes: 42 additions & 0 deletions sync2/handler2/handler_unexported_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package handler2

import (
"encoding/json"
"testing"
)

func Test_typingHash(t *testing.T) {
tests := []struct {
name string
ephEvent json.RawMessage
want uint64
}{
{
name: "doesn't fall over if list is empty",
ephEvent: json.RawMessage(`{"content":{"user_ids":[]}}`),
want: 14695981039346656037,
},
{
name: "hash alice typing",
ephEvent: json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`),
want: 16709353342265369358,
},
{
name: "hash alice and bob typing",
ephEvent: json.RawMessage(`{"content":{"user_ids":["@alice:localhost","@bob:localhost"]}}`),
want: 11071889279173799154,
},
{
name: "hash bob and alice typing",
ephEvent: json.RawMessage(`{"content":{"user_ids":["@bob:localhost","@alice:localhost"]}}`),
want: 11071889279173799154, // this should be the same as the previous
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := typingHash(tt.ephEvent); got != tt.want {
t.Errorf("typingHash() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit e566158

Please sign in to comment.