diff --git a/src/lsps2/service.rs b/src/lsps2/service.rs index ca0911e..e27c840 100644 --- a/src/lsps2/service.rs +++ b/src/lsps2/service.rs @@ -81,6 +81,13 @@ enum HTLCInterceptedAction { ForwardPayment(ForwardPaymentAction), } +/// Possible actions that need to be taken when a payment is forwarded. +#[derive(Debug, PartialEq)] +enum PaymentForwardedAction { + ForwardPayment(ForwardPaymentAction), + ForwardHTLCs(ForwardHTLCsAction), +} + /// The forwarding of a payment while skimming the JIT channel opening fee. #[derive(Debug, PartialEq)] struct ForwardPaymentAction(ChannelId, FeePayment); @@ -318,23 +325,42 @@ impl OutboundJITChannelState { } fn payment_forwarded( - &mut self, - ) -> Result<(Self, Option), ChannelStateError> { + &mut self, skimmed_fee_msat: Option, + ) -> Result<(Self, Option), ChannelStateError> { match self { OutboundJITChannelState::PendingPaymentForward { - payment_queue, channel_id, .. + payment_queue, + channel_id, + opening_fee_msat, } => { let mut payment_queue_lock = payment_queue.lock().unwrap(); - let payment_forwarded = - OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; - let htlcs = payment_queue_lock - .clear() - .into_iter() - .map(|(_, htlcs)| htlcs) - .flatten() - .collect(); - let forward_htlcs = ForwardHTLCsAction(*channel_id, htlcs); - Ok((payment_forwarded, Some(forward_htlcs))) + + let skimmed_fee_msat = skimmed_fee_msat.unwrap_or(0); + let remaining_fee = opening_fee_msat.saturating_sub(skimmed_fee_msat); + + if remaining_fee > 0 { + let (state, payment_action) = try_get_payment( + Arc::clone(payment_queue), + payment_queue_lock, + *channel_id, + remaining_fee, + ); + Ok((state, payment_action.map(|pa| PaymentForwardedAction::ForwardPayment(pa)))) + } else { + let payment_forwarded = + OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; + let htlcs = payment_queue_lock + .clear() + .into_iter() + .map(|(_, htlcs)| htlcs) + .flatten() + .collect(); + let forward_htlcs = ForwardHTLCsAction(*channel_id, htlcs); + Ok(( + payment_forwarded, + Some(PaymentForwardedAction::ForwardHTLCs(forward_htlcs)), + )) + } }, OutboundJITChannelState::PaymentForwarded { channel_id } => { let payment_forwarded = @@ -368,6 +394,10 @@ impl OutboundJITChannel { } } + pub fn has_paid_fee(&self) -> bool { + matches!(self.state, OutboundJITChannelState::PaymentForwarded { .. }) + } + fn htlc_intercepted( &mut self, htlc: InterceptedHTLC, ) -> Result, LightningError> { @@ -391,8 +421,10 @@ impl OutboundJITChannel { Ok(action) } - fn payment_forwarded(&mut self) -> Result, LightningError> { - let (new_state, action) = self.state.payment_forwarded()?; + fn payment_forwarded( + &mut self, skimmed_fee_msat: Option, + ) -> Result, LightningError> { + let (new_state, action) = self.state.payment_forwarded(skimmed_fee_msat)?; self.state = new_state; Ok(action) } @@ -818,7 +850,9 @@ where /// greater or equal to 0.0.107. /// /// [`Event::PaymentForwarded`]: lightning::events::Event::PaymentForwarded - pub fn payment_forwarded(&self, next_channel_id: ChannelId) -> Result<(), APIError> { + pub fn payment_forwarded( + &self, next_channel_id: ChannelId, skimmed_fee_msat: Option, + ) -> Result { if let Some(counterparty_node_id) = self.peer_by_channel_id.read().unwrap().get(&next_channel_id) { @@ -832,8 +866,10 @@ where if let Some(jit_channel) = peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid) { - match jit_channel.payment_forwarded() { - Ok(Some(ForwardHTLCsAction(channel_id, htlcs))) => { + match jit_channel.payment_forwarded(skimmed_fee_msat) { + Ok(Some(PaymentForwardedAction::ForwardHTLCs( + ForwardHTLCsAction(channel_id, htlcs), + ))) => { for htlc in htlcs { self.channel_manager.get_cm().forward_intercepted_htlc( htlc.intercept_id, @@ -843,6 +879,29 @@ where )?; } }, + Ok(Some(PaymentForwardedAction::ForwardPayment( + ForwardPaymentAction( + channel_id, + FeePayment { htlcs, opening_fee_msat }, + ), + ))) => { + let amounts_to_forward_msat = + calculate_amount_to_forward_per_htlc( + &htlcs, + opening_fee_msat, + ); + + for (intercept_id, amount_to_forward_msat) in + amounts_to_forward_msat + { + self.channel_manager.get_cm().forward_intercepted_htlc( + intercept_id, + &channel_id, + *counterparty_node_id, + amount_to_forward_msat, + )?; + } + }, Ok(None) => {}, Err(e) => { return Err(APIError::APIMisuseError { @@ -853,6 +912,7 @@ where }) }, } + return Ok(jit_channel.has_paid_fee()); } } else { return Err(APIError::APIMisuseError { @@ -868,7 +928,7 @@ where } } - Ok(()) + Ok(false) } /// Used by LSP to fail intercepted htlcs backwards when the channel open fails for any reason. @@ -1476,12 +1536,18 @@ mod tests { } state = new_state; } + + // TODO: how do I get the expected skimmed amount here + // Payment completes, queued payments get forwarded. { - let (new_state, action) = state.payment_forwarded().unwrap(); + let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap(); assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. })); match action { - Some(ForwardHTLCsAction(channel_id, htlcs)) => { + Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction( + channel_id, + htlcs, + ))) => { assert_eq!(channel_id, ChannelId([200; 32])); assert_eq!( htlcs, @@ -1617,12 +1683,18 @@ mod tests { } state = new_state; } + + // TODO: how do I grab the expected skimmed fee amount here. + // Payment completes, queued payments get forwarded. { - let (new_state, action) = state.payment_forwarded().unwrap(); + let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap(); assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. })); match action { - Some(ForwardHTLCsAction(channel_id, htlcs)) => { + Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction( + channel_id, + htlcs, + ))) => { assert_eq!(channel_id, ChannelId([200; 32])); assert_eq!( htlcs,