Skip to content

Commit

Permalink
Port from reqwest to ureq (georust#205)
Browse files Browse the repository at this point in the history
Addresses georust#189, reducing `cargo tree -F network | wc -l` from 296 to
190.
  • Loading branch information
TomFryersMidsummer authored and urschrei committed Jul 1, 2024
1 parent 4c4846f commit 3dadba1
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 64 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ geo-types = { version = "0.7.10", optional = true }
libc = "0.2.119"
num-traits = "0.2.14"
thiserror = "1.0.30"
reqwest = { version = "0.12.0", optional = true, default-features = false, features = ["blocking", "rustls-tls"] }
ureq = { version = "2.0.0", optional = true }

[workspace]
members = ["proj-sys"]
Expand All @@ -28,7 +28,7 @@ members = ["proj-sys"]
default = ["geo-types"]
bundled_proj = [ "proj-sys/bundled_proj" ]
pkg_config = [ "proj-sys/pkg_config" ]
network = ["reqwest", "proj-sys/network"]
network = ["ureq", "proj-sys/network"]

[dev-dependencies]
# approx version must match the one used in geo-types
Expand Down
130 changes: 73 additions & 57 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
// This functionality based on https://github.com/OSGeo/PROJ/blob/master/src/networkfilemanager.cpp#L1675
use proj_sys::{proj_context_set_network_callbacks, PJ_CONTEXT, PROJ_NETWORK_HANDLE};

use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::Method;
use std::collections::HashMap;
use std::ffi::CString;
use std::io::Read;
use std::ops::Range;
use std::os::raw::c_ulonglong;
use std::ptr::{self, NonNull};
use ureq::{Agent, Request, Response};

use crate::proj::{ProjError, _string};
use libc::c_char;
Expand All @@ -30,11 +32,14 @@ const CLIENT: &str = concat!("proj-rs/", env!("CARGO_PKG_VERSION"));
const MAX_RETRIES: u8 = 8;
// S3 sometimes sends these in place of actual client errors, so retry instead of erroring
const RETRY_CODES: [u16; 4] = [429, 500, 502, 504];
const SUCCESS_ERROR_CODES: Range<u16> = 200..300;
const CLIENT_ERROR_CODES: Range<u16> = 400..500;
const SERVER_ERROR_CODES: Range<u16> = 500..600;

/// This struct is cast to `c_void`, then to `PROJ_NETWORK_HANDLE` so it can be passed around
struct HandleData {
url: String,
headers: reqwest::header::HeaderMap,
headers: HashMap<String, String>,
// this raw pointer is handed out to libproj but never returned,
// so a copy of the pointer (raw pointers are Copy) is stored here.
// Note to future self: are you 100% sure that the pointer is never read again
Expand All @@ -43,11 +48,7 @@ struct HandleData {
}

impl HandleData {
fn new(
url: String,
headers: reqwest::header::HeaderMap,
hptr: Option<NonNull<c_char>>,
) -> Self {
fn new(url: String, headers: HashMap<String, String>, hptr: Option<NonNull<c_char>>) -> Self {
Self { url, headers, hptr }
}
}
Expand All @@ -74,36 +75,34 @@ fn get_wait_time_exp(retrycount: i32) -> u64 {

/// Process CDN response: handle retries in case of server error, or early return for client errors
/// Successful retry data is stored into res
fn error_handler(res: &mut Response, rb: RequestBuilder) -> Result<&Response, ProjError> {
let mut status = res.status().as_u16();
fn error_handler(res: &mut Response, rb: Request) -> Result<&Response, ProjError> {
let mut retries = 0;
// Check whether something went wrong on the server, or if it's an S3 retry code
if res.status().is_server_error() || RETRY_CODES.contains(&status) {
if SERVER_ERROR_CODES.contains(&res.status()) || RETRY_CODES.contains(&res.status()) {
// Start retrying: up to MAX_RETRIES
while (res.status().is_server_error() || RETRY_CODES.contains(&status))
while (SERVER_ERROR_CODES.contains(&res.status()) || RETRY_CODES.contains(&res.status()))
&& retries <= MAX_RETRIES
{
retries += 1;
let wait = time::Duration::from_millis(get_wait_time_exp(retries as i32));
thread::sleep(wait);
let retry = rb.try_clone().ok_or(ProjError::RequestCloneError)?;
*res = retry.send()?;
status = res.status().as_u16();
let retry = rb.clone();
*res = retry.call()?;
}
// Not a timeout or known S3 retry code: bail out
} else if res.status().is_client_error() {
} else if CLIENT_ERROR_CODES.contains(&res.status()) {
return Err(ProjError::DownloadError(
res.status().as_str().to_string(),
res.url().to_string(),
res.status_text().to_string(),
res.get_url().to_string(),
retries,
));
}
// Retries have been exhausted OR
// The loop ended prematurely due to a different error
if !res.status().is_success() {
if !SUCCESS_ERROR_CODES.contains(&res.status()) {
return Err(ProjError::DownloadError(
res.status().as_str().to_string(),
res.url().to_string(),
res.status_text().to_string(),
res.get_url().to_string(),
retries,
));
}
Expand Down Expand Up @@ -173,26 +172,35 @@ unsafe fn _network_open(
// RANGE header definition is "bytes=x-y"
let hvalue = format!("bytes={offset}-{end}");
// Create a new client that can be reused for subsequent queries
let clt = Client::builder().build()?;
let req = clt.request(Method::GET, &url);
// this performs the initial byte read, presumably as an error check
let initial = req.try_clone().ok_or(ProjError::RequestCloneError)?;
let with_headers = initial.header("Range", &hvalue).header("Client", CLIENT);
let mut res = with_headers.send()?;
let in_case_of_error = req
.try_clone()
.ok_or(ProjError::RequestCloneError)?
.header("Range", &hvalue);
let clt = Agent::new();
let req = clt.get(&url);
let with_headers = req.set("Range", &hvalue).set("Client", CLIENT);
let in_case_of_error = with_headers.clone();
let mut res = with_headers.call()?;
// hand the response off to the error-handler, continue on success
error_handler(&mut res, in_case_of_error)?;
// Write the initial read length value into the pointer
let contentlength = res.content_length().ok_or(ProjError::ContentLength)? as usize;
out_size_read.write(contentlength);
let headers = res.headers().clone();
let Some(Ok(contentlength)) = res.header("Content-Length").map(str::parse::<usize>) else {
return Err(ProjError::ContentLength);
};
let headers = res
.headers_names()
.into_iter()
.filter_map(|h| {
Some({
let v = res.header(&h)?.to_string();
(h, v)
})
})
.collect();
// Copy the downloaded bytes into the buffer so it can be passed around
res.bytes()?
.as_ptr()
.copy_to_nonoverlapping(buffer.cast(), contentlength.min(size_to_read));
let capacity = contentlength.min(size_to_read);
let mut buf = Vec::with_capacity(capacity);
res.into_reader()
.take(size_to_read as u64)
.read_to_end(&mut buf)?;
out_size_read.write(buf.len());
buf.as_ptr().copy_to_nonoverlapping(buffer.cast(), capacity);
let hd = HandleData::new(url, headers, None);
// heap-allocate the struct and cast it to a void pointer so it can be passed around to PROJ
let hd_boxed = Box::new(hd);
Expand Down Expand Up @@ -255,9 +263,8 @@ unsafe fn _network_get_header_value(
let hvalue = hd
.headers
.get(&lookup)
.ok_or_else(|| ProjError::HeaderError(lookup.to_string()))?
.to_str()?;
let cstr = CString::new(hvalue).unwrap();
.ok_or_else(|| ProjError::HeaderError(lookup.to_string()))?;
let cstr = CString::new(&**hvalue).unwrap();
let header = cstr.into_raw();
// Raw pointers are Copy: the pointer returned by this function is never returned by libproj so
// in order to avoid a memory leak the pointer is copied and stored in the HandleData struct,
Expand Down Expand Up @@ -327,34 +334,43 @@ fn _network_read_range(
let end = offset as usize + size_to_read - 1;
let hvalue = format!("bytes={offset}-{end}");
let hd = unsafe { &mut *(handle as *const c_void as *mut HandleData) };
let clt = Client::builder().build()?;
let initial = clt.request(Method::GET, &hd.url);
let in_case_of_error = initial
.try_clone()
.ok_or(ProjError::RequestCloneError)?
.header("Range", &hvalue)
.header("Client", CLIENT);
let req = in_case_of_error
.try_clone()
.ok_or(ProjError::RequestCloneError)?;
let mut res = req.send()?;
let clt = Agent::new();
let initial = clt.get(&hd.url);
let in_case_of_error = initial.clone().set("Range", &hvalue).set("Client", CLIENT);
let req = in_case_of_error.clone();
let mut res = req.call()?;
// hand the response and retry instance off to the error-handler, continue on success
error_handler(&mut res, in_case_of_error)?;
let headers = res.headers().clone();
let contentlength = res.content_length().ok_or(ProjError::ContentLength)? as usize;
let headers = res
.headers_names()
.into_iter()
.filter_map(|h| {
Some({
let v = res.header(&h)?.to_string();
(h, v)
})
})
.collect();
let Some(Ok(contentlength)) = res.header("Content-Length").map(str::parse::<usize>) else {
return Err(ProjError::ContentLength);
};
// Copy the downloaded bytes into the buffer so it can be passed around
let capacity = contentlength.min(size_to_read);
let mut buf = Vec::with_capacity(capacity);
res.into_reader()
.take(size_to_read as u64)
.read_to_end(&mut buf)?;
unsafe {
res.bytes()?
.as_ptr()
.copy_to_nonoverlapping(buffer.cast::<u8>(), contentlength.min(size_to_read));
buf.as_ptr()
.copy_to_nonoverlapping(buffer.cast::<u8>(), capacity);
}
let err_string = "";
unsafe {
out_error_string.copy_from_nonoverlapping(err_string.as_ptr().cast(), err_string.len());
out_error_string.add(err_string.len()).write(0);
}
hd.headers = headers;
Ok(contentlength)
Ok(buf.len())
}

/// Set up and initialise the grid download callback functions for all subsequent PROJ contexts
Expand Down
18 changes: 13 additions & 5 deletions src/proj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<T: CoordinateType> Coord<T> for (T, T) {

/// Errors originating in PROJ which can occur during projection and conversion
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ProjError {
/// A projection error
#[error("The projection failed with the following error: {0}")]
Expand All @@ -108,24 +109,31 @@ pub enum ProjError {
Network,
#[error("Could not set remote grid download callbacks")]
RemoteCallbacks,
#[error("Couldn't build request")]
#[error("Couldn't access the network")]
#[cfg(feature = "network")]
BuilderError(#[from] reqwest::Error),
NetworkError(Box<ureq::Error>),
#[error("Couldn't clone request")]
RequestCloneError,
#[error("Could not retrieve content length")]
ContentLength,
#[error("Couldn't retrieve header for key {0}")]
HeaderError(String),
#[cfg(feature = "network")]
#[error("Couldn't convert header value to str")]
HeaderConversion(#[from] reqwest::header::ToStrError),
#[error("Couldn't read response to buffer")]
ReadError(#[from] std::io::Error),
#[error("A {0} error occurred for url {1} after {2} retries")]
DownloadError(String, String, u8),
#[error("The current definition could not be retrieved")]
Definition,
}

#[cfg(feature = "network")]
impl From<ureq::Error> for ProjError {
fn from(e: ureq::Error) -> Self {
Self::NetworkError(Box::new(e))
}
}

#[derive(Error, Debug)]
pub enum ProjCreateError {
#[error("A nul byte was found in the PROJ string definition or CRS argument: {0}")]
Expand Down Expand Up @@ -1503,7 +1511,7 @@ mod test {
let usa_m = MyPoint::new(-115.797615, 37.2647978);
let usa_ft = to_feet.convert(usa_m).unwrap();
assert_relative_eq!(6693625.67217475, usa_ft.x());
assert_relative_eq!(3497301.5918027232, usa_ft.y(), epsilon=1e-8);
assert_relative_eq!(3497301.5918027232, usa_ft.y(), epsilon = 1e-8);
}

#[test]
Expand Down

0 comments on commit 3dadba1

Please sign in to comment.