Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure device list updates are robust to race conditions and network failures #432

Merged
merged 3 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions state/device_data_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil {
if err == sql.ErrNoRows {
// if there is no device data for this user, it's not an error.
Expand All @@ -70,6 +70,9 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
if !swap {
return nil // don't swap
}
// the caller will only look at sent, so make sure what is new is now in sent
result.DeviceLists.Sent = result.DeviceLists.New

// swap over the fields
writeBack := *result
writeBack.DeviceLists.Sent = result.DeviceLists.New
Expand Down Expand Up @@ -104,7 +107,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// select what already exists
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
if err != nil && err != sql.ErrNoRows {
return err
}
Expand Down
220 changes: 147 additions & 73 deletions state/device_data_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
assertVal(t, "DeviceLists", g.DeviceLists, w.DeviceLists)
if w.DeviceLists.Sent != nil {
assertVal(t, "DeviceLists.Sent", g.DeviceLists.Sent, w.DeviceLists.Sent)
}
}

func TestDeviceDataTableSwaps(t *testing.T) {
// Tests OTKCounts and FallbackKeyTypes behaviour
func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
db, close := connectToDB(t)
defer close()
table := NewDeviceDataTable(db)
userID := "@bob"
userID := "@TestDeviceDataTableOTKCountAndFallbackKeyTypes"
deviceID := "BOB"

// test accumulating deltas
// these are individual updates from Synapse from /sync v2
deltas := []internal.DeviceData{
{
UserID: userID,
Expand All @@ -46,9 +49,6 @@ func TestDeviceDataTableSwaps(t *testing.T) {
UserID: userID,
DeviceID: deviceID,
FallbackKeyTypes: []string{"foobar"},
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
},
},
{
UserID: userID,
Expand All @@ -60,85 +60,157 @@ func TestDeviceDataTableSwaps(t *testing.T) {
{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
},
},
}

// apply them
for _, dd := range deltas {
err := table.Upsert(&dd)
assertNoError(t, err)
}

// read them without swap, it should have replaced them correctly.
// Because sync v2 sends the complete OTK count and complete fallback key types
// every time, we always use the latest values. Because we aren't swapping, repeated
// reads produce the same result.
for i := 0; i < 3; i++ {
got, err := table.Select(userID, deviceID, false)
mustNotError(t, err)
want := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
assertDeviceData(t, *got, want)
}
// now we swap the data. This still returns the same values, but the changed bits are no longer set
// on subsequent reads.
got, err := table.Select(userID, deviceID, true)
mustNotError(t, err)
want := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
Sent: map[string]int{},
},
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
// check we can read-only select
assertDeviceData(t, *got, want)

// subsequent read
got, err = table.Select(userID, deviceID, false)
mustNotError(t, err)
want = internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
}
assertDeviceData(t, *got, want)
}

// Tests the DeviceLists field
func TestDeviceDataTableDeviceList(t *testing.T) {
db, close := connectToDB(t)
defer close()
table := NewDeviceDataTable(db)
userID := "@TestDeviceDataTableDeviceList"
deviceID := "BOB"

// these are individual updates from Synapse from /sync v2
deltas := []internal.DeviceData{
{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
},
},
{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
},
},
}
// apply them
for _, dd := range deltas {
err := table.Upsert(&dd)
assertNoError(t, err)
}

// check we can read-only select. This doesn't modify any fields.
for i := 0; i < 3; i++ {
got, err := table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{}, // until we "swap" we don't consume the New entries
},
})
}
// now swap-er-roo, at this point we still expect the "old" data,
// as it is the first time we swap
// now swap-er-roo, which shifts everything from New into Sent.
got, err := table.Select(userID, deviceID, true)
assertNoError(t, err)
assertDeviceData(t, *got, want)

// changed bits were reset when we swapped
want2 := want
want2.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: map[string]int{},
}
want2.ChangedBits = 0
want.ChangedBits = 0
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})

// this is permanent, read-only views show this too.
// Since we have swapped previously, we now expect New to be empty
// and Sent to be set. Swap again to clear Sent.
got, err = table.Select(userID, deviceID, true)
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want2)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})

// We now expect empty DeviceLists, as we swapped twice.
got, err = table.Select(userID, deviceID, false)
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
want3 := want2
want3.DeviceLists = internal.DeviceLists{
Sent: map[string]int{},
New: map[string]int{},
}
assertDeviceData(t, *got, want3)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{},
},
})

// get back the original state
//err = table.DeleteDevice(userID, deviceID)
assertNoError(t, err)
for _, dd := range deltas {
err = table.Upsert(&dd)
assertNoError(t, err)
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want)

// swap once then add once so both sent and new are populated
// Moves Alice and Bob to Sent
_, err = table.Select(userID, deviceID, true)
// Move original state to Sent by swapping
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})
// Add new entries to New before acknowledging Sent
err = table.Upsert(&internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
Expand All @@ -148,20 +220,18 @@ func TestDeviceDataTableSwaps(t *testing.T) {
})
assertNoError(t, err)

want.ChangedBits = 0

want4 := want
want4.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
}
// Without swapping, we expect Alice and Bob in Sent, and Bob and Charlie in New
// Reading without swapping does not move New->Sent, so returns the previous value
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want4)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})

// another append then consume
// This results in dave to be added to New
// Append even more items to New
err = table.Upsert(&internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
Expand All @@ -170,24 +240,28 @@ func TestDeviceDataTableSwaps(t *testing.T) {
},
})
assertNoError(t, err)

// Now swap: all the combined items in New go into Sent
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
want5 := want4
want5.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
}
assertDeviceData(t, *got, want5)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"💣", "dave"}, []string{"charlie", "dave"}),
},
})

// Swapping again clears New
// Swapping again clears Sent out, and since nothing is in New we get an empty list
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
want5 = want4
want5.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
New: map[string]int{},
}
assertDeviceData(t, *got, want5)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{},
},
})

// delete everything, no data returned
assertNoError(t, table.DeleteDevice(userID, deviceID))
Expand Down
42 changes: 42 additions & 0 deletions tests-integration/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,48 @@ func TestExtensionE2EE(t *testing.T) {
if time.Since(start) >= (500 * time.Millisecond) {
t.Fatalf("sync request did not return immediately with OTK counts")
}

// check that if we lose a device list update and restart from nothing, we see the same update
v2.queueResponse(alice, sync2.SyncResponse{
DeviceLists: struct {
Changed []string `json:"changed,omitempty"`
Left []string `json:"left,omitempty"`
}{
Changed: wantChanged,
Left: wantLeft,
},
})
v2.waitUntilEmpty(t, alice)
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10}, // doesn't matter
},
}},
// enable the E2EE extension
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
})
m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft))
// we actually lost this update: start again and we should see it.
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10}, // doesn't matter
},
}},
// enable the E2EE extension
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
})
m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft))

}

// Checks that to-device messages are passed from v2 to v3
Expand Down
Loading