axum
Welcome to the axum book!
This book is mostly based on the examples that can be found in the git repository of axum.
As you flip through the book you'll see that for now you can mostly see the examples without explanations and without any order.
I am working on it.
Goals
- Provide an easy way to see the examples in the axum repository with explanations.
- A way to learn using axum.
- Improve the examples (adding test, filling holes if there are any).
- Adding
README.md
files to the examples. - The book can be found at axum.code-maven.com with the hope that one day it will be accepted as part of the official axum project and displayed on a stand-alone axum website or on the Tokio website.
Notes
This book uses the examples from the source code of the axum project. For this reason the Cargo.toml
files will declare their dependency on
axum
in a relative manner using the following entry:
axum = { path = "../../axum" }
This works well as the examples are included the in the git repository of axum. In real-world applications this line need to be replaced by a line declaring the version of axum you'd like to use. For example:
axum = "0.8.1"
How to run the examples?
You can run the examples either from the git repository of axum or as stand alone applications.
git clone https://github.com/szabgab/axum.git
cd axum
cargo run -p example-NAME
For example, in order to run the hello-world example you need to execute:
cargo run -p example-hello-world
Then you can visit the web site using this address: http://localhost:3000/.
Alternatively, you can copy the content of the examples, replace the
axum = { path = "../../axum" }
by
axum = "0.8.1"
and then run
cargo run
How to run the tests?
cargo test -p NAME-OF-THE-EXAMPLE
Specifically for the hello-world example:
cargo test -p example-hello-world
More examples, tutorials
This page, contains a list of additional axum-related crates, project showcases, and several tutorials taken from the ECOSYSTEM.md
file.
Community Projects
If your project isn't listed here and you would like it to be, please feel free to create a PR.
Community maintained axum ecosystem
- axum-server: axum-server is a hyper server implementation designed to be used with axum.
- axum-typed-websockets:
axum::extract::ws
with type safe messages. - tower-cookies: Cookie manager middleware
- axum-flash: One-time notifications (aka flash messages) for axum.
- axum-msgpack: MessagePack Extractors for axum.
- axum-sqlx-tx: Request-bound SQLx transactions with automatic commit/rollback based on response.
- aliri_axum and aliri_tower: JWT validation middleware and OAuth2 scopes enforcing extractors.
- ezsockets: Easy to use WebSocket library that integrates with axum.
- axum_session: Database persistent sessions like pythons flask_sessionstore for axum.
- axum_session_auth: Persistent session based user login with rights management for axum.
- axum-auth: High-level http auth extractors for axum.
- axum-keycloak-auth: Protect axum routes with a JWT emitted by Keycloak.
- axum-tungstenite: WebSocket connections for axum directly using tungstenite
- axum-jrpc: Json-rpc extractor for axum
- axum-tracing-opentelemetry: Middlewares and tools to integrate axum + tracing + opentelemetry
- svelte-axum-project: Template and example for Svelte frontend app with axum as backend
- axum-streams: Streaming HTTP body with different formats: JSON, CSV, Protobuf.
- axum-template: Layers, extractors and template engine wrappers for axum based Web MVC applications
- axum-template: GraphQL and REST API, SurrealDb, JWT auth, direct error handling, request logs
- axum-guard-logic: Use AND/OR logic to extract types and check their values against
Service
inputs. - axum-casbin-auth: Casbin access control middleware for axum framework
- aide: Code-first Open API documentation generator with axum integration.
- axum-typed-routing: Statically typed routing macros with OpenAPI generation using aide.
- axum-jsonschema: A
Json<T>
extractor that does JSON schema validation of requests. - axum-login: Session-based user authentication for axum.
- axum-csrf-sync-pattern: A middleware implementing CSRF STP for AJAX backends and API endpoints.
- axum-otel-metrics: A axum OpenTelemetry Metrics middleware with prometheus exporter supported.
- jwt-authorizer: JWT authorization layer for axum (oidc discovery, validation options, claims extraction, etc.)
- axum-typed-multipart: Type safe wrapper for
axum::extract::Multipart
. - tower-governor: A Tower service and layer that provides a rate-limiting backend by governor
- axum-restful: A restful framework based on axum and sea-orm, inspired by django-rest-framework.
- springtime-web-axum: A web framework built on Springtime and axum, leveraging dependency injection for easy app development.
- rust-axum-with-google-oauth: website template for Google OAuth authentication on axum, using SQLite with SQLx or MongoDB and MiniJinja.
- axum-htmx: Htmx extractors and request guards for axum.
- axum-prometheus: A middleware library to collect HTTP metrics for axum applications, compatible with all metrics.rs exporters.
- axum-valid: Extractors for data validation using validator, garde, and validify.
- tower-sessions: Sessions as a
tower
andaxum
middleware. - shuttle: Build & ship backends without writing any infrastructure files. Now with axum support.
- socketioxide: An easy to use socket.io server implementation working as a
tower
layer/service. - axum-serde: Provides multiple serde-based extractors / responses, also offers a macro to easily customize serde-based extractors / responses.
- loco.rs: A full stack Web and API productivity framework similar to Rails, based on axum.
- axum-test: High level library for writing Cargo tests that run against axum.
- axum-messages: One-time notification messages for axum.
- spring-rs: spring-rs is a microservice framework written in rust inspired by java's spring-boot, based on axum
- zino: Zino is a next-generation framework for composable applications which provides full integrations with axum.
- axum-rails-cookie: Extract rails session cookies in axum based apps.
- axum-ws-broadcaster: A broadcasting liblary for both axum-typed-websockets and
axum::extract::ws
. - axum-negotiate-layer: Middleware/Layer for Kerberos/NTLM "Negotiate" authentication.
- axum-kit: Streamline the integration and usage of axum with SQLx and Redis.
- tower_allowed_hosts: Allowed hosts middleware which limits request from only allowed hosts.
- baxe: Simple macro for defining backend errors once and automatically generate standardized JSON error responses, saving time and reducing complexity
- axum-html-minifier: This middleware minify the html body content of a axum response.
- static-serve: A helper macro for compressing and embedding static assets in an axum webserver.
Project showcase
- HomeDisk: ☁️ Fast, lightweight and Open Source local cloud for your data.
- Houseflow: House automation platform written in Rust.
- JWT Auth: JWT auth service for educational purposes.
- ROAPI: Create full-fledged APIs for static datasets without writing a single line of code.
- notify.run: HTTP-to-WebPush relay for sending desktop/mobile notifications to yourself, written in Rust.
- turbo.fish (repository): Find out for yourself 😉
- Book Management: CRUD system of book-management with ORM and JWT for educational purposes.
- realworld-axum-sqlx: A Rust implementation of the Realworld demo app spec using axum and SQLx. See https://github.com/davidpdrsn/realworld-axum-sqlx for a fork with up to date dependencies.
- Rustapi: RESTful API template using MongoDB
- axum-postgres-template: Production-ready axum + PostgreSQL application template
- RUSTfulapi: Reusable template for building REST Web Services in Rust. Uses axum and SeaORM.
- Jotsy: Self-hosted notes app powered by Skytable, axum and Tokio
- Svix (repository): Enterprise-ready webhook service
- emojied (repository): Shorten URLs to emojis!
- CLOMonitor (repository): Checks open source projects repositories to verify they meet certain best practices.
- Pinging.net (repository): A new way to check and monitor your internet connection.
- wastebin: A minimalist pastebin service.
- sandbox_axum_observability A Sandbox/showcase project to experiment axum and observability (tracing, opentelemetry, jaeger, grafana tempo,...)
- axum_admin: An admin panel built with axum, Sea-orm and Vue 3.
- rgit: A blazingly fast Git repository browser, compatible with- and heavily inspired by cgit.
- Petclinic: A port of Spring Framework's Petclinic showcase project to axum
- axum-middleware-example: A authorization application using axum, Casbin and Diesel, with JWT support.
- circleci-hook: Translate CircleCI WebHooks to OpenTelemetry traces to improve your test insights. Add detail with otel-cli to capture individual commands. Use the TRACEPARENT integration to add details from your tests.
- lishuuro.org: Small chess variant server that uses axum for the backend.
- freedit: A forum powered by rust.
- axum-http-auth-example: axum http auth example using postgres and redis.
- Deaftone: Lightweight music server. With a clean and simple API
- dropit: Temporary file hosting.
- cobrust: Multiplayer web based snake game.
- meta-cross: Tweaked version of Tic-Tac-Toe.
- httq HTTP to MQTT trivial proxy.
- Pods-Blitz Self-hosted podcast publisher. Uses the crates axum-login, password-auth, sqlx and handlebars (for HTML templates).
- ReductStore: A time series database for storing and managing large amounts of blob data
- randoku: A tiny web service which generates random numbers and shuffles lists randomly
- sero: Host static sites with custom subdomains as surge.sh does. But with full control and cool new features. (axum, sea-orm, postgresql)
- Hatsu: 🩵 Self-hosted & Fully-automated ActivityPub Bridge for Static Sites.
- Mini RPS: Mini reverse proxy server, HTTPS, CORS, static file hosting and template engine (minijinja).
Tutorials
- Rust on Nails: A full stack architecture for Rust web applications
- axum-tutorial (website): axum tutorial for beginners
- demo-rust-axum: Demo of Rust and axum web framework
- Introduction to axum (talk): Talk about axum from the Copenhagen Rust Meetup
- Getting Started with Axum: axum tutorial, GET, POST endpoints and serving files
- Using Rust, Axum, PostgreSQL, and Tokio to build a Blog
- Introduction to axum: YouTube playlist
- Rust Axum Full Course: YouTube video
- Deploying Axum projects with Shuttle
- API Development with Rust: REST APIs based on axum
Building a SaaS with Rust & Next.js A tutorial for combining Next.js with Rust via axum to make a SaaS.
Crates in use
The following crates are used in the examples.
- anyhow
- askama
- assert-json-diff
- async-session
- axum
- axum-extra
- axum-server
- bb8
- bb8-postgres
- bb8-redis
- brotli
- deadpool-diesel
- diesel
- diesel-async
- diesel_migrations
- eventsource-stream
- flate2
- futures
- futures-executor
- futures-util
- headers
- http
- http-body-util
- hyper
- hyper-util
- jsonwebtoken
- listenfd
- metrics
- metrics-exporter-prometheus
- mime
- minijinja
- mongodb
- oauth2
- openssl
- redis
- reqwest
- reqwest-eventsource
- serde
- serde_json
- sqlx
- thiserror
- tokio
- tokio-native-tls
- tokio-openssl
- tokio-postgres
- tokio-rustls
- tokio-stream
- tokio-tungstenite
- tokio-util
- tower
- tower-http
- tower-service
- tracing
- tracing-subscriber
- uuid
- validator
- zstd
Outline
An outline for the book
-
Echo GET Accept parameters in a GET request.
-
Echo POST Accept parameters in a POST request.
-
Path parameters Accept path parameters.
-
Session management
-
Templates
-
Auto-reload the server
HTTP methods
Introduction
The first steps writing a web application using axum.
Simple handling of parameters in GET and POST requests and as part of the path.
Hello World
The standard "Hello World" application.
- Create a new create
- Add axum and tokio with "full" feature.
cargo new hello-world
cd hello-world
cargo add axum
cargo add tokio -F full
We also have two additional crates that we use to test our application.
This is how our Cargo.toml
file looks like. As mentioned earlier, here in this book we use the axum
located in the same repository.
You will have something like this: axum = "0.8.1"
.
[package]
name = "example-hello-world"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]
http-body-util = "0.1.0"
tower = { version = "0.5.2", features = ["util"] }
In our application we need to map the path part of each URL the user might visit to a function to handle that request.
For this we need a function to handle the request and we need the map the path portion of the URL to the function that will handle it.
For example if we would like to handle the URL https://example.org/hello/world
then we need to map /hello/world
path to the appropriate
function in our application. In particular the address of the main page is https://example.org/
and thus the path is /
.
We defined a function to handle a request. To make it simple we return a static string with and HTML snippet. The name of the function does not matter.
#![allow(unused)] fn main() { async fn handler() -> Html<&'static str> { Html("<h1>Hello, World!</h1>") } }
We need to map the GET request that arrives to /
to be handled by this function.
We put the creation of the Router
in a separate function to make it easy to test it.
#![allow(unused)] fn main() { fn app() -> Router { Router::new().route("/", get(handler)) } }
Finally, we need to create our server in our main
function.
#[tokio::main] async fn main() { // build our application with a route let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on http://{}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); }
To run the application type in the following command:
cargo run
This will compile the code, run the server and print the following to the terminal:
listening on http://127.0.0.1:3000
You can then open your browser and visit that address.
You should see Hello World! in big letters.
Troubleshooting
If when you run cargo run
it gives you an error like this, then you have another process running and using the 3000 port.
Maybe another example from the axum repository. You can either find the application, shut it down and try to run this again,
or you can change the port number in this example from 3000 to 3001 or some other number and try again.
thread 'main' panicked at examples/hello-world/src/main.rs:17:10:
called `Result::unwrap()` on an `Err` value: Os { code: 98, kind: AddrInUse, message: "Address already in use" }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Handling other pages
If you try to visit some other page on your server e.g. http://127.0.0.1:3000/hi
you will get a blank page.
We'll see what happens there and how to make axum to display a custom 404 error page.
Checking with curl
The curl command allows you to access web sites from the command line. It is a very handy tool.
Let's see how can we use it with our web site.
Accessing the main page:
$ curl http://localhost:3000/
<h1>Hello, World!</h1>
We can also observe what happens if we try to access a page that does not exist. It seems that nothing happens which is rather inconvenient. This is the blank page we saw earlier.
$ curl http://localhost:3000/hi
Using the -i
flag we can ask curl
to also display the HTTP header the server sent us.
Using the upper-case -I
flag curl
would print only the header that was sent by the server.
It might be more convenient.
Accessing the main page we get the following:
$ curl -I http://localhost:3000/
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 22
date: Fri, 14 Mar 2025 08:27:44 GMT
<h1>Hello, World!</h1>
The first line includes the status code. This time it is 200 OK
success status.
Accessing a page that does not exists we get a 404 Not Found
error status.
$ curl -I http://localhost:3000/hi
HTTP/1.1 404 Not Found
date: Fri, 14 Mar 2025 08:27:41 GMT
Shutting down our local server
Once you are done with this experiment you will want to shut down this local web server. Return to the terminal where you ran it and press Ctrl-C
.
Editing this example
Feel free to edit this example and see what happens. However, remember that after each change you'll need to stop the server and start it again. This is rather inconvenient. Later we'll see how to make Rust automatically recompile and restart the server every time to make some changes.
Improving the 404 page
You might dislike the fact visiting a non-existent path returns a blank page.
Check out the example showing the 404 handler.
Testing
Writing automated tests for your application can save you a lot of time down the road and you might even develop you application much faster if instead of checking it in a browser you write test. This is especially true if you are implementing an API which is designed to be consumed by other software anyway.
In main.rs
we need to mention the test module:
#![allow(unused)] fn main() { #[cfg(test)] mod tests; }
tests.rs
#![allow(unused)] fn main() { use super::*; use axum::{body::Body, http::Request, http::StatusCode}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_main_page() { let response = app() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "<h1>Hello, World!</h1>"); } }
The full example
//! Run with //! //! ```not_rust //! cargo run -p example-hello-world //! ``` use axum::{response::Html, routing::get, Router}; #[tokio::main] async fn main() { // build our application with a route let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on http://{}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { Router::new().route("/", get(handler)) } async fn handler() -> Html<&'static str> { Html("<h1>Hello, World!</h1>") } #[cfg(test)] mod tests;
Echo GET - accepting Query params
Show how to accept parameters in a GET request.
Running
cargo run -p example-echo-get
GET the main page
$ curl http://localhost:3000/
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 131
date: Tue, 18 Mar 2025 08:04:53 GMT
<form method="get" action="/echo">
<input type="text" name="text">
<input type="submit" value="Echo">
</form>
GET request with parameter
$ curl -i http://localhost:3000/echo?text=Hello+World!
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 29
date: Tue, 18 Mar 2025 08:06:31 GMT
You said: <b>Hello World!</b>
GET request without the parameter
$ curl -i http://localhost:3000/echo
HTTP/1.1 400 Bad Request
content-type: text/plain; charset=utf-8
content-length: 56
date: Tue, 18 Mar 2025 08:05:13 GMT
Failed to deserialize query string: missing field `text`
GET request with parameter name but without value
$ curl -i http://localhost:3000/echo?text=
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 17
date: Tue, 18 Mar 2025 08:07:04 GMT
You said: <b></b>
Cargo.toml
[package]
name = "example-echo-get"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]
http-body-util = "0.1.0"
tower = { version = "0.5.2", features = ["util"] }
There are two function handling two paths:
The main page is static HTML
async fn main_page() -> Html<&'static str> { Html( r#" <form method="get" action="/echo"> <input type="text" name="text"> <input type="submit" value="Echo"> </form> "#, ) }
The echo page
#![allow(unused)] fn main() { async fn echo(Query(params): Query<Params>) -> Html<String> { println!("params: {:?}", params); Html(format!(r#"You said: <b>{}</b>"#, params.text)) } }
Struct describing the parameters
#![allow(unused)] fn main() { #[derive(Debug, Deserialize)] #[allow(dead_code)] struct Params { text: String, } }
Mapping the routes to functions
#![allow(unused)] fn main() { fn app() -> Router { Router::new() .route("/", get(main_page)) .route("/echo", get(echo)) } }
The full example
use axum::{extract::Query, response::Html, routing::get, Router}; use serde::Deserialize; #[tokio::main] async fn main() { let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app()).await.unwrap(); } fn app() -> Router { Router::new() .route("/", get(main_page)) .route("/echo", get(echo)) } async fn main_page() -> Html<&'static str> { Html( r#" <form method="get" action="/echo"> <input type="text" name="text"> <input type="submit" value="Echo"> </form> "#, ) } async fn echo(Query(params): Query<Params>) -> Html<String> { println!("params: {:?}", params); Html(format!(r#"You said: <b>{}</b>"#, params.text)) } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct Params { text: String, } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request, http::StatusCode}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_main_page() { let response = app() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert!(html.contains(r#"<form method="get" action="/echo">"#)); } #[tokio::test] async fn test_echo_with_data() { let response = app() .oneshot( Request::builder() .uri("/echo?text=Hello+World!") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "You said: <b>Hello World!</b>"); } #[tokio::test] async fn test_echo_without_data() { let response = app() .oneshot(Request::builder().uri("/echo").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); // 400 let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!( html, "Failed to deserialize query string: missing field `text`" ); } #[tokio::test] async fn test_echo_missing_value() { let response = app() .oneshot( Request::builder() .uri("/echo?text=") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "You said: <b></b>"); } #[tokio::test] async fn test_echo_extra_param() { let response = app() .oneshot( Request::builder() .uri("/echo?text=Hello&extra=123") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "You said: <b>Hello</b>"); } }
Echo POST
Show how to accept parameters in a POST request.
Running
cargo run -p example-echo-post
GET the main page
$ curl -i http://localhost:3000/
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 132
date: Tue, 18 Mar 2025 08:21:36 GMT
<form method="post" action="/echo">
<input type="text" name="text">
<input type="submit" value="Echo">
</form>
POST request setting the header and the data
$ curl -i -X POST \
-H "Content-Type: application/x-www-form-urlencoded" \
--data "text=Hello World!" \
http://localhost:3000/echo
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 29
date: Tue, 18 Mar 2025 08:23:51 GMT
You said: <b>Hello World!</b>
POST missing parameter
$ curl -i -X POST \
-H "Content-Type: application/x-www-form-urlencoded" \
http://localhost:3000/echo
HTTP/1.1 422 Unprocessable Entity
content-type: text/plain; charset=utf-8
content-length: 53
date: Tue, 18 Mar 2025 08:25:39 GMT
Failed to deserialize form body: missing field `text`
[package]
name = "example-echo-post"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]
http-body-util = "0.1.0"
mime = "0.3.17"
tower = { version = "0.5.2", features = ["util"] }
use axum::{ response::Html, routing::{get, post}, Form, Router, }; use serde::Deserialize; #[tokio::main] async fn main() { let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app()).await.unwrap(); } fn app() -> Router { Router::new() .route("/", get(main_page)) .route("/echo", post(echo)) } async fn main_page() -> Html<&'static str> { Html( r#" <form method="post" action="/echo"> <input type="text" name="text"> <input type="submit" value="Echo"> </form> "#, ) } async fn echo(Form(params): Form<Params>) -> Html<String> { println!("params: {:?}", params); Html(format!(r#"You said: <b>{}</b>"#, params.text)) } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct Params { text: String, } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{self, Request, StatusCode}, }; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_main_page() { let response = app() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert!(html.contains(r#"<form method="post" action="/echo">"#)); } #[tokio::test] async fn test_echo_with_data() { let response = app() .oneshot( Request::builder() .method(http::Method::POST) .uri("/echo") .header( http::header::CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), ) .body(Body::from("text=Hello+World!")) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "You said: <b>Hello World!</b>"); } #[tokio::test] async fn test_echo_without_data() { let response = app() .oneshot( Request::builder() .method(http::Method::POST) .uri("/echo") .header( http::header::CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), ) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); // 422 let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!( html, "Failed to deserialize form body: missing field `text`" ); } #[tokio::test] async fn test_echo_missing_value() { let response = app() .oneshot( Request::builder() .method(http::Method::POST) .uri("/echo") .header( http::header::CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), ) .body(Body::from("text=")) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "You said: <b></b>"); } #[tokio::test] async fn test_echo_extra_param() { let response = app() .oneshot( Request::builder() .method(http::Method::POST) .uri("/echo") .header( http::header::CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), ) .body(Body::from("text=Hello&extra=123")) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "You said: <b>Hello</b>"); } }
Path parameters
Show how to accept parameters in the path of the request. For example we to accept all the paths that look like this: https://example.org/user/foobar
.
Running
cargo run -p example-path-parameter
GET the main page
$ curl -i http://localhost:3000/
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 89
date: Tue, 18 Mar 2025 09:32:55 GMT
<a href="/user/foo">/user/foo</a><br>
<a href="/user/bar">/user/bar</a><br>
Getting user Foo
$ curl -i http://localhost:3000/user/Foo
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 11
date: Tue, 18 Mar 2025 09:35:45 GMT
Hello, Foo!
Try without a username
$ curl -i http://localhost:3000/user/
HTTP/1.1 404 Not Found
content-length: 0
date: Tue, 18 Mar 2025 09:36:15 GMT
Cargo.toml
[package]
name = "example-path-parameters"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]
http-body-util = "0.1.0"
tower = { version = "0.5.2", features = ["util"] }
The whole example
use axum::{extract::Path, response::Html, routing::get, Router}; #[tokio::main] async fn main() { // build our application with a route let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on http://{}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { Router::new() .route("/", get(main_page)) .route("/user/{name}", get(user_page)) } async fn main_page() -> Html<&'static str> { Html( r#" <a href="/user/foo">/user/foo</a><br> <a href="/user/bar">/user/bar</a><br> "#, ) } async fn user_page(Path(name): Path<String>) -> Html<String> { println!("user: {}", name); Html(format!("Hello, {}!", name)) } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request, http::StatusCode}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_main_page() { let response = app() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert!(html.contains(r#"<a href="/user/foo">/user/foo</a><br>"#)); } #[tokio::test] async fn test_user_page() { let response = app() .oneshot( Request::builder() .uri("/user/qqrq") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "Hello, qqrq!"); } }
Input validation
For GET, POST, and path parameters.
- Accepting strings.
- Accepting values in other types (numbers, booleans).
- Accepting only IDs that are in the database.
- Accepting a limited set of values that can be defined in an enum.
versioning - path parameter with fixed values
Sometimes the path-parameter is from a fixed set of values. For example if we are build an API and the first part of the path is the
version number of the API then we might accept the strings v1
, v2
, v3
, but no other value.
In this example we see exactly that.
[package]
name = "example-versioning"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
http-body-util = "0.1.0"
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-versioning //! ``` use axum::{ extract::{FromRequestParts, Path}, http::{request::Parts, StatusCode}, response::{Html, IntoResponse, Response}, routing::get, RequestPartsExt, Router, }; use std::collections::HashMap; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { Router::new().route("/{version}/foo", get(handler)) } async fn handler(version: Version) -> Html<String> { Html(format!("received request with version {version:?}")) } #[derive(Debug)] enum Version { V1, V2, V3, } impl<S> FromRequestParts<S> for Version where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { let params: Path<HashMap<String, String>> = parts.extract().await.map_err(IntoResponse::into_response)?; let version = params .get("version") .ok_or_else(|| (StatusCode::NOT_FOUND, "version param missing").into_response())?; match version.as_str() { "v1" => Ok(Version::V1), "v2" => Ok(Version::V2), "v3" => Ok(Version::V3), _ => Err((StatusCode::NOT_FOUND, "unknown version").into_response()), } } } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request, http::StatusCode}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_v1() { let response = app() .oneshot( Request::builder() .uri("/v1/foo") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "received request with version V1"); } #[tokio::test] async fn test_v4() { let response = app() .oneshot( Request::builder() .uri("/v4/foo") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::NOT_FOUND); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "unknown version"); } }
Query params - GET requests
Shows hoe to get the Query parameters from a request.
From
https://example.org/some/path?name=Foo&height=1.87
extract
name: Foo
height: 1.87
Run
cargo run -p example-query-params-with-empty-strings
Test
cargo test -p example-query-params-with-empty-strings
Use
Using the command line we can check how this works.
If we don't provide any parameters:
$ curl http://localhost:3000
Params { foo: None, bar: None }
We provide an integer as the value of the foo
parameter:
$ curl http://localhost:3000?foo=42
Params { foo: Some(42), bar: None }
An integer for foo
and a string for bar
:
curl "http://localhost:3000?foo=42&bar=hello"
Params { foo: Some(42), bar: Some("hello") }
If we provide the name of the parameter, but not any value, then for the numerical value we still get None
, but for the string value we get an empty string.
$ curl "http://localhost:3000?foo=&bar="
Params { foo: None, bar: Some("") }
There are two main ways to send data from the browser to the server. One of them happens when the request is a GET
request. The name of these parameters are "Query params".
They are visible in the address bar.
[package]
name = "example-query-params-with-empty-strings"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0"
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-query-params-with-empty-strings //! ``` use axum::{extract::Query, routing::get, Router}; use serde::{de, Deserialize, Deserializer}; use std::{fmt, str::FromStr}; #[tokio::main] async fn main() { let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app()).await.unwrap(); } fn app() -> Router { Router::new().route("/", get(handler)) } async fn handler(Query(params): Query<Params>) -> String { format!("{params:?}") } /// See the tests below for which combinations of `foo` and `bar` result in /// which deserializations. /// /// This example only shows one possible way to do this. [`serde_with`] provides /// another way. Use which ever method works best for you. /// /// [`serde_with`]: https://docs.rs/serde_with/1.11.0/serde_with/rust/string_empty_as_none/index.html #[derive(Debug, Deserialize)] #[allow(dead_code)] struct Params { #[serde(default, deserialize_with = "empty_string_as_none")] foo: Option<i32>, bar: Option<String>, } /// Serde deserialization decorator to map empty Strings to None, fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error> where D: Deserializer<'de>, T: FromStr, T::Err: fmt::Display, { let opt = Option::<String>::deserialize(de)?; match opt.as_deref() { None | Some("") => Ok(None), Some(s) => FromStr::from_str(s).map_err(de::Error::custom).map(Some), } } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_something() { assert_eq!( send_request_get_body("foo=1&bar=bar").await, r#"Params { foo: Some(1), bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=&bar=bar").await, r#"Params { foo: None, bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=&bar=").await, r#"Params { foo: None, bar: Some("") }"#, ); assert_eq!( send_request_get_body("foo=1").await, r#"Params { foo: Some(1), bar: None }"#, ); assert_eq!( send_request_get_body("bar=bar").await, r#"Params { foo: None, bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=").await, r#"Params { foo: None, bar: None }"#, ); assert_eq!( send_request_get_body("bar=").await, r#"Params { foo: None, bar: Some("") }"#, ); assert_eq!( send_request_get_body("").await, r#"Params { foo: None, bar: None }"#, ); } async fn send_request_get_body(query: &str) -> String { let body = app() .oneshot( Request::builder() .uri(format!("/?{query}")) .body(Body::empty()) .unwrap(), ) .await .unwrap() .into_body(); let bytes = body.collect().await.unwrap().to_bytes(); String::from_utf8(bytes.to_vec()).unwrap() } }
Full example
//! Run with //! //! ```not_rust //! cargo run -p example-query-params-with-empty-strings //! ``` use axum::{extract::Query, routing::get, Router}; use serde::{de, Deserialize, Deserializer}; use std::{fmt, str::FromStr}; #[tokio::main] async fn main() { let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app()).await.unwrap(); } fn app() -> Router { Router::new().route("/", get(handler)) } async fn handler(Query(params): Query<Params>) -> String { format!("{params:?}") } /// See the tests below for which combinations of `foo` and `bar` result in /// which deserializations. /// /// This example only shows one possible way to do this. [`serde_with`] provides /// another way. Use which ever method works best for you. /// /// [`serde_with`]: https://docs.rs/serde_with/1.11.0/serde_with/rust/string_empty_as_none/index.html #[derive(Debug, Deserialize)] #[allow(dead_code)] struct Params { #[serde(default, deserialize_with = "empty_string_as_none")] foo: Option<i32>, bar: Option<String>, } /// Serde deserialization decorator to map empty Strings to None, fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error> where D: Deserializer<'de>, T: FromStr, T::Err: fmt::Display, { let opt = Option::<String>::deserialize(de)?; match opt.as_deref() { None | Some("") => Ok(None), Some(s) => FromStr::from_str(s).map_err(de::Error::custom).map(Some), } } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_something() { assert_eq!( send_request_get_body("foo=1&bar=bar").await, r#"Params { foo: Some(1), bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=&bar=bar").await, r#"Params { foo: None, bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=&bar=").await, r#"Params { foo: None, bar: Some("") }"#, ); assert_eq!( send_request_get_body("foo=1").await, r#"Params { foo: Some(1), bar: None }"#, ); assert_eq!( send_request_get_body("bar=bar").await, r#"Params { foo: None, bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=").await, r#"Params { foo: None, bar: None }"#, ); assert_eq!( send_request_get_body("bar=").await, r#"Params { foo: None, bar: Some("") }"#, ); assert_eq!( send_request_get_body("").await, r#"Params { foo: None, bar: None }"#, ); } async fn send_request_get_body(query: &str) -> String { let body = app() .oneshot( Request::builder() .uri(format!("/?{query}")) .body(Body::empty()) .unwrap(), ) .await .unwrap() .into_body(); let bytes = body.collect().await.unwrap().to_bytes(); String::from_utf8(bytes.to_vec()).unwrap() } }
Form - accepting POST request
In this example we can see how an application can accept http POST requestts
To see the responses using curl
:
Asking for the main page with a GET request
$ curl http://localhost:3000/
This returns the HTML page.
$ curl -X POST \
-H "Content-Type: application/x-www-form-urlencoded" \
--data "name=Foo&email=foo@bar.com" \
http://localhost:3000/
email='foo@bar.com'
name='Foo'
Missing field
curl -X POST \
-H "Content-Type: application/x-www-form-urlencoded" \
--data "name=Foo" \
http://localhost:3000/
Failed to deserialize form body: missing field `email`
Extra fields are ignored
$ curl -X POST \
-H "Content-Type: application/x-www-form-urlencoded"
--data "name=Foo&email=foo@bar.com&age=42" \
http://localhost:3000/
email='foo@bar.com'
name='Foo'
[package]
name = "example-form"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
http-body-util = "0.1.3"
mime = "0.3.17"
tower = "0.5.2"
//! Run with //! //! ```not_rust //! cargo run -p example-form //! ``` use axum::{extract::Form, response::Html, routing::get, Router}; use serde::Deserialize; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { Router::new().route("/", get(show_form).post(accept_form)) } async fn show_form() -> Html<&'static str> { Html( r#" <!doctype html> <html> <head></head> <body> <form action="/" method="post"> <label for="name"> Enter your name: <input type="text" name="name"> </label> <label> Enter your email: <input type="text" name="email"> </label> <input type="submit" value="Subscribe!"> </form> </body> </html> "#, ) } #[derive(Deserialize, Debug)] #[allow(dead_code)] struct Input { name: String, email: String, } async fn accept_form(Form(input): Form<Input>) -> Html<String> { dbg!(&input); Html(format!( "email='{}'\nname='{}'\n", &input.email, &input.name )) } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{self, Request, StatusCode}, }; use http_body_util::BodyExt; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` // for `collect` #[tokio::test] async fn test_get() { let app = app(); let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); let body = std::str::from_utf8(&body).unwrap(); assert!(body.contains(r#"<input type="submit" value="Subscribe!">"#)); } #[tokio::test] async fn test_post() { let app = app(); let response = app .oneshot( Request::builder() .method(http::Method::POST) .uri("/") .header( http::header::CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), ) .body(Body::from("name=foo&email=bar@axum")) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); let body = std::str::from_utf8(&body).unwrap(); assert_eq!(body, "email='bar@axum'\nname='foo'\n"); } }
Validator
Showing how to use the validator crate to validate the values passed by the users.
Run
cargo run -p example-validator
Test
cargo test -p example-validator
Use
Field not supplied
$ curl -i http://localhost:3000/
HTTP/1.1 400 Bad Request
content-type: text/plain; charset=utf-8
content-length: 48
date: Wed, 19 Mar 2025 15:00:06 GMT
Failed to deserialize form: missing field `name`
Input too short
$ curl -i http://localhost:3000/?name=
HTTP/1.1 400 Bad Request
content-type: text/plain; charset=utf-8
content-length: 48
date: Wed, 19 Mar 2025 15:03:22 GMT
Input validation error: [name: Can not be empty]
Acceptable input
$ curl -i http://localhost:3000/?name=Jo
HTTP/1.1 200 OK
content-type: text/html; charset=utf-8
content-length: 19
date: Wed, 19 Mar 2025 15:03:52 GMT
<h1>Hello, Jo!</h1>$
[package]
edition = "2021"
name = "example-validator"
publish = false
version = "0.1.0"
[dependencies]
axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
thiserror = "1.0.29"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
validator = { version = "0.18.1", features = ["derive"] }
[dev-dependencies]
http-body-util = "0.1.0"
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-validator //! //! curl '127.0.0.1:3000?name=' //! -> Input validation error: [name: Can not be empty] //! //! curl '127.0.0.1:3000?name=LT' //! -> <h1>Hello, LT!</h1> //! ``` use axum::{ extract::{rejection::FormRejection, Form, FromRequest, Request}, http::StatusCode, response::{Html, IntoResponse, Response}, routing::get, Router, }; use serde::{de::DeserializeOwned, Deserialize}; use thiserror::Error; use tokio::net::TcpListener; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use validator::Validate; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route let app = app(); // run it let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { Router::new().route("/", get(handler)) } #[derive(Debug, Deserialize, Validate)] pub struct NameInput { #[validate(length(min = 2, message = "Can not be empty"))] pub name: String, } async fn handler(ValidatedForm(input): ValidatedForm<NameInput>) -> Html<String> { Html(format!("<h1>Hello, {}!</h1>", input.name)) } #[derive(Debug, Clone, Copy, Default)] pub struct ValidatedForm<T>(pub T); impl<T, S> FromRequest<S> for ValidatedForm<T> where T: DeserializeOwned + Validate, S: Send + Sync, Form<T>: FromRequest<S, Rejection = FormRejection>, { type Rejection = ServerError; async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { let Form(value) = Form::<T>::from_request(req, state).await?; value.validate()?; Ok(ValidatedForm(value)) } } #[derive(Debug, Error)] pub enum ServerError { #[error(transparent)] ValidationError(#[from] validator::ValidationErrors), #[error(transparent)] AxumFormRejection(#[from] FormRejection), } impl IntoResponse for ServerError { fn into_response(self) -> Response { match self { ServerError::ValidationError(_) => { let message = format!("Input validation error: [{self}]").replace('\n', ", "); (StatusCode::BAD_REQUEST, message) } ServerError::AxumFormRejection(_) => (StatusCode::BAD_REQUEST, self.to_string()), } .into_response() } } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{Request, StatusCode}, }; use http_body_util::BodyExt; use tower::ServiceExt; async fn get_html(response: Response<Body>) -> String { let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); String::from_utf8(bytes.to_vec()).unwrap() } #[tokio::test] async fn test_no_param() { let response = app() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); let html = get_html(response).await; assert_eq!(html, "Failed to deserialize form: missing field `name`"); } #[tokio::test] async fn test_with_param_without_value() { let response = app() .oneshot( Request::builder() .uri("/?name=") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); let html = get_html(response).await; assert_eq!(html, "Input validation error: [name: Can not be empty]"); } #[tokio::test] async fn test_with_param_with_short_value() { let response = app() .oneshot( Request::builder() .uri("/?name=X") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); let html = get_html(response).await; assert_eq!(html, "Input validation error: [name: Can not be empty]"); } #[tokio::test] async fn test_with_param_and_value() { let response = app() .oneshot( Request::builder() .uri("/?name=LT") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let html = get_html(response).await; assert_eq!(html, "<h1>Hello, LT!</h1>"); } }
Middleware
print-request-response
[package]
name = "example-print-request-response"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-print-request-response //! ``` use axum::{ body::{Body, Bytes}, extract::Request, http::StatusCode, middleware::{self, Next}, response::{IntoResponse, Response}, routing::post, Router, }; use http_body_util::BodyExt; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new() .route("/", post(|| async move { "Hello from `POST /`" })) .layer(middleware::from_fn(print_request_response)); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn print_request_response( req: Request, next: Next, ) -> Result<impl IntoResponse, (StatusCode, String)> { let (parts, body) = req.into_parts(); let bytes = buffer_and_print("request", body).await?; let req = Request::from_parts(parts, Body::from(bytes)); let res = next.run(req).await; let (parts, body) = res.into_parts(); let bytes = buffer_and_print("response", body).await?; let res = Response::from_parts(parts, Body::from(bytes)); Ok(res) } async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)> where B: axum::body::HttpBody<Data = Bytes>, B::Error: std::fmt::Display, { let bytes = match body.collect().await { Ok(collected) => collected.to_bytes(), Err(err) => { return Err(( StatusCode::BAD_REQUEST, format!("failed to read {direction} body: {err}"), )); } }; if let Ok(body) = std::str::from_utf8(&bytes) { tracing::debug!("{direction} body = {body:?}"); } Ok(bytes) }
Templating Systems
templates
[package]
name = "example-templates"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
askama = "0.12"
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
http-body-util = "0.1.0"
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-templates //! ``` use askama::Template; use axum::{ extract, http::StatusCode, response::{Html, IntoResponse, Response}, routing::get, Router, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { Router::new().route("/greet/{name}", get(greet)) } async fn greet(extract::Path(name): extract::Path<String>) -> impl IntoResponse { let template = HelloTemplate { name }; HtmlTemplate(template) } #[derive(Template)] #[template(path = "hello.html")] struct HelloTemplate { name: String, } struct HtmlTemplate<T>(T); impl<T> IntoResponse for HtmlTemplate<T> where T: Template, { fn into_response(self) -> Response { match self.0.render() { Ok(html) => Html(html).into_response(), Err(err) => ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to render template. Error: {err}"), ) .into_response(), } } } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{Request, StatusCode}, }; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_main() { let response = app() .oneshot( Request::builder() .uri("/greet/Foo") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "<h1>Hello, Foo!</h1>"); } }
<h1>Hello, {{ name }}!</h1>
templates-minijinja
[package]
name = "example-templates-minijinja"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
minijinja = "2.3.1"
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]
http-body-util = "0.1.0"
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-templates-minijinja //! ``` //! Demo for the MiniJinja templating engine. //! Exposes three pages all sharing the same layout with a minimal nav menu. use axum::extract::State; use axum::http::StatusCode; use axum::{response::Html, routing::get, Router}; use minijinja::{context, Environment}; use std::sync::Arc; struct AppState { env: Environment<'static>, } #[tokio::main] async fn main() { let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { // init template engine and add templates let mut env = Environment::new(); env.add_template("layout", include_str!("../templates/layout.jinja")) .unwrap(); env.add_template("home", include_str!("../templates/home.jinja")) .unwrap(); env.add_template("content", include_str!("../templates/content.jinja")) .unwrap(); env.add_template("about", include_str!("../templates/about.jinja")) .unwrap(); // pass env to handlers via state let app_state = Arc::new(AppState { env }); // define routes Router::new() .route("/", get(handler_home)) .route("/content", get(handler_content)) .route("/about", get(handler_about)) .with_state(app_state) } async fn handler_home(State(state): State<Arc<AppState>>) -> Result<Html<String>, StatusCode> { let template = state.env.get_template("home").unwrap(); let rendered = template .render(context! { title => "Home", welcome_text => "Hello World!", }) .unwrap(); Ok(Html(rendered)) } async fn handler_content(State(state): State<Arc<AppState>>) -> Result<Html<String>, StatusCode> { let template = state.env.get_template("content").unwrap(); let some_example_entries = vec!["Data 1", "Data 2", "Data 3"]; let rendered = template .render(context! { title => "Content", entries => some_example_entries, }) .unwrap(); Ok(Html(rendered)) } async fn handler_about(State(state): State<Arc<AppState>>) -> Result<Html<String>, StatusCode> { let template = state.env.get_template("about").unwrap(); let rendered = template.render(context!{ title => "About", about_text => "Simple demonstration layout for an axum project with minijinja as templating engine.", }).unwrap(); Ok(Html(rendered)) } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request, http::StatusCode}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_main_page() { let response = app() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); check_layout(&html); assert!(html.contains("<title>Website Name | Home </title>")); assert!(html.contains("<h1>Home</h1>")); assert!(html.contains("<p>Hello World!</p>")); } #[tokio::test] async fn test_content_page() { let response = app() .oneshot( Request::builder() .uri("/content") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); check_layout(&html); assert!(html.contains("<title>Website Name | Content </title>")); assert!(html.contains("<h1>Content</h1>")); assert!(html.contains("<li>Data 1</li>")); } #[tokio::test] async fn test_about_page() { let response = app() .oneshot( Request::builder() .uri("/about") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); check_layout(&html); assert!(html.contains("<title>Website Name | About </title>")); assert!(html.contains("<h1>About</h1>")); assert!(html.contains("<p>Simple demonstration layout for an axum project with minijinja as templating engine.</p>")); } fn check_layout(html: &str) { assert!(html.contains(r#"<li><a href="/">Home</a></li>"#)); } }
{% extends "layout" %}
{% block title %}{{ super() }} | {{ title }} {% endblock %}
{% block body %}
<h1>{{ title }}</h1>
<p>{{ welcome_text }}</p>
{% endblock %}
{% extends "layout" %}
{% block title %}{{ super() }} | {{ title }} {% endblock %}
{% block body %}
<h1>{{ title }}</h1>
<p>{{ about_text }}</p>
{% endblock %}
{% extends "layout" %}
{% block title %}{{ super() }} | {{ title }} {% endblock %}
{% block body %}
<h1>{{ title }}</h1>
{% for data_entry in entries %}
<ul>
<li>{{ data_entry }}</li>
</ul>
{% endfor %}
{% endblock %}
<!doctype html>
<html>
<head><title>{% block title %}Website Name{% endblock %}</title></head>
<body>
<nav>
<ul>
<li><a href="/">Home</a></li>
<li><a href="/content">Content</a></li>
<li><a href="/about">About</a></li>
</ul>
</nav>
{% block body %}{% endblock %}
</body>
</html>
examples/templates-minijinja/
├── Cargo.toml
├── src
│ └── main.rs
└── templates
├── about.jinja
├── content.jinja
├── home.jinja
└── layout.jinja
global-404-handler
[package]
name = "example-global-404-handler"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-global-404-handler //! ``` use axum::{ http::StatusCode, response::{Html, IntoResponse}, routing::get, Router, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route let app = Router::new().route("/", get(handler)); // add a fallback service for handling routes to unknown paths let app = app.fallback(handler_404); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn handler() -> Html<&'static str> { Html("<h1>Hello, World!</h1>") } async fn handler_404() -> impl IntoResponse { (StatusCode::NOT_FOUND, "nothing to see here") }
Redirect
We can redirect a request to another page on our side or to a page on another site.
For this we use the Redirect struct that has methods
for permanent
redirection (308 Permanent Redirect
) and temporary
redirection (307 Temporary Redirect
).
[package]
name = "example-redirect"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0"
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-redirect //! ``` use axum::{ response::{Html, Redirect}, routing::get, Router, }; #[tokio::main] async fn main() { let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app()).await.unwrap(); } fn app() -> Router { Router::new() .route("/", get(main_page)) .route("/try", get(try_page)) .route("/land", get(land_page)) } async fn main_page() -> Html<&'static str> { Html( r#" <h1>Redirect</h1> <a href="/try">Try</a><br> <a href="/land">Land</a> "#, ) } async fn try_page() -> Redirect { Redirect::temporary("/land") } async fn land_page() -> Html<&'static str> { Html("Landed") } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{Request, StatusCode}, }; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_main() { let response = app() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert!(html.contains("<h1>Redirect</h1>")); } #[tokio::test] async fn test_landed() { let response = app() .oneshot(Request::builder().uri("/land").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "Landed"); } #[tokio::test] async fn test_try() { let response = app() .oneshot(Request::builder().uri("/try").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::TEMPORARY_REDIRECT); let location = response.headers().get("location").unwrap(); assert_eq!(location, "/land"); } }
Minimal logging setup
axum uses the tracing and tracing-subscriber for logging so we need to include both.
{{#include ../../examples/minimal-tracing/Cargo.toml}
#![allow(unused)] fn main() { {{#include ../../examples/minimal-tracing/src/main.rs} }
When we start the application with cargo run
we'll see line like this on the terminal:
2025-03-17T08:39:04.089621Z DEBUG example_minimal_tracing: listening on 127.0.0.1:3000
When we access the main page with a browser we'll see two more lines:
2025-03-17T08:39:27.044996Z TRACE axum::serve: connection 127.0.0.1:58560 accepted
2025-03-17T08:39:27.045345Z DEBUG example_minimal_tracing: in handler
anyhow-error-response
[package]
name = "example-anyhow-error-response"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
anyhow = "1.0"
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]
http-body-util = "0.1.0"
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-anyhow-error-response //! ``` use axum::{ http::StatusCode, response::{IntoResponse, Response}, routing::get, Router, }; #[tokio::main] async fn main() { let app = app(); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn handler() -> Result<(), AppError> { try_thing()?; Ok(()) } fn try_thing() -> Result<(), anyhow::Error> { anyhow::bail!("it failed!") } // Make our own error that wraps `anyhow::Error`. struct AppError(anyhow::Error); // Tell axum how to convert `AppError` into a response. impl IntoResponse for AppError { fn into_response(self) -> Response { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {}", self.0), ) .into_response() } } fn app() -> Router { Router::new().route("/", get(handler)) } // This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into // `Result<_, AppError>`. That way you don't need to do that manually. impl<E> From<E> for AppError where E: Into<anyhow::Error>, { fn from(err: E) -> Self { Self(err.into()) } } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request, http::StatusCode}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_main_page() { let response = app() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(html, "Something went wrong: it failed!"); } }
$ curl -i http://localhost:3000
HTTP/1.1 500 Internal Server Error
content-type: text/plain; charset=utf-8
content-length: 32
date: Sun, 16 Mar 2025 16:18:04 GMT
Something went wrong: it failed!
compression
This example shows how to:
- automatically decompress request bodies when necessary
- compress response bodies based on the
accept
header.
Running
cargo run -p example-compression
Sending compressed requests
curl -v -g 'http://localhost:3000/' \
-H "Content-Type: application/json" \
-H "Content-Encoding: gzip" \
--compressed \
--data-binary @data/products.json.gz
(Notice the Content-Encoding: gzip
in the request, and content-encoding: gzip
in the response.)
Sending uncompressed requests
curl -v -g 'http://localhost:3000/' \
-H "Content-Type: application/json" \
--compressed \
--data-binary @data/products.json
Cargo.toml
[package]
name = "example-compression"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
serde_json = "1"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
tower = "0.5.2"
tower-http = { version = "0.6.1", features = ["compression-full", "decompression-full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
assert-json-diff = "2.0"
brotli = "6.0"
flate2 = "1"
http = "1"
zstd = "0.13"
main.rs
use axum::{routing::post, Json, Router}; use serde_json::Value; use tower::ServiceBuilder; use tower_http::{compression::CompressionLayer, decompression::RequestDecompressionLayer}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[cfg(test)] mod tests; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let app: Router = app(); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { Router::new().route("/", post(root)).layer( ServiceBuilder::new() .layer(RequestDecompressionLayer::new()) .layer(CompressionLayer::new()), ) } async fn root(Json(value): Json<Value>) -> Json<Value> { Json(value) }
test.rs
#![allow(unused)] fn main() { use assert_json_diff::assert_json_eq; use axum::{ body::{Body, Bytes}, response::Response, }; use brotli::enc::BrotliEncoderParams; use flate2::{read::GzDecoder, write::GzEncoder, Compression}; use http::{header, StatusCode}; use serde_json::{json, Value}; use std::io::{Read, Write}; use tower::ServiceExt; use super::*; #[tokio::test] async fn handle_uncompressed_request_bodies() { // Given let body = json(); let compressed_request = http::Request::post("/") .header(header::CONTENT_TYPE, "application/json") .body(json_body(&body)) .unwrap(); // When let response = app().oneshot(compressed_request).await.unwrap(); // Then assert_eq!(response.status(), StatusCode::OK); assert_json_eq!(json_from_response(response).await, json()); } #[tokio::test] async fn decompress_gzip_request_bodies() { // Given let body = compress_gzip(&json()); let compressed_request = http::Request::post("/") .header(header::CONTENT_TYPE, "application/json") .header(header::CONTENT_ENCODING, "gzip") .body(Body::from(body)) .unwrap(); // When let response = app().oneshot(compressed_request).await.unwrap(); // Then assert_eq!(response.status(), StatusCode::OK); assert_json_eq!(json_from_response(response).await, json()); } #[tokio::test] async fn decompress_br_request_bodies() { // Given let body = compress_br(&json()); let compressed_request = http::Request::post("/") .header(header::CONTENT_TYPE, "application/json") .header(header::CONTENT_ENCODING, "br") .body(Body::from(body)) .unwrap(); // When let response = app().oneshot(compressed_request).await.unwrap(); // Then assert_eq!(response.status(), StatusCode::OK); assert_json_eq!(json_from_response(response).await, json()); } #[tokio::test] async fn decompress_zstd_request_bodies() { // Given let body = compress_zstd(&json()); let compressed_request = http::Request::post("/") .header(header::CONTENT_TYPE, "application/json") .header(header::CONTENT_ENCODING, "zstd") .body(Body::from(body)) .unwrap(); // When let response = app().oneshot(compressed_request).await.unwrap(); // Then assert_eq!(response.status(), StatusCode::OK); assert_json_eq!(json_from_response(response).await, json()); } #[tokio::test] async fn do_not_compress_response_bodies() { // Given let request = http::Request::post("/") .header(header::CONTENT_TYPE, "application/json") .body(json_body(&json())) .unwrap(); // When let response = app().oneshot(request).await.unwrap(); // Then assert_eq!(response.status(), StatusCode::OK); assert_json_eq!(json_from_response(response).await, json()); } #[tokio::test] async fn compress_response_bodies_with_gzip() { // Given let request = http::Request::post("/") .header(header::CONTENT_TYPE, "application/json") .header(header::ACCEPT_ENCODING, "gzip") .body(json_body(&json())) .unwrap(); // When let response = app().oneshot(request).await.unwrap(); // Then assert_eq!(response.status(), StatusCode::OK); let response_body = byte_from_response(response).await; let mut decoder = GzDecoder::new(response_body.as_ref()); let mut decompress_body = String::new(); decoder.read_to_string(&mut decompress_body).unwrap(); assert_json_eq!( serde_json::from_str::<serde_json::Value>(&decompress_body).unwrap(), json() ); } #[tokio::test] async fn compress_response_bodies_with_br() { // Given let request = http::Request::post("/") .header(header::CONTENT_TYPE, "application/json") .header(header::ACCEPT_ENCODING, "br") .body(json_body(&json())) .unwrap(); // When let response = app().oneshot(request).await.unwrap(); // Then assert_eq!(response.status(), StatusCode::OK); let response_body = byte_from_response(response).await; let mut decompress_body = Vec::new(); brotli::BrotliDecompress(&mut response_body.as_ref(), &mut decompress_body).unwrap(); assert_json_eq!( serde_json::from_slice::<serde_json::Value>(&decompress_body).unwrap(), json() ); } #[tokio::test] async fn compress_response_bodies_with_zstd() { // Given let request = http::Request::post("/") .header(header::CONTENT_TYPE, "application/json") .header(header::ACCEPT_ENCODING, "zstd") .body(json_body(&json())) .unwrap(); // When let response = app().oneshot(request).await.unwrap(); // Then assert_eq!(response.status(), StatusCode::OK); let response_body = byte_from_response(response).await; let decompress_body = zstd::stream::decode_all(std::io::Cursor::new(response_body)).unwrap(); assert_json_eq!( serde_json::from_slice::<serde_json::Value>(&decompress_body).unwrap(), json() ); } fn json() -> Value { json!({ "name": "foo", "mainProduct": { "typeId": "product", "id": "p1" }, }) } fn json_body(input: &Value) -> Body { Body::from(serde_json::to_vec(&input).unwrap()) } async fn json_from_response(response: Response) -> Value { let body = byte_from_response(response).await; body_as_json(body) } async fn byte_from_response(response: Response) -> Bytes { axum::body::to_bytes(response.into_body(), usize::MAX) .await .unwrap() } fn body_as_json(body: Bytes) -> Value { serde_json::from_slice(body.as_ref()).unwrap() } fn compress_gzip(json: &Value) -> Vec<u8> { let request_body = serde_json::to_vec(&json).unwrap(); let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); encoder.write_all(&request_body).unwrap(); encoder.finish().unwrap() } fn compress_br(json: &Value) -> Vec<u8> { let request_body = serde_json::to_vec(&json).unwrap(); let mut result = Vec::new(); let params = BrotliEncoderParams::default(); let _ = brotli::enc::BrotliCompress(&mut &request_body[..], &mut result, ¶ms).unwrap(); result } fn compress_zstd(json: &Value) -> Vec<u8> { let request_body = serde_json::to_vec(&json).unwrap(); zstd::stream::encode_all(std::io::Cursor::new(request_body), 4).unwrap() } }
{
"products": [
{
"id": 1,
"name": "Product 1"
},
{
"id": 2,
"name": "Product 2"
}
]
}
There is also a file called data/products.json.gz
testing
[package]
name = "example-testing"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0"
hyper-util = { version = "0.1", features = ["client", "http1", "client-legacy"] }
mime = "0.3"
serde_json = "1.0"
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.6.1", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo test -p example-testing //! ``` use std::net::SocketAddr; use axum::{ extract::ConnectInfo, routing::{get, post}, Json, Router, }; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app()).await.unwrap(); } /// Having a function that produces our app makes it easy to call it from tests /// without having to create an HTTP server. fn app() -> Router { Router::new() .route("/", get(|| async { "Hello, World!" })) .route( "/json", post(|payload: Json<serde_json::Value>| async move { Json(serde_json::json!({ "data": payload.0 })) }), ) .route( "/requires-connect-info", get(|ConnectInfo(addr): ConnectInfo<SocketAddr>| async move { format!("Hi {addr}") }), ) // We can still add middleware .layer(TraceLayer::new_for_http()) } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, extract::connect_info::MockConnectInfo, http::{self, Request, StatusCode}, }; use http_body_util::BodyExt; // for `collect` use serde_json::{json, Value}; use tokio::net::TcpListener; use tower::{Service, ServiceExt}; // for `call`, `oneshot`, and `ready` #[tokio::test] async fn hello_world() { let app = app(); // `Router` implements `tower::Service<Request<Body>>` so we can // call it like any tower service, no need to run an HTTP server. let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); assert_eq!(&body[..], b"Hello, World!"); } #[tokio::test] async fn json() { let app = app(); let response = app .oneshot( Request::builder() .method(http::Method::POST) .uri("/json") .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::from( serde_json::to_vec(&json!([1, 2, 3, 4])).unwrap(), )) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); let body: Value = serde_json::from_slice(&body).unwrap(); assert_eq!(body, json!({ "data": [1, 2, 3, 4] })); } #[tokio::test] async fn not_found() { let app = app(); let response = app .oneshot( Request::builder() .uri("/does-not-exist") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::NOT_FOUND); let body = response.into_body().collect().await.unwrap().to_bytes(); assert!(body.is_empty()); } // You can also spawn a server and talk to it like any other HTTP server: #[tokio::test] async fn the_real_deal() { let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app()).await.unwrap(); }); let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) .build_http(); let response = client .request( Request::builder() .uri(format!("http://{addr}")) .header("Host", "localhost") .body(Body::empty()) .unwrap(), ) .await .unwrap(); let body = response.into_body().collect().await.unwrap().to_bytes(); assert_eq!(&body[..], b"Hello, World!"); } // You can use `ready()` and `call()` to avoid using `clone()` // in multiple request #[tokio::test] async fn multiple_request() { let mut app = app().into_service(); let request = Request::builder().uri("/").body(Body::empty()).unwrap(); let response = ServiceExt::<Request<Body>>::ready(&mut app) .await .unwrap() .call(request) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let request = Request::builder().uri("/").body(Body::empty()).unwrap(); let response = ServiceExt::<Request<Body>>::ready(&mut app) .await .unwrap() .call(request) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); } // Here we're calling `/requires-connect-info` which requires `ConnectInfo` // // That is normally set with `Router::into_make_service_with_connect_info` but we can't easily // use that during tests. The solution is instead to set the `MockConnectInfo` layer during // tests. #[tokio::test] async fn with_into_make_service_with_connect_info() { let mut app = app() .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 3000)))) .into_service(); let request = Request::builder() .uri("/requires-connect-info") .body(Body::empty()) .unwrap(); let response = app.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); } }
Handle HEAD request
This example shows how to:
- handle HEAD requests in their own rout
- handle HEAD requests in a get route
Running
cargo run -p example-handle-head-request
Sending GET request to GET handler
$ curl -i http://localhost:3000/my-get
HTTP/1.1 200 OK
content-type: text/plain; charset=utf-8
x-some-header: header from GET
content-length: 13
date: Tue, 18 Mar 2025 07:10:38 GMT
body from GET
Sending HEAD request to GET handler
$ curl -I http://localhost:3000/my-get
HTTP/1.1 200 OK
x-some-header: header from HEAD in get-handler
content-length: 0
date: Tue, 18 Mar 2025 07:11:17 GMT
Sending GET request to HEAD handler
This is not handled
$ curl -i http://localhost:3000/my-head
HTTP/1.1 405 Method Not Allowed
allow: HEAD
content-length: 0
date: Tue, 18 Mar 2025 07:12:12 GMT
Sending HEAD request to HEAD handler
$ curl -I http://localhost:3000/my-head
HTTP/1.1 200 OK
x-some-header: header from HEAD in head-handler
content-length: 0
date: Tue, 18 Mar 2025 07:12:50 GMT
[package]
name = "example-handle-head-request"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]
http-body-util = "0.1.0"
hyper = { version = "1.0.0", features = ["full"] }
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-handle-head-request //! ``` use axum::response::{IntoResponse, Response}; use axum::{http, routing::get, routing::head, Router}; fn app() -> Router { Router::new() .route("/my-get", get(get_handler)) .route("/my-head", head(head_handler)) } #[tokio::main] async fn main() { let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app()).await.unwrap(); } // GET routes will also be called for HEAD requests but will have the response body removed. // You can handle the HEAD method explicitly by extracting `http::Method` from the request. async fn get_handler(method: http::Method) -> Response { // it usually only makes sense to special-case HEAD // if computing the body has some relevant cost if method == http::Method::HEAD { return ([("x-some-header", "header from HEAD in get-handler")]).into_response(); } // then do some computing task in GET do_some_computing_task(); ([("x-some-header", "header from GET")], "body from GET").into_response() } fn do_some_computing_task() { // TODO } // HET routes will be called only for HEAD requests. async fn head_handler() -> Response { // it usually only makes sense to special-case HEAD // if computing the body has some relevant cost ([("x-some-header", "header from HEAD in head-handler")]).into_response() } #[cfg(test)] mod tests { use super::*; use axum::body::Body; use axum::http::{Request, StatusCode}; use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] async fn test_get_from_get_handler() { let app = app(); let response = app .oneshot(Request::get("/my-get").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.headers()["x-some-header"], "header from GET"); let body = response.collect().await.unwrap().to_bytes(); assert_eq!(&body[..], b"body from GET"); } #[tokio::test] async fn test_implicit_head() { let app = app(); let response = app .oneshot(Request::head("/my-get").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); assert_eq!( response.headers()["x-some-header"], "header from HEAD in get-handler" ); let body = response.collect().await.unwrap().to_bytes(); assert!(body.is_empty()); } #[tokio::test] async fn test_get_from_head_handler() { let app = app(); let response = app .oneshot(Request::get("/my-head").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED); assert!(!response.headers().contains_key("x-some-header")); assert_eq!(response.headers()["allow"], "HEAD"); let body = response.collect().await.unwrap().to_bytes(); assert!(body.is_empty()); } #[tokio::test] async fn test_head_from_head_handler() { let app = app(); let response = app .oneshot(Request::head("/my-head").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); assert_eq!( response.headers()["x-some-header"], "header from HEAD in head-handler" ); let body = response.collect().await.unwrap().to_bytes(); assert!(body.is_empty()); } }
TODOs
[package]
name = "example-todos"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.5.2", features = ["util", "timeout"] }
tower-http = { version = "0.6.1", features = ["add-extension", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.0", features = ["serde", "v4"] }
[dev-dependencies]
http-body-util = "0.1.0"
mime = "0.3.17"
serde_json = "1.0.140"
//! Provides a RESTful web server managing some Todos. //! //! API will be: //! //! - `GET /todos`: return a JSON list of Todos. //! - `POST /todos`: create a new Todo. //! - `PATCH /todos/{id}`: update a specific Todo. //! - `DELETE /todos/{id}`: delete a specific Todo. //! //! Run with //! //! ```not_rust //! cargo run -p example-todos //! ``` use axum::{ error_handling::HandleErrorLayer, extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, routing::{get, patch}, Json, Router, }; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, sync::{Arc, RwLock}, time::Duration, }; use tower::{BoxError, ServiceBuilder}; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use uuid::Uuid; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); // Compose the routes let app = app(); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { let db = Db::default(); Router::new() .route("/todos", get(todos_index).post(todos_create)) .route("/todos/{id}", patch(todos_update).delete(todos_delete)) // Add middleware to all routes .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|error: BoxError| async move { if error.is::<tower::timeout::error::Elapsed>() { Ok(StatusCode::REQUEST_TIMEOUT) } else { Err(( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {error}"), )) } })) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) .into_inner(), ) .with_state(db) } // The query parameters for todos index #[derive(Debug, Deserialize, Default)] pub struct Pagination { pub offset: Option<usize>, pub limit: Option<usize>, } async fn todos_index(pagination: Query<Pagination>, State(db): State<Db>) -> impl IntoResponse { let todos = db.read().unwrap(); let todos = todos .values() .skip(pagination.offset.unwrap_or(0)) .take(pagination.limit.unwrap_or(usize::MAX)) .cloned() .collect::<Vec<_>>(); Json(todos) } #[derive(Debug, Deserialize)] struct CreateTodo { text: String, } async fn todos_create(State(db): State<Db>, Json(input): Json<CreateTodo>) -> impl IntoResponse { let todo = Todo { id: Uuid::new_v4(), text: input.text, completed: false, }; db.write().unwrap().insert(todo.id, todo.clone()); (StatusCode::CREATED, Json(todo)) } #[derive(Debug, Deserialize)] struct UpdateTodo { text: Option<String>, completed: Option<bool>, } async fn todos_update( Path(id): Path<Uuid>, State(db): State<Db>, Json(input): Json<UpdateTodo>, ) -> Result<impl IntoResponse, StatusCode> { let mut todo = db .read() .unwrap() .get(&id) .cloned() .ok_or(StatusCode::NOT_FOUND)?; if let Some(text) = input.text { todo.text = text; } if let Some(completed) = input.completed { todo.completed = completed; } db.write().unwrap().insert(todo.id, todo.clone()); Ok(Json(todo)) } async fn todos_delete(Path(id): Path<Uuid>, State(db): State<Db>) -> impl IntoResponse { if db.write().unwrap().remove(&id).is_some() { StatusCode::NO_CONTENT } else { StatusCode::NOT_FOUND } } type Db = Arc<RwLock<HashMap<Uuid, Todo>>>; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] struct Todo { id: Uuid, text: String, completed: bool, } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{self, Request, StatusCode}, routing::RouterIntoService, }; use http_body_util::BodyExt; use serde_json::json; use tower::{Service, ServiceExt}; #[tokio::test] async fn test_empty_list_of_todos() { let response = app() .oneshot( Request::builder() .uri("/todos") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); let todos = serde_json::from_str::<Vec<Todo>>(&html).unwrap(); assert_eq!(todos, []); } #[tokio::test] async fn test_add_todo() { let mut app = app().into_service(); let request = Request::builder() .method(http::Method::POST) .uri("/todos") .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::from(json!({"text": "Write more tests!"}).to_string())) .unwrap(); let response = ServiceExt::<Request<Body>>::ready(&mut app) .await .unwrap() .call(request) .await .unwrap(); assert_eq!(response.status(), StatusCode::CREATED); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); let todo = serde_json::from_str::<Todo>(&html).unwrap(); assert_eq!(todo.text, "Write more tests!"); assert!(!todo.completed); let todos = get_todos(&mut app).await; assert_eq!(todos.len(), 1); assert_eq!(todos[0], todo); } #[tokio::test] async fn test_complex() { let mut app = app().into_service(); let todos = get_todos(&mut app).await; assert_eq!(todos.len(), 0); let mut todo0 = add_todo(&mut app, "Add more tests!").await; let todo1 = add_todo(&mut app, "Some other thing to do.").await; let todo2 = add_todo(&mut app, "Write a book about axum.").await; let mut todos = get_todos(&mut app).await; assert_eq!(todos.len(), 3); // Ensure the order is correct for the tests todos.sort_by_key(|todo| todo.text.clone()); assert_eq!(todos[0], todo0); assert_eq!(todos[1], todo1); assert_eq!(todos[2], todo2); let (status, res) = update_todo(&mut app, todo0.id, &todo0.text, true).await; assert_eq!(status, StatusCode::OK); todo0.completed = true; assert_eq!(res, Some(todo0.clone())); let status = delete_todo(&mut app, todo1.id).await; assert_eq!(status, StatusCode::NO_CONTENT); let mut todos = get_todos(&mut app).await; assert_eq!(todos.len(), 2); // Ensure the order is correct for the tests todos.sort_by_key(|todo| todo.text.clone()); assert_eq!(todos[0], todo0); assert_eq!(todos[1], todo2); let status = delete_todo(&mut app, todo1.id).await; assert_eq!(status, StatusCode::NOT_FOUND); let (status, res) = update_todo(&mut app, todo1.id, "", true).await; assert_eq!(status, StatusCode::NOT_FOUND); assert_eq!(res, None); } async fn get_todos(app: &mut RouterIntoService<Body>) -> Vec<Todo> { let request = Request::builder() .uri("/todos") .body(Body::empty()) .unwrap(); let response = ServiceExt::<Request<Body>>::ready(app) .await .unwrap() .call(request) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); serde_json::from_str::<Vec<Todo>>(&html).unwrap() } async fn add_todo(app: &mut RouterIntoService<Body>, text: &str) -> Todo { let request = Request::builder() .method(http::Method::POST) .uri("/todos") .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::from(json!({"text": text}).to_string())) .unwrap(); let response = ServiceExt::<Request<Body>>::ready(app) .await .unwrap() .call(request) .await .unwrap(); assert_eq!(response.status(), StatusCode::CREATED); let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); serde_json::from_str::<Todo>(&html).unwrap() } async fn update_todo( app: &mut RouterIntoService<Body>, id: Uuid, text: &str, completed: bool, ) -> (StatusCode, Option<Todo>) { let request = Request::builder() .method(http::Method::PATCH) .uri(format!("/todos/{id}")) .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::from( json!({"text": text, "completed": completed}).to_string(), )) .unwrap(); let response = ServiceExt::<Request<Body>>::ready(app) .await .unwrap() .call(request) .await .unwrap(); let status = response.status(); if status != StatusCode::OK { return (status, None); } let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); let html = String::from_utf8(bytes.to_vec()).unwrap(); (status, Some(serde_json::from_str::<Todo>(&html).unwrap())) } async fn delete_todo(app: &mut RouterIntoService<Body>, id: Uuid) -> StatusCode { let request = Request::builder() .method(http::Method::DELETE) .uri(format!("/todos/{id}")) .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::empty()) .unwrap(); let response = ServiceExt::<Request<Body>>::ready(app) .await .unwrap() .call(request) .await .unwrap(); response.status() } }
$ curl -X POST -H "Content-Type: application/json" -d '{"text": "Hello World!"}' http://localhost:3000/todos
{"id":"ccd0ebd7-f2b3-4395-bf4a-273f1d0c9851","text":"Hello World!","completed":false}
$ curl -X POST -H "Content-Type: application/json" -d '{"text": "Another item"}' http://localhost:3000/todos
{"id":"5903e415-e162-4767-9b57-bf6583e89c3f","text":"Another item","completed":false}
Readme
[package]
name = "example-readme"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
http-body-util = "0.1.3"
mime = "0.3.17"
serde_json = "1.0.140"
tower = "0.5.2"
//! Run with //! //! ```not_rust //! cargo run -p example-readme //! ``` use axum::{ http::StatusCode, response::IntoResponse, routing::{get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; #[tokio::main] async fn main() { // initialize tracing tracing_subscriber::fmt::init(); // build our application with a route let app = app(); // run our app with hyper let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } // basic handler that responds with a static string async fn root() -> &'static str { "Hello, World!" } fn app() -> Router { Router::new() // `GET /` goes to `root` .route("/", get(root)) // `POST /users` goes to `create_user` .route("/users", post(create_user)) } async fn create_user( // this argument tells axum to parse the request body // as JSON into a `CreateUser` type Json(payload): Json<CreateUser>, ) -> impl IntoResponse { // insert your application logic here let user = User { id: 1337, username: payload.username, }; // this will be converted into a JSON response // with a status code of `201 Created` (StatusCode::CREATED, Json(user)) } // the input to our `create_user` handler #[derive(Deserialize)] struct CreateUser { username: String, } // the output to our `create_user` handler #[derive(Serialize)] struct User { id: u64, username: String, } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{self, Request, StatusCode}, }; use http_body_util::BodyExt; use serde_json::json; use tower::ServiceExt; // for `oneshot` #[tokio::test] async fn main_page() { let app = app(); // `Router` implements `tower::Service<Request<Body>>` so we can // call it like any tower service, no need to run an HTTP server. let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); let body = std::str::from_utf8(&body).unwrap(); //assert_eq!(&body[..], b"Hello, World!"); assert_eq!(body, "Hello, World!"); } #[tokio::test] async fn users_json() { let app = app(); let response = app .oneshot( Request::builder() .method(http::Method::POST) .uri("/users") .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::from( serde_json::to_vec(&json!({"username": "foobar"})).unwrap(), )) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::CREATED); let body = response.into_body().collect().await.unwrap().to_bytes(); let body = std::str::from_utf8(&body).unwrap(); assert_eq!(body, r#"{"id":1337,"username":"foobar"}"#); //let body: Value = serde_json::from_slice(&body).unwrap(); //assert_eq!(body, json!({ "data": [1, 2, 3, 4] })); } }
Using curl
$ curl http://localhost:3000/
Hello, World!
$ curl -X POST -H "Content-Type: application/json" -d '{"username":"foobar"}' http://localhost:3000/users
{"id":1337,"username":"foobar"}
TODO
- async-graphql See https://github.com/async-graphql/examples.
cors
[package]
name = "example-cors"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.6.1", features = ["cors"] }
//! Run with //! //! ```not_rust //! cargo run -p example-cors //! ``` use axum::{ http::{HeaderValue, Method}, response::{Html, IntoResponse}, routing::get, Json, Router, }; use std::net::SocketAddr; use tower_http::cors::CorsLayer; #[tokio::main] async fn main() { let frontend = async { let app = Router::new().route("/", get(html)); serve(app, 3000).await; }; let backend = async { let app = Router::new().route("/json", get(json)).layer( // see https://docs.rs/tower-http/latest/tower_http/cors/index.html // for more details // // pay attention that for some request types like posting content-type: application/json // it is required to add ".allow_headers([http::header::CONTENT_TYPE])" // or see this issue https://github.com/tokio-rs/axum/issues/849 CorsLayer::new() .allow_origin("http://localhost:3000".parse::<HeaderValue>().unwrap()) .allow_methods([Method::GET]), ); serve(app, 4000).await; }; tokio::join!(frontend, backend); } async fn serve(app: Router, port: u16) { let addr = SocketAddr::from(([127, 0, 0, 1], port)); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); } async fn html() -> impl IntoResponse { Html( r#" <script> fetch('http://localhost:4000/json') .then(response => response.json()) .then(data => console.log(data)); </script> "#, ) } async fn json() -> impl IntoResponse { Json(vec!["one", "two", "three"]) }
consume-body-in-extractor-or-middleware
[package]
name = "example-consume-body-in-extractor-or-middleware"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-consume-body-in-extractor-or-middleware //! ``` use axum::{ body::{Body, Bytes}, extract::{FromRequest, Request}, http::StatusCode, middleware::{self, Next}, response::{IntoResponse, Response}, routing::post, Router, }; use http_body_util::BodyExt; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new() .route("/", post(handler)) .layer(middleware::from_fn(print_request_body)); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } // middleware that shows how to consume the request body upfront async fn print_request_body(request: Request, next: Next) -> Result<impl IntoResponse, Response> { let request = buffer_request_body(request).await?; Ok(next.run(request).await) } // the trick is to take the request apart, buffer the body, do what you need to do, then put // the request back together async fn buffer_request_body(request: Request) -> Result<Request, Response> { let (parts, body) = request.into_parts(); // this won't work if the body is an long running stream let bytes = body .collect() .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())? .to_bytes(); do_thing_with_request_body(bytes.clone()); Ok(Request::from_parts(parts, Body::from(bytes))) } fn do_thing_with_request_body(bytes: Bytes) { tracing::debug!(body = ?bytes); } async fn handler(BufferRequestBody(body): BufferRequestBody) { tracing::debug!(?body, "handler received body"); } // extractor that shows how to consume the request body upfront struct BufferRequestBody(Bytes); // we must implement `FromRequest` (and not `FromRequestParts`) to consume the body impl<S> FromRequest<S> for BufferRequestBody where S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { let body = Bytes::from_request(req, state) .await .map_err(|err| err.into_response())?; do_thing_with_request_body(body.clone()); Ok(Self(body)) } }
customize-extractor-error
#![allow(unused)] fn main() { This example explores 3 different ways you can create custom rejections for already existing extractors - [`with_rejection`](src/with_rejection.rs): Uses `axum_extra::extract::WithRejection` to transform one rejection into another - [`derive_from_request`](src/derive_from_request.rs): Uses the `axum::extract::FromRequest` derive macro to wrap another extractor and customize the rejection - [`custom_extractor`](src/custom_extractor.rs): Manual implementation of `FromRequest` that wraps another extractor Run with ```sh cargo run -p example-customize-extractor-error }
```toml
[package]
name = "example-customize-extractor-error"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["macros"] }
axum-extra = { path = "../../axum-extra" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.20", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-customize-extractor-error //! ``` mod custom_extractor; mod derive_from_request; mod with_rejection; use axum::{routing::post, Router}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // Build our application with some routes let app = Router::new() .route("/with-rejection", post(with_rejection::handler)) .route("/custom-extractor", post(custom_extractor::handler)) .route("/derive-from-request", post(derive_from_request::handler)); // Run our application let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); }
#![allow(unused)] fn main() { //! Manual implementation of `FromRequest` that wraps another extractor //! //! + Powerful API: Implementing `FromRequest` grants access to `RequestParts` //! and `async/await`. This means that you can create more powerful rejections //! - Boilerplate: Requires creating a new extractor for every custom rejection //! - Complexity: Manually implementing `FromRequest` results on more complex code use axum::{ extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request}, http::StatusCode, response::IntoResponse, RequestPartsExt, }; use serde_json::{json, Value}; pub async fn handler(Json(value): Json<Value>) -> impl IntoResponse { Json(dbg!(value)); } // We define our own `Json` extractor that customizes the error from `axum::Json` pub struct Json<T>(pub T); impl<S, T> FromRequest<S> for Json<T> where axum::Json<T>: FromRequest<S, Rejection = JsonRejection>, S: Send + Sync, { type Rejection = (StatusCode, axum::Json<Value>); async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { let (mut parts, body) = req.into_parts(); // We can use other extractors to provide better rejection messages. // For example, here we are using `axum::extract::MatchedPath` to // provide a better error message. // // Have to run that first since `Json` extraction consumes the request. let path = parts .extract::<MatchedPath>() .await .map(|path| path.as_str().to_owned()) .ok(); let req = Request::from_parts(parts, body); match axum::Json::<T>::from_request(req, state).await { Ok(value) => Ok(Self(value.0)), // convert the error from `axum::Json` into whatever we want Err(rejection) => { let payload = json!({ "message": rejection.body_text(), "origin": "custom_extractor", "path": path, }); Err((rejection.status(), axum::Json(payload))) } } } } }
#![allow(unused)] fn main() { //! Uses `axum::extract::FromRequest` to wrap another extractor and customize the //! rejection //! //! + Easy learning curve: Deriving `FromRequest` generates a `FromRequest` //! implementation for your type using another extractor. You only need //! to provide a `From` impl between the original rejection type and the //! target rejection. Crates like [`thiserror`] can provide such conversion //! using derive macros. //! - Boilerplate: Requires deriving `FromRequest` for every custom rejection //! - There are some known limitations: [FromRequest#known-limitations] //! //! [`thiserror`]: https://crates.io/crates/thiserror //! [FromRequest#known-limitations]: https://docs.rs/axum-macros/*/axum_macros/derive.FromRequest.html#known-limitations use axum::{ extract::rejection::JsonRejection, extract::FromRequest, http::StatusCode, response::IntoResponse, }; use serde::Serialize; use serde_json::{json, Value}; pub async fn handler(Json(value): Json<Value>) -> impl IntoResponse { Json(dbg!(value)) } // create an extractor that internally uses `axum::Json` but has a custom rejection #[derive(FromRequest)] #[from_request(via(axum::Json), rejection(ApiError))] pub struct Json<T>(T); // We implement `IntoResponse` for our extractor so it can be used as a response impl<T: Serialize> IntoResponse for Json<T> { fn into_response(self) -> axum::response::Response { let Self(value) = self; axum::Json(value).into_response() } } // We create our own rejection type #[derive(Debug)] pub struct ApiError { status: StatusCode, message: String, } // We implement `From<JsonRejection> for ApiError` impl From<JsonRejection> for ApiError { fn from(rejection: JsonRejection) -> Self { Self { status: rejection.status(), message: rejection.body_text(), } } } // We implement `IntoResponse` so `ApiError` can be used as a response impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let payload = json!({ "message": self.message, "origin": "derive_from_request" }); (self.status, axum::Json(payload)).into_response() } } }
#![allow(unused)] fn main() { //! Uses `axum_extra::extract::WithRejection` to transform one rejection into //! another //! //! + Easy learning curve: `WithRejection` acts as a wrapper for another //! already existing extractor. You only need to provide a `From` impl //! between the original rejection type and the target rejection. Crates like //! `thiserror` can provide such conversion using derive macros. See //! [`thiserror`] //! - Verbose types: types become much larger, which makes them difficult to //! read. Current limitations on type aliasing makes impossible to destructure //! a type alias. See [#1116] //! //! [`thiserror`]: https://crates.io/crates/thiserror //! [#1116]: https://github.com/tokio-rs/axum/issues/1116#issuecomment-1186197684 use axum::{extract::rejection::JsonRejection, response::IntoResponse, Json}; use axum_extra::extract::WithRejection; use serde_json::{json, Value}; use thiserror::Error; pub async fn handler( // `WithRejection` will extract `Json<Value>` from the request. If it fails, // `JsonRejection` will be transform into `ApiError` and returned as response // to the client. // // The second constructor argument is not meaningful and can be safely ignored WithRejection(Json(value), _): WithRejection<Json<Value>, ApiError>, ) -> impl IntoResponse { Json(dbg!(value)) } // We derive `thiserror::Error` #[derive(Debug, Error)] pub enum ApiError { // The `#[from]` attribute generates `From<JsonRejection> for ApiError` // implementation. See `thiserror` docs for more information #[error(transparent)] JsonExtractorRejection(#[from] JsonRejection), } // We implement `IntoResponse` so ApiError can be used as a response impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let (status, message) = match self { ApiError::JsonExtractorRejection(json_rejection) => { (json_rejection.status(), json_rejection.body_text()) } }; let payload = json!({ "message": message, "origin": "with_rejection" }); (status, Json(payload)).into_response() } } }
customize-path-rejection
[package]
name = "example-customize-path-rejection"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-customize-path-rejection //! ``` use axum::{ extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts}, http::{request::Parts, StatusCode}, response::IntoResponse, routing::get, Router, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route let app = Router::new().route("/users/{user_id}/teams/{team_id}", get(handler)); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn handler(Path(params): Path<Params>) -> impl IntoResponse { axum::Json(params) } #[derive(Debug, Deserialize, Serialize)] struct Params { user_id: u32, team_id: u32, } // We define our own `Path` extractor that customizes the error from `axum::extract::Path` struct Path<T>(T); impl<S, T> FromRequestParts<S> for Path<T> where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, S: Send + Sync, { type Rejection = (StatusCode, axum::Json<PathError>); async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { match axum::extract::Path::<T>::from_request_parts(parts, state).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { let (status, body) = match rejection { PathRejection::FailedToDeserializePathParams(inner) => { let mut status = StatusCode::BAD_REQUEST; let kind = inner.into_kind(); let body = match &kind { ErrorKind::WrongNumberOfParameters { .. } => PathError { message: kind.to_string(), location: None, }, ErrorKind::ParseErrorAtKey { key, .. } => PathError { message: kind.to_string(), location: Some(key.clone()), }, ErrorKind::ParseErrorAtIndex { index, .. } => PathError { message: kind.to_string(), location: Some(index.to_string()), }, ErrorKind::ParseError { .. } => PathError { message: kind.to_string(), location: None, }, ErrorKind::InvalidUtf8InPathParam { key } => PathError { message: kind.to_string(), location: Some(key.clone()), }, ErrorKind::UnsupportedType { .. } => { // this error is caused by the programmer using an unsupported type // (such as nested maps) so respond with `500` instead status = StatusCode::INTERNAL_SERVER_ERROR; PathError { message: kind.to_string(), location: None, } } ErrorKind::Message(msg) => PathError { message: msg.clone(), location: None, }, _ => PathError { message: format!("Unhandled deserialization error: {kind}"), location: None, }, }; (status, body) } PathRejection::MissingPathParams(error) => ( StatusCode::INTERNAL_SERVER_ERROR, PathError { message: error.to_string(), location: None, }, ), _ => ( StatusCode::INTERNAL_SERVER_ERROR, PathError { message: format!("Unhandled path rejection: {rejection}"), location: None, }, ), }; Err((status, axum::Json(body))) } } } } #[derive(Serialize)] struct PathError { message: String, location: Option<String>, }
dependency-injection
[package]
name = "example-dependency-injection"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["tracing", "macros"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.0", features = ["serde", "v4"] }
//! Run with //! //! ```not_rust //! cargo run -p example-dependency-injection //! ``` use std::{ collections::HashMap, sync::{Arc, Mutex}, }; use axum::{ extract::{Path, State}, http::StatusCode, routing::{get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use uuid::Uuid; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let user_repo = InMemoryUserRepo::default(); // We generally have two ways to inject dependencies: // // 1. Using trait objects (`dyn SomeTrait`) // - Pros // - Likely leads to simpler code due to fewer type parameters. // - Cons // - Less flexible because we can only use object safe traits // - Small amount of additional runtime overhead due to dynamic dispatch. // This is likely to be negligible. // 2. Using generics (`T where T: SomeTrait`) // - Pros // - More flexible since all traits can be used. // - No runtime overhead. // - Cons: // - Additional type parameters and trait bounds can lead to more complex code and // boilerplate. // // Using trait objects is recommended unless you really need generics. let using_dyn = Router::new() .route("/users/{id}", get(get_user_dyn)) .route("/users", post(create_user_dyn)) .with_state(AppStateDyn { user_repo: Arc::new(user_repo.clone()), }); let using_generic = Router::new() .route("/users/{id}", get(get_user_generic::<InMemoryUserRepo>)) .route("/users", post(create_user_generic::<InMemoryUserRepo>)) .with_state(AppStateGeneric { user_repo }); let app = Router::new() .nest("/dyn", using_dyn) .nest("/generic", using_generic); let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } #[derive(Clone)] struct AppStateDyn { user_repo: Arc<dyn UserRepo>, } #[derive(Clone)] struct AppStateGeneric<T> { user_repo: T, } #[derive(Debug, Serialize, Clone)] struct User { id: Uuid, name: String, } #[derive(Deserialize)] struct UserParams { name: String, } async fn create_user_dyn( State(state): State<AppStateDyn>, Json(params): Json<UserParams>, ) -> Json<User> { let user = User { id: Uuid::new_v4(), name: params.name, }; state.user_repo.save_user(&user); Json(user) } async fn get_user_dyn( State(state): State<AppStateDyn>, Path(id): Path<Uuid>, ) -> Result<Json<User>, StatusCode> { match state.user_repo.get_user(id) { Some(user) => Ok(Json(user)), None => Err(StatusCode::NOT_FOUND), } } async fn create_user_generic<T>( State(state): State<AppStateGeneric<T>>, Json(params): Json<UserParams>, ) -> Json<User> where T: UserRepo, { let user = User { id: Uuid::new_v4(), name: params.name, }; state.user_repo.save_user(&user); Json(user) } async fn get_user_generic<T>( State(state): State<AppStateGeneric<T>>, Path(id): Path<Uuid>, ) -> Result<Json<User>, StatusCode> where T: UserRepo, { match state.user_repo.get_user(id) { Some(user) => Ok(Json(user)), None => Err(StatusCode::NOT_FOUND), } } trait UserRepo: Send + Sync { fn get_user(&self, id: Uuid) -> Option<User>; fn save_user(&self, user: &User); } #[derive(Debug, Clone, Default)] struct InMemoryUserRepo { map: Arc<Mutex<HashMap<Uuid, User>>>, } impl UserRepo for InMemoryUserRepo { fn get_user(&self, id: Uuid) -> Option<User> { self.map.lock().unwrap().get(&id).cloned() } fn save_user(&self, user: &User) { self.map.lock().unwrap().insert(user.id, user.clone()); } }
error-handling
[package]
name = "example-error-handling"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["macros"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.6.1", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Example showing how to convert errors into responses. //! //! Run with //! //! ```not_rust //! cargo run -p example-error-handling //! ``` //! //! For successful requests the log output will be //! //! ```ignore //! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request //! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=200 //! ``` //! //! For failed requests the log output will be //! //! ```ignore //! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request //! ERROR request{method=POST uri=/users matched_path="/users"}: example_error_handling: error from time_library err=failed to get time //! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=500 //! ``` use std::{ collections::HashMap, sync::{ atomic::{AtomicU64, Ordering}, Arc, Mutex, }, }; use axum::{ extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request, State}, http::StatusCode, response::{IntoResponse, Response}, routing::post, Router, }; use serde::{Deserialize, Serialize}; use time_library::Timestamp; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let state = AppState::default(); let app = Router::new() // A dummy route that accepts some JSON but sometimes fails .route("/users", post(users_create)) .layer( TraceLayer::new_for_http() // Create our own span for the request and include the matched path. The matched // path is useful for figuring out which handler the request was routed to. .make_span_with(|req: &Request| { let method = req.method(); let uri = req.uri(); // axum automatically adds this extension. let matched_path = req .extensions() .get::<MatchedPath>() .map(|matched_path| matched_path.as_str()); tracing::debug_span!("request", %method, %uri, matched_path) }) // By default `TraceLayer` will log 5xx responses but we're doing our specific // logging of errors so disable that .on_failure(()), ) .with_state(state); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } #[derive(Default, Clone)] struct AppState { next_id: Arc<AtomicU64>, users: Arc<Mutex<HashMap<u64, User>>>, } #[derive(Deserialize)] struct UserParams { name: String, } #[derive(Serialize, Clone)] struct User { id: u64, name: String, created_at: Timestamp, } async fn users_create( State(state): State<AppState>, // Make sure to use our own JSON extractor so we get input errors formatted in a way that // matches our application AppJson(params): AppJson<UserParams>, ) -> Result<AppJson<User>, AppError> { let id = state.next_id.fetch_add(1, Ordering::SeqCst); // We have implemented `From<time_library::Error> for AppError` which allows us to use `?` to // automatically convert the error let created_at = Timestamp::now()?; let user = User { id, name: params.name, created_at, }; state.users.lock().unwrap().insert(id, user.clone()); Ok(AppJson(user)) } // Create our own JSON extractor by wrapping `axum::Json`. This makes it easy to override the // rejection and provide our own which formats errors to match our application. // // `axum::Json` responds with plain text if the input is invalid. #[derive(FromRequest)] #[from_request(via(axum::Json), rejection(AppError))] struct AppJson<T>(T); impl<T> IntoResponse for AppJson<T> where axum::Json<T>: IntoResponse, { fn into_response(self) -> Response { axum::Json(self.0).into_response() } } // The kinds of errors we can hit in our application. enum AppError { // The request body contained invalid JSON JsonRejection(JsonRejection), // Some error from a third party library we're using TimeError(time_library::Error), } // Tell axum how `AppError` should be converted into a response. // // This is also a convenient place to log errors. impl IntoResponse for AppError { fn into_response(self) -> Response { // How we want errors responses to be serialized #[derive(Serialize)] struct ErrorResponse { message: String, } let (status, message) = match self { AppError::JsonRejection(rejection) => { // This error is caused by bad user input so don't log it (rejection.status(), rejection.body_text()) } AppError::TimeError(err) => { // Because `TraceLayer` wraps each request in a span that contains the request // method, uri, etc we don't need to include those details here tracing::error!(%err, "error from time_library"); // Don't expose any details about the error to the client ( StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong".to_owned(), ) } }; (status, AppJson(ErrorResponse { message })).into_response() } } impl From<JsonRejection> for AppError { fn from(rejection: JsonRejection) -> Self { Self::JsonRejection(rejection) } } impl From<time_library::Error> for AppError { fn from(error: time_library::Error) -> Self { Self::TimeError(error) } } // Imagine this is some third party library that we're using. It sometimes returns errors which we // want to log. mod time_library { use std::sync::atomic::{AtomicU64, Ordering}; use serde::Serialize; #[derive(Serialize, Clone)] pub struct Timestamp(u64); impl Timestamp { pub fn now() -> Result<Self, Error> { static COUNTER: AtomicU64 = AtomicU64::new(0); // Fail on every third call just to simulate errors if COUNTER.fetch_add(1, Ordering::SeqCst) % 3 == 0 { Err(Error::FailedToGetTime) } else { Ok(Self(1337)) } } } #[derive(Debug)] pub enum Error { FailedToGetTime, } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "failed to get time") } } }
graceful-shutdown
[package]
name = "example-graceful-shutdown"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["tracing"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.6.1", features = ["timeout", "trace"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-graceful-shutdown //! kill or ctrl-c //! ``` use std::time::Duration; use axum::{routing::get, Router}; use tokio::net::TcpListener; use tokio::signal; use tokio::time::sleep; use tower_http::timeout::TimeoutLayer; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { // Enable tracing. tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!( "{}=debug,tower_http=debug,axum=trace", env!("CARGO_CRATE_NAME") ) .into() }), ) .with(tracing_subscriber::fmt::layer().without_time()) .init(); // Create a regular axum app. let app = Router::new() .route("/slow", get(|| sleep(Duration::from_secs(5)))) .route("/forever", get(std::future::pending::<()>)) .layer(( TraceLayer::new_for_http(), // Graceful shutdown will wait for outstanding requests to complete. Add a timeout so // requests don't hang forever. TimeoutLayer::new(Duration::from_secs(10)), )); // Create a `TcpListener` using tokio. let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); // Run the server with graceful shutdown axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); } async fn shutdown_signal() { let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } }
jwt
[package]
name = "example-jwt"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
jsonwebtoken = "9.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Example JWT authorization/authentication. //! //! Run with //! //! ```not_rust //! JWT_SECRET=secret cargo run -p example-jwt //! ``` use axum::{ extract::FromRequestParts, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Json, RequestPartsExt, Router, }; use axum_extra::{ headers::{authorization::Bearer, Authorization}, TypedHeader, }; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::fmt::Display; use std::sync::LazyLock; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // Quick instructions // // - get an authorization token: // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -d '{"client_id":"foo","client_secret":"bar"}' \ // http://localhost:3000/authorize // // - visit the protected area using the authorized token // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \ // http://localhost:3000/protected // // - try to visit the protected area using an invalid token // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -H 'Authorization: Bearer blahblahblah' \ // http://localhost:3000/protected static KEYS: LazyLock<Keys> = LazyLock::new(|| { let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); Keys::new(secret.as_bytes()) }); #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new() .route("/protected", get(protected)) .route("/authorize", post(authorize)); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn protected(claims: Claims) -> Result<String, AuthError> { // Send the protected data to the user Ok(format!( "Welcome to the protected area :)\nYour data:\n{claims}", )) } async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> { // Check if the user sent the credentials if payload.client_id.is_empty() || payload.client_secret.is_empty() { return Err(AuthError::MissingCredentials); } // Here you can check the user credentials from a database if payload.client_id != "foo" || payload.client_secret != "bar" { return Err(AuthError::WrongCredentials); } let claims = Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned(), // Mandatory expiry time as UTC timestamp exp: 2000000000, // May 2033 }; // Create the authorization token let token = encode(&Header::default(), &claims, &KEYS.encoding) .map_err(|_| AuthError::TokenCreation)?; // Send the authorized token Ok(Json(AuthBody::new(token))) } impl Display for Claims { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Email: {}\nCompany: {}", self.sub, self.company) } } impl AuthBody { fn new(access_token: String) -> Self { Self { access_token, token_type: "Bearer".to_string(), } } } impl<S> FromRequestParts<S> for Claims where S: Send + Sync, { type Rejection = AuthError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = parts .extract::<TypedHeader<Authorization<Bearer>>>() .await .map_err(|_| AuthError::InvalidToken)?; // Decode the user data let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default()) .map_err(|_| AuthError::InvalidToken)?; Ok(token_data.claims) } } impl IntoResponse for AuthError { fn into_response(self) -> Response { let (status, error_message) = match self { AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"), AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), }; let body = Json(json!({ "error": error_message, })); (status, body).into_response() } } struct Keys { encoding: EncodingKey, decoding: DecodingKey, } impl Keys { fn new(secret: &[u8]) -> Self { Self { encoding: EncodingKey::from_secret(secret), decoding: DecodingKey::from_secret(secret), } } } #[derive(Debug, Serialize, Deserialize)] struct Claims { sub: String, company: String, exp: usize, } #[derive(Debug, Serialize)] struct AuthBody { access_token: String, token_type: String, } #[derive(Debug, Deserialize)] struct AuthPayload { client_id: String, client_secret: String, } #[derive(Debug)] enum AuthError { WrongCredentials, MissingCredentials, TokenCreation, InvalidToken, }
http-proxy
[package]
name = "example-http-proxy"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
hyper = { version = "1", features = ["full"] }
hyper-util = "0.1.1"
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.5.2", features = ["make", "util"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! $ cargo run -p example-http-proxy //! ``` //! //! In another terminal: //! //! ```not_rust //! $ curl -v -x "127.0.0.1:3000" https://tokio.rs //! ``` //! //! Example is based on <https://github.com/hyperium/hyper/blob/master/examples/http_proxy.rs> use axum::{ body::Body, extract::Request, http::{Method, StatusCode}, response::{IntoResponse, Response}, routing::get, Router, }; use hyper::body::Incoming; use hyper::server::conn::http1; use hyper::upgrade::Upgraded; use std::net::SocketAddr; use tokio::net::{TcpListener, TcpStream}; use tower::Service; use tower::ServiceExt; use hyper_util::rt::TokioIo; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=trace,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let router_svc = Router::new().route("/", get(|| async { "Hello, World!" })); let tower_service = tower::service_fn(move |req: Request<_>| { let router_svc = router_svc.clone(); let req = req.map(Body::new); async move { if req.method() == Method::CONNECT { proxy(req).await } else { router_svc.oneshot(req).await.map_err(|err| match err {}) } } }); let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { tower_service.clone().call(request) }); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); let listener = TcpListener::bind(addr).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); let io = TokioIo::new(stream); let hyper_service = hyper_service.clone(); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) .serve_connection(io, hyper_service) .with_upgrades() .await { println!("Failed to serve connection: {:?}", err); } }); } } async fn proxy(req: Request) -> Result<Response, hyper::Error> { tracing::trace!(?req); if let Some(host_addr) = req.uri().authority().map(|auth| auth.to_string()) { tokio::task::spawn(async move { match hyper::upgrade::on(req).await { Ok(upgraded) => { if let Err(e) = tunnel(upgraded, host_addr).await { tracing::warn!("server io error: {}", e); }; } Err(e) => tracing::warn!("upgrade error: {}", e), } }); Ok(Response::new(Body::empty())) } else { tracing::warn!("CONNECT host is not socket addr: {:?}", req.uri()); Ok(( StatusCode::BAD_REQUEST, "CONNECT must be to a socket address", ) .into_response()) } } async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { let mut server = TcpStream::connect(addr).await?; let mut upgraded = TokioIo::new(upgraded); let (from_client, from_server) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?; tracing::debug!( "client wrote {} bytes and received {} bytes", from_client, from_server ); Ok(()) }
multipart-form
[package]
name = "example-multipart-form"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["multipart"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.6.1", features = ["limit", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-multipart-form //! ``` use axum::{ extract::{DefaultBodyLimit, Multipart}, response::Html, routing::get, Router, }; use tower_http::limit::RequestBodyLimitLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes let app = Router::new() .route("/", get(show_form).post(accept_form)) .layer(DefaultBodyLimit::disable()) .layer(RequestBodyLimitLayer::new( 250 * 1024 * 1024, /* 250mb */ )) .layer(tower_http::trace::TraceLayer::new_for_http()); // run it with hyper let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn show_form() -> Html<&'static str> { Html( r#" <!doctype html> <html> <head></head> <body> <form action="/" method="post" enctype="multipart/form-data"> <label> Upload file: <input type="file" name="file" multiple> </label> <input type="submit" value="Upload files"> </form> </body> </html> "#, ) } async fn accept_form(mut multipart: Multipart) { while let Some(field) = multipart.next_field().await.unwrap() { let name = field.name().unwrap().to_string(); let file_name = field.file_name().unwrap().to_string(); let content_type = field.content_type().unwrap().to_string(); let data = field.bytes().await.unwrap(); println!( "Length of `{name}` (`{file_name}`: `{content_type}`) is {} bytes", data.len() ); } }
reqwest-response
[package]
name = "example-reqwest-response"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
reqwest = { version = "0.12", features = ["stream"] }
tokio = { version = "1.0", features = ["full"] }
tokio-stream = "0.1"
tower-http = { version = "0.6.1", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-reqwest-response //! ``` use axum::{ body::{Body, Bytes}, extract::State, http::StatusCode, response::{IntoResponse, Response}, routing::get, Router, }; use reqwest::Client; use std::{convert::Infallible, time::Duration}; use tokio_stream::StreamExt; use tower_http::trace::TraceLayer; use tracing::Span; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let client = Client::new(); let app = Router::new() .route("/", get(stream_reqwest_response)) .route("/stream", get(stream_some_data)) // Add some logging so we can see the streams going through .layer(TraceLayer::new_for_http().on_body_chunk( |chunk: &Bytes, _latency: Duration, _span: &Span| { tracing::debug!("streaming {} bytes", chunk.len()); }, )) .with_state(client); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn stream_reqwest_response(State(client): State<Client>) -> Response { let reqwest_response = match client.get("http://127.0.0.1:3000/stream").send().await { Ok(res) => res, Err(err) => { tracing::error!(%err, "request failed"); return (StatusCode::BAD_REQUEST, Body::empty()).into_response(); } }; let mut response_builder = Response::builder().status(reqwest_response.status()); *response_builder.headers_mut().unwrap() = reqwest_response.headers().clone(); response_builder .body(Body::from_stream(reqwest_response.bytes_stream())) // This unwrap is fine because the body is empty here .unwrap() } async fn stream_some_data() -> Body { let stream = tokio_stream::iter(0..5) .throttle(Duration::from_secs(1)) .map(|n| n.to_string()) .map(Ok::<_, Infallible>); Body::from_stream(stream) }
parse-body-based-on-content-type
[package]
name = "example-parse-body-based-on-content-type"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Provides a RESTful web server managing some Todos. //! //! Run with //! //! ```not_rust //! cargo run -p example-parse-body-based-on-content-type //! ``` use axum::{ extract::{FromRequest, Request}, http::{header::CONTENT_TYPE, StatusCode}, response::{IntoResponse, Response}, routing::post, Form, Json, RequestExt, Router, }; use serde::{Deserialize, Serialize}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new().route("/", post(handler)); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } #[derive(Debug, Serialize, Deserialize)] struct Payload { foo: String, } async fn handler(JsonOrForm(payload): JsonOrForm<Payload>) { dbg!(payload); } struct JsonOrForm<T>(T); impl<S, T> FromRequest<S> for JsonOrForm<T> where S: Send + Sync, Json<T>: FromRequest<()>, Form<T>: FromRequest<()>, T: 'static, { type Rejection = Response; async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> { let content_type_header = req.headers().get(CONTENT_TYPE); let content_type = content_type_header.and_then(|value| value.to_str().ok()); if let Some(content_type) = content_type { if content_type.starts_with("application/json") { let Json(payload) = req.extract().await.map_err(IntoResponse::into_response)?; return Ok(Self(payload)); } if content_type.starts_with("application/x-www-form-urlencoded") { let Form(payload) = req.extract().await.map_err(IntoResponse::into_response)?; return Ok(Self(payload)); } } Err(StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response()) } }
unix-domain-socket
[package]
name = "example-unix-domain-socket"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1"
hyper = { version = "1.0.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] }
tokio = { version = "1.0", features = ["full"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-unix-domain-socket //! ``` #[cfg(unix)] #[tokio::main] async fn main() { unix::server().await; } #[cfg(not(unix))] fn main() { println!("This example requires unix") } #[cfg(unix)] mod unix { use axum::{ body::Body, extract::connect_info::{self, ConnectInfo}, http::{Method, Request, StatusCode}, routing::get, serve::IncomingStream, Router, }; use http_body_util::BodyExt; use hyper_util::rt::TokioIo; use std::{path::PathBuf, sync::Arc}; use tokio::net::{unix::UCred, UnixListener, UnixStream}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub async fn server() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let path = PathBuf::from("/tmp/axum/helloworld"); let _ = tokio::fs::remove_file(&path).await; tokio::fs::create_dir_all(path.parent().unwrap()) .await .unwrap(); let uds = UnixListener::bind(path.clone()).unwrap(); tokio::spawn(async move { let app = Router::new() .route("/", get(handler)) .into_make_service_with_connect_info::<UdsConnectInfo>(); axum::serve(uds, app).await.unwrap(); }); let stream = TokioIo::new(UnixStream::connect(path).await.unwrap()); let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap(); tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); } }); let request = Request::builder() .method(Method::GET) .uri("http://uri-doesnt-matter.com") .body(Body::empty()) .unwrap(); let response = sender.send_request(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); assert_eq!(body, "Hello, World!"); } async fn handler(ConnectInfo(info): ConnectInfo<UdsConnectInfo>) -> &'static str { println!("new connection from `{:?}`", info); "Hello, World!" } #[derive(Clone, Debug)] #[allow(dead_code)] struct UdsConnectInfo { peer_addr: Arc<tokio::net::unix::SocketAddr>, peer_cred: UCred, } impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo { fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self { let peer_addr = stream.io().peer_addr().unwrap(); let peer_cred = stream.io().peer_cred().unwrap(); Self { peer_addr: Arc::new(peer_addr), peer_cred, } } } }
tracing-aka-logging
[package]
name = "example-tracing-aka-logging"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["tracing"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.6.1", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-tracing-aka-logging //! ``` use axum::{ body::Bytes, extract::MatchedPath, http::{HeaderMap, Request}, response::{Html, Response}, routing::get, Router, }; use std::time::Duration; use tokio::net::TcpListener; use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; use tracing::{info_span, Span}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { // axum logs rejections from built-in extractors with the `axum::rejection` // target, at `TRACE` level. `axum::rejection=trace` enables showing those events format!( "{}=debug,tower_http=debug,axum::rejection=trace", env!("CARGO_CRATE_NAME") ) .into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route let app = Router::new() .route("/", get(handler)) // `TraceLayer` is provided by tower-http so you have to add that as a dependency. // It provides good defaults but is also very customizable. // // See https://docs.rs/tower-http/0.1.1/tower_http/trace/index.html for more details. // // If you want to customize the behavior using closures here is how. .layer( TraceLayer::new_for_http() .make_span_with(|request: &Request<_>| { // Log the matched route's path (with placeholders not filled in). // Use request.uri() or OriginalUri if you want the real path. let matched_path = request .extensions() .get::<MatchedPath>() .map(MatchedPath::as_str); info_span!( "http_request", method = ?request.method(), matched_path, some_other_field = tracing::field::Empty, ) }) .on_request(|_request: &Request<_>, _span: &Span| { // You can use `_span.record("some_other_field", value)` in one of these // closures to attach a value to the initially empty field in the info_span // created above. }) .on_response(|_response: &Response, _latency: Duration, _span: &Span| { // ... }) .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| { // ... }) .on_eos( |_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| { // ... }, ) .on_failure( |_error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { // ... }, ), ); // run it let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn handler() -> Html<&'static str> { Html("<h1>Hello, World!</h1>") }
tls-rustls
[package]
name = "example-tls-rustls"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra" }
axum-server = { version = "0.7", features = ["tls-rustls"] }
tokio = { version = "1", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-tls-rustls //! ``` #![allow(unused_imports)] use axum::{ handler::HandlerWithoutStateExt, http::{uri::Authority, StatusCode, Uri}, response::Redirect, routing::get, BoxError, Router, }; use axum_extra::extract::Host; use axum_server::tls_rustls::RustlsConfig; use std::{net::SocketAddr, path::PathBuf}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[allow(dead_code)] #[derive(Clone, Copy)] struct Ports { http: u16, https: u16, } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let ports = Ports { http: 7878, https: 3000, }; // optional: spawn a second server to redirect http requests to this server tokio::spawn(redirect_http_to_https(ports)); // configure certificate and private key used by https let config = RustlsConfig::from_pem_file( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), ) .await .unwrap(); let app = Router::new().route("/", get(handler)); // run https server let addr = SocketAddr::from(([127, 0, 0, 1], ports.https)); tracing::debug!("listening on {}", addr); axum_server::bind_rustls(addr, config) .serve(app.into_make_service()) .await .unwrap(); } #[allow(dead_code)] async fn handler() -> &'static str { "Hello, World!" } #[allow(dead_code)] async fn redirect_http_to_https(ports: Ports) { fn make_https(host: &str, uri: Uri, https_port: u16) -> Result<Uri, BoxError> { let mut parts = uri.into_parts(); parts.scheme = Some(axum::http::uri::Scheme::HTTPS); if parts.path_and_query.is_none() { parts.path_and_query = Some("/".parse().unwrap()); } let authority: Authority = host.parse()?; let bare_host = match authority.port() { Some(port_struct) => authority .as_str() .strip_suffix(port_struct.as_str()) .unwrap() .strip_suffix(':') .unwrap(), // if authority.port() is Some(port) then we can be sure authority ends with :{port} None => authority.as_str(), }; parts.authority = Some(format!("{bare_host}:{https_port}").parse()?); Ok(Uri::from_parts(parts)?) } let redirect = move |Host(host): Host, uri: Uri| async move { match make_https(&host, uri, ports.https) { Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), Err(error) => { tracing::warn!(%error, "failed to convert URI to HTTPS"); Err(StatusCode::BAD_REQUEST) } } }; let addr = SocketAddr::from(([127, 0, 0, 1], ports.http)); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, redirect.into_make_service()) .await .unwrap(); }
There are two more files: self_signed_certs/cert.pem self_signed_certs/key.pem
tls-graceful-shutdown
//! Run with //! //! ```not_rust //! cargo run -p example-tls-graceful-shutdown //! ``` use axum::{ handler::HandlerWithoutStateExt, http::{uri::Authority, StatusCode, Uri}, response::Redirect, routing::get, BoxError, Router, }; use axum_extra::extract::Host; use axum_server::tls_rustls::RustlsConfig; use std::{future::Future, net::SocketAddr, path::PathBuf, time::Duration}; use tokio::signal; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[derive(Clone, Copy)] struct Ports { http: u16, https: u16, } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let ports = Ports { http: 7878, https: 3000, }; //Create a handle for our TLS server so the shutdown signal can all shutdown let handle = axum_server::Handle::new(); //save the future for easy shutting down of redirect server let shutdown_future = shutdown_signal(handle.clone()); // optional: spawn a second server to redirect http requests to this server tokio::spawn(redirect_http_to_https(ports, shutdown_future)); // configure certificate and private key used by https let config = RustlsConfig::from_pem_file( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), ) .await .unwrap(); let app = Router::new().route("/", get(handler)); // run https server let addr = SocketAddr::from(([127, 0, 0, 1], ports.https)); tracing::debug!("listening on {addr}"); axum_server::bind_rustls(addr, config) .handle(handle) .serve(app.into_make_service()) .await .unwrap(); } async fn shutdown_signal(handle: axum_server::Handle) { let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } tracing::info!("Received termination signal shutting down"); handle.graceful_shutdown(Some(Duration::from_secs(10))); // 10 secs is how long docker will wait // to force shutdown } async fn handler() -> &'static str { "Hello, World!" } async fn redirect_http_to_https<F>(ports: Ports, signal: F) where F: Future<Output = ()> + Send + 'static, { fn make_https(host: &str, uri: Uri, https_port: u16) -> Result<Uri, BoxError> { let mut parts = uri.into_parts(); parts.scheme = Some(axum::http::uri::Scheme::HTTPS); if parts.path_and_query.is_none() { parts.path_and_query = Some("/".parse().unwrap()); } let authority: Authority = host.parse()?; let bare_host = match authority.port() { Some(port_struct) => authority .as_str() .strip_suffix(port_struct.as_str()) .unwrap() .strip_suffix(':') .unwrap(), // if authority.port() is Some(port) then we can be sure authority ends with :{port} None => authority.as_str(), }; parts.authority = Some(format!("{bare_host}:{https_port}").parse()?); Ok(Uri::from_parts(parts)?) } let redirect = move |Host(host): Host, uri: Uri| async move { match make_https(&host, uri, ports.https) { Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), Err(error) => { tracing::warn!(%error, "failed to convert URI to HTTPS"); Err(StatusCode::BAD_REQUEST) } } }; let addr = SocketAddr::from(([127, 0, 0, 1], ports.http)); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); tracing::debug!("listening on {addr}"); axum::serve(listener, redirect.into_make_service()) .with_graceful_shutdown(signal) .await .unwrap(); }
stream-to-file
[package]
name = "example-stream-to-file"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["multipart"] }
futures = "0.3"
tokio = { version = "1.0", features = ["full"] }
tokio-util = { version = "0.7", features = ["io"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-stream-to-file //! ``` use axum::{ body::Bytes, extract::{Multipart, Path, Request}, http::StatusCode, response::{Html, Redirect}, routing::{get, post}, BoxError, Router, }; use futures::{Stream, TryStreamExt}; use std::io; use tokio::{fs::File, io::BufWriter}; use tokio_util::io::StreamReader; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; const UPLOADS_DIRECTORY: &str = "uploads"; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // save files to a separate directory to not override files in the current directory tokio::fs::create_dir(UPLOADS_DIRECTORY) .await .expect("failed to create `uploads` directory"); let app = Router::new() .route("/", get(show_form).post(accept_form)) .route("/file/{file_name}", post(save_request_body)); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } // Handler that streams the request body to a file. // // POST'ing to `/file/foo.txt` will create a file called `foo.txt`. async fn save_request_body( Path(file_name): Path<String>, request: Request, ) -> Result<(), (StatusCode, String)> { stream_to_file(&file_name, request.into_body().into_data_stream()).await } // Handler that returns HTML for a multipart form. async fn show_form() -> Html<&'static str> { Html( r#" <!doctype html> <html> <head> <title>Upload something!</title> </head> <body> <form action="/" method="post" enctype="multipart/form-data"> <div> <label> Upload file: <input type="file" name="file" multiple> </label> </div> <div> <input type="submit" value="Upload files"> </div> </form> </body> </html> "#, ) } // Handler that accepts a multipart form upload and streams each field to a file. async fn accept_form(mut multipart: Multipart) -> Result<Redirect, (StatusCode, String)> { while let Ok(Some(field)) = multipart.next_field().await { let file_name = if let Some(file_name) = field.file_name() { file_name.to_owned() } else { continue; }; stream_to_file(&file_name, field).await?; } Ok(Redirect::to("/")) } // Save a `Stream` to a file async fn stream_to_file<S, E>(path: &str, stream: S) -> Result<(), (StatusCode, String)> where S: Stream<Item = Result<Bytes, E>>, E: Into<BoxError>, { if !path_is_valid(path) { return Err((StatusCode::BAD_REQUEST, "Invalid path".to_owned())); } async { // Convert the stream into an `AsyncRead`. let body_with_io_error = stream.map_err(io::Error::other); let body_reader = StreamReader::new(body_with_io_error); futures::pin_mut!(body_reader); // Create the file. `File` implements `AsyncWrite`. let path = std::path::Path::new(UPLOADS_DIRECTORY).join(path); let mut file = BufWriter::new(File::create(path).await?); // Copy the body into the file. tokio::io::copy(&mut body_reader, &mut file).await?; Ok::<_, io::Error>(()) } .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string())) } // to prevent directory traversal attacks we ensure the path consists of exactly one normal // component fn path_is_valid(path: &str) -> bool { let path = std::path::Path::new(path); let mut components = path.components().peekable(); if let Some(first) = components.peek() { if !matches!(first, std::path::Component::Normal(_)) { return false; } } components.count() == 1 }
static-file-server
[package]
name = "example-static-file-server"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.5.2", features = ["util"] }
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
http-body-util = "0.1.0"
//! Run with //! //! ```not_rust //! cargo run -p example-static-file-server //! ``` #[cfg(test)] mod tests; use axum::{ extract::Request, handler::HandlerWithoutStateExt, http::StatusCode, routing::get, Router, }; use std::net::SocketAddr; use tower::ServiceExt; use tower_http::{ services::{ServeDir, ServeFile}, trace::TraceLayer, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); tokio::join!( serve(using_serve_dir(), 3001), serve(using_serve_dir_with_assets_fallback(), 3002), serve(using_serve_dir_only_from_root_via_fallback(), 3003), serve(using_serve_dir_with_handler_as_service(), 3004), serve(two_serve_dirs(), 3005), serve(calling_serve_dir_from_a_handler(), 3006), serve(using_serve_file_from_a_route(), 3307), ); } fn using_serve_dir() -> Router { // serve the file in the "assets" directory under `/assets` Router::new().nest_service("/assets", ServeDir::new("assets")) } fn using_serve_dir_with_assets_fallback() -> Router { // `ServeDir` allows setting a fallback if an asset is not found // so with this `GET /assets/doesnt-exist.jpg` will return `index.html` // rather than a 404 let serve_dir = ServeDir::new("assets").not_found_service(ServeFile::new("assets/index.html")); Router::new() .route("/foo", get(|| async { "Hi from /foo" })) .nest_service("/assets", serve_dir.clone()) .fallback_service(serve_dir) } fn using_serve_dir_only_from_root_via_fallback() -> Router { // you can also serve the assets directly from the root (not nested under `/assets`) // by only setting a `ServeDir` as the fallback let serve_dir = ServeDir::new("assets").not_found_service(ServeFile::new("assets/index.html")); Router::new() .route("/foo", get(|| async { "Hi from /foo" })) .fallback_service(serve_dir) } fn using_serve_dir_with_handler_as_service() -> Router { async fn handle_404() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not found") } // you can convert handler function to service let service = handle_404.into_service(); let serve_dir = ServeDir::new("assets").not_found_service(service); Router::new() .route("/foo", get(|| async { "Hi from /foo" })) .fallback_service(serve_dir) } fn two_serve_dirs() -> Router { // you can also have two `ServeDir`s nested at different paths let serve_dir_from_assets = ServeDir::new("assets"); let serve_dir_from_dist = ServeDir::new("dist"); Router::new() .nest_service("/assets", serve_dir_from_assets) .nest_service("/dist", serve_dir_from_dist) } #[allow(clippy::let_and_return)] fn calling_serve_dir_from_a_handler() -> Router { // via `tower::Service::call`, or more conveniently `tower::ServiceExt::oneshot` you can // call `ServeDir` yourself from a handler Router::new().nest_service( "/foo", get(|request: Request| async { let service = ServeDir::new("assets"); let result = service.oneshot(request).await; result }), ) } fn using_serve_file_from_a_route() -> Router { Router::new().route_service("/foo", ServeFile::new("assets/index.html")) } async fn serve(app: Router, port: u16) { let addr = SocketAddr::from(([127, 0, 0, 1], port)); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app.layer(TraceLayer::new_for_http())) .await .unwrap(); }
Hi from index.html
console.log("Hello, World!");
sse
[package]
name = "example-sse"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
futures = "0.3"
headers = "0.4"
tokio = { version = "1.0", features = ["full"] }
tokio-stream = "0.1"
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
eventsource-stream = "0.2"
reqwest = { version = "0.12", features = ["stream"] }
reqwest-eventsource = "0.6"
//! Run with //! //! ```not_rust //! cargo run -p example-sse //! ``` //! Test with //! ```not_rust //! cargo test -p example-sse //! ``` use axum::{ response::sse::{Event, Sse}, routing::get, Router, }; use axum_extra::TypedHeader; use futures::stream::{self, Stream}; use std::{convert::Infallible, path::PathBuf, time::Duration}; use tokio_stream::StreamExt as _; use tower_http::{services::ServeDir, trace::TraceLayer}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn app() -> Router { let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); let static_files_service = ServeDir::new(assets_dir).append_index_html_on_directories(true); // build our application with a route Router::new() .fallback_service(static_files_service) .route("/sse", get(sse_handler)) .layer(TraceLayer::new_for_http()) } async fn sse_handler( TypedHeader(user_agent): TypedHeader<headers::UserAgent>, ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> { println!("`{}` connected", user_agent.as_str()); // A `Stream` that repeats an event every second // // You can also create streams from tokio channels using the wrappers in // https://docs.rs/tokio-stream let stream = stream::repeat_with(|| Event::default().data("hi!")) .map(Ok) .throttle(Duration::from_secs(1)); Sse::new(stream).keep_alive( axum::response::sse::KeepAlive::new() .interval(Duration::from_secs(1)) .text("keep-alive-text"), ) } #[cfg(test)] mod tests { use eventsource_stream::Eventsource; use tokio::net::TcpListener; use super::*; #[tokio::test] async fn integration_test() { // A helper function that spawns our application in the background async fn spawn_app(host: impl Into<String>) -> String { let host = host.into(); // Bind to localhost at the port 0, which will let the OS assign an available port to us let listener = TcpListener::bind(format!("{}:0", host)).await.unwrap(); // Retrieve the port assigned to us by the OS let port = listener.local_addr().unwrap().port(); tokio::spawn(async { axum::serve(listener, app()).await.unwrap(); }); // Returns address (e.g. http://127.0.0.1{random_port}) format!("http://{}:{}", host, port) } let listening_url = spawn_app("127.0.0.1").await; let mut event_stream = reqwest::Client::new() .get(format!("{}/sse", listening_url)) .header("User-Agent", "integration_test") .send() .await .unwrap() .bytes_stream() .eventsource() .take(1); let mut event_data: Vec<String> = vec![]; while let Some(event) = event_stream.next().await { match event { Ok(event) => { // break the loop at the end of SSE stream if event.data == "[DONE]" { break; } event_data.push(event.data); } Err(_) => { panic!("Error in event stream"); } } } assert!(event_data[0] == "hi!"); } }
<script src='script.js'></script>
var eventSource = new EventSource('sse');
eventSource.onmessage = function(event) {
console.log('Message from server ', event.data);
}
simple-router-wasm
[package]
name = "example-simple-router-wasm"
version = "0.1.0"
edition = "2018"
publish = false
[dependencies]
# `default-features = false` to not depend on tokio features which don't support wasm
# you can still pull in tokio manually and only add features that tokio supports for wasm
axum = { path = "../../axum", default-features = false }
# we don't strictly use axum-extra in this example but wanna make sure that
# works in wasm as well
axum-extra = { path = "../../axum-extra", default-features = false }
futures-executor = "0.3.21"
http = "1.0.0"
tower-service = "0.3.1"
[package.metadata.cargo-machete]
ignored = ["axum-extra"]
//! Run with //! //! ```not_rust //! cargo run -p example-simple-router-wasm //! ``` //! //! This example shows what using axum in a wasm context might look like. This example should //! always compile with `--target wasm32-unknown-unknown`. //! //! [`mio`](https://docs.rs/mio/latest/mio/index.html), tokio's IO layer, does not support the //! `wasm32-unknown-unknown` target which is why this crate requires `default-features = false` //! for axum. //! //! Most serverless runtimes expect an exported function that takes in a single request and returns //! a single response, much like axum's `Handler` trait. In this example, the handler function is //! `app` with `main` acting as the serverless runtime which originally receives the request and //! calls the app function. //! //! We can use axum's routing, extractors, tower services, and everything else to implement //! our serverless function, even though we are running axum in a wasm context. use axum::{ response::{Html, Response}, routing::get, Router, }; use futures_executor::block_on; use http::Request; use tower_service::Service; fn main() { let request: Request<String> = Request::builder() .uri("https://serverless.example/api/") .body("Some Body Data".into()) .unwrap(); let response: Response = block_on(app(request)); assert_eq!(200, response.status()); } #[allow(clippy::let_and_return)] async fn app(request: Request<String>) -> Response { let mut router = Router::new().route("/api/", get(index)); let response = router.call(request).await.unwrap(); response } async fn index() -> Html<&'static str> { Html("<h1>Hello, World!</h1>") }
serve-with-hyper
[package]
name = "example-serve-with-hyper"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
hyper = { version = "1.0", features = [] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.5.2", features = ["util"] }
//! Run with //! //! ```not_rust //! cargo run -p example-serve-with-hyper //! ``` //! //! This example shows how to run axum using hyper's low level API. //! //! The [hyper-util] crate exists to provide high level utilities but it's still in early stages of //! development. //! //! [hyper-util]: https://crates.io/crates/hyper-util use std::convert::Infallible; use std::net::SocketAddr; use axum::extract::ConnectInfo; use axum::{extract::Request, routing::get, Router}; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; use hyper_util::server; use tokio::net::TcpListener; use tower::{Service, ServiceExt}; #[tokio::main] async fn main() { tokio::join!(serve_plain(), serve_with_connect_info()); } async fn serve_plain() { // Create a regular axum app. let app = Router::new().route("/", get(|| async { "Hello!" })); // Create a `TcpListener` using tokio. let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); // Continuously accept new connections. loop { // In this example we discard the remote address. See `fn serve_with_connect_info` for how // to expose that. let (socket, _remote_addr) = listener.accept().await.unwrap(); // We don't need to call `poll_ready` because `Router` is always ready. let tower_service = app.clone(); // Spawn a task to handle the connection. That way we can handle multiple connections // concurrently. tokio::spawn(async move { // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. // `TokioIo` converts between them. let socket = TokioIo::new(socket); // Hyper also has its own `Service` trait and doesn't use tower. We can use // `hyper::service::service_fn` to create a hyper `Service` that calls our app through // `tower::Service::call`. let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas // tower's `Service` requires `&mut self`. // // We don't need to call `poll_ready` since `Router` is always ready. tower_service.clone().call(request) }); // `server::conn::auto::Builder` supports both http1 and http2. // // `TokioExecutor` tells hyper to use `tokio::spawn` to spawn tasks. if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) // `serve_connection_with_upgrades` is required for websockets. If you don't need // that you can use `serve_connection` instead. .serve_connection_with_upgrades(socket, hyper_service) .await { eprintln!("failed to serve connection: {err:#}"); } }); } } // Similar setup to `serve_plain` but captures the remote address and exposes it through the // `ConnectInfo` extractor async fn serve_with_connect_info() { let app = Router::new().route( "/", get( |ConnectInfo(remote_addr): ConnectInfo<SocketAddr>| async move { format!("Hello {remote_addr}") }, ), ); let mut make_service = app.into_make_service_with_connect_info::<SocketAddr>(); let listener = TcpListener::bind("0.0.0.0:3001").await.unwrap(); loop { let (socket, remote_addr) = listener.accept().await.unwrap(); // We don't need to call `poll_ready` because `IntoMakeServiceWithConnectInfo` is always // ready. let tower_service = unwrap_infallible(make_service.call(remote_addr).await); tokio::spawn(async move { let socket = TokioIo::new(socket); let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { tower_service.clone().oneshot(request) }); if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) .serve_connection_with_upgrades(socket, hyper_service) .await { eprintln!("failed to serve connection: {err:#}"); } }); } } fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T { match result { Ok(value) => value, Err(err) => match err {}, } }
routes-and-handlers-close-together
[package]
name = "example-routes-and-handlers-close-together"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
//! Run with //! //! ```not_rust //! cargo run -p example-routes-and-handlers-close-together //! ``` use axum::{ routing::{get, post, MethodRouter}, Router, }; #[tokio::main] async fn main() { let app = Router::new() .merge(root()) .merge(get_foo()) .merge(post_foo()); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } fn root() -> Router { async fn handler() -> &'static str { "Hello, World!" } route("/", get(handler)) } fn get_foo() -> Router { async fn handler() -> &'static str { "Hi from `GET /foo`" } route("/foo", get(handler)) } fn post_foo() -> Router { async fn handler() -> &'static str { "Hi from `POST /foo`" } route("/foo", post(handler)) } fn route(path: &str, method_router: MethodRouter<()>) -> Router { Router::new().route(path, method_router) }
reverse-proxy
[package]
name = "example-reverse-proxy"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = { path = "../../axum" }
hyper = { version = "1.0.0", features = ["full"] }
hyper-util = { version = "0.1.1", features = ["client-legacy"] }
tokio = { version = "1", features = ["full"] }
//! Reverse proxy listening in "localhost:4000" will proxy all requests to "localhost:3000" //! endpoint. //! //! Run with //! //! ```not_rust //! cargo run -p example-reverse-proxy //! ``` use axum::{ body::Body, extract::{Request, State}, http::uri::Uri, response::{IntoResponse, Response}, routing::get, Router, }; use hyper::StatusCode; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; type Client = hyper_util::client::legacy::Client<HttpConnector, Body>; #[tokio::main] async fn main() { tokio::spawn(server()); let client: Client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) .build(HttpConnector::new()); let app = Router::new().route("/", get(handler)).with_state(client); let listener = tokio::net::TcpListener::bind("127.0.0.1:4000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn handler(State(client): State<Client>, mut req: Request) -> Result<Response, StatusCode> { let path = req.uri().path(); let path_query = req .uri() .path_and_query() .map(|v| v.as_str()) .unwrap_or(path); let uri = format!("http://127.0.0.1:3000{}", path_query); *req.uri_mut() = Uri::try_from(uri).unwrap(); Ok(client .request(req) .await .map_err(|_| StatusCode::BAD_REQUEST)? .into_response()) } async fn server() { let app = Router::new().route("/", get(|| async { "Hello, world!" })); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); }
request-id
#![allow(unused)] fn main() { [package] name = "example-request-id" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tower = "0.5.2" tower-http = { version = "0.5", features = ["request-id", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } }
//! Run with //! //! ```not_rust //! cargo run -p example-request-id //! ``` use axum::{ http::{HeaderName, Request}, response::Html, routing::get, Router, }; use tower::ServiceBuilder; use tower_http::{ request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}, trace::TraceLayer, }; use tracing::{error, info, info_span}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; const REQUEST_ID_HEADER: &str = "x-request-id"; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { // axum logs rejections from built-in extractors with the `axum::rejection` // target, at `TRACE` level. `axum::rejection=trace` enables showing those events format!( "{}=debug,tower_http=debug,axum::rejection=trace", env!("CARGO_CRATE_NAME") ) .into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let x_request_id = HeaderName::from_static(REQUEST_ID_HEADER); let middleware = ServiceBuilder::new() .layer(SetRequestIdLayer::new( x_request_id.clone(), MakeRequestUuid, )) .layer( TraceLayer::new_for_http().make_span_with(|request: &Request<_>| { // Log the request id as generated. let request_id = request.headers().get(REQUEST_ID_HEADER); match request_id { Some(request_id) => info_span!( "http_request", request_id = ?request_id, ), None => { error!("could not extract request_id"); info_span!("http_request") } } }), ) // send headers from request to response headers .layer(PropagateRequestIdLayer::new(x_request_id)); // build our application with a route let app = Router::new().route("/", get(handler)).layer(middleware); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn handler() -> Html<&'static str> { info!("Hello world!"); Html("<h1>Hello, World!</h1>") }
prometheus-metrics
[package]
name = "example-prometheus-metrics"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
metrics = { version = "0.23", default-features = false }
metrics-exporter-prometheus = { version = "0.15", default-features = false }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Someday tower-http will hopefully have a metrics middleware, until then you can track //! metrics like this. //! //! Run with //! //! ```not_rust //! cargo run -p example-prometheus-metrics //! ``` use axum::{ extract::{MatchedPath, Request}, middleware::{self, Next}, response::IntoResponse, routing::get, Router, }; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::{ future::ready, time::{Duration, Instant}, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; fn metrics_app() -> Router { let recorder_handle = setup_metrics_recorder(); Router::new().route("/metrics", get(move || ready(recorder_handle.render()))) } fn main_app() -> Router { Router::new() .route("/fast", get(|| async {})) .route( "/slow", get(|| async { tokio::time::sleep(Duration::from_secs(1)).await; }), ) .route_layer(middleware::from_fn(track_metrics)) } async fn start_main_server() { let app = main_app(); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn start_metrics_server() { let app = metrics_app(); // NOTE: expose metrics endpoint on a different port let listener = tokio::net::TcpListener::bind("127.0.0.1:3001") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); // The `/metrics` endpoint should not be publicly available. If behind a reverse proxy, this // can be achieved by rejecting requests to `/metrics`. In this example, a second server is // started on another port to expose `/metrics`. let (_main_server, _metrics_server) = tokio::join!(start_main_server(), start_metrics_server()); } fn setup_metrics_recorder() -> PrometheusHandle { const EXPONENTIAL_SECONDS: &[f64] = &[ 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, ]; PrometheusBuilder::new() .set_buckets_for_metric( Matcher::Full("http_requests_duration_seconds".to_string()), EXPONENTIAL_SECONDS, ) .unwrap() .install_recorder() .unwrap() } async fn track_metrics(req: Request, next: Next) -> impl IntoResponse { let start = Instant::now(); let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() { matched_path.as_str().to_owned() } else { req.uri().path().to_owned() }; let method = req.method().clone(); let response = next.run(req).await; let latency = start.elapsed().as_secs_f64(); let status = response.status().as_u16().to_string(); let labels = [ ("method", method.to_string()), ("path", path), ("status", status), ]; metrics::counter!("http_requests_total", &labels).increment(1); metrics::histogram!("http_requests_duration_seconds", &labels).record(latency); response }
oauth
[package]
name = "example-oauth"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
anyhow = "1"
async-session = "3.0.0"
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
http = "1.0.0"
oauth2 = "4.1"
# Use Rustls because it makes it easier to cross-compile on CI
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Example OAuth (Discord) implementation. //! //! 1) Create a new application at <https://discord.com/developers/applications> //! 2) Visit the OAuth2 tab to get your CLIENT_ID and CLIENT_SECRET //! 3) Add a new redirect URI (for this example: `http://127.0.0.1:3000/auth/authorized`) //! 4) Run with the following (replacing values appropriately): //! ```not_rust //! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth //! ``` use anyhow::{anyhow, Context, Result}; use async_session::{MemoryStore, Session, SessionStore}; use axum::{ extract::{FromRef, FromRequestParts, OptionalFromRequestParts, Query, State}, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, routing::get, RequestPartsExt, Router, }; use axum_extra::{headers, typed_header::TypedHeaderRejectionReason, TypedHeader}; use http::{header, request::Parts, StatusCode}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; use std::{convert::Infallible, env}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; static COOKIE_NAME: &str = "SESSION"; static CSRF_TOKEN: &str = "csrf_token"; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // `MemoryStore` is just used as an example. Don't use this in production. let store = MemoryStore::new(); let oauth_client = oauth_client().unwrap(); let app_state = AppState { store, oauth_client, }; let app = Router::new() .route("/", get(index)) .route("/auth/discord", get(discord_auth)) .route("/auth/authorized", get(login_authorized)) .route("/protected", get(protected)) .route("/logout", get(logout)) .with_state(app_state); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .context("failed to bind TcpListener") .unwrap(); tracing::debug!( "listening on {}", listener .local_addr() .context("failed to return local address") .unwrap() ); axum::serve(listener, app).await.unwrap(); } #[derive(Clone)] struct AppState { store: MemoryStore, oauth_client: BasicClient, } impl FromRef<AppState> for MemoryStore { fn from_ref(state: &AppState) -> Self { state.store.clone() } } impl FromRef<AppState> for BasicClient { fn from_ref(state: &AppState) -> Self { state.oauth_client.clone() } } fn oauth_client() -> Result<BasicClient, AppError> { // Environment variables (* = required): // *"CLIENT_ID" "REPLACE_ME"; // *"CLIENT_SECRET" "REPLACE_ME"; // "REDIRECT_URL" "http://127.0.0.1:3000/auth/authorized"; // "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code"; // "TOKEN_URL" "https://discord.com/api/oauth2/token"; let client_id = env::var("CLIENT_ID").context("Missing CLIENT_ID!")?; let client_secret = env::var("CLIENT_SECRET").context("Missing CLIENT_SECRET!")?; let redirect_url = env::var("REDIRECT_URL") .unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string()); let auth_url = env::var("AUTH_URL").unwrap_or_else(|_| { "https://discord.com/api/oauth2/authorize?response_type=code".to_string() }); let token_url = env::var("TOKEN_URL") .unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string()); Ok(BasicClient::new( ClientId::new(client_id), Some(ClientSecret::new(client_secret)), AuthUrl::new(auth_url).context("failed to create new authorization server URL")?, Some(TokenUrl::new(token_url).context("failed to create new token endpoint URL")?), ) .set_redirect_uri( RedirectUrl::new(redirect_url).context("failed to create new redirection URL")?, )) } // The user data we'll get back from Discord. // https://discord.com/developers/docs/resources/user#user-object-user-structure #[derive(Debug, Serialize, Deserialize)] struct User { id: String, avatar: Option<String>, username: String, discriminator: String, } // Session is optional async fn index(user: Option<User>) -> impl IntoResponse { match user { Some(u) => format!( "Hey {}! You're logged in!\nYou may now access `/protected`.\nLog out with `/logout`.", u.username ), None => "You're not logged in.\nVisit `/auth/discord` to do so.".to_string(), } } async fn discord_auth( State(client): State<BasicClient>, State(store): State<MemoryStore>, ) -> Result<impl IntoResponse, AppError> { let (auth_url, csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) .url(); // Create session to store csrf_token let mut session = Session::new(); session .insert(CSRF_TOKEN, &csrf_token) .context("failed in inserting CSRF token into session")?; // Store the session in MemoryStore and retrieve the session cookie let cookie = store .store_session(session) .await .context("failed to store CSRF token session")? .context("unexpected error retrieving CSRF cookie value")?; // Attach the session cookie to the response header let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/"); let mut headers = HeaderMap::new(); headers.insert( SET_COOKIE, cookie.parse().context("failed to parse cookie")?, ); Ok((headers, Redirect::to(auth_url.as_ref()))) } // Valid user session required. If there is none, redirect to the auth page async fn protected(user: User) -> impl IntoResponse { format!("Welcome to the protected area :)\nHere's your info:\n{user:?}") } async fn logout( State(store): State<MemoryStore>, TypedHeader(cookies): TypedHeader<headers::Cookie>, ) -> Result<impl IntoResponse, AppError> { let cookie = cookies .get(COOKIE_NAME) .context("unexpected error getting cookie name")?; let session = match store .load_session(cookie.to_string()) .await .context("failed to load session")? { Some(s) => s, // No session active, just redirect None => return Ok(Redirect::to("/")), }; store .destroy_session(session) .await .context("failed to destroy session")?; Ok(Redirect::to("/")) } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct AuthRequest { code: String, state: String, } async fn csrf_token_validation_workflow( auth_request: &AuthRequest, cookies: &headers::Cookie, store: &MemoryStore, ) -> Result<(), AppError> { // Extract the cookie from the request let cookie = cookies .get(COOKIE_NAME) .context("unexpected error getting cookie name")? .to_string(); // Load the session let session = match store .load_session(cookie) .await .context("failed to load session")? { Some(session) => session, None => return Err(anyhow!("Session not found").into()), }; // Extract the CSRF token from the session let stored_csrf_token = session .get::<CsrfToken>(CSRF_TOKEN) .context("CSRF token not found in session")? .to_owned(); // Cleanup the CSRF token session store .destroy_session(session) .await .context("Failed to destroy old session")?; // Validate CSRF token is the same as the one in the auth request if *stored_csrf_token.secret() != auth_request.state { return Err(anyhow!("CSRF token mismatch").into()); } Ok(()) } async fn login_authorized( Query(query): Query<AuthRequest>, State(store): State<MemoryStore>, State(oauth_client): State<BasicClient>, TypedHeader(cookies): TypedHeader<headers::Cookie>, ) -> Result<impl IntoResponse, AppError> { csrf_token_validation_workflow(&query, &cookies, &store).await?; // Get an auth token let token = oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await .context("failed in sending request request to authorization server")?; // Fetch user data from discord let client = reqwest::Client::new(); let user_data: User = client // https://discord.com/developers/docs/resources/user#get-current-user .get("https://discordapp.com/api/users/@me") .bearer_auth(token.access_token().secret()) .send() .await .context("failed in sending request to target Url")? .json::<User>() .await .context("failed to deserialize response as JSON")?; // Create a new session filled with user data let mut session = Session::new(); session .insert("user", &user_data) .context("failed in inserting serialized value into session")?; // Store session and get corresponding cookie let cookie = store .store_session(session) .await .context("failed to store session")? .context("unexpected error retrieving cookie value")?; // Build the cookie let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/"); // Set cookie let mut headers = HeaderMap::new(); headers.insert( SET_COOKIE, cookie.parse().context("failed to parse cookie")?, ); Ok((headers, Redirect::to("/"))) } struct AuthRedirect; impl IntoResponse for AuthRedirect { fn into_response(self) -> Response { Redirect::temporary("/auth/discord").into_response() } } impl<S> FromRequestParts<S> for User where MemoryStore: FromRef<S>, S: Send + Sync, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { let store = MemoryStore::from_ref(state); let cookies = parts .extract::<TypedHeader<headers::Cookie>>() .await .map_err(|e| match *e.name() { header::COOKIE => match e.reason() { TypedHeaderRejectionReason::Missing => AuthRedirect, _ => panic!("unexpected error getting Cookie header(s): {e}"), }, _ => panic!("unexpected error getting cookies: {e}"), })?; let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?; let session = store .load_session(session_cookie.to_string()) .await .unwrap() .ok_or(AuthRedirect)?; let user = session.get::<User>("user").ok_or(AuthRedirect)?; Ok(user) } } impl<S> OptionalFromRequestParts<S> for User where MemoryStore: FromRef<S>, S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts( parts: &mut Parts, state: &S, ) -> Result<Option<Self>, Self::Rejection> { match <User as FromRequestParts<S>>::from_request_parts(parts, state).await { Ok(res) => Ok(Some(res)), Err(AuthRedirect) => Ok(None), } } } // Use anyhow, define error and enable '?' // For a simplified example of using anyhow in axum check /examples/anyhow-error-response #[derive(Debug)] struct AppError(anyhow::Error); // Tell axum how to convert `AppError` into a response. impl IntoResponse for AppError { fn into_response(self) -> Response { tracing::error!("Application error: {:#}", self.0); (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() } } // This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into // `Result<_, AppError>`. That way you don't need to do that manually. impl<E> From<E> for AppError where E: Into<anyhow::Error>, { fn from(err: E) -> Self { Self(err.into()) } }
low-level-rustls
[package]
name = "example-low-level-rustls"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
futures-util = { version = "0.3", default-features = false }
hyper = { version = "1.0.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["http2"] }
tokio = { version = "1", features = ["full"] }
tokio-rustls = "0.26"
tower-service = "0.3.2"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-low-level-rustls //! ``` use axum::{extract::Request, routing::get, Router}; use futures_util::pin_mut; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; use std::{ path::{Path, PathBuf}, sync::Arc, }; use tokio::net::TcpListener; use tokio_rustls::{ rustls::pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, rustls::ServerConfig, TlsAcceptor, }; use tower_service::Service; use tracing::{error, info, warn}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let rustls_config = rustls_server_config( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), ); let tls_acceptor = TlsAcceptor::from(rustls_config); let bind = "[::1]:3000"; let tcp_listener = TcpListener::bind(bind).await.unwrap(); info!("HTTPS server listening on {bind}. To contact curl -k https://localhost:3000"); let app = Router::new().route("/", get(handler)); pin_mut!(tcp_listener); loop { let tower_service = app.clone(); let tls_acceptor = tls_acceptor.clone(); // Wait for new tcp connection let (cnx, addr) = tcp_listener.accept().await.unwrap(); tokio::spawn(async move { // Wait for tls handshake to happen let Ok(stream) = tls_acceptor.accept(cnx).await else { error!("error during tls handshake connection from {}", addr); return; }; // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. // `TokioIo` converts between them. let stream = TokioIo::new(stream); // Hyper also has its own `Service` trait and doesn't use tower. We can use // `hyper::service::service_fn` to create a hyper `Service` that calls our app through // `tower::Service::call`. let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas // tower's `Service` requires `&mut self`. // // We don't need to call `poll_ready` since `Router` is always ready. tower_service.clone().call(request) }); let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) .serve_connection_with_upgrades(stream, hyper_service) .await; if let Err(err) = ret { warn!("error serving connection from {}: {}", addr, err); } }); } } async fn handler() -> &'static str { "Hello, World!" } fn rustls_server_config(key: impl AsRef<Path>, cert: impl AsRef<Path>) -> Arc<ServerConfig> { let key = PrivateKeyDer::from_pem_file(key).unwrap(); let certs = CertificateDer::pem_file_iter(cert) .unwrap() .map(|cert| cert.unwrap()) .collect(); let mut config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key) .expect("bad certificate/key"); config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; Arc::new(config) }
Additional files:
self_signed_certs/cert.pem self_signed_certs/key.pem
low-level-openssl
[package]
name = "example-low-level-openssl"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
hyper = { version = "1.0.0", features = ["full"] }
hyper-util = { version = "0.1" }
openssl = "0.10"
tokio = { version = "1", features = ["full"] }
tokio-openssl = "0.6"
tower = { version = "0.5.2", features = ["make"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
use axum::{http::Request, routing::get, Router}; use futures_util::pin_mut; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; use openssl::ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod}; use std::{path::PathBuf, pin::Pin}; use tokio::net::TcpListener; use tokio_openssl::SslStream; use tower::Service; use tracing::{error, info, warn}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls()).unwrap(); tls_builder .set_certificate_file( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), SslFiletype::PEM, ) .unwrap(); tls_builder .set_private_key_file( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), SslFiletype::PEM, ) .unwrap(); tls_builder.check_private_key().unwrap(); let tls_acceptor = tls_builder.build(); let bind = "[::1]:3000"; let tcp_listener = TcpListener::bind(bind).await.unwrap(); info!("HTTPS server listening on {bind}. To contact curl -k https://localhost:3000"); let app = Router::new().route("/", get(handler)); pin_mut!(tcp_listener); loop { let tower_service = app.clone(); let tls_acceptor = tls_acceptor.clone(); // Wait for new tcp connection let (cnx, addr) = tcp_listener.accept().await.unwrap(); tokio::spawn(async move { let ssl = Ssl::new(tls_acceptor.context()).unwrap(); let mut tls_stream = SslStream::new(ssl, cnx).unwrap(); if let Err(err) = SslStream::accept(Pin::new(&mut tls_stream)).await { error!( "error during tls handshake connection from {}: {}", addr, err ); return; } // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. // `TokioIo` converts between them. let stream = TokioIo::new(tls_stream); // Hyper also has its own `Service` trait and doesn't use tower. We can use // `hyper::service::service_fn` to create a hyper `Service` that calls our app through // `tower::Service::call`. let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas // tower's `Service` requires `&mut self`. // // We don't need to call `poll_ready` since `Router` is always ready. tower_service.clone().call(request) }); let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) .serve_connection_with_upgrades(stream, hyper_service) .await; if let Err(err) = ret { warn!("error serving connection from {}: {}", addr, err); } }); } } async fn handler() -> &'static str { "Hello, World!" }
self_signed_certs/cert.pem self_signed_certs/key.pem
low-level-native-tls
[package]
name = "example-low-level-native-tls"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
futures-util = { version = "0.3", default-features = false }
hyper = { version = "1.0.0", features = ["full"] }
hyper-util = { version = "0.1" }
tokio = { version = "1", features = ["full"] }
tokio-native-tls = "0.3.1"
tower-service = "0.3.2"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-low-level-native-tls //! ``` use axum::{extract::Request, routing::get, Router}; use futures_util::pin_mut; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; use std::path::PathBuf; use tokio::net::TcpListener; use tokio_native_tls::{ native_tls::{Identity, Protocol, TlsAcceptor as NativeTlsAcceptor}, TlsAcceptor, }; use tower_service::Service; use tracing::{error, info, warn}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_low_level_rustls=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let tls_acceptor = native_tls_acceptor( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), ); let tls_acceptor = TlsAcceptor::from(tls_acceptor); let bind = "[::1]:3000"; let tcp_listener = TcpListener::bind(bind).await.unwrap(); info!("HTTPS server listening on {bind}. To contact curl -k https://localhost:3000"); let app = Router::new().route("/", get(handler)); pin_mut!(tcp_listener); loop { let tower_service = app.clone(); let tls_acceptor = tls_acceptor.clone(); // Wait for new tcp connection let (cnx, addr) = tcp_listener.accept().await.unwrap(); tokio::spawn(async move { // Wait for tls handshake to happen let Ok(stream) = tls_acceptor.accept(cnx).await else { error!("error during tls handshake connection from {}", addr); return; }; // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. // `TokioIo` converts between them. let stream = TokioIo::new(stream); // Hyper also has its own `Service` trait and doesn't use tower. We can use // `hyper::service::service_fn` to create a hyper `Service` that calls our app through // `tower::Service::call`. let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas // tower's `Service` requires `&mut self`. // // We don't need to call `poll_ready` since `Router` is always ready. tower_service.clone().call(request) }); let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) .serve_connection_with_upgrades(stream, hyper_service) .await; if let Err(err) = ret { warn!("error serving connection from {addr}: {err}"); } }); } } async fn handler() -> &'static str { "Hello, World!" } fn native_tls_acceptor(key_file: PathBuf, cert_file: PathBuf) -> NativeTlsAcceptor { let key_pem = std::fs::read_to_string(&key_file).unwrap(); let cert_pem = std::fs::read_to_string(&cert_file).unwrap(); let id = Identity::from_pkcs8(cert_pem.as_bytes(), key_pem.as_bytes()).unwrap(); NativeTlsAcceptor::builder(id) // let's be modern .min_protocol_version(Some(Protocol::Tlsv12)) .build() .unwrap() }
self_signed_certs/cert.pem self_signed_certs/key.pem
key-value-store
[package]
name = "example-key-value-store"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.5.2", features = ["util", "timeout", "load-shed", "limit"] }
tower-http = { version = "0.6.1", features = [
"add-extension",
"auth",
"compression-full",
"limit",
"trace",
] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Simple in-memory key/value store showing features of axum. //! //! Run with: //! //! ```not_rust //! cargo run -p example-key-value-store //! ``` use axum::{ body::Bytes, error_handling::HandleErrorLayer, extract::{DefaultBodyLimit, Path, State}, handler::Handler, http::StatusCode, response::IntoResponse, routing::{delete, get}, Router, }; use std::{ borrow::Cow, collections::HashMap, sync::{Arc, RwLock}, time::Duration, }; use tower::{BoxError, ServiceBuilder}; use tower_http::{ compression::CompressionLayer, limit::RequestBodyLimitLayer, trace::TraceLayer, validate_request::ValidateRequestHeaderLayer, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let shared_state = SharedState::default(); // Build our application by composing routes let app = Router::new() .route( "/{key}", // Add compression to `kv_get` get(kv_get.layer(CompressionLayer::new())) // But don't compress `kv_set` .post_service( kv_set .layer(( DefaultBodyLimit::disable(), RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */), )) .with_state(Arc::clone(&shared_state)), ), ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` .nest("/admin", admin_routes()) // Add middleware to all routes .layer( ServiceBuilder::new() // Handle errors from middleware .layer(HandleErrorLayer::new(handle_error)) .load_shed() .concurrency_limit(1024) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()), ) .with_state(Arc::clone(&shared_state)); // Run our app with hyper let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } type SharedState = Arc<RwLock<AppState>>; #[derive(Default)] struct AppState { db: HashMap<String, Bytes>, } async fn kv_get( Path(key): Path<String>, State(state): State<SharedState>, ) -> Result<Bytes, StatusCode> { let db = &state.read().unwrap().db; if let Some(value) = db.get(&key) { Ok(value.clone()) } else { Err(StatusCode::NOT_FOUND) } } async fn kv_set(Path(key): Path<String>, State(state): State<SharedState>, bytes: Bytes) { state.write().unwrap().db.insert(key, bytes); } async fn list_keys(State(state): State<SharedState>) -> String { let db = &state.read().unwrap().db; db.keys() .map(|key| key.to_string()) .collect::<Vec<String>>() .join("\n") } fn admin_routes() -> Router<SharedState> { async fn delete_all_keys(State(state): State<SharedState>) { state.write().unwrap().db.clear(); } async fn remove_key(Path(key): Path<String>, State(state): State<SharedState>) { state.write().unwrap().db.remove(&key); } Router::new() .route("/keys", delete(delete_all_keys)) .route("/key/{key}", delete(remove_key)) // Require bearer auth for all admin routes .layer(ValidateRequestHeaderLayer::bearer("secret-token")) } async fn handle_error(error: BoxError) -> impl IntoResponse { if error.is::<tower::timeout::error::Elapsed>() { return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out")); } if error.is::<tower::load_shed::error::Overloaded>() { return ( StatusCode::SERVICE_UNAVAILABLE, Cow::from("service is overloaded, try again later"), ); } ( StatusCode::INTERNAL_SERVER_ERROR, Cow::from(format!("Unhandled internal error: {error}")), ) }
auto-reload
This example shows how you can set up a development environment for your axum
service such that whenever the source code changes, the app is recompiled and
restarted. It uses listenfd
to be able to migrate connections from an old
version of the app to a newly-compiled version.
Setup
cargo install cargo-watch systemfd
Running
systemfd --no-pid -s http::3000 -- cargo watch -x run
[package]
name = "auto-reload"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
listenfd = "1.0.1"
tokio = { version = "1.0", features = ["full"] }
//! Run with //! //! ```not_rust //! cargo run -p auto-reload //! ``` use axum::{response::Html, routing::get, Router}; use listenfd::ListenFd; use tokio::net::TcpListener; #[tokio::main] async fn main() { // build our application with a route let app = Router::new().route("/", get(handler)); let mut listenfd = ListenFd::from_env(); let listener = match listenfd.take_tcp_listener(0).unwrap() { // if we are given a tcp listener on listen fd 0, we use that one Some(listener) => { listener.set_nonblocking(true).unwrap(); TcpListener::from_std(listener).unwrap() } // otherwise fall back to local listening None => TcpListener::bind("127.0.0.1:3000").await.unwrap(), }; // run it println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn handler() -> Html<&'static str> { Html("<h1>Hello, World!</h1>") }
websockets
[package]
name = "example-websockets"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["ws"] }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
futures = "0.3"
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
headers = "0.4"
tokio = { version = "1.0", features = ["full"] }
tokio-tungstenite = "0.26.0"
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[[bin]]
name = "example-websockets"
path = "src/main.rs"
[[bin]]
name = "example-client"
path = "src/client.rs"
//! Example websocket server. //! //! Run the server with //! ```not_rust //! cargo run -p example-websockets --bin example-websockets //! ``` //! //! Run a browser client with //! ```not_rust //! firefox http://localhost:3000 //! ``` //! //! Alternatively you can run the rust client (showing two //! concurrent websocket connections being established) with //! ```not_rust //! cargo run -p example-websockets --bin example-client //! ``` use axum::{ body::Bytes, extract::ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, response::IntoResponse, routing::any, Router, }; use axum_extra::TypedHeader; use std::ops::ControlFlow; use std::{net::SocketAddr, path::PathBuf}; use tower_http::{ services::ServeDir, trace::{DefaultMakeSpan, TraceLayer}, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; //allows to extract the IP of connecting user use axum::extract::connect_info::ConnectInfo; use axum::extract::ws::CloseFrame; //allows to split the websocket stream into separate TX and RX branches use futures::{sink::SinkExt, stream::StreamExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); // build our application with some routes let app = Router::new() .fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) .route("/ws", any(ws_handler)) // logging so we can see what's going on .layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::default().include_headers(true)), ); // run it with hyper let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve( listener, app.into_make_service_with_connect_info::<SocketAddr>(), ) .await .unwrap(); } /// The handler for the HTTP request (this gets called when the HTTP request lands at the start /// of websocket negotiation). After this completes, the actual switching from HTTP to /// websocket protocol will occur. /// This is the last point where we can extract TCP/IP metadata such as IP address of the client /// as well as things from HTTP headers such as user-agent of the browser etc. async fn ws_handler( ws: WebSocketUpgrade, user_agent: Option<TypedHeader<headers::UserAgent>>, ConnectInfo(addr): ConnectInfo<SocketAddr>, ) -> impl IntoResponse { let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { user_agent.to_string() } else { String::from("Unknown browser") }; println!("`{user_agent}` at {addr} connected."); // finalize the upgrade process by returning upgrade callback. // we can customize the callback by sending additional info such as address. ws.on_upgrade(move |socket| handle_socket(socket, addr)) } /// Actual websocket statemachine (one will be spawned per connection) async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { // send a ping (unsupported by some browsers) just to kick things off and get a response if socket .send(Message::Ping(Bytes::from_static(&[1, 2, 3]))) .await .is_ok() { println!("Pinged {who}..."); } else { println!("Could not send ping {who}!"); // no Error here since the only thing we can do is to close the connection. // If we can not send messages, there is no way to salvage the statemachine anyway. return; } // receive single message from a client (we can either receive or send with socket). // this will likely be the Pong for our Ping or a hello message from client. // waiting for message from a client will block this task, but will not block other client's // connections. if let Some(msg) = socket.recv().await { if let Ok(msg) = msg { if process_message(msg, who).is_break() { return; } } else { println!("client {who} abruptly disconnected"); return; } } // Since each client gets individual statemachine, we can pause handling // when necessary to wait for some external event (in this case illustrated by sleeping). // Waiting for this client to finish getting its greetings does not prevent other clients from // connecting to server and receiving their greetings. for i in 1..5 { if socket .send(Message::Text(format!("Hi {i} times!").into())) .await .is_err() { println!("client {who} abruptly disconnected"); return; } tokio::time::sleep(std::time::Duration::from_millis(100)).await; } // By splitting socket we can send and receive at the same time. In this example we will send // unsolicited messages to client based on some sort of server's internal event (i.e .timer). let (mut sender, mut receiver) = socket.split(); // Spawn a task that will push several messages to the client (does not matter what client does) let mut send_task = tokio::spawn(async move { let n_msg = 20; for i in 0..n_msg { // In case of any websocket error, we exit. if sender .send(Message::Text(format!("Server message {i} ...").into())) .await .is_err() { return i; } tokio::time::sleep(std::time::Duration::from_millis(300)).await; } println!("Sending close to {who}..."); if let Err(e) = sender .send(Message::Close(Some(CloseFrame { code: axum::extract::ws::close_code::NORMAL, reason: Utf8Bytes::from_static("Goodbye"), }))) .await { println!("Could not send Close due to {e}, probably it is ok?"); } n_msg }); // This second task will receive messages from client and print them on server console let mut recv_task = tokio::spawn(async move { let mut cnt = 0; while let Some(Ok(msg)) = receiver.next().await { cnt += 1; // print message and break if instructed to do so if process_message(msg, who).is_break() { break; } } cnt }); // If any one of the tasks exit, abort the other. tokio::select! { rv_a = (&mut send_task) => { match rv_a { Ok(a) => println!("{a} messages sent to {who}"), Err(a) => println!("Error sending messages {a:?}") } recv_task.abort(); }, rv_b = (&mut recv_task) => { match rv_b { Ok(b) => println!("Received {b} messages"), Err(b) => println!("Error receiving messages {b:?}") } send_task.abort(); } } // returning from the handler closes the websocket connection println!("Websocket context {who} destroyed"); } /// helper to print contents of messages to stdout. Has special treatment for Close. fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { match msg { Message::Text(t) => { println!(">>> {who} sent str: {t:?}"); } Message::Binary(d) => { println!(">>> {} sent {} bytes: {:?}", who, d.len(), d); } Message::Close(c) => { if let Some(cf) = c { println!( ">>> {} sent close with code {} and reason `{}`", who, cf.code, cf.reason ); } else { println!(">>> {who} somehow sent close message without CloseFrame"); } return ControlFlow::Break(()); } Message::Pong(v) => { println!(">>> {who} sent pong with {v:?}"); } // You should never need to manually handle Message::Ping, as axum's websocket library // will do so for you automagically by replying with Pong and copying the v according to // spec. But if you need the contents of the pings you can see them here. Message::Ping(v) => { println!(">>> {who} sent ping with {v:?}"); } } ControlFlow::Continue(()) }
//! Based on tokio-tungstenite example websocket client, but with multiple //! concurrent websocket clients in one package //! //! This will connect to a server specified in the SERVER with N_CLIENTS //! concurrent connections, and then flood some test messages over websocket. //! This will also print whatever it gets into stdout. //! //! Note that this is not currently optimized for performance, especially around //! stdout mutex management. Rather it's intended to show an example of working with axum's //! websocket server and how the client-side and server-side code can be quite similar. //! use futures_util::stream::FuturesUnordered; use futures_util::{SinkExt, StreamExt}; use std::ops::ControlFlow; use std::time::Instant; use tokio_tungstenite::tungstenite::Utf8Bytes; // we will use tungstenite for websocket client impl (same library as what axum is using) use tokio_tungstenite::{ connect_async, tungstenite::protocol::{frame::coding::CloseCode, CloseFrame, Message}, }; const N_CLIENTS: usize = 2; //set to desired number const SERVER: &str = "ws://127.0.0.1:3000/ws"; #[tokio::main] async fn main() { let start_time = Instant::now(); //spawn several clients that will concurrently talk to the server let mut clients = (0..N_CLIENTS) .map(|cli| tokio::spawn(spawn_client(cli))) .collect::<FuturesUnordered<_>>(); //wait for all our clients to exit while clients.next().await.is_some() {} let end_time = Instant::now(); //total time should be the same no matter how many clients we spawn println!( "Total time taken {:#?} with {N_CLIENTS} concurrent clients, should be about 6.45 seconds.", end_time - start_time ); } //creates a client. quietly exits on failure. async fn spawn_client(who: usize) { let ws_stream = match connect_async(SERVER).await { Ok((stream, response)) => { println!("Handshake for client {who} has been completed"); // This will be the HTTP response, same as with server this is the last moment we // can still access HTTP stuff. println!("Server response was {response:?}"); stream } Err(e) => { println!("WebSocket handshake for client {who} failed with {e}!"); return; } }; let (mut sender, mut receiver) = ws_stream.split(); //we can ping the server for start sender .send(Message::Ping(axum::body::Bytes::from_static( b"Hello, Server!", ))) .await .expect("Can not send!"); //spawn an async sender to push some more messages into the server let mut send_task = tokio::spawn(async move { for i in 1..30 { // In any websocket error, break loop. if sender .send(Message::Text(format!("Message number {i}...").into())) .await .is_err() { //just as with server, if send fails there is nothing we can do but exit. return; } tokio::time::sleep(std::time::Duration::from_millis(300)).await; } // When we are done we may want our client to close connection cleanly. println!("Sending close to {who}..."); if let Err(e) = sender .send(Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: Utf8Bytes::from_static("Goodbye"), }))) .await { println!("Could not send Close due to {e:?}, probably it is ok?"); }; }); //receiver just prints whatever it gets let mut recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { // print message and break if instructed to do so if process_message(msg, who).is_break() { break; } } }); //wait for either task to finish and kill the other task tokio::select! { _ = (&mut send_task) => { recv_task.abort(); }, _ = (&mut recv_task) => { send_task.abort(); } } } /// Function to handle messages we get (with a slight twist that Frame variant is visible /// since we are working with the underlying tungstenite library directly without axum here). fn process_message(msg: Message, who: usize) -> ControlFlow<(), ()> { match msg { Message::Text(t) => { println!(">>> {who} got str: {t:?}"); } Message::Binary(d) => { println!(">>> {} got {} bytes: {:?}", who, d.len(), d); } Message::Close(c) => { if let Some(cf) = c { println!( ">>> {} got close with code {} and reason `{}`", who, cf.code, cf.reason ); } else { println!(">>> {who} somehow got close message without CloseFrame"); } return ControlFlow::Break(()); } Message::Pong(v) => { println!(">>> {who} got pong with {v:?}"); } // Just as with axum server, the underlying tungstenite websocket library // will handle Ping for you automagically by replying with Pong and copying the // v according to spec. But if you need the contents of the pings you can see them here. Message::Ping(v) => { println!(">>> {who} got ping with {v:?}"); } Message::Frame(_) => { unreachable!("This is never supposed to happen") } } ControlFlow::Continue(()) }
<a>Open the console to see stuff, then refresh to initiate exchange.</a>
<script src='script.js'></script>
const socket = new WebSocket('ws://localhost:3000/ws');
socket.addEventListener('open', function (event) {
socket.send('Hello Server!');
});
socket.addEventListener('message', function (event) {
console.log('Message from server ', event.data);
});
setTimeout(() => {
const obj = { hello: "world" };
const blob = new Blob([JSON.stringify(obj, null, 2)], {
type: "application/json",
});
console.log("Sending blob over websocket");
socket.send(blob);
}, 1000);
setTimeout(() => {
socket.send('About done here...');
console.log("Sending close over websocket");
socket.close(3000, "Crash and Burn!");
}, 3000);
websockets-http2
[package]
name = "example-websockets-http2"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["ws", "http2"] }
axum-server = { version = "0.6", features = ["tls-rustls"] }
tokio = { version = "1", features = ["full"] }
tower-http = { version = "0.5.0", features = ["fs"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-websockets-http2 //! ``` use axum::{ extract::{ ws::{self, WebSocketUpgrade}, State, }, http::Version, routing::any, Router, }; use axum_server::tls_rustls::RustlsConfig; use std::{net::SocketAddr, path::PathBuf}; use tokio::sync::broadcast; use tower_http::services::ServeDir; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); // configure certificate and private key used by https let config = RustlsConfig::from_pem_file( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), ) .await .unwrap(); // build our application with some routes and a broadcast channel let app = Router::new() .fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) .route("/ws", any(ws_handler)) .with_state(broadcast::channel::<String>(16).0); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); let mut server = axum_server::bind_rustls(addr, config); // IMPORTANT: This is required to advertise our support for HTTP/2 websockets to the client. // If you use axum::serve, it is enabled by default. server.http_builder().http2().enable_connect_protocol(); server.serve(app.into_make_service()).await.unwrap(); } async fn ws_handler( ws: WebSocketUpgrade, version: Version, State(sender): State<broadcast::Sender<String>>, ) -> axum::response::Response { tracing::debug!("accepted a WebSocket using {version:?}"); let mut receiver = sender.subscribe(); ws.on_upgrade(|mut ws| async move { loop { tokio::select! { // Since `ws` is a `Stream`, it is by nature cancel-safe. res = ws.recv() => { match res { Some(Ok(ws::Message::Text(s))) => { let _ = sender.send(s.to_string()); } Some(Ok(_)) => {} Some(Err(e)) => tracing::debug!("client disconnected abruptly: {e}"), None => break, } } // Tokio guarantees that `broadcast::Receiver::recv` is cancel-safe. res = receiver.recv() => { match res { Ok(msg) => if let Err(e) = ws.send(ws::Message::Text(msg.into())).await { tracing::debug!("client disconnected abruptly: {e}"); } Err(_) => continue, } } } } }) }
<p>Open this page in two windows and try sending some messages!</p>
<form action="javascript:void(0)">
<input type="text" name="content" required>
<button>Send</button>
</form>
<div id="messages"></div>
<script src='script.js'></script>
const socket = new WebSocket('wss://localhost:3000/ws');
socket.addEventListener('message', e => {
document.getElementById("messages").append(e.data, document.createElement("br"));
});
const form = document.querySelector("form");
form.addEventListener("submit", () => {
socket.send(form.elements.namedItem("content").value);
form.elements.namedItem("content").value = "";
});
testing-websockets
[package]
name = "example-testing-websockets"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["ws"] }
futures = "0.3"
tokio = { version = "1.0", features = ["full"] }
tokio-tungstenite = "0.26"
//! Run with //! //! ```not_rust //! cargo test -p example-testing-websockets //! ``` use axum::{ extract::{ ws::{Message, WebSocket}, WebSocketUpgrade, }, response::Response, routing::get, Router, }; use futures::{Sink, SinkExt, Stream, StreamExt}; #[tokio::main] async fn main() { let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app()).await.unwrap(); } fn app() -> Router { // WebSocket routes can generally be tested in two ways: // // - Integration tests where you run the server and connect with a real WebSocket client. // - Unit tests where you mock the socket as some generic send/receive type // // Which version you pick is up to you. Generally we recommend the integration test version // unless your app has a lot of setup that makes it hard to run in a test. Router::new() .route("/integration-testable", get(integration_testable_handler)) .route("/unit-testable", get(unit_testable_handler)) } // A WebSocket handler that echos any message it receives. // // This one we'll be integration testing so it can be written in the regular way. async fn integration_testable_handler(ws: WebSocketUpgrade) -> Response { ws.on_upgrade(integration_testable_handle_socket) } async fn integration_testable_handle_socket(mut socket: WebSocket) { while let Some(Ok(msg)) = socket.recv().await { if let Message::Text(msg) = msg { if socket .send(Message::Text(format!("You said: {msg}").into())) .await .is_err() { break; } } } } // The unit testable version requires some changes. // // By splitting the socket into an `impl Sink` and `impl Stream` we can test without providing a // real socket and instead using channels, which also implement `Sink` and `Stream`. async fn unit_testable_handler(ws: WebSocketUpgrade) -> Response { ws.on_upgrade(|socket| { let (write, read) = socket.split(); unit_testable_handle_socket(write, read) }) } // The implementation is largely the same as `integration_testable_handle_socket` expect we call // methods from `SinkExt` and `StreamExt`. async fn unit_testable_handle_socket<W, R>(mut write: W, mut read: R) where W: Sink<Message> + Unpin, R: Stream<Item = Result<Message, axum::Error>> + Unpin, { while let Some(Ok(msg)) = read.next().await { if let Message::Text(msg) = msg { if write .send(Message::Text(format!("You said: {msg}").into())) .await .is_err() { break; } } } } #[cfg(test)] mod tests { use super::*; use std::{ future::IntoFuture, net::{Ipv4Addr, SocketAddr}, }; use tokio_tungstenite::tungstenite; // We can integration test one handler by running the server in a background task and // connecting to it like any other client would. #[tokio::test] async fn integration_test() { let listener = tokio::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))) .await .unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(axum::serve(listener, app()).into_future()); let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/integration-testable")) .await .unwrap(); socket .send(tungstenite::Message::text("foo")) .await .unwrap(); let msg = match socket.next().await.unwrap().unwrap() { tungstenite::Message::Text(msg) => msg, other => panic!("expected a text message but got {other:?}"), }; assert_eq!(msg.as_str(), "You said: foo"); } // We can unit test the other handler by creating channels to read and write from. #[tokio::test] async fn unit_test() { // Need to use "futures" channels rather than "tokio" channels as they implement `Sink` and // `Stream` let (socket_write, mut test_rx) = futures::channel::mpsc::channel(1024); let (mut test_tx, socket_read) = futures::channel::mpsc::channel(1024); tokio::spawn(unit_testable_handle_socket(socket_write, socket_read)); test_tx.send(Ok(Message::Text("foo".into()))).await.unwrap(); let msg = match test_rx.next().await.unwrap() { Message::Text(msg) => msg, other => panic!("expected a text message but got {other:?}"), }; assert_eq!(msg.as_str(), "You said: foo"); } }
chat
[package]
name = "example-chat"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["ws"] }
futures = "0.3"
tokio = { version = "1", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Example chat application. //! //! Run with //! //! ```not_rust //! cargo run -p example-chat //! ``` use axum::{ extract::{ ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, State, }, response::{Html, IntoResponse}, routing::get, Router, }; use futures::{sink::SinkExt, stream::StreamExt}; use std::{ collections::HashSet, sync::{Arc, Mutex}, }; use tokio::sync::broadcast; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // Our shared state struct AppState { // We require unique usernames. This tracks which usernames have been taken. user_set: Mutex<HashSet<String>>, // Channel used to send messages to all connected clients. tx: broadcast::Sender<String>, } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // Set up application state for use with with_state(). let user_set = Mutex::new(HashSet::new()); let (tx, _rx) = broadcast::channel(100); let app_state = Arc::new(AppState { user_set, tx }); let app = Router::new() .route("/", get(index)) .route("/websocket", get(websocket_handler)) .with_state(app_state); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn websocket_handler( ws: WebSocketUpgrade, State(state): State<Arc<AppState>>, ) -> impl IntoResponse { ws.on_upgrade(|socket| websocket(socket, state)) } // This function deals with a single websocket connection, i.e., a single // connected client / user, for which we will spawn two independent tasks (for // receiving / sending chat messages). async fn websocket(stream: WebSocket, state: Arc<AppState>) { // By splitting, we can send and receive at the same time. let (mut sender, mut receiver) = stream.split(); // Username gets set in the receive loop, if it's valid. let mut username = String::new(); // Loop until a text message is found. while let Some(Ok(message)) = receiver.next().await { if let Message::Text(name) = message { // If username that is sent by client is not taken, fill username string. check_username(&state, &mut username, name.as_str()); // If not empty we want to quit the loop else we want to quit function. if !username.is_empty() { break; } else { // Only send our client that username is taken. let _ = sender .send(Message::Text(Utf8Bytes::from_static( "Username already taken.", ))) .await; return; } } } // We subscribe *before* sending the "joined" message, so that we will also // display it to our client. let mut rx = state.tx.subscribe(); // Now send the "joined" message to all subscribers. let msg = format!("{username} joined."); tracing::debug!("{msg}"); let _ = state.tx.send(msg); // Spawn the first task that will receive broadcast messages and send text // messages over the websocket to our client. let mut send_task = tokio::spawn(async move { while let Ok(msg) = rx.recv().await { // In any websocket error, break loop. if sender.send(Message::text(msg)).await.is_err() { break; } } }); // Clone things we want to pass (move) to the receiving task. let tx = state.tx.clone(); let name = username.clone(); // Spawn a task that takes messages from the websocket, prepends the user // name, and sends them to all broadcast subscribers. let mut recv_task = tokio::spawn(async move { while let Some(Ok(Message::Text(text))) = receiver.next().await { // Add username before message. let _ = tx.send(format!("{name}: {text}")); } }); // If any one of the tasks run to completion, we abort the other. tokio::select! { _ = &mut send_task => recv_task.abort(), _ = &mut recv_task => send_task.abort(), }; // Send "user left" message (similar to "joined" above). let msg = format!("{username} left."); tracing::debug!("{msg}"); let _ = state.tx.send(msg); // Remove username from map so new clients can take it again. state.user_set.lock().unwrap().remove(&username); } fn check_username(state: &AppState, string: &mut String, name: &str) { let mut user_set = state.user_set.lock().unwrap(); if !user_set.contains(name) { user_set.insert(name.to_owned()); string.push_str(name); } } // Include utf-8 file at **compile** time. async fn index() -> Html<&'static str> { Html(std::include_str!("../chat.html")) }
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>WebSocket Chat</title>
</head>
<body>
<h1>WebSocket Chat Example</h1>
<input id="username" style="display:block; width:100px; box-sizing: border-box" type="text" placeholder="username">
<button id="join-chat" type="button">Join Chat</button>
<textarea id="chat" style="display:block; width:600px; height:400px; box-sizing: border-box" cols="30" rows="10"></textarea>
<input id="input" style="display:block; width:600px; box-sizing: border-box" type="text" placeholder="chat">
<script>
const username = document.querySelector("#username");
const join_btn = document.querySelector("#join-chat");
const textarea = document.querySelector("#chat");
const input = document.querySelector("#input");
join_btn.addEventListener("click", function(e) {
this.disabled = true;
const websocket = new WebSocket("ws://localhost:3000/websocket");
websocket.onopen = function() {
console.log("connection opened");
websocket.send(username.value);
}
const btn = this;
websocket.onclose = function() {
console.log("connection closed");
btn.disabled = false;
}
websocket.onmessage = function(e) {
console.log("received message: "+e.data);
textarea.value += e.data+"\r\n";
}
input.onkeydown = function(e) {
if (e.key == "Enter") {
websocket.send(input.value);
input.value = "";
}
}
});
</script>
</body>
</html>
tokio-redis
[package]
name = "example-tokio-redis"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
bb8 = "0.8.5"
bb8-redis = "0.17.0"
redis = "0.27.2"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-tokio-redis //! ``` use axum::{ extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, Router, }; use bb8::{Pool, PooledConnection}; use bb8_redis::RedisConnectionManager; use redis::AsyncCommands; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use bb8_redis::bb8; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); tracing::debug!("connecting to redis"); let manager = RedisConnectionManager::new("redis://localhost").unwrap(); let pool = bb8::Pool::builder().build(manager).await.unwrap(); { // ping the database before starting let mut conn = pool.get().await.unwrap(); conn.set::<&str, &str, ()>("foo", "bar").await.unwrap(); let result: String = conn.get("foo").await.unwrap(); assert_eq!(result, "bar"); } tracing::debug!("successfully connected to redis and pinged it"); // build our application with some routes let app = Router::new() .route( "/", get(using_connection_pool_extractor).post(using_connection_extractor), ) .with_state(pool); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } type ConnectionPool = Pool<RedisConnectionManager>; async fn using_connection_pool_extractor( State(pool): State<ConnectionPool>, ) -> Result<String, (StatusCode, String)> { let mut conn = pool.get().await.map_err(internal_error)?; let result: String = conn.get("foo").await.map_err(internal_error)?; Ok(result) } // we can also write a custom extractor that grabs a connection from the pool // which setup is appropriate depends on your application struct DatabaseConnection(PooledConnection<'static, RedisConnectionManager>); impl<S> FromRequestParts<S> for DatabaseConnection where ConnectionPool: FromRef<S>, S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { let pool = ConnectionPool::from_ref(state); let conn = pool.get_owned().await.map_err(internal_error)?; Ok(Self(conn)) } } async fn using_connection_extractor( DatabaseConnection(mut conn): DatabaseConnection, ) -> Result<String, (StatusCode, String)> { let result: String = conn.get("foo").await.map_err(internal_error)?; Ok(result) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error<E>(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) }
tokio-postgres
[package]
name = "example-tokio-postgres"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
bb8 = "0.9.0"
bb8-postgres = "0.9.0"
tokio = { version = "1.0", features = ["full"] }
tokio-postgres = "0.7.2"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-tokio-postgres //! ``` use axum::{ extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, Router, }; use bb8::{Pool, PooledConnection}; use bb8_postgres::PostgresConnectionManager; use tokio_postgres::NoTls; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // set up connection pool let manager = PostgresConnectionManager::new_from_stringlike("host=localhost user=postgres", NoTls) .unwrap(); let pool = Pool::builder().build(manager).await.unwrap(); // build our application with some routes let app = Router::new() .route( "/", get(using_connection_pool_extractor).post(using_connection_extractor), ) .with_state(pool); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } type ConnectionPool = Pool<PostgresConnectionManager<NoTls>>; async fn using_connection_pool_extractor( State(pool): State<ConnectionPool>, ) -> Result<String, (StatusCode, String)> { let conn = pool.get().await.map_err(internal_error)?; let row = conn .query_one("select 1 + 1", &[]) .await .map_err(internal_error)?; let two: i32 = row.try_get(0).map_err(internal_error)?; Ok(two.to_string()) } // we can also write a custom extractor that grabs a connection from the pool // which setup is appropriate depends on your application struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager<NoTls>>); impl<S> FromRequestParts<S> for DatabaseConnection where ConnectionPool: FromRef<S>, S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { let pool = ConnectionPool::from_ref(state); let conn = pool.get_owned().await.map_err(internal_error)?; Ok(Self(conn)) } } async fn using_connection_extractor( DatabaseConnection(conn): DatabaseConnection, ) -> Result<String, (StatusCode, String)> { let row = conn .query_one("select 1 + 1", &[]) .await .map_err(internal_error)?; let two: i32 = row.try_get(0).map_err(internal_error)?; Ok(two.to_string()) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error<E>(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) }
sqlx-postgres
[package]
name = "example-sqlx-postgres"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any", "postgres"] }
//! Example of application using <https://github.com/launchbadge/sqlx> //! //! Run with //! //! ```not_rust //! cargo run -p example-sqlx-postgres //! ``` //! //! Test with curl: //! //! ```not_rust //! curl 127.0.0.1:3000 //! curl -X POST 127.0.0.1:3000 //! ``` use axum::{ extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, Router, }; use sqlx::postgres::{PgPool, PgPoolOptions}; use tokio::net::TcpListener; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use std::time::Duration; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let db_connection_str = std::env::var("DATABASE_URL") .unwrap_or_else(|_| "postgres://postgres:password@localhost".to_string()); // set up connection pool let pool = PgPoolOptions::new() .max_connections(5) .acquire_timeout(Duration::from_secs(3)) .connect(&db_connection_str) .await .expect("can't connect to database"); // build our application with some routes let app = Router::new() .route( "/", get(using_connection_pool_extractor).post(using_connection_extractor), ) .with_state(pool); // run it with hyper let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } // we can extract the connection pool with `State` async fn using_connection_pool_extractor( State(pool): State<PgPool>, ) -> Result<String, (StatusCode, String)> { sqlx::query_scalar("select 'hello world from pg'") .fetch_one(&pool) .await .map_err(internal_error) } // we can also write a custom extractor that grabs a connection from the pool // which setup is appropriate depends on your application struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>); impl<S> FromRequestParts<S> for DatabaseConnection where PgPool: FromRef<S>, S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { let pool = PgPool::from_ref(state); let conn = pool.acquire().await.map_err(internal_error)?; Ok(Self(conn)) } } async fn using_connection_extractor( DatabaseConnection(mut conn): DatabaseConnection, ) -> Result<String, (StatusCode, String)> { sqlx::query_scalar("select 'hello world from pg'") .fetch_one(&mut *conn) .await .map_err(internal_error) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error<E>(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) }
mongodb
[package]
name = "example-mongodb"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
mongodb = "3.1.0"
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.6.1", features = ["add-extension", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-mongodb //! ``` use axum::{ extract::{Path, State}, http::StatusCode, routing::{delete, get, post, put}, Json, Router, }; use mongodb::{ bson::doc, results::{DeleteResult, InsertOneResult, UpdateResult}, Client, Collection, }; use serde::{Deserialize, Serialize}; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { // connecting to mongodb let db_connection_str = std::env::var("DATABASE_URL").unwrap_or_else(|_| { "mongodb://admin:password@127.0.0.1:27017/?authSource=admin".to_string() }); let client = Client::with_uri_str(db_connection_str).await.unwrap(); // pinging the database client .database("axum-mongo") .run_command(doc! { "ping": 1 }) .await .unwrap(); println!("Pinged your database. Successfully connected to MongoDB!"); // logging middleware tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("Listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app(client)).await.unwrap(); } // defining routes and state fn app(client: Client) -> Router { let collection: Collection<Member> = client.database("axum-mongo").collection("members"); Router::new() .route("/create", post(create_member)) .route("/read/:id", get(read_member)) .route("/update", put(update_member)) .route("/delete/:id", delete(delete_member)) .layer(TraceLayer::new_for_http()) .with_state(collection) } // handler to create a new member async fn create_member( State(db): State<Collection<Member>>, Json(input): Json<Member>, ) -> Result<Json<InsertOneResult>, (StatusCode, String)> { let result = db.insert_one(input).await.map_err(internal_error)?; Ok(Json(result)) } // handler to read an existing member async fn read_member( State(db): State<Collection<Member>>, Path(id): Path<u32>, ) -> Result<Json<Option<Member>>, (StatusCode, String)> { let result = db .find_one(doc! { "_id": id }) .await .map_err(internal_error)?; Ok(Json(result)) } // handler to update an existing member async fn update_member( State(db): State<Collection<Member>>, Json(input): Json<Member>, ) -> Result<Json<UpdateResult>, (StatusCode, String)> { let result = db .replace_one(doc! { "_id": input.id }, input) .await .map_err(internal_error)?; Ok(Json(result)) } // handler to delete an existing member async fn delete_member( State(db): State<Collection<Member>>, Path(id): Path<u32>, ) -> Result<Json<DeleteResult>, (StatusCode, String)> { let result = db .delete_one(doc! { "_id": id }) .await .map_err(internal_error)?; Ok(Json(result)) } fn internal_error<E>(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } // defining Member type #[derive(Debug, Deserialize, Serialize)] struct Member { #[serde(rename = "_id")] id: u32, name: String, active: bool, }
diesel-postgres
[package]
name = "example-diesel-postgres"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["macros"] }
deadpool-diesel = { version = "0.6.1", features = ["postgres"] }
diesel = { version = "2", features = ["postgres"] }
diesel_migrations = "2"
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```not_rust //! cargo run -p example-diesel-postgres //! ``` //! //! Checkout the [diesel webpage](https://diesel.rs) for //! longer guides about diesel //! //! Checkout the [crates.io source code](https://github.com/rust-lang/crates.io/) //! for a real world application using axum and diesel use axum::{ extract::State, http::StatusCode, response::Json, routing::{get, post}, Router, }; use diesel::prelude::*; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // this embeds the migrations into the application binary // the migration path is relative to the `CARGO_MANIFEST_DIR` pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/"); // normally part of your generated schema.rs file table! { users (id) { id -> Integer, name -> Text, hair_color -> Nullable<Text>, } } #[derive(serde::Serialize, Selectable, Queryable)] struct User { id: i32, name: String, hair_color: Option<String>, } #[derive(serde::Deserialize, Insertable)] #[diesel(table_name = users)] struct NewUser { name: String, hair_color: Option<String>, } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let db_url = std::env::var("DATABASE_URL").unwrap(); // set up connection pool let manager = deadpool_diesel::postgres::Manager::new(db_url, deadpool_diesel::Runtime::Tokio1); let pool = deadpool_diesel::postgres::Pool::builder(manager) .build() .unwrap(); // run the migrations on server startup { let conn = pool.get().await.unwrap(); conn.interact(|conn| conn.run_pending_migrations(MIGRATIONS).map(|_| ())) .await .unwrap() .unwrap(); } // build our application with some routes let app = Router::new() .route("/user/list", get(list_users)) .route("/user/create", post(create_user)) .with_state(pool); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {addr}"); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); } async fn create_user( State(pool): State<deadpool_diesel::postgres::Pool>, Json(new_user): Json<NewUser>, ) -> Result<Json<User>, (StatusCode, String)> { let conn = pool.get().await.map_err(internal_error)?; let res = conn .interact(|conn| { diesel::insert_into(users::table) .values(new_user) .returning(User::as_returning()) .get_result(conn) }) .await .map_err(internal_error)? .map_err(internal_error)?; Ok(Json(res)) } async fn list_users( State(pool): State<deadpool_diesel::postgres::Pool>, ) -> Result<Json<Vec<User>>, (StatusCode, String)> { let conn = pool.get().await.map_err(internal_error)?; let res = conn .interact(|conn| users::table.select(User::as_select()).load(conn)) .await .map_err(internal_error)? .map_err(internal_error)?; Ok(Json(res)) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error<E>(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) }
-- Your SQL goes here
CREATE TABLE "users"(
"id" SERIAL PRIMARY KEY,
"name" TEXT NOT NULL,
"hair_color" TEXT
);
-- This file should undo anything in "up.sql"
DROP TABLE "users";
diesel-async-postgres
[package]
name = "example-diesel-async-postgres"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["macros"] }
bb8 = "0.8"
diesel = "2"
diesel-async = { version = "0.5", features = ["postgres", "bb8"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Run with //! //! ```sh //! export DATABASE_URL=postgres://localhost/your_db //! diesel migration run //! cargo run -p example-diesel-async-postgres //! ``` //! //! Checkout the [diesel webpage](https://diesel.rs) for //! longer guides about diesel //! //! Checkout the [crates.io source code](https://github.com/rust-lang/crates.io/) //! for a real world application using axum and diesel use axum::{ extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, response::Json, routing::{get, post}, Router, }; use diesel::prelude::*; use diesel_async::{ pooled_connection::AsyncDieselConnectionManager, AsyncPgConnection, RunQueryDsl, }; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // normally part of your generated schema.rs file table! { users (id) { id -> Integer, name -> Text, hair_color -> Nullable<Text>, } } #[derive(serde::Serialize, Selectable, Queryable)] struct User { id: i32, name: String, hair_color: Option<String>, } #[derive(serde::Deserialize, Insertable)] #[diesel(table_name = users)] struct NewUser { name: String, hair_color: Option<String>, } type Pool = bb8::Pool<AsyncDieselConnectionManager<AsyncPgConnection>>; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let db_url = std::env::var("DATABASE_URL").unwrap(); // set up connection pool let config = AsyncDieselConnectionManager::<diesel_async::AsyncPgConnection>::new(db_url); let pool = bb8::Pool::builder().build(config).await.unwrap(); // build our application with some routes let app = Router::new() .route("/user/list", get(list_users)) .route("/user/create", post(create_user)) .with_state(pool); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {addr}"); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); } async fn create_user( State(pool): State<Pool>, Json(new_user): Json<NewUser>, ) -> Result<Json<User>, (StatusCode, String)> { let mut conn = pool.get().await.map_err(internal_error)?; let res = diesel::insert_into(users::table) .values(new_user) .returning(User::as_returning()) .get_result(&mut conn) .await .map_err(internal_error)?; Ok(Json(res)) } // we can also write a custom extractor that grabs a connection from the pool // which setup is appropriate depends on your application struct DatabaseConnection( bb8::PooledConnection<'static, AsyncDieselConnectionManager<AsyncPgConnection>>, ); impl<S> FromRequestParts<S> for DatabaseConnection where S: Send + Sync, Pool: FromRef<S>, { type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { let pool = Pool::from_ref(state); let conn = pool.get_owned().await.map_err(internal_error)?; Ok(Self(conn)) } } async fn list_users( DatabaseConnection(mut conn): DatabaseConnection, ) -> Result<Json<Vec<User>>, (StatusCode, String)> { let res = users::table .select(User::as_select()) .load(&mut conn) .await .map_err(internal_error)?; Ok(Json(res)) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error<E>(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) }
-- Your SQL goes here
CREATE TABLE "users"(
"id" SERIAL PRIMARY KEY,
"name" TEXT NOT NULL,
"hair_color" TEXT
);
-- This file should undo anything in "up.sql"
DROP TABLE "users";