diff --git a/spec/unit/rust-crypto/RoomEncryptor.spec.ts b/spec/unit/rust-crypto/RoomEncryptor.spec.ts index 66d21f9d5da..de6863cbce1 100644 --- a/spec/unit/rust-crypto/RoomEncryptor.spec.ts +++ b/spec/unit/rust-crypto/RoomEncryptor.spec.ts @@ -16,16 +16,125 @@ * / */ -import { HistoryVisibility as RustHistoryVisibility } from "@matrix-org/matrix-sdk-crypto-wasm"; - -import { HistoryVisibility } from "../../../src"; -import { toRustHistoryVisibility } from "../../../src/rust-crypto/RoomEncryptor"; - -it.each([ - [HistoryVisibility.Invited, RustHistoryVisibility.Invited], - [HistoryVisibility.Joined, RustHistoryVisibility.Joined], - [HistoryVisibility.Shared, RustHistoryVisibility.Shared], - [HistoryVisibility.WorldReadable, RustHistoryVisibility.WorldReadable], -])("JS HistoryVisibility to Rust HistoryVisibility: converts %s to %s", (historyVisibility, expected) => { - expect(toRustHistoryVisibility(historyVisibility)).toBe(expected); +import { + Curve25519PublicKey, + Ed25519PublicKey, + HistoryVisibility as RustHistoryVisibility, + IdentityKeys, + OlmMachine, +} from "@matrix-org/matrix-sdk-crypto-wasm"; +import { Mocked } from "jest-mock"; + +import { HistoryVisibility, MatrixEvent, Room, RoomMember } from "../../../src"; +import { RoomEncryptor, toRustHistoryVisibility } from "../../../src/rust-crypto/RoomEncryptor"; +import { KeyClaimManager } from "../../../src/rust-crypto/KeyClaimManager"; +import { defer } from "../../../src/utils"; +import { OutgoingRequestsManager } from "../../../src/rust-crypto/OutgoingRequestsManager"; + +describe("RoomEncryptor", () => { + describe("History Visibility", () => { + it.each([ + [HistoryVisibility.Invited, RustHistoryVisibility.Invited], + [HistoryVisibility.Joined, RustHistoryVisibility.Joined], + [HistoryVisibility.Shared, RustHistoryVisibility.Shared], + [HistoryVisibility.WorldReadable, RustHistoryVisibility.WorldReadable], + ])("JS HistoryVisibility to Rust HistoryVisibility: converts %s to %s", (historyVisibility, expected) => { + expect(toRustHistoryVisibility(historyVisibility)).toBe(expected); + }); + }); + + describe("RoomEncryptor", () => { + /** The room encryptor under test */ + let roomEncryptor: RoomEncryptor; + + let mockOlmMachine: Mocked; + let mockKeyClaimManager: Mocked; + let mockOutgoingRequestManager: Mocked; + let mockRoom: Mocked; + + function createMockEvent(text: string): Mocked { + return { + getTxnId: jest.fn().mockReturnValue(""), + getType: jest.fn().mockReturnValue("m.room.message"), + getContent: jest.fn().mockReturnValue({ + body: text, + msgtype: "m.text", + }), + makeEncrypted: jest.fn().mockReturnValue(undefined), + } as unknown as Mocked; + } + + beforeEach(() => { + mockOlmMachine = { + identityKeys: { + curve25519: { + toBase64: jest.fn().mockReturnValue("curve25519"), + } as unknown as Curve25519PublicKey, + ed25519: { + toBase64: jest.fn().mockReturnValue("ed25519"), + } as unknown as Ed25519PublicKey, + } as unknown as Mocked, + shareRoomKey: jest.fn(), + updateTrackedUsers: jest.fn().mockResolvedValue(undefined), + encryptRoomEvent: jest.fn().mockResolvedValue("{}"), + } as unknown as Mocked; + + mockKeyClaimManager = { + ensureSessionsForUsers: jest.fn(), + } as unknown as Mocked; + + mockOutgoingRequestManager = { + doProcessOutgoingRequests: jest.fn().mockResolvedValue(undefined), + } as unknown as Mocked; + + const mockRoomMember = { + userId: "@alice:example.org", + membership: "join", + } as unknown as Mocked; + + mockRoom = { + roomId: "!foo:example.org", + getJoinedMembers: jest.fn().mockReturnValue([mockRoomMember]), + getEncryptionTargetMembers: jest.fn().mockReturnValue([mockRoomMember]), + shouldEncryptForInvitedMembers: jest.fn().mockReturnValue(true), + getHistoryVisibility: jest.fn().mockReturnValue(HistoryVisibility.Invited), + getBlacklistUnverifiedDevices: jest.fn().mockReturnValue(false), + } as unknown as Mocked; + + roomEncryptor = new RoomEncryptor( + mockOlmMachine, + mockKeyClaimManager, + mockOutgoingRequestManager, + mockRoom, + { algorithm: "m.megolm.v1.aes-sha2" }, + ); + }); + + it("should ensure that there is only one shareRoomKey at a time", async () => { + const deferredShare = defer(); + const insideOlmShareRoom = defer(); + + mockOlmMachine.shareRoomKey.mockImplementationOnce(async () => { + insideOlmShareRoom.resolve(); + await deferredShare.promise; + }); + + roomEncryptor.prepareForEncryption(false); + await insideOlmShareRoom.promise; + + // call several times more + roomEncryptor.prepareForEncryption(false); + roomEncryptor.encryptEvent(createMockEvent("Hello"), false); + roomEncryptor.prepareForEncryption(false); + roomEncryptor.encryptEvent(createMockEvent("World"), false); + + expect(mockOlmMachine.shareRoomKey).toHaveBeenCalledTimes(1); + + deferredShare.resolve(); + await roomEncryptor.prepareForEncryption(false); + + // should have been called again + expect(mockOlmMachine.shareRoomKey).toHaveBeenCalledTimes(6); + }); + }); }); diff --git a/src/rust-crypto/RoomEncryptor.ts b/src/rust-crypto/RoomEncryptor.ts index 9f11f79c13b..5a752c53fdc 100644 --- a/src/rust-crypto/RoomEncryptor.ts +++ b/src/rust-crypto/RoomEncryptor.ts @@ -45,6 +45,9 @@ export class RoomEncryptor { /** whether the room members have been loaded and tracked for the first time */ private lazyLoadedMembersResolved = false; + /** Ensures that there is only one call to shareRoomKeys at a time */ + private currentShareRoomKeyPromise = Promise.resolve(); + /** * @param olmMachine - The rust-sdk's OlmMachine * @param keyClaimManager - Our KeyClaimManager, which manages the queue of one-time-key claim requests @@ -198,11 +201,7 @@ export class RoomEncryptor { rustEncryptionSettings.onlyAllowTrustedDevices = this.room.getBlacklistUnverifiedDevices() ?? globalBlacklistUnverifiedDevices; - const shareMessages: ToDeviceRequest[] = await this.olmMachine.shareRoomKey( - new RoomId(this.room.roomId), - userList, - rustEncryptionSettings, - ); + const shareMessages: ToDeviceRequest[] = await this.shareRoomKey(userList, rustEncryptionSettings); if (shareMessages) { for (const m of shareMessages) { await this.outgoingRequestManager.outgoingRequestProcessor.makeOutgoingRequest(m); @@ -210,6 +209,30 @@ export class RoomEncryptor { } } + /** + * The Rust-SDK requires that we only have one shareRoomKey process in flight at once for a room. + * This method ensures that, by only having one call to shareRoomKey active at once (and making them + * queue up in order). + * + * @param userList - list of userIDs to share with + * @param rustEncryptionSettings - encryption settings to use + * + * @returns a promise which resolves to the list of ToDeviceRequests to send + */ + private async shareRoomKey( + userList: UserId[], + rustEncryptionSettings: EncryptionSettings, + ): Promise { + const prom = this.currentShareRoomKeyPromise + .catch(() => { + // any errors in the previous claim will have been reported already, so there is nothing to do here. + // we just throw away the error and start anew. + }) + .then(() => this.olmMachine.shareRoomKey(new RoomId(this.room.roomId), userList, rustEncryptionSettings)); + this.currentShareRoomKeyPromise = prom; + return prom; + } + /** * Discard any existing group session for this room */