1use rand::Rng;
4use sqlx::SqlitePool;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum AuthSource {
9 Atrg,
11 AtprotoJwt,
13}
14
15#[derive(Debug, Clone)]
22pub struct AtrgSession {
23 pub did: String,
25 pub handle: String,
27 pub access_token: String,
29 pub refresh_token: Option<String>,
31 pub expires_at: i64,
33 pub source: AuthSource,
35}
36
37pub fn generate_session_id() -> String {
39 let mut bytes = [0u8; 32];
40 rand::thread_rng().fill(&mut bytes);
41 base64_url_encode(&bytes)
42}
43
44fn base64_url_encode(data: &[u8]) -> String {
46 use base64::Engine;
47 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
48}
49
50pub async fn find_session(
52 pool: &SqlitePool,
53 session_id: &str,
54) -> anyhow::Result<Option<AtrgSession>> {
55 let now = std::time::SystemTime::now()
56 .duration_since(std::time::UNIX_EPOCH)
57 .unwrap_or_default()
58 .as_secs() as i64;
59
60 let row = sqlx::query_as::<_, SessionRow>(
61 "SELECT id, did, handle, access_token, refresh_token, expires_at
62 FROM atrg_sessions
63 WHERE id = ? AND expires_at > ?",
64 )
65 .bind(session_id)
66 .bind(now)
67 .fetch_optional(pool)
68 .await?;
69
70 if row.is_some() {
72 let _ = sqlx::query("UPDATE atrg_sessions SET last_used_at = unixepoch() WHERE id = ?")
73 .bind(session_id)
74 .execute(pool)
75 .await;
76 }
77
78 Ok(row.map(|r| AtrgSession {
79 did: r.did,
80 handle: r.handle,
81 access_token: r.access_token,
82 refresh_token: r.refresh_token,
83 expires_at: r.expires_at,
84 source: AuthSource::Atrg,
85 }))
86}
87
88pub async fn create_session(
90 pool: &SqlitePool,
91 session_id: &str,
92 did: &str,
93 handle: &str,
94 access_token: &str,
95 refresh_token: Option<&str>,
96 expires_at: i64,
97) -> anyhow::Result<()> {
98 sqlx::query(
99 "INSERT INTO atrg_sessions (id, did, handle, access_token, refresh_token, expires_at)
100 VALUES (?, ?, ?, ?, ?, ?)",
101 )
102 .bind(session_id)
103 .bind(did)
104 .bind(handle)
105 .bind(access_token)
106 .bind(refresh_token)
107 .bind(expires_at)
108 .execute(pool)
109 .await?;
110
111 tracing::debug!(did = %did, handle = %handle, "session created");
112 Ok(())
113}
114
115pub async fn delete_session(pool: &SqlitePool, session_id: &str) -> anyhow::Result<()> {
117 sqlx::query("DELETE FROM atrg_sessions WHERE id = ?")
118 .bind(session_id)
119 .execute(pool)
120 .await?;
121 Ok(())
122}
123
124pub async fn cleanup_expired_sessions(pool: &SqlitePool) -> anyhow::Result<u64> {
126 let now = std::time::SystemTime::now()
127 .duration_since(std::time::UNIX_EPOCH)
128 .unwrap_or_default()
129 .as_secs() as i64;
130
131 let result = sqlx::query("DELETE FROM atrg_sessions WHERE expires_at <= ?")
132 .bind(now)
133 .execute(pool)
134 .await?;
135
136 let deleted = result.rows_affected();
137 if deleted > 0 {
138 tracing::info!(count = deleted, "cleaned up expired sessions");
139 }
140 Ok(deleted)
141}
142
143pub async fn cleanup_expired_oauth_states(pool: &SqlitePool) -> anyhow::Result<u64> {
145 let now = std::time::SystemTime::now()
146 .duration_since(std::time::UNIX_EPOCH)
147 .unwrap_or_default()
148 .as_secs() as i64;
149
150 let result = sqlx::query("DELETE FROM atrg_oauth_states WHERE expires_at <= ?")
151 .bind(now)
152 .execute(pool)
153 .await?;
154
155 Ok(result.rows_affected())
156}
157
158#[derive(sqlx::FromRow)]
159struct SessionRow {
160 #[allow(dead_code)]
161 id: String,
162 did: String,
163 handle: String,
164 access_token: String,
165 refresh_token: Option<String>,
166 expires_at: i64,
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 async fn test_pool() -> SqlitePool {
174 let pool = atrg_db::connect("sqlite::memory:").await.unwrap();
175 atrg_db::run_internal_migrations(&pool).await.unwrap();
176 pool
177 }
178
179 #[tokio::test]
180 async fn generate_session_id_is_unique() {
181 let a = generate_session_id();
182 let b = generate_session_id();
183 assert_ne!(a, b);
184 assert!(a.len() >= 40, "session id should be ~43 chars base64url");
185 }
186
187 #[tokio::test]
188 async fn create_and_find_session() {
189 let pool = test_pool().await;
190 let sid = generate_session_id();
191 let expires = std::time::SystemTime::now()
192 .duration_since(std::time::UNIX_EPOCH)
193 .unwrap()
194 .as_secs() as i64
195 + 86400;
196
197 create_session(
198 &pool,
199 &sid,
200 "did:plc:test123",
201 "alice.test",
202 "tok_abc",
203 Some("ref_xyz"),
204 expires,
205 )
206 .await
207 .unwrap();
208
209 let session = find_session(&pool, &sid)
210 .await
211 .unwrap()
212 .expect("session should exist");
213 assert_eq!(session.did, "did:plc:test123");
214 assert_eq!(session.handle, "alice.test");
215 assert_eq!(session.access_token, "tok_abc");
216 assert_eq!(session.refresh_token.as_deref(), Some("ref_xyz"));
217 assert_eq!(session.source, AuthSource::Atrg);
218 }
219
220 #[tokio::test]
221 async fn expired_session_not_found() {
222 let pool = test_pool().await;
223 let sid = generate_session_id();
224 let expires = std::time::SystemTime::now()
226 .duration_since(std::time::UNIX_EPOCH)
227 .unwrap()
228 .as_secs() as i64
229 - 3600;
230
231 create_session(
232 &pool,
233 &sid,
234 "did:plc:expired",
235 "old.test",
236 "tok",
237 None,
238 expires,
239 )
240 .await
241 .unwrap();
242
243 let session = find_session(&pool, &sid).await.unwrap();
244 assert!(session.is_none(), "expired session should not be returned");
245 }
246
247 #[tokio::test]
248 async fn delete_session_works() {
249 let pool = test_pool().await;
250 let sid = generate_session_id();
251 let expires = std::time::SystemTime::now()
252 .duration_since(std::time::UNIX_EPOCH)
253 .unwrap()
254 .as_secs() as i64
255 + 86400;
256
257 create_session(&pool, &sid, "did:plc:del", "del.test", "tok", None, expires)
258 .await
259 .unwrap();
260
261 delete_session(&pool, &sid).await.unwrap();
262 let session = find_session(&pool, &sid).await.unwrap();
263 assert!(session.is_none());
264 }
265
266 #[tokio::test]
267 async fn cleanup_expired_sessions_works() {
268 let pool = test_pool().await;
269 let expired = std::time::SystemTime::now()
270 .duration_since(std::time::UNIX_EPOCH)
271 .unwrap()
272 .as_secs() as i64
273 - 3600;
274 let valid = expired + 7200;
275
276 create_session(&pool, "expired1", "did:plc:e1", "e1", "tok", None, expired)
277 .await
278 .unwrap();
279 create_session(&pool, "valid1", "did:plc:v1", "v1", "tok", None, valid)
280 .await
281 .unwrap();
282
283 let deleted = cleanup_expired_sessions(&pool).await.unwrap();
284 assert_eq!(deleted, 1);
285
286 assert!(find_session(&pool, "valid1").await.unwrap().is_some());
287 }
288
289 #[tokio::test]
290 async fn missing_session_returns_none() {
291 let pool = test_pool().await;
292 let session = find_session(&pool, "nonexistent").await.unwrap();
293 assert!(session.is_none());
294 }
295}