Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race conditions in tests #410

Merged
merged 3 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- name: Test
run: |
set -euo pipefail
go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all
go test -count=1 -race -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all
shell: bash
env:
POSTGRES_HOST: localhost
Expand Down
4 changes: 4 additions & 0 deletions pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ func (ps *PubSub) Notify(chanName string, p Payload) error {
return fmt.Errorf("notify with payload %v timed out", p.Type())
}
if ps.bufferSize == 0 {
// for some reason go test -race flags this as racing with calls
// to close(ch), despite the fact that it _should_ be thread-safe :S
ps.mu.Lock()
ch <- &emptyPayload{}
ps.mu.Unlock()
}
return nil
}
Expand Down
12 changes: 7 additions & 5 deletions state/accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/matrix-org/sliding-sync/testutils"
"reflect"
"sort"
"sync"
"sync/atomic"
"testing"

"github.com/matrix-org/sliding-sync/testutils"

"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/sync2"
Expand Down Expand Up @@ -680,7 +682,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
[]byte(`{"event_id":"con_4", "type":"m.room.name", "state_key":"", "content":{"name":"4"}}`),
[]byte(`{"event_id":"con_5", "type":"m.room.name", "state_key":"", "content":{"name":"5"}}`),
}
totalNumNew := 0
var totalNumNew atomic.Int64
var wg sync.WaitGroup
wg.Add(len(newEvents))
for i := 0; i < len(newEvents); i++ {
Expand All @@ -689,7 +691,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
result, err := accumulator.Accumulate(txn, userID, roomID, sync2.TimelineResponse{Events: subset})
totalNumNew += result.NumNew
totalNumNew.Add(int64(result.NumNew))
return err
})
if err != nil {
Expand All @@ -698,8 +700,8 @@ func TestAccumulatorConcurrency(t *testing.T) {
}(i)
}
wg.Wait() // wait for all goroutines to finish
if totalNumNew != len(newEvents) {
t.Errorf("got %d total new events, want %d", totalNumNew, len(newEvents))
if int(totalNumNew.Load()) != len(newEvents) {
t.Errorf("got %d total new events, want %d", totalNumNew.Load(), len(newEvents))
}
// check that the name of the room is "5"
snapshot := currentSnapshotNIDs(t, accumulator.snapshotTable, roomID)
Expand Down
47 changes: 32 additions & 15 deletions sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,23 @@ const initialSinceToken = "0"
var (
timeSinceMu sync.Mutex
timeSinceValue = time.Duration(0) // 0 means use the real impl
timeSleepMu sync.Mutex
timeSleepValue = time.Duration(0) // 0 means use the real impl
timeSleepCheck func(time.Duration) // called to check sleep values
)

func setTimeSinceValue(val time.Duration) {
timeSinceMu.Lock()
defer timeSinceMu.Unlock()
timeSinceValue = val
timeSinceMu.Unlock()
}
func setTimeSleepDelay(val time.Duration, fn ...func(d time.Duration)) {
timeSleepMu.Lock()
defer timeSleepMu.Unlock()
timeSleepValue = val
if len(fn) > 0 {
timeSleepCheck = fn[0]
}
}
func init() {
timeSince = func(t time.Time) time.Duration {
Expand All @@ -41,6 +52,18 @@ func init() {
}
return timeSinceValue
}
timeSleep = func(d time.Duration) {
timeSleepMu.Lock()
defer timeSleepMu.Unlock()
if timeSleepCheck != nil {
timeSleepCheck(d)
}
if timeSleepValue == 0 {
time.Sleep(d)
return
}
time.Sleep(timeSleepValue)
}
}

// Tests that EnsurePolling works in the happy case
Expand Down Expand Up @@ -583,12 +606,10 @@ func TestPollerGivesUpEventually(t *testing.T) {
accumulator, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) {
return nil, 524, fmt.Errorf("gateway timeout")
})
timeSleep = func(d time.Duration) {
// actually sleep to make sure async actions can happen if any
time.Sleep(1 * time.Microsecond)
}
// actually sleep to make sure async actions can happen if any
setTimeSleepDelay(time.Microsecond)
defer func() { // reset the value after the test runs
timeSleep = time.Sleep
setTimeSleepDelay(0)
}()
var wg sync.WaitGroup
wg.Add(1)
Expand Down Expand Up @@ -654,15 +675,13 @@ func TestPollerBackoff(t *testing.T) {
wantBackoffDuration = errorResponses[i].backoff
return nil, errorResponses[i].code, errorResponses[i].err
})
timeSleep = func(d time.Duration) {
setTimeSleepDelay(time.Millisecond, func(d time.Duration) {
if d != wantBackoffDuration {
t.Errorf("time.Sleep called incorrectly: got %v want %v", d, wantBackoffDuration)
}
// actually sleep to make sure async actions can happen if any
time.Sleep(1 * time.Millisecond)
}
})
defer func() { // reset the value after the test runs
timeSleep = time.Sleep
setTimeSleepDelay(0)
}()
var wg sync.WaitGroup
wg.Add(1)
Expand Down Expand Up @@ -727,12 +746,10 @@ func TestPollerResendsOnCallbackError(t *testing.T) {
pid := PollerID{UserID: "@TestPollerResendsOnCallbackError:localhost", DeviceID: "FOOBAR"}

defer func() { // reset the value after the test runs
timeSleep = time.Sleep
setTimeSleepDelay(0)
}()
// we don't actually want to wait 3s between retries, so monkey patch it out
timeSleep = func(d time.Duration) {
time.Sleep(time.Millisecond)
}
setTimeSleepDelay(time.Millisecond)

testCases := []struct {
name string
Expand Down
9 changes: 5 additions & 4 deletions sync3/connmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"sort"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -201,15 +202,15 @@ func assertDestroyedConns(t *testing.T, cidToConn map[ConnID]*Conn, isDestroyedF
t.Helper()
for cid, conn := range cidToConn {
if isDestroyedFn(cid) {
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, true, fmt.Sprintf("conn %+v was not destroyed", cid))
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed.Load(), true, fmt.Sprintf("conn %+v was not destroyed", cid))
} else {
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, false, fmt.Sprintf("conn %+v was destroyed", cid))
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed.Load(), false, fmt.Sprintf("conn %+v was destroyed", cid))
}
}
}

