Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix V2Typing send multiple times #214

Merged
merged 14 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions sync2/handler2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"os"
"sync"
"time"
Expand Down Expand Up @@ -41,8 +40,9 @@ type Handler struct {
Highlight int
Notif int
}
// room_id => fnv_hash([typing user ids])
typingMap map[string]uint64
// room_id -> PollerID, stores which Poller is allowed to update typing notifications
typingHandler map[string]sync2.PollerID
typingMu *sync.Mutex

deviceDataTicker *sync2.DeviceDataTicker
e2eeWorkerPool *internal.WorkerPool
Expand All @@ -64,7 +64,8 @@ func NewHandler(
Highlight int
Notif int
}),
typingMap: make(map[string]uint64),
typingMu: &sync.Mutex{},
typingHandler: make(map[string]sync2.PollerID),
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded
}
Expand Down Expand Up @@ -165,7 +166,15 @@ func (h *Handler) updateMetrics() {
h.numPollers.Set(float64(h.pMap.NumPollers()))
}

func (h *Handler) OnTerminated(ctx context.Context, userID, deviceID string) {
func (h *Handler) OnTerminated(ctx context.Context, pollerID sync2.PollerID) {
// Check if this device is handling any typing notifications, of so, remove it
h.typingMu.Lock()
defer h.typingMu.Unlock()
for roomID, devID := range h.typingHandler {
if devID == pollerID {
delete(h.typingHandler, roomID)
}
}
h.updateMetrics()
}

Expand Down Expand Up @@ -335,13 +344,20 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra
return res.PrependTimelineEvents
}

func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
next := typingHash(ephEvent)
existing := h.typingMap[roomID]
if existing == next {
func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) {
h.typingMu.Lock()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though I couldn't get the test fail in CI, this was causing a race condition.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the bits which touch typingMap or typingDeviceHandler should be protected.

defer h.typingMu.Unlock()

existingDevice := h.typingHandler[roomID]
isPollerAssigned := existingDevice.DeviceID != "" && existingDevice.UserID != ""
if isPollerAssigned && existingDevice != pollerID {
// A different device is already handling typing notifications for this room
return
} else if !isPollerAssigned {
// We're the first to call SetTyping, assign our pollerID
h.typingHandler[roomID] = pollerID
}
h.typingMap[roomID] = next

// we don't persist this for long term storage as typing notifs are inherently ephemeral.
// So rather than maintaining them forever, they will naturally expire when we terminate.
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Typing{
Expand Down Expand Up @@ -455,6 +471,13 @@ func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to retire invite")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
}

// Remove room from the typing deviceHandler map, this ensures we always
// have a device handling typing notifications for a given room.
h.typingMu.Lock()
defer h.typingMu.Unlock()
delete(h.typingHandler, roomID)

h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{
UserID: userID,
RoomID: roomID,
Expand Down Expand Up @@ -490,11 +513,3 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
})
}()
}

func typingHash(ephEvent json.RawMessage) uint64 {
h := fnv.New64a()
for _, userID := range gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() {
h.Write([]byte(userID.Str))
}
return h.Sum64()
}
54 changes: 51 additions & 3 deletions sync2/handler2/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package handler2_test

import (
"context"
"encoding/json"
"os"
"reflect"
"sync"
Expand Down Expand Up @@ -92,12 +94,17 @@ func (p *mockPub) WaitForPayloadType(t string) chan struct{} {
return ch
}

func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}) {
func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}, wantTimeOut bool) {
select {
case <-ch:
if wantTimeOut {
t.Fatalf("expected to timeout, but received on channel")
}
return
case <-time.After(time.Second):
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
if !wantTimeOut {
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
}
}
}

Expand Down Expand Up @@ -155,7 +162,7 @@ func TestHandlerFreshEnsurePolling(t *testing.T) {
DeviceID: deviceID,
AccessTokenHash: tok.AccessTokenHash,
})
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch)
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch, false)

// make sure we polled with the token i.e it did a db hit
pMap.assertCallExists(t, pollInfo{
Expand All @@ -169,3 +176,44 @@ func TestHandlerFreshEnsurePolling(t *testing.T) {
})

}

