Skip to content

Commit

Permalink
Restart broken Olm sessions ([MSC1719](matrix-org/matrix-spec-proposa…
Browse files Browse the repository at this point in the history
  • Loading branch information
bmarty committed Apr 20, 2020
1 parent 3615ca6 commit a6368c4
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Improvements 🙌:
- Emoji Verification | It's not the same butterfly! (#1220)
- Cross-Signing | Composer decoration: shields (#1077)
- Cross-Signing | Migrate existing keybackup to cross signing with 4S from mobile (#1197)
- Restart broken Olm sessions ([MSC1719](https://github.com/matrix-org/matrix-doc/pull/1719))

Bugfix 🐛:
- Fix summary notification staying after "mark as read"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ import im.vector.matrix.android.common.CryptoTestHelper
import org.amshove.kluent.shouldBe
import org.amshove.kluent.shouldBeEqualTo
import org.junit.Before
import org.junit.BeforeClass
import org.junit.FixMethodOrder
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.MethodSorters
import timber.log.Timber
import java.util.concurrent.CountDownLatch

/**
Expand Down Expand Up @@ -99,13 +101,13 @@ class UnwedgingTest : InstrumentedTest {

override fun onTimelineUpdated(snapshot: List<TimelineEvent>) {
val decryptedEventReceivedByBob = snapshot.filter { it.root.getClearType() == EventType.MESSAGE }
Timber.d("Bob can now decrypt ${decryptedEventReceivedByBob.size} messages")
if (decryptedEventReceivedByBob.size == 3) {
bobFinalLatch.countDown()
}
}
}
bobTimeline.addListener(bobHasThreeDecryptedEventsListener)


var latch = CountDownLatch(1)
var bobEventsListener = createEventListener(latch, 1)
Expand All @@ -128,7 +130,7 @@ class UnwedgingTest : InstrumentedTest {
sessionIdsForBob!!.size shouldBe 1
val olmSession = aliceCryptoStore.getDeviceSession(sessionIdsForBob.first(), bobSession.cryptoService().getMyDevice().identityKey()!!)!!

// Sam join the room
// Sam join the room, so it will force a new session creation
val samSession = mCryptoTestHelper.createSamAccountAndInviteToTheRoom(roomFromAlicePOV)

latch = CountDownLatch(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ object EventType {
// Relation Events
const val REACTION = "m.reaction"

// Unwedging
internal const val DUMMY = "m.dummy"

private val STATE_EVENTS = listOf(
STATE_ROOM_NAME,
STATE_ROOM_TOPIC,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import im.vector.matrix.android.api.session.room.model.RoomMemberSummary
import im.vector.matrix.android.internal.crypto.actions.MegolmSessionDataImporter
import im.vector.matrix.android.internal.crypto.actions.SetDeviceVerificationAction
import im.vector.matrix.android.internal.crypto.algorithms.IMXEncrypting
import im.vector.matrix.android.internal.crypto.algorithms.megolm.MXMegolmDecryption
import im.vector.matrix.android.internal.crypto.algorithms.megolm.MXMegolmEncryptionFactory
import im.vector.matrix.android.internal.crypto.algorithms.olm.MXOlmEncryptionFactory
import im.vector.matrix.android.internal.crypto.crosssigning.DefaultCrossSigningService
Expand Down Expand Up @@ -179,6 +180,10 @@ internal class DefaultCryptoService @Inject constructor(
private val isStarting = AtomicBoolean(false)
private val isStarted = AtomicBoolean(false)

// The date of the last time we forced establishment
// of a new session for each user:device.
private val lastNewSessionForcedDates = MXUsersDevicesMap<Long>()

fun onStateEvent(roomId: String, event: Event) {
when {
event.getClearType() == EventType.STATE_ROOM_ENCRYPTION -> onRoomEncryptionEvent(roomId, event)
Expand Down Expand Up @@ -675,11 +680,52 @@ internal class DefaultCryptoService @Inject constructor(
Timber.e("## decryptEvent() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, reason)
} else {
return alg.decryptEvent(event, timeline)
try {
return alg.decryptEvent(event, timeline)
} catch (mxCryptoError: MXCryptoError) {
if (mxCryptoError is MXCryptoError.Base
&& mxCryptoError.errorType == MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE
&& alg is MXMegolmDecryption) {
// TODO Do it on decryption thread like on iOS?
markOlmSessionForUnwedging(event, alg)
}
throw mxCryptoError
}
}
}
}

private fun markOlmSessionForUnwedging(event: Event, mxMegolmDecryption: MXMegolmDecryption) {
val senderId = event.senderId ?: return
val encryptedMessage = event.content.toModel<EncryptedEventContent>() ?: return
val deviceKey = encryptedMessage.senderKey ?: return
encryptedMessage.algorithm?.takeIf { it == MXCRYPTO_ALGORITHM_MEGOLM } ?: return

if (senderId == userId
&& deviceKey == olmDevice.deviceCurve25519Key) {
Timber.d("[MXCrypto] markOlmSessionForUnwedging: Do not unwedge ourselves")
return
}

val lastForcedDate = lastNewSessionForcedDates.getObject(senderId, deviceKey) ?: 0
val now = System.currentTimeMillis()
if (now - lastForcedDate < CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) {
Timber.d("[MXCrypto] markOlmSessionForUnwedging: New session already forced with device at $lastForcedDate. Not forcing another")
return
}

// Establish a new olm session with this device since we're failing to decrypt messages
// on a current session.
val deviceInfo = getDeviceInfo(senderId, deviceKey) ?: return Unit.also {
Timber.d("[MXCrypto] markOlmSessionForUnwedging: Couldn't find device for identity key $deviceKey: not re-establishing session")
}

Timber.d("[MXCrypto] markOlmSessionForUnwedging from $senderId:${deviceInfo.deviceId}")
lastNewSessionForcedDates.setObject(senderId, deviceKey, now)

mxMegolmDecryption.markOlmSessionForUnwedging(senderId, deviceInfo)
}

/**
* Reset replay attack data for the given timeline.
*
Expand Down Expand Up @@ -1189,4 +1235,8 @@ internal class DefaultCryptoService @Inject constructor(

@VisibleForTesting
val cryptoStoreForTesting = cryptoStore

companion object {
const val CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS = 3_600_000 // one hour
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ import im.vector.matrix.android.internal.crypto.model.rest.GossipingToDeviceObje
import im.vector.matrix.android.internal.crypto.store.IMXCryptoStore
import im.vector.matrix.android.internal.di.SessionId
import im.vector.matrix.android.internal.session.SessionScope
import im.vector.matrix.android.internal.util.MatrixCoroutineDispatchers
import im.vector.matrix.android.internal.worker.WorkerParamsFactory
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import timber.log.Timber
import javax.inject.Inject

Expand All @@ -43,7 +46,10 @@ internal class IncomingGossipingRequestManager @Inject constructor(
private val cryptoStore: IMXCryptoStore,
private val cryptoConfig: MXCryptoConfig,
private val gossipingWorkManager: GossipingWorkManager,
private val roomDecryptorProvider: RoomDecryptorProvider) {
private val roomEncryptorsStore: RoomEncryptorsStore,
private val roomDecryptorProvider: RoomDecryptorProvider,
private val coroutineDispatchers: MatrixCoroutineDispatchers,
private val cryptoCoroutineScope: CoroutineScope) {

// list of IncomingRoomKeyRequests/IncomingRoomKeyRequestCancellations
// we received in the current sync.
Expand Down Expand Up @@ -178,17 +184,42 @@ internal class IncomingGossipingRequestManager @Inject constructor(
}

private fun processIncomingRoomKeyRequest(request: IncomingRoomKeyRequest) {
val userId = request.userId
val deviceId = request.deviceId
val body = request.requestBody
val roomId = body!!.roomId
val alg = body.algorithm
val userId = request.userId ?: return
val deviceId = request.deviceId ?: return
val body = request.requestBody ?: return
val roomId = body.roomId ?: return
val alg = body.algorithm ?: return

Timber.v("## GOSSIP processIncomingRoomKeyRequest from $userId:$deviceId for $roomId / ${body.sessionId} id ${request.requestId}")
if (userId == null || credentials.userId != userId) {
// TODO: determine if we sent this device the keys already: in
Timber.w("## GOSSIP processReceivedGossipingRequests() : Ignoring room key request from other user for now")
cryptoStore.updateGossipingRequestState(request, GossipingRequestState.REJECTED)
if (credentials.userId != userId) {
Timber.w("## GOSSIP processReceivedGossipingRequests() : room key request from other user")
val senderKey = body.senderKey ?: return Unit
.also { Timber.w("missing senderKey") }
.also { cryptoStore.updateGossipingRequestState(request, GossipingRequestState.REJECTED) }
val sessionId = body.sessionId ?: return Unit
.also { Timber.w("missing sessionId") }
.also { cryptoStore.updateGossipingRequestState(request, GossipingRequestState.REJECTED) }

if (alg != MXCRYPTO_ALGORITHM_MEGOLM) {
return Unit
.also { Timber.w("Only megolm is accepted here") }
.also { cryptoStore.updateGossipingRequestState(request, GossipingRequestState.REJECTED) }
}

val roomEncryptor = roomEncryptorsStore.get(roomId) ?: return Unit
.also { Timber.w("no room Encryptor") }
.also { cryptoStore.updateGossipingRequestState(request, GossipingRequestState.REJECTED) }

cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
val isSuccess = roomEncryptor.reshareKey(sessionId, userId, deviceId, senderKey)

if (isSuccess) {
cryptoStore.updateGossipingRequestState(request, GossipingRequestState.ACCEPTED)
} else {
cryptoStore.updateGossipingRequestState(request, GossipingRequestState.UNABLE_TO_PROCESS)
}
}
cryptoStore.updateGossipingRequestState(request, GossipingRequestState.RE_REQUESTED)
return
}
// TODO: should we queue up requests we don't yet have keys for, in case they turn up later?
Expand Down Expand Up @@ -219,7 +250,7 @@ internal class IncomingGossipingRequestManager @Inject constructor(
cryptoStore.updateGossipingRequestState(request, GossipingRequestState.REJECTED)
}
// if the device is verified already, share the keys
val device = cryptoStore.getUserDevice(userId, deviceId!!)
val device = cryptoStore.getUserDevice(userId, deviceId)
if (device != null) {
if (device.isVerified) {
Timber.v("## GOSSIP processReceivedGossipingRequests() : device is already verified: sharing keys")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package im.vector.matrix.android.internal.crypto

import im.vector.matrix.android.internal.crypto.algorithms.IMXEncrypting
import im.vector.matrix.android.internal.session.SessionScope
import javax.inject.Inject

@SessionScope
internal class RoomEncryptorsStore @Inject constructor() {

// MXEncrypting instance for each room.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import im.vector.matrix.android.internal.crypto.tasks.ClaimOneTimeKeysForUsersDe
import timber.log.Timber
import javax.inject.Inject

internal class EnsureOlmSessionsForDevicesAction @Inject constructor(private val olmDevice: MXOlmDevice,
private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) {
internal class EnsureOlmSessionsForDevicesAction @Inject constructor(
private val olmDevice: MXOlmDevice,
private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) {

suspend fun handle(devicesByUser: Map<String, List<CryptoDeviceInfo>>): MXUsersDevicesMap<MXOlmSessionResult> {
suspend fun handle(devicesByUser: Map<String, List<CryptoDeviceInfo>>, force: Boolean = false): MXUsersDevicesMap<MXOlmSessionResult> {
val devicesWithoutSession = ArrayList<CryptoDeviceInfo>()

val results = MXUsersDevicesMap<MXOlmSessionResult>()
Expand All @@ -40,7 +41,7 @@ internal class EnsureOlmSessionsForDevicesAction @Inject constructor(private val

val sessionId = olmDevice.getSessionId(key!!)

if (sessionId.isNullOrEmpty()) {
if (sessionId.isNullOrEmpty() || force) {
devicesWithoutSession.add(deviceInfo)
}

Expand Down Expand Up @@ -80,7 +81,7 @@ internal class EnsureOlmSessionsForDevicesAction @Inject constructor(private val
if (null != deviceIds) {
for (deviceId in deviceIds) {
val olmSessionResult = results.getObject(userId, deviceId)
if (olmSessionResult!!.sessionId != null) {
if (olmSessionResult!!.sessionId != null && !force) {
// We already have a result for this device
continue
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,20 @@ internal interface IMXEncrypting {
* @return the encrypted content
*/
suspend fun encryptEventContent(eventContent: Content, eventType: String, userIds: List<String>): Content

/**
* Re-shares a session key with devices if the key has already been
* sent to them.
*
* @param sessionId The id of the outbound session to share.
* @param userId The id of the user who owns the target device.
* @param deviceId The id of the target device.
* @param senderKey The key of the originating device for the session.
*
* @return true in case of success
*/
suspend fun reshareKey(sessionId: String,
userId: String,
deviceId: String,
senderKey: String): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import im.vector.matrix.android.internal.crypto.actions.EnsureOlmSessionsForDevi
import im.vector.matrix.android.internal.crypto.actions.MessageEncrypter
import im.vector.matrix.android.internal.crypto.algorithms.IMXDecrypting
import im.vector.matrix.android.internal.crypto.keysbackup.DefaultKeysBackupService
import im.vector.matrix.android.internal.crypto.model.CryptoDeviceInfo
import im.vector.matrix.android.internal.crypto.model.MXUsersDevicesMap
import im.vector.matrix.android.internal.crypto.model.event.EncryptedEventContent
import im.vector.matrix.android.internal.crypto.model.event.RoomKeyContent
Expand Down Expand Up @@ -346,4 +347,25 @@ internal class MXMegolmDecryption(private val userId: String,
}
}
}

fun markOlmSessionForUnwedging(senderId: String, deviceInfo: CryptoDeviceInfo) {
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
ensureOlmSessionsForDevicesAction.handle(mapOf(senderId to listOf(deviceInfo)), force = true)

// Now send a blank message on that session so the other side knows about it.
// (The keyshare request is sent in the clear so that won't do)
// We send this first such that, as long as the toDevice messages arrive in the
// same order we sent them, the other end will get this first, set up the new session,
// then get the keyshare request and send the key over this new session (because it
// is the session it has most recently received a message on).
val payloadJson = mapOf<String, Any>("type" to EventType.DUMMY)

val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
val sendToDeviceMap = MXUsersDevicesMap<Any>()
sendToDeviceMap.setObject(senderId, deviceInfo.deviceId, encodedPayload)
Timber.v("## markOlmSessionForUnwedging() : sending to $senderId:${deviceInfo.deviceId}")
val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap)
sendToDeviceTask.execute(sendToDeviceParams)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,49 @@ internal class MXMegolmEncryption(
throw MXCryptoError.UnknownDevice(unknownDevices)
}
}

override suspend fun reshareKey(sessionId: String,
userId: String,
deviceId: String,
senderKey: String): Boolean {
Timber.d("[MXMegolmEncryption] reshareKey: $sessionId to $userId:$deviceId")
val deviceInfo = cryptoStore.getUserDevice(userId, deviceId) ?: return false
.also { Timber.w("Device not found") }

// Get the chain index of the key we previously sent this device
val chainIndex = outboundSession?.sharedWithDevices?.getObject(userId, deviceId)?.toLong() ?: return false
.also { Timber.w("[MXMegolmEncryption] reshareKey : ERROR : Never share megolm with this device") }

val devicesByUser = mapOf(userId to listOf(deviceInfo))
val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser)
val olmSessionResult = usersDeviceMap.getObject(userId, deviceId)
olmSessionResult?.sessionId
?: // no session with this device, probably because there were no one-time keys.
// ensureOlmSessionsForDevicesAction has already done the logging, so just skip it.
return false

Timber.d("[MXMegolmEncryption] reshareKey: sharing keys for session $senderKey|$sessionId:$chainIndex with device $userId:$deviceId")

val payloadJson = mutableMapOf<String, Any>("type" to EventType.FORWARDED_ROOM_KEY)

runCatching { olmDevice.getInboundGroupSession(sessionId, senderKey, roomId) }
.fold(
{
// TODO
payloadJson["content"] = it.exportKeys(chainIndex) ?: ""
},
{
// TODO
}

)

val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
val sendToDeviceMap = MXUsersDevicesMap<Any>()
sendToDeviceMap.setObject(userId, deviceId, encodedPayload)
Timber.v("## shareKeysWithDevice() : sending to $userId:$deviceId")
val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap)
sendToDeviceTask.execute(sendToDeviceParams)
return true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,9 @@ internal class MXOlmEncryption(
deviceListManager.downloadKeys(users, false)
ensureOlmSessionsForUsersAction.handle(users)
}

override suspend fun reshareKey(sessionId: String, userId: String, deviceId: String, senderKey: String): Boolean {
// No need for olm
return false
}
}
Loading

0 comments on commit a6368c4

Please sign in to comment.