type mockConnHandler struct {
isDestroyed bool
isDestroyed atomic.Bool
cancel context.CancelFunc
}

Expand All @@ -219,7 +220,7 @@ func (c *mockConnHandler) OnIncomingRequest(ctx context.Context, cid ConnID, req
func (c *mockConnHandler) OnUpdate(ctx context.Context, update caches.Update) {}
func (c *mockConnHandler) PublishEventsUpTo(roomID string, nid int64) {}
func (c *mockConnHandler) Destroy() {
c.isDestroyed = true
c.isDestroyed.Store(true)
}
func (c *mockConnHandler) Alive() bool {
return true // buffer never fills up
Expand Down
13 changes: 7 additions & 6 deletions tests-integration/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"net/http"
"os"
"sync/atomic"
"testing"
"time"

"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"

"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/sync3/extensions"
Expand Down Expand Up @@ -45,7 +46,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
// now sync with device B, and check we send the filter up
deviceBToken := "DEVICE_B_TOKEN"
v2.addAccountWithDeviceID(alice, "B", deviceBToken)
seenInitialRequest := false
var seenInitialRequest atomic.Bool
v2.SetCheckRequest(func(token string, req *http.Request) {
if token != deviceBToken {
return
Expand All @@ -62,7 +63,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
timelineLimit := filterJSON.Get("room.timeline.limit").Int()
roomsFilter := filterJSON.Get("room.rooms")

if !seenInitialRequest {
if !seenInitialRequest.Load() {
// First poll: should be an initial sync, limit 1, excluding all room timelines.
if since != "" {
t.Errorf("Expected no since token on first poll, but got %v", since)
Expand All @@ -89,7 +90,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
}
}

seenInitialRequest = true
seenInitialRequest.Store(true)
})

wantMsg := json.RawMessage(`{"type":"f","content":{"f":"b"}}`)
Expand All @@ -110,7 +111,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
},
})

if !seenInitialRequest {
if !seenInitialRequest.Load() {
t.Fatalf("did not see initial request for 2nd device")
}
// the first request will not wait for the response before returning due to device A. Poll again
Expand Down
Loading