diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index f89a20f6..aed37237 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -3,7 +3,6 @@ package handler2 import ( "context" "encoding/json" - "hash/fnv" "os" "sync" "time" @@ -40,8 +39,9 @@ type Handler struct { Highlight int Notif int } - // room_id => fnv_hash([typing user ids]) - typingMap map[string]uint64 + // room_id -> PollerID, stores which Poller is allowed to update typing notifications + typingHandler map[string]sync2.PollerID + typingMu *sync.Mutex PendingTxnIDs *sync2.PendingTransactionIDs deviceDataTicker *sync2.DeviceDataTicker @@ -64,7 +64,8 @@ func NewHandler( Highlight int Notif int }), - typingMap: make(map[string]uint64), + typingMu: &sync.Mutex{}, + typingHandler: make(map[string]sync2.PollerID), PendingTxnIDs: sync2.NewPendingTransactionIDs(pMap.DeviceIDs), deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded @@ -166,7 +167,15 @@ func (h *Handler) updateMetrics() { h.numPollers.Set(float64(h.pMap.NumPollers())) } -func (h *Handler) OnTerminated(ctx context.Context, userID, deviceID string) { +func (h *Handler) OnTerminated(ctx context.Context, pollerID sync2.PollerID) { + // Check if this device is handling any typing notifications, of so, remove it + h.typingMu.Lock() + defer h.typingMu.Unlock() + for roomID, devID := range h.typingHandler { + if devID == pollerID { + delete(h.typingHandler, roomID) + } + } h.updateMetrics() } @@ -352,13 +361,20 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra return res.PrependTimelineEvents } -func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) { - next := typingHash(ephEvent) - existing := h.typingMap[roomID] - if existing == next { +func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) { + h.typingMu.Lock() + defer h.typingMu.Unlock() + + existingDevice := h.typingHandler[roomID] + isPollerAssigned := existingDevice.DeviceID != "" && existingDevice.UserID != "" + if isPollerAssigned && existingDevice != pollerID { + // A different device is already handling typing notifications for this room return + } else if !isPollerAssigned { + // We're the first to call SetTyping, assign our pollerID + h.typingHandler[roomID] = pollerID } - h.typingMap[roomID] = next + // we don't persist this for long term storage as typing notifs are inherently ephemeral. // So rather than maintaining them forever, they will naturally expire when we terminate. h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Typing{ @@ -473,6 +489,12 @@ func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEv internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) } + // Remove room from the typing deviceHandler map, this ensures we always + // have a device handling typing notifications for a given room. + h.typingMu.Lock() + defer h.typingMu.Unlock() + delete(h.typingHandler, roomID) + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{ UserID: userID, RoomID: roomID, @@ -509,11 +531,3 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) { }) }() } - -func typingHash(ephEvent json.RawMessage) uint64 { - h := fnv.New64a() - for _, userID := range gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() { - h.Write([]byte(userID.Str)) - } - return h.Sum64() -} diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go index b123292b..bbd53e9f 100644 --- a/sync2/handler2/handler_test.go +++ b/sync2/handler2/handler_test.go @@ -1,6 +1,8 @@ package handler2_test import ( + "context" + "encoding/json" "os" "reflect" "sync" @@ -97,12 +99,17 @@ func (p *mockPub) WaitForPayloadType(t string) chan struct{} { return ch } -func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}) { +func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}, wantTimeOut bool) { select { case <-ch: + if wantTimeOut { + t.Fatalf("expected to timeout, but received on channel") + } return case <-time.After(time.Second): - t.Fatalf("DoWait: timed out waiting: %s", errMsg) + if !wantTimeOut { + t.Fatalf("DoWait: timed out waiting: %s", errMsg) + } } } @@ -160,7 +167,7 @@ func TestHandlerFreshEnsurePolling(t *testing.T) { DeviceID: deviceID, AccessTokenHash: tok.AccessTokenHash, }) - pub.DoWait(t, "didn't see V2InitialSyncComplete", ch) + pub.DoWait(t, "didn't see V2InitialSyncComplete", ch, false) // make sure we polled with the token i.e it did a db hit pMap.assertCallExists(t, pollInfo{ @@ -174,3 +181,44 @@ func TestHandlerFreshEnsurePolling(t *testing.T) { }) } + +func TestSetTypingConcurrently(t *testing.T) { + store := state.NewStorage(postgresURI) + v2Store := sync2.NewStore(postgresURI, "secret") + pMap := &mockPollerMap{} + pub := newMockPub() + sub := &mockSub{} + h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute) + assertNoError(t, err) + ctx := context.Background() + + roomID := "!typing:localhost" + + typingType := pubsub.V2Typing{} + + // startSignal is used to synchronize calling SetTyping + startSignal := make(chan struct{}) + // Call SetTyping twice, this may happen with pollers for the same user + go func() { + <-startSignal + h.SetTyping(ctx, sync2.PollerID{UserID: "@alice", DeviceID: "aliceDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) + }() + go func() { + <-startSignal + h.SetTyping(ctx, sync2.PollerID{UserID: "@bob", DeviceID: "bobDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) + }() + + close(startSignal) + + // Wait for the event to be published + ch := pub.WaitForPayloadType(typingType.Type()) + pub.DoWait(t, "didn't see V2Typing", ch, false) + ch = pub.WaitForPayloadType(typingType.Type()) + // Wait again, but this time we expect to timeout. + pub.DoWait(t, "saw unexpected V2Typing", ch, true) + + // We expect only one call to Notify, as the hashes should match + if gotCalls := len(pub.calls); gotCalls != 1 { + t.Fatalf("expected only one call to notify, got %d", gotCalls) + } +} diff --git a/sync2/poller.go b/sync2/poller.go index c5784eb5..d5791a66 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -39,7 +39,7 @@ type V2DataReceiver interface { // If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB. Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID? // SetTyping indicates which users are typing. - SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) + SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) // Sent when there is a new receipt OnReceipt(ctx context.Context, userID, roomID, ephEventType string, ephEvent json.RawMessage) // AddToDeviceMessages adds this chunk of to_device messages. Preserve the ordering. @@ -55,7 +55,7 @@ type V2DataReceiver interface { // Sent when there is a _change_ in E2EE data, not all the time OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) // Sent when the poll loop terminates - OnTerminated(ctx context.Context, userID, deviceID string) + OnTerminated(ctx context.Context, pollerID PollerID) // Sent when the token gets a 401 response OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) } @@ -297,11 +297,11 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json. wg.Wait() return } -func (h *PollerMap) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) { +func (h *PollerMap) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) { var wg sync.WaitGroup wg.Add(1) h.executor <- func() { - h.callbacks.SetTyping(ctx, roomID, ephEvent) + h.callbacks.SetTyping(ctx, pollerID, roomID, ephEvent) wg.Done() } wg.Wait() @@ -332,8 +332,8 @@ func (h *PollerMap) AddToDeviceMessages(ctx context.Context, userID, deviceID st h.callbacks.AddToDeviceMessages(ctx, userID, deviceID, msgs) } -func (h *PollerMap) OnTerminated(ctx context.Context, userID, deviceID string) { - h.callbacks.OnTerminated(ctx, userID, deviceID) +func (h *PollerMap) OnTerminated(ctx context.Context, pollerID PollerID) { + h.callbacks.OnTerminated(ctx, pollerID) } func (h *PollerMap) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) { @@ -473,7 +473,10 @@ func (p *poller) Poll(since string) { logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msgf("%s. Traceback:\n%s", panicErr, debug.Stack()) internal.GetSentryHubFromContextOrDefault(ctx).RecoverWithContext(ctx, panicErr) } - p.receiver.OnTerminated(ctx, p.userID, p.deviceID) + p.receiver.OnTerminated(ctx, PollerID{ + UserID: p.userID, + DeviceID: p.deviceID, + }) }() state := pollLoopState{ @@ -706,7 +709,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { switch ephEventType { case "m.typing": typingCalls++ - p.receiver.SetTyping(ctx, roomID, ephEvent) + p.receiver.SetTyping(ctx, PollerID{UserID: p.userID, DeviceID: p.deviceID}, roomID, ephEvent) case "m.receipt": receiptCalls++ p.receiver.OnReceipt(ctx, p.userID, roomID, ephEventType, ephEvent) diff --git a/sync2/poller_test.go b/sync2/poller_test.go index 5cba979c..9094e785 100644 --- a/sync2/poller_test.go +++ b/sync2/poller_test.go @@ -598,7 +598,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state // timeline. Untested here---return nil for now. return nil } -func (a *mockDataReceiver) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) { +func (a *mockDataReceiver) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) { } func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) { s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since @@ -621,7 +621,7 @@ func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string } func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) { } -func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {} +func (s *mockDataReceiver) OnTerminated(ctx context.Context, pollerID PollerID) {} func (s *mockDataReceiver) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) { } diff --git a/tests-integration/extensions_test.go b/tests-integration/extensions_test.go index c9a9934a..11f7c6ff 100644 --- a/tests-integration/extensions_test.go +++ b/tests-integration/extensions_test.go @@ -604,3 +604,171 @@ func TestExtensionLateEnable(t *testing.T) { }, }) } + +func TestTypingMultiplePoller(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + // setup code + v2 := runTestV2Server(t) + v3 := runTestServer(t, v2, pqString) + defer v2.close() + defer v3.close() + + roomA := "!a:localhost" + + v2.addAccountWithDeviceID(alice, "first", aliceToken) + v2.addAccountWithDeviceID(bob, "second", bobToken) + + // Create the room state and join with Bob + roomState := createRoomState(t, alice, time.Now()) + joinEv := testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{ + "membership": "join", + }) + + // Queue the response with Alice typing + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomA: { + State: sync2.EventsResponse{ + Events: roomState, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{joinEv}, + }, + Ephemeral: sync2.EventsResponse{ + Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@alice:localhost"]}}`)}, + }, + }, + }, + }, + }) + + // Queue another response for Bob with Bob typing. + v2.queueResponse(bobToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomA: { + State: sync2.EventsResponse{ + Events: roomState, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{joinEv}, + }, + Ephemeral: sync2.EventsResponse{ + Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)}}, + }, + }, + }, + }) + + // Start the pollers. Since Alice's poller is started first, the poller is in + // charge of handling typing notifications for roomA. + aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{}) + bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{}) + + // Get the response from v3 + for _, token := range []string{aliceToken, bobToken} { + pos := aliceRes.Pos + if token == bobToken { + pos = bobRes.Pos + } + + res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{ + Extensions: extensions.Request{ + Typing: &extensions.TypingRequest{ + Core: extensions.Core{Enabled: &boolTrue}, + }, + }, + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 1}, + }, + Sort: []string{sync3.SortByRecency}, + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 0, + }, + }}, + }) + // We expect only Alice typing, as only Alice Poller is "allowed" + // to update typing notifications. + m.MatchResponse(t, res, m.MatchTyping(roomA, []string{alice})) + if token == bobToken { + bobRes = res + } + if token == aliceToken { + aliceRes = res + } + } + + // Queue the response with Bob typing + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomA: { + State: sync2.EventsResponse{ + Events: roomState, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{joinEv}, + }, + Ephemeral: sync2.EventsResponse{ + Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)}, + }, + }, + }, + }, + }) + + // Queue another response for Bob with Charlie typing. + // Since Alice's poller is in charge of handling typing notifications, this shouldn't + // show up on future responses. + v2.queueResponse(bobToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomA: { + State: sync2.EventsResponse{ + Events: roomState, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{joinEv}, + }, + Ephemeral: sync2.EventsResponse{ + Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@charlie:localhost"]}}`)}, + }, + }, + }, + }, + }) + + // Wait for the queued responses to be processed. + v2.waitUntilEmpty(t, aliceToken) + v2.waitUntilEmpty(t, bobToken) + + // Check that only Bob is typing and not Charlie. + for _, token := range []string{aliceToken, bobToken} { + pos := aliceRes.Pos + if token == bobToken { + pos = bobRes.Pos + } + + res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{ + Extensions: extensions.Request{ + Typing: &extensions.TypingRequest{ + Core: extensions.Core{Enabled: &boolTrue}, + }, + }, + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 1}, + }, + Sort: []string{sync3.SortByRecency}, + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 0, + }, + }}, + }) + // We expect only Bob typing, as only Alice Poller is "allowed" + // to update typing notifications. + m.MatchResponse(t, res, m.MatchTyping(roomA, []string{bob})) + } +}