Skip to content

Commit

Permalink
Merge pull request #214 from matrix-org/s7evink/typing
Browse files Browse the repository at this point in the history
Fix V2Typing send multiple times
  • Loading branch information
S7evinK committed Aug 2, 2023
2 parents a61a3fd + 2ced198 commit 9862fad
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 31 deletions.
50 changes: 32 additions & 18 deletions sync2/handler2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package handler2
import (
"context"
"encoding/json"
"hash/fnv"
"os"
"sync"
"time"
Expand Down Expand Up @@ -40,8 +39,9 @@ type Handler struct {
Highlight int
Notif int
}
// room_id => fnv_hash([typing user ids])
typingMap map[string]uint64
// room_id -> PollerID, stores which Poller is allowed to update typing notifications
typingHandler map[string]sync2.PollerID
typingMu *sync.Mutex
PendingTxnIDs *sync2.PendingTransactionIDs

deviceDataTicker *sync2.DeviceDataTicker
Expand All @@ -64,7 +64,8 @@ func NewHandler(
Highlight int
Notif int
}),
typingMap: make(map[string]uint64),
typingMu: &sync.Mutex{},
typingHandler: make(map[string]sync2.PollerID),
PendingTxnIDs: sync2.NewPendingTransactionIDs(pMap.DeviceIDs),
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded
Expand Down Expand Up @@ -166,7 +167,15 @@ func (h *Handler) updateMetrics() {
h.numPollers.Set(float64(h.pMap.NumPollers()))
}

func (h *Handler) OnTerminated(ctx context.Context, userID, deviceID string) {
func (h *Handler) OnTerminated(ctx context.Context, pollerID sync2.PollerID) {
// Check if this device is handling any typing notifications, of so, remove it
h.typingMu.Lock()
defer h.typingMu.Unlock()
for roomID, devID := range h.typingHandler {
if devID == pollerID {
delete(h.typingHandler, roomID)
}
}
h.updateMetrics()
}

Expand Down Expand Up @@ -352,13 +361,20 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra
return res.PrependTimelineEvents
}

func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
next := typingHash(ephEvent)
existing := h.typingMap[roomID]
if existing == next {
func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) {
h.typingMu.Lock()
defer h.typingMu.Unlock()

existingDevice := h.typingHandler[roomID]
isPollerAssigned := existingDevice.DeviceID != "" && existingDevice.UserID != ""
if isPollerAssigned && existingDevice != pollerID {
// A different device is already handling typing notifications for this room
return
} else if !isPollerAssigned {
// We're the first to call SetTyping, assign our pollerID
h.typingHandler[roomID] = pollerID
}
h.typingMap[roomID] = next

// we don't persist this for long term storage as typing notifs are inherently ephemeral.
// So rather than maintaining them forever, they will naturally expire when we terminate.
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Typing{
Expand Down Expand Up @@ -473,6 +489,12 @@ func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEv
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
}

// Remove room from the typing deviceHandler map, this ensures we always
// have a device handling typing notifications for a given room.
h.typingMu.Lock()
defer h.typingMu.Unlock()
delete(h.typingHandler, roomID)

h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{
UserID: userID,
RoomID: roomID,
Expand Down Expand Up @@ -509,11 +531,3 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
})
}()
}

func typingHash(ephEvent json.RawMessage) uint64 {
h := fnv.New64a()
for _, userID := range gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() {
h.Write([]byte(userID.Str))
}
return h.Sum64()
}
54 changes: 51 additions & 3 deletions sync2/handler2/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package handler2_test

import (
"context"
"encoding/json"
"os"
"reflect"
"sync"
Expand Down Expand Up @@ -97,12 +99,17 @@ func (p *mockPub) WaitForPayloadType(t string) chan struct{} {
return ch
}

func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}) {
func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}, wantTimeOut bool) {
select {
case <-ch:
if wantTimeOut {
t.Fatalf("expected to timeout, but received on channel")
}
return
case <-time.After(time.Second):
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
if !wantTimeOut {
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
}
}
}

Expand Down Expand Up @@ -160,7 +167,7 @@ func TestHandlerFreshEnsurePolling(t *testing.T) {
DeviceID: deviceID,
AccessTokenHash: tok.AccessTokenHash,
})
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch)
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch, false)

// make sure we polled with the token i.e it did a db hit
pMap.assertCallExists(t, pollInfo{
Expand All @@ -174,3 +181,44 @@ func TestHandlerFreshEnsurePolling(t *testing.T) {
})

}

