Skip to content

Commit

Permalink
refactor(mempool_infra): change remote server to send messages to loc…
Browse files Browse the repository at this point in the history
…al server

commit-id:7e0d8402
  • Loading branch information
Itay-Tsabary-Starkware committed Sep 22, 2024
1 parent 0faafcc commit 98ecb4e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 109 deletions.
13 changes: 1 addition & 12 deletions crates/batcher/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
use std::net::IpAddr;

use async_trait::async_trait;
use starknet_batcher_types::communication::{
BatcherRequest,
BatcherRequestAndResponseSender,
BatcherResponse,
};
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_server::{LocalComponentServer, RemoteComponentServer};
use starknet_mempool_infra::component_server::LocalComponentServer;
use tokio::sync::mpsc::Receiver;

use crate::batcher::Batcher;

pub type LocalBatcherServer = LocalComponentServer<Batcher, BatcherRequest, BatcherResponse>;
pub type RemoteBatcherServer = RemoteComponentServer<Batcher, BatcherRequest, BatcherResponse>;

pub fn create_local_batcher_server(
batcher: Batcher,
Expand All @@ -22,14 +19,6 @@ pub fn create_local_batcher_server(
LocalComponentServer::new(batcher, rx_batcher)
}

pub fn create_remote_batcher_server(
batcher: Batcher,
ip_address: IpAddr,
port: u16,
) -> RemoteBatcherServer {
RemoteComponentServer::new(batcher, ip_address, port)
}

#[async_trait]
impl ComponentRequestHandler<BatcherRequest, BatcherResponse> for Batcher {
async fn handle_request(&mut self, request: BatcherRequest) -> BatcherResponse {
Expand Down
14 changes: 1 addition & 13 deletions crates/consensus_manager/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
use std::net::IpAddr;

use async_trait::async_trait;
use starknet_consensus_manager_types::communication::{
ConsensusManagerRequest,
ConsensusManagerRequestAndResponseSender,
ConsensusManagerResponse,
};
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_server::{LocalActiveComponentServer, RemoteComponentServer};
use starknet_mempool_infra::component_server::LocalActiveComponentServer;
use tokio::sync::mpsc::Receiver;

use crate::consensus_manager::ConsensusManager;

pub type LocalConsensusManagerServer =
LocalActiveComponentServer<ConsensusManager, ConsensusManagerRequest, ConsensusManagerResponse>;
pub type RemoteConsensusManagerServer =
RemoteComponentServer<ConsensusManager, ConsensusManagerRequest, ConsensusManagerResponse>;

pub fn create_local_consensus_manager_server(
consensus_manager: ConsensusManager,
Expand All @@ -24,14 +20,6 @@ pub fn create_local_consensus_manager_server(
LocalActiveComponentServer::new(consensus_manager, rx_consensus_manager)
}

pub fn create_remote_consensus_manager_server(
consensus_manager: ConsensusManager,
ip_address: IpAddr,
port: u16,
) -> RemoteConsensusManagerServer {
RemoteComponentServer::new(consensus_manager, ip_address, port)
}

#[async_trait]
impl ComponentRequestHandler<ConsensusManagerRequest, ConsensusManagerResponse>
for ConsensusManager
Expand Down
16 changes: 1 addition & 15 deletions crates/mempool/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::net::IpAddr;

use async_trait::async_trait;
use starknet_api::executable_transaction::Transaction;
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_runner::ComponentStarter;
use starknet_mempool_infra::component_server::{LocalComponentServer, RemoteComponentServer};
use starknet_mempool_infra::component_server::LocalComponentServer;
use starknet_mempool_types::communication::{
MempoolRequest,
MempoolRequestAndResponseSender,
Expand All @@ -19,9 +17,6 @@ use crate::mempool::Mempool;
pub type MempoolServer =
LocalComponentServer<MempoolCommunicationWrapper, MempoolRequest, MempoolResponse>;

pub type RemoteMempoolServer =
RemoteComponentServer<MempoolCommunicationWrapper, MempoolRequest, MempoolResponse>;

pub fn create_mempool_server(
mempool: Mempool,
rx_mempool: Receiver<MempoolRequestAndResponseSender>,
Expand All @@ -30,15 +25,6 @@ pub fn create_mempool_server(
LocalComponentServer::new(communication_wrapper, rx_mempool)
}

pub fn create_remote_mempool_server(
mempool: Mempool,
ip_address: IpAddr,
port: u16,
) -> RemoteMempoolServer {
let communication_wrapper = MempoolCommunicationWrapper::new(mempool);
RemoteComponentServer::new(communication_wrapper, ip_address, port)
}

/// Wraps the mempool to enable inbound async communication from other components.
pub struct MempoolCommunicationWrapper {
mempool: Mempool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ where
let (res_tx, mut res_rx) = channel::<Response>(1);
let request_and_res_tx = ComponentRequestAndResponseSender { request, tx: res_tx };
self.tx.send(request_and_res_tx).await.expect("Outbound connection should be open.");

res_rx.recv().await.expect("Inbound connection should be open.")
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
Expand All @@ -10,14 +8,10 @@ use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::sync::Mutex;

use super::definitions::ComponentServerStarter;
use crate::component_definitions::{
ComponentRequestHandler,
ServerError,
APPLICATION_OCTET_STREAM,
};
use crate::component_client::LocalComponentClient;
use crate::component_definitions::{ServerError, APPLICATION_OCTET_STREAM};

/// The `RemoteComponentServer` struct is a generic server that handles requests and responses for a
/// specified component. It receives requests, processes them using the provided component, and
Expand Down Expand Up @@ -47,6 +41,7 @@ use crate::component_definitions::{
/// use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter};
/// use tokio::task;
///
/// use crate::starknet_mempool_infra::component_client::LocalComponentClient;
/// use crate::starknet_mempool_infra::component_definitions::ComponentRequestHandler;
/// use crate::starknet_mempool_infra::component_server::{
/// ComponentServerStarter,
Expand Down Expand Up @@ -84,67 +79,59 @@ use crate::component_definitions::{
///
/// #[tokio::main]
/// async fn main() {
/// // Instantiate the component.
/// let component = MyComponent {};
/// // Instantiate a local client to communicate with component.
/// let (tx, _rx) = tokio::sync::mpsc::channel(32);
/// let local_client = LocalComponentClient::<MyRequest, MyResponse>::new(tx);
///
/// // Set the ip address and port of the server's socket.
/// let ip_address = std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
/// let port: u16 = 8080;
///
/// // Instantiate the server.
/// let mut server = RemoteComponentServer::<MyComponent, MyRequest, MyResponse>::new(
/// component, ip_address, port,
/// );
/// let mut server =
/// RemoteComponentServer::<MyRequest, MyResponse>::new(local_client, ip_address, port);
///
/// // Start the server in a new task.
/// task::spawn(async move {
/// server.start().await;
/// });
/// }
/// ```
pub struct RemoteComponentServer<Component, Request, Response>
pub struct RemoteComponentServer<Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: DeserializeOwned + Send + 'static,
Response: Serialize + 'static,
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
{
socket: SocketAddr,
component: Arc<Mutex<Component>>,
_req: PhantomData<Request>,
_res: PhantomData<Response>,
local_client: LocalComponentClient<Request, Response>,
}

impl<Component, Request, Response> RemoteComponentServer<Component, Request, Response>
impl<Request, Response> RemoteComponentServer<Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: DeserializeOwned + Send + 'static,
Response: Serialize + 'static,
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
{
pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self {
Self {
component: Arc::new(Mutex::new(component)),
socket: SocketAddr::new(ip_address, port),
_req: PhantomData,
_res: PhantomData,
}
pub fn new(
local_client: LocalComponentClient<Request, Response>,
ip_address: IpAddr,
port: u16,
) -> Self {
Self { local_client, socket: SocketAddr::new(ip_address, port) }
}

async fn handler(
http_request: HyperRequest<Body>,
component: Arc<Mutex<Component>>,
local_client: LocalComponentClient<Request, Response>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let http_response = match deserialize(&body_bytes) {
Ok(component_request) => {
// Acquire the lock for component computation, release afterwards.
let component_response =
{ component.lock().await.handle_request(component_request).await };
Ok(request) => {
let response = local_client.send(request).await;
HyperResponse::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&component_response)
.expect("Response serialization should succeed"),
serialize(&response).expect("Response serialization should succeed"),
))
}
Err(error) => {
Expand All @@ -161,19 +148,17 @@ where
}

#[async_trait]
impl<Component, Request, Response> ComponentServerStarter
for RemoteComponentServer<Component, Request, Response>
impl<Request, Response> ComponentServerStarter for RemoteComponentServer<Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
{
async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let component = Arc::clone(&self.component);
let local_client = self.local_client.clone();
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::handler(req, Arc::clone(&component))
Self::handler(req, local_client.clone())
}))
}
});
Expand Down
73 changes: 48 additions & 25 deletions crates/mempool_infra/tests/remote_component_client_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,24 @@ use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
use rstest::rstest;
use serde::Serialize;
use starknet_mempool_infra::component_client::{ClientError, ClientResult, RemoteComponentClient};
use starknet_mempool_infra::component_client::{
ClientError,
ClientResult,
LocalComponentClient,
RemoteComponentClient,
};
use starknet_mempool_infra::component_definitions::{
ComponentRequestAndResponseSender,
ComponentRequestHandler,
ServerError,
APPLICATION_OCTET_STREAM,
};
use starknet_mempool_infra::component_server::{ComponentServerStarter, RemoteComponentServer};
use starknet_mempool_infra::component_server::{
ComponentServerStarter,
LocalComponentServer,
RemoteComponentServer,
};
use tokio::sync::mpsc::channel;
use tokio::sync::Mutex;
use tokio::task;

Expand Down Expand Up @@ -108,10 +119,10 @@ impl ComponentRequestHandler<ComponentBRequest, ComponentBResponse> for Componen
}

async fn verify_error(
a_client: impl ComponentAClientTrait,
a_remote_client: impl ComponentAClientTrait,
expected_error_contained_keywords: &[&str],
) {
let Err(error) = a_client.a_get_value().await else {
let Err(error) = a_remote_client.a_get_value().await else {
panic!("Expected an error.");
};
assert_error_contains_keywords(error.to_string(), expected_error_contained_keywords)
Expand Down Expand Up @@ -156,29 +167,41 @@ where
}

async fn setup_for_tests(setup_value: ValueB, a_port: u16, b_port: u16) {
let a_client = ComponentAClient::new(LOCAL_IP, a_port, MAX_RETRIES);
let b_client = ComponentBClient::new(LOCAL_IP, b_port, MAX_RETRIES);

let component_a = ComponentA::new(Box::new(b_client));
let component_b = ComponentB::new(setup_value, Box::new(a_client.clone()));

let mut component_a_server = RemoteComponentServer::<
ComponentA,
ComponentARequest,
ComponentAResponse,
>::new(component_a, LOCAL_IP, a_port);
let mut component_b_server = RemoteComponentServer::<
ComponentB,
ComponentBRequest,
ComponentBResponse,
>::new(component_b, LOCAL_IP, b_port);
let a_remote_client = ComponentAClient::new(LOCAL_IP, a_port, MAX_RETRIES);
let b_remote_client = ComponentBClient::new(LOCAL_IP, b_port, MAX_RETRIES);

let component_a = ComponentA::new(Box::new(b_remote_client));
let component_b = ComponentB::new(setup_value, Box::new(a_remote_client.clone()));

let (tx_a, rx_a) =
channel::<ComponentRequestAndResponseSender<ComponentARequest, ComponentAResponse>>(32);
let (tx_b, rx_b) =
channel::<ComponentRequestAndResponseSender<ComponentBRequest, ComponentBResponse>>(32);

let a_local_client = LocalComponentClient::<ComponentARequest, ComponentAResponse>::new(tx_a);
let b_local_client = LocalComponentClient::<ComponentBRequest, ComponentBResponse>::new(tx_b);

let mut component_a_local_server = LocalComponentServer::new(component_a, rx_a);
let mut component_b_local_server = LocalComponentServer::new(component_b, rx_b);

let mut component_a_remote_server =
RemoteComponentServer::new(a_local_client, LOCAL_IP, a_port);
let mut component_b_remote_server =
RemoteComponentServer::new(b_local_client, LOCAL_IP, b_port);

task::spawn(async move {
component_a_local_server.start().await;
});
task::spawn(async move {
component_b_local_server.start().await;
});

task::spawn(async move {
component_a_server.start().await;
component_a_remote_server.start().await;
});

task::spawn(async move {
component_b_server.start().await;
component_b_remote_server.start().await;
});

// Todo(uriel): Get rid of this
Expand All @@ -189,9 +212,9 @@ async fn setup_for_tests(setup_value: ValueB, a_port: u16, b_port: u16) {
async fn test_proper_setup() {
let setup_value: ValueB = 90;
setup_for_tests(setup_value, A_PORT_TEST_SETUP, B_PORT_TEST_SETUP).await;
let a_client = ComponentAClient::new(LOCAL_IP, A_PORT_TEST_SETUP, MAX_RETRIES);
let b_client = ComponentBClient::new(LOCAL_IP, B_PORT_TEST_SETUP, MAX_RETRIES);
test_a_b_functionality(a_client, b_client, setup_value.into()).await;
let a_remote_client = ComponentAClient::new(LOCAL_IP, A_PORT_TEST_SETUP, MAX_RETRIES);
let b_remote_client = ComponentBClient::new(LOCAL_IP, B_PORT_TEST_SETUP, MAX_RETRIES);
test_a_b_functionality(a_remote_client, b_remote_client, setup_value.into()).await;
}

#[tokio::test]
Expand Down

0 comments on commit 98ecb4e

Please sign in to comment.