diff --git a/.gitignore b/.gitignore index ac372712..cf5efeb4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ /syncv3 node_modules + +# Go workspaces +go.work +go.work.sum diff --git a/sync3/dispatcher.go b/sync3/dispatcher.go index 6c71d7e5..9fd21de1 100644 --- a/sync3/dispatcher.go +++ b/sync3/dispatcher.go @@ -159,6 +159,7 @@ func (d *Dispatcher) OnNewEvent( targetUser := "" membership := "" shouldForceInitial := false + leaveAfterJoinOrInvite := false if ed.EventType == "m.room.member" && ed.StateKey != nil { targetUser = *ed.StateKey membership = ed.Content.Get("membership").Str @@ -173,7 +174,7 @@ func (d *Dispatcher) OnNewEvent( case "ban": fallthrough case "leave": - d.jrt.UserLeftRoom(targetUser, ed.RoomID) + leaveAfterJoinOrInvite = d.jrt.UserLeftRoom(targetUser, ed.RoomID) } ed.InviteCount = d.jrt.NumInvitedUsersForRoom(ed.RoomID) } @@ -186,6 +187,11 @@ func (d *Dispatcher) OnNewEvent( return d.ReceiverForUser(userID) != nil }) ed.JoinCount = joinCount + if leaveAfterJoinOrInvite { + // Only tell the target user about a leave if they were previously aware of the + // room. This prevents us from leaking pre-emptive bans. + userIDs = append(userIDs, targetUser) + } d.notifyListeners(ctx, ed, userIDs, targetUser, shouldForceInitial, membership) } @@ -256,13 +262,11 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData, } // per-user listeners - notifiedTarget := false for _, userID := range userIDs { l := d.userToReceiver[userID] if l != nil { edd := *ed if targetUser == userID { - notifiedTarget = true if shouldForceInitial { edd.ForceInitial = true } @@ -270,21 +274,6 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData, l.OnNewEvent(ctx, &edd) } } - if targetUser != "" && !notifiedTarget { // e.g invites/leaves where you aren't joined yet but need to know about it - // We expect invites to come down the invitee's poller, which triggers OnInvite code paths and - // not normal event codepaths. We need the separate code path to ensure invite stripped state - // is sent to the conn and not live data. Hence, if we get the invite event early from a different - // connection, do not send it to the target, as they must wait for the invite on their poller. - if membership != "invite" { - if shouldForceInitial { - ed.ForceInitial = true - } - l := d.userToReceiver[targetUser] - if l != nil { - l.OnNewEvent(ctx, ed) - } - } - } } func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, roomID string) { diff --git a/sync3/tracker.go b/sync3/tracker.go index 3a1c73cf..e33fabfc 100644 --- a/sync3/tracker.go +++ b/sync3/tracker.go @@ -115,19 +115,28 @@ func (t *JoinedRoomsTracker) UsersJoinedRoom(userIDs []string, roomID string) bo } // UserLeftRoom marks the given user as having left the given room. -func (t *JoinedRoomsTracker) UserLeftRoom(userID, roomID string) { +// Returns true if this user _was_ joined or invited to the room before this call, +// and false otherwise. +func (t *JoinedRoomsTracker) UserLeftRoom(userID, roomID string) bool { t.mu.Lock() defer t.mu.Unlock() joinedRooms := t.userIDToJoinedRooms[userID] - delete(joinedRooms, roomID) joinedUsers := t.roomIDToJoinedUsers[roomID] - delete(joinedUsers, userID) invitedUsers := t.roomIDToInvitedUsers[roomID] + + _, wasJoined := joinedUsers[userID] + _, wasInvited := invitedUsers[userID] + + delete(joinedRooms, roomID) + delete(joinedUsers, userID) delete(invitedUsers, userID) t.userIDToJoinedRooms[userID] = joinedRooms t.roomIDToJoinedUsers[roomID] = joinedUsers t.roomIDToInvitedUsers[roomID] = invitedUsers + + return wasJoined || wasInvited } + func (t *JoinedRoomsTracker) JoinedRoomsForUser(userID string) []string { t.mu.RLock() defer t.mu.RUnlock() diff --git a/sync3/tracker_test.go b/sync3/tracker_test.go index 7be2336a..0bc58009 100644 --- a/sync3/tracker_test.go +++ b/sync3/tracker_test.go @@ -1,6 +1,7 @@ package sync3 import ( + "fmt" "sort" "testing" ) @@ -82,6 +83,63 @@ func TestTrackerStartup(t *testing.T) { assertInt(t, jrt.NumInvitedUsersForRoom(roomC), 0) } +func TestJoinedRoomsTracker_UserLeftRoom_ReturnValue(t *testing.T) { + alice := "@alice" + bob := "@bob" + + // Tell the tracker that alice left various rooms. Assert its return value is sensible. + + tcs := []struct { + id string + joined []string + invited []string + expectedResult bool + }{ + { + id: "!a", + joined: []string{alice, bob}, + invited: nil, + expectedResult: true, + }, + { + id: "!b", + joined: []string{alice}, + invited: nil, + expectedResult: true, + }, + { + id: "!c", + joined: []string{bob}, + invited: nil, + expectedResult: false, + }, + { + id: "!d", + joined: nil, + invited: nil, + expectedResult: false, + }, + { + id: "!e", + joined: nil, + invited: []string{alice}, + expectedResult: true, + }, + } + + jrt := NewJoinedRoomsTracker() + for _, tc := range tcs { + jrt.UsersJoinedRoom(tc.joined, tc.id) + jrt.UsersInvitedToRoom(tc.invited, tc.id) + } + + // Tell the tracker that Alice left every room. Check the return value is sensible. + for _, tc := range tcs { + wasJoinedOrInvited := jrt.UserLeftRoom(alice, tc.id) + assertBool(t, fmt.Sprintf("wasJoinedOrInvited[%s]", tc.id), wasJoinedOrInvited, tc.expectedResult) + } +} + func assertBool(t *testing.T, msg string, got, want bool) { t.Helper() if got != want { diff --git a/tests-e2e/membership_transitions_test.go b/tests-e2e/membership_transitions_test.go index 3a1a1bbb..66a54cb6 100644 --- a/tests-e2e/membership_transitions_test.go +++ b/tests-e2e/membership_transitions_test.go @@ -802,3 +802,61 @@ func TestMemberCounts(t *testing.T) { }, })) } + +func TestPreemptiveBanIsNotLeaked(t *testing.T) { + alice := registerNamedUser(t, "alice") + nigel := registerNamedUser(t, "nigel") + + t.Log("Alice creates a public room and a DM with Nigel.") + public := alice.MustCreateRoom(t, map[string]interface{}{"preset": "public_chat"}) + dm := alice.MustCreateRoom(t, map[string]interface{}{"preset": "private_chat", "invite": []string{nigel.UserID}}) + + t.Log("Nigel joins the DM") + nigel.JoinRoom(t, dm, nil) + + t.Log("Alice sends a sentinel message into the DM.") + dmSentinel := alice.SendEventSynced(t, dm, b.Event{ + Type: "m.room.message", + Content: map[string]interface{}{"body": "sentinel, sentinel, where have you been?", "msgtype": "m.text"}, + }) + + t.Log("Nigel does an initial sliding sync.") + nigelRes := nigel.SlidingSync(t, sync3.Request{ + Lists: map[string]sync3.RequestList{ + "a": { + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 20, + }, + Ranges: sync3.SliceRanges{{0, 10}}, + }, + }, + }) + t.Log("Nigel sees the sentinel.") + m.MatchResponse(t, nigelRes, m.MatchRoomSubscription(dm, MatchRoomTimelineMostRecent(1, []Event{{ID: dmSentinel}}))) + + t.Log("Alice pre-emptively bans Nigel from the public room.") + alice.MustDo(t, "POST", []string{"_matrix", "client", "v3", "rooms", public, "ban"}, + client.WithJSONBody(t, map[string]any{"user_id": nigel.UserID})) + + t.Log("Alice sliding syncs until she sees the ban.") + alice.SlidingSyncUntilMembership(t, "", public, nigel, "ban") + + t.Log("Alice sends a second sentinel in Nigel's DM.") + dmSentinel2 := alice.SendEventSynced(t, dm, b.Event{ + Type: "m.room.message", + Content: map[string]interface{}{"body": "sentinel 2 placeholder boogaloo", "msgtype": "m.text"}, + }) + + t.Log("Nigel syncs until he sees the second sentinel. He should NOT see his ban event.") + + nigelRes = nigel.SlidingSyncUntil(t, nigelRes.Pos, sync3.Request{}, func(response *sync3.Response) error { + seenPublicRoom := m.MatchRoomSubscription(public) + if seenPublicRoom(response) == nil { + t.Errorf("Nigel had a room subscription for the public room, but shouldn't have.") + m.LogResponse(t)(response) + t.FailNow() + } + seenSentinel := m.MatchRoomSubscription(dm, MatchRoomTimelineMostRecent(1, []Event{{ID: dmSentinel2}})) + return seenSentinel(response) + }) +}