Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sliding Sync: Slight optimization when fetching state for the room (get_events_as_list(...)) #17718

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions changelog.d/17718.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Slight optimization when fetching state/events for Sliding Sync.
8 changes: 3 additions & 5 deletions synapse/handlers/sliding_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,11 @@ async def get_current_state_at(
to_token=to_token,
)

event_map = await self.store.get_events(list(state_ids.values()))
events = await self.store.get_events_as_list(list(state_ids.values()))

state_map = {}
for key, event_id in state_ids.items():
event = event_map.get(event_id)
if event:
state_map[key] = event
for event in events:
state_map[(event.type, event.state_key)] = event
Comment on lines +446 to +450
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One slight optimization to use get_events_as_list(...) directly instead of get_events(...). get_events(...) just turns the result from get_events_as_list(...) into a dict and since we're just iterating over the events, we don't need the dict/map.


return state_map

Expand Down
41 changes: 37 additions & 4 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@
current_context,
make_deferred_yieldable,
)
from synapse.logging.opentracing import start_active_span, tag_args, trace
from synapse.logging.opentracing import (
SynapseTags,
set_tag,
start_active_span,
tag_args,
trace,
)
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
Expand Down Expand Up @@ -525,6 +531,7 @@ async def get_event(

return event

@trace
async def get_events(
self,
event_ids: Collection[str],
Expand Down Expand Up @@ -556,6 +563,11 @@ async def get_events(
Returns:
A mapping from event_id to event.
"""
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
str(len(event_ids)),
)

events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
Expand Down Expand Up @@ -603,6 +615,10 @@ async def get_events_as_list(
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
"""
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
str(len(event_ids)),
)

if not event_ids:
return []
Expand Down Expand Up @@ -723,10 +739,11 @@ async def get_events_as_list(

return events

@trace
@cancellable
async def get_unredacted_events_from_cache_or_db(
self,
event_ids: Iterable[str],
event_ids: Collection[str],
Copy link
Collaborator Author

@MadLittleMods MadLittleMods Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Collection because we iterate over event_ids multiple times (even before I added in the tracing tag)

allow_rejected: bool = False,
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
Expand All @@ -748,6 +765,11 @@ async def get_unredacted_events_from_cache_or_db(
Returns:
map from event id to result
"""
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
str(len(event_ids)),
)

# Shortcut: check if we have any events in the *in memory* cache - this function
# may be called repeatedly for the same event so at this point we cannot reach
# out to any external cache for performance reasons. The external cache is
Expand Down Expand Up @@ -936,7 +958,7 @@ async def _get_events_from_cache(
events, update_metrics=update_metrics
)

missing_event_ids = (e for e in events if e not in event_map)
missing_event_ids = [e for e in events if e not in event_map]
event_map.update(
await self._get_events_from_external_cache(
events=missing_event_ids,
Expand All @@ -946,8 +968,9 @@ async def _get_events_from_cache(

return event_map

@trace
async def _get_events_from_external_cache(
self, events: Iterable[str], update_metrics: bool = True
self, events: Collection[str], update_metrics: bool = True
Copy link
Collaborator Author

@MadLittleMods MadLittleMods Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one was a proper Iterable before and only used once but updated to be a Collection now that we also read from it for the tracing tag.

) -> Dict[str, EventCacheEntry]:
"""Fetch events from any configured external cache.

Expand All @@ -957,6 +980,10 @@ async def _get_events_from_external_cache(
events: list of event_ids to fetch
update_metrics: Whether to update the cache hit ratio metrics
"""
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "events.length",
str(len(events)),
)
event_map = {}

for event_id in events:
Expand Down Expand Up @@ -1222,6 +1249,7 @@ def fire_errback(exc: Exception) -> None:
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire_errback, e)

@trace
async def _get_events_from_db(
self, event_ids: Collection[str]
) -> Dict[str, EventCacheEntry]:
Expand All @@ -1240,6 +1268,11 @@ async def _get_events_from_db(
map from event id to result. May return extra events which
weren't asked for.
"""
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
str(len(event_ids)),
)

fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}

Expand Down
49 changes: 48 additions & 1 deletion tests/storage/databases/main/test_events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#
import json
from contextlib import contextmanager
from typing import Generator, List, Tuple
from typing import Generator, List, Set, Tuple
from unittest import mock

from twisted.enterprise.adbapi import ConnectionPool
Expand Down Expand Up @@ -295,6 +295,53 @@ def test_dedupe(self) -> None:
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)


class GetEventsTestCase(unittest.HomeserverTestCase):
"""Test `get_events(...)`/`get_events_as_list(...)`"""

servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main

def test_get_lots_of_messages(self) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New test because I used it as a test bench and get consistent traces (previous iteration with tracing). The tracing stuff has been removed.

Feel free to nit to remove the whole thing.

"""Sanity check that `get_events(...)`/`get_events_as_list(...)` works"""
num_events = 100

user_id = self.register_user("user", "pass")
user_tok = self.login(user_id, "pass")

room_id = self.helper.create_room_as(user_id, tok=user_tok)

event_ids: Set[str] = set()
for i in range(num_events):
event = self.get_success(
inject_event(
self.hs,
room_id=room_id,
type="m.room.message",
sender=user_id,
content={
"body": f"foo{i}",
"msgtype": "m.text",
},
)
)
event_ids.add(event.event_id)

# Sanity check that we actually created the events
self.assertEqual(len(event_ids), num_events)

# This is the function under test
fetched_event_map = self.get_success(self.store.get_events(event_ids))

# Sanity check that we got the events back
self.assertIncludes(fetched_event_map.keys(), event_ids, exact=True)


class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"""Test event fetching during a database outage."""

Expand Down
Loading