diff --git a/src/lsps2/payment_queue.rs b/src/lsps2/payment_queue.rs index dfd97ad..3459648 100644 --- a/src/lsps2/payment_queue.rs +++ b/src/lsps2/payment_queue.rs @@ -49,8 +49,8 @@ impl PaymentQueue { position.map(|position| self.payments.remove(position)) } - pub(crate) fn clear(&mut self) -> Vec { - self.payments.drain(..).map(|(_k, v)| v).flatten().collect() + pub(crate) fn clear(&mut self) -> Vec<(PaymentHash, Vec)> { + self.payments.drain(..).collect() } } @@ -109,11 +109,14 @@ mod tests { ); assert_eq!( payment_queue.clear(), - vec![InterceptedHTLC { - intercept_id: InterceptId([1; 32]), - expected_outbound_amount_msat: 300_000_000, - payment_hash: PaymentHash([101; 32]), - }] + vec![( + PaymentHash([101; 32]), + vec![InterceptedHTLC { + intercept_id: InterceptId([1; 32]), + expected_outbound_amount_msat: 300_000_000, + payment_hash: PaymentHash([101; 32]), + }] + )] ); } } diff --git a/src/lsps2/service.rs b/src/lsps2/service.rs index 34025c8..0ce372d 100644 --- a/src/lsps2/service.rs +++ b/src/lsps2/service.rs @@ -362,7 +362,13 @@ impl OutboundJITChannelState { let mut payment_queue_lock = payment_queue.lock().unwrap(); let payment_forwarded = OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; - let forward_htlcs = ForwardHTLCsAction(*channel_id, payment_queue_lock.clear()); + let htlcs = payment_queue_lock + .clear() + .into_iter() + .map(|(_, htlcs)| htlcs) + .flatten() + .collect(); + let forward_htlcs = ForwardHTLCsAction(*channel_id, htlcs); Ok((payment_forwarded, Some(forward_htlcs))) }, OutboundJITChannelState::PaymentForwarded { channel_id } => { @@ -898,6 +904,67 @@ where Ok(()) } + /// Used by LSP to fail intercepted htlcs backwards when the channel open fails for any reason. + /// + /// Should be called in response to receiving a [`LSPS2ServiceEvent::OpenChannel`] event. + /// + /// The JIT channel state is reset such that the payer can attempt payment again. + /// [`LSPS2ServiceEvent::OpenChannel`]: crate::lsps2::event::LSPS2ServiceEvent::OpenChannel + pub fn channel_open_failed( + &self, counterparty_node_id: &PublicKey, intercept_scid: u64, + ) -> Result<(), APIError> { + let outer_state_lock = self.per_peer_state.read().unwrap(); + match outer_state_lock.get(counterparty_node_id) { + Some(inner_state_lock) => { + let mut peer_state = inner_state_lock.lock().unwrap(); + + if let Some(jit_channel) = + peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid) + { + let new_state = if let OutboundJITChannelState::PendingChannelOpen { + payment_queue, + .. + } = &jit_channel.state + { + let mut queue = payment_queue.lock().unwrap(); + let payment_hashes = queue + .clear() + .into_iter() + .map(|(payment_hash, _)| payment_hash) + .collect::>(); + for payment_hash in payment_hashes { + self.channel_manager.get_cm().fail_htlc_backwards_with_reason( + &payment_hash, + FailureCode::TemporaryNodeFailure, + ); + } + OutboundJITChannelState::PendingInitialPayment { + payment_queue: payment_queue.clone(), + } + } else { + return Err(APIError::APIMisuseError { + err: format!("Channel is not in the PendingChannelOpen state.",), + }); + }; + jit_channel.state = new_state; + } else { + return Err(APIError::APIMisuseError { + err: format!( + "Could not find a channel with intercept_scid {}", + intercept_scid + ), + }); + } + }, + None => { + return Err(APIError::APIMisuseError { + err: format!("No counterparty state for: {}", counterparty_node_id), + }); + }, + } + Ok(()) + } + /// Forward [`Event::ChannelReady`] event parameters into this function. /// /// Will forward the intercepted HTLC if it matches a channel