diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a5bcd30b..4c496bef 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,7 +59,7 @@ jobs: - name: Test run: | set -euo pipefail - go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all + go test -count=1 -race -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all shell: bash env: POSTGRES_HOST: localhost diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index dea3efe1..f3132ae8 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -84,7 +84,11 @@ func (ps *PubSub) Notify(chanName string, p Payload) error { return fmt.Errorf("notify with payload %v timed out", p.Type()) } if ps.bufferSize == 0 { + // for some reason go test -race flags this as racing with calls + // to close(ch), despite the fact that it _should_ be thread-safe :S + ps.mu.Lock() ch <- &emptyPayload{} + ps.mu.Unlock() } return nil } diff --git a/state/accumulator_test.go b/state/accumulator_test.go index db358830..5d95b0b4 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -4,12 +4,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/matrix-org/sliding-sync/testutils" "reflect" "sort" "sync" + "sync/atomic" "testing" + "github.com/matrix-org/sliding-sync/testutils" + "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/sqlutil" "github.com/matrix-org/sliding-sync/sync2" @@ -680,7 +682,7 @@ func TestAccumulatorConcurrency(t *testing.T) { []byte(`{"event_id":"con_4", "type":"m.room.name", "state_key":"", "content":{"name":"4"}}`), []byte(`{"event_id":"con_5", "type":"m.room.name", "state_key":"", "content":{"name":"5"}}`), } - totalNumNew := 0 + var totalNumNew atomic.Int64 var wg sync.WaitGroup wg.Add(len(newEvents)) for i := 0; i < len(newEvents); i++ { @@ -689,7 +691,7 @@ func TestAccumulatorConcurrency(t *testing.T) { subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { result, err := accumulator.Accumulate(txn, userID, roomID, sync2.TimelineResponse{Events: subset}) - totalNumNew += result.NumNew + totalNumNew.Add(int64(result.NumNew)) return err }) if err != nil { @@ -698,8 +700,8 @@ func TestAccumulatorConcurrency(t *testing.T) { }(i) } wg.Wait() // wait for all goroutines to finish - if totalNumNew != len(newEvents) { - t.Errorf("got %d total new events, want %d", totalNumNew, len(newEvents)) + if int(totalNumNew.Load()) != len(newEvents) { + t.Errorf("got %d total new events, want %d", totalNumNew.Load(), len(newEvents)) } // check that the name of the room is "5" snapshot := currentSnapshotNIDs(t, accumulator.snapshotTable, roomID) diff --git a/sync2/poller_test.go b/sync2/poller_test.go index 1ce7a745..4b145aab 100644 --- a/sync2/poller_test.go +++ b/sync2/poller_test.go @@ -25,12 +25,23 @@ const initialSinceToken = "0" var ( timeSinceMu sync.Mutex timeSinceValue = time.Duration(0) // 0 means use the real impl + timeSleepMu sync.Mutex + timeSleepValue = time.Duration(0) // 0 means use the real impl + timeSleepCheck func(time.Duration) // called to check sleep values ) func setTimeSinceValue(val time.Duration) { timeSinceMu.Lock() + defer timeSinceMu.Unlock() timeSinceValue = val - timeSinceMu.Unlock() +} +func setTimeSleepDelay(val time.Duration, fn ...func(d time.Duration)) { + timeSleepMu.Lock() + defer timeSleepMu.Unlock() + timeSleepValue = val + if len(fn) > 0 { + timeSleepCheck = fn[0] + } } func init() { timeSince = func(t time.Time) time.Duration { @@ -41,6 +52,18 @@ func init() { } return timeSinceValue } + timeSleep = func(d time.Duration) { + timeSleepMu.Lock() + defer timeSleepMu.Unlock() + if timeSleepCheck != nil { + timeSleepCheck(d) + } + if timeSleepValue == 0 { + time.Sleep(d) + return + } + time.Sleep(timeSleepValue) + } } // Tests that EnsurePolling works in the happy case @@ -583,12 +606,10 @@ func TestPollerGivesUpEventually(t *testing.T) { accumulator, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) { return nil, 524, fmt.Errorf("gateway timeout") }) - timeSleep = func(d time.Duration) { - // actually sleep to make sure async actions can happen if any - time.Sleep(1 * time.Microsecond) - } + // actually sleep to make sure async actions can happen if any + setTimeSleepDelay(time.Microsecond) defer func() { // reset the value after the test runs - timeSleep = time.Sleep + setTimeSleepDelay(0) }() var wg sync.WaitGroup wg.Add(1) @@ -654,15 +675,13 @@ func TestPollerBackoff(t *testing.T) { wantBackoffDuration = errorResponses[i].backoff return nil, errorResponses[i].code, errorResponses[i].err }) - timeSleep = func(d time.Duration) { + setTimeSleepDelay(time.Millisecond, func(d time.Duration) { if d != wantBackoffDuration { t.Errorf("time.Sleep called incorrectly: got %v want %v", d, wantBackoffDuration) } - // actually sleep to make sure async actions can happen if any - time.Sleep(1 * time.Millisecond) - } + }) defer func() { // reset the value after the test runs - timeSleep = time.Sleep + setTimeSleepDelay(0) }() var wg sync.WaitGroup wg.Add(1) @@ -727,12 +746,10 @@ func TestPollerResendsOnCallbackError(t *testing.T) { pid := PollerID{UserID: "@TestPollerResendsOnCallbackError:localhost", DeviceID: "FOOBAR"} defer func() { // reset the value after the test runs - timeSleep = time.Sleep + setTimeSleepDelay(0) }() // we don't actually want to wait 3s between retries, so monkey patch it out - timeSleep = func(d time.Duration) { - time.Sleep(time.Millisecond) - } + setTimeSleepDelay(time.Millisecond) testCases := []struct { name string diff --git a/sync3/connmap_test.go b/sync3/connmap_test.go index 8fbcf8cd..adf3aaf4 100644 --- a/sync3/connmap_test.go +++ b/sync3/connmap_test.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "sort" + "sync/atomic" "testing" "time" @@ -201,15 +202,15 @@ func assertDestroyedConns(t *testing.T, cidToConn map[ConnID]*Conn, isDestroyedF t.Helper() for cid, conn := range cidToConn { if isDestroyedFn(cid) { - mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, true, fmt.Sprintf("conn %+v was not destroyed", cid)) + mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed.Load(), true, fmt.Sprintf("conn %+v was not destroyed", cid)) } else { - mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, false, fmt.Sprintf("conn %+v was destroyed", cid)) + mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed.Load(), false, fmt.Sprintf("conn %+v was destroyed", cid)) } } } type mockConnHandler struct { - isDestroyed bool + isDestroyed atomic.Bool cancel context.CancelFunc } @@ -219,7 +220,7 @@ func (c *mockConnHandler) OnIncomingRequest(ctx context.Context, cid ConnID, req func (c *mockConnHandler) OnUpdate(ctx context.Context, update caches.Update) {} func (c *mockConnHandler) PublishEventsUpTo(roomID string, nid int64) {} func (c *mockConnHandler) Destroy() { - c.isDestroyed = true + c.isDestroyed.Store(true) } func (c *mockConnHandler) Alive() bool { return true // buffer never fills up diff --git a/tests-integration/poller_test.go b/tests-integration/poller_test.go index 5ce3b4dd..3c3fae06 100644 --- a/tests-integration/poller_test.go +++ b/tests-integration/poller_test.go @@ -4,14 +4,15 @@ import ( "context" "encoding/json" "fmt" - "github.com/jmoiron/sqlx" - "github.com/matrix-org/sliding-sync/sqlutil" "net/http" "os" "sync/atomic" "testing" "time" + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" + "github.com/matrix-org/sliding-sync/sync2" "github.com/matrix-org/sliding-sync/sync3" "github.com/matrix-org/sliding-sync/sync3/extensions" @@ -45,7 +46,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) { // now sync with device B, and check we send the filter up deviceBToken := "DEVICE_B_TOKEN" v2.addAccountWithDeviceID(alice, "B", deviceBToken) - seenInitialRequest := false + var seenInitialRequest atomic.Bool v2.SetCheckRequest(func(token string, req *http.Request) { if token != deviceBToken { return @@ -62,7 +63,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) { timelineLimit := filterJSON.Get("room.timeline.limit").Int() roomsFilter := filterJSON.Get("room.rooms") - if !seenInitialRequest { + if !seenInitialRequest.Load() { // First poll: should be an initial sync, limit 1, excluding all room timelines. if since != "" { t.Errorf("Expected no since token on first poll, but got %v", since) @@ -89,7 +90,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) { } } - seenInitialRequest = true + seenInitialRequest.Store(true) }) wantMsg := json.RawMessage(`{"type":"f","content":{"f":"b"}}`) @@ -110,7 +111,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) { }, }) - if !seenInitialRequest { + if !seenInitialRequest.Load() { t.Fatalf("did not see initial request for 2nd device") } // the first request will not wait for the response before returning due to device A. Poll again