Skip to main content

atrg_core/
cors.rs

1//! CORS layer builder.
2//!
3//! Constructs a [`tower_http::cors::CorsLayer`] from the allowed origins
4//! listed in `atrg.toml` under `[app].cors_origins`.
5
6use axum::http::{self, Method};
7use tower_http::cors::CorsLayer;
8
9/// Build a CORS layer from the configured allowed origins.
10///
11/// - Empty list → same-origin only (no `Access-Control-Allow-Origin` header).
12/// - `["*"]` → fully permissive (all origins, all methods).
13/// - Specific origins → allowlist with credentials support.
14pub fn build_cors_layer(origins: &[String]) -> CorsLayer {
15    if origins.is_empty() {
16        // Same-origin only — no CORS headers emitted.
17        CorsLayer::new()
18    } else if origins.len() == 1 && origins[0] == "*" {
19        tracing::warn!("CORS configured with wildcard '*' — all origins allowed");
20        CorsLayer::permissive()
21    } else {
22        let parsed_origins: Vec<http::HeaderValue> = origins
23            .iter()
24            .filter_map(|o| match o.parse::<http::HeaderValue>() {
25                Ok(v) => Some(v),
26                Err(e) => {
27                    tracing::warn!(origin = %o, error = %e, "Skipping invalid CORS origin");
28                    None
29                }
30            })
31            .collect();
32
33        CorsLayer::new()
34            .allow_origin(parsed_origins)
35            .allow_methods([
36                Method::GET,
37                Method::POST,
38                Method::PUT,
39                Method::DELETE,
40                Method::PATCH,
41                Method::OPTIONS,
42            ])
43            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
44            .allow_credentials(true)
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51
52    #[test]
53    fn empty_origins_returns_layer() {
54        // Should not panic — produces a restrictive (same-origin) layer.
55        let _layer = build_cors_layer(&[]);
56    }
57
58    #[test]
59    fn wildcard_returns_permissive_layer() {
60        let _layer = build_cors_layer(&["*".to_string()]);
61    }
62
63    #[test]
64    fn specific_origins_returns_layer() {
65        let origins = vec![
66            "http://localhost:5173".to_string(),
67            "https://myapp.example.com".to_string(),
68        ];
69        let _layer = build_cors_layer(&origins);
70    }
71
72    #[test]
73    fn invalid_origin_is_skipped() {
74        // A header value with invalid bytes should be filtered out without panic.
75        let origins = vec![
76            "http://localhost:5173".to_string(),
77            "not a valid \x00 header".to_string(),
78        ];
79        let _layer = build_cors_layer(&origins);
80    }
81}