Skip to content

Commit

Permalink
Pass a Context to error handler
Browse files Browse the repository at this point in the history
  • Loading branch information
connorslade committed Dec 23, 2023
1 parent 3c873f6 commit 067b280
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 124 deletions.
24 changes: 12 additions & 12 deletions examples/basic/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use std::{
};

use afire::{
error::AnyResult,
extensions::{Logger, RealIp, ServeStatic},
middleware::Middleware,
route::RouteError,
trace::{set_log_level, Level},
Content, Method, Request, Response, Server,
Content, Context, Method, Request, Response, Server,
};
use anyhow::Result;
use rand::Rng;
use serde::Deserialize;
use serde_json::json;
Expand All @@ -22,7 +22,7 @@ struct App {
analytics: RwLock<HashMap<String, u64>>,
}

fn main() -> Result<()> {
fn main() -> AnyResult<()> {
// Show some helpful information during startup.
// afire log level is global and will affect all afire servers in your application
// (although there is usually only one)
Expand Down Expand Up @@ -110,20 +110,20 @@ impl Middleware for Analytics {

/// Custom error handler that returns JSON for API routes and plain text for other routes.
/// Note: This is just an example, in your own application you should consider making use of the location and error fields of RouteError.
fn error_handler(_server: Arc<Server<App>>, req: Arc<Request>, error: RouteError) -> Response {
if req.path.starts_with("/api") {
Response::new()
.text(json!({
"message": error.message,
}))
.content(Content::JSON)
fn error_handler(ctx: &Context<App>, error: RouteError) -> AnyResult<()> {
if ctx.req.path.starts_with("/api") {
ctx.text(json!({
"message": error.message,
}))
.content(Content::JSON)
} else {
Response::new()
.text(format!("Internal Server Error\n{}", error.message))
ctx.text(format!("Internal Server Error\n{}", error.message))
.content(Content::TXT)
}
.status(error.status)
.headers(error.headers)
.send()?;
Ok(())
}

impl App {
Expand Down
61 changes: 31 additions & 30 deletions examples/paste_bin/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,42 @@
use std::sync::Arc;

use afire::{headers::Vary, route::RouteError, Content, HeaderName, Request, Response, Server};
use afire::{error::AnyResult, headers::Vary, route::RouteError, Content, Context, HeaderName};
use serde_json::json;

use crate::app::App;

pub fn error_handler(_server: Arc<Server<App>>, req: Arc<Request>, error: RouteError) -> Response {
if req
pub fn error_handler(ctx: &Context<App>, error: RouteError) -> AnyResult<()> {
if ctx
.req
.headers
.get(HeaderName::Accept)
.map(|x| x == "application/json")
.unwrap_or(false)
{
Response::new()
.text(json!({
"message": error.message,
"location": error.location.map(|x| x.to_string()),
"error": error.error.map(|x| format!("{x:?}")),
}))
.content(Content::JSON)
ctx.text(json!({
"message": error.message,
"location": error.location.map(|x| x.to_string()),
"error": error.error.map(|x| format!("{x:?}")),
}))
.content(Content::JSON)
} else {
Response::new()
.text(format!(
"Internal Server Error\n{}{}{}",
error.message,
error
.error
.map(|x| format!("\n{:?}", x))
.unwrap_or_default(),
error
.location
.map(|x| format!("\n{}", x))
.unwrap_or_default(),
))
.content(Content::TXT)
}
.header(Vary::headers([HeaderName::Accept]))
.status(error.status)
.headers(error.headers)
ctx.text(format!(
"Internal Server Error\n{}{}{}",
error.message,
error
.error
.map(|x| format!("\n{:?}", x))
.unwrap_or_default(),
error
.location
.map(|x| format!("\n{}", x))
.unwrap_or_default(),
))
.content(Content::TXT)
};

ctx.header(Vary::headers([HeaderName::Accept]))
.status(error.status)
.headers(error.headers)
.send()?;

Ok(())
}
8 changes: 6 additions & 2 deletions examples/tmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use afire::{
internal::sync::{ForceLockMutex, ForceLockRwLock},
multipart::MultipartData,
prelude::*,
route::RouteContext,
route::{RouteContext, RouteError},
trace,
trace::DefaultFormatter,
trace::{set_log_formatter, set_log_level, Formatter, Level},
Expand All @@ -30,7 +30,11 @@ const PATH: &str = r#"..."#;
const FILE_TYPE: &str = "...";

fn main() -> Result<(), Box<dyn Error>> {
let mut server = Server::<()>::new("localhost", 8081).workers(4);
let mut server = Server::<()>::new("localhost", 8081)
.workers(4)
.error_handler(|ctx: &Context<()>, error: RouteError| {
Ok(ctx.text(error.message).send()?)
});
set_log_level(Level::Debug);
set_log_formatter(LogFormatter);

Expand Down
10 changes: 8 additions & 2 deletions lib/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
io::Read,
sync::{
atomic::{AtomicU8, Ordering},
Arc, Mutex,
Arc, Mutex, MutexGuard,
},
};

Expand All @@ -18,7 +18,7 @@ use crate::{
/// A collection of data important for handling a request.
/// It includes both the request data, and a reference to the server.
/// You also use it to build and send the response.
pub struct Context<State: 'static + Send + Sync> {
pub struct Context<State: 'static + Send + Sync = ()> {
/// Reference to the server.
pub server: Arc<Server<State>>,
/// The request you are handling.
Expand Down Expand Up @@ -115,6 +115,12 @@ impl<State: 'static + Send + Sync> Context<State> {
.unwrap_or_else(|| panic!("Path parameter #{} does not exist.", idx))
}

/// Gets a reference to the internal response.
/// This is mostly useful for when you need to inspect the current state of the response, or overwrite it in an error handler.
pub fn get_response(&self) -> MutexGuard<Response> {
self.response.force_lock()
}

/// Sends the response to the client.
/// This method must not be called more than once per request.
///
Expand Down
14 changes: 6 additions & 8 deletions lib/internal/event_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,14 @@ impl<State: Send + Sync> EventLoop<State> for TcpEventLoop {
break;
}

let event = match i {
Ok(event) => event,
Err(err) => {
trace!(Level::Error, "Error accepting connection: {err}");
continue;
match i {
Ok(event) => {
let this_server = server.clone();
let event = Arc::new(Socket::new(event));
server.thread_pool.execute(|| handle(event, this_server));
}
Err(err) => trace!(Level::Error, "Error accepting connection: {err}"),
};

let event = Arc::new(Socket::new(event));
handle(event, server.clone());
}
Ok(())
}
Expand Down
24 changes: 14 additions & 10 deletions lib/internal/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,20 @@ where
continue 'outer;
}

// TODO: account for guaranteed send
// TODO: Run through `write` for middleware
let error = RouteError::downcast_error(e);
if let Err(e) = this
.clone()
.error_handler
.handle(this.clone(), req.clone(), error)
.write(req.socket.clone(), &this.default_headers)
{
trace!(Level::Debug, "Error writing error response: {:?}", e);
if sent_response {
trace!(
Level::Error,
"Route handler [{:?}] errored after sending a response.",
route
);
} else {
if let Err(e) = this
.clone()
.error_handler
.handle(&ctx, RouteError::downcast_error(e))
{
trace!(Level::Debug, "Error writing error response: {:?}", e);
}
}
} else if sent_response || req.socket.is_raw() {
// A response has already been sent or another system has taken over the socket.
Expand Down
62 changes: 11 additions & 51 deletions lib/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::{
borrow::Cow,
error::Error,
fmt::{self, Debug, Display},
marker::PhantomData,
panic::Location,
sync::Arc,
};
Expand All @@ -15,7 +14,7 @@ use crate::{
error::{self, AnyResult},
internal::router::PathParameters,
router::Path,
Content, Context, Header, HeaderName, Method, Request, Response, Server, Status,
Content, Context, Header, HeaderName, Method, Request, Status,
};

type Handler<State> = Box<dyn Fn(&Context<State>) -> AnyResult<()> + 'static + Send + Sync>;
Expand Down Expand Up @@ -48,50 +47,16 @@ impl<State: Send + Sync> Debug for Route<State> {
/// For example, you could send JSON for errors if if the request is going to an API route or of its `Accept` header is `application/json` and HTML otherwise.
pub trait ErrorHandler<State: 'static + Send + Sync> {
/// Generates a response from an error.
fn handle(&self, server: Arc<Server<State>>, req: Arc<Request>, error: RouteError) -> Response;
fn handle(&self, ctx: &Context<State>, error: RouteError) -> AnyResult<()>;
}

impl<State, F> ErrorHandler<State> for F
where
State: 'static + Send + Sync,
F: Fn(Arc<Server<State>>, Arc<Request>, RouteError) -> Response + Send + Sync,
F: Fn(&Context<State>, RouteError) -> AnyResult<()> + Send + Sync,
{
fn handle(&self, server: Arc<Server<State>>, req: Arc<Request>, error: RouteError) -> Response {
(self)(server, req, error)
}
}

/// Lets you create an error handler from a function with the signature `Fn(Arc<Server<State>>, RouteError) -> Response`.
pub struct AnonymousErrorHandler<State, F>
where
State: Send + Sync + 'static,
F: Fn(Arc<Server<State>>, Arc<Request>, RouteError) -> Response + Send + Sync,
{
f: F,
_state: PhantomData<State>,
}

impl<State, F> AnonymousErrorHandler<State, F>
where
State: Send + Sync + 'static,
F: Fn(Arc<Server<State>>, Arc<Request>, RouteError) -> Response + Send + Sync,
{
/// Creates a new anonymous error handler.
pub fn new(f: F) -> Self {
Self {
f,
_state: PhantomData,
}
}
}

impl<State, F> ErrorHandler<State> for AnonymousErrorHandler<State, F>
where
State: 'static + Send + Sync,
F: Fn(Arc<Server<State>>, Arc<Request>, RouteError) -> Response + Send + Sync,
{
fn handle(&self, server: Arc<Server<State>>, req: Arc<Request>, error: RouteError) -> Response {
(self.f)(server, req, error)
fn handle(&self, ctx: &Context<State>, error: RouteError) -> AnyResult<()> {
(self)(ctx, error)
}
}

Expand All @@ -115,12 +80,7 @@ where
pub struct DefaultErrorHandler;

impl<State: 'static + Send + Sync> ErrorHandler<State> for DefaultErrorHandler {
fn handle(
&self,
_server: Arc<Server<State>>,
_req: Arc<Request>,
error: RouteError,
) -> Response {
fn handle(&self, ctx: &Context<State>, error: RouteError) -> AnyResult<()> {
let mut message = format!("Internal Server Error\n\n{}", error.message);

if let Some(location) = error.location {
Expand All @@ -131,16 +91,16 @@ impl<State: 'static + Send + Sync> ErrorHandler<State> for DefaultErrorHandler {
message.push_str(&format!("\n\n{:#?}", error));
}

let mut res = Response::new()
.status(error.status)
ctx.status(error.status)
.text(message)
.headers(error.headers);

if !res.headers.has(HeaderName::ContentType) {
res = res.content(Content::TXT);
if !ctx.get_response().headers.has(HeaderName::ContentType) {
ctx.content(Content::TXT);
}

res
ctx.send()?;
Ok(())
}
}

Expand Down
11 changes: 6 additions & 5 deletions lib/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,14 @@ impl<State: Send + Sync> Server<State> {
/// Be sure that your panic handler wont panic, because that will just panic the whole application.
/// ## Example
/// ```rust
/// # use afire::{Server, Response, Status, route::AnonymousErrorHandler};
/// # use afire::{Server, Response, Status, Context, route::RouteError};
/// Server::<()>::new("localhost", 8080)
/// .error_handler(AnonymousErrorHandler::new(|_server, _req, err| {
/// Response::new()
/// .status(Status::InternalServerError)
/// .error_handler(|ctx: &Context, err: RouteError| {
/// ctx.status(Status::InternalServerError)
/// .text(format!("Internal Server Error: {}", err.message))
/// }));
/// .send()?;
/// Ok(())
/// });
/// ```
pub fn error_handler(self, res: impl ErrorHandler<State> + Send + Sync + 'static) -> Self {
trace!("{}Setting Error Handler", emoji("✌"));
Expand Down
8 changes: 4 additions & 4 deletions lib/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ pub(crate) fn emoji(_emoji: &str) -> Cow<str> {
/// Enabled with the `tracing` feature
#[macro_export]
macro_rules! trace {
(Level::$level: ident, $($arg: tt) +) => {
(Level::$level: ident, $($arg: tt) +) => {{
#[cfg(feature = "tracing")]
$crate::trace::_trace($crate::trace::Level::$level, format_args!($($arg)+));
};
($($arg: tt) +) => {
}};
($($arg: tt) +) => {{
#[cfg(feature = "tracing")]
$crate::trace::_trace($crate::trace::Level::Trace, format_args!($($arg)+));
};
}};
}

/// A wrapper for [`Display`] or [`Debug`] types that only evaluates the inner value when it is actually used.
Expand Down

0 comments on commit 067b280

Please sign in to comment.