func TestSetTypingConcurrently(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't what I'm interested in testing. The case here should be passing without your change.

The case I'm interested in is when you have 2 pollers receiving delayed typing notifs. For example. if alice starts typing then stops typing (so [A] then []) the problem is that 1 poller may be "behind" the other, such that it has yet to see [A] whilst the other "ahead" poller has already seen [A] and []. In this scenario, we flicker with 4 operations instead of 2, as we go [A], [], [A], [], which this test is not testing, nor does the code fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't what I'm interested in testing. The case here should be passing without your change.

It doesn't always pass. With luck it does, yea, if the machine is slow enough to execute both calls.

fatal error: concurrent map writes

goroutine 246 [running]:
github.com/matrix-org/sliding-sync/sync2/handler2.(*Handler).SetTyping(0xc0003d2080, {0x0?, 0x0?}, {0xa24ad3, 0x11}, {0xc0004e8090, 0x2d, 0x2d})
        github.com/sliding-sync/sync2/handler2/handler.go:344 +0x96
github.com/matrix-org/sliding-sync/sync2/handler2_test.TestSetTypingConcurrently.func2()
       github.com/sliding-sync/sync2/handler2/handler_test.go:203 +0xd9
created by github.com/matrix-org/sliding-sync/sync2/handler2_test.TestSetTypingConcurrently
       github.com/sliding-sync/sync2/handler2/handler_test.go:201 +0x2b0

which also means that h.typingMap[roomID] returned 0 as the existing value, resulting in duplicate notifications (if the machine is, again, slow enough, that the map writes aren't concurrent :D)

store := state.NewStorage(postgresURI)
v2Store := sync2.NewStore(postgresURI, "secret")
pMap := &mockPollerMap{}
pub := newMockPub()
sub := &mockSub{}
h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute)
assertNoError(t, err)
ctx := context.Background()

roomID := "!typing:localhost"

typingType := pubsub.V2Typing{}

// startSignal is used to synchronize calling SetTyping
startSignal := make(chan struct{})
// Call SetTyping twice, this may happen with pollers for the same user
go func() {
<-startSignal
h.SetTyping(ctx, sync2.PollerID{UserID: "@alice", DeviceID: "aliceDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
}()
go func() {
<-startSignal
h.SetTyping(ctx, sync2.PollerID{UserID: "@bob", DeviceID: "bobDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
}()

close(startSignal)

// Wait for the event to be published
ch := pub.WaitForPayloadType(typingType.Type())
pub.DoWait(t, "didn't see V2Typing", ch, false)
ch = pub.WaitForPayloadType(typingType.Type())
// Wait again, but this time we expect to timeout.
pub.DoWait(t, "saw unexpected V2Typing", ch, true)

// We expect only one call to Notify, as the hashes should match
if gotCalls := len(pub.calls); gotCalls != 1 {
t.Fatalf("expected only one call to notify, got %d", gotCalls)
}
}
19 changes: 11 additions & 8 deletions sync2/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type V2DataReceiver interface {
// If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB.
Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID?
// SetTyping indicates which users are typing.
SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage)
SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage)
// Sent when there is a new receipt
OnReceipt(ctx context.Context, userID, roomID, ephEventType string, ephEvent json.RawMessage)
// AddToDeviceMessages adds this chunk of to_device messages. Preserve the ordering.
Expand All @@ -55,7 +55,7 @@ type V2DataReceiver interface {
// Sent when there is a _change_ in E2EE data, not all the time
OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
// Sent when the poll loop terminates
OnTerminated(ctx context.Context, userID, deviceID string)
OnTerminated(ctx context.Context, pollerID PollerID)
// Sent when the token gets a 401 response
OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string)
}
Expand Down Expand Up @@ -282,11 +282,11 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.
wg.Wait()
return
}
func (h *PollerMap) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
func (h *PollerMap) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.SetTyping(ctx, roomID, ephEvent)
h.callbacks.SetTyping(ctx, pollerID, roomID, ephEvent)
wg.Done()
}
wg.Wait()
Expand Down Expand Up @@ -317,8 +317,8 @@ func (h *PollerMap) AddToDeviceMessages(ctx context.Context, userID, deviceID st
h.callbacks.AddToDeviceMessages(ctx, userID, deviceID, msgs)
}

