From 199563955d9305b7fde221f6239144f8ddea9815 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 13 Mar 2024 13:47:00 +0100 Subject: [PATCH] feat(udp): use recvmmsg Read up to `BATCH_SIZE = 32` with single `recvmmsg` syscall. Previously `neqo_bin::udp::Socket::recv` would use `recvmmsg`, but provide a single buffer to write into only, effectively using `recvmsg` instead of `recvmmsg`. With this commit `Socket::recv` provides `BATCH_SIZE` number of buffers on each `recvmmsg` syscall, thus reading more than one datagram at a time if available. --- neqo-bin/src/udp.rs | 95 +++++++++++++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 34 deletions(-) diff --git a/neqo-bin/src/udp.rs b/neqo-bin/src/udp.rs index 632a1293d7..a0348628ca 100644 --- a/neqo-bin/src/udp.rs +++ b/neqo-bin/src/udp.rs @@ -9,6 +9,7 @@ use std::{ io::{self, IoSliceMut}, + mem::MaybeUninit, net::{SocketAddr, ToSocketAddrs}, slice, }; @@ -17,6 +18,13 @@ use neqo_common::{Datagram, IpTos}; use quinn_udp::{EcnCodepoint, RecvMeta, Transmit, UdpSocketState}; use tokio::io::Interest; +#[cfg(not(any(target_os = "macos", target_os = "ios")))] +// Chosen somewhat arbitrarily; might benefit from additional tuning. +pub(crate) const BATCH_SIZE: usize = 32; + +#[cfg(any(target_os = "macos", target_os = "ios"))] +pub(crate) const BATCH_SIZE: usize = 1; + /// Socket receive buffer size. /// /// Allows reading multiple datagrams in a single [`Socket::recv`] call. @@ -25,7 +33,8 @@ const RECV_BUF_SIZE: usize = u16::MAX as usize; pub struct Socket { socket: tokio::net::UdpSocket, state: UdpSocketState, - recv_buf: Vec, + // TODO: Rename + recv_buf: [Vec; BATCH_SIZE], } impl Socket { @@ -36,7 +45,11 @@ impl Socket { Ok(Self { state: quinn_udp::UdpSocketState::new((&socket).into())?, socket: tokio::net::UdpSocket::from_std(socket)?, - recv_buf: vec![0; RECV_BUF_SIZE], + recv_buf: (0..BATCH_SIZE) + .map(|_| vec![0; RECV_BUF_SIZE]) + .collect::>() + .try_into() + .expect("successful array instantiation"), }) } @@ -77,18 +90,25 @@ impl Socket { /// Receive a UDP datagram on the specified socket. pub fn recv(&mut self, local_address: &SocketAddr) -> Result, io::Error> { - let mut meta = RecvMeta::default(); - - match self.socket.try_io(Interest::READABLE, || { - self.state.recv( - (&self.socket).into(), - &mut [IoSliceMut::new(&mut self.recv_buf)], - slice::from_mut(&mut meta), - ) + let mut metas = [RecvMeta::default(); BATCH_SIZE]; + + // TODO: Safe? + let mut iovs = MaybeUninit::<[IoSliceMut<'_>; BATCH_SIZE]>::uninit(); + for (i, buf) in self.recv_buf.iter_mut().enumerate() { + unsafe { + iovs.as_mut_ptr() + .cast::() + .add(i) + .write(IoSliceMut::new(buf)); + }; + } + let mut iovs = unsafe { iovs.assume_init() }; + + let msgs = match self.socket.try_io(Interest::READABLE, || { + self.state + .recv((&self.socket).into(), &mut iovs, &mut metas) }) { - Ok(n) => { - assert_eq!(n, 1, "only passed one slice"); - } + Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted => @@ -100,27 +120,34 @@ impl Socket { } }; - if meta.len == 0 { - eprintln!("zero length datagram received?"); - return Ok(vec![]); - } - if meta.len == self.recv_buf.len() { - eprintln!( - "Might have received more than {} bytes", - self.recv_buf.len() - ); - } - - Ok(self.recv_buf[0..meta.len] - .chunks(meta.stride.min(self.recv_buf.len())) - .map(|d| { - Datagram::new( - meta.addr, - *local_address, - meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), - None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 - d, - ) + // TODO + // if meta.len == 0 { + // eprintln!("zero length datagram received?"); + // return Ok(vec![]); + // } + // if meta.len == self.recv_buf.len() { + // eprintln!( + // "Might have received more than {} bytes", + // self.recv_buf.len() + // ); + // } + + Ok(metas + .iter() + .zip(iovs.iter()) + .take(msgs) + .flat_map(|(meta, buf)| { + buf[0..meta.len] + .chunks(meta.stride.min(buf.len())) + .map(|d| { + Datagram::new( + meta.addr, + *local_address, + meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), + None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 + d, + ) + }) }) .collect()) }