From 2840d0dd52defa260f8cd6462af5654c5e100cd6 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 20 Jul 2023 09:42:03 +0200 Subject: [PATCH 1/9] Add failing test --- sync2/handler2/handler_test.go | 41 +++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go index fa315228..173edbf3 100644 --- a/sync2/handler2/handler_test.go +++ b/sync2/handler2/handler_test.go @@ -1,6 +1,8 @@ package handler2_test import ( + "context" + "encoding/json" "os" "reflect" "sync" @@ -92,12 +94,14 @@ func (p *mockPub) WaitForPayloadType(t string) chan struct{} { return ch } -func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}) { +func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}, wantTimeOut bool) { select { case <-ch: return case <-time.After(time.Second): - t.Fatalf("DoWait: timed out waiting: %s", errMsg) + if !wantTimeOut { + t.Fatalf("DoWait: timed out waiting: %s", errMsg) + } } } @@ -155,7 +159,7 @@ func TestHandlerFreshEnsurePolling(t *testing.T) { DeviceID: deviceID, AccessTokenHash: tok.AccessTokenHash, }) - pub.DoWait(t, "didn't see V2InitialSyncComplete", ch) + pub.DoWait(t, "didn't see V2InitialSyncComplete", ch, false) // make sure we polled with the token i.e it did a db hit pMap.assertCallExists(t, pollInfo{ @@ -169,3 +173,34 @@ func TestHandlerFreshEnsurePolling(t *testing.T) { }) } + +func TestSetTypingConcurrently(t *testing.T) { + store := state.NewStorage(postgresURI) + v2Store := sync2.NewStore(postgresURI, "secret") + pMap := &mockPollerMap{} + pub := newMockPub() + sub := &mockSub{} + h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute) + assertNoError(t, err) + ctx := context.Background() + + roomID := "!typing:localhost" + + typingType := pubsub.V2Typing{} + + // Call SetTyping twice, this may happen with pollers for the same user + go h.SetTyping(ctx, roomID, json.RawMessage(`{"content":"user_ids":["@alice:localhost"]}`)) + go h.SetTyping(ctx, roomID, json.RawMessage(`{"content":"user_ids":["@alice:localhost"]}`)) + + // Wait for the event to be published + ch := pub.WaitForPayloadType(typingType.Type()) + pub.DoWait(t, "didn't see V2Typing", ch, false) + ch = pub.WaitForPayloadType(typingType.Type()) + // Wait again, but this time we expect to timeout. + pub.DoWait(t, "saw unexpected V2Typing", ch, true) + + // We expect only one call to Notify, as the hashes should match + if gotCalls := len(pub.calls); gotCalls != 1 { + t.Fatalf("expected only one call to notify, got %d", gotCalls) + } +} From ca263b261e07a272b3d2c07cc9cc93407fa3cda2 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 20 Jul 2023 12:31:18 +0200 Subject: [PATCH 2/9] Try to call SetTyping more synchronized, fix wrong JSON --- sync2/handler2/handler_test.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go index 173edbf3..398a3da9 100644 --- a/sync2/handler2/handler_test.go +++ b/sync2/handler2/handler_test.go @@ -97,6 +97,9 @@ func (p *mockPub) WaitForPayloadType(t string) chan struct{} { func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}, wantTimeOut bool) { select { case <-ch: + if wantTimeOut { + t.Fatalf("expected to timeout, but received on channel") + } return case <-time.After(time.Second): if !wantTimeOut { @@ -188,9 +191,19 @@ func TestSetTypingConcurrently(t *testing.T) { typingType := pubsub.V2Typing{} + // startSignal is used to synchronize calling SetTyping + startSignal := make(chan struct{}) // Call SetTyping twice, this may happen with pollers for the same user - go h.SetTyping(ctx, roomID, json.RawMessage(`{"content":"user_ids":["@alice:localhost"]}`)) - go h.SetTyping(ctx, roomID, json.RawMessage(`{"content":"user_ids":["@alice:localhost"]}`)) + go func() { + <-startSignal + h.SetTyping(ctx, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) + }() + go func() { + <-startSignal + h.SetTyping(ctx, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) + }() + + close(startSignal) // Wait for the event to be published ch := pub.WaitForPayloadType(typingType.Type()) From e56615856f8f497c19c73ed6af9aa28e7504466d Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 20 Jul 2023 13:06:11 +0200 Subject: [PATCH 3/9] Protect typingMap with a mutex, sort userIDs before hashing --- sync2/handler2/handler.go | 17 +++++++-- sync2/handler2/handler_unexported_test.go | 42 +++++++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 sync2/handler2/handler_unexported_test.go diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 15d8037e..5275a8a7 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -6,6 +6,7 @@ import ( "fmt" "hash/fnv" "os" + "sort" "sync" "time" @@ -43,6 +44,7 @@ type Handler struct { } // room_id => fnv_hash([typing user ids]) typingMap map[string]uint64 + typingMu *sync.Mutex deviceDataTicker *sync2.DeviceDataTicker e2eeWorkerPool *internal.WorkerPool @@ -65,6 +67,7 @@ func NewHandler( Notif int }), typingMap: make(map[string]uint64), + typingMu: &sync.Mutex{}, deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded } @@ -337,6 +340,10 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) { next := typingHash(ephEvent) + // protect typingMap with a lock, so concurrent calls to SetTyping see the correct map + h.typingMu.Lock() + defer h.typingMu.Unlock() + existing := h.typingMap[roomID] if existing == next { return @@ -493,8 +500,14 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) { func typingHash(ephEvent json.RawMessage) uint64 { h := fnv.New64a() - for _, userID := range gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() { - h.Write([]byte(userID.Str)) + parsedUserIDs := gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() + userIDs := make([]string, len(parsedUserIDs)) + for i := range parsedUserIDs { + userIDs[i] = parsedUserIDs[i].Str + } + sort.Strings(userIDs) + for _, userID := range userIDs { + _, _ = h.Write([]byte(userID)) } return h.Sum64() } diff --git a/sync2/handler2/handler_unexported_test.go b/sync2/handler2/handler_unexported_test.go new file mode 100644 index 00000000..060bb259 --- /dev/null +++ b/sync2/handler2/handler_unexported_test.go @@ -0,0 +1,42 @@ +package handler2 + +import ( + "encoding/json" + "testing" +) + +func Test_typingHash(t *testing.T) { + tests := []struct { + name string + ephEvent json.RawMessage + want uint64 + }{ + { + name: "doesn't fall over if list is empty", + ephEvent: json.RawMessage(`{"content":{"user_ids":[]}}`), + want: 14695981039346656037, + }, + { + name: "hash alice typing", + ephEvent: json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`), + want: 16709353342265369358, + }, + { + name: "hash alice and bob typing", + ephEvent: json.RawMessage(`{"content":{"user_ids":["@alice:localhost","@bob:localhost"]}}`), + want: 11071889279173799154, + }, + { + name: "hash bob and alice typing", + ephEvent: json.RawMessage(`{"content":{"user_ids":["@bob:localhost","@alice:localhost"]}}`), + want: 11071889279173799154, // this should be the same as the previous + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := typingHash(tt.ephEvent); got != tt.want { + t.Errorf("typingHash() = %v, want %v", got, tt.want) + } + }) + } +} From 8dc8d4897f4b20fcd72fde4284990b8b194394e0 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 24 Jul 2023 08:40:23 +0200 Subject: [PATCH 4/9] Let only one device handle typing notifications --- sync2/handler2/handler.go | 34 +++++++++++++++++++++++++++++----- sync2/handler2/handler_test.go | 4 ++-- sync2/poller.go | 8 ++++---- sync2/poller_test.go | 2 +- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 5275a8a7..e1750a26 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -46,6 +46,9 @@ type Handler struct { typingMap map[string]uint64 typingMu *sync.Mutex + // room_id -> device_id, stores which device is allowed to update typing notifications + typingDeviceHandler map[string]string + deviceDataTicker *sync2.DeviceDataTicker e2eeWorkerPool *internal.WorkerPool @@ -66,10 +69,11 @@ func NewHandler( Highlight int Notif int }), - typingMap: make(map[string]uint64), - typingMu: &sync.Mutex{}, - deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), - e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded + typingMap: make(map[string]uint64), + typingMu: &sync.Mutex{}, + typingDeviceHandler: make(map[string]string), + deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), + e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded } if enablePrometheus { @@ -169,6 +173,12 @@ func (h *Handler) updateMetrics() { } func (h *Handler) OnTerminated(ctx context.Context, userID, deviceID string) { + // Check if this device is handling any typing notifications, of so, remove it + for roomID, devID := range h.typingDeviceHandler { + if devID == deviceID { + delete(h.typingDeviceHandler, roomID) + } + } h.updateMetrics() } @@ -338,7 +348,16 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra return res.PrependTimelineEvents } -func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) { +func (h *Handler) SetTyping(ctx context.Context, deviceID string, roomID string, ephEvent json.RawMessage) { + existingDevice := h.typingDeviceHandler[roomID] + if existingDevice != "" && existingDevice != deviceID { + // A different device is already handling typing notifications for this room + return + } else if existingDevice == "" { + // We're the first to call SetTyping, assign our deviceID + h.typingDeviceHandler[roomID] = deviceID + } + next := typingHash(ephEvent) // protect typingMap with a lock, so concurrent calls to SetTyping see the correct map h.typingMu.Lock() @@ -462,6 +481,11 @@ func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) { logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to retire invite") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) } + + // Remove room from the typing deviceHandler map, this ensures we always + // have a device handling typing notifications for a given room. + delete(h.typingDeviceHandler, roomID) + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{ UserID: userID, RoomID: roomID, diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go index 398a3da9..dbca780c 100644 --- a/sync2/handler2/handler_test.go +++ b/sync2/handler2/handler_test.go @@ -196,11 +196,11 @@ func TestSetTypingConcurrently(t *testing.T) { // Call SetTyping twice, this may happen with pollers for the same user go func() { <-startSignal - h.SetTyping(ctx, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) + h.SetTyping(ctx, "aliceDevice", roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) }() go func() { <-startSignal - h.SetTyping(ctx, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) + h.SetTyping(ctx, "bobDevice", roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) }() close(startSignal) diff --git a/sync2/poller.go b/sync2/poller.go index 062366a6..79864b46 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -39,7 +39,7 @@ type V2DataReceiver interface { // If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB. Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID? // SetTyping indicates which users are typing. - SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) + SetTyping(ctx context.Context, deviceID string, roomID string, ephEvent json.RawMessage) // Sent when there is a new receipt OnReceipt(ctx context.Context, userID, roomID, ephEventType string, ephEvent json.RawMessage) // AddToDeviceMessages adds this chunk of to_device messages. Preserve the ordering. @@ -282,11 +282,11 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json. wg.Wait() return } -func (h *PollerMap) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) { +func (h *PollerMap) SetTyping(ctx context.Context, deviceID string, roomID string, ephEvent json.RawMessage) { var wg sync.WaitGroup wg.Add(1) h.executor <- func() { - h.callbacks.SetTyping(ctx, roomID, ephEvent) + h.callbacks.SetTyping(ctx, deviceID, roomID, ephEvent) wg.Done() } wg.Wait() @@ -691,7 +691,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { switch ephEventType { case "m.typing": typingCalls++ - p.receiver.SetTyping(ctx, roomID, ephEvent) + p.receiver.SetTyping(ctx, p.deviceID, roomID, ephEvent) case "m.receipt": receiptCalls++ p.receiver.OnReceipt(ctx, p.userID, roomID, ephEventType, ephEvent) diff --git a/sync2/poller_test.go b/sync2/poller_test.go index 04c1c0f8..637feba7 100644 --- a/sync2/poller_test.go +++ b/sync2/poller_test.go @@ -598,7 +598,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state // timeline. Untested here---return nil for now. return nil } -func (a *mockDataReceiver) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) { +func (a *mockDataReceiver) SetTyping(ctx context.Context, deviceID string, roomID string, ephEvent json.RawMessage) { } func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) { s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since From 3a2001f07dd00a13cc9ff465b740c8f64fffae0e Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 27 Jul 2023 12:33:10 +0200 Subject: [PATCH 5/9] Use PollerID instead of device ID --- sync2/handler2/handler.go | 72 ++++++++--------------- sync2/handler2/handler_test.go | 4 +- sync2/handler2/handler_unexported_test.go | 42 ------------- sync2/poller.go | 19 +++--- sync2/poller_test.go | 4 +- 5 files changed, 40 insertions(+), 101 deletions(-) delete mode 100644 sync2/handler2/handler_unexported_test.go diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index e1750a26..98184f1e 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -4,9 +4,7 @@ import ( "context" "encoding/json" "fmt" - "hash/fnv" "os" - "sort" "sync" "time" @@ -42,12 +40,9 @@ type Handler struct { Highlight int Notif int } - // room_id => fnv_hash([typing user ids]) - typingMap map[string]uint64 - typingMu *sync.Mutex - - // room_id -> device_id, stores which device is allowed to update typing notifications - typingDeviceHandler map[string]string + // room_id -> PollerID, stores which Poller is allowed to update typing notifications + typingHandler map[string]sync2.PollerID + typingMu *sync.Mutex deviceDataTicker *sync2.DeviceDataTicker e2eeWorkerPool *internal.WorkerPool @@ -69,11 +64,10 @@ func NewHandler( Highlight int Notif int }), - typingMap: make(map[string]uint64), - typingMu: &sync.Mutex{}, - typingDeviceHandler: make(map[string]string), - deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), - e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded + typingMu: &sync.Mutex{}, + typingHandler: make(map[string]sync2.PollerID), + deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), + e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded } if enablePrometheus { @@ -172,11 +166,13 @@ func (h *Handler) updateMetrics() { h.numPollers.Set(float64(h.pMap.NumPollers())) } -func (h *Handler) OnTerminated(ctx context.Context, userID, deviceID string) { +func (h *Handler) OnTerminated(ctx context.Context, pollerID sync2.PollerID) { // Check if this device is handling any typing notifications, of so, remove it - for roomID, devID := range h.typingDeviceHandler { - if devID == deviceID { - delete(h.typingDeviceHandler, roomID) + h.typingMu.Lock() + defer h.typingMu.Unlock() + for roomID, devID := range h.typingHandler { + if devID == pollerID { + delete(h.typingHandler, roomID) } } h.updateMetrics() @@ -348,26 +344,20 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra return res.PrependTimelineEvents } -func (h *Handler) SetTyping(ctx context.Context, deviceID string, roomID string, ephEvent json.RawMessage) { - existingDevice := h.typingDeviceHandler[roomID] - if existingDevice != "" && existingDevice != deviceID { - // A different device is already handling typing notifications for this room - return - } else if existingDevice == "" { - // We're the first to call SetTyping, assign our deviceID - h.typingDeviceHandler[roomID] = deviceID - } - - next := typingHash(ephEvent) - // protect typingMap with a lock, so concurrent calls to SetTyping see the correct map +func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) { h.typingMu.Lock() defer h.typingMu.Unlock() - existing := h.typingMap[roomID] - if existing == next { + existingDevice := h.typingHandler[roomID] + isPollerAssigned := existingDevice.DeviceID != "" && existingDevice.UserID != "" + if isPollerAssigned && existingDevice != pollerID { + // A different device is already handling typing notifications for this room return + } else if !isPollerAssigned { + // We're the first to call SetTyping, assign our pollerID + h.typingHandler[roomID] = pollerID } - h.typingMap[roomID] = next + // we don't persist this for long term storage as typing notifs are inherently ephemeral. // So rather than maintaining them forever, they will naturally expire when we terminate. h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Typing{ @@ -484,7 +474,9 @@ func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) { // Remove room from the typing deviceHandler map, this ensures we always // have a device handling typing notifications for a given room. - delete(h.typingDeviceHandler, roomID) + h.typingMu.Lock() + defer h.typingMu.Unlock() + delete(h.typingHandler, roomID) h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{ UserID: userID, @@ -521,17 +513,3 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) { }) }() } - -func typingHash(ephEvent json.RawMessage) uint64 { - h := fnv.New64a() - parsedUserIDs := gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() - userIDs := make([]string, len(parsedUserIDs)) - for i := range parsedUserIDs { - userIDs[i] = parsedUserIDs[i].Str - } - sort.Strings(userIDs) - for _, userID := range userIDs { - _, _ = h.Write([]byte(userID)) - } - return h.Sum64() -} diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go index dbca780c..4c349b16 100644 --- a/sync2/handler2/handler_test.go +++ b/sync2/handler2/handler_test.go @@ -196,11 +196,11 @@ func TestSetTypingConcurrently(t *testing.T) { // Call SetTyping twice, this may happen with pollers for the same user go func() { <-startSignal - h.SetTyping(ctx, "aliceDevice", roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) + h.SetTyping(ctx, sync2.PollerID{UserID: "@alice", DeviceID: "aliceDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) }() go func() { <-startSignal - h.SetTyping(ctx, "bobDevice", roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) + h.SetTyping(ctx, sync2.PollerID{UserID: "@bob", DeviceID: "bobDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`)) }() close(startSignal) diff --git a/sync2/handler2/handler_unexported_test.go b/sync2/handler2/handler_unexported_test.go deleted file mode 100644 index 060bb259..00000000 --- a/sync2/handler2/handler_unexported_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package handler2 - -import ( - "encoding/json" - "testing" -) - -func Test_typingHash(t *testing.T) { - tests := []struct { - name string - ephEvent json.RawMessage - want uint64 - }{ - { - name: "doesn't fall over if list is empty", - ephEvent: json.RawMessage(`{"content":{"user_ids":[]}}`), - want: 14695981039346656037, - }, - { - name: "hash alice typing", - ephEvent: json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`), - want: 16709353342265369358, - }, - { - name: "hash alice and bob typing", - ephEvent: json.RawMessage(`{"content":{"user_ids":["@alice:localhost","@bob:localhost"]}}`), - want: 11071889279173799154, - }, - { - name: "hash bob and alice typing", - ephEvent: json.RawMessage(`{"content":{"user_ids":["@bob:localhost","@alice:localhost"]}}`), - want: 11071889279173799154, // this should be the same as the previous - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := typingHash(tt.ephEvent); got != tt.want { - t.Errorf("typingHash() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/sync2/poller.go b/sync2/poller.go index 79864b46..68d51214 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -39,7 +39,7 @@ type V2DataReceiver interface { // If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB. Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID? // SetTyping indicates which users are typing. - SetTyping(ctx context.Context, deviceID string, roomID string, ephEvent json.RawMessage) + SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) // Sent when there is a new receipt OnReceipt(ctx context.Context, userID, roomID, ephEventType string, ephEvent json.RawMessage) // AddToDeviceMessages adds this chunk of to_device messages. Preserve the ordering. @@ -55,7 +55,7 @@ type V2DataReceiver interface { // Sent when there is a _change_ in E2EE data, not all the time OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) // Sent when the poll loop terminates - OnTerminated(ctx context.Context, userID, deviceID string) + OnTerminated(ctx context.Context, pollerID PollerID) // Sent when the token gets a 401 response OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) } @@ -282,11 +282,11 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json. wg.Wait() return } -func (h *PollerMap) SetTyping(ctx context.Context, deviceID string, roomID string, ephEvent json.RawMessage) { +func (h *PollerMap) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) { var wg sync.WaitGroup wg.Add(1) h.executor <- func() { - h.callbacks.SetTyping(ctx, deviceID, roomID, ephEvent) + h.callbacks.SetTyping(ctx, pollerID, roomID, ephEvent) wg.Done() } wg.Wait() @@ -317,8 +317,8 @@ func (h *PollerMap) AddToDeviceMessages(ctx context.Context, userID, deviceID st h.callbacks.AddToDeviceMessages(ctx, userID, deviceID, msgs) } -func (h *PollerMap) OnTerminated(ctx context.Context, userID, deviceID string) { - h.callbacks.OnTerminated(ctx, userID, deviceID) +func (h *PollerMap) OnTerminated(ctx context.Context, pollerID PollerID) { + h.callbacks.OnTerminated(ctx, pollerID) } func (h *PollerMap) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) { @@ -458,7 +458,10 @@ func (p *poller) Poll(since string) { logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msgf("%s. Traceback:\n%s", panicErr, debug.Stack()) internal.GetSentryHubFromContextOrDefault(ctx).RecoverWithContext(ctx, panicErr) } - p.receiver.OnTerminated(ctx, p.userID, p.deviceID) + p.receiver.OnTerminated(ctx, PollerID{ + UserID: p.userID, + DeviceID: p.deviceID, + }) }() state := pollLoopState{ @@ -691,7 +694,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { switch ephEventType { case "m.typing": typingCalls++ - p.receiver.SetTyping(ctx, p.deviceID, roomID, ephEvent) + p.receiver.SetTyping(ctx, PollerID{UserID: p.userID, DeviceID: p.deviceID}, roomID, ephEvent) case "m.receipt": receiptCalls++ p.receiver.OnReceipt(ctx, p.userID, roomID, ephEventType, ephEvent) diff --git a/sync2/poller_test.go b/sync2/poller_test.go index 637feba7..e1308177 100644 --- a/sync2/poller_test.go +++ b/sync2/poller_test.go @@ -598,7 +598,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state // timeline. Untested here---return nil for now. return nil } -func (a *mockDataReceiver) SetTyping(ctx context.Context, deviceID string, roomID string, ephEvent json.RawMessage) { +func (a *mockDataReceiver) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) { } func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) { s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since @@ -620,7 +620,7 @@ func (s *mockDataReceiver) OnInvite(ctx context.Context, userID, roomID string, func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string) {} func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) { } -func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {} +func (s *mockDataReceiver) OnTerminated(ctx context.Context, pollerID PollerID) {} func (s *mockDataReceiver) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) { } From 86c4f18c7e6dbd44efb28335c861cd268c29ef64 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 31 Jul 2023 15:16:54 +0200 Subject: [PATCH 6/9] Add test to validate that only one poller is allowed to change typing notifications --- tests-integration/extensions_test.go | 91 ++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/tests-integration/extensions_test.go b/tests-integration/extensions_test.go index c9a9934a..9619b617 100644 --- a/tests-integration/extensions_test.go +++ b/tests-integration/extensions_test.go @@ -604,3 +604,94 @@ func TestExtensionLateEnable(t *testing.T) { }, }) } + +func TestTypingMultiplePoller(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + // setup code + v2 := runTestV2Server(t) + v3 := runTestServer(t, v2, pqString) + defer v2.close() + defer v3.close() + + roomA := "!a:localhost" + + v2.addAccountWithDeviceID(alice, "first", aliceToken) + v2.addAccountWithDeviceID(bob, "second", bobToken) + + // start the pollers + aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{}) + bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{}) + + // Create the room state and join with Bob + roomState := createRoomState(t, alice, time.Now()) + joinEv := testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{ + "membership": "join", + }) + roomState = append(roomState, joinEv) + + // Queue the response with Alice typing + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomA: { + State: sync2.EventsResponse{ + Events: roomState, + }, + Ephemeral: sync2.EventsResponse{ + Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@alice:localhost"]}}`)}, + }, + }, + }, + }, + }) + // Wait for the server to have processed the syncv2 response. + // This ensures the poller of Alice will be assigned the typing notification handler. + v2.waitUntilEmpty(t, aliceToken) + + // Queue another response for bob, with bob typing. + // Since Bobs poller isn't allowed to update typing notifications, we should only see + // Alice typing below. + v2.queueResponse(bobToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomA: { + State: sync2.EventsResponse{ + Events: roomState, + }, + Ephemeral: sync2.EventsResponse{ + Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)}}, + }, + }, + }, + }) + // Wait for the server to have processed that event + v2.waitUntilEmpty(t, bobToken) + + // Get the response from v3 + for _, token := range []string{aliceToken, bobToken} { + pos := aliceRes.Pos + if token == bobToken { + pos = bobRes.Pos + } + + res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{ + Extensions: extensions.Request{ + Typing: &extensions.TypingRequest{ + Core: extensions.Core{Enabled: &boolTrue}, + }, + }, + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 1}, + }, + Sort: []string{sync3.SortByRecency}, + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 0, + }, + }}, + }) + // We expect only Alice typing, as only Alice Poller is "allowed" + // to update typing notifications. + m.MatchResponse(t, res, m.MatchTyping(roomA, []string{alice})) + } +} From 1b1d00db9529d85e352baa2835e54e39107014c6 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 1 Aug 2023 13:11:28 +0200 Subject: [PATCH 7/9] Update test a bit --- tests-integration/extensions_test.go | 72 ++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/tests-integration/extensions_test.go b/tests-integration/extensions_test.go index 9619b617..89977a22 100644 --- a/tests-integration/extensions_test.go +++ b/tests-integration/extensions_test.go @@ -618,10 +618,6 @@ func TestTypingMultiplePoller(t *testing.T) { v2.addAccountWithDeviceID(alice, "first", aliceToken) v2.addAccountWithDeviceID(bob, "second", bobToken) - // start the pollers - aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{}) - bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{}) - // Create the room state and join with Bob roomState := createRoomState(t, alice, time.Now()) joinEv := testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{ @@ -637,6 +633,9 @@ func TestTypingMultiplePoller(t *testing.T) { State: sync2.EventsResponse{ Events: roomState, }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{joinEv}, + }, Ephemeral: sync2.EventsResponse{ Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@alice:localhost"]}}`)}, }, @@ -644,9 +643,6 @@ func TestTypingMultiplePoller(t *testing.T) { }, }, }) - // Wait for the server to have processed the syncv2 response. - // This ensures the poller of Alice will be assigned the typing notification handler. - v2.waitUntilEmpty(t, aliceToken) // Queue another response for bob, with bob typing. // Since Bobs poller isn't allowed to update typing notifications, we should only see @@ -658,14 +654,19 @@ func TestTypingMultiplePoller(t *testing.T) { State: sync2.EventsResponse{ Events: roomState, }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{joinEv}, + }, Ephemeral: sync2.EventsResponse{ Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)}}, }, }, }, }) - // Wait for the server to have processed that event - v2.waitUntilEmpty(t, bobToken) + + // start the pollers + aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{}) + bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{}) // Get the response from v3 for _, token := range []string{aliceToken, bobToken} { @@ -693,5 +694,58 @@ func TestTypingMultiplePoller(t *testing.T) { // We expect only Alice typing, as only Alice Poller is "allowed" // to update typing notifications. m.MatchResponse(t, res, m.MatchTyping(roomA, []string{alice})) + if token == bobToken { + bobRes = res + } + if token == aliceToken { + aliceRes = res + } + } + + // Queue the response with Bob typing + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomA: { + State: sync2.EventsResponse{ + Events: roomState, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{joinEv}, + }, + Ephemeral: sync2.EventsResponse{ + Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)}, + }, + }, + }, + }, + }) + + // Get the response from v3 + for _, token := range []string{aliceToken, bobToken} { + pos := aliceRes.Pos + if token == bobToken { + pos = bobRes.Pos + } + + res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{ + Extensions: extensions.Request{ + Typing: &extensions.TypingRequest{ + Core: extensions.Core{Enabled: &boolTrue}, + }, + }, + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 1}, + }, + Sort: []string{sync3.SortByRecency}, + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 0, + }, + }}, + }) + // We expect only Bob typing, as only Alice Poller is "allowed" + // to update typing notifications. + m.MatchResponse(t, res, m.MatchTyping(roomA, []string{bob})) } } From d357aa234fd83cb45da4527833bf1e5290d8589e Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:57:22 +0200 Subject: [PATCH 8/9] Update test again to validate that we don't see Charlie typing --- tests-integration/extensions_test.go | 31 ++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/tests-integration/extensions_test.go b/tests-integration/extensions_test.go index 89977a22..704e615c 100644 --- a/tests-integration/extensions_test.go +++ b/tests-integration/extensions_test.go @@ -623,7 +623,6 @@ func TestTypingMultiplePoller(t *testing.T) { joinEv := testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{ "membership": "join", }) - roomState = append(roomState, joinEv) // Queue the response with Alice typing v2.queueResponse(aliceToken, sync2.SyncResponse{ @@ -644,9 +643,7 @@ func TestTypingMultiplePoller(t *testing.T) { }, }) - // Queue another response for bob, with bob typing. - // Since Bobs poller isn't allowed to update typing notifications, we should only see - // Alice typing below. + // Queue another response for Bob with Bob typing. v2.queueResponse(bobToken, sync2.SyncResponse{ Rooms: sync2.SyncRoomsResponse{ Join: map[string]sync2.SyncV2JoinResponse{ @@ -664,7 +661,8 @@ func TestTypingMultiplePoller(t *testing.T) { }, }) - // start the pollers + // Start the pollers. Since Alice's poller is started first, the poller is in + // charge of handling typing notifications for roomA. aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{}) bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{}) @@ -721,7 +719,28 @@ func TestTypingMultiplePoller(t *testing.T) { }, }) - // Get the response from v3 + // Queue another response for Bob with Charlie typing. + // Since Alice's poller is in charge of handling typing notifications, this shouldn't + // show up on future responses. + v2.queueResponse(bobToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomA: { + State: sync2.EventsResponse{ + Events: roomState, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{joinEv}, + }, + Ephemeral: sync2.EventsResponse{ + Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@charlie:localhost"]}}`)}, + }, + }, + }, + }, + }) + + // Check that only Bob is typing and not Charlie. for _, token := range []string{aliceToken, bobToken} { pos := aliceRes.Pos if token == bobToken { From 2ced1986b4420004249f3cb31d4d7fbaa5be1915 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Wed, 2 Aug 2023 14:04:32 +0200 Subject: [PATCH 9/9] Wait for responses to be processed --- tests-integration/extensions_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests-integration/extensions_test.go b/tests-integration/extensions_test.go index 704e615c..11f7c6ff 100644 --- a/tests-integration/extensions_test.go +++ b/tests-integration/extensions_test.go @@ -740,6 +740,10 @@ func TestTypingMultiplePoller(t *testing.T) { }, }) + // Wait for the queued responses to be processed. + v2.waitUntilEmpty(t, aliceToken) + v2.waitUntilEmpty(t, bobToken) + // Check that only Bob is typing and not Charlie. for _, token := range []string{aliceToken, bobToken} { pos := aliceRes.Pos