Skip to main content

atrg_auth/
session.rs

1//! Session types and database operations.
2
3use rand::Rng;
4use sqlx::SqlitePool;
5
6/// The source of authentication credentials.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum AuthSource {
9    /// Authenticated via atrg's own session token (cookie or bearer).
10    Atrg,
11    /// Authenticated via a PDS-issued AT Protocol JWT.
12    AtprotoJwt,
13}
14
15/// A resolved authentication session, shared across all auth paths.
16///
17/// Handlers receive this via `AuthUser` or `RequireAuth` extractors.
18/// The `source` field indicates whether the credential was an atrg
19/// session token or an AT Protocol JWT — but most handlers shouldn't
20/// need to check.
21#[derive(Debug, Clone)]
22pub struct AtrgSession {
23    /// The user's DID (e.g. `did:plc:...`).
24    pub did: String,
25    /// The user's handle (e.g. `alice.bsky.social`).
26    pub handle: String,
27    /// The access token for outbound AT Protocol calls.
28    pub access_token: String,
29    /// The refresh token (only present for atrg sessions).
30    pub refresh_token: Option<String>,
31    /// Unix timestamp when this session expires.
32    pub expires_at: i64,
33    /// How the user authenticated.
34    pub source: AuthSource,
35}
36
37/// Generate a cryptographically random session ID (32 bytes, base64url-encoded).
38pub fn generate_session_id() -> String {
39    let mut bytes = [0u8; 32];
40    rand::thread_rng().fill(&mut bytes);
41    base64_url_encode(&bytes)
42}
43
44/// Base64url-encode without padding.
45fn base64_url_encode(data: &[u8]) -> String {
46    use base64::Engine;
47    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
48}
49
50/// Look up a session by ID, filtering out expired sessions.
51pub async fn find_session(
52    pool: &SqlitePool,
53    session_id: &str,
54) -> anyhow::Result<Option<AtrgSession>> {
55    let now = std::time::SystemTime::now()
56        .duration_since(std::time::UNIX_EPOCH)
57        .unwrap_or_default()
58        .as_secs() as i64;
59
60    let row = sqlx::query_as::<_, SessionRow>(
61        "SELECT id, did, handle, access_token, refresh_token, expires_at
62         FROM atrg_sessions
63         WHERE id = ? AND expires_at > ?",
64    )
65    .bind(session_id)
66    .bind(now)
67    .fetch_optional(pool)
68    .await?;
69
70    // Update last_used_at on access
71    if row.is_some() {
72        let _ = sqlx::query("UPDATE atrg_sessions SET last_used_at = unixepoch() WHERE id = ?")
73            .bind(session_id)
74            .execute(pool)
75            .await;
76    }
77
78    Ok(row.map(|r| AtrgSession {
79        did: r.did,
80        handle: r.handle,
81        access_token: r.access_token,
82        refresh_token: r.refresh_token,
83        expires_at: r.expires_at,
84        source: AuthSource::Atrg,
85    }))
86}
87
88/// Insert a new session into the database.
89pub async fn create_session(
90    pool: &SqlitePool,
91    session_id: &str,
92    did: &str,
93    handle: &str,
94    access_token: &str,
95    refresh_token: Option<&str>,
96    expires_at: i64,
97) -> anyhow::Result<()> {
98    sqlx::query(
99        "INSERT INTO atrg_sessions (id, did, handle, access_token, refresh_token, expires_at)
100         VALUES (?, ?, ?, ?, ?, ?)",
101    )
102    .bind(session_id)
103    .bind(did)
104    .bind(handle)
105    .bind(access_token)
106    .bind(refresh_token)
107    .bind(expires_at)
108    .execute(pool)
109    .await?;
110
111    tracing::debug!(did = %did, handle = %handle, "session created");
112    Ok(())
113}
114
115/// Delete a session by ID (logout).
116pub async fn delete_session(pool: &SqlitePool, session_id: &str) -> anyhow::Result<()> {
117    sqlx::query("DELETE FROM atrg_sessions WHERE id = ?")
118        .bind(session_id)
119        .execute(pool)
120        .await?;
121    Ok(())
122}
123
124/// Delete all expired sessions (cleanup).
125pub async fn cleanup_expired_sessions(pool: &SqlitePool) -> anyhow::Result<u64> {
126    let now = std::time::SystemTime::now()
127        .duration_since(std::time::UNIX_EPOCH)
128        .unwrap_or_default()
129        .as_secs() as i64;
130
131    let result = sqlx::query("DELETE FROM atrg_sessions WHERE expires_at <= ?")
132        .bind(now)
133        .execute(pool)
134        .await?;
135
136    let deleted = result.rows_affected();
137    if deleted > 0 {
138        tracing::info!(count = deleted, "cleaned up expired sessions");
139    }
140    Ok(deleted)
141}
142
143/// Delete expired OAuth states (cleanup).
144pub async fn cleanup_expired_oauth_states(pool: &SqlitePool) -> anyhow::Result<u64> {
145    let now = std::time::SystemTime::now()
146        .duration_since(std::time::UNIX_EPOCH)
147        .unwrap_or_default()
148        .as_secs() as i64;
149
150    let result = sqlx::query("DELETE FROM atrg_oauth_states WHERE expires_at <= ?")
151        .bind(now)
152        .execute(pool)
153        .await?;
154
155    Ok(result.rows_affected())
156}
157
158#[derive(sqlx::FromRow)]
159struct SessionRow {
160    #[allow(dead_code)]
161    id: String,
162    did: String,
163    handle: String,
164    access_token: String,
165    refresh_token: Option<String>,
166    expires_at: i64,
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    async fn test_pool() -> SqlitePool {
174        let pool = atrg_db::connect("sqlite::memory:").await.unwrap();
175        atrg_db::run_internal_migrations(&pool).await.unwrap();
176        pool
177    }
178
179    #[tokio::test]
180    async fn generate_session_id_is_unique() {
181        let a = generate_session_id();
182        let b = generate_session_id();
183        assert_ne!(a, b);
184        assert!(a.len() >= 40, "session id should be ~43 chars base64url");
185    }
186
187    #[tokio::test]
188    async fn create_and_find_session() {
189        let pool = test_pool().await;
190        let sid = generate_session_id();
191        let expires = std::time::SystemTime::now()
192            .duration_since(std::time::UNIX_EPOCH)
193            .unwrap()
194            .as_secs() as i64
195            + 86400;
196
197        create_session(
198            &pool,
199            &sid,
200            "did:plc:test123",
201            "alice.test",
202            "tok_abc",
203            Some("ref_xyz"),
204            expires,
205        )
206        .await
207        .unwrap();
208
209        let session = find_session(&pool, &sid)
210            .await
211            .unwrap()
212            .expect("session should exist");
213        assert_eq!(session.did, "did:plc:test123");
214        assert_eq!(session.handle, "alice.test");
215        assert_eq!(session.access_token, "tok_abc");
216        assert_eq!(session.refresh_token.as_deref(), Some("ref_xyz"));
217        assert_eq!(session.source, AuthSource::Atrg);
218    }
219
220    #[tokio::test]
221    async fn expired_session_not_found() {
222        let pool = test_pool().await;
223        let sid = generate_session_id();
224        // Expired 1 hour ago
225        let expires = std::time::SystemTime::now()
226            .duration_since(std::time::UNIX_EPOCH)
227            .unwrap()
228            .as_secs() as i64
229            - 3600;
230
231        create_session(
232            &pool,
233            &sid,
234            "did:plc:expired",
235            "old.test",
236            "tok",
237            None,
238            expires,
239        )
240        .await
241        .unwrap();
242
243        let session = find_session(&pool, &sid).await.unwrap();
244        assert!(session.is_none(), "expired session should not be returned");
245    }
246
247    #[tokio::test]
248    async fn delete_session_works() {
249        let pool = test_pool().await;
250        let sid = generate_session_id();
251        let expires = std::time::SystemTime::now()
252            .duration_since(std::time::UNIX_EPOCH)
253            .unwrap()
254            .as_secs() as i64
255            + 86400;
256
257        create_session(&pool, &sid, "did:plc:del", "del.test", "tok", None, expires)
258            .await
259            .unwrap();
260
261        delete_session(&pool, &sid).await.unwrap();
262        let session = find_session(&pool, &sid).await.unwrap();
263        assert!(session.is_none());
264    }
265
266    #[tokio::test]
267    async fn cleanup_expired_sessions_works() {
268        let pool = test_pool().await;
269        let expired = std::time::SystemTime::now()
270            .duration_since(std::time::UNIX_EPOCH)
271            .unwrap()
272            .as_secs() as i64
273            - 3600;
274        let valid = expired + 7200;
275
276        create_session(&pool, "expired1", "did:plc:e1", "e1", "tok", None, expired)
277            .await
278            .unwrap();
279        create_session(&pool, "valid1", "did:plc:v1", "v1", "tok", None, valid)
280            .await
281            .unwrap();
282
283        let deleted = cleanup_expired_sessions(&pool).await.unwrap();
284        assert_eq!(deleted, 1);
285
286        assert!(find_session(&pool, "valid1").await.unwrap().is_some());
287    }
288
289    #[tokio::test]
290    async fn missing_session_returns_none() {
291        let pool = test_pool().await;
292        let session = find_session(&pool, "nonexistent").await.unwrap();
293        assert!(session.is_none());
294    }
295}