From 66375638886c6e0145fbeda7fe9d8ac08d288aed Mon Sep 17 00:00:00 2001 From: nazeh Date: Mon, 16 Sep 2024 19:13:41 +0300 Subject: [PATCH] test(legacy): add test to overriding the resolved socket's port --- src/client/legacy/connect/http.rs | 36 ++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/client/legacy/connect/http.rs b/src/client/legacy/connect/http.rs index 27b5458..bcad03f 100644 --- a/src/client/legacy/connect/http.rs +++ b/src/client/legacy/connect/http.rs @@ -433,11 +433,7 @@ where .map_err(ConnectError::dns)?; let addrs = addrs .map(|mut addr| { - // Respect explicit ports in the URI, - // and non `0` ports resolved from a custom dns resolver. - if dst.port().is_some() || addr.port() == 0 { - addr.set_port(port) - }; + set_port(&mut addr, port, dst.port().is_some()); addr }) @@ -830,9 +826,19 @@ impl ConnectingTcp<'_> { } } +/// Respect explicit ports in the URI, if none, either +/// keep non `0` ports resolved from a custom dns resolver, +/// or use the default port for the scheme. +fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) { + if explicit || addr.port() == 0 { + addr.set_port(host_port) + }; +} + #[cfg(test)] mod tests { use std::io; + use std::net::SocketAddr; use ::http::Uri; @@ -841,6 +847,8 @@ mod tests { use super::super::sealed::{Connect, ConnectSvc}; use super::{Config, ConnectError, HttpConnector}; + use super::set_port; + async fn connect( connector: C, dst: Uri, @@ -1239,4 +1247,22 @@ mod tests { panic!("test failed"); } } + + #[test] + fn test_set_port() { + // Respect explicit ports no matter what the resolved port is. + let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881)); + set_port(&mut addr, 42, true); + assert_eq!(addr.port(), 42); + + // Ignore default host port, and use the socket port instead. + let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881)); + set_port(&mut addr, 443, false); + assert_eq!(addr.port(), 6881); + + // Use the default port if the resolved port is `0`. + let mut addr = SocketAddr::from(([0, 0, 0, 0], 0)); + set_port(&mut addr, 443, false); + assert_eq!(addr.port(), 443); + } }