[package]
name = "example-oauth"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
anyhow = "1"
async-session = "3.0.0"
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
http = "1.0.0"
oauth2 = "4.1"
# Use Rustls because it makes it easier to cross-compile on CI
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
//! Example OAuth (Discord) implementation.
//!
//! 1) Create a new application at <https://discord.com/developers/applications>
//! 2) Visit the OAuth2 tab to get your CLIENT_ID and CLIENT_SECRET
//! 3) Add a new redirect URI (for this example: `http://127.0.0.1:3000/auth/authorized`)
//! 4) Run with the following (replacing values appropriately):
//! ```not_rust
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
//! ```
use anyhow::{anyhow, Context, Result};
use async_session::{MemoryStore, Session, SessionStore};
use axum::{
extract::{FromRef, FromRequestParts, OptionalFromRequestParts, Query, State},
http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response},
routing::get,
RequestPartsExt, Router,
};
use axum_extra::{headers, typed_header::TypedHeaderRejectionReason, TypedHeader};
use http::{header, request::Parts, StatusCode};
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
use std::{convert::Infallible, env};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
static COOKIE_NAME: &str = "SESSION";
static CSRF_TOKEN: &str = "csrf_token";
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// `MemoryStore` is just used as an example. Don't use this in production.
let store = MemoryStore::new();
let oauth_client = oauth_client().unwrap();
let app_state = AppState {
store,
oauth_client,
};
let app = Router::new()
.route("/", get(index))
.route("/auth/discord", get(discord_auth))
.route("/auth/authorized", get(login_authorized))
.route("/protected", get(protected))
.route("/logout", get(logout))
.with_state(app_state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.context("failed to bind TcpListener")
.unwrap();
tracing::debug!(
"listening on {}",
listener
.local_addr()
.context("failed to return local address")
.unwrap()
);
axum::serve(listener, app).await.unwrap();
}
#[derive(Clone)]
struct AppState {
store: MemoryStore,
oauth_client: BasicClient,
}
impl FromRef<AppState> for MemoryStore {
fn from_ref(state: &AppState) -> Self {
state.store.clone()
}
}
impl FromRef<AppState> for BasicClient {
fn from_ref(state: &AppState) -> Self {
state.oauth_client.clone()
}
}
fn oauth_client() -> Result<BasicClient, AppError> {
// Environment variables (* = required):
// *"CLIENT_ID" "REPLACE_ME";
// *"CLIENT_SECRET" "REPLACE_ME";
// "REDIRECT_URL" "http://127.0.0.1:3000/auth/authorized";
// "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code";
// "TOKEN_URL" "https://discord.com/api/oauth2/token";
let client_id = env::var("CLIENT_ID").context("Missing CLIENT_ID!")?;
let client_secret = env::var("CLIENT_SECRET").context("Missing CLIENT_SECRET!")?;
let redirect_url = env::var("REDIRECT_URL")
.unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string());
let auth_url = env::var("AUTH_URL").unwrap_or_else(|_| {
"https://discord.com/api/oauth2/authorize?response_type=code".to_string()
});
let token_url = env::var("TOKEN_URL")
.unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string());
Ok(BasicClient::new(
ClientId::new(client_id),
Some(ClientSecret::new(client_secret)),
AuthUrl::new(auth_url).context("failed to create new authorization server URL")?,
Some(TokenUrl::new(token_url).context("failed to create new token endpoint URL")?),
)
.set_redirect_uri(
RedirectUrl::new(redirect_url).context("failed to create new redirection URL")?,
))
}
// The user data we'll get back from Discord.
// https://discord.com/developers/docs/resources/user#user-object-user-structure
#[derive(Debug, Serialize, Deserialize)]
struct User {
id: String,
avatar: Option<String>,
username: String,
discriminator: String,
}
// Session is optional
async fn index(user: Option<User>) -> impl IntoResponse {
match user {
Some(u) => format!(
"Hey {}! You're logged in!\nYou may now access `/protected`.\nLog out with `/logout`.",
u.username
),
None => "You're not logged in.\nVisit `/auth/discord` to do so.".to_string(),
}
}
async fn discord_auth(
State(client): State<BasicClient>,
State(store): State<MemoryStore>,
) -> Result<impl IntoResponse, AppError> {
let (auth_url, csrf_token) = client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("identify".to_string()))
.url();
// Create session to store csrf_token
let mut session = Session::new();
session
.insert(CSRF_TOKEN, &csrf_token)
.context("failed in inserting CSRF token into session")?;
// Store the session in MemoryStore and retrieve the session cookie
let cookie = store
.store_session(session)
.await
.context("failed to store CSRF token session")?
.context("unexpected error retrieving CSRF cookie value")?;
// Attach the session cookie to the response header
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");
let mut headers = HeaderMap::new();
headers.insert(
SET_COOKIE,
cookie.parse().context("failed to parse cookie")?,
);
Ok((headers, Redirect::to(auth_url.as_ref())))
}
// Valid user session required. If there is none, redirect to the auth page
async fn protected(user: User) -> impl IntoResponse {
format!("Welcome to the protected area :)\nHere's your info:\n{user:?}")
}
async fn logout(
State(store): State<MemoryStore>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> Result<impl IntoResponse, AppError> {
let cookie = cookies
.get(COOKIE_NAME)
.context("unexpected error getting cookie name")?;
let session = match store
.load_session(cookie.to_string())
.await
.context("failed to load session")?
{
Some(s) => s,
// No session active, just redirect
None => return Ok(Redirect::to("/")),
};
store
.destroy_session(session)
.await
.context("failed to destroy session")?;
Ok(Redirect::to("/"))
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct AuthRequest {
code: String,
state: String,
}
async fn csrf_token_validation_workflow(
auth_request: &AuthRequest,
cookies: &headers::Cookie,
store: &MemoryStore,
) -> Result<(), AppError> {
// Extract the cookie from the request
let cookie = cookies
.get(COOKIE_NAME)
.context("unexpected error getting cookie name")?
.to_string();
// Load the session
let session = match store
.load_session(cookie)
.await
.context("failed to load session")?
{
Some(session) => session,
None => return Err(anyhow!("Session not found").into()),
};
// Extract the CSRF token from the session
let stored_csrf_token = session
.get::<CsrfToken>(CSRF_TOKEN)
.context("CSRF token not found in session")?
.to_owned();
// Cleanup the CSRF token session
store
.destroy_session(session)
.await
.context("Failed to destroy old session")?;
// Validate CSRF token is the same as the one in the auth request
if *stored_csrf_token.secret() != auth_request.state {
return Err(anyhow!("CSRF token mismatch").into());
}
Ok(())
}
async fn login_authorized(
Query(query): Query<AuthRequest>,
State(store): State<MemoryStore>,
State(oauth_client): State<BasicClient>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> Result<impl IntoResponse, AppError> {
csrf_token_validation_workflow(&query, &cookies, &store).await?;
// Get an auth token
let token = oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone()))
.request_async(async_http_client)
.await
.context("failed in sending request request to authorization server")?;
// Fetch user data from discord
let client = reqwest::Client::new();
let user_data: User = client
// https://discord.com/developers/docs/resources/user#get-current-user
.get("https://discordapp.com/api/users/@me")
.bearer_auth(token.access_token().secret())
.send()
.await
.context("failed in sending request to target Url")?
.json::<User>()
.await
.context("failed to deserialize response as JSON")?;
// Create a new session filled with user data
let mut session = Session::new();
session
.insert("user", &user_data)
.context("failed in inserting serialized value into session")?;
// Store session and get corresponding cookie
let cookie = store
.store_session(session)
.await
.context("failed to store session")?
.context("unexpected error retrieving cookie value")?;
// Build the cookie
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");
// Set cookie
let mut headers = HeaderMap::new();
headers.insert(
SET_COOKIE,
cookie.parse().context("failed to parse cookie")?,
);
Ok((headers, Redirect::to("/")))
}
struct AuthRedirect;
impl IntoResponse for AuthRedirect {
fn into_response(self) -> Response {
Redirect::temporary("/auth/discord").into_response()
}
}
impl<S> FromRequestParts<S> for User
where
MemoryStore: FromRef<S>,
S: Send + Sync,
{
// If anything goes wrong or no session is found, redirect to the auth page
type Rejection = AuthRedirect;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let store = MemoryStore::from_ref(state);
let cookies = parts
.extract::<TypedHeader<headers::Cookie>>()
.await
.map_err(|e| match *e.name() {
header::COOKIE => match e.reason() {
TypedHeaderRejectionReason::Missing => AuthRedirect,
_ => panic!("unexpected error getting Cookie header(s): {e}"),
},
_ => panic!("unexpected error getting cookies: {e}"),
})?;
let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?;
let session = store
.load_session(session_cookie.to_string())
.await
.unwrap()
.ok_or(AuthRedirect)?;
let user = session.get::<User>("user").ok_or(AuthRedirect)?;
Ok(user)
}
}
impl<S> OptionalFromRequestParts<S> for User
where
MemoryStore: FromRef<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<Self>, Self::Rejection> {
match <User as FromRequestParts<S>>::from_request_parts(parts, state).await {
Ok(res) => Ok(Some(res)),
Err(AuthRedirect) => Ok(None),
}
}
}
// Use anyhow, define error and enable '?'
// For a simplified example of using anyhow in axum check /examples/anyhow-error-response
#[derive(Debug)]
struct AppError(anyhow::Error);
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
tracing::error!("Application error: {:#}", self.0);
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
}
}
// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}