Skip to main content

atrg_auth/
routes.rs

1//! OAuth and session HTTP routes.
2//!
3//! These routes are mounted automatically by `AtrgApp::run()`:
4//!
5//! - `GET /auth/login?handle=<handle>` — initiate OAuth
6//! - `GET /auth/callback` — OAuth callback
7//! - `POST /auth/logout` — clear session
8//! - `GET /auth/session` — current session info (JSON)
9//! - `GET /client-metadata.json` — OAuth client metadata
10//! - `GET /.well-known/oauth-protected-resource` — OAuth protected resource metadata
11
12use axum::extract::{Query, State};
13use axum::http::{HeaderValue, StatusCode};
14use axum::response::{IntoResponse, Response};
15use axum::routing::{get, post};
16use axum::{Json, Router};
17
18use atrg_core::error::AtrgError;
19use atrg_core::state::AppState;
20
21use crate::extractor::RequireAuth;
22use crate::session;
23
24/// Build the auth router with all authentication routes.
25pub fn routes() -> Router<AppState> {
26    Router::new()
27        .route("/auth/login", get(login))
28        .route("/auth/callback", get(callback))
29        .route("/auth/logout", post(logout))
30        .route("/auth/session", get(get_session))
31}
32
33/// Returns a single router containing all atrg built-in auth routes:
34/// `/auth/*`, `/client-metadata.json`, and `/.well-known/oauth-protected-resource`.
35///
36/// This is the recommended way to wire auth into [`atrg_core::AtrgApp`]:
37///
38/// ```rust,no_run
39/// use atrg_core::AtrgApp;
40///
41/// AtrgApp::new()
42///     .with_auth_routes(atrg_auth::routes::auth_router())
43///     .with_cleanup_task(atrg_auth::routes::spawn_cleanup_task)
44///     .run();
45/// ```
46pub fn auth_router() -> Router<AppState> {
47    routes()
48        .route("/client-metadata.json", get(client_metadata))
49        .route("/.well-known/oauth-protected-resource", get(well_known))
50}
51
52/// OAuth client metadata endpoint.
53///
54/// Returns the JSON document required by the AT Protocol OAuth spec.
55pub async fn client_metadata(State(state): State<AppState>) -> Json<serde_json::Value> {
56    let config = &state.config.auth;
57    Json(serde_json::json!({
58        "client_id": config.client_id,
59        "client_name": state.config.app.name,
60        "client_uri": format!("http://{}:{}", state.config.app.host, state.config.app.port),
61        "redirect_uris": [config.redirect_uri],
62        "scope": config.scope,
63        "grant_types": ["authorization_code", "refresh_token"],
64        "response_types": ["code"],
65        "application_type": "web",
66        "token_endpoint_auth_method": "none",
67        "dpop_bound_access_tokens": true,
68    }))
69}
70
71/// OAuth protected resource metadata endpoint.
72pub async fn well_known(State(state): State<AppState>) -> Json<serde_json::Value> {
73    let base_url = format!("http://{}:{}", state.config.app.host, state.config.app.port);
74    Json(serde_json::json!({
75        "resource": base_url,
76        "authorization_servers": [],
77        "scopes_supported": [state.config.auth.scope],
78        "bearer_methods_supported": ["header"],
79    }))
80}
81
82/// Login query parameters.
83#[derive(serde::Deserialize)]
84pub struct LoginQuery {
85    /// The user's AT Protocol handle.
86    handle: Option<String>,
87}
88
89/// `GET /auth/login?handle=<handle>`
90///
91/// In a full implementation, this initiates the OAuth PKCE flow with the
92/// user's PDS. For now, this is a stub that validates the handle parameter
93/// and returns an error explaining OAuth is not yet wired to a real PDS.
94async fn login(
95    State(_state): State<AppState>,
96    Query(params): Query<LoginQuery>,
97) -> Result<Response, AtrgError> {
98    let handle = params
99        .handle
100        .filter(|h| !h.trim().is_empty())
101        .ok_or_else(|| AtrgError::BadRequest("missing 'handle' query parameter".to_string()))?;
102
103    tracing::info!(handle = %handle, "OAuth login initiated");
104
105    // TODO(phase2-full): Wire up atproto-oauth-axum for real OAuth flow.
106    // For now, return a JSON response explaining the flow.
107    Ok((
108        StatusCode::OK,
109        Json(serde_json::json!({
110            "status": "oauth_not_yet_wired",
111            "message": "OAuth PKCE flow will be wired via atproto-oauth-axum in the next iteration. For now, use the session injection API for testing.",
112            "handle": handle,
113        })),
114    )
115        .into_response())
116}
117
118/// `GET /auth/callback`
119///
120/// OAuth callback handler. Stub for now.
121async fn callback(State(_state): State<AppState>) -> Result<Response, AtrgError> {
122    // TODO(phase2-full): Process OAuth callback, exchange code for tokens,
123    // create session, set cookie, redirect.
124    Ok((
125        StatusCode::OK,
126        Json(serde_json::json!({
127            "status": "callback_stub",
128            "message": "OAuth callback will be implemented with atproto-oauth-axum.",
129        })),
130    )
131        .into_response())
132}
133
134/// `POST /auth/logout`
135///
136/// Clears the session cookie and deletes the session from the database.
137async fn logout(
138    State(state): State<AppState>,
139    headers: axum::http::HeaderMap,
140) -> Result<Response, AtrgError> {
141    // Try to find session ID from cookie or bearer
142    let session_id = extract_session_id(&headers);
143
144    if let Some(sid) = session_id {
145        session::delete_session(&state.db, sid)
146            .await
147            .map_err(AtrgError::Internal)?;
148        tracing::info!("session deleted via logout");
149    }
150
151    // Build response with cookie clearing
152    let mut response = StatusCode::NO_CONTENT.into_response();
153
154    let is_secure = state.config.app.environment != "development";
155    let cookie_value = format!(
156        "atrg_session=; Path=/; Max-Age=0; HttpOnly; SameSite=Lax{}",
157        if is_secure { "; Secure" } else { "" }
158    );
159
160    if let Ok(val) = HeaderValue::from_str(&cookie_value) {
161        response.headers_mut().insert("set-cookie", val);
162    }
163
164    Ok(response)
165}
166
167/// `GET /auth/session`
168///
169/// Returns the current session info or 401.
170async fn get_session(RequireAuth(user): RequireAuth) -> Json<serde_json::Value> {
171    Json(serde_json::json!({
172        "did": user.did,
173        "handle": user.handle,
174        "expires_at": user.expires_at,
175    }))
176}
177
178/// Extract session ID from Authorization header or cookie.
179fn extract_session_id(headers: &axum::http::HeaderMap) -> Option<&str> {
180    // Try bearer token
181    if let Some(auth) = headers.get(axum::http::header::AUTHORIZATION) {
182        if let Ok(s) = auth.to_str() {
183            if let Some(token) = s.strip_prefix("Bearer ") {
184                return Some(token.trim());
185            }
186        }
187    }
188
189    // Try cookie
190    if let Some(cookie) = headers.get(axum::http::header::COOKIE) {
191        if let Ok(cookies) = cookie.to_str() {
192            return crate::extractor::extract_cookie_value(cookies, "atrg_session");
193        }
194    }
195
196    None
197}
198
199/// Spawn a periodic cleanup task for expired OAuth states and sessions.
200pub fn spawn_cleanup_task(pool: sqlx::SqlitePool) {
201    tokio::spawn(async move {
202        let mut interval = tokio::time::interval(std::time::Duration::from_secs(600)); // every 10 min
203        loop {
204            interval.tick().await;
205            if let Err(e) = session::cleanup_expired_sessions(&pool).await {
206                tracing::warn!(error = %e, "session cleanup failed");
207            }
208            if let Err(e) = session::cleanup_expired_oauth_states(&pool).await {
209                tracing::warn!(error = %e, "oauth state cleanup failed");
210            }
211        }
212    });
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use std::sync::Arc;
219
220    use atrg_core::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
221    use axum::body::Body;
222    use http_body_util::BodyExt;
223    use hyper::Request;
224    use tower::ServiceExt;
225
226    async fn test_state() -> AppState {
227        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
228        atrg_db::run_internal_migrations(&db).await.unwrap();
229        AppState {
230            config: Arc::new(Config {
231                app: AppConfig {
232                    name: "test".into(),
233                    host: "127.0.0.1".into(),
234                    port: 3000,
235                    secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
236                    cors_origins: vec![],
237                    environment: "development".into(),
238                },
239                auth: AuthConfig {
240                    client_id: "http://localhost:3000/client-metadata.json".into(),
241                    redirect_uri: "http://localhost:3000/auth/callback".into(),
242                    scope: "atproto transition:generic".into(),
243                },
244                database: DatabaseConfig {
245                    url: "sqlite::memory:".into(),
246                },
247                jetstream: None,
248                firehose: None,
249                feed_generator: None,
250                labeler: None,
251                rate_limit: None,
252            }),
253            db,
254            http: reqwest::Client::new(),
255            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
256                reqwest::Client::new(),
257            )),
258        }
259    }
260
261    fn test_router(state: AppState) -> Router {
262        Router::new()
263            .merge(routes())
264            .route("/client-metadata.json", get(client_metadata))
265            .route("/.well-known/oauth-protected-resource", get(well_known))
266            .with_state(state)
267    }
268
269    async fn body_json(resp: axum::response::Response) -> serde_json::Value {
270        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
271        serde_json::from_slice(&bytes).unwrap()
272    }
273
274    #[tokio::test]
275    async fn client_metadata_has_required_fields() {
276        let state = test_state().await;
277        let app = test_router(state);
278        let resp = app
279            .oneshot(
280                Request::get("/client-metadata.json")
281                    .body(Body::empty())
282                    .unwrap(),
283            )
284            .await
285            .unwrap();
286        assert_eq!(resp.status(), 200);
287        let body = body_json(resp).await;
288        assert!(body["client_id"].is_string());
289        assert!(body["redirect_uris"].is_array());
290        assert!(body["scope"].is_string());
291        assert!(body["application_type"].is_string());
292        assert!(body["grant_types"].is_array());
293        assert!(body["response_types"].is_array());
294        assert!(body["dpop_bound_access_tokens"].is_boolean());
295    }
296
297    #[tokio::test]
298    async fn well_known_returns_json() {
299        let state = test_state().await;
300        let app = test_router(state);
301        let resp = app
302            .oneshot(
303                Request::get("/.well-known/oauth-protected-resource")
304                    .body(Body::empty())
305                    .unwrap(),
306            )
307            .await
308            .unwrap();
309        assert_eq!(resp.status(), 200);
310        let body = body_json(resp).await;
311        assert!(body["resource"].is_string());
312        assert!(body["scopes_supported"].is_array());
313    }
314
315    #[tokio::test]
316    async fn login_without_handle_returns_400() {
317        let state = test_state().await;
318        let app = test_router(state);
319        let resp = app
320            .oneshot(Request::get("/auth/login").body(Body::empty()).unwrap())
321            .await
322            .unwrap();
323        assert_eq!(resp.status(), 400);
324    }
325
326    #[tokio::test]
327    async fn session_without_auth_returns_401() {
328        let state = test_state().await;
329        let app = test_router(state);
330        let resp = app
331            .oneshot(Request::get("/auth/session").body(Body::empty()).unwrap())
332            .await
333            .unwrap();
334        assert_eq!(resp.status(), 401);
335    }
336
337    #[tokio::test]
338    async fn session_with_valid_session_returns_200() {
339        let state = test_state().await;
340        let sid = session::generate_session_id();
341        let expires = std::time::SystemTime::now()
342            .duration_since(std::time::UNIX_EPOCH)
343            .unwrap()
344            .as_secs() as i64
345            + 86400;
346        session::create_session(
347            &state.db,
348            &sid,
349            "did:plc:test",
350            "alice.test",
351            "tok",
352            None,
353            expires,
354        )
355        .await
356        .unwrap();
357
358        let app = test_router(state);
359        let resp = app
360            .oneshot(
361                Request::get("/auth/session")
362                    .header("cookie", format!("atrg_session={sid}"))
363                    .body(Body::empty())
364                    .unwrap(),
365            )
366            .await
367            .unwrap();
368        assert_eq!(resp.status(), 200);
369        let body = body_json(resp).await;
370        assert_eq!(body["did"], "did:plc:test");
371        assert_eq!(body["handle"], "alice.test");
372    }
373
374    #[tokio::test]
375    async fn session_with_bearer_token_returns_200() {
376        let state = test_state().await;
377        let sid = session::generate_session_id();
378        let expires = std::time::SystemTime::now()
379            .duration_since(std::time::UNIX_EPOCH)
380            .unwrap()
381            .as_secs() as i64
382            + 86400;
383        session::create_session(
384            &state.db,
385            &sid,
386            "did:plc:bearer",
387            "bob.test",
388            "tok",
389            None,
390            expires,
391        )
392        .await
393        .unwrap();
394
395        let app = test_router(state);
396        let resp = app
397            .oneshot(
398                Request::get("/auth/session")
399                    .header("authorization", format!("Bearer {sid}"))
400                    .body(Body::empty())
401                    .unwrap(),
402            )
403            .await
404            .unwrap();
405        assert_eq!(resp.status(), 200);
406        let body = body_json(resp).await;
407        assert_eq!(body["did"], "did:plc:bearer");
408    }
409
410    #[tokio::test]
411    async fn logout_clears_session() {
412        let state = test_state().await;
413        let sid = session::generate_session_id();
414        let expires = std::time::SystemTime::now()
415            .duration_since(std::time::UNIX_EPOCH)
416            .unwrap()
417            .as_secs() as i64
418            + 86400;
419        session::create_session(
420            &state.db,
421            &sid,
422            "did:plc:logout",
423            "logout.test",
424            "tok",
425            None,
426            expires,
427        )
428        .await
429        .unwrap();
430
431        let app = test_router(state.clone());
432        let resp = app
433            .oneshot(
434                Request::post("/auth/logout")
435                    .header("cookie", format!("atrg_session={sid}"))
436                    .body(Body::empty())
437                    .unwrap(),
438            )
439            .await
440            .unwrap();
441        assert_eq!(resp.status(), 204);
442        assert!(resp.headers().get("set-cookie").is_some());
443
444        // Session should be gone
445        let s = session::find_session(&state.db, &sid).await.unwrap();
446        assert!(s.is_none());
447    }
448}