diff --git a/state/device_data_table.go b/state/device_data_table.go index cd9fc34b..1ed3e870 100644 --- a/state/device_data_table.go +++ b/state/device_data_table.go @@ -46,7 +46,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable { func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { var row DeviceDataRow - err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID) + err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID) if err != nil { if err == sql.ErrNoRows { // if there is no device data for this user, it's not an error. @@ -70,6 +70,9 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in if !swap { return nil // don't swap } + // the caller will only look at sent, so make sure what is new is now in sent + result.DeviceLists.Sent = result.DeviceLists.New + // swap over the fields writeBack := *result writeBack.DeviceLists.Sent = result.DeviceLists.New @@ -104,7 +107,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { // select what already exists var row DeviceDataRow - err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID) + err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID) if err != nil && err != sql.ErrNoRows { return err } diff --git a/state/device_data_table_test.go b/state/device_data_table_test.go index 00c53cd0..d099eda0 100644 --- a/state/device_data_table_test.go +++ b/state/device_data_table_test.go @@ -22,17 +22,20 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) { assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes) assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts) assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits) - assertVal(t, "DeviceLists", g.DeviceLists, w.DeviceLists) + if w.DeviceLists.Sent != nil { + assertVal(t, "DeviceLists.Sent", g.DeviceLists.Sent, w.DeviceLists.Sent) + } } -func TestDeviceDataTableSwaps(t *testing.T) { +// Tests OTKCounts and FallbackKeyTypes behaviour +func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { db, close := connectToDB(t) defer close() table := NewDeviceDataTable(db) - userID := "@bob" + userID := "@TestDeviceDataTableOTKCountAndFallbackKeyTypes" deviceID := "BOB" - // test accumulating deltas + // these are individual updates from Synapse from /sync v2 deltas := []internal.DeviceData{ { UserID: userID, @@ -46,9 +49,6 @@ func TestDeviceDataTableSwaps(t *testing.T) { UserID: userID, DeviceID: deviceID, FallbackKeyTypes: []string{"foobar"}, - DeviceLists: internal.DeviceLists{ - New: internal.ToDeviceListChangesMap([]string{"alice"}, nil), - }, }, { UserID: userID, @@ -60,16 +60,38 @@ func TestDeviceDataTableSwaps(t *testing.T) { { UserID: userID, DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - New: internal.ToDeviceListChangesMap([]string{"💣"}, nil), - }, }, } + + // apply them for _, dd := range deltas { err := table.Upsert(&dd) assertNoError(t, err) } + // read them without swap, it should have replaced them correctly. + // Because sync v2 sends the complete OTK count and complete fallback key types + // every time, we always use the latest values. Because we aren't swapping, repeated + // reads produce the same result. + for i := 0; i < 3; i++ { + got, err := table.Select(userID, deviceID, false) + mustNotError(t, err) + want := internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + OTKCounts: map[string]int{ + "foo": 99, + }, + FallbackKeyTypes: []string{"foobar"}, + } + want.SetFallbackKeysChanged() + want.SetOTKCountChanged() + assertDeviceData(t, *got, want) + } + // now we swap the data. This still returns the same values, but the changed bits are no longer set + // on subsequent reads. + got, err := table.Select(userID, deviceID, true) + mustNotError(t, err) want := internal.DeviceData{ UserID: userID, DeviceID: deviceID, @@ -77,68 +99,118 @@ func TestDeviceDataTableSwaps(t *testing.T) { "foo": 99, }, FallbackKeyTypes: []string{"foobar"}, - DeviceLists: internal.DeviceLists{ - New: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), - Sent: map[string]int{}, - }, } want.SetFallbackKeysChanged() want.SetOTKCountChanged() - // check we can read-only select + assertDeviceData(t, *got, want) + + // subsequent read + got, err = table.Select(userID, deviceID, false) + mustNotError(t, err) + want = internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + OTKCounts: map[string]int{ + "foo": 99, + }, + FallbackKeyTypes: []string{"foobar"}, + } + assertDeviceData(t, *got, want) +} + +// Tests the DeviceLists field +func TestDeviceDataTableDeviceList(t *testing.T) { + db, close := connectToDB(t) + defer close() + table := NewDeviceDataTable(db) + userID := "@TestDeviceDataTableDeviceList" + deviceID := "BOB" + + // these are individual updates from Synapse from /sync v2 + deltas := []internal.DeviceData{ + { + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + New: internal.ToDeviceListChangesMap([]string{"alice"}, nil), + }, + }, + { + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + New: internal.ToDeviceListChangesMap([]string{"💣"}, nil), + }, + }, + } + // apply them + for _, dd := range deltas { + err := table.Upsert(&dd) + assertNoError(t, err) + } + + // check we can read-only select. This doesn't modify any fields. for i := 0; i < 3; i++ { got, err := table.Select(userID, deviceID, false) assertNoError(t, err) - assertDeviceData(t, *got, want) + assertDeviceData(t, *got, internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + Sent: internal.MapStringInt{}, // until we "swap" we don't consume the New entries + }, + }) } - // now swap-er-roo, at this point we still expect the "old" data, - // as it is the first time we swap + // now swap-er-roo, which shifts everything from New into Sent. got, err := table.Select(userID, deviceID, true) assertNoError(t, err) - assertDeviceData(t, *got, want) - - // changed bits were reset when we swapped - want2 := want - want2.DeviceLists = internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), - New: map[string]int{}, - } - want2.ChangedBits = 0 - want.ChangedBits = 0 + assertDeviceData(t, *got, internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), + }, + }) // this is permanent, read-only views show this too. - // Since we have swapped previously, we now expect New to be empty - // and Sent to be set. Swap again to clear Sent. - got, err = table.Select(userID, deviceID, true) + got, err = table.Select(userID, deviceID, false) assertNoError(t, err) - assertDeviceData(t, *got, want2) + assertDeviceData(t, *got, internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), + }, + }) // We now expect empty DeviceLists, as we swapped twice. - got, err = table.Select(userID, deviceID, false) + got, err = table.Select(userID, deviceID, true) assertNoError(t, err) - want3 := want2 - want3.DeviceLists = internal.DeviceLists{ - Sent: map[string]int{}, - New: map[string]int{}, - } - assertDeviceData(t, *got, want3) + assertDeviceData(t, *got, internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + Sent: internal.MapStringInt{}, + }, + }) // get back the original state - //err = table.DeleteDevice(userID, deviceID) assertNoError(t, err) for _, dd := range deltas { err = table.Upsert(&dd) assertNoError(t, err) } - want.SetFallbackKeysChanged() - want.SetOTKCountChanged() - got, err = table.Select(userID, deviceID, false) - assertNoError(t, err) - assertDeviceData(t, *got, want) - - // swap once then add once so both sent and new are populated - // Moves Alice and Bob to Sent - _, err = table.Select(userID, deviceID, true) + // Move original state to Sent by swapping + got, err = table.Select(userID, deviceID, true) assertNoError(t, err) + assertDeviceData(t, *got, internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), + }, + }) + // Add new entries to New before acknowledging Sent err = table.Upsert(&internal.DeviceData{ UserID: userID, DeviceID: deviceID, @@ -148,20 +220,18 @@ func TestDeviceDataTableSwaps(t *testing.T) { }) assertNoError(t, err) - want.ChangedBits = 0 - - want4 := want - want4.DeviceLists = internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), - New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}), - } - // Without swapping, we expect Alice and Bob in Sent, and Bob and Charlie in New + // Reading without swapping does not move New->Sent, so returns the previous value got, err = table.Select(userID, deviceID, false) assertNoError(t, err) - assertDeviceData(t, *got, want4) + assertDeviceData(t, *got, internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), + }, + }) - // another append then consume - // This results in dave to be added to New + // Append even more items to New err = table.Upsert(&internal.DeviceData{ UserID: userID, DeviceID: deviceID, @@ -170,24 +240,28 @@ func TestDeviceDataTableSwaps(t *testing.T) { }, }) assertNoError(t, err) + + // Now swap: all the combined items in New go into Sent got, err = table.Select(userID, deviceID, true) assertNoError(t, err) - want5 := want4 - want5.DeviceLists = internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), - New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}), - } - assertDeviceData(t, *got, want5) + assertDeviceData(t, *got, internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + Sent: internal.ToDeviceListChangesMap([]string{"💣", "dave"}, []string{"charlie", "dave"}), + }, + }) - // Swapping again clears New + // Swapping again clears Sent out, and since nothing is in New we get an empty list got, err = table.Select(userID, deviceID, true) assertNoError(t, err) - want5 = want4 - want5.DeviceLists = internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}), - New: map[string]int{}, - } - assertDeviceData(t, *got, want5) + assertDeviceData(t, *got, internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + DeviceLists: internal.DeviceLists{ + Sent: internal.MapStringInt{}, + }, + }) // delete everything, no data returned assertNoError(t, table.DeleteDevice(userID, deviceID)) diff --git a/tests-integration/extensions_test.go b/tests-integration/extensions_test.go index 9d30d978..b2f271f7 100644 --- a/tests-integration/extensions_test.go +++ b/tests-integration/extensions_test.go @@ -193,6 +193,48 @@ func TestExtensionE2EE(t *testing.T) { if time.Since(start) >= (500 * time.Millisecond) { t.Fatalf("sync request did not return immediately with OTK counts") } + + // check that if we lose a device list update and restart from nothing, we see the same update + v2.queueResponse(alice, sync2.SyncResponse{ + DeviceLists: struct { + Changed []string `json:"changed,omitempty"` + Left []string `json:"left,omitempty"` + }{ + Changed: wantChanged, + Left: wantLeft, + }, + }) + v2.waitUntilEmpty(t, alice) + res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{ + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 10}, // doesn't matter + }, + }}, + // enable the E2EE extension + Extensions: extensions.Request{ + E2EE: &extensions.E2EERequest{ + Core: extensions.Core{Enabled: &boolTrue}, + }, + }, + }) + m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft)) + // we actually lost this update: start again and we should see it. + res = v3.mustDoV3Request(t, aliceToken, sync3.Request{ + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 10}, // doesn't matter + }, + }}, + // enable the E2EE extension + Extensions: extensions.Request{ + E2EE: &extensions.E2EERequest{ + Core: extensions.Core{Enabled: &boolTrue}, + }, + }, + }) + m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft)) + } // Checks that to-device messages are passed from v2 to v3