diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1c630ea2..09e2526b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,7 +59,7 @@ jobs: - name: Test run: | set -euo pipefail - go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt + go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all shell: bash env: POSTGRES_HOST: localhost @@ -144,7 +144,7 @@ jobs: - name: Run end-to-end tests run: | set -euo pipefail - ./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner.log | gotestfmt + ./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner.log | gotestfmt -hide all working-directory: tests-e2e shell: bash env: diff --git a/cmd/syncv3/main.go b/cmd/syncv3/main.go index a43fa3a1..66b19b3d 100644 --- a/cmd/syncv3/main.go +++ b/cmd/syncv3/main.go @@ -173,9 +173,10 @@ func main() { panic("invalid value for " + EnvMaxConns + ": " + args[EnvMaxConns]) } h2, h3 := syncv3.Setup(args[EnvServer], args[EnvDB], args[EnvSecret], syncv3.Opts{ - AddPrometheusMetrics: args[EnvPrometheus] != "", - DBMaxConns: maxConnsInt, - DBConnMaxIdleTime: time.Hour, + AddPrometheusMetrics: args[EnvPrometheus] != "", + DBMaxConns: maxConnsInt, + DBConnMaxIdleTime: time.Hour, + MaxTransactionIDDelay: time.Second, }) go h2.StartV2Pollers() diff --git a/pubsub/v2.go b/pubsub/v2.go index 7dfb01e0..6b85d202 100644 --- a/pubsub/v2.go +++ b/pubsub/v2.go @@ -41,12 +41,15 @@ type V2Accumulate struct { func (*V2Accumulate) Type() string { return "V2Accumulate" } -// V2TransactionID is emitted by a poller when it sees an event with a transaction ID. +// V2TransactionID is emitted by a poller when it sees an event with a transaction ID, +// or when it is certain that no other poller will see a transaction ID for this event +// (the "all-clear"). type V2TransactionID struct { EventID string - UserID string + RoomID string + UserID string // of the sender DeviceID string - TransactionID string + TransactionID string // Note: an empty transaction ID represents the all-clear. NID int64 } diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 15d8037e..6a9b866a 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -3,7 +3,6 @@ package handler2 import ( "context" "encoding/json" - "fmt" "hash/fnv" "os" "sync" @@ -240,15 +239,24 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev // Remember any transaction IDs that may be unique to this user eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order eventIDToTxnID := make(map[string]string, len(timeline)) // event_id -> txn_id + // Also remember events which were sent by this user but lack a transaction ID. + eventIDsLackingTxns := make([]string, 0, len(timeline)) + for _, e := range timeline { - txnID := gjson.GetBytes(e, "unsigned.transaction_id") - if !txnID.Exists() { + parsed := gjson.ParseBytes(e) + eventID := parsed.Get("event_id").Str + + if txnID := parsed.Get("unsigned.transaction_id"); txnID.Exists() { + eventIDsWithTxns = append(eventIDsWithTxns, eventID) + eventIDToTxnID[eventID] = txnID.Str continue } - eventID := gjson.GetBytes(e, "event_id").Str - eventIDsWithTxns = append(eventIDsWithTxns, eventID) - eventIDToTxnID[eventID] = txnID.Str + + if sender := parsed.Get("sender"); sender.Str == userID { + eventIDsLackingTxns = append(eventIDsLackingTxns, eventID) + } } + if len(eventIDToTxnID) > 0 { // persist the txn IDs err := h.Store.TransactionsTable.Insert(userID, deviceID, eventIDToTxnID) @@ -265,56 +273,63 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) return } - if numNew == 0 { - // no new events - return + + // We've updated the database. Now tell any pubsub listeners what we learned. + if numNew != 0 { + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{ + RoomID: roomID, + PrevBatch: prevBatch, + EventNIDs: latestNIDs, + }) } - h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{ - RoomID: roomID, - PrevBatch: prevBatch, - EventNIDs: latestNIDs, - }) - if len(eventIDToTxnID) > 0 { + if len(eventIDToTxnID) > 0 || len(eventIDsLackingTxns) > 0 { // The call to h.Store.Accumulate above only tells us about new events' NIDS; // for existing events we need to requery the database to fetch them. // Rather than try to reuse work, keep things simple and just fetch NIDs for // all events with txnIDs. var nidsByIDs map[string]int64 + eventIDsToFetch := append(eventIDsWithTxns, eventIDsLackingTxns...) err = sqlutil.WithTransaction(h.Store.DB, func(txn *sqlx.Tx) error { - nidsByIDs, err = h.Store.EventsTable.SelectNIDsByIDs(txn, eventIDsWithTxns) + nidsByIDs, err = h.Store.EventsTable.SelectNIDsByIDs(txn, eventIDsToFetch) return err }) if err != nil { logger.Err(err). Int("timeline", len(timeline)). Int("num_transaction_ids", len(eventIDsWithTxns)). + Int("num_missing_transaction_ids", len(eventIDsLackingTxns)). Str("room", roomID). - Msg("V2: failed to fetch nids for events with transaction_ids") + Msg("V2: failed to fetch nids for event transaction_id handling") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) return } - for _, eventID := range eventIDsWithTxns { + for eventID, nid := range nidsByIDs { txnID, ok := eventIDToTxnID[eventID] - if !ok { - continue - } - nid, ok := nidsByIDs[eventID] - if !ok { - errMsg := "V2: failed to fetch NID for txnID" - logger.Error().Str("user", userID).Str("device", deviceID).Msg(errMsg) - internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf("errMsg")) - continue + if ok { + h.pMap.SeenTxnID(eventID) + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{ + EventID: eventID, + RoomID: roomID, + UserID: userID, + DeviceID: deviceID, + TransactionID: txnID, + NID: nid, + }) + } else { + allClear, _ := h.pMap.MissingTxnID(eventID, userID, deviceID) + if allClear { + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{ + EventID: eventID, + RoomID: roomID, + UserID: userID, + DeviceID: deviceID, + TransactionID: "", + NID: nid, + }) + } } - - h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{ - EventID: eventID, - UserID: userID, - DeviceID: deviceID, - TransactionID: txnID, - NID: nid, - }) } } } diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go index fa315228..3b69f07b 100644 --- a/sync2/handler2/handler_test.go +++ b/sync2/handler2/handler_test.go @@ -42,6 +42,14 @@ func (p *mockPollerMap) NumPollers() int { } func (p *mockPollerMap) Terminate() {} +func (p *mockPollerMap) MissingTxnID(eventID, userID, deviceID string) (bool, error) { + return false, nil +} + +func (p *mockPollerMap) SeenTxnID(eventID string) error { + return nil +} + func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) { p.calls = append(p.calls, pollInfo{ pid: pid, diff --git a/sync2/poller.go b/sync2/poller.go index 062366a6..77e7767f 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -64,6 +64,8 @@ type IPollerMap interface { EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) NumPollers() int Terminate() + MissingTxnID(eventID, userID, deviceID string) (bool, error) + SeenTxnID(eventID string) error } // PollerMap is a map of device ID to Poller @@ -72,6 +74,7 @@ type PollerMap struct { callbacks V2DataReceiver pollerMu *sync.Mutex Pollers map[PollerID]*poller + pendingTxnIDs *PendingTransactionIDs executor chan func() executorRunning bool processHistogramVec *prometheus.HistogramVec @@ -112,6 +115,7 @@ func NewPollerMap(v2Client Client, enablePrometheus bool) *PollerMap { Pollers: make(map[PollerID]*poller), executor: make(chan func(), 0), } + pm.pendingTxnIDs = NewPendingTransactionIDs(pm.deviceIDs) if enablePrometheus { pm.processHistogramVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "sliding_sync", @@ -195,6 +199,28 @@ func (h *PollerMap) NumPollers() (count int) { return } +// deviceIDs returns the slice of all devices currently being polled for by this user. +// The return value is brand-new and is fully owned by the caller. +func (h *PollerMap) deviceIDs(userID string) []string { + h.pollerMu.Lock() + defer h.pollerMu.Unlock() + var devices []string + for _, p := range h.Pollers { + if !p.terminated.Load() && p.userID == userID { + devices = append(devices, p.deviceID) + } + } + return devices +} + +func (h *PollerMap) MissingTxnID(eventID, userID, deviceID string) (bool, error) { + return h.pendingTxnIDs.MissingTxnID(eventID, userID, deviceID) +} + +func (h *PollerMap) SeenTxnID(eventID string) error { + return h.pendingTxnIDs.SeenTxnID(eventID) +} + // EnsurePolling makes sure there is a poller for this device, making one if need be. // Blocks until at least 1 sync is done if and only if the poller was just created. // This ensures that calls to the database will return data. diff --git a/sync2/txnid.go b/sync2/txnid.go index 00e0ede9..79e7c006 100644 --- a/sync2/txnid.go +++ b/sync2/txnid.go @@ -1,38 +1,121 @@ package sync2 import ( + "fmt" + "sync" "time" "github.com/ReneKroon/ttlcache/v2" ) -type TransactionIDCache struct { - cache *ttlcache.Cache +type loaderFunc func(userID string) (deviceIDs []string) + +// PendingTransactionIDs is (conceptually) a map from event IDs to a list of device IDs. +// Its keys are the IDs of event we've seen which a) lack a transaction ID, and b) were +// sent by one of the users we are polling for. The values are the list of the sender's +// devices whose pollers are yet to see a transaction ID. +// +// If another poller sees the same event +// +// - with a transaction ID, it emits a V2TransactionID payload with that ID and +// removes the event ID from this map. +// +// - without a transaction ID, it removes the polling device ID from the values +// list. If the device ID list is now empty, the poller emits an "all clear" +// V2TransactionID payload. +// +// This is a best-effort affair to ensure that the rest of the proxy can wait for +// transaction IDs to appear before transmitting an event down /sync to its sender. +// +// It's possible that we add an entry to this map and then the list of remaining +// device IDs becomes out of date, either due to a new device creation or an +// existing device expiring. We choose not to handle this case, because it is relatively +// rare. +// +// To avoid the map growing without bound, we use a ttlcache and drop entries +// after a short period of time. +type PendingTransactionIDs struct { + // mu guards the pending field. See MissingTxnID for rationale. + mu sync.Mutex + pending *ttlcache.Cache + // loader should provide the list of device IDs + loader loaderFunc } -func NewTransactionIDCache() *TransactionIDCache { +func NewPendingTransactionIDs(loader loaderFunc) *PendingTransactionIDs { c := ttlcache.NewCache() c.SetTTL(5 * time.Minute) // keep transaction IDs for 5 minutes before forgetting about them c.SkipTTLExtensionOnHit(true) // we don't care how many times they ask for the item, 5min is the limit. - return &TransactionIDCache{ - cache: c, + return &PendingTransactionIDs{ + mu: sync.Mutex{}, + pending: c, + loader: loader, } } -// Store a new transaction ID received via v2 /sync -func (c *TransactionIDCache) Store(userID, eventID, txnID string) { - c.cache.Set(cacheKey(userID, eventID), txnID) -} +// MissingTxnID should be called to report that this device ID did not see a +// transaction ID for this event ID. Returns true if this is the first time we know +// for sure that we'll never see a txn ID for this event. +func (c *PendingTransactionIDs) MissingTxnID(eventID, userID, myDeviceID string) (bool, error) { + // While ttlcache is threadsafe, it does not provide a way to atomically update + // (get+set) a value, which means we are still open to races. For example: + // + // - We have three pollers A, B, C. + // - Poller A sees an event without txn id and calls MissingTxnID. + // - `c.pending.Get()` fails, so we load up all device IDs: [A, B, C]. + // - Then `c.pending.Set()` with [B, C]. + // - Poller B sees the same event, also missing txn ID and calls MissingTxnID. + // - Poller C does the same concurrently. + // + // If the Get+Set isn't atomic, then we might do e.g. + // - B gets [B, C] and prepares to write [C]. + // - C gets [B, C] and prepares to write [B]. + // - Last writer wins. Either way, we never write [] and so never return true + // (the all-clear signal.) + // + // This wouldn't be the end of the world (the API process has a maximum delay, and + // the ttlcache will expire the entry), but it would still be nice to avoid it. + c.mu.Lock() + defer c.mu.Unlock() -// Get a transaction ID previously stored. -func (c *TransactionIDCache) Get(userID, eventID string) string { - val, _ := c.cache.Get(cacheKey(userID, eventID)) - if val != nil { - return val.(string) + data, err := c.pending.Get(eventID) + if err == ttlcache.ErrNotFound { + data = c.loader(userID) + } else if err != nil { + return false, fmt.Errorf("PendingTransactionIDs: failed to get device ids: %w", err) } - return "" + + deviceIDs, ok := data.([]string) + if !ok { + return false, fmt.Errorf("PendingTransactionIDs: failed to cast device IDs") + } + + deviceIDs, changed := removeDevice(myDeviceID, deviceIDs) + if changed { + err = c.pending.Set(eventID, deviceIDs) + if err != nil { + return false, fmt.Errorf("PendingTransactionIDs: failed to set device IDs: %w", err) + } + } + return changed && len(deviceIDs) == 0, nil +} + +// SeenTxnID should be called to report that this device saw a transaction ID +// for this event. +func (c *PendingTransactionIDs) SeenTxnID(eventID string) error { + c.mu.Lock() + defer c.mu.Unlock() + return c.pending.Set(eventID, []string{}) } -func cacheKey(userID, eventID string) string { - return userID + " " + eventID +// removeDevice takes a device ID slice and returns a device ID slice with one +// particular string removed. Assumes that the given slice has no duplicates. +// Does not modify the given slice in situ. +func removeDevice(device string, devices []string) ([]string, bool) { + for i, otherDevice := range devices { + if otherDevice == device { + return append(devices[:i], devices[i+1:]...), true + } + } + return devices, false } diff --git a/sync2/txnid_test.go b/sync2/txnid_test.go index 72d1967e..f9b15990 100644 --- a/sync2/txnid_test.go +++ b/sync2/txnid_test.go @@ -2,54 +2,107 @@ package sync2 import "testing" -func TestTransactionIDCache(t *testing.T) { - alice := "@alice:localhost" - bob := "@bob:localhost" - eventA := "$a:localhost" - eventB := "$b:localhost" - eventC := "$c:localhost" - txn1 := "1" - txn2 := "2" - cache := NewTransactionIDCache() - cache.Store(alice, eventA, txn1) - cache.Store(bob, eventB, txn1) // different users can use same txn ID - cache.Store(alice, eventC, txn2) - - testCases := []struct { - eventID string - userID string - want string - }{ - { - eventID: eventA, - userID: alice, - want: txn1, - }, - { - eventID: eventB, - userID: bob, - want: txn1, - }, - { - eventID: eventC, - userID: alice, - want: txn2, - }, - { - eventID: "$invalid", - userID: alice, - want: "", - }, - { - eventID: eventA, - userID: "@invalid", - want: "", - }, +func TestPendingTransactionIDs(t *testing.T) { + pollingDevicesByUser := map[string][]string{ + "alice": {"A1", "A2"}, + "bob": {"B1"}, + "chris": {}, + "delia": {"D1", "D2", "D3", "D4"}, + "enid": {"E1", "E2"}, } - for _, tc := range testCases { - txnID := cache.Get(tc.userID, tc.eventID) - if txnID != tc.want { - t.Errorf("%+v: got %v want %v", tc, txnID, tc.want) + mockLoad := func(userID string) (deviceIDs []string) { + devices, ok := pollingDevicesByUser[userID] + if !ok { + t.Fatalf("Mock didn't have devices for %s", userID) } + newDevices := make([]string, len(devices)) + copy(newDevices, devices) + return newDevices + } + + pending := NewPendingTransactionIDs(mockLoad) + + // Alice. + // We're tracking two of Alice's devices. + allClear, err := pending.MissingTxnID("event1", "alice", "A1") + assertNoError(t, err) + assertAllClear(t, allClear, false) // waiting on A2 + + // If for some reason the poller sees the same event for the same device, we should + // still be waiting for A2. + allClear, err = pending.MissingTxnID("event1", "alice", "A1") + assertNoError(t, err) + assertAllClear(t, allClear, false) + + // If for some reason Alice spun up a new device, we are still going to be waiting + // for A2. + allClear, err = pending.MissingTxnID("event1", "alice", "A_unknown_device") + assertNoError(t, err) + assertAllClear(t, allClear, false) + + // If A2 sees the event without a txnID, we should emit the all clear signal. + allClear, err = pending.MissingTxnID("event1", "alice", "A2") + assertNoError(t, err) + assertAllClear(t, allClear, true) + + // If for some reason A2 sees the event a second time, we shouldn't re-emit the + // all clear signal. + allClear, err = pending.MissingTxnID("event1", "alice", "A2") + assertNoError(t, err) + assertAllClear(t, allClear, false) + + // Bob. + // We're only tracking one device for Bob + allClear, err = pending.MissingTxnID("event2", "bob", "B1") + assertNoError(t, err) + assertAllClear(t, allClear, true) // not waiting on any devices + + // Chris. + // We're not tracking any devices for Chris. A MissingTxnID call for him shouldn't + // cause anything to explode. + allClear, err = pending.MissingTxnID("event3", "chris", "C_unknown_device") + assertNoError(t, err) + + // Delia. + // Delia is tracking four devices. + allClear, err = pending.MissingTxnID("event4", "delia", "D1") + assertNoError(t, err) + assertAllClear(t, allClear, false) // waiting on D2, D3 and D4 + + // One of Delia's devices, say D2, sees a txn ID for event 4. + err = pending.SeenTxnID("event4") + assertNoError(t, err) + + // The other devices see the event. Neither should emit all clear. + allClear, err = pending.MissingTxnID("event4", "delia", "D3") + assertNoError(t, err) + assertAllClear(t, allClear, false) + + allClear, err = pending.MissingTxnID("event4", "delia", "D4") + assertNoError(t, err) + assertAllClear(t, allClear, false) + + // Enid. + // Enid has two devices. Her first poller (E1) is lucky and sees the transaction ID. + err = pending.SeenTxnID("event5") + assertNoError(t, err) + + // Her second poller misses the transaction ID, but this shouldn't cause an all clear. + allClear, err = pending.MissingTxnID("event4", "delia", "E2") + assertNoError(t, err) + assertAllClear(t, allClear, false) +} + +func assertAllClear(t *testing.T, got bool, want bool) { + t.Helper() + if got != want { + t.Errorf("Expected allClear=%t, got %t", want, got) + } +} + +func assertNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("got error: %s", err) } } diff --git a/sync3/caches/global.go b/sync3/caches/global.go index 10d5c342..fd6d80f6 100644 --- a/sync3/caches/global.go +++ b/sync3/caches/global.go @@ -21,6 +21,15 @@ type EventData struct { Content gjson.Result Timestamp uint64 Sender string + // TransactionID is the unsigned.transaction_id field in the event as stored in the + // syncv3_events table, or the empty string if there is no such field. + // + // We may see the event on poller A without a transaction_id, and then later on + // poller B with a transaction_id. If this happens, we make a temporary note of the + // transaction_id in the syncv3_txns table, but do not edit the persisted event. + // This means that this field is not authoritative; we only include it here as a + // hint to avoid unnecessary waits for V2TransactionID payloads. + TransactionID string // the number of joined users in this room. Use this value and don't try to work it out as you // may get it wrong due to Synapse sending duplicate join events(!) This value has them de-duped diff --git a/sync3/conn.go b/sync3/conn.go index 33fd87bb..d9d3682a 100644 --- a/sync3/conn.go +++ b/sync3/conn.go @@ -32,6 +32,7 @@ type ConnHandler interface { // status code to send back. OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool, start time.Time) (*Response, error) OnUpdate(ctx context.Context, update caches.Update) + PublishEventsUpTo(roomID string, nid int64) Destroy() Alive() bool } diff --git a/sync3/conn_test.go b/sync3/conn_test.go index c326938c..8b148ea6 100644 --- a/sync3/conn_test.go +++ b/sync3/conn_test.go @@ -25,6 +25,7 @@ func (c *connHandlerMock) UserID() string { func (c *connHandlerMock) Destroy() {} func (c *connHandlerMock) Alive() bool { return true } func (c *connHandlerMock) OnUpdate(ctx context.Context, update caches.Update) {} +func (c *connHandlerMock) PublishEventsUpTo(roomID string, nid int64) {} // Test that Conn can send and receive requests based on positions func TestConn(t *testing.T) { diff --git a/sync3/connmap.go b/sync3/connmap.go index 0cae2748..8d08dd27 100644 --- a/sync3/connmap.go +++ b/sync3/connmap.go @@ -208,3 +208,12 @@ func (m *ConnMap) closeConn(conn *Conn) { h.Destroy() m.updateMetrics(len(m.connIDToConn)) } + +func (m *ConnMap) ClearUpdateQueues(userID, roomID string, nid int64) { + m.mu.Lock() + defer m.mu.Unlock() + + for _, conn := range m.userIDToConn[userID] { + conn.handler.PublishEventsUpTo(roomID, nid) + } +} diff --git a/sync3/dispatcher.go b/sync3/dispatcher.go index 3d80ff9a..12ddfeb9 100644 --- a/sync3/dispatcher.go +++ b/sync3/dispatcher.go @@ -87,14 +87,15 @@ func (d *Dispatcher) newEventData(event json.RawMessage, roomID string, latestPo eventType := ev.Get("type").Str return &caches.EventData{ - Event: event, - RoomID: roomID, - EventType: eventType, - StateKey: stateKey, - Content: ev.Get("content"), - NID: latestPos, - Timestamp: ev.Get("origin_server_ts").Uint(), - Sender: ev.Get("sender").Str, + Event: event, + RoomID: roomID, + EventType: eventType, + StateKey: stateKey, + Content: ev.Get("content"), + NID: latestPos, + Timestamp: ev.Get("origin_server_ts").Uint(), + Sender: ev.Get("sender").Str, + TransactionID: ev.Get("unsigned.transaction_id").Str, } } diff --git a/sync3/handler/connstate.go b/sync3/handler/connstate.go index 1611e2a7..c6a39378 100644 --- a/sync3/handler/connstate.go +++ b/sync3/handler/connstate.go @@ -42,7 +42,8 @@ type ConnState struct { // roomID -> latest load pos loadPositions map[string]int64 - live *connStateLive + txnIDWaiter *TxnIDWaiter + live *connStateLive globalCache *caches.GlobalCache userCache *caches.UserCache @@ -59,7 +60,7 @@ type ConnState struct { func NewConnState( userID, deviceID string, userCache *caches.UserCache, globalCache *caches.GlobalCache, ex extensions.HandlerInterface, joinChecker JoinChecker, setupHistVec *prometheus.HistogramVec, histVec *prometheus.HistogramVec, - maxPendingEventUpdates int, + maxPendingEventUpdates int, maxTransactionIDDelay time.Duration, ) *ConnState { cs := &ConnState{ globalCache: globalCache, @@ -80,6 +81,13 @@ func NewConnState( ConnState: cs, updates: make(chan caches.Update, maxPendingEventUpdates), } + cs.txnIDWaiter = NewTxnIDWaiter( + userID, + maxTransactionIDDelay, + func(delayed bool, update caches.Update) { + cs.live.onUpdate(update) + }, + ) // subscribe for updates before loading. We risk seeing dupes but that's fine as load positions // will stop us double-processing. cs.userCacheID = cs.userCache.Subsribe(cs) @@ -663,7 +671,8 @@ func (s *ConnState) UserID() string { } func (s *ConnState) OnUpdate(ctx context.Context, up caches.Update) { - s.live.onUpdate(up) + // will eventually call s.live.onUpdate + s.txnIDWaiter.Ingest(up) } // Called by the user cache when updates arrive @@ -679,15 +688,19 @@ func (s *ConnState) OnRoomUpdate(ctx context.Context, up caches.RoomUpdate) { } internal.AssertWithContext(ctx, "missing global room metadata", update.GlobalRoomMetadata() != nil) internal.Logf(ctx, "connstate", "queued update %d", update.EventData.NID) - s.live.onUpdate(update) + s.OnUpdate(ctx, update) case caches.RoomUpdate: internal.AssertWithContext(ctx, "missing global room metadata", update.GlobalRoomMetadata() != nil) - s.live.onUpdate(update) + s.OnUpdate(ctx, update) default: logger.Warn().Str("room_id", up.RoomID()).Msg("OnRoomUpdate unknown update type") } } +func (s *ConnState) PublishEventsUpTo(roomID string, nid int64) { + s.txnIDWaiter.PublishUpToNID(roomID, nid) +} + // clampSliceRangeToListSize helps us to send client-friendly SYNC and INVALIDATE ranges. // // Suppose the client asks for a window on positions [10, 19]. If the list diff --git a/sync3/handler/connstate_test.go b/sync3/handler/connstate_test.go index 17700ea8..46c00d6c 100644 --- a/sync3/handler/connstate_test.go +++ b/sync3/handler/connstate_test.go @@ -107,7 +107,7 @@ func TestConnStateInitial(t *testing.T) { } return result } - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0) if userID != cs.UserID() { t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID) } @@ -272,7 +272,7 @@ func TestConnStateMultipleRanges(t *testing.T) { userCache.LazyRoomDataOverride = mockLazyRoomOverride dispatcher.Register(context.Background(), userCache.UserID, userCache) dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0) // request first page res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ @@ -451,7 +451,7 @@ func TestBumpToOutsideRange(t *testing.T) { userCache.LazyRoomDataOverride = mockLazyRoomOverride dispatcher.Register(context.Background(), userCache.UserID, userCache) dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0) // Ask for A,B res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ Lists: map[string]sync3.RequestList{"a": { @@ -562,7 +562,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) { } dispatcher.Register(context.Background(), userCache.UserID, userCache) dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0) // subscribe to room D res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ RoomSubscriptions: map[string]sync3.RoomSubscription{ diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index c02c5a29..a8a56c0a 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -58,6 +58,7 @@ type SyncLiveHandler struct { GlobalCache *caches.GlobalCache maxPendingEventUpdates int + maxTransactionIDDelay time.Duration setupHistVec *prometheus.HistogramVec histVec *prometheus.HistogramVec @@ -67,6 +68,7 @@ type SyncLiveHandler struct { func NewSync3Handler( store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, secret string, pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, maxPendingEventUpdates int, + maxTransactionIDDelay time.Duration, ) (*SyncLiveHandler, error) { logger.Info().Msg("creating handler") sh := &SyncLiveHandler{ @@ -78,6 +80,7 @@ func NewSync3Handler( Dispatcher: sync3.NewDispatcher(), GlobalCache: caches.NewGlobalCache(store), maxPendingEventUpdates: maxPendingEventUpdates, + maxTransactionIDDelay: maxTransactionIDDelay, } sh.Extensions = &extensions.Handler{ Store: store, @@ -411,7 +414,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ // to check for an existing connection though, as it's possible for the client to call /sync // twice for a new connection. conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler { - return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates) + return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay) }) if created { log.Info().Msg("created new connection") @@ -622,7 +625,11 @@ func (h *SyncLiveHandler) Accumulate(p *pubsub.V2Accumulate) { func (h *SyncLiveHandler) OnTransactionID(p *pubsub.V2TransactionID) { _, task := internal.StartTask(context.Background(), "TransactionID") defer task.End() - // TODO implement me + + // There is some event E for which we now have a transaction ID, or else now know + // that we will never get a transaction ID. In either case, tell the sender's + // connections to unblock that event in the transaction ID waiter. + h.ConnMap.ClearUpdateQueues(p.UserID, p.RoomID, p.NID) } // Called from the v2 poller, implements V2DataReceiver diff --git a/sync3/handler/txn_id_waiter.go b/sync3/handler/txn_id_waiter.go new file mode 100644 index 00000000..36986344 --- /dev/null +++ b/sync3/handler/txn_id_waiter.go @@ -0,0 +1,94 @@ +package handler + +import ( + "github.com/matrix-org/sliding-sync/sync3/caches" + "sync" + "time" +) + +type TxnIDWaiter struct { + userID string + publish func(delayed bool, update caches.Update) + // mu guards the queues map. + mu sync.Mutex + queues map[string][]*caches.RoomEventUpdate + maxDelay time.Duration +} + +func NewTxnIDWaiter(userID string, maxDelay time.Duration, publish func(bool, caches.Update)) *TxnIDWaiter { + return &TxnIDWaiter{ + userID: userID, + publish: publish, + mu: sync.Mutex{}, + queues: make(map[string][]*caches.RoomEventUpdate), + maxDelay: maxDelay, + // TODO: metric that tracks how long events were queued for. + } +} + +func (t *TxnIDWaiter) Ingest(up caches.Update) { + if t.maxDelay <= 0 { + t.publish(false, up) + return + } + + eventUpdate, isEventUpdate := up.(*caches.RoomEventUpdate) + if !isEventUpdate { + t.publish(false, up) + return + } + + ed := eventUpdate.EventData + + // An event should be queued if + // - it's a state event that our user sent, lacking a txn_id; OR + // - the room already has queued events. + t.mu.Lock() + defer t.mu.Unlock() + _, roomQueued := t.queues[ed.RoomID] + missingTxnID := ed.StateKey == nil && ed.Sender == t.userID && ed.TransactionID == "" + if !(missingTxnID || roomQueued) { + t.publish(false, up) + return + } + + // We've decided to queue the event. + queue, exists := t.queues[ed.RoomID] + if !exists { + queue = make([]*caches.RoomEventUpdate, 0, 10) + } + // TODO: bound the queue size? + t.queues[ed.RoomID] = append(queue, eventUpdate) + + time.AfterFunc(t.maxDelay, func() { t.PublishUpToNID(ed.RoomID, ed.NID) }) +} + +func (t *TxnIDWaiter) PublishUpToNID(roomID string, publishNID int64) { + t.mu.Lock() + defer t.mu.Unlock() + + queue, exists := t.queues[roomID] + if !exists { + return + } + + var i int + for i = 0; i < len(queue); i++ { + // Scan forwards through the queue until we find an event with nid > publishNID. + if queue[i].EventData.NID > publishNID { + break + } + } + // Now queue[:i] has events with nid <= publishNID, and queue[i:] has nids > publishNID. + // strip off the first i events from the slice and publish them. + toPublish, queue := queue[:i], queue[i:] + if len(queue) == 0 { + delete(t.queues, roomID) + } else { + t.queues[roomID] = queue + } + + for _, eventUpdate := range toPublish { + t.publish(true, eventUpdate) + } +} diff --git a/sync3/handler/txn_id_waiter_test.go b/sync3/handler/txn_id_waiter_test.go new file mode 100644 index 00000000..6f4e53c1 --- /dev/null +++ b/sync3/handler/txn_id_waiter_test.go @@ -0,0 +1,390 @@ +package handler + +import ( + "github.com/matrix-org/sliding-sync/sync3/caches" + "github.com/tidwall/gjson" + "testing" + "time" +) + +type publishArg struct { + delayed bool + update caches.Update +} + +// Test that +// - events are (reported as being) delayed when we expect them to be +// - delayed events are automatically published after the maximum delay period +func TestTxnIDWaiter_QueuingLogic(t *testing.T) { + const alice = "alice" + const bob = "bob" + const room1 = "!theroom" + const room2 = "!daszimmer" + + testCases := []struct { + Name string + Ingest []caches.Update + WaitForUpdate int + ExpectDelayed bool + }{ + { + Name: "empty queue, non-event update", + Ingest: []caches.Update{&caches.AccountDataUpdate{}}, + WaitForUpdate: 0, + ExpectDelayed: false, + }, + { + Name: "empty queue, event update, another sender", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: bob, + }, + }}, + WaitForUpdate: 0, + ExpectDelayed: false, + }, + { + Name: "empty queue, event update, has txn_id", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "txntxntxn", + }, + }}, + WaitForUpdate: 0, + ExpectDelayed: false, + }, + { + Name: "empty queue, event update, no txn_id", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "", + }, + }}, + WaitForUpdate: 0, + ExpectDelayed: true, + }, + { + Name: "nonempty queue, non-event update", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "", + NID: 1, + }, + }, + &caches.AccountDataUpdate{}, + }, + WaitForUpdate: 1, + ExpectDelayed: false, // not a room event, no need to queued behind alice's event + }, + { + Name: "empty queue, join event for sender", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "", + NID: 1, + EventType: "m.room.member", + StateKey: ptr(alice), + Content: gjson.Parse(`{"membership": "join"}`), + }, + }, + }, + WaitForUpdate: 0, + ExpectDelayed: false, + }, + { + Name: "nonempty queue, join event for sender", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "", + NID: 1, + }, + }, + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "", + NID: 2, + EventType: "m.room.member", + StateKey: ptr(alice), + Content: gjson.Parse(`{"membership": "join"}`), + }, + }, + }, + WaitForUpdate: 1, + ExpectDelayed: true, + }, + + { + Name: "nonempty queue, event update, different sender", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "", + NID: 1, + }, + }, + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: bob, + NID: 2, + }, + }, + }, + WaitForUpdate: 1, + ExpectDelayed: true, // should be queued behind alice's event + }, + { + Name: "nonempty queue, event update, has txn_id", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "", + NID: 1, + }, + }, + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + NID: 2, + TransactionID: "I have a txn", + }, + }, + }, + WaitForUpdate: 1, + ExpectDelayed: true, // should still be queued behind alice's first event + }, + { + Name: "existence of queue only matters per-room", + Ingest: []caches.Update{ + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room1, + Sender: alice, + TransactionID: "", + NID: 1, + }, + }, + &caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room2, + Sender: alice, + NID: 2, + TransactionID: "I have a txn", + }, + }, + }, + WaitForUpdate: 1, + ExpectDelayed: false, // queue only tracks room1 + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + updates := make(chan publishArg, 100) + publish := func(delayed bool, update caches.Update) { + updates <- publishArg{delayed, update} + } + + w := NewTxnIDWaiter(alice, time.Millisecond, publish) + + for _, up := range tc.Ingest { + w.Ingest(up) + } + + wantedUpdate := tc.Ingest[tc.WaitForUpdate] + var got publishArg + WaitForSelectedUpdate: + for { + select { + case got = <-updates: + t.Logf("Got update %v", got.update) + if got.update == wantedUpdate { + break WaitForSelectedUpdate + } + case <-time.After(5 * time.Millisecond): + t.Fatalf("Did not see update %v published", wantedUpdate) + } + } + + if got.delayed != tc.ExpectDelayed { + t.Errorf("Got delayed=%t want delayed=%t", got.delayed, tc.ExpectDelayed) + } + }) + } +} + +// Test that PublishUpToNID +// - correctly pops off the start of the queue +// - is idempotent +// - deletes map entry if queue is empty (so that roomQueued is set correctly) +func TestTxnIDWaiter_PublishUpToNID(t *testing.T) { + const alice = "@alice:example.com" + const room = "!unimportant" + var published []publishArg + publish := func(delayed bool, update caches.Update) { + published = append(published, publishArg{delayed, update}) + } + // Use an hour's expiry to effectively disable expiry. + w := NewTxnIDWaiter(alice, time.Hour, publish) + // Ingest 5 events, each of which would be queued by themselves. + for i := int64(2); i <= 6; i++ { + w.Ingest(&caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room, + Sender: alice, + TransactionID: "", + NID: i, + }, + }) + } + + t.Log("Queue has nids [2,3,4,5,6]") + t.Log("Publishing up to 1 should do nothing") + w.PublishUpToNID(room, 1) + assertNIDs(t, published, nil) + + t.Log("Publishing up to 3 should yield nids [2, 3] in that order") + w.PublishUpToNID(room, 3) + assertNIDs(t, published, []int64{2, 3}) + assertDelayed(t, published[:2]) + + t.Log("Publishing up to 3 a second time should do nothing") + w.PublishUpToNID(room, 3) + assertNIDs(t, published, []int64{2, 3}) + + t.Log("Publishing up to 2 at this point should do nothing.") + w.PublishUpToNID(room, 2) + assertNIDs(t, published, []int64{2, 3}) + + t.Log("Publishing up to 6 should yield nids [4, 5, 6] in that order") + w.PublishUpToNID(room, 6) + assertNIDs(t, published, []int64{2, 3, 4, 5, 6}) + assertDelayed(t, published[2:5]) + + t.Log("Publishing up to 6 a second time should do nothing") + w.PublishUpToNID(room, 6) + assertNIDs(t, published, []int64{2, 3, 4, 5, 6}) + + t.Log("Ingesting another event that doesn't need to be queueing should be published immediately") + w.Ingest(&caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: room, + Sender: "@notalice:example.com", + TransactionID: "", + NID: 7, + }, + }) + assertNIDs(t, published, []int64{2, 3, 4, 5, 6, 7}) + if published[len(published)-1].delayed { + t.Errorf("Final event was delayed, but should have been published immediately") + } +} + +// Test that PublishUpToNID only publishes in the given room +func TestTxnIDWaiter_PublishUpToNID_MultipleRooms(t *testing.T) { + const alice = "@alice:example.com" + var published []publishArg + publish := func(delayed bool, update caches.Update) { + published = append(published, publishArg{delayed, update}) + } + // Use an hour's expiry to effectively disable expiry. + w := NewTxnIDWaiter(alice, time.Hour, publish) + // Ingest four queueable events across two rooms. + w.Ingest(&caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: "!room1", + Sender: alice, + TransactionID: "", + NID: 1, + }, + }) + w.Ingest(&caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: "!room2", + Sender: alice, + TransactionID: "", + NID: 2, + }, + }) + w.Ingest(&caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: "!room2", + Sender: alice, + TransactionID: "", + NID: 3, + }, + }) + w.Ingest(&caches.RoomEventUpdate{ + EventData: &caches.EventData{ + RoomID: "!room1", + Sender: alice, + TransactionID: "", + NID: 4, + }, + }) + + t.Log("Queues are [1, 4] and [2, 3]") + t.Log("Publish up to NID 4 in room 1 should yield nids [1, 4]") + w.PublishUpToNID("!room1", 4) + assertNIDs(t, published, []int64{1, 4}) + assertDelayed(t, published) + + t.Log("Queues are [1, 4] and [2, 3]") + t.Log("Publish up to NID 3 in room 2 should yield nids [2, 3]") + w.PublishUpToNID("!room2", 3) + assertNIDs(t, published, []int64{1, 4, 2, 3}) + assertDelayed(t, published) +} + +func assertDelayed(t *testing.T, published []publishArg) { + t.Helper() + for _, p := range published { + if !p.delayed { + t.Errorf("published arg with NID %d was not delayed, but we expected it to be", p.update.(*caches.RoomEventUpdate).EventData.NID) + } + } +} + +func assertNIDs(t *testing.T, published []publishArg, expectedNIDs []int64) { + t.Helper() + if len(published) != len(expectedNIDs) { + t.Errorf("Got %d nids, but expected %d", len(published), len(expectedNIDs)) + } + for i := range published { + rup, ok := published[i].update.(*caches.RoomEventUpdate) + if !ok { + t.Errorf("Update %d (%v) was not a RoomEventUpdate", i, published[i].update) + } + if rup.EventData.NID != expectedNIDs[i] { + t.Errorf("Update %d (%v) got nid %d, expected %d", i, *rup, rup.EventData.NID, expectedNIDs[i]) + } + } +} + +func ptr(s string) *string { + return &s +} diff --git a/tests-e2e/membership_transitions_test.go b/tests-e2e/membership_transitions_test.go index 2271d78e..bb721f3c 100644 --- a/tests-e2e/membership_transitions_test.go +++ b/tests-e2e/membership_transitions_test.go @@ -233,17 +233,18 @@ func TestInviteRejection(t *testing.T) { } func TestInviteAcceptance(t *testing.T) { - alice := registerNewUser(t) - bob := registerNewUser(t) + alice := registerNamedUser(t, "alice") + bob := registerNamedUser(t, "bob") // ensure that invite state correctly propagates. One room will already be in 'invite' state // prior to the first proxy sync, whereas the 2nd will transition. + t.Logf("Alice creates two rooms and invites Bob to the first.") firstInviteRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "private_chat", "name": "First"}) alice.InviteRoom(t, firstInviteRoomID, bob.UserID) secondInviteRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "private_chat", "name": "Second"}) t.Logf("first %s second %s", firstInviteRoomID, secondInviteRoomID) - // sync as bob, we should see 1 invite + t.Log("Sync as Bob, requesting invites only. He should see 1 invite") res := bob.SlidingSync(t, sync3.Request{ Lists: map[string]sync3.RequestList{ "a": { @@ -273,10 +274,12 @@ func TestInviteAcceptance(t *testing.T) { }, })) - // now invite bob + t.Log("Alice invites bob to room 2.") alice.InviteRoom(t, secondInviteRoomID, bob.UserID) + t.Log("Alice syncs until she sees Bob's invite.") alice.SlidingSyncUntilMembership(t, "", secondInviteRoomID, bob, "invite") + t.Log("Bob syncs. He should see the invite to room 2 as well.") res = bob.SlidingSync(t, sync3.Request{ Lists: map[string]sync3.RequestList{ "a": { @@ -304,13 +307,16 @@ func TestInviteAcceptance(t *testing.T) { }, })) - // now accept the invites + t.Log("Bob accept the invites.") bob.JoinRoom(t, firstInviteRoomID, nil) bob.JoinRoom(t, secondInviteRoomID, nil) + + t.Log("Alice syncs until she sees Bob join room 1.") alice.SlidingSyncUntilMembership(t, "", firstInviteRoomID, bob, "join") + t.Log("Alice syncs until she sees Bob join room 2.") alice.SlidingSyncUntilMembership(t, "", secondInviteRoomID, bob, "join") - // the list should be purged + t.Log("Bob does an incremental sync") res = bob.SlidingSync(t, sync3.Request{ Lists: map[string]sync3.RequestList{ "a": { @@ -318,12 +324,13 @@ func TestInviteAcceptance(t *testing.T) { }, }, }, WithPos(res.Pos)) + t.Log("Both of his invites should be purged.") m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0), m.MatchV3Ops( m.MatchV3DeleteOp(1), m.MatchV3DeleteOp(0), ))) - // fresh sync -> no invites + t.Log("Bob makes a fresh sliding sync request.") res = bob.SlidingSync(t, sync3.Request{ Lists: map[string]sync3.RequestList{ "a": { @@ -334,6 +341,7 @@ func TestInviteAcceptance(t *testing.T) { }, }, }) + t.Log("He should see no invites.") m.MatchResponse(t, res, m.MatchNoV3Ops(), m.MatchRoomSubscriptionsStrict(nil), m.MatchList("a", m.MatchV3Count(0))) } diff --git a/tests-e2e/transaction_id_test.go b/tests-e2e/transaction_id_test.go index c49cb5a1..c3a53ee4 100644 --- a/tests-e2e/transaction_id_test.go +++ b/tests-e2e/transaction_id_test.go @@ -32,33 +32,10 @@ func TestTransactionIDsAppear(t *testing.T) { // we cannot use MatchTimeline here because the Unsigned section contains 'age' which is not // deterministic and MatchTimeline does not do partial matches. - matchTransactionID := func(eventID, txnID string) m.RoomMatcher { - return func(r sync3.Room) error { - for _, ev := range r.Timeline { - var got Event - if err := json.Unmarshal(ev, &got); err != nil { - return fmt.Errorf("failed to unmarshal event: %s", err) - } - if got.ID != eventID { - continue - } - tx, ok := got.Unsigned["transaction_id"] - if !ok { - return fmt.Errorf("unsigned block for %s has no transaction_id", eventID) - } - gotTxnID := tx.(string) - if gotTxnID != txnID { - return fmt.Errorf("wrong transaction_id, got %s want %s", gotTxnID, txnID) - } - t.Logf("%s has txn ID %s", eventID, gotTxnID) - return nil - } - return fmt.Errorf("not found event %s", eventID) - } - } + m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ roomID: { - matchTransactionID(eventID, "foobar"), + matchTransactionID(t, eventID, "foobar"), }, })) @@ -74,8 +51,85 @@ func TestTransactionIDsAppear(t *testing.T) { res = client.SlidingSyncUntilEvent(t, res.Pos, sync3.Request{}, roomID, Event{ID: eventID}) m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ roomID: { - matchTransactionID(eventID, "foobar2"), + matchTransactionID(t, eventID, "foobar2"), }, })) } + +// This test has 1 poller expecting a txn ID and 10 others that won't see one. +// We test that sending device sees a txnID. Without the TxnIDWaiter logic in place, +// this test is likely (but not guaranteed) to fail. +func TestTransactionIDsAppearWithMultiplePollers(t *testing.T) { + alice := registerNamedUser(t, "alice") + + t.Log("Alice creates a room and syncs until she sees it.") + roomID := alice.CreateRoom(t, map[string]interface{}{}) + res := alice.SlidingSync(t, sync3.Request{ + Lists: map[string]sync3.RequestList{ + "a": { + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 10, + }, + Ranges: sync3.SliceRanges{{0, 20}}, + }, + }, + }) + m.MatchResponse(t, res, m.MatchRoomSubscription(roomID)) + + t.Log("Alice makes other devices and starts them syncing.") + for i := 0; i < 10; i++ { + device := *alice + device.Login(t, "password", fmt.Sprintf("device_%d", i)) + device.SlidingSync(t, sync3.Request{ + Lists: map[string]sync3.RequestList{ + "a": { + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 10, + }, + Ranges: sync3.SliceRanges{{0, 20}}, + }, + }, + }) + } + + t.Log("Alice sends a message with a transaction ID.") + const txnID = "foobar" + sendRes := alice.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "rooms", roomID, "send", "m.room.message", txnID}, + WithJSONBody(t, map[string]interface{}{ + "msgtype": "m.text", + "body": "Hello, world!", + })) + body := ParseJSON(t, sendRes) + eventID := GetJSONFieldStr(t, body, "event_id") + + t.Log("Alice syncs on her main devices until she sees her message.") + res = alice.SlidingSyncUntilEventID(t, res.Pos, roomID, eventID) + + m.MatchResponse(t, res, m.MatchRoomSubscription(roomID, matchTransactionID(t, eventID, txnID))) +} + +func matchTransactionID(t *testing.T, eventID, txnID string) m.RoomMatcher { + return func(r sync3.Room) error { + for _, ev := range r.Timeline { + var got Event + if err := json.Unmarshal(ev, &got); err != nil { + return fmt.Errorf("failed to unmarshal event: %s", err) + } + if got.ID != eventID { + continue + } + tx, ok := got.Unsigned["transaction_id"] + if !ok { + return fmt.Errorf("unsigned block for %s has no transaction_id", eventID) + } + gotTxnID := tx.(string) + if gotTxnID != txnID { + return fmt.Errorf("wrong transaction_id, got %s want %s", gotTxnID, txnID) + } + t.Logf("%s has txn ID %s", eventID, gotTxnID) + return nil + } + return fmt.Errorf("not found event %s", eventID) + } +} diff --git a/tests-integration/timeline_test.go b/tests-integration/timeline_test.go index cb188d56..62ba9230 100644 --- a/tests-integration/timeline_test.go +++ b/tests-integration/timeline_test.go @@ -3,6 +3,7 @@ package syncv3 import ( "encoding/json" "fmt" + slidingsync "github.com/matrix-org/sliding-sync" "testing" "time" @@ -689,6 +690,132 @@ func TestTimelineTxnID(t *testing.T) { )) } +// TestTimelineTxnID checks that Alice sees her transaction_id if +// - Bob's poller sees Alice's event, +// - Alice's poller sees Alice's event with txn_id, and +// - Alice syncs. +// +// This test is similar but not identical. It checks that Alice sees her transaction_id if +// - Bob's poller sees Alice's event, +// - Alice does an incremental sync, which should omit her event, +// - Alice's poller sees Alice's event with txn_id, and +// - Alice syncs, seeing her event with txn_id. +func TestTimelineTxnIDBuffersForTxnID(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + // setup code + v2 := runTestV2Server(t) + v3 := runTestServer(t, v2, pqString, slidingsync.Opts{ + // This needs to be greater than the request timeout, which is hardcoded to a + // minimum of 100ms in connStateLive.liveUpdate. This ensures that the + // liveUpdate call finishes before the TxnIDWaiter publishes the update, + // meaning that Alice doesn't see her event before the txn ID is known. + MaxTransactionIDDelay: 200 * time.Millisecond, + }) + defer v2.close() + defer v3.close() + roomID := "!a:localhost" + latestTimestamp := time.Now() + t.Log("Alice and Bob are in the same room") + room := roomEvents{ + roomID: roomID, + events: append( + createRoomState(t, alice, latestTimestamp), + testutils.NewJoinEvent(t, bob), + ), + } + v2.addAccount(t, alice, aliceToken) + v2.addAccount(t, bob, bobToken) + v2.queueResponse(alice, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(room), + }, + NextBatch: "alice_after_initial_poll", + }) + v2.queueResponse(bob, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(room), + }, + NextBatch: "bob_after_initial_poll", + }) + + t.Log("Alice and Bob make initial sliding syncs.") + aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{ + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 10}, + }, + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 2, + }, + }, + }, + }) + bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{ + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 10}, + }, + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 2, + }, + }, + }, + }) + + t.Log("Alice has sent a message... but it arrives down Bob's poller first, without a transaction_id") + txnID := "m1234567890" + newEvent := testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{"body": "hi"}, testutils.WithUnsigned(map[string]interface{}{ + "transaction_id": txnID, + })) + newEventNoUnsigned, err := sjson.DeleteBytes(newEvent, "unsigned") + if err != nil { + t.Fatalf("failed to delete bytes: %s", err) + } + + v2.queueResponse(bob, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(roomEvents{ + roomID: roomID, + events: []json.RawMessage{newEventNoUnsigned}, + }), + }, + }) + t.Log("Bob's poller sees the message.") + v2.waitUntilEmpty(t, bob) + + t.Log("Bob makes an incremental sliding sync") + bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{}) + t.Log("Bob should see the message without a transaction_id") + m.MatchResponse(t, bobRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription( + roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEventNoUnsigned}), + )) + + t.Log("Alice requests an incremental sliding sync with no request changes.") + aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{}) + t.Log("Alice should see no messages.") + m.MatchResponse(t, aliceRes, m.MatchRoomSubscriptionsStrict(nil)) + + // Now the message arrives down Alice's poller. + v2.queueResponse(alice, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(roomEvents{ + roomID: roomID, + events: []json.RawMessage{newEvent}, + }), + }, + }) + t.Log("Alice's poller sees the message with transaction_id.") + v2.waitUntilEmpty(t, alice) + + t.Log("Alice makes another incremental sync request.") + aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{}) + t.Log("Alice's sync response includes the message with the txn ID.") + m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription( + roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEvent}), + )) + +} + // Executes a sync v3 request without a ?pos and asserts that the count, rooms and timeline events m.Match the inputs given. func testTimelineLoadInitialEvents(v3 *testV3Server, token string, count int, wantRooms []roomEvents, numTimelineEventsPerRoom int) func(t *testing.T) { return func(t *testing.T) { diff --git a/tests-integration/v3_test.go b/tests-integration/v3_test.go index de72d954..c29b2126 100644 --- a/tests-integration/v3_test.go +++ b/tests-integration/v3_test.go @@ -370,12 +370,14 @@ func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postg TestingSynchronousPubsub: true, // critical to avoid flakey tests AddPrometheusMetrics: false, MaxPendingEventUpdates: 200, + MaxTransactionIDDelay: 0, // disable the txnID buffering to avoid flakey tests } if len(opts) > 0 { opt := opts[0] combinedOpts.AddPrometheusMetrics = opt.AddPrometheusMetrics combinedOpts.DBConnMaxIdleTime = opt.DBConnMaxIdleTime combinedOpts.DBMaxConns = opt.DBMaxConns + combinedOpts.MaxTransactionIDDelay = opt.MaxTransactionIDDelay if opt.MaxPendingEventUpdates > 0 { combinedOpts.MaxPendingEventUpdates = opt.MaxPendingEventUpdates handler.BufferWaitTime = 5 * time.Millisecond diff --git a/v3.go b/v3.go index e8948e94..72cf37c8 100644 --- a/v3.go +++ b/v3.go @@ -37,6 +37,10 @@ type Opts struct { // if true, publishing messages will block until the consumer has consumed it. // Assumes a single producer and a single consumer. TestingSynchronousPubsub bool + // MaxTransactionIDDelay is the longest amount of time that we will wait for + // confirmation of an event's transaction_id before sending it to its sender. + // Set to 0 to disable this delay mechanism entirely. + MaxTransactionIDDelay time.Duration DBMaxConns int DBConnMaxIdleTime time.Duration @@ -115,7 +119,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han pMap.SetCallbacks(h2) // create v3 handler - h3, err := handler.NewSync3Handler(store, storev2, v2Client, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates) + h3, err := handler.NewSync3Handler(store, storev2, v2Client, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates, opts.MaxTransactionIDDelay) if err != nil { panic(err) }