From 11ee29416031164ae8f5aec2c437cacd89321229 Mon Sep 17 00:00:00 2001 From: Yair Bakalchuk Date: Sun, 15 Sep 2024 17:02:28 +0300 Subject: [PATCH] feat(batcher): mock BlockBuilder for tests --- Cargo.lock | 3 + Cargo.toml | 1 + crates/batcher/Cargo.toml | 4 + crates/batcher/src/proposal_manager.rs | 116 +++++---- crates/batcher/src/proposal_manager_test.rs | 245 ++++++++++++++++++-- 5 files changed, 308 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a7c14642cd..3a1c53e85b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9691,8 +9691,11 @@ version = "0.0.0" dependencies = [ "assert_matches", "async-trait", + "futures", + "mempool_test_utils", "mockall", "papyrus_config", + "rstest", "serde", "starknet_api", "starknet_batcher_types", diff --git a/Cargo.toml b/Cargo.toml index 48008f0fb5..6d6466a590 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ assert_matches = "1.5.0" async-recursion = "1.1.0" async-stream = "0.3.3" async-trait = "0.1.79" +atomic_refcell = "0.1.13" axum = "0.6.12" base64 = "0.13.0" bincode = "1.3.3" diff --git a/crates/batcher/Cargo.toml b/crates/batcher/Cargo.toml index 933ebc7eb4..ec8f6a3861 100644 --- a/crates/batcher/Cargo.toml +++ b/crates/batcher/Cargo.toml @@ -24,4 +24,8 @@ validator.workspace = true [dev-dependencies] assert_matches.workspace = true +futures.workspace = true +mempool_test_utils.workspace = true mockall.workspace = true +rstest.workspace = true +starknet_api = { workspace = true, features = ["testing"] } diff --git a/crates/batcher/src/proposal_manager.rs b/crates/batcher/src/proposal_manager.rs index a9e8a6a99e..b322d530a4 100644 --- a/crates/batcher/src/proposal_manager.rs +++ b/crates/batcher/src/proposal_manager.rs @@ -1,6 +1,9 @@ use std::collections::BTreeMap; use std::sync::Arc; +use async_trait::async_trait; +#[cfg(test)] +use mockall::automock; use papyrus_config::dumping::{ser_param, SerializeConfig}; use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; use serde::{Deserialize, Serialize}; @@ -17,20 +20,32 @@ pub type ProposalId = u64; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ProposalManagerConfig { + pub block_builder_next_txs_buffer_size: usize, pub max_txs_per_mempool_request: usize, pub outstream_content_buffer_size: usize, } impl Default for ProposalManagerConfig { fn default() -> Self { - // TODO: Get correct value for default max_txs_per_mempool_request. - Self { max_txs_per_mempool_request: 10, outstream_content_buffer_size: 100 } + // TODO: Get correct default values. + Self { + block_builder_next_txs_buffer_size: 100, + max_txs_per_mempool_request: 10, + outstream_content_buffer_size: 100, + } } } impl SerializeConfig for ProposalManagerConfig { fn dump(&self) -> BTreeMap { BTreeMap::from_iter([ + ser_param( + "block_builder_next_txs_buffer_size", + &self.block_builder_next_txs_buffer_size, + "Maximum transactions to fill in the stream buffer for the block builder before \ + blocking", + ParamPrivacyInput::Public, + ), ser_param( "max_txs_per_mempool_request", &self.max_txs_per_mempool_request, @@ -81,13 +96,28 @@ pub(crate) struct ProposalManager { /// At any given time, there can be only one proposal being actively executed (either proposed /// or validated). active_proposal: Arc>>, + active_proposal_handle: Option, + // Use a factory object, to be able to mock BlockBuilder in tests. + block_builder_factory: Arc, } +type ActiveTaskHandle = tokio::task::JoinHandle>; + impl ProposalManager { // TODO: Remove dead_code attribute. #[allow(dead_code)] - pub fn new(config: ProposalManagerConfig, mempool_client: SharedMempoolClient) -> Self { - Self { config, mempool_client, active_proposal: Arc::new(Mutex::new(None)) } + pub fn new( + config: ProposalManagerConfig, + mempool_client: SharedMempoolClient, + block_builder_factory: Arc, + ) -> Self { + Self { + config, + mempool_client, + active_proposal: Arc::new(Mutex::new(None)), + block_builder_factory, + active_proposal_handle: None, + } } /// Starts a new block proposal generation task for the given proposal_id and height with @@ -104,11 +134,13 @@ impl ProposalManager { info!("Starting generation of a new proposal with id {}.", proposal_id); self.set_active_proposal(proposal_id).await?; - let block_builder = block_builder::BlockBuilder {}; - let _handle = tokio::spawn( + let block_builder = self.block_builder_factory.create_block_builder(); + + self.active_proposal_handle = Some(tokio::spawn( BuildProposalTask { mempool_client: self.mempool_client.clone(), output_content_sender, + block_builder_next_txs_buffer_size: self.config.block_builder_next_txs_buffer_size, max_txs_per_mempool_request: self.config.max_txs_per_mempool_request, block_builder, active_proposal: self.active_proposal.clone(), @@ -116,7 +148,7 @@ impl ProposalManager { } .run() .in_current_span(), - ); + )); Ok(()) } @@ -137,30 +169,14 @@ impl ProposalManager { debug!("Set proposal {} as the one being generated.", proposal_id); Ok(()) } -} - -// TODO: Should be defined elsewhere. -#[allow(dead_code)] -mod block_builder { - use starknet_api::executable_transaction::Transaction; - use tokio_stream::Stream; - - #[derive(Debug, PartialEq)] - pub enum Status { - Building, - Ready, - Timeout, - } - - pub struct BlockBuilder {} - impl BlockBuilder { - pub async fn build_block( - &self, - _deadline: tokio::time::Instant, - _mempool_tx_stream: impl Stream, - _output_content_sender: tokio::sync::mpsc::Sender, - ) { + // A helper function for testing purposes (to be able to await the active proposal). + // TODO: Consider making the tests a nested module to allow them to access private members. + #[cfg(test)] + pub async fn await_active_proposal(&mut self) -> Option> { + match self.active_proposal_handle.take() { + Some(handle) => Some(handle.await.unwrap()), + None => None, } } } @@ -170,7 +186,8 @@ struct BuildProposalTask { mempool_client: SharedMempoolClient, output_content_sender: tokio::sync::mpsc::Sender, max_txs_per_mempool_request: usize, - block_builder: block_builder::BlockBuilder, + block_builder_next_txs_buffer_size: usize, + block_builder: Arc, active_proposal: Arc>>, deadline: tokio::time::Instant, } @@ -178,11 +195,10 @@ struct BuildProposalTask { #[allow(dead_code)] impl BuildProposalTask { async fn run(mut self) -> ProposalsManagerResult<()> { - // TODO: Should we use a different config for the stream buffer size? // We convert the receiver to a stream and pass it to the block builder while using the // sender to feed the stream. let (mempool_tx_sender, mempool_tx_receiver) = - tokio::sync::mpsc::channel::(self.max_txs_per_mempool_request); + tokio::sync::mpsc::channel::(self.block_builder_next_txs_buffer_size); let mempool_tx_stream = ReceiverStream::new(mempool_tx_receiver); let building_future = self.block_builder.build_block( self.deadline, @@ -225,6 +241,11 @@ impl BuildProposalTask { loop { // TODO: Get L1 transactions. let mempool_txs = match mempool_client.get_txs(max_txs_per_mempool_request).await { + Ok(txs) if txs.is_empty() => { + // TODO: Consider sleeping for a while. + tokio::task::yield_now().await; + continue; + } Ok(txs) => txs, Err(e) => return e.into(), }; @@ -233,13 +254,10 @@ impl BuildProposalTask { mempool_txs.len() ); for tx in mempool_txs { - if let Err(e) = mempool_tx_sender.send(tx).await.map_err(|err| { - // TODO: should we return the rest of the txs to the mempool? - error!("Failed to send transaction to the block builder: {}.", err); - ProposalManagerError::InternalError - }) { - return e; - } + mempool_tx_sender + .send(tx) + .await + .expect("Channel should remain open during feeding mempool transactions."); } } } @@ -249,3 +267,21 @@ impl BuildProposalTask { *proposal_id = None; } } + +pub type InputTxStream = ReceiverStream; +pub type OutputTxStream = ReceiverStream; + +#[async_trait] +pub trait BlockBuilderTrait: Send + Sync { + async fn build_block( + &self, + deadline: tokio::time::Instant, + tx_stream: InputTxStream, + output_content_sender: tokio::sync::mpsc::Sender, + ); +} + +#[cfg_attr(test, automock)] +pub trait BlockBuilderFactoryTrait: Send + Sync { + fn create_block_builder(&self) -> Arc; +} diff --git a/crates/batcher/src/proposal_manager_test.rs b/crates/batcher/src/proposal_manager_test.rs index 4bd5416955..36d1845fac 100644 --- a/crates/batcher/src/proposal_manager_test.rs +++ b/crates/batcher/src/proposal_manager_test.rs @@ -1,37 +1,177 @@ +use std::ops::Range; use std::sync::Arc; use assert_matches::assert_matches; +use async_trait::async_trait; +use futures::future::BoxFuture; +use futures::FutureExt; +#[cfg(test)] +use mockall::automock; +use rstest::{fixture, rstest}; +use starknet_api::executable_transaction::Transaction; +use starknet_api::felt; +use starknet_api::test_utils::invoke::{executable_invoke_tx, InvokeTxArgs}; +use starknet_api::transaction::TransactionHash; use starknet_mempool_types::communication::MockMempoolClient; +use tokio_stream::StreamExt; -use crate::proposal_manager::{ProposalManager, ProposalManagerConfig, ProposalManagerError}; +use crate::proposal_manager::{ + BlockBuilderTrait, + InputTxStream, + MockBlockBuilderFactoryTrait, + OutputTxStream, + ProposalManager, + ProposalManagerConfig, + ProposalManagerError, +}; -const GENERATION_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(1); +#[fixture] +fn proposal_manager_config() -> ProposalManagerConfig { + ProposalManagerConfig::default() +} + +// TODO: Figure out how to pass expectations to the mock. +#[fixture] +fn block_builder_factory() -> MockBlockBuilderFactoryTrait { + MockBlockBuilderFactoryTrait::new() +} + +// TODO: Figure out how to pass expectations to the mock. +#[fixture] +fn mempool_client() -> MockMempoolClient { + MockMempoolClient::new() +} + +#[fixture] +fn output_streaming( + proposal_manager_config: ProposalManagerConfig, +) -> (tokio::sync::mpsc::Sender, OutputTxStream) { + let (output_content_sender, output_content_receiver) = + tokio::sync::mpsc::channel(proposal_manager_config.outstream_content_buffer_size); + let stream = tokio_stream::wrappers::ReceiverStream::new(output_content_receiver); + (output_content_sender, stream) +} +#[rstest] #[tokio::test] -async fn multiple_proposals_generation_fails() { - let mut mempool_client = MockMempoolClient::new(); +async fn proposal_generation_success( + proposal_manager_config: ProposalManagerConfig, + mut block_builder_factory: MockBlockBuilderFactoryTrait, + mut mempool_client: MockMempoolClient, + output_streaming: (tokio::sync::mpsc::Sender, OutputTxStream), +) { + let n_txs = 2 * proposal_manager_config.max_txs_per_mempool_request; + block_builder_factory + .expect_create_block_builder() + .once() + .returning(move || simulate_build_block(Some(n_txs))); + + mempool_client.expect_get_txs().once().returning(|max_n_txs| Ok(test_txs(0..max_n_txs))); + + mempool_client + .expect_get_txs() + .once() + .returning(|max_n_txs| Ok(test_txs(max_n_txs..2 * max_n_txs))); + mempool_client.expect_get_txs().returning(|_| Ok(vec![])); - let mut proposals_manager = - ProposalManager::new(ProposalManagerConfig::default(), Arc::new(mempool_client)); - let (output_content_sender, _rx) = tokio::sync::mpsc::channel(1); - proposals_manager - .build_block_proposal( - 0, - tokio::time::Instant::now() + GENERATION_TIMEOUT, - output_content_sender, - ) + + let mut proposal_manager = ProposalManager::new( + proposal_manager_config.clone(), + Arc::new(mempool_client), + Arc::new(block_builder_factory), + ); + + let (output_content_sender, stream) = output_streaming; + proposal_manager + .build_block_proposal(0, arbitrary_deadline(), output_content_sender) .await .unwrap(); - let (another_output_content_sender, _another_rx) = tokio::sync::mpsc::channel(1); - let another_generate_request = proposals_manager - .build_block_proposal( - 1, - tokio::time::Instant::now() + GENERATION_TIMEOUT, - another_output_content_sender, - ) - .await; + assert_matches!(proposal_manager.await_active_proposal().await, Some(Ok(()))); + let proposal_content: Vec<_> = stream.collect().await; + assert_eq!(proposal_content, test_txs(0..n_txs)); +} + +#[rstest] +#[tokio::test] +async fn consecutive_proposal_generations_success( + proposal_manager_config: ProposalManagerConfig, + mut block_builder_factory: MockBlockBuilderFactoryTrait, + mut mempool_client: MockMempoolClient, +) { + let n_txs = proposal_manager_config.max_txs_per_mempool_request; + block_builder_factory + .expect_create_block_builder() + .times(2) + .returning(move || simulate_build_block(Some(n_txs))); + + let expected_txs = test_txs(0..proposal_manager_config.max_txs_per_mempool_request); + let mempool_txs = expected_txs.clone(); + mempool_client.expect_get_txs().returning(move |_max_n_txs| Ok(mempool_txs.clone())); + + let mut proposal_manager = ProposalManager::new( + proposal_manager_config.clone(), + Arc::new(mempool_client), + Arc::new(block_builder_factory), + ); + + let (output_content_sender, stream) = output_streaming(proposal_manager_config.clone()); + proposal_manager + .build_block_proposal(0, arbitrary_deadline(), output_content_sender) + .await + .unwrap(); + + // Make sure the first proposal generated successfully. + assert_matches!(proposal_manager.await_active_proposal().await, Some(Ok(()))); + let v: Vec<_> = stream.collect().await; + assert_eq!(v, expected_txs); + + let (output_content_sender, stream) = output_streaming(proposal_manager_config); + proposal_manager + .build_block_proposal(1, arbitrary_deadline(), output_content_sender) + .await + .unwrap(); + + // Make sure the proposal generated successfully. + assert_matches!(proposal_manager.await_active_proposal().await, Some(Ok(()))); + let proposal_content: Vec<_> = stream.collect().await; + assert_eq!(proposal_content, expected_txs); +} + +#[rstest] +#[tokio::test] +async fn multiple_proposals_generation_fail( + proposal_manager_config: ProposalManagerConfig, + mut block_builder_factory: MockBlockBuilderFactoryTrait, + mut mempool_client: MockMempoolClient, +) { + // The block builder will never stop. + block_builder_factory + .expect_create_block_builder() + .once() + .returning(|| simulate_build_block(None)); + + mempool_client.expect_get_txs().returning(|_| Ok(vec![])); + let mut proposal_manager = ProposalManager::new( + proposal_manager_config.clone(), + Arc::new(mempool_client), + Arc::new(block_builder_factory), + ); + + // A proposal that will never finish. + let (output_content_sender, _stream) = output_streaming(proposal_manager_config.clone()); + proposal_manager + .build_block_proposal(0, arbitrary_deadline(), output_content_sender) + .await + .unwrap(); + + // Try to generate another proposal while the first one is still being generated. + let (another_output_content_sender, _another_stream) = + output_streaming(proposal_manager_config); + let another_generate_request = proposal_manager + .build_block_proposal(1, arbitrary_deadline(), another_output_content_sender) + .await; assert_matches!( another_generate_request, Err(ProposalManagerError::AlreadyGeneratingProposal { @@ -40,3 +180,66 @@ async fn multiple_proposals_generation_fails() { }) if current_generating_proposal_id == 0 && new_proposal_id == 1 ); } + +fn arbitrary_deadline() -> tokio::time::Instant { + const GENERATION_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(1); + tokio::time::Instant::now() + GENERATION_TIMEOUT +} + +fn test_txs(tx_hash_range: Range) -> Vec { + tx_hash_range + .map(|i| { + Transaction::Invoke(executable_invoke_tx(InvokeTxArgs { + tx_hash: TransactionHash(felt!(u128::try_from(i).unwrap())), + ..Default::default() + })) + }) + .collect() +} + +fn simulate_build_block(n_txs: Option) -> Arc { + let mut mock_block_builder = MockBlockBuilderTraitWrapper::new(); + mock_block_builder.expect_build_block().return_once( + move |deadline, mempool_tx_stream, output_content_sender| { + simulate_block_builder(deadline, mempool_tx_stream, output_content_sender, n_txs) + .boxed() + }, + ); + Arc::new(mock_block_builder) +} + +async fn simulate_block_builder( + _deadline: tokio::time::Instant, + mempool_tx_stream: InputTxStream, + output_sender: tokio::sync::mpsc::Sender, + n_txs_to_take: Option, +) { + let mut mempool_tx_stream = mempool_tx_stream.take(n_txs_to_take.unwrap_or(usize::MAX)); + while let Some(tx) = mempool_tx_stream.next().await { + output_sender.send(tx).await.unwrap(); + } +} + +// A wrapper trait to allow mocking the BlockBuilderTrait in tests. +#[cfg_attr(test, automock)] +trait BlockBuilderTraitWrapper: Send + Sync { + // Equivalent to: async fn build_block(&self, deadline: tokio::time::Instant); + fn build_block( + &self, + deadline: tokio::time::Instant, + tx_stream: InputTxStream, + output_content_sender: tokio::sync::mpsc::Sender, + ) -> BoxFuture<'_, ()>; +} + +#[async_trait] +impl BlockBuilderTrait for T { + async fn build_block( + &self, + deadline: tokio::time::Instant, + tx_stream: InputTxStream, + output_content_sender: tokio::sync::mpsc::Sender, + ) { + self.build_block(deadline, tx_stream, output_content_sender).await + } +}