func TestSetTypingConcurrently(t *testing.T) {
store := state.NewStorage(postgresURI)
v2Store := sync2.NewStore(postgresURI, "secret")
pMap := &mockPollerMap{}
pub := newMockPub()
sub := &mockSub{}
h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute)
assertNoError(t, err)
ctx := context.Background()

roomID := "!typing:localhost"

typingType := pubsub.V2Typing{}

// startSignal is used to synchronize calling SetTyping
startSignal := make(chan struct{})
// Call SetTyping twice, this may happen with pollers for the same user
go func() {
<-startSignal
h.SetTyping(ctx, sync2.PollerID{UserID: "@alice", DeviceID: "aliceDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
}()
go func() {
<-startSignal
h.SetTyping(ctx, sync2.PollerID{UserID: "@bob", DeviceID: "bobDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
}()

close(startSignal)

// Wait for the event to be published
ch := pub.WaitForPayloadType(typingType.Type())
pub.DoWait(t, "didn't see V2Typing", ch, false)
ch = pub.WaitForPayloadType(typingType.Type())
// Wait again, but this time we expect to timeout.
pub.DoWait(t, "saw unexpected V2Typing", ch, true)

// We expect only one call to Notify, as the hashes should match
if gotCalls := len(pub.calls); gotCalls != 1 {
t.Fatalf("expected only one call to notify, got %d", gotCalls)
}
}
19 changes: 11 additions & 8 deletions sync2/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type V2DataReceiver interface {
// If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB.
Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID?
// SetTyping indicates which users are typing.
SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage)
SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage)
// Sent when there is a new receipt
OnReceipt(ctx context.Context, userID, roomID, ephEventType string, ephEvent json.RawMessage)
// AddToDeviceMessages adds this chunk of to_device messages. Preserve the ordering.
Expand All @@ -55,7 +55,7 @@ type V2DataReceiver interface {
// Sent when there is a _change_ in E2EE data, not all the time
OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
// Sent when the poll loop terminates
OnTerminated(ctx context.Context, userID, deviceID string)
OnTerminated(ctx context.Context, pollerID PollerID)
// Sent when the token gets a 401 response
OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string)
}
Expand Down Expand Up @@ -297,11 +297,11 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.
wg.Wait()
return
}
func (h *PollerMap) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
func (h *PollerMap) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.SetTyping(ctx, roomID, ephEvent)
h.callbacks.SetTyping(ctx, pollerID, roomID, ephEvent)
wg.Done()
}
wg.Wait()
Expand Down Expand Up @@ -332,8 +332,8 @@ func (h *PollerMap) AddToDeviceMessages(ctx context.Context, userID, deviceID st
h.callbacks.AddToDeviceMessages(ctx, userID, deviceID, msgs)
}

func (h *PollerMap) OnTerminated(ctx context.Context, userID, deviceID string) {
h.callbacks.OnTerminated(ctx, userID, deviceID)
func (h *PollerMap) OnTerminated(ctx context.Context, pollerID PollerID) {
h.callbacks.OnTerminated(ctx, pollerID)
}

func (h *PollerMap) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
Expand Down Expand Up @@ -473,7 +473,10 @@ func (p *poller) Poll(since string) {
logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msgf("%s. Traceback:\n%s", panicErr, debug.Stack())
internal.GetSentryHubFromContextOrDefault(ctx).RecoverWithContext(ctx, panicErr)
}
p.receiver.OnTerminated(ctx, p.userID, p.deviceID)
p.receiver.OnTerminated(ctx, PollerID{
UserID: p.userID,
DeviceID: p.deviceID,
})
}()

state := pollLoopState{
Expand Down Expand Up @@ -706,7 +709,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
switch ephEventType {
case "m.typing":
typingCalls++
p.receiver.SetTyping(ctx, roomID, ephEvent)
p.receiver.SetTyping(ctx, PollerID{UserID: p.userID, DeviceID: p.deviceID}, roomID, ephEvent)
case "m.receipt":
receiptCalls++
p.receiver.OnReceipt(ctx, p.userID, roomID, ephEventType, ephEvent)
Expand Down
4 changes: 2 additions & 2 deletions sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state
// timeline. Untested here---return nil for now.
return nil
}
func (a *mockDataReceiver) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
func (a *mockDataReceiver) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
}
func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) {
s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since
Expand All @@ -621,7 +621,7 @@ func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string
}
func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, pollerID PollerID) {}
func (s *mockDataReceiver) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
}

Expand Down
Loading

0 comments on commit 9862fad

Please sign in to comment.