From 2cd9a81ab23f083db1228f3f31d6401b1c14c995 Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 17 May 2024 09:37:38 +0100 Subject: [PATCH 1/9] Add DeviceListTable Shift over unit tests from DeviceDataTable --- state/device_data_table.go | 6 +- state/device_data_table_test.go | 154 -------------------------------- state/device_list_table.go | 110 +++++++++++++++++++++++ state/device_list_table_test.go | 108 ++++++++++++++++++++++ 4 files changed, 222 insertions(+), 156 deletions(-) create mode 100644 state/device_list_table.go create mode 100644 state/device_list_table_test.go diff --git a/state/device_data_table.go b/state/device_data_table.go index 1ed3e870..2c5576d7 100644 --- a/state/device_data_table.go +++ b/state/device_data_table.go @@ -21,7 +21,8 @@ type DeviceDataRow struct { } type DeviceDataTable struct { - db *sqlx.DB + db *sqlx.DB + deviceListTable *DeviceListTable } func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable { @@ -37,7 +38,8 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable { ALTER TABLE syncv3_device_data SET (fillfactor = 90); `) return &DeviceDataTable{ - db: db, + db: db, + deviceListTable: NewDeviceListTable(db), } } diff --git a/state/device_data_table_test.go b/state/device_data_table_test.go index d099eda0..b4fe6ad0 100644 --- a/state/device_data_table_test.go +++ b/state/device_data_table_test.go @@ -118,160 +118,6 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { assertDeviceData(t, *got, want) } -// Tests the DeviceLists field -func TestDeviceDataTableDeviceList(t *testing.T) { - db, close := connectToDB(t) - defer close() - table := NewDeviceDataTable(db) - userID := "@TestDeviceDataTableDeviceList" - deviceID := "BOB" - - // these are individual updates from Synapse from /sync v2 - deltas := []internal.DeviceData{ - { - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - New: internal.ToDeviceListChangesMap([]string{"alice"}, nil), - }, - }, - { - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - New: internal.ToDeviceListChangesMap([]string{"💣"}, nil), - }, - }, - } - // apply them - for _, dd := range deltas { - err := table.Upsert(&dd) - assertNoError(t, err) - } - - // check we can read-only select. This doesn't modify any fields. - for i := 0; i < 3; i++ { - got, err := table.Select(userID, deviceID, false) - assertNoError(t, err) - assertDeviceData(t, *got, internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - Sent: internal.MapStringInt{}, // until we "swap" we don't consume the New entries - }, - }) - } - // now swap-er-roo, which shifts everything from New into Sent. - got, err := table.Select(userID, deviceID, true) - assertNoError(t, err) - assertDeviceData(t, *got, internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), - }, - }) - - // this is permanent, read-only views show this too. - got, err = table.Select(userID, deviceID, false) - assertNoError(t, err) - assertDeviceData(t, *got, internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), - }, - }) - - // We now expect empty DeviceLists, as we swapped twice. - got, err = table.Select(userID, deviceID, true) - assertNoError(t, err) - assertDeviceData(t, *got, internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - Sent: internal.MapStringInt{}, - }, - }) - - // get back the original state - assertNoError(t, err) - for _, dd := range deltas { - err = table.Upsert(&dd) - assertNoError(t, err) - } - // Move original state to Sent by swapping - got, err = table.Select(userID, deviceID, true) - assertNoError(t, err) - assertDeviceData(t, *got, internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), - }, - }) - // Add new entries to New before acknowledging Sent - err = table.Upsert(&internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}), - }, - }) - assertNoError(t, err) - - // Reading without swapping does not move New->Sent, so returns the previous value - got, err = table.Select(userID, deviceID, false) - assertNoError(t, err) - assertDeviceData(t, *got, internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), - }, - }) - - // Append even more items to New - err = table.Upsert(&internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - New: internal.ToDeviceListChangesMap([]string{"dave"}, []string{"dave"}), - }, - }) - assertNoError(t, err) - - // Now swap: all the combined items in New go into Sent - got, err = table.Select(userID, deviceID, true) - assertNoError(t, err) - assertDeviceData(t, *got, internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - Sent: internal.ToDeviceListChangesMap([]string{"💣", "dave"}, []string{"charlie", "dave"}), - }, - }) - - // Swapping again clears Sent out, and since nothing is in New we get an empty list - got, err = table.Select(userID, deviceID, true) - assertNoError(t, err) - assertDeviceData(t, *got, internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceLists: internal.DeviceLists{ - Sent: internal.MapStringInt{}, - }, - }) - - // delete everything, no data returned - assertNoError(t, table.DeleteDevice(userID, deviceID)) - got, err = table.Select(userID, deviceID, false) - assertNoError(t, err) - if got != nil { - t.Errorf("wanted no data, got %v", got) - } -} - func TestDeviceDataTableBitset(t *testing.T) { db, close := connectToDB(t) defer close() diff --git a/state/device_list_table.go b/state/device_list_table.go new file mode 100644 index 00000000..5baae563 --- /dev/null +++ b/state/device_list_table.go @@ -0,0 +1,110 @@ +package state + +import ( + "fmt" + + "github.com/getsentry/sentry-go" + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/internal" + "github.com/matrix-org/sliding-sync/sqlutil" +) + +const ( + BucketNew = 1 + BucketSent = 2 +) + +type DeviceListTable struct { + db *sqlx.DB +} + +func NewDeviceListTable(db *sqlx.DB) *DeviceListTable { + db.MustExec(` + CREATE TABLE IF NOT EXISTS syncv3_device_list_updates ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_state SMALLINT NOT NULL, + bucket SMALLINT NOT NULL, + UNIQUE(user_id, device_id, target_user_id, bucket) + ); + -- make an index so selecting all the rows is faster + CREATE INDEX IF NOT EXISTS syncv3_device_list_updates_bucket_idx ON syncv3_device_list_updates(user_id, device_id, bucket); + -- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only + -- change the data, not anything indexed like the id) + ALTER TABLE syncv3_device_list_updates SET (fillfactor = 90); + `) + return &DeviceListTable{ + db: db, + } +} + +// Upsert new device list changes. +func (t *DeviceListTable) Upsert(userID, deviceID string, deviceListChanges map[string]int) (err error) { + err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { + for targetUserID, targetState := range deviceListChanges { + if targetState != internal.DeviceListChanged && targetState != internal.DeviceListLeft { + sentry.CaptureException(fmt.Errorf("DeviceListTable.Upsert invalid target_state: %d this is a programming error", targetState)) + continue + } + _, err = txn.Exec( + `INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket) VALUES($1,$2,$3,$4,$5) + ON CONFLICT (user_id, device_id, target_user_id, bucket) DO UPDATE SET target_state=$4`, + userID, deviceID, targetUserID, targetState, BucketNew, + ) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + sentry.CaptureException(err) + } + return +} + +// Select device list changes for this client. Returns a map of user_id => change enum. +func (t *DeviceListTable) Select(userID, deviceID string, swap bool) (result internal.MapStringInt, err error) { + err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { + if !swap { + // read only view, just return what we previously sent and don't do anything else. + result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketSent) + return err + } + + // delete the now acknowledged 'sent' data + _, err = txn.Exec(`DELETE FROM syncv3_device_list_updates WHERE user_id=$1 AND device_id=$2 AND bucket=$3`, userID, deviceID, BucketSent) + if err != nil { + return err + } + // grab any 'new' updates + result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketNew) + if err != nil { + return err + } + + // mark these 'new' updates as 'sent' + _, err = txn.Exec(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4`, BucketSent, userID, deviceID, BucketNew) + return err + }) + return +} + +func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) { + rows, err := txn.Query(`SELECT target_user_id, target_state FROM syncv3_device_list_updates WHERE user_id=$1 AND device_id=$2 AND bucket=$3`, userID, deviceID, bucket) + if err != nil { + return nil, err + } + defer rows.Close() + result = make(internal.MapStringInt) + var targetUserID string + var targetState int + for rows.Next() { + if err := rows.Scan(&targetUserID, &targetState); err != nil { + return nil, err + } + result[targetUserID] = targetState + } + return result, rows.Err() +} diff --git a/state/device_list_table_test.go b/state/device_list_table_test.go new file mode 100644 index 00000000..79cf8438 --- /dev/null +++ b/state/device_list_table_test.go @@ -0,0 +1,108 @@ +package state + +import ( + "testing" + + "github.com/matrix-org/sliding-sync/internal" +) + +// Tests the DeviceLists table +func TestDeviceListTable(t *testing.T) { + db, close := connectToDB(t) + defer close() + table := NewDeviceListTable(db) + userID := "@TestDeviceListTable" + deviceID := "BOB" + + // these are individual updates from Synapse from /sync v2 + deltas := []internal.MapStringInt{ + { + "alice": internal.DeviceListChanged, + }, + { + "💣": internal.DeviceListChanged, + }, + } + // apply them + for _, dd := range deltas { + err := table.Upsert(userID, deviceID, dd) + assertNoError(t, err) + } + + // check we can read-only select. This doesn't modify any fields. + for i := 0; i < 3; i++ { + got, err := table.Select(userID, deviceID, false) + assertNoError(t, err) + // until we "swap" we don't consume the New entries + assertVal(t, "unexpected data on swapless select", got, internal.MapStringInt{}) + } + // now swap-er-roo, which shifts everything from New into Sent. + got, err := table.Select(userID, deviceID, true) + assertNoError(t, err) + assertVal(t, "did not select what was upserted on swap select", got, internal.MapStringInt{ + "alice": internal.DeviceListChanged, + "💣": internal.DeviceListChanged, + }) + + // this is permanent, read-only views show this too. + got, err = table.Select(userID, deviceID, false) + assertNoError(t, err) + assertVal(t, "swapless select did not return the same data as before", got, internal.MapStringInt{ + "alice": internal.DeviceListChanged, + "💣": internal.DeviceListChanged, + }) + + // We now expect empty DeviceLists, as we swapped twice. + got, err = table.Select(userID, deviceID, true) + assertNoError(t, err) + assertVal(t, "swap select did not return nothing", got, internal.MapStringInt{}) + + // get back the original state + assertNoError(t, err) + for _, dd := range deltas { + err = table.Upsert(userID, deviceID, dd) + assertNoError(t, err) + } + // Move original state to Sent by swapping + got, err = table.Select(userID, deviceID, true) + assertNoError(t, err) + assertVal(t, "did not select what was upserted on swap select", got, internal.MapStringInt{ + "alice": internal.DeviceListChanged, + "💣": internal.DeviceListChanged, + }) + // Add new entries to New before acknowledging Sent + err = table.Upsert(userID, deviceID, internal.MapStringInt{ + "💣": internal.DeviceListChanged, + "charlie": internal.DeviceListLeft, + }) + assertNoError(t, err) + + // Reading without swapping does not move New->Sent, so returns the previous value + got, err = table.Select(userID, deviceID, false) + assertNoError(t, err) + assertVal(t, "swapless select did not return the same data as before", got, internal.MapStringInt{ + "alice": internal.DeviceListChanged, + "💣": internal.DeviceListChanged, + }) + + // Append even more items to New + err = table.Upsert(userID, deviceID, internal.MapStringInt{ + "charlie": internal.DeviceListChanged, // we previously said "left" for charlie, so as "changed" is newer, we should see "changed" + "dave": internal.DeviceListLeft, + }) + assertNoError(t, err) + + // Now swap: all the combined items in New go into Sent + got, err = table.Select(userID, deviceID, true) + assertNoError(t, err) + assertVal(t, "swap select did not return combined new items", got, internal.MapStringInt{ + "💣": internal.DeviceListChanged, + "charlie": internal.DeviceListChanged, + "dave": internal.DeviceListLeft, + }) + + // Swapping again clears Sent out, and since nothing is in New we get an empty list + got, err = table.Select(userID, deviceID, true) + assertNoError(t, err) + assertVal(t, "swap select did not return combined new items", got, internal.MapStringInt{}) +} From b383ed0d82f2e0fa166aa6a4c4560a7d767ec8f2 Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 17 May 2024 13:45:14 +0100 Subject: [PATCH 2/9] Add migrations and refactor internal structs --- internal/device_data.go | 102 +++-------- state/device_data_table.go | 74 ++++---- state/device_data_table_test.go | 88 +++++----- state/device_list_table.go | 105 ++++++++---- state/device_list_table_test.go | 12 ++ .../20230802121023_device_data_jsonb_test.go | 5 +- .../20230814183302_cbor_device_data.go | 5 +- .../20230814183302_cbor_device_data_test.go | 80 ++++++++- .../20240517104423_device_list_table.go | 158 ++++++++++++++++++ .../20240517104423_device_list_table_test.go | 7 + sync2/handler2/handler.go | 13 +- sync3/extensions/e2ee.go | 7 +- 12 files changed, 439 insertions(+), 217 deletions(-) create mode 100644 state/migrations/20240517104423_device_list_table.go create mode 100644 state/migrations/20240517104423_device_list_table_test.go diff --git a/internal/device_data.go b/internal/device_data.go index 651fbf15..c5fe767b 100644 --- a/internal/device_data.go +++ b/internal/device_data.go @@ -1,9 +1,5 @@ package internal -import ( - "sync" -) - const ( bitOTKCount int = iota bitFallbackKeyTypes @@ -18,9 +14,22 @@ func isBitSet(n int, bit int) bool { return val > 0 } -// DeviceData contains useful data for this user's device. This list can be expanded without prompting -// schema changes. These values are upserted into the database and persisted forever. +// DeviceData contains useful data for this user's device. type DeviceData struct { + DeviceListChanges + DeviceKeyData + UserID string + DeviceID string +} + +// This is calculated from device_lists table +type DeviceListChanges struct { + DeviceListChanged []string + DeviceListLeft []string +} + +// This gets serialised as CBOR in device_data table +type DeviceKeyData struct { // Contains the latest device_one_time_keys_count values. // Set whenever this field arrives down the v2 poller, and it replaces what was previously there. OTKCounts MapStringInt `json:"otk"` @@ -28,95 +37,22 @@ type DeviceData struct { // Set whenever this field arrives down the v2 poller, and it replaces what was previously there. // If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up. FallbackKeyTypes []string `json:"fallback"` - - DeviceLists DeviceLists `json:"dl"` - // bitset for which device data changes are present. They accumulate until they get swapped over // when they get reset ChangedBits int `json:"c"` - - UserID string - DeviceID string } -func (dd *DeviceData) SetOTKCountChanged() { +func (dd *DeviceKeyData) SetOTKCountChanged() { dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount) } -func (dd *DeviceData) SetFallbackKeysChanged() { +func (dd *DeviceKeyData) SetFallbackKeysChanged() { dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes) } -func (dd *DeviceData) OTKCountChanged() bool { +func (dd *DeviceKeyData) OTKCountChanged() bool { return isBitSet(dd.ChangedBits, bitOTKCount) } -func (dd *DeviceData) FallbackKeysChanged() bool { +func (dd *DeviceKeyData) FallbackKeysChanged() bool { return isBitSet(dd.ChangedBits, bitFallbackKeyTypes) } - -type UserDeviceKey struct { - UserID string - DeviceID string -} - -type DeviceDataMap struct { - deviceDataMu *sync.Mutex - deviceDataMap map[UserDeviceKey]*DeviceData - Pos int64 -} - -func NewDeviceDataMap(startPos int64, devices []DeviceData) *DeviceDataMap { - ddm := &DeviceDataMap{ - deviceDataMu: &sync.Mutex{}, - deviceDataMap: make(map[UserDeviceKey]*DeviceData), - Pos: startPos, - } - for i, dd := range devices { - ddm.deviceDataMap[UserDeviceKey{ - UserID: dd.UserID, - DeviceID: dd.DeviceID, - }] = &devices[i] - } - return ddm -} - -func (d *DeviceDataMap) Get(userID, deviceID string) *DeviceData { - key := UserDeviceKey{ - UserID: userID, - DeviceID: deviceID, - } - d.deviceDataMu.Lock() - defer d.deviceDataMu.Unlock() - dd, ok := d.deviceDataMap[key] - if !ok { - return nil - } - return dd -} - -func (d *DeviceDataMap) Update(dd DeviceData) DeviceData { - key := UserDeviceKey{ - UserID: dd.UserID, - DeviceID: dd.DeviceID, - } - d.deviceDataMu.Lock() - defer d.deviceDataMu.Unlock() - existing, ok := d.deviceDataMap[key] - if !ok { - existing = &DeviceData{ - UserID: dd.UserID, - DeviceID: dd.DeviceID, - } - } - if dd.OTKCounts != nil { - existing.OTKCounts = dd.OTKCounts - } - if dd.FallbackKeyTypes != nil { - existing.FallbackKeyTypes = dd.FallbackKeyTypes - } - existing.DeviceLists = existing.DeviceLists.Combine(dd.DeviceLists) - - d.deviceDataMap[key] = existing - - return *existing -} diff --git a/state/device_data_table.go b/state/device_data_table.go index 2c5576d7..86388fee 100644 --- a/state/device_data_table.go +++ b/state/device_data_table.go @@ -15,9 +15,9 @@ type DeviceDataRow struct { ID int64 `db:"id"` UserID string `db:"user_id"` DeviceID string `db:"device_id"` - // This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't + // This will contain internal.DeviceKeyData serialised as JSON. It's stored in a single column as we don't // need to perform searches on this data. - Data []byte `db:"data"` + KeyData []byte `db:"data"` } type DeviceDataTable struct { @@ -47,6 +47,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable { // This should only be called by the v3 HTTP APIs when servicing an E2EE extension request. func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { + // grab otk counts and fallback key types var row DeviceDataRow err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID) if err != nil { @@ -56,32 +57,38 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in } return err } + result = &internal.DeviceData{} + var keyData *internal.DeviceKeyData // unmarshal to swap - opts := cbor.DecOptions{ - MaxMapPairs: 1000000000, // 1 billion :( + if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil { + return err } - decMode, err := opts.DecMode() + result.UserID = userID + result.DeviceID = deviceID + if keyData != nil { + result.DeviceKeyData = *keyData + } + + deviceListChanges, err := t.deviceListTable.SelectTx(txn, userID, deviceID, swap) if err != nil { return err } - if err = decMode.Unmarshal(row.Data, &result); err != nil { - return err + for targetUserID, targetState := range deviceListChanges { + switch targetState { + case internal.DeviceListChanged: + result.DeviceListChanged = append(result.DeviceListChanged, targetUserID) + case internal.DeviceListLeft: + result.DeviceListLeft = append(result.DeviceListLeft, targetUserID) + } } - result.UserID = userID - result.DeviceID = deviceID if !swap { return nil // don't swap } - // the caller will only look at sent, so make sure what is new is now in sent - result.DeviceLists.Sent = result.DeviceLists.New - // swap over the fields - writeBack := *result - writeBack.DeviceLists.Sent = result.DeviceLists.New - writeBack.DeviceLists.New = make(map[string]int) + writeBack := *keyData writeBack.ChangedBits = 0 - if reflect.DeepEqual(result, &writeBack) { + if reflect.DeepEqual(keyData, &writeBack) { // The update to the DB would be a no-op; don't bother with it. // This helps reduce write usage and the contention on the unique index for // the device_data table. @@ -99,14 +106,13 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in return } -func (t *DeviceDataTable) DeleteDevice(userID, deviceID string) error { - _, err := t.db.Exec(`DELETE FROM syncv3_device_data WHERE user_id = $1 AND device_id = $2`, userID, deviceID) - return err -} - // Upsert combines what is in the database for this user|device with the partial entry `dd` -func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) { +func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[string]int) (err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { + // Update device lists + if err = t.deviceListTable.UpsertTx(txn, dd.UserID, dd.DeviceID, deviceListChanges); err != nil { + return err + } // select what already exists var row DeviceDataRow err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID) @@ -114,30 +120,22 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) { return err } // unmarshal and combine - var tempDD internal.DeviceData - if len(row.Data) > 0 { - opts := cbor.DecOptions{ - MaxMapPairs: 1000000000, // 1 billion :( - } - decMode, err := opts.DecMode() - if err != nil { - return err - } - if err = decMode.Unmarshal(row.Data, &tempDD); err != nil { + var keyData internal.DeviceKeyData + if len(row.KeyData) > 0 { + if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil { return err } } if dd.FallbackKeyTypes != nil { - tempDD.FallbackKeyTypes = dd.FallbackKeyTypes - tempDD.SetFallbackKeysChanged() + keyData.FallbackKeyTypes = dd.FallbackKeyTypes + keyData.SetFallbackKeysChanged() } if dd.OTKCounts != nil { - tempDD.OTKCounts = dd.OTKCounts - tempDD.SetOTKCountChanged() + keyData.OTKCounts = dd.OTKCounts + keyData.SetOTKCountChanged() } - tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists) - data, err := cbor.Marshal(tempDD) + data, err := cbor.Marshal(keyData) if err != nil { return err } diff --git a/state/device_data_table_test.go b/state/device_data_table_test.go index b4fe6ad0..eac73095 100644 --- a/state/device_data_table_test.go +++ b/state/device_data_table_test.go @@ -22,9 +22,6 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) { assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes) assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts) assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits) - if w.DeviceLists.Sent != nil { - assertVal(t, "DeviceLists.Sent", g.DeviceLists.Sent, w.DeviceLists.Sent) - } } // Tests OTKCounts and FallbackKeyTypes behaviour @@ -40,21 +37,27 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { { UserID: userID, DeviceID: deviceID, - OTKCounts: map[string]int{ - "foo": 100, - "bar": 92, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: map[string]int{ + "foo": 100, + "bar": 92, + }, }, }, { - UserID: userID, - DeviceID: deviceID, - FallbackKeyTypes: []string{"foobar"}, + UserID: userID, + DeviceID: deviceID, + DeviceKeyData: internal.DeviceKeyData{ + FallbackKeyTypes: []string{"foobar"}, + }, }, { UserID: userID, DeviceID: deviceID, - OTKCounts: map[string]int{ - "foo": 99, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: map[string]int{ + "foo": 99, + }, }, }, { @@ -65,7 +68,7 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { // apply them for _, dd := range deltas { - err := table.Upsert(&dd) + err := table.Upsert(&dd, nil) assertNoError(t, err) } @@ -79,10 +82,12 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { want := internal.DeviceData{ UserID: userID, DeviceID: deviceID, - OTKCounts: map[string]int{ - "foo": 99, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: map[string]int{ + "foo": 99, + }, + FallbackKeyTypes: []string{"foobar"}, }, - FallbackKeyTypes: []string{"foobar"}, } want.SetFallbackKeysChanged() want.SetOTKCountChanged() @@ -95,10 +100,12 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { want := internal.DeviceData{ UserID: userID, DeviceID: deviceID, - OTKCounts: map[string]int{ - "foo": 99, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: map[string]int{ + "foo": 99, + }, + FallbackKeyTypes: []string{"foobar"}, }, - FallbackKeyTypes: []string{"foobar"}, } want.SetFallbackKeysChanged() want.SetOTKCountChanged() @@ -110,10 +117,12 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { want = internal.DeviceData{ UserID: userID, DeviceID: deviceID, - OTKCounts: map[string]int{ - "foo": 99, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: map[string]int{ + "foo": 99, + }, + FallbackKeyTypes: []string{"foobar"}, }, - FallbackKeyTypes: []string{"foobar"}, } assertDeviceData(t, *got, want) } @@ -127,29 +136,32 @@ func TestDeviceDataTableBitset(t *testing.T) { otkUpdate := internal.DeviceData{ UserID: userID, DeviceID: deviceID, - OTKCounts: map[string]int{ - "foo": 100, - "bar": 92, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: map[string]int{ + "foo": 100, + "bar": 92, + }, }, - DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}}, } fallbakKeyUpdate := internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - FallbackKeyTypes: []string{"foo", "bar"}, - DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}}, + UserID: userID, + DeviceID: deviceID, + DeviceKeyData: internal.DeviceKeyData{ + FallbackKeyTypes: []string{"foo", "bar"}, + }, } bothUpdate := internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - FallbackKeyTypes: []string{"both"}, - OTKCounts: map[string]int{ - "both": 100, + UserID: userID, + DeviceID: deviceID, + DeviceKeyData: internal.DeviceKeyData{ + FallbackKeyTypes: []string{"both"}, + OTKCounts: map[string]int{ + "both": 100, + }, }, - DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}}, } - err := table.Upsert(&otkUpdate) + err := table.Upsert(&otkUpdate, nil) assertNoError(t, err) got, err := table.Select(userID, deviceID, true) assertNoError(t, err) @@ -161,7 +173,7 @@ func TestDeviceDataTableBitset(t *testing.T) { otkUpdate.ChangedBits = 0 assertDeviceData(t, *got, otkUpdate) // now same for fallback keys, but we won't swap them so it should return those diffs - err = table.Upsert(&fallbakKeyUpdate) + err = table.Upsert(&fallbakKeyUpdate, nil) assertNoError(t, err) fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts got, err = table.Select(userID, deviceID, false) @@ -173,7 +185,7 @@ func TestDeviceDataTableBitset(t *testing.T) { fallbakKeyUpdate.SetFallbackKeysChanged() assertDeviceData(t, *got, fallbakKeyUpdate) // updating both works - err = table.Upsert(&bothUpdate) + err = table.Upsert(&bothUpdate, nil) assertNoError(t, err) got, err = table.Select(userID, deviceID, true) assertNoError(t, err) diff --git a/state/device_list_table.go b/state/device_list_table.go index 5baae563..d1ab0b8c 100644 --- a/state/device_list_table.go +++ b/state/device_list_table.go @@ -14,6 +14,14 @@ const ( BucketSent = 2 ) +type DeviceListRow struct { + UserID string `db:"user_id"` + DeviceID string `db:"device_id"` + TargetUserID string `db:"target_user_id"` + TargetState int `db:"target_state"` + Bucket int `db:"bucket"` +} + type DeviceListTable struct { db *sqlx.DB } @@ -39,24 +47,9 @@ func NewDeviceListTable(db *sqlx.DB) *DeviceListTable { } } -// Upsert new device list changes. func (t *DeviceListTable) Upsert(userID, deviceID string, deviceListChanges map[string]int) (err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { - for targetUserID, targetState := range deviceListChanges { - if targetState != internal.DeviceListChanged && targetState != internal.DeviceListLeft { - sentry.CaptureException(fmt.Errorf("DeviceListTable.Upsert invalid target_state: %d this is a programming error", targetState)) - continue - } - _, err = txn.Exec( - `INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket) VALUES($1,$2,$3,$4,$5) - ON CONFLICT (user_id, device_id, target_user_id, bucket) DO UPDATE SET target_state=$4`, - userID, deviceID, targetUserID, targetState, BucketNew, - ) - if err != nil { - return err - } - } - return nil + return t.UpsertTx(txn, userID, deviceID, deviceListChanges) }) if err != nil { sentry.CaptureException(err) @@ -64,33 +57,70 @@ func (t *DeviceListTable) Upsert(userID, deviceID string, deviceListChanges map[ return } -// Select device list changes for this client. Returns a map of user_id => change enum. -func (t *DeviceListTable) Select(userID, deviceID string, swap bool) (result internal.MapStringInt, err error) { - err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { - if !swap { - // read only view, just return what we previously sent and don't do anything else. - result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketSent) - return err - } - - // delete the now acknowledged 'sent' data - _, err = txn.Exec(`DELETE FROM syncv3_device_list_updates WHERE user_id=$1 AND device_id=$2 AND bucket=$3`, userID, deviceID, BucketSent) - if err != nil { - return err +// Upsert new device list changes. +func (t *DeviceListTable) UpsertTx(txn *sqlx.Tx, userID, deviceID string, deviceListChanges map[string]int) (err error) { + if len(deviceListChanges) == 0 { + return nil + } + var deviceListRows []DeviceListRow + for targetUserID, targetState := range deviceListChanges { + if targetState != internal.DeviceListChanged && targetState != internal.DeviceListLeft { + sentry.CaptureException(fmt.Errorf("DeviceListTable.Upsert invalid target_state: %d this is a programming error", targetState)) + continue } - // grab any 'new' updates - result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketNew) + deviceListRows = append(deviceListRows, DeviceListRow{ + UserID: userID, + DeviceID: deviceID, + TargetUserID: targetUserID, + TargetState: targetState, + Bucket: BucketNew, + }) + } + chunks := sqlutil.Chunkify(5, MaxPostgresParameters, DeviceListChunker(deviceListRows)) + for _, chunk := range chunks { + _, err := txn.NamedExec(` + INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket) + VALUES(:user_id, :device_id, :target_user_id, :target_state, :bucket) + ON CONFLICT (user_id, device_id, target_user_id, bucket) DO UPDATE SET target_state = EXCLUDED.target_state`, chunk) if err != nil { return err } + } + return nil + return +} - // mark these 'new' updates as 'sent' - _, err = txn.Exec(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4`, BucketSent, userID, deviceID, BucketNew) +func (t *DeviceListTable) Select(userID, deviceID string, swap bool) (result internal.MapStringInt, err error) { + err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { + result, err = t.SelectTx(txn, userID, deviceID, swap) return err }) return } +// Select device list changes for this client. Returns a map of user_id => change enum. +func (t *DeviceListTable) SelectTx(txn *sqlx.Tx, userID, deviceID string, swap bool) (result internal.MapStringInt, err error) { + if !swap { + // read only view, just return what we previously sent and don't do anything else. + return t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketSent) + } + + // delete the now acknowledged 'sent' data + _, err = txn.Exec(`DELETE FROM syncv3_device_list_updates WHERE user_id=$1 AND device_id=$2 AND bucket=$3`, userID, deviceID, BucketSent) + if err != nil { + return nil, err + } + // grab any 'new' updates + result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketNew) + if err != nil { + return nil, err + } + + // mark these 'new' updates as 'sent' + _, err = txn.Exec(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4`, BucketSent, userID, deviceID, BucketNew) + return result, err +} + func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) { rows, err := txn.Query(`SELECT target_user_id, target_state FROM syncv3_device_list_updates WHERE user_id=$1 AND device_id=$2 AND bucket=$3`, userID, deviceID, bucket) if err != nil { @@ -108,3 +138,12 @@ func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, } return result, rows.Err() } + +type DeviceListChunker []DeviceListRow + +func (c DeviceListChunker) Len() int { + return len(c) +} +func (c DeviceListChunker) Subslice(i, j int) sqlutil.Chunker { + return c[i:j] +} diff --git a/state/device_list_table_test.go b/state/device_list_table_test.go index 79cf8438..e7c980fc 100644 --- a/state/device_list_table_test.go +++ b/state/device_list_table_test.go @@ -1,6 +1,7 @@ package state import ( + "fmt" "testing" "github.com/matrix-org/sliding-sync/internal" @@ -105,4 +106,15 @@ func TestDeviceListTable(t *testing.T) { got, err = table.Select(userID, deviceID, true) assertNoError(t, err) assertVal(t, "swap select did not return combined new items", got, internal.MapStringInt{}) + + // large updates work (chunker) + largeUpdate := internal.MapStringInt{} + for i := 0; i < 100000; i++ { + largeUpdate[fmt.Sprintf("user_%d", i)] = internal.DeviceListChanged + } + err = table.Upsert(userID, deviceID, largeUpdate) + assertNoError(t, err) + got, err = table.Select(userID, deviceID, true) + assertNoError(t, err) + assertVal(t, "swap select did not return large items", got, largeUpdate) } diff --git a/state/migrations/20230802121023_device_data_jsonb_test.go b/state/migrations/20230802121023_device_data_jsonb_test.go index eb6c5a8e..9a27bdcb 100644 --- a/state/migrations/20230802121023_device_data_jsonb_test.go +++ b/state/migrations/20230802121023_device_data_jsonb_test.go @@ -7,7 +7,6 @@ import ( "github.com/jmoiron/sqlx" _ "github.com/lib/pq" - "github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/testutils" ) @@ -48,8 +47,8 @@ func TestJSONBMigration(t *testing.T) { defer tx.Commit() // insert some "invalid" data - dd := internal.DeviceData{ - DeviceLists: internal.DeviceLists{ + dd := OldDeviceData{ + DeviceLists: OldDeviceLists{ New: map[string]int{"@💣:localhost": 1}, Sent: map[string]int{}, }, diff --git a/state/migrations/20230814183302_cbor_device_data.go b/state/migrations/20230814183302_cbor_device_data.go index 02b7f2a4..b549d933 100644 --- a/state/migrations/20230814183302_cbor_device_data.go +++ b/state/migrations/20230814183302_cbor_device_data.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/fxamacker/cbor/v2" - "github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/sync2" "github.com/pressly/goose/v3" ) @@ -59,7 +58,7 @@ func upCborDeviceData(ctx context.Context, tx *sql.Tx) error { } for dd, jsonBytes := range deviceDatas { - var data internal.DeviceData + var data OldDeviceData if err := json.Unmarshal(jsonBytes, &data); err != nil { return fmt.Errorf("failed to unmarshal JSON: %v -> %v", string(jsonBytes), err) } @@ -115,7 +114,7 @@ func downCborDeviceData(ctx context.Context, tx *sql.Tx) error { } for dd, cborBytes := range deviceDatas { - var data internal.DeviceData + var data OldDeviceData if err := cbor.Unmarshal(cborBytes, &data); err != nil { return fmt.Errorf("failed to unmarshal CBOR: %v", err) } diff --git a/state/migrations/20230814183302_cbor_device_data_test.go b/state/migrations/20230814183302_cbor_device_data_test.go index 2537ee82..a6b3b884 100644 --- a/state/migrations/20230814183302_cbor_device_data_test.go +++ b/state/migrations/20230814183302_cbor_device_data_test.go @@ -2,13 +2,15 @@ package migrations import ( "context" + "database/sql" "encoding/json" "reflect" "testing" + "github.com/fxamacker/cbor/v2" + "github.com/jmoiron/sqlx" _ "github.com/lib/pq" - "github.com/matrix-org/sliding-sync/internal" - "github.com/matrix-org/sliding-sync/state" + "github.com/matrix-org/sliding-sync/sqlutil" ) func TestCBORBMigration(t *testing.T) { @@ -30,9 +32,9 @@ func TestCBORBMigration(t *testing.T) { t.Fatal(err) } - rowData := []internal.DeviceData{ + rowData := []OldDeviceData{ { - DeviceLists: internal.DeviceLists{ + DeviceLists: OldDeviceLists{ New: map[string]int{"@bob:localhost": 2}, Sent: map[string]int{}, }, @@ -43,7 +45,7 @@ func TestCBORBMigration(t *testing.T) { UserID: "@alice:localhost", }, { - DeviceLists: internal.DeviceLists{ + DeviceLists: OldDeviceLists{ New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2}, Sent: map[string]int{"@sent:localhost": 1}, }, @@ -78,9 +80,8 @@ func TestCBORBMigration(t *testing.T) { tx.Commit() // ensure we can now select it - table := state.NewDeviceDataTable(db) for _, want := range rowData { - got, err := table.Select(want.UserID, want.DeviceID, false) + got, err := OldDeviceDataTableSelect(db, want.UserID, want.DeviceID, false) if err != nil { t.Fatal(err) } @@ -101,7 +102,7 @@ func TestCBORBMigration(t *testing.T) { // ensure it is what we originally inserted for _, want := range rowData { - var got internal.DeviceData + var got OldDeviceData var gotBytes []byte err = tx.QueryRow(`SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, want.UserID, want.DeviceID).Scan(&gotBytes) if err != nil { @@ -119,3 +120,66 @@ func TestCBORBMigration(t *testing.T) { tx.Commit() } + +type OldDeviceDataRow struct { + ID int64 `db:"id"` + UserID string `db:"user_id"` + DeviceID string `db:"device_id"` + // This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't + // need to perform searches on this data. + Data []byte `db:"data"` +} + +func OldDeviceDataTableSelect(db *sqlx.DB, userID, deviceID string, swap bool) (result *OldDeviceData, err error) { + err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + var row OldDeviceDataRow + err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID) + if err != nil { + if err == sql.ErrNoRows { + // if there is no device data for this user, it's not an error. + return nil + } + return err + } + // unmarshal to swap + opts := cbor.DecOptions{ + MaxMapPairs: 1000000000, // 1 billion :( + } + decMode, err := opts.DecMode() + if err != nil { + return err + } + if err = decMode.Unmarshal(row.Data, &result); err != nil { + return err + } + result.UserID = userID + result.DeviceID = deviceID + if !swap { + return nil // don't swap + } + // the caller will only look at sent, so make sure what is new is now in sent + result.DeviceLists.Sent = result.DeviceLists.New + + // swap over the fields + writeBack := *result + writeBack.DeviceLists.Sent = result.DeviceLists.New + writeBack.DeviceLists.New = make(map[string]int) + writeBack.ChangedBits = 0 + + if reflect.DeepEqual(result, &writeBack) { + // The update to the DB would be a no-op; don't bother with it. + // This helps reduce write usage and the contention on the unique index for + // the device_data table. + return nil + } + // re-marshal and write + data, err := cbor.Marshal(writeBack) + if err != nil { + return err + } + + _, err = txn.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID) + return err + }) + return +} diff --git a/state/migrations/20240517104423_device_list_table.go b/state/migrations/20240517104423_device_list_table.go new file mode 100644 index 00000000..3081bf43 --- /dev/null +++ b/state/migrations/20240517104423_device_list_table.go @@ -0,0 +1,158 @@ +package migrations + +import ( + "context" + "database/sql" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/lib/pq" + "github.com/matrix-org/sliding-sync/internal" + "github.com/matrix-org/sliding-sync/state" + "github.com/pressly/goose/v3" +) + +type OldDeviceData struct { + // Contains the latest device_one_time_keys_count values. + // Set whenever this field arrives down the v2 poller, and it replaces what was previously there. + OTKCounts internal.MapStringInt `json:"otk"` + // Contains the latest device_unused_fallback_key_types value + // Set whenever this field arrives down the v2 poller, and it replaces what was previously there. + // If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up. + FallbackKeyTypes []string `json:"fallback"` + + DeviceLists OldDeviceLists `json:"dl"` + + // bitset for which device data changes are present. They accumulate until they get swapped over + // when they get reset + ChangedBits int `json:"c"` + + UserID string + DeviceID string +} + +type OldDeviceLists struct { + // map user_id -> DeviceList enum + New internal.MapStringInt `json:"n"` + Sent internal.MapStringInt `json:"s"` +} + +func init() { + goose.AddMigrationContext(upDeviceListTable, downDeviceListTable) +} + +func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { + // create the table. It's a bit gross we need to dupe the schema here, but this is the first migration to + // add a new table like this. + _, err := tx.Exec(` + CREATE TABLE IF NOT EXISTS syncv3_device_list_updates ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_state SMALLINT NOT NULL, + bucket SMALLINT NOT NULL, + UNIQUE(user_id, device_id, target_user_id, bucket) + ); + -- make an index so selecting all the rows is faster + CREATE INDEX IF NOT EXISTS syncv3_device_list_updates_bucket_idx ON syncv3_device_list_updates(user_id, device_id, bucket); + -- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only + -- change the data, not anything indexed like the id) + ALTER TABLE syncv3_device_list_updates SET (fillfactor = 90); + `) + if err != nil { + return err + } + + var count int + if err = tx.QueryRow(`SELECT count(*) FROM syncv3_device_data`).Scan(&count); err != nil { + return err + } + logger.Info().Int("count", count).Msg("transferring device list data for devices") + + // scan for existing CBOR (streaming as the CBOR can be large) and for each row: + rows, err := tx.Query(`SELECT user_id, device_id, data FROM syncv3_device_data`) + if err != nil { + return err + } + defer rows.Close() + var userID string + var deviceID string + var data []byte + // every N seconds log an update + updateFrequency := time.Second * 2 + lastUpdate := time.Now() + i := 0 + for rows.Next() { + i++ + if time.Since(lastUpdate) > updateFrequency { + logger.Info().Msgf("%d/%d process device list data", i, count) + lastUpdate = time.Now() + } + // * deserialise the CBOR + if err := rows.Scan(&userID, &deviceID, &data); err != nil { + return err + } + result, err := deserialiseCBOR(data) + if err != nil { + return err + } + + // * transfer the device lists to the new device lists table + // uses a bulk copy that lib/pq supports + stmt, err := tx.Prepare(pq.CopyIn("syncv3_device_list_updates", "user_id", "device_id", "target_user_id", "target_state", "bucket")) + if err != nil { + return err + } + for targetUser, targetState := range result.DeviceLists.New { + if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketNew); err != nil { + return err + } + } + for targetUser, targetState := range result.DeviceLists.Sent { + if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketSent); err != nil { + return err + } + } + if _, err = stmt.Exec(); err != nil { + return err + } + if err = stmt.Close(); err != nil { + return err + } + + // * delete the device lists from the CBOR and update + result.DeviceLists = OldDeviceLists{ + New: make(internal.MapStringInt), + Sent: make(internal.MapStringInt), + } + data, err := cbor.Marshal(result) + if err != nil { + return err + } + _, err = tx.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID) + if err != nil { + return err + } + } + return rows.Err() +} + +func downDeviceListTable(ctx context.Context, tx *sql.Tx) error { + // no-op: we'll drop the device list updates but still work correctly as new/sent are still in the cbor but are empty + return nil +} + +func deserialiseCBOR(data []byte) (*OldDeviceData, error) { + opts := cbor.DecOptions{ + MaxMapPairs: 1000000000, // 1 billion :( + } + decMode, err := opts.DecMode() + if err != nil { + return nil, err + } + var result *OldDeviceData + if err = decMode.Unmarshal(data, &result); err != nil { + return nil, err + } + return result, nil +} diff --git a/state/migrations/20240517104423_device_list_table_test.go b/state/migrations/20240517104423_device_list_table_test.go new file mode 100644 index 00000000..75212d8d --- /dev/null +++ b/state/migrations/20240517104423_device_list_table_test.go @@ -0,0 +1,7 @@ +package migrations + +import "testing" + +func TestDeviceListTableMigration(t *testing.T) { + +} diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 40512489..b8f29757 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -234,15 +234,14 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo defer wg.Done() // some of these fields may be set partialDD := internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - OTKCounts: otkCounts, - FallbackKeyTypes: fallbackKeyTypes, - DeviceLists: internal.DeviceLists{ - New: deviceListChanges, + UserID: userID, + DeviceID: deviceID, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: otkCounts, + FallbackKeyTypes: fallbackKeyTypes, }, } - err := h.Store.DeviceDataTable.Upsert(&partialDD) + err := h.Store.DeviceDataTable.Upsert(&partialDD, deviceListChanges) if err != nil { logger.Err(err).Str("user", userID).Msg("failed to upsert device data") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) diff --git a/sync3/extensions/e2ee.go b/sync3/extensions/e2ee.go index 91550457..01711d6a 100644 --- a/sync3/extensions/e2ee.go +++ b/sync3/extensions/e2ee.go @@ -70,11 +70,10 @@ func (r *E2EERequest) ProcessInitial(ctx context.Context, res *Response, extCtx extRes.OTKCounts = dd.OTKCounts hasUpdates = true } - changed, left := internal.DeviceListChangesArrays(dd.DeviceLists.Sent) - if len(changed) > 0 || len(left) > 0 { + if len(dd.DeviceListChanged) > 0 || len(dd.DeviceListLeft) > 0 { extRes.DeviceLists = &E2EEDeviceList{ - Changed: changed, - Left: left, + Changed: dd.DeviceListChanged, + Left: dd.DeviceListLeft, } hasUpdates = true } From b6f2f9d2730022ac4ac9d11b47065463df1996fd Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 17 May 2024 14:48:08 +0100 Subject: [PATCH 3/9] Use a CURSOR --- .../20240517104423_device_list_table.go | 91 ++++++++++++------- .../20240517104423_device_list_table_test.go | 69 +++++++++++++- 2 files changed, 127 insertions(+), 33 deletions(-) diff --git a/state/migrations/20240517104423_device_list_table.go b/state/migrations/20240517104423_device_list_table.go index 3081bf43..477ac985 100644 --- a/state/migrations/20240517104423_device_list_table.go +++ b/state/migrations/20240517104423_device_list_table.go @@ -3,11 +3,13 @@ package migrations import ( "context" "database/sql" + "fmt" + "strings" "time" "github.com/fxamacker/cbor/v2" - "github.com/lib/pq" "github.com/matrix-org/sliding-sync/internal" + "github.com/matrix-org/sliding-sync/sqlutil" "github.com/matrix-org/sliding-sync/state" "github.com/pressly/goose/v3" ) @@ -52,13 +54,7 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { target_state SMALLINT NOT NULL, bucket SMALLINT NOT NULL, UNIQUE(user_id, device_id, target_user_id, bucket) - ); - -- make an index so selecting all the rows is faster - CREATE INDEX IF NOT EXISTS syncv3_device_list_updates_bucket_idx ON syncv3_device_list_updates(user_id, device_id, bucket); - -- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only - -- change the data, not anything indexed like the id) - ALTER TABLE syncv3_device_list_updates SET (fillfactor = 90); - `) + );`) if err != nil { return err } @@ -69,12 +65,12 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { } logger.Info().Int("count", count).Msg("transferring device list data for devices") - // scan for existing CBOR (streaming as the CBOR can be large) and for each row: - rows, err := tx.Query(`SELECT user_id, device_id, data FROM syncv3_device_data`) + // scan for existing CBOR (streaming as the CBOR with cursors as it can be large) + _, err = tx.Exec(`DECLARE device_data_migration_cursor CURSOR FOR SELECT user_id, device_id, data FROM syncv3_device_data`) if err != nil { return err } - defer rows.Close() + defer tx.Exec("CLOSE device_data_migration_cursor") var userID string var deviceID string var data []byte @@ -82,42 +78,73 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { updateFrequency := time.Second * 2 lastUpdate := time.Now() i := 0 - for rows.Next() { + for { + // logging i++ if time.Since(lastUpdate) > updateFrequency { logger.Info().Msgf("%d/%d process device list data", i, count) lastUpdate = time.Now() } - // * deserialise the CBOR - if err := rows.Scan(&userID, &deviceID, &data); err != nil { + + if err := tx.QueryRow( + `FETCH NEXT FROM device_data_migration_cursor`, + ).Scan(&userID, &deviceID, &data); err != nil { + if err == sql.ErrNoRows { + // End of rows. + break + } return err } + + // * deserialise the CBOR result, err := deserialiseCBOR(data) if err != nil { return err } // * transfer the device lists to the new device lists table - // uses a bulk copy that lib/pq supports - stmt, err := tx.Prepare(pq.CopyIn("syncv3_device_list_updates", "user_id", "device_id", "target_user_id", "target_state", "bucket")) - if err != nil { - return err - } + var deviceListRows []state.DeviceListRow for targetUser, targetState := range result.DeviceLists.New { - if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketNew); err != nil { - return err - } + deviceListRows = append(deviceListRows, state.DeviceListRow{ + UserID: userID, + DeviceID: deviceID, + TargetUserID: targetUser, + TargetState: targetState, + Bucket: state.BucketNew, + }) } for targetUser, targetState := range result.DeviceLists.Sent { - if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketSent); err != nil { - return err - } - } - if _, err = stmt.Exec(); err != nil { - return err + deviceListRows = append(deviceListRows, state.DeviceListRow{ + UserID: userID, + DeviceID: deviceID, + TargetUserID: targetUser, + TargetState: targetState, + Bucket: state.BucketSent, + }) } - if err = stmt.Close(); err != nil { - return err + chunks := sqlutil.Chunkify(5, state.MaxPostgresParameters, state.DeviceListChunker(deviceListRows)) + for _, chunk := range chunks { + var placeholders []string + var vals []interface{} + listChunk := chunk.(state.DeviceListChunker) + for i, deviceListRow := range listChunk { + placeholders = append(placeholders, fmt.Sprintf("($%d,$%d,$%d,$%d,$%d)", + i*5+1, + i*5+2, + i*5+3, + i*5+4, + i*5+5, + )) + vals = append(vals, deviceListRow.UserID, deviceListRow.DeviceID, deviceListRow.TargetUserID, deviceListRow.TargetState, deviceListRow.Bucket) + } + query := fmt.Sprintf( + `INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket) VALUES %s`, + strings.Join(placeholders, ","), + ) + _, err = tx.ExecContext(ctx, query, vals...) + if err != nil { + return fmt.Errorf("failed to bulk insert: %s", err) + } } // * delete the device lists from the CBOR and update @@ -129,12 +156,12 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { if err != nil { return err } - _, err = tx.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID) + _, err = tx.ExecContext(ctx, `UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID) if err != nil { return err } } - return rows.Err() + return nil } func downDeviceListTable(ctx context.Context, tx *sql.Tx) error { diff --git a/state/migrations/20240517104423_device_list_table_test.go b/state/migrations/20240517104423_device_list_table_test.go index 75212d8d..a9d53dda 100644 --- a/state/migrations/20240517104423_device_list_table_test.go +++ b/state/migrations/20240517104423_device_list_table_test.go @@ -1,7 +1,74 @@ package migrations -import "testing" +import ( + "context" + "testing" + + "github.com/fxamacker/cbor/v2" +) func TestDeviceListTableMigration(t *testing.T) { + ctx := context.Background() + db, close := connectToDB(t) + defer close() + + // Create the table in the old format (data = JSONB instead of BYTEA) + // and insert some data: we'll make sure that this data is preserved + // after migrating. + _, err := db.Exec(`CREATE TABLE IF NOT EXISTS syncv3_device_data ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + data BYTEA NOT NULL, + UNIQUE(user_id, device_id) + );`) + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + + // insert old data + rowData := []OldDeviceData{ + { + DeviceLists: OldDeviceLists{ + New: map[string]int{"@bob:localhost": 2}, + Sent: map[string]int{}, + }, + ChangedBits: 2, + OTKCounts: map[string]int{"bar": 42}, + FallbackKeyTypes: []string{"narp"}, + DeviceID: "ALICE", + UserID: "@alice:localhost", + }, + { + DeviceLists: OldDeviceLists{ + New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2}, + Sent: map[string]int{"@sent:localhost": 1}, + }, + OTKCounts: map[string]int{"foo": 100}, + FallbackKeyTypes: []string{"yep"}, + DeviceID: "BOB", + UserID: "@bob:localhost", + }, + } + for _, data := range rowData { + blob, err := cbor.Marshal(data) + if err != nil { + t.Fatal(err) + } + _, err = db.ExecContext(ctx, `INSERT INTO syncv3_device_data (user_id, device_id, data) VALUES ($1, $2, $3)`, data.UserID, data.DeviceID, blob) + if err != nil { + t.Fatal(err) + } + } + + // now migrate and ensure we didn't lose any data + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + err = upDeviceListTable(ctx, tx) + if err != nil { + t.Fatal(err) + } + tx.Commit() } From fcd9b490f9ab9f022555565d2d8423558eb121da Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 17 May 2024 15:20:18 +0100 Subject: [PATCH 4/9] Debug logging --- state/migrations/20240517104423_device_list_table.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/state/migrations/20240517104423_device_list_table.go b/state/migrations/20240517104423_device_list_table.go index 477ac985..1d427c0d 100644 --- a/state/migrations/20240517104423_device_list_table.go +++ b/state/migrations/20240517104423_device_list_table.go @@ -143,6 +143,8 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { ) _, err = tx.ExecContext(ctx, query, vals...) if err != nil { + fmt.Println(query) + fmt.Println(vals...) return fmt.Errorf("failed to bulk insert: %s", err) } } From 5028f93f83a2e1eb02966d2d7e9e2a7e4b7d8bf1 Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 17 May 2024 15:37:45 +0100 Subject: [PATCH 5/9] If there's no updates, don't insert anything --- state/migrations/20240517104423_device_list_table.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/state/migrations/20240517104423_device_list_table.go b/state/migrations/20240517104423_device_list_table.go index 1d427c0d..5e6d7712 100644 --- a/state/migrations/20240517104423_device_list_table.go +++ b/state/migrations/20240517104423_device_list_table.go @@ -122,6 +122,9 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { Bucket: state.BucketSent, }) } + if len(deviceListRows) == 0 { + continue + } chunks := sqlutil.Chunkify(5, state.MaxPostgresParameters, state.DeviceListChunker(deviceListRows)) for _, chunk := range chunks { var placeholders []string @@ -143,8 +146,6 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { ) _, err = tx.ExecContext(ctx, query, vals...) if err != nil { - fmt.Println(query) - fmt.Println(vals...) return fmt.Errorf("failed to bulk insert: %s", err) } } From af1f34861ea10b94db651b7e07f928285dd05934 Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 17 May 2024 16:09:02 +0100 Subject: [PATCH 6/9] Ensure txns are closed so we can wipe the db for other tests --- state/migrations/20231108122539_clear_stuck_invites_test.go | 1 + testutils/db.go | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/state/migrations/20231108122539_clear_stuck_invites_test.go b/state/migrations/20231108122539_clear_stuck_invites_test.go index 616fd62e..a4b4257d 100644 --- a/state/migrations/20231108122539_clear_stuck_invites_test.go +++ b/state/migrations/20231108122539_clear_stuck_invites_test.go @@ -151,6 +151,7 @@ func TestClearStuckInvites(t *testing.T) { if err != nil { t.Fatal(err) } + defer tx.Rollback() // users in room B (bob) and F (doris) should be reset. tokens, err := tokensTable.TokenForEachDevice(tx) diff --git a/testutils/db.go b/testutils/db.go index d4106aa1..af791021 100644 --- a/testutils/db.go +++ b/testutils/db.go @@ -1,11 +1,13 @@ package testutils import ( + "context" "database/sql" "fmt" "os" "os/exec" "os/user" + "time" ) var Quiet = false @@ -64,7 +66,9 @@ func PrepareDBConnectionString() (connStr string) { if err != nil { panic(err) } - _, err = db.Exec(`DROP SCHEMA public CASCADE; + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + _, err = db.ExecContext(ctx, `DROP SCHEMA public CASCADE; CREATE SCHEMA public;`) if err != nil { panic(err) From 35c9fd4d95eadd5e082b18ac5debb4fc9936a9ed Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 17 May 2024 17:10:05 +0100 Subject: [PATCH 7/9] Remove spurious return --- state/device_list_table.go | 1 - 1 file changed, 1 deletion(-) diff --git a/state/device_list_table.go b/state/device_list_table.go index d1ab0b8c..889081e7 100644 --- a/state/device_list_table.go +++ b/state/device_list_table.go @@ -87,7 +87,6 @@ func (t *DeviceListTable) UpsertTx(txn *sqlx.Tx, userID, deviceID string, device } } return nil - return } func (t *DeviceListTable) Select(userID, deviceID string, swap bool) (result internal.MapStringInt, err error) { From fdbebaea68dfa3fb0b4e7c91dc260ee53251286d Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Mon, 20 May 2024 08:22:48 +0100 Subject: [PATCH 8/9] Some review comments; swap to UPDATE..RETURNING --- state/device_data_table.go | 16 ++++++++-------- state/device_data_table_test.go | 8 ++++---- state/device_list_table.go | 27 +++++++++++++++++++++------ sync2/handler2/handler.go | 14 ++++---------- 4 files changed, 37 insertions(+), 28 deletions(-) diff --git a/state/device_data_table.go b/state/device_data_table.go index 86388fee..2c53295e 100644 --- a/state/device_data_table.go +++ b/state/device_data_table.go @@ -107,15 +107,15 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in } // Upsert combines what is in the database for this user|device with the partial entry `dd` -func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[string]int) (err error) { +func (t *DeviceDataTable) Upsert(userID, deviceID string, keys internal.DeviceKeyData, deviceListChanges map[string]int) (err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { // Update device lists - if err = t.deviceListTable.UpsertTx(txn, dd.UserID, dd.DeviceID, deviceListChanges); err != nil { + if err = t.deviceListTable.UpsertTx(txn, userID, deviceID, deviceListChanges); err != nil { return err } // select what already exists var row DeviceDataRow - err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID) + err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID) if err != nil && err != sql.ErrNoRows { return err } @@ -126,12 +126,12 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[ return err } } - if dd.FallbackKeyTypes != nil { - keyData.FallbackKeyTypes = dd.FallbackKeyTypes + if keys.FallbackKeyTypes != nil { + keyData.FallbackKeyTypes = keys.FallbackKeyTypes keyData.SetFallbackKeysChanged() } - if dd.OTKCounts != nil { - keyData.OTKCounts = dd.OTKCounts + if keys.OTKCounts != nil { + keyData.OTKCounts = keys.OTKCounts keyData.SetOTKCountChanged() } @@ -142,7 +142,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[ _, err = txn.Exec( `INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`, - dd.UserID, dd.DeviceID, data, + userID, deviceID, data, ) return err }) diff --git a/state/device_data_table_test.go b/state/device_data_table_test.go index eac73095..39e2d05f 100644 --- a/state/device_data_table_test.go +++ b/state/device_data_table_test.go @@ -68,7 +68,7 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { // apply them for _, dd := range deltas { - err := table.Upsert(&dd, nil) + err := table.Upsert(dd.UserID, dd.DeviceID, dd.DeviceKeyData, nil) assertNoError(t, err) } @@ -161,7 +161,7 @@ func TestDeviceDataTableBitset(t *testing.T) { }, } - err := table.Upsert(&otkUpdate, nil) + err := table.Upsert(otkUpdate.UserID, otkUpdate.DeviceID, otkUpdate.DeviceKeyData, nil) assertNoError(t, err) got, err := table.Select(userID, deviceID, true) assertNoError(t, err) @@ -173,7 +173,7 @@ func TestDeviceDataTableBitset(t *testing.T) { otkUpdate.ChangedBits = 0 assertDeviceData(t, *got, otkUpdate) // now same for fallback keys, but we won't swap them so it should return those diffs - err = table.Upsert(&fallbakKeyUpdate, nil) + err = table.Upsert(fallbakKeyUpdate.UserID, fallbakKeyUpdate.DeviceID, fallbakKeyUpdate.DeviceKeyData, nil) assertNoError(t, err) fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts got, err = table.Select(userID, deviceID, false) @@ -185,7 +185,7 @@ func TestDeviceDataTableBitset(t *testing.T) { fallbakKeyUpdate.SetFallbackKeysChanged() assertDeviceData(t, *got, fallbakKeyUpdate) // updating both works - err = table.Upsert(&bothUpdate, nil) + err = table.Upsert(bothUpdate.UserID, bothUpdate.DeviceID, bothUpdate.DeviceKeyData, nil) assertNoError(t, err) got, err = table.Select(userID, deviceID, true) assertNoError(t, err) diff --git a/state/device_list_table.go b/state/device_list_table.go index 889081e7..350a6066 100644 --- a/state/device_list_table.go +++ b/state/device_list_table.go @@ -109,15 +109,30 @@ func (t *DeviceListTable) SelectTx(txn *sqlx.Tx, userID, deviceID string, swap b if err != nil { return nil, err } - // grab any 'new' updates - result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketNew) + // grab any 'new' updates and atomically mark these as 'sent'. + // NB: we must not SELECT then UPDATE, because a 'new' row could be inserted after the SELECT and before the UPDATE, which + // would then be incorrectly moved to 'sent' without being returned to the client, dropping the data. This happens because + // the default transaction level is 'read committed', which /allows/ nonrepeatable reads which is: + // > A transaction re-reads data it has previously read and finds that data has been modified by another transaction (that committed since the initial read). + // We could change the isolation level but this incurs extra performance costs in addition to serialisation errors which + // need to be handled. It's easier to just use UPDATE .. RETURNING. Note that we don't require UPDATE .. RETURNING to be + // atomic in any way, it's just that we need to guarantee each things SELECTed is also UPDATEd (so in the scenario above, + // we don't care if the SELECT includes or excludes the 'new' row, but if it is SELECTed it MUST be UPDATEd). + rows, err := txn.Query(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4 RETURNING target_user_id, target_state`, BucketSent, userID, deviceID, BucketNew) if err != nil { return nil, err } - - // mark these 'new' updates as 'sent' - _, err = txn.Exec(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4`, BucketSent, userID, deviceID, BucketNew) - return result, err + defer rows.Close() + result = make(internal.MapStringInt) + var targetUserID string + var targetState int + for rows.Next() { + if err := rows.Scan(&targetUserID, &targetState); err != nil { + return nil, err + } + result[targetUserID] = targetState + } + return result, rows.Err() } func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) { diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index b8f29757..7b6ecf5b 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -232,16 +232,10 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo wg.Add(1) h.e2eeWorkerPool.Queue(func() { defer wg.Done() - // some of these fields may be set - partialDD := internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceKeyData: internal.DeviceKeyData{ - OTKCounts: otkCounts, - FallbackKeyTypes: fallbackKeyTypes, - }, - } - err := h.Store.DeviceDataTable.Upsert(&partialDD, deviceListChanges) + err := h.Store.DeviceDataTable.Upsert(userID, deviceID, internal.DeviceKeyData{ + OTKCounts: otkCounts, + FallbackKeyTypes: fallbackKeyTypes, + }, deviceListChanges) if err != nil { logger.Err(err).Str("user", userID).Msg("failed to upsert device data") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) From 3fc49bd4ea405cfe49628d67a6afeae227efaa13 Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Mon, 20 May 2024 08:37:09 +0100 Subject: [PATCH 9/9] Migration review comments --- .../20240517104423_device_list_table.go | 6 +- .../20240517104423_device_list_table_test.go | 85 +++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/state/migrations/20240517104423_device_list_table.go b/state/migrations/20240517104423_device_list_table.go index 5e6d7712..315b79c2 100644 --- a/state/migrations/20240517104423_device_list_table.go +++ b/state/migrations/20240517104423_device_list_table.go @@ -168,8 +168,10 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error { } func downDeviceListTable(ctx context.Context, tx *sql.Tx) error { - // no-op: we'll drop the device list updates but still work correctly as new/sent are still in the cbor but are empty - return nil + // no-op: we'll drop the device list updates but still work correctly as new/sent are still in the cbor but are empty. + // This will lose some device list updates. + _, err := tx.Exec(`DROP TABLE IF EXISTS syncv3_device_list_updates`) + return err } func deserialiseCBOR(data []byte) (*OldDeviceData, error) { diff --git a/state/migrations/20240517104423_device_list_table_test.go b/state/migrations/20240517104423_device_list_table_test.go index a9d53dda..88571cf4 100644 --- a/state/migrations/20240517104423_device_list_table_test.go +++ b/state/migrations/20240517104423_device_list_table_test.go @@ -2,9 +2,12 @@ package migrations import ( "context" + "reflect" "testing" "github.com/fxamacker/cbor/v2" + "github.com/matrix-org/sliding-sync/internal" + "github.com/matrix-org/sliding-sync/state" ) func TestDeviceListTableMigration(t *testing.T) { @@ -71,4 +74,86 @@ func TestDeviceListTableMigration(t *testing.T) { } tx.Commit() + wantSents := []internal.DeviceData{ + { + UserID: "@alice:localhost", + DeviceID: "ALICE", + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: internal.MapStringInt{ + "bar": 42, + }, + FallbackKeyTypes: []string{"narp"}, + ChangedBits: 2, + }, + }, + { + UserID: "@bob:localhost", + DeviceID: "BOB", + DeviceListChanges: internal.DeviceListChanges{ + DeviceListChanged: []string{"@sent:localhost"}, + }, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: internal.MapStringInt{ + "foo": 100, + }, + FallbackKeyTypes: []string{"yep"}, + }, + }, + } + + table := state.NewDeviceDataTable(db) + for _, wantSent := range wantSents { + gotSent, err := table.Select(wantSent.UserID, wantSent.DeviceID, false) + if err != nil { + t.Fatal(err) + } + assertVal(t, "'sent' data was corrupted during the migration", *gotSent, wantSent) + } + + wantNews := []internal.DeviceData{ + { + UserID: "@alice:localhost", + DeviceID: "ALICE", + DeviceListChanges: internal.DeviceListChanges{ + DeviceListLeft: []string{"@bob:localhost"}, + }, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: internal.MapStringInt{ + "bar": 42, + }, + FallbackKeyTypes: []string{"narp"}, + ChangedBits: 2, + }, + }, + { + UserID: "@bob:localhost", + DeviceID: "BOB", + DeviceListChanges: internal.DeviceListChanges{ + DeviceListChanged: []string{"@💣:localhost"}, + DeviceListLeft: []string{"@bomb:localhost"}, + }, + DeviceKeyData: internal.DeviceKeyData{ + OTKCounts: internal.MapStringInt{ + "foo": 100, + }, + FallbackKeyTypes: []string{"yep"}, + }, + }, + } + + for _, wantNew := range wantNews { + gotNew, err := table.Select(wantNew.UserID, wantNew.DeviceID, true) + if err != nil { + t.Fatal(err) + } + assertVal(t, "'new' data was corrupted during the migration", *gotNew, wantNew) + } + +} + +func assertVal(t *testing.T, msg string, got, want interface{}) { + t.Helper() + if !reflect.DeepEqual(got, want) { + t.Errorf("%s: got\n%#v\nwant\n%#v", msg, got, want) + } }