1use axum::extract::{FromRef, FromRequestParts};
7use axum::http::header;
8use axum::http::request::Parts;
9
10use atrg_core::error::AtrgError;
11use atrg_core::state::AppState;
12
13use crate::jwt;
14use crate::session::{self, AtrgSession, AuthSource};
15
16pub struct AuthUser(pub Option<AtrgSession>);
31
32pub struct RequireAuth(pub AtrgSession);
43
44impl<S> FromRequestParts<S> for AuthUser
45where
46 S: Send + Sync,
47 AppState: axum::extract::FromRef<S>,
48{
49 type Rejection = std::convert::Infallible;
50
51 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
52 let app_state = AppState::from_ref(state);
53 let session = resolve_session(parts, &app_state).await;
54 Ok(AuthUser(session))
55 }
56}
57
58impl<S> FromRequestParts<S> for RequireAuth
59where
60 S: Send + Sync,
61 AppState: axum::extract::FromRef<S>,
62{
63 type Rejection = AtrgError;
64
65 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
66 let app_state = AppState::from_ref(state);
67 let session = resolve_session(parts, &app_state).await;
68 match session {
69 Some(s) => Ok(RequireAuth(s)),
70 None => Err(AtrgError::Auth("unauthenticated".to_string())),
71 }
72 }
73}
74
75async fn resolve_session(parts: &Parts, state: &AppState) -> Option<AtrgSession> {
83 if let Some(auth_header) = parts.headers.get(header::AUTHORIZATION) {
85 if let Ok(auth_str) = auth_header.to_str() {
86 if let Some(token) = auth_str.strip_prefix("Bearer ") {
87 let token = token.trim();
88 if !token.is_empty() {
89 return resolve_bearer_token(token, state).await;
90 }
91 }
92 }
93 }
94
95 if let Some(cookie_header) = parts.headers.get(header::COOKIE) {
97 if let Ok(cookies) = cookie_header.to_str() {
98 if let Some(session_id) = extract_cookie_value(cookies, "atrg_session") {
99 return resolve_atrg_session(session_id, state).await;
100 }
101 }
102 }
103
104 None
105}
106
107async fn resolve_bearer_token(token: &str, state: &AppState) -> Option<AtrgSession> {
109 if jwt::looks_like_jwt(token) {
111 if let Ok(claims) = jwt::decode_claims_unverified(token) {
112 if jwt::verify_expiration(&claims).is_ok() {
114 let host = &state.config.app.host;
116 if jwt::verify_audience(&claims, host).is_ok() || claims.aud.is_none() {
117 tracing::debug!(
118 iss = %claims.iss,
119 sub = %claims.sub,
120 "accepted AT Protocol JWT (unverified signature — full verification requires identity resolver)"
121 );
122 return Some(AtrgSession {
123 did: claims.sub,
124 handle: String::new(), access_token: token.to_string(),
126 refresh_token: None,
127 expires_at: claims.exp.unwrap_or(0) as i64,
128 source: AuthSource::AtprotoJwt,
129 });
130 }
131 }
132 }
133 }
135
136 resolve_atrg_session(token, state).await
138}
139
140async fn resolve_atrg_session(session_id: &str, state: &AppState) -> Option<AtrgSession> {
142 match session::find_session(&state.db, session_id).await {
143 Ok(session) => session,
144 Err(e) => {
145 tracing::warn!(error = %e, "failed to look up session");
146 None
147 }
148 }
149}
150
151pub(crate) fn extract_cookie_value<'a>(cookies: &'a str, name: &str) -> Option<&'a str> {
153 cookies.split(';').map(|s| s.trim()).find_map(|cookie| {
154 let (key, value) = cookie.split_once('=')?;
155 if key.trim() == name {
156 Some(value.trim())
157 } else {
158 None
159 }
160 })
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn extract_cookie_value_present() {
169 let cookies = "foo=bar; atrg_session=abc123; other=val";
170 assert_eq!(
171 extract_cookie_value(cookies, "atrg_session"),
172 Some("abc123")
173 );
174 }
175
176 #[test]
177 fn extract_cookie_value_missing() {
178 let cookies = "foo=bar; other=val";
179 assert_eq!(extract_cookie_value(cookies, "atrg_session"), None);
180 }
181
182 #[test]
183 fn extract_cookie_value_empty() {
184 assert_eq!(extract_cookie_value("", "atrg_session"), None);
185 }
186
187 #[test]
188 fn extract_cookie_value_single() {
189 assert_eq!(
190 extract_cookie_value("atrg_session=xyz", "atrg_session"),
191 Some("xyz")
192 );
193 }
194}