1use axum::extract::{Query, State};
13use axum::http::{HeaderValue, StatusCode};
14use axum::response::{IntoResponse, Response};
15use axum::routing::{get, post};
16use axum::{Json, Router};
17
18use atrg_core::error::AtrgError;
19use atrg_core::state::AppState;
20
21use crate::extractor::RequireAuth;
22use crate::session;
23
24pub fn routes() -> Router<AppState> {
26 Router::new()
27 .route("/auth/login", get(login))
28 .route("/auth/callback", get(callback))
29 .route("/auth/logout", post(logout))
30 .route("/auth/session", get(get_session))
31}
32
33pub fn auth_router() -> Router<AppState> {
47 routes()
48 .route("/client-metadata.json", get(client_metadata))
49 .route("/.well-known/oauth-protected-resource", get(well_known))
50}
51
52pub async fn client_metadata(State(state): State<AppState>) -> Json<serde_json::Value> {
56 let config = &state.config.auth;
57 Json(serde_json::json!({
58 "client_id": config.client_id,
59 "client_name": state.config.app.name,
60 "client_uri": format!("http://{}:{}", state.config.app.host, state.config.app.port),
61 "redirect_uris": [config.redirect_uri],
62 "scope": config.scope,
63 "grant_types": ["authorization_code", "refresh_token"],
64 "response_types": ["code"],
65 "application_type": "web",
66 "token_endpoint_auth_method": "none",
67 "dpop_bound_access_tokens": true,
68 }))
69}
70
71pub async fn well_known(State(state): State<AppState>) -> Json<serde_json::Value> {
73 let base_url = format!("http://{}:{}", state.config.app.host, state.config.app.port);
74 Json(serde_json::json!({
75 "resource": base_url,
76 "authorization_servers": [],
77 "scopes_supported": [state.config.auth.scope],
78 "bearer_methods_supported": ["header"],
79 }))
80}
81
82#[derive(serde::Deserialize)]
84pub struct LoginQuery {
85 handle: Option<String>,
87}
88
89async fn login(
95 State(_state): State<AppState>,
96 Query(params): Query<LoginQuery>,
97) -> Result<Response, AtrgError> {
98 let handle = params
99 .handle
100 .filter(|h| !h.trim().is_empty())
101 .ok_or_else(|| AtrgError::BadRequest("missing 'handle' query parameter".to_string()))?;
102
103 tracing::info!(handle = %handle, "OAuth login initiated");
104
105 Ok((
108 StatusCode::OK,
109 Json(serde_json::json!({
110 "status": "oauth_not_yet_wired",
111 "message": "OAuth PKCE flow will be wired via atproto-oauth-axum in the next iteration. For now, use the session injection API for testing.",
112 "handle": handle,
113 })),
114 )
115 .into_response())
116}
117
118async fn callback(State(_state): State<AppState>) -> Result<Response, AtrgError> {
122 Ok((
125 StatusCode::OK,
126 Json(serde_json::json!({
127 "status": "callback_stub",
128 "message": "OAuth callback will be implemented with atproto-oauth-axum.",
129 })),
130 )
131 .into_response())
132}
133
134async fn logout(
138 State(state): State<AppState>,
139 headers: axum::http::HeaderMap,
140) -> Result<Response, AtrgError> {
141 let session_id = extract_session_id(&headers);
143
144 if let Some(sid) = session_id {
145 session::delete_session(&state.db, sid)
146 .await
147 .map_err(AtrgError::Internal)?;
148 tracing::info!("session deleted via logout");
149 }
150
151 let mut response = StatusCode::NO_CONTENT.into_response();
153
154 let is_secure = state.config.app.environment != "development";
155 let cookie_value = format!(
156 "atrg_session=; Path=/; Max-Age=0; HttpOnly; SameSite=Lax{}",
157 if is_secure { "; Secure" } else { "" }
158 );
159
160 if let Ok(val) = HeaderValue::from_str(&cookie_value) {
161 response.headers_mut().insert("set-cookie", val);
162 }
163
164 Ok(response)
165}
166
167async fn get_session(RequireAuth(user): RequireAuth) -> Json<serde_json::Value> {
171 Json(serde_json::json!({
172 "did": user.did,
173 "handle": user.handle,
174 "expires_at": user.expires_at,
175 }))
176}
177
178fn extract_session_id(headers: &axum::http::HeaderMap) -> Option<&str> {
180 if let Some(auth) = headers.get(axum::http::header::AUTHORIZATION) {
182 if let Ok(s) = auth.to_str() {
183 if let Some(token) = s.strip_prefix("Bearer ") {
184 return Some(token.trim());
185 }
186 }
187 }
188
189 if let Some(cookie) = headers.get(axum::http::header::COOKIE) {
191 if let Ok(cookies) = cookie.to_str() {
192 return crate::extractor::extract_cookie_value(cookies, "atrg_session");
193 }
194 }
195
196 None
197}
198
199pub fn spawn_cleanup_task(pool: sqlx::SqlitePool) {
201 tokio::spawn(async move {
202 let mut interval = tokio::time::interval(std::time::Duration::from_secs(600)); loop {
204 interval.tick().await;
205 if let Err(e) = session::cleanup_expired_sessions(&pool).await {
206 tracing::warn!(error = %e, "session cleanup failed");
207 }
208 if let Err(e) = session::cleanup_expired_oauth_states(&pool).await {
209 tracing::warn!(error = %e, "oauth state cleanup failed");
210 }
211 }
212 });
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use std::sync::Arc;
219
220 use atrg_core::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
221 use axum::body::Body;
222 use http_body_util::BodyExt;
223 use hyper::Request;
224 use tower::ServiceExt;
225
226 async fn test_state() -> AppState {
227 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
228 atrg_db::run_internal_migrations(&db).await.unwrap();
229 AppState {
230 config: Arc::new(Config {
231 app: AppConfig {
232 name: "test".into(),
233 host: "127.0.0.1".into(),
234 port: 3000,
235 secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
236 cors_origins: vec![],
237 environment: "development".into(),
238 },
239 auth: AuthConfig {
240 client_id: "http://localhost:3000/client-metadata.json".into(),
241 redirect_uri: "http://localhost:3000/auth/callback".into(),
242 scope: "atproto transition:generic".into(),
243 },
244 database: DatabaseConfig {
245 url: "sqlite::memory:".into(),
246 },
247 jetstream: None,
248 firehose: None,
249 feed_generator: None,
250 labeler: None,
251 rate_limit: None,
252 }),
253 db,
254 http: reqwest::Client::new(),
255 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
256 reqwest::Client::new(),
257 )),
258 }
259 }
260
261 fn test_router(state: AppState) -> Router {
262 Router::new()
263 .merge(routes())
264 .route("/client-metadata.json", get(client_metadata))
265 .route("/.well-known/oauth-protected-resource", get(well_known))
266 .with_state(state)
267 }
268
269 async fn body_json(resp: axum::response::Response) -> serde_json::Value {
270 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
271 serde_json::from_slice(&bytes).unwrap()
272 }
273
274 #[tokio::test]
275 async fn client_metadata_has_required_fields() {
276 let state = test_state().await;
277 let app = test_router(state);
278 let resp = app
279 .oneshot(
280 Request::get("/client-metadata.json")
281 .body(Body::empty())
282 .unwrap(),
283 )
284 .await
285 .unwrap();
286 assert_eq!(resp.status(), 200);
287 let body = body_json(resp).await;
288 assert!(body["client_id"].is_string());
289 assert!(body["redirect_uris"].is_array());
290 assert!(body["scope"].is_string());
291 assert!(body["application_type"].is_string());
292 assert!(body["grant_types"].is_array());
293 assert!(body["response_types"].is_array());
294 assert!(body["dpop_bound_access_tokens"].is_boolean());
295 }
296
297 #[tokio::test]
298 async fn well_known_returns_json() {
299 let state = test_state().await;
300 let app = test_router(state);
301 let resp = app
302 .oneshot(
303 Request::get("/.well-known/oauth-protected-resource")
304 .body(Body::empty())
305 .unwrap(),
306 )
307 .await
308 .unwrap();
309 assert_eq!(resp.status(), 200);
310 let body = body_json(resp).await;
311 assert!(body["resource"].is_string());
312 assert!(body["scopes_supported"].is_array());
313 }
314
315 #[tokio::test]
316 async fn login_without_handle_returns_400() {
317 let state = test_state().await;
318 let app = test_router(state);
319 let resp = app
320 .oneshot(Request::get("/auth/login").body(Body::empty()).unwrap())
321 .await
322 .unwrap();
323 assert_eq!(resp.status(), 400);
324 }
325
326 #[tokio::test]
327 async fn session_without_auth_returns_401() {
328 let state = test_state().await;
329 let app = test_router(state);
330 let resp = app
331 .oneshot(Request::get("/auth/session").body(Body::empty()).unwrap())
332 .await
333 .unwrap();
334 assert_eq!(resp.status(), 401);
335 }
336
337 #[tokio::test]
338 async fn session_with_valid_session_returns_200() {
339 let state = test_state().await;
340 let sid = session::generate_session_id();
341 let expires = std::time::SystemTime::now()
342 .duration_since(std::time::UNIX_EPOCH)
343 .unwrap()
344 .as_secs() as i64
345 + 86400;
346 session::create_session(
347 &state.db,
348 &sid,
349 "did:plc:test",
350 "alice.test",
351 "tok",
352 None,
353 expires,
354 )
355 .await
356 .unwrap();
357
358 let app = test_router(state);
359 let resp = app
360 .oneshot(
361 Request::get("/auth/session")
362 .header("cookie", format!("atrg_session={sid}"))
363 .body(Body::empty())
364 .unwrap(),
365 )
366 .await
367 .unwrap();
368 assert_eq!(resp.status(), 200);
369 let body = body_json(resp).await;
370 assert_eq!(body["did"], "did:plc:test");
371 assert_eq!(body["handle"], "alice.test");
372 }
373
374 #[tokio::test]
375 async fn session_with_bearer_token_returns_200() {
376 let state = test_state().await;
377 let sid = session::generate_session_id();
378 let expires = std::time::SystemTime::now()
379 .duration_since(std::time::UNIX_EPOCH)
380 .unwrap()
381 .as_secs() as i64
382 + 86400;
383 session::create_session(
384 &state.db,
385 &sid,
386 "did:plc:bearer",
387 "bob.test",
388 "tok",
389 None,
390 expires,
391 )
392 .await
393 .unwrap();
394
395 let app = test_router(state);
396 let resp = app
397 .oneshot(
398 Request::get("/auth/session")
399 .header("authorization", format!("Bearer {sid}"))
400 .body(Body::empty())
401 .unwrap(),
402 )
403 .await
404 .unwrap();
405 assert_eq!(resp.status(), 200);
406 let body = body_json(resp).await;
407 assert_eq!(body["did"], "did:plc:bearer");
408 }
409
410 #[tokio::test]
411 async fn logout_clears_session() {
412 let state = test_state().await;
413 let sid = session::generate_session_id();
414 let expires = std::time::SystemTime::now()
415 .duration_since(std::time::UNIX_EPOCH)
416 .unwrap()
417 .as_secs() as i64
418 + 86400;
419 session::create_session(
420 &state.db,
421 &sid,
422 "did:plc:logout",
423 "logout.test",
424 "tok",
425 None,
426 expires,
427 )
428 .await
429 .unwrap();
430
431 let app = test_router(state.clone());
432 let resp = app
433 .oneshot(
434 Request::post("/auth/logout")
435 .header("cookie", format!("atrg_session={sid}"))
436 .body(Body::empty())
437 .unwrap(),
438 )
439 .await
440 .unwrap();
441 assert_eq!(resp.status(), 204);
442 assert!(resp.headers().get("set-cookie").is_some());
443
444 let s = session::find_session(&state.db, &sid).await.unwrap();
446 assert!(s.is_none());
447 }
448}