diff --git a/sync2/devices_table.go b/sync2/devices_table.go index ab6ddb9f..30e68df6 100644 --- a/sync2/devices_table.go +++ b/sync2/devices_table.go @@ -32,8 +32,8 @@ func NewDevicesTable(db *sqlx.DB) *DevicesTable { // InsertDevice creates a new devices row with a blank since token if no such row // exists. Otherwise, it does nothing. -func (t *DevicesTable) InsertDevice(userID, deviceID string) error { - _, err := t.db.Exec( +func (t *DevicesTable) InsertDevice(txn *sqlx.Tx, userID, deviceID string) error { + _, err := txn.Exec( ` INSERT INTO syncv3_sync2_devices(user_id, device_id, since) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id) DO NOTHING`, userID, deviceID, "", diff --git a/sync2/devices_table_test.go b/sync2/devices_table_test.go index 5f70846a..1db3564d 100644 --- a/sync2/devices_table_test.go +++ b/sync2/devices_table_test.go @@ -2,6 +2,7 @@ package sync2 import ( "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" "os" "sort" "testing" @@ -41,18 +42,25 @@ func TestDevicesTableSinceColumn(t *testing.T) { aliceSecret1 := "mysecret1" aliceSecret2 := "mysecret2" - t.Log("Insert two tokens for Alice.") - aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, time.Now()) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, time.Now()) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } + var aliceToken, aliceToken2 *Token + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) { + t.Log("Insert two tokens for Alice.") + aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, time.Now()) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + aliceToken2, err = tokens.Insert(txn, aliceSecret2, alice, aliceDevice, time.Now()) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } - t.Log("Add a devices row for Alice") - err = devices.InsertDevice(alice, aliceDevice) + t.Log("Add a devices row for Alice") + err = devices.InsertDevice(txn, alice, aliceDevice) + if err != nil { + t.Fatalf("Failed to Insert device: %s", err) + } + return nil + }) t.Log("Pretend we're about to start a poller. Fetch Alice's token along with the since value tracked by the devices table.") accessToken, since, err := tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash) @@ -104,40 +112,50 @@ func TestTokenForEachDevice(t *testing.T) { chris := "chris" chrisDevice := "chris_desktop" - t.Log("Add a device for Alice, Bob and Chris.") - err := devices.InsertDevice(alice, aliceDevice) - if err != nil { - t.Fatalf("InsertDevice returned error: %s", err) - } - err = devices.InsertDevice(bob, bobDevice) - if err != nil { - t.Fatalf("InsertDevice returned error: %s", err) - } - err = devices.InsertDevice(chris, chrisDevice) - if err != nil { - t.Fatalf("InsertDevice returned error: %s", err) - } + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + t.Log("Add a device for Alice, Bob and Chris.") + err := devices.InsertDevice(txn, alice, aliceDevice) + if err != nil { + t.Fatalf("InsertDevice returned error: %s", err) + } + err = devices.InsertDevice(txn, bob, bobDevice) + if err != nil { + t.Fatalf("InsertDevice returned error: %s", err) + } + err = devices.InsertDevice(txn, chris, chrisDevice) + if err != nil { + t.Fatalf("InsertDevice returned error: %s", err) + } + return nil + }) t.Log("Mark Alice's device with a since token.") sinceValue := "s-1-2-3-4" - devices.UpdateDeviceSince(alice, aliceDevice, sinceValue) - - t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.") - aliceLastSeen1 := time.Now() - _, err = tokens.Insert("alice_secret", alice, aliceDevice, aliceLastSeen1) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute) - aliceToken2, err := tokens.Insert("alice_secret2", alice, aliceDevice, aliceLastSeen2) + err := devices.UpdateDeviceSince(alice, aliceDevice, sinceValue) if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - bobToken, err := tokens.Insert("bob_secret", bob, bobDevice, time.Time{}) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) + t.Fatalf("UpdateDeviceSince returned error: %s", err) } + var aliceToken2, bobToken *Token + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.") + aliceLastSeen1 := time.Now() + _, err = tokens.Insert(txn, "alice_secret", alice, aliceDevice, aliceLastSeen1) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute) + aliceToken2, err = tokens.Insert(txn, "alice_secret2", alice, aliceDevice, aliceLastSeen2) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + bobToken, err = tokens.Insert(txn, "bob_secret", bob, bobDevice, time.Time{}) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + return nil + }) + t.Log("Fetch a token for every device") gotTokens, err := tokens.TokenForEachDevice(nil) if err != nil { diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go index 13a2597f..20f064ab 100644 --- a/sync2/handler2/handler_test.go +++ b/sync2/handler2/handler_test.go @@ -1,6 +1,8 @@ package handler2_test import ( + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" "os" "reflect" "sync" @@ -131,11 +133,15 @@ func TestHandlerFreshEnsurePolling(t *testing.T) { deviceID := "ALICE" token := "aliceToken" - // the device and token needs to already exist prior to EnsurePolling - err = v2Store.DevicesTable.InsertDevice(alice, deviceID) - assertNoError(t, err) - tok, err := v2Store.TokensTable.Insert(token, alice, deviceID, time.Now()) - assertNoError(t, err) + var tok *sync2.Token + sqlutil.WithTransaction(v2Store.DB, func(txn *sqlx.Tx) error { + // the device and token needs to already exist prior to EnsurePolling + err = v2Store.DevicesTable.InsertDevice(txn, alice, deviceID) + assertNoError(t, err) + tok, err = v2Store.TokensTable.Insert(txn, token, alice, deviceID, time.Now()) + assertNoError(t, err) + return nil + }) payloadInitialSyncComplete := pubsub.V2InitialSyncComplete{ UserID: alice, diff --git a/sync2/tokens_table.go b/sync2/tokens_table.go index 066c6508..961e7bdd 100644 --- a/sync2/tokens_table.go +++ b/sync2/tokens_table.go @@ -171,10 +171,10 @@ func (t *TokensTable) TokenForEachDevice(txn *sqlx.Tx) (tokens []TokenForPoller, } // Insert a new token into the table. -func (t *TokensTable) Insert(plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) { +func (t *TokensTable) Insert(txn *sqlx.Tx, plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) { hashedToken := hashToken(plaintextToken) encToken := t.encrypt(plaintextToken) - _, err := t.db.Exec( + _, err := txn.Exec( `INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (token_hash) DO NOTHING;`, diff --git a/sync2/tokens_table_test.go b/sync2/tokens_table_test.go index 9249077e..c787b2a0 100644 --- a/sync2/tokens_table_test.go +++ b/sync2/tokens_table_test.go @@ -1,6 +1,8 @@ package sync2 import ( + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" "testing" "time" ) @@ -26,27 +28,31 @@ func TestTokensTable(t *testing.T) { aliceSecret1 := "mysecret1" aliceToken1FirstSeen := time.Now() - // Test a single token - t.Log("Insert a new token from Alice.") - aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - - t.Log("The returned Token struct should have been populated correctly.") - assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) - - t.Log("Reinsert the same token.") - reinsertedToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } + var aliceToken, reinsertedToken *Token + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) { + // Test a single token + t.Log("Insert a new token from Alice.") + aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + + t.Log("The returned Token struct should have been populated correctly.") + assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) + + t.Log("Reinsert the same token.") + reinsertedToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + return nil + }) t.Log("This should yield an equal Token struct.") assertEqualTokens(t, tokens, reinsertedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) t.Log("Try to mark Alice's token as being used after an hour.") - err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour)) + err := tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour)) if err != nil { t.Fatalf("Failed to update last seen: %s", err) } @@ -74,17 +80,20 @@ func TestTokensTable(t *testing.T) { } assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1LastSeen) - // Test a second token for Alice - t.Log("Insert a second token for Alice.") - aliceSecret2 := "mysecret2" - aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute) - aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - - t.Log("The returned Token struct should have been populated correctly.") - assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen) + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + // Test a second token for Alice + t.Log("Insert a second token for Alice.") + aliceSecret2 := "mysecret2" + aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute) + aliceToken2, err := tokens.Insert(txn, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + + t.Log("The returned Token struct should have been populated correctly.") + assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen) + return nil + }) } func TestDeletingTokens(t *testing.T) { @@ -94,11 +103,15 @@ func TestDeletingTokens(t *testing.T) { t.Log("Insert a new token from Alice.") accessToken := "mytoken" - token, err := tokens.Insert(accessToken, "@bob:builders.com", "device", time.Time{}) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } + var token *Token + err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) { + token, err = tokens.Insert(txn, accessToken, "@bob:builders.com", "device", time.Time{}) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + return nil + }) t.Log("We should be able to fetch this token without error.") _, err = tokens.Token(accessToken) if err != nil { diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 0f2c1009..98d4e610 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -420,14 +420,14 @@ func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger var token *sync2.Token err = sqlutil.WithTransaction(h.V2Store.DB, func(txn *sqlx.Tx) error { // Create a brand-new row for this token. - token, err = h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now()) + token, err = h.V2Store.TokensTable.Insert(txn, accessToken, userID, deviceID, time.Now()) if err != nil { logger.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token") return err } // Ensure we have a device row for this token. - err = h.V2Store.DevicesTable.InsertDevice(userID, deviceID) + err = h.V2Store.DevicesTable.InsertDevice(txn, userID, deviceID) if err != nil { log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 device") return err