From 25f71ca51867778d13a703daecf9ba3128d3cde3 Mon Sep 17 00:00:00 2001 From: Connor Slade Date: Tue, 31 Oct 2023 08:14:51 -0400 Subject: [PATCH] Dynamic dispatch on socket type --- lib/extensions/range.rs | 2 +- lib/proto/websocket/frame.rs | 4 ++-- lib/response.rs | 4 ++-- lib/socket.rs | 33 ++++++++++++++++++++++++++++----- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/lib/extensions/range.rs b/lib/extensions/range.rs index 269becd..6337289 100644 --- a/lib/extensions/range.rs +++ b/lib/extensions/range.rs @@ -43,7 +43,7 @@ struct RangeResponse { impl Middleware for Range { // Inject the Accept-Ranges header into the response. fn post(&self, req: &Request, res: &mut Response) -> MiddleResult { - if req.method != Method::GET || req.method != Method::GET { + if req.method != Method::GET { return MiddleResult::Continue; } diff --git a/lib/proto/websocket/frame.rs b/lib/proto/websocket/frame.rs index 8938795..d695454 100644 --- a/lib/proto/websocket/frame.rs +++ b/lib/proto/websocket/frame.rs @@ -5,7 +5,7 @@ use std::{ }; use super::xor_mask; -use crate::trace::LazyFmt; +use crate::{trace::LazyFmt, socket::SocketStream}; /// ## Frame Layout /// ```plain @@ -126,7 +126,7 @@ impl Frame { buf } - pub fn write(&self, socket: &mut TcpStream) -> io::Result<()> { + pub fn write(&self, socket: &mut SocketStream) -> io::Result<()> { let buf = self.to_bytes(); trace!(Level::Debug, "[WS] Writing: {:?}", buf); diff --git a/lib/response.rs b/lib/response.rs index fe862ae..399052b 100644 --- a/lib/response.rs +++ b/lib/response.rs @@ -9,7 +9,7 @@ use crate::consts; use crate::header::{HeaderName, Headers}; use crate::internal::sync::ForceLockMutex; use crate::proto::http::status::Status; -use crate::socket::Socket; +use crate::socket::{Socket, Stream, SocketStream}; use crate::{ error::Result, header::headers_to_string, internal::handle::Writeable, Content, Header, SetCookie, @@ -366,7 +366,7 @@ impl ResponseBody { /// Writes a ResponseBody to a TcpStream. /// Either in one go if it is static or in chunks if it is a stream. - fn write(&mut self, stream: &mut TcpStream) -> Result<()> { + fn write(&mut self, stream: &mut SocketStream) -> Result<()> { match self { ResponseBody::Empty => {} ResponseBody::Static(data) => stream.write_all(data)?, diff --git a/lib/socket.rs b/lib/socket.rs index 9e85542..e65940f 100644 --- a/lib/socket.rs +++ b/lib/socket.rs @@ -1,5 +1,6 @@ use std::{ - net::TcpStream, + io::{self, Read, Write}, + net::{IpAddr, Shutdown, SocketAddr, TcpStream}, ops::Deref, sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, @@ -12,10 +13,32 @@ use crate::{ response::ResponseFlag, }; +pub type SocketStream = Box; + +pub trait Stream: Read + Write { + fn peer_addr(&self) -> io::Result; + fn try_clone(&self) -> io::Result; + fn shutdown(&self, shutdown: Shutdown) -> io::Result<()>; +} + +impl Stream for TcpStream { + fn peer_addr(&self) -> io::Result { + self.peer_addr() + } + + fn try_clone(&self) -> io::Result { + Ok(self.try_clone().map(Box::new)?) + } + + fn shutdown(&self, shutdown: Shutdown) -> io::Result<()> { + self.shutdown(shutdown) + } +} + /// Socket is a wrapper around TcpStream that allows for sending a response from other threads. pub struct Socket { /// The internal TcpStream. - pub socket: Mutex, + pub socket: Mutex>, /// A unique identifier that uniquely identifies this socket. pub id: u64, /// A barrier that is used to wait for the response to be sent in the case of a guaranteed send. @@ -32,10 +55,10 @@ pub struct Socket { impl Socket { /// Create a new `Socket` from a `TcpStream`. /// Will also create a new unique identifier for the socket. - pub(crate) fn new(socket: TcpStream) -> Self { + pub(crate) fn new(socket: impl Stream + Send + Sync + 'static) -> Self { static ID: AtomicU64 = AtomicU64::new(0); Self { - socket: Mutex::new(socket), + socket: Mutex::new(Box::new(socket)), id: ID.fetch_add(1, Ordering::Relaxed), barrier: Arc::new(SingleBarrier::new()), raw: AtomicBool::new(false), @@ -76,7 +99,7 @@ impl Socket { } impl Deref for Socket { - type Target = Mutex; + type Target = Mutex>; fn deref(&self) -> &Self::Target { &self.socket