Skip to main content

atrg_auth/
extractor.rs

1//! Axum extractors for authentication.
2//!
3//! - [`AuthUser`] — optional; returns `None` if not authenticated.
4//! - [`RequireAuth`] — strict; rejects with 401 if not authenticated.
5
6use axum::extract::{FromRef, FromRequestParts};
7use axum::http::header;
8use axum::http::request::Parts;
9
10use atrg_core::error::AtrgError;
11use atrg_core::state::AppState;
12
13use crate::jwt;
14use crate::session::{self, AtrgSession, AuthSource};
15
16/// Optional authentication extractor.
17///
18/// Reads the `atrg_session` cookie or `Authorization: Bearer` header
19/// and resolves the session. Returns `AuthUser(None)` if no valid
20/// credential is found — does NOT reject the request.
21///
22/// ```rust,ignore
23/// async fn handler(AuthUser(user): AuthUser) -> impl IntoResponse {
24///     match user {
25///         Some(session) => Json(json!({"did": session.did})),
26///         None => Json(json!({"authenticated": false})),
27///     }
28/// }
29/// ```
30pub struct AuthUser(pub Option<AtrgSession>);
31
32/// Strict authentication extractor.
33///
34/// Same logic as `AuthUser`, but rejects with `401 Unauthorized` JSON
35/// if no valid session is found.
36///
37/// ```rust,ignore
38/// async fn handler(RequireAuth(session): RequireAuth) -> impl IntoResponse {
39///     Json(json!({"did": session.did}))
40/// }
41/// ```
42pub struct RequireAuth(pub AtrgSession);
43
44impl<S> FromRequestParts<S> for AuthUser
45where
46    S: Send + Sync,
47    AppState: axum::extract::FromRef<S>,
48{
49    type Rejection = std::convert::Infallible;
50
51    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
52        let app_state = AppState::from_ref(state);
53        let session = resolve_session(parts, &app_state).await;
54        Ok(AuthUser(session))
55    }
56}
57
58impl<S> FromRequestParts<S> for RequireAuth
59where
60    S: Send + Sync,
61    AppState: axum::extract::FromRef<S>,
62{
63    type Rejection = AtrgError;
64
65    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
66        let app_state = AppState::from_ref(state);
67        let session = resolve_session(parts, &app_state).await;
68        match session {
69            Some(s) => Ok(RequireAuth(s)),
70            None => Err(AtrgError::Auth("unauthenticated".to_string())),
71        }
72    }
73}
74
75/// Core session resolution logic shared by both extractors.
76///
77/// Priority:
78/// 1. `Authorization: Bearer <token>` header
79///    a. If token looks like a JWT → try AT Protocol JWT verification
80///    b. Otherwise → look up as atrg session ID
81/// 2. `atrg_session=<id>` cookie → look up as atrg session ID
82async fn resolve_session(parts: &Parts, state: &AppState) -> Option<AtrgSession> {
83    // Try Authorization header first
84    if let Some(auth_header) = parts.headers.get(header::AUTHORIZATION) {
85        if let Ok(auth_str) = auth_header.to_str() {
86            if let Some(token) = auth_str.strip_prefix("Bearer ") {
87                let token = token.trim();
88                if !token.is_empty() {
89                    return resolve_bearer_token(token, state).await;
90                }
91            }
92        }
93    }
94
95    // Fall back to cookie
96    if let Some(cookie_header) = parts.headers.get(header::COOKIE) {
97        if let Ok(cookies) = cookie_header.to_str() {
98            if let Some(session_id) = extract_cookie_value(cookies, "atrg_session") {
99                return resolve_atrg_session(session_id, state).await;
100            }
101        }
102    }
103
104    None
105}
106
107/// Resolve a bearer token — either JWT or atrg session.
108async fn resolve_bearer_token(token: &str, state: &AppState) -> Option<AtrgSession> {
109    // If it looks like a JWT, try parsing it as an AT Protocol JWT
110    if jwt::looks_like_jwt(token) {
111        if let Ok(claims) = jwt::decode_claims_unverified(token) {
112            // Verify expiration
113            if jwt::verify_expiration(&claims).is_ok() {
114                // Verify audience against our host
115                let host = &state.config.app.host;
116                if jwt::verify_audience(&claims, host).is_ok() || claims.aud.is_none() {
117                    tracing::debug!(
118                        iss = %claims.iss,
119                        sub = %claims.sub,
120                        "accepted AT Protocol JWT (unverified signature — full verification requires identity resolver)"
121                    );
122                    return Some(AtrgSession {
123                        did: claims.sub,
124                        handle: String::new(), // handle not in JWT; caller can resolve
125                        access_token: token.to_string(),
126                        refresh_token: None,
127                        expires_at: claims.exp.unwrap_or(0) as i64,
128                        source: AuthSource::AtprotoJwt,
129                    });
130                }
131            }
132        }
133        // If JWT parsing failed, fall through to session lookup
134    }
135
136    // Try as atrg session token
137    resolve_atrg_session(token, state).await
138}
139
140/// Look up an atrg session by ID from the database.
141async fn resolve_atrg_session(session_id: &str, state: &AppState) -> Option<AtrgSession> {
142    match session::find_session(&state.db, session_id).await {
143        Ok(session) => session,
144        Err(e) => {
145            tracing::warn!(error = %e, "failed to look up session");
146            None
147        }
148    }
149}
150
151/// Parse a cookie header string and extract the value for a given name.
152pub(crate) fn extract_cookie_value<'a>(cookies: &'a str, name: &str) -> Option<&'a str> {
153    cookies.split(';').map(|s| s.trim()).find_map(|cookie| {
154        let (key, value) = cookie.split_once('=')?;
155        if key.trim() == name {
156            Some(value.trim())
157        } else {
158            None
159        }
160    })
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn extract_cookie_value_present() {
169        let cookies = "foo=bar; atrg_session=abc123; other=val";
170        assert_eq!(
171            extract_cookie_value(cookies, "atrg_session"),
172            Some("abc123")
173        );
174    }
175
176    #[test]
177    fn extract_cookie_value_missing() {
178        let cookies = "foo=bar; other=val";
179        assert_eq!(extract_cookie_value(cookies, "atrg_session"), None);
180    }
181
182    #[test]
183    fn extract_cookie_value_empty() {
184        assert_eq!(extract_cookie_value("", "atrg_session"), None);
185    }
186
187    #[test]
188    fn extract_cookie_value_single() {
189        assert_eq!(
190            extract_cookie_value("atrg_session=xyz", "atrg_session"),
191            Some("xyz")
192        );
193    }
194}