1use axum::extract::Request;
4use axum::http::{HeaderName, HeaderValue};
5use axum::middleware::Next;
6
7static REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
8
9pub async fn request_id_middleware(mut req: Request, next: Next) -> axum::response::Response {
11 let request_id = req
13 .headers()
14 .get(&REQUEST_ID_HEADER)
15 .and_then(|v| v.to_str().ok())
16 .map(|s| s.to_string())
17 .unwrap_or_else(generate_request_id);
18
19 req.extensions_mut().insert(RequestId(request_id.clone()));
21
22 let mut response = next.run(req).await;
23
24 if let Ok(val) = HeaderValue::from_str(&request_id) {
26 response
27 .headers_mut()
28 .insert(REQUEST_ID_HEADER.clone(), val);
29 }
30
31 response
32}
33
34#[derive(Debug, Clone)]
36pub struct RequestId(pub String);
37
38fn generate_request_id() -> String {
39 use rand::Rng;
40 let mut bytes = [0u8; 16];
41 rand::thread_rng().fill(&mut bytes);
42 hex::encode(bytes)
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48
49 #[test]
50 fn generate_request_id_is_32_hex_chars() {
51 let id = generate_request_id();
52 assert_eq!(id.len(), 32);
53 assert!(id.chars().all(|c| c.is_ascii_hexdigit()));
54 }
55
56 #[test]
57 fn generate_request_id_is_unique() {
58 let a = generate_request_id();
59 let b = generate_request_id();
60 assert_ne!(a, b);
61 }
62}