[package]
name = "example-error-handling"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["macros"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.6.1", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Example showing how to convert errors into responses.
//!
//! Run with
//!
//! ```not_rust
//! cargo run -p example-error-handling
//! ```
//!
//! For successful requests the log output will be
//!
//! ```ignore
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=200
//! ```
//!
//! For failed requests the log output will be
//!
//! ```ignore
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request
//! ERROR request{method=POST uri=/users matched_path="/users"}: example_error_handling: error from time_library err=failed to get time
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=500
//! ```
use std::{
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
};
use axum::{
extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Router,
};
use serde::{Deserialize, Serialize};
use time_library::Timestamp;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into()
}),
)
.with(tracing_subscriber::fmt::layer())
.init();
let state = AppState::default();
let app = Router::new()
// A dummy route that accepts some JSON but sometimes fails
.route("/users", post(users_create))
.layer(
TraceLayer::new_for_http()
// Create our own span for the request and include the matched path. The matched
// path is useful for figuring out which handler the request was routed to.
.make_span_with(|req: &Request| {
let method = req.method();
let uri = req.uri();
// axum automatically adds this extension.
let matched_path = req
.extensions()
.get::<MatchedPath>()
.map(|matched_path| matched_path.as_str());
tracing::debug_span!("request", %method, %uri, matched_path)
})
// By default `TraceLayer` will log 5xx responses but we're doing our specific
// logging of errors so disable that
.on_failure(()),
)
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
#[derive(Default, Clone)]
struct AppState {
next_id: Arc<AtomicU64>,
users: Arc<Mutex<HashMap<u64, User>>>,
}
#[derive(Deserialize)]
struct UserParams {
name: String,
}
#[derive(Serialize, Clone)]
struct User {
id: u64,
name: String,
created_at: Timestamp,
}
async fn users_create(
State(state): State<AppState>,
// Make sure to use our own JSON extractor so we get input errors formatted in a way that
// matches our application
AppJson(params): AppJson<UserParams>,
) -> Result<AppJson<User>, AppError> {
let id = state.next_id.fetch_add(1, Ordering::SeqCst);
// We have implemented `From<time_library::Error> for AppError` which allows us to use `?` to
// automatically convert the error
let created_at = Timestamp::now()?;
let user = User {
id,
name: params.name,
created_at,
};
state.users.lock().unwrap().insert(id, user.clone());
Ok(AppJson(user))
}
// Create our own JSON extractor by wrapping `axum::Json`. This makes it easy to override the
// rejection and provide our own which formats errors to match our application.
//
// `axum::Json` responds with plain text if the input is invalid.
#[derive(FromRequest)]
#[from_request(via(axum::Json), rejection(AppError))]
struct AppJson<T>(T);
impl<T> IntoResponse for AppJson<T>
where
axum::Json<T>: IntoResponse,
{
fn into_response(self) -> Response {
axum::Json(self.0).into_response()
}
}
// The kinds of errors we can hit in our application.
enum AppError {
// The request body contained invalid JSON
JsonRejection(JsonRejection),
// Some error from a third party library we're using
TimeError(time_library::Error),
}
// Tell axum how `AppError` should be converted into a response.
//
// This is also a convenient place to log errors.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
// How we want errors responses to be serialized
#[derive(Serialize)]
struct ErrorResponse {
message: String,
}
let (status, message) = match self {
AppError::JsonRejection(rejection) => {
// This error is caused by bad user input so don't log it
(rejection.status(), rejection.body_text())
}
AppError::TimeError(err) => {
// Because `TraceLayer` wraps each request in a span that contains the request
// method, uri, etc we don't need to include those details here
tracing::error!(%err, "error from time_library");
// Don't expose any details about the error to the client
(
StatusCode::INTERNAL_SERVER_ERROR,
"Something went wrong".to_owned(),
)
}
};
(status, AppJson(ErrorResponse { message })).into_response()
}
}
impl From<JsonRejection> for AppError {
fn from(rejection: JsonRejection) -> Self {
Self::JsonRejection(rejection)
}
}
impl From<time_library::Error> for AppError {
fn from(error: time_library::Error) -> Self {
Self::TimeError(error)
}
}
// Imagine this is some third party library that we're using. It sometimes returns errors which we
// want to log.
mod time_library {
use std::sync::atomic::{AtomicU64, Ordering};
use serde::Serialize;
#[derive(Serialize, Clone)]
pub struct Timestamp(u64);
impl Timestamp {
pub fn now() -> Result<Self, Error> {
static COUNTER: AtomicU64 = AtomicU64::new(0);
// Fail on every third call just to simulate errors
if COUNTER.fetch_add(1, Ordering::SeqCst) % 3 == 0 {
Err(Error::FailedToGetTime)
} else {
Ok(Self(1337))
}
}
}
#[derive(Debug)]
pub enum Error {
FailedToGetTime,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to get time")
}
}
}