[package]
name = "example-websockets"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["ws"] }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
futures = "0.3"
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
headers = "0.4"
tokio = { version = "1.0", features = ["full"] }
tokio-tungstenite = "0.26.0"
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[[bin]]
name = "example-websockets"
path = "src/main.rs"
[[bin]]
name = "example-client"
path = "src/client.rs"
//! Example websocket server.
//!
//! Run the server with
//! ```not_rust
//! cargo run -p example-websockets --bin example-websockets
//! ```
//!
//! Run a browser client with
//! ```not_rust
//! firefox http://localhost:3000
//! ```
//!
//! Alternatively you can run the rust client (showing two
//! concurrent websocket connections being established) with
//! ```not_rust
//! cargo run -p example-websockets --bin example-client
//! ```
use axum::{
body::Bytes,
extract::ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
response::IntoResponse,
routing::any,
Router,
};
use axum_extra::TypedHeader;
use std::ops::ControlFlow;
use std::{net::SocketAddr, path::PathBuf};
use tower_http::{
services::ServeDir,
trace::{DefaultMakeSpan, TraceLayer},
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
//allows to extract the IP of connecting user
use axum::extract::connect_info::ConnectInfo;
use axum::extract::ws::CloseFrame;
//allows to split the websocket stream into separate TX and RX branches
use futures::{sink::SinkExt, stream::StreamExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into()
}),
)
.with(tracing_subscriber::fmt::layer())
.init();
let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
// build our application with some routes
let app = Router::new()
.fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true))
.route("/ws", any(ws_handler))
// logging so we can see what's going on
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
);
// run it with hyper
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
}
/// The handler for the HTTP request (this gets called when the HTTP request lands at the start
/// of websocket negotiation). After this completes, the actual switching from HTTP to
/// websocket protocol will occur.
/// This is the last point where we can extract TCP/IP metadata such as IP address of the client
/// as well as things from HTTP headers such as user-agent of the browser etc.
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
println!("`{user_agent}` at {addr} connected.");
// finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| handle_socket(socket, addr))
}
/// Actual websocket statemachine (one will be spawned per connection)
async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
// send a ping (unsupported by some browsers) just to kick things off and get a response
if socket
.send(Message::Ping(Bytes::from_static(&[1, 2, 3])))
.await
.is_ok()
{
println!("Pinged {who}...");
} else {
println!("Could not send ping {who}!");
// no Error here since the only thing we can do is to close the connection.
// If we can not send messages, there is no way to salvage the statemachine anyway.
return;
}
// receive single message from a client (we can either receive or send with socket).
// this will likely be the Pong for our Ping or a hello message from client.
// waiting for message from a client will block this task, but will not block other client's
// connections.
if let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
if process_message(msg, who).is_break() {
return;
}
} else {
println!("client {who} abruptly disconnected");
return;
}
}
// Since each client gets individual statemachine, we can pause handling
// when necessary to wait for some external event (in this case illustrated by sleeping).
// Waiting for this client to finish getting its greetings does not prevent other clients from
// connecting to server and receiving their greetings.
for i in 1..5 {
if socket
.send(Message::Text(format!("Hi {i} times!").into()))
.await
.is_err()
{
println!("client {who} abruptly disconnected");
return;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
// By splitting socket we can send and receive at the same time. In this example we will send
// unsolicited messages to client based on some sort of server's internal event (i.e .timer).
let (mut sender, mut receiver) = socket.split();
// Spawn a task that will push several messages to the client (does not matter what client does)
let mut send_task = tokio::spawn(async move {
let n_msg = 20;
for i in 0..n_msg {
// In case of any websocket error, we exit.
if sender
.send(Message::Text(format!("Server message {i} ...").into()))
.await
.is_err()
{
return i;
}
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
}
println!("Sending close to {who}...");
if let Err(e) = sender
.send(Message::Close(Some(CloseFrame {
code: axum::extract::ws::close_code::NORMAL,
reason: Utf8Bytes::from_static("Goodbye"),
})))
.await
{
println!("Could not send Close due to {e}, probably it is ok?");
}
n_msg
});
// This second task will receive messages from client and print them on server console
let mut recv_task = tokio::spawn(async move {
let mut cnt = 0;
while let Some(Ok(msg)) = receiver.next().await {
cnt += 1;
// print message and break if instructed to do so
if process_message(msg, who).is_break() {
break;
}
}
cnt
});
// If any one of the tasks exit, abort the other.
tokio::select! {
rv_a = (&mut send_task) => {
match rv_a {
Ok(a) => println!("{a} messages sent to {who}"),
Err(a) => println!("Error sending messages {a:?}")
}
recv_task.abort();
},
rv_b = (&mut recv_task) => {
match rv_b {
Ok(b) => println!("Received {b} messages"),
Err(b) => println!("Error receiving messages {b:?}")
}
send_task.abort();
}
}
// returning from the handler closes the websocket connection
println!("Websocket context {who} destroyed");
}
/// helper to print contents of messages to stdout. Has special treatment for Close.
fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
match msg {
Message::Text(t) => {
println!(">>> {who} sent str: {t:?}");
}
Message::Binary(d) => {
println!(">>> {} sent {} bytes: {:?}", who, d.len(), d);
}
Message::Close(c) => {
if let Some(cf) = c {
println!(
">>> {} sent close with code {} and reason `{}`",
who, cf.code, cf.reason
);
} else {
println!(">>> {who} somehow sent close message without CloseFrame");
}
return ControlFlow::Break(());
}
Message::Pong(v) => {
println!(">>> {who} sent pong with {v:?}");
}
// You should never need to manually handle Message::Ping, as axum's websocket library
// will do so for you automagically by replying with Pong and copying the v according to
// spec. But if you need the contents of the pings you can see them here.
Message::Ping(v) => {
println!(">>> {who} sent ping with {v:?}");
}
}
ControlFlow::Continue(())
}
//! Based on tokio-tungstenite example websocket client, but with multiple
//! concurrent websocket clients in one package
//!
//! This will connect to a server specified in the SERVER with N_CLIENTS
//! concurrent connections, and then flood some test messages over websocket.
//! This will also print whatever it gets into stdout.
//!
//! Note that this is not currently optimized for performance, especially around
//! stdout mutex management. Rather it's intended to show an example of working with axum's
//! websocket server and how the client-side and server-side code can be quite similar.
//!
use futures_util::stream::FuturesUnordered;
use futures_util::{SinkExt, StreamExt};
use std::ops::ControlFlow;
use std::time::Instant;
use tokio_tungstenite::tungstenite::Utf8Bytes;
// we will use tungstenite for websocket client impl (same library as what axum is using)
use tokio_tungstenite::{
connect_async,
tungstenite::protocol::{frame::coding::CloseCode, CloseFrame, Message},
};
const N_CLIENTS: usize = 2; //set to desired number
const SERVER: &str = "ws://127.0.0.1:3000/ws";
#[tokio::main]
async fn main() {
let start_time = Instant::now();
//spawn several clients that will concurrently talk to the server
let mut clients = (0..N_CLIENTS)
.map(|cli| tokio::spawn(spawn_client(cli)))
.collect::<FuturesUnordered<_>>();
//wait for all our clients to exit
while clients.next().await.is_some() {}
let end_time = Instant::now();
//total time should be the same no matter how many clients we spawn
println!(
"Total time taken {:#?} with {N_CLIENTS} concurrent clients, should be about 6.45 seconds.",
end_time - start_time
);
}
//creates a client. quietly exits on failure.
async fn spawn_client(who: usize) {
let ws_stream = match connect_async(SERVER).await {
Ok((stream, response)) => {
println!("Handshake for client {who} has been completed");
// This will be the HTTP response, same as with server this is the last moment we
// can still access HTTP stuff.
println!("Server response was {response:?}");
stream
}
Err(e) => {
println!("WebSocket handshake for client {who} failed with {e}!");
return;
}
};
let (mut sender, mut receiver) = ws_stream.split();
//we can ping the server for start
sender
.send(Message::Ping(axum::body::Bytes::from_static(
b"Hello, Server!",
)))
.await
.expect("Can not send!");
//spawn an async sender to push some more messages into the server
let mut send_task = tokio::spawn(async move {
for i in 1..30 {
// In any websocket error, break loop.
if sender
.send(Message::Text(format!("Message number {i}...").into()))
.await
.is_err()
{
//just as with server, if send fails there is nothing we can do but exit.
return;
}
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
}
// When we are done we may want our client to close connection cleanly.
println!("Sending close to {who}...");
if let Err(e) = sender
.send(Message::Close(Some(CloseFrame {
code: CloseCode::Normal,
reason: Utf8Bytes::from_static("Goodbye"),
})))
.await
{
println!("Could not send Close due to {e:?}, probably it is ok?");
};
});
//receiver just prints whatever it gets
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
// print message and break if instructed to do so
if process_message(msg, who).is_break() {
break;
}
}
});
//wait for either task to finish and kill the other task
tokio::select! {
_ = (&mut send_task) => {
recv_task.abort();
},
_ = (&mut recv_task) => {
send_task.abort();
}
}
}
/// Function to handle messages we get (with a slight twist that Frame variant is visible
/// since we are working with the underlying tungstenite library directly without axum here).
fn process_message(msg: Message, who: usize) -> ControlFlow<(), ()> {
match msg {
Message::Text(t) => {
println!(">>> {who} got str: {t:?}");
}
Message::Binary(d) => {
println!(">>> {} got {} bytes: {:?}", who, d.len(), d);
}
Message::Close(c) => {
if let Some(cf) = c {
println!(
">>> {} got close with code {} and reason `{}`",
who, cf.code, cf.reason
);
} else {
println!(">>> {who} somehow got close message without CloseFrame");
}
return ControlFlow::Break(());
}
Message::Pong(v) => {
println!(">>> {who} got pong with {v:?}");
}
// Just as with axum server, the underlying tungstenite websocket library
// will handle Ping for you automagically by replying with Pong and copying the
// v according to spec. But if you need the contents of the pings you can see them here.
Message::Ping(v) => {
println!(">>> {who} got ping with {v:?}");
}
Message::Frame(_) => {
unreachable!("This is never supposed to happen")
}
}
ControlFlow::Continue(())
}
<a>Open the console to see stuff, then refresh to initiate exchange.</a>
<script src='script.js'></script>
const socket = new WebSocket('ws://localhost:3000/ws');
socket.addEventListener('open', function (event) {
socket.send('Hello Server!');
});
socket.addEventListener('message', function (event) {
console.log('Message from server ', event.data);
});
setTimeout(() => {
const obj = { hello: "world" };
const blob = new Blob([JSON.stringify(obj, null, 2)], {
type: "application/json",
});
console.log("Sending blob over websocket");
socket.send(blob);
}, 1000);
setTimeout(() => {
socket.send('About done here...');
console.log("Sending close over websocket");
socket.close(3000, "Crash and Burn!");
}, 3000);