func (h *PollerMap) OnTerminated(ctx context.Context, userID, deviceID string) {
h.callbacks.OnTerminated(ctx, userID, deviceID)
func (h *PollerMap) OnTerminated(ctx context.Context, pollerID PollerID) {
h.callbacks.OnTerminated(ctx, pollerID)
}

func (h *PollerMap) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
Expand Down Expand Up @@ -458,7 +458,10 @@ func (p *poller) Poll(since string) {
logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msgf("%s. Traceback:\n%s", panicErr, debug.Stack())
internal.GetSentryHubFromContextOrDefault(ctx).RecoverWithContext(ctx, panicErr)
}
p.receiver.OnTerminated(ctx, p.userID, p.deviceID)
p.receiver.OnTerminated(ctx, PollerID{
UserID: p.userID,
DeviceID: p.deviceID,
})
}()

state := pollLoopState{
Expand Down Expand Up @@ -691,7 +694,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
switch ephEventType {
case "m.typing":
typingCalls++
p.receiver.SetTyping(ctx, roomID, ephEvent)
p.receiver.SetTyping(ctx, PollerID{UserID: p.userID, DeviceID: p.deviceID}, roomID, ephEvent)
case "m.receipt":
receiptCalls++
p.receiver.OnReceipt(ctx, p.userID, roomID, ephEventType, ephEvent)
Expand Down
4 changes: 2 additions & 2 deletions sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state
// timeline. Untested here---return nil for now.
return nil
}
func (a *mockDataReceiver) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
func (a *mockDataReceiver) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
}
func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) {
s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since
Expand All @@ -620,7 +620,7 @@ func (s *mockDataReceiver) OnInvite(ctx context.Context, userID, roomID string,
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string) {}
func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, pollerID PollerID) {}
func (s *mockDataReceiver) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
}

Expand Down
91 changes: 91 additions & 0 deletions tests-integration/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,94 @@ func TestExtensionLateEnable(t *testing.T) {
},
})
}

func TestTypingMultiplePoller(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
defer v2.close()
defer v3.close()

roomA := "!a:localhost"

v2.addAccountWithDeviceID(alice, "first", aliceToken)
v2.addAccountWithDeviceID(bob, "second", bobToken)

// start the pollers
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{})
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{})

// Create the room state and join with Bob
roomState := createRoomState(t, alice, time.Now())
joinEv := testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{
"membership": "join",
})
roomState = append(roomState, joinEv)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required, you're adding this in the timeline.


// Queue the response with Alice typing
v2.queueResponse(aliceToken, sync2.SyncResponse{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this before you call mustDoV3Request.

Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomA: {
State: sync2.EventsResponse{
Events: roomState,
},
Ephemeral: sync2.EventsResponse{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need something in Timeline as well, else it isn't mimicing Synapse. Put joinEv in it.

Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@alice:localhost"]}}`)},
},
},
},
},
})
// Wait for the server to have processed the syncv2 response.
// This ensures the poller of Alice will be assigned the typing notification handler.
v2.waitUntilEmpty(t, aliceToken)

// Queue another response for bob, with bob typing.
// Since Bobs poller isn't allowed to update typing notifications, we should only see
// Alice typing below.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment makes no sense as that hasn't been decided yet. Move it down to when you start doing v3 reqs.

v2.queueResponse(bobToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomA: {
State: sync2.EventsResponse{
Events: roomState,
},
Ephemeral: sync2.EventsResponse{
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)}},
},
},
},
})
// Wait for the server to have processed that event
v2.waitUntilEmpty(t, bobToken)

// Get the response from v3
for _, token := range []string{aliceToken, bobToken} {
pos := aliceRes.Pos
if token == bobToken {
pos = bobRes.Pos
}

res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{
Extensions: extensions.Request{
Typing: &extensions.TypingRequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 1},
},
Sort: []string{sync3.SortByRecency},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 0,
},
}},
})
// We expect only Alice typing, as only Alice Poller is "allowed"
// to update typing notifications.
m.MatchResponse(t, res, m.MatchTyping(roomA, []string{alice}))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then test that if you update the typing event to be bob then it comes through please.

}
Loading