1use serde::de::DeserializeOwned;
6#[allow(unused_imports)]
7use tracing::debug;
8
9use crate::at_uri::AtUri;
10use crate::blob;
11use crate::error::RepoError;
12use crate::tid::Tid;
13use crate::types::{BlobRef, Page, Record, StrongRef};
14
15pub struct Repo {
32 http: reqwest::Client,
33 pds_endpoint: String,
34 access_token: String,
35 did: String,
36}
37
38impl Repo {
39 pub fn new(http: &reqwest::Client, pds_endpoint: &str, access_token: &str, did: &str) -> Self {
41 Self {
42 http: http.clone(),
43 pds_endpoint: pds_endpoint.trim_end_matches('/').to_string(),
44 access_token: access_token.to_string(),
45 did: did.to_string(),
46 }
47 }
48
49 pub fn from_session(
51 http: &reqwest::Client,
52 session: &atrg_auth::AtrgSession,
53 pds_endpoint: &str,
54 ) -> Self {
55 Self::new(http, pds_endpoint, &session.access_token, &session.did)
56 }
57
58 pub fn did(&self) -> &str {
60 &self.did
61 }
62
63 pub fn pds_endpoint(&self) -> &str {
65 &self.pds_endpoint
66 }
67
68 pub async fn get_record<T: DeserializeOwned>(
72 &self,
73 uri: &AtUri,
74 ) -> Result<Record<T>, RepoError> {
75 let url = format!("{}/xrpc/com.atproto.repo.getRecord", self.pds_endpoint);
76
77 debug!(uri = %uri, "getting record");
78
79 let resp = self
80 .http
81 .get(&url)
82 .bearer_auth(&self.access_token)
83 .query(&[
84 ("repo", uri.authority.as_str()),
85 ("collection", uri.collection.as_str()),
86 ("rkey", uri.rkey.as_str()),
87 ])
88 .send()
89 .await?;
90
91 if !resp.status().is_success() {
92 let status = resp.status();
93 let body = resp.text().await.unwrap_or_default();
94 if status.as_u16() == 404 {
95 return Err(RepoError::NotFound);
96 }
97 return Err(RepoError::Pds(format!(
98 "getRecord failed ({}): {}",
99 status, body
100 )));
101 }
102
103 let json: serde_json::Value = resp.json().await?;
104
105 let record_uri = json["uri"].as_str().unwrap_or_default().to_string();
106 let cid = json["cid"].as_str().unwrap_or_default().to_string();
107 let value: T = serde_json::from_value(json["value"].clone())
108 .map_err(|e| RepoError::Internal(e.into()))?;
109
110 Ok(Record {
111 uri: record_uri,
112 cid,
113 value,
114 })
115 }
116
117 pub async fn list_records<T: DeserializeOwned>(
121 &self,
122 collection: &str,
123 cursor: Option<&str>,
124 limit: Option<usize>,
125 ) -> Result<Page<Record<T>>, RepoError> {
126 let url = format!("{}/xrpc/com.atproto.repo.listRecords", self.pds_endpoint);
127
128 debug!(collection, cursor, limit, "listing records");
129
130 let mut query = vec![("repo", self.did.as_str()), ("collection", collection)];
131
132 let limit_str;
133 if let Some(l) = limit {
134 limit_str = l.to_string();
135 query.push(("limit", &limit_str));
136 }
137
138 if let Some(c) = cursor {
139 query.push(("cursor", c));
140 }
141
142 let resp = self
143 .http
144 .get(&url)
145 .bearer_auth(&self.access_token)
146 .query(&query)
147 .send()
148 .await?;
149
150 if !resp.status().is_success() {
151 let status = resp.status();
152 let body = resp.text().await.unwrap_or_default();
153 return Err(RepoError::Pds(format!(
154 "listRecords failed ({}): {}",
155 status, body
156 )));
157 }
158
159 let json: serde_json::Value = resp.json().await?;
160
161 let cursor_out = json["cursor"].as_str().map(String::from);
162
163 let records_json = json["records"].as_array().cloned().unwrap_or_default();
164
165 let mut records = Vec::with_capacity(records_json.len());
166 for r in records_json {
167 let uri = r["uri"].as_str().unwrap_or_default().to_string();
168 let cid = r["cid"].as_str().unwrap_or_default().to_string();
169 let value: T = serde_json::from_value(r["value"].clone())
170 .map_err(|e| RepoError::Internal(e.into()))?;
171 records.push(Record { uri, cid, value });
172 }
173
174 Ok(Page {
175 records,
176 cursor: cursor_out,
177 })
178 }
179
180 pub async fn create_record(
184 &self,
185 collection: &str,
186 record: &serde_json::Value,
187 ) -> Result<StrongRef, RepoError> {
188 let url = format!("{}/xrpc/com.atproto.repo.createRecord", self.pds_endpoint);
189
190 debug!(collection, "creating record");
191
192 let body = serde_json::json!({
193 "repo": self.did,
194 "collection": collection,
195 "record": record,
196 });
197
198 let resp = self
199 .http
200 .post(&url)
201 .bearer_auth(&self.access_token)
202 .json(&body)
203 .send()
204 .await?;
205
206 if !resp.status().is_success() {
207 let status = resp.status();
208 let body = resp.text().await.unwrap_or_default();
209 return Err(RepoError::Pds(format!(
210 "createRecord failed ({}): {}",
211 status, body
212 )));
213 }
214
215 let json: serde_json::Value = resp.json().await?;
216
217 Ok(StrongRef {
218 uri: json["uri"].as_str().unwrap_or_default().to_string(),
219 cid: json["cid"].as_str().unwrap_or_default().to_string(),
220 })
221 }
222
223 pub async fn put_record(
227 &self,
228 collection: &str,
229 rkey: &str,
230 record: &serde_json::Value,
231 ) -> Result<StrongRef, RepoError> {
232 let url = format!("{}/xrpc/com.atproto.repo.putRecord", self.pds_endpoint);
233
234 debug!(collection, rkey, "putting record");
235
236 let body = serde_json::json!({
237 "repo": self.did,
238 "collection": collection,
239 "rkey": rkey,
240 "record": record,
241 });
242
243 let resp = self
244 .http
245 .post(&url)
246 .bearer_auth(&self.access_token)
247 .json(&body)
248 .send()
249 .await?;
250
251 if !resp.status().is_success() {
252 let status = resp.status();
253 let body = resp.text().await.unwrap_or_default();
254 return Err(RepoError::Pds(format!(
255 "putRecord failed ({}): {}",
256 status, body
257 )));
258 }
259
260 let json: serde_json::Value = resp.json().await?;
261
262 Ok(StrongRef {
263 uri: json["uri"].as_str().unwrap_or_default().to_string(),
264 cid: json["cid"].as_str().unwrap_or_default().to_string(),
265 })
266 }
267
268 pub async fn delete_record(&self, uri: &AtUri) -> Result<(), RepoError> {
272 let url = format!("{}/xrpc/com.atproto.repo.deleteRecord", self.pds_endpoint);
273
274 debug!(%uri, "deleting record");
275
276 let body = serde_json::json!({
277 "repo": uri.authority,
278 "collection": uri.collection,
279 "rkey": uri.rkey,
280 });
281
282 let resp = self
283 .http
284 .post(&url)
285 .bearer_auth(&self.access_token)
286 .json(&body)
287 .send()
288 .await?;
289
290 if !resp.status().is_success() {
291 let status = resp.status();
292 let body = resp.text().await.unwrap_or_default();
293 if status.as_u16() == 404 {
294 return Err(RepoError::NotFound);
295 }
296 return Err(RepoError::Pds(format!(
297 "deleteRecord failed ({}): {}",
298 status, body
299 )));
300 }
301
302 Ok(())
303 }
304
305 pub async fn upload_blob(&self, data: Vec<u8>, mime_type: &str) -> Result<BlobRef, RepoError> {
309 blob::upload_blob(
310 &self.http,
311 &self.pds_endpoint,
312 &self.access_token,
313 data,
314 mime_type,
315 )
316 .await
317 }
318
319 pub fn new_tid() -> Tid {
321 Tid::now()
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_repo_new_trims_trailing_slash() {
331 let http = reqwest::Client::new();
332 let repo = Repo::new(&http, "https://pds.example.com/", "tok", "did:plc:abc");
333 assert_eq!(repo.pds_endpoint(), "https://pds.example.com");
334 }
335
336 #[test]
337 fn test_repo_new_no_trailing_slash_unchanged() {
338 let http = reqwest::Client::new();
339 let repo = Repo::new(&http, "https://pds.example.com", "tok", "did:plc:abc");
340 assert_eq!(repo.pds_endpoint(), "https://pds.example.com");
341 }
342
343 #[test]
344 fn test_repo_did() {
345 let http = reqwest::Client::new();
346 let repo = Repo::new(&http, "https://pds.example.com", "tok", "did:plc:abc");
347 assert_eq!(repo.did(), "did:plc:abc");
348 }
349
350 #[test]
351 fn test_from_session() {
352 use atrg_auth::{AtrgSession, AuthSource};
353
354 let session = AtrgSession {
355 did: "did:plc:session123".to_string(),
356 handle: "alice.test".to_string(),
357 access_token: "access_tok_xyz".to_string(),
358 refresh_token: Some("ref_tok".to_string()),
359 expires_at: 9999999999,
360 source: AuthSource::Atrg,
361 };
362
363 let http = reqwest::Client::new();
364 let repo = Repo::from_session(&http, &session, "https://pds.example.com/");
365
366 assert_eq!(repo.did(), "did:plc:session123");
367 assert_eq!(repo.pds_endpoint(), "https://pds.example.com");
368 }
369
370 #[test]
371 fn test_from_session_atproto_jwt_source() {
372 use atrg_auth::{AtrgSession, AuthSource};
373
374 let session = AtrgSession {
375 did: "did:web:bob.test".to_string(),
376 handle: "bob.test".to_string(),
377 access_token: "jwt_token".to_string(),
378 refresh_token: None,
379 expires_at: 9999999999,
380 source: AuthSource::AtprotoJwt,
381 };
382
383 let http = reqwest::Client::new();
384 let repo = Repo::from_session(&http, &session, "https://other-pds.example.com");
385
386 assert_eq!(repo.did(), "did:web:bob.test");
387 assert_eq!(repo.pds_endpoint(), "https://other-pds.example.com");
388 }
389
390 #[test]
391 fn test_new_tid_returns_valid() {
392 let tid = Repo::new_tid();
393 assert_eq!(tid.as_str().len(), 13);
394 }
395
396 #[test]
397 fn test_new_tid_parses_back() {
398 let tid = Repo::new_tid();
399 let parsed = Tid::parse(tid.as_str());
400 assert!(parsed.is_ok(), "generated TID should parse successfully");
401 assert_eq!(parsed.unwrap().as_str(), tid.as_str());
402 }
403
404 #[test]
405 fn test_new_tid_successive_are_distinct() {
406 let a = Repo::new_tid();
407 std::thread::sleep(std::time::Duration::from_millis(2));
408 let b = Repo::new_tid();
409 assert_ne!(a.as_str(), b.as_str());
410 }
411
412 use wiremock::matchers::{header, method, path, query_param};
413 use wiremock::{Mock, MockServer, ResponseTemplate};
414
415 async fn mock_repo(server: &MockServer) -> Repo {
416 let http = reqwest::Client::new();
417 Repo::new(&http, &server.uri(), "test_token", "did:plc:testuser")
418 }
419
420 #[tokio::test]
423 async fn get_record_success() {
424 let server = MockServer::start().await;
425 Mock::given(method("GET"))
426 .and(path("/xrpc/com.atproto.repo.getRecord"))
427 .and(query_param("repo", "did:plc:testuser"))
428 .and(query_param("collection", "app.bsky.feed.post"))
429 .and(query_param("rkey", "3k2la"))
430 .and(header("Authorization", "Bearer test_token"))
431 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
432 "uri": "at://did:plc:testuser/app.bsky.feed.post/3k2la",
433 "cid": "bafyabc",
434 "value": { "text": "hello world" }
435 })))
436 .mount(&server)
437 .await;
438
439 let repo = mock_repo(&server).await;
440 let uri = AtUri::parse("at://did:plc:testuser/app.bsky.feed.post/3k2la").unwrap();
441 let record: Record<serde_json::Value> = repo.get_record(&uri).await.unwrap();
442
443 assert_eq!(record.uri, "at://did:plc:testuser/app.bsky.feed.post/3k2la");
444 assert_eq!(record.cid, "bafyabc");
445 assert_eq!(record.value["text"], "hello world");
446 }
447
448 #[tokio::test]
449 async fn get_record_not_found() {
450 let server = MockServer::start().await;
451 Mock::given(method("GET"))
452 .and(path("/xrpc/com.atproto.repo.getRecord"))
453 .respond_with(ResponseTemplate::new(404).set_body_json(serde_json::json!({
454 "error": "RecordNotFound",
455 "message": "not found"
456 })))
457 .mount(&server)
458 .await;
459
460 let repo = mock_repo(&server).await;
461 let uri = AtUri::parse("at://did:plc:testuser/app.bsky.feed.post/missing").unwrap();
462 let result: Result<Record<serde_json::Value>, _> = repo.get_record(&uri).await;
463
464 assert!(matches!(result, Err(RepoError::NotFound)));
465 }
466
467 #[tokio::test]
468 async fn get_record_pds_error() {
469 let server = MockServer::start().await;
470 Mock::given(method("GET"))
471 .and(path("/xrpc/com.atproto.repo.getRecord"))
472 .respond_with(ResponseTemplate::new(500).set_body_string("internal"))
473 .mount(&server)
474 .await;
475
476 let repo = mock_repo(&server).await;
477 let uri = AtUri::parse("at://did:plc:testuser/app.bsky.feed.post/rk").unwrap();
478 let result: Result<Record<serde_json::Value>, _> = repo.get_record(&uri).await;
479
480 match result {
481 Err(RepoError::Pds(msg)) => assert!(msg.contains("500")),
482 other => panic!("expected Pds error, got {:?}", other),
483 }
484 }
485
486 #[tokio::test]
489 async fn list_records_success() {
490 let server = MockServer::start().await;
491 Mock::given(method("GET"))
492 .and(path("/xrpc/com.atproto.repo.listRecords"))
493 .and(query_param("repo", "did:plc:testuser"))
494 .and(query_param("collection", "app.bsky.feed.post"))
495 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
496 "records": [
497 { "uri": "at://did:plc:testuser/app.bsky.feed.post/1", "cid": "cid1", "value": { "text": "a" } },
498 { "uri": "at://did:plc:testuser/app.bsky.feed.post/2", "cid": "cid2", "value": { "text": "b" } }
499 ],
500 "cursor": "next123"
501 })))
502 .mount(&server)
503 .await;
504
505 let repo = mock_repo(&server).await;
506 let page: Page<Record<serde_json::Value>> = repo
507 .list_records("app.bsky.feed.post", None, None)
508 .await
509 .unwrap();
510
511 assert_eq!(page.records.len(), 2);
512 assert_eq!(page.records[0].value["text"], "a");
513 assert_eq!(page.cursor.as_deref(), Some("next123"));
514 }
515
516 #[tokio::test]
517 async fn list_records_with_cursor_and_limit() {
518 let server = MockServer::start().await;
519 Mock::given(method("GET"))
520 .and(path("/xrpc/com.atproto.repo.listRecords"))
521 .and(query_param("cursor", "abc"))
522 .and(query_param("limit", "5"))
523 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
524 "records": [],
525 "cursor": null
526 })))
527 .mount(&server)
528 .await;
529
530 let repo = mock_repo(&server).await;
531 let page: Page<Record<serde_json::Value>> = repo
532 .list_records("app.bsky.feed.post", Some("abc"), Some(5))
533 .await
534 .unwrap();
535
536 assert!(page.records.is_empty());
537 assert!(page.cursor.is_none());
538 }
539
540 #[tokio::test]
541 async fn list_records_pds_error() {
542 let server = MockServer::start().await;
543 Mock::given(method("GET"))
544 .and(path("/xrpc/com.atproto.repo.listRecords"))
545 .respond_with(ResponseTemplate::new(403).set_body_string("forbidden"))
546 .mount(&server)
547 .await;
548
549 let repo = mock_repo(&server).await;
550 let result: Result<Page<Record<serde_json::Value>>, _> =
551 repo.list_records("col", None, None).await;
552
553 assert!(matches!(result, Err(RepoError::Pds(_))));
554 }
555
556 #[tokio::test]
559 async fn create_record_success() {
560 let server = MockServer::start().await;
561 Mock::given(method("POST"))
562 .and(path("/xrpc/com.atproto.repo.createRecord"))
563 .and(header("Authorization", "Bearer test_token"))
564 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
565 "uri": "at://did:plc:testuser/app.bsky.feed.post/newrkey",
566 "cid": "bafynew"
567 })))
568 .mount(&server)
569 .await;
570
571 let repo = mock_repo(&server).await;
572 let record = serde_json::json!({ "text": "new post" });
573 let strong = repo
574 .create_record("app.bsky.feed.post", &record)
575 .await
576 .unwrap();
577
578 assert_eq!(
579 strong.uri,
580 "at://did:plc:testuser/app.bsky.feed.post/newrkey"
581 );
582 assert_eq!(strong.cid, "bafynew");
583 }
584
585 #[tokio::test]
586 async fn create_record_pds_error() {
587 let server = MockServer::start().await;
588 Mock::given(method("POST"))
589 .and(path("/xrpc/com.atproto.repo.createRecord"))
590 .respond_with(ResponseTemplate::new(400).set_body_string("bad request"))
591 .mount(&server)
592 .await;
593
594 let repo = mock_repo(&server).await;
595 let result = repo.create_record("col", &serde_json::json!({})).await;
596
597 match result {
598 Err(RepoError::Pds(msg)) => assert!(msg.contains("400")),
599 other => panic!("expected Pds error, got {:?}", other),
600 }
601 }
602
603 #[tokio::test]
606 async fn put_record_success() {
607 let server = MockServer::start().await;
608 Mock::given(method("POST"))
609 .and(path("/xrpc/com.atproto.repo.putRecord"))
610 .and(header("Authorization", "Bearer test_token"))
611 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
612 "uri": "at://did:plc:testuser/app.bsky.actor.profile/self",
613 "cid": "bafyput"
614 })))
615 .mount(&server)
616 .await;
617
618 let repo = mock_repo(&server).await;
619 let record = serde_json::json!({ "displayName": "Alice" });
620 let strong = repo
621 .put_record("app.bsky.actor.profile", "self", &record)
622 .await
623 .unwrap();
624
625 assert_eq!(strong.cid, "bafyput");
626 }
627
628 #[tokio::test]
629 async fn put_record_pds_error() {
630 let server = MockServer::start().await;
631 Mock::given(method("POST"))
632 .and(path("/xrpc/com.atproto.repo.putRecord"))
633 .respond_with(ResponseTemplate::new(502).set_body_string("bad gateway"))
634 .mount(&server)
635 .await;
636
637 let repo = mock_repo(&server).await;
638 let result = repo.put_record("col", "rk", &serde_json::json!({})).await;
639
640 assert!(matches!(result, Err(RepoError::Pds(_))));
641 }
642
643 #[tokio::test]
646 async fn delete_record_success() {
647 let server = MockServer::start().await;
648 Mock::given(method("POST"))
649 .and(path("/xrpc/com.atproto.repo.deleteRecord"))
650 .and(header("Authorization", "Bearer test_token"))
651 .respond_with(ResponseTemplate::new(200))
652 .mount(&server)
653 .await;
654
655 let repo = mock_repo(&server).await;
656 let uri = AtUri::parse("at://did:plc:testuser/app.bsky.feed.post/3k2la").unwrap();
657 repo.delete_record(&uri).await.unwrap();
658 }
659
660 #[tokio::test]
661 async fn delete_record_not_found() {
662 let server = MockServer::start().await;
663 Mock::given(method("POST"))
664 .and(path("/xrpc/com.atproto.repo.deleteRecord"))
665 .respond_with(ResponseTemplate::new(404).set_body_string("not found"))
666 .mount(&server)
667 .await;
668
669 let repo = mock_repo(&server).await;
670 let uri = AtUri::parse("at://did:plc:testuser/app.bsky.feed.post/gone").unwrap();
671 let result = repo.delete_record(&uri).await;
672
673 assert!(matches!(result, Err(RepoError::NotFound)));
674 }
675
676 #[tokio::test]
677 async fn delete_record_pds_error() {
678 let server = MockServer::start().await;
679 Mock::given(method("POST"))
680 .and(path("/xrpc/com.atproto.repo.deleteRecord"))
681 .respond_with(ResponseTemplate::new(500).set_body_string("error"))
682 .mount(&server)
683 .await;
684
685 let repo = mock_repo(&server).await;
686 let uri = AtUri::parse("at://did:plc:testuser/app.test/rk").unwrap();
687 let result = repo.delete_record(&uri).await;
688
689 assert!(matches!(result, Err(RepoError::Pds(_))));
690 }
691
692 #[tokio::test]
695 async fn upload_blob_success() {
696 let server = MockServer::start().await;
697 Mock::given(method("POST"))
698 .and(path("/xrpc/com.atproto.repo.uploadBlob"))
699 .and(header("Authorization", "Bearer test_token"))
700 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
701 "blob": {
702 "ref": { "$link": "bafyblob123" },
703 "mimeType": "image/png",
704 "size": 2048
705 }
706 })))
707 .mount(&server)
708 .await;
709
710 let repo = mock_repo(&server).await;
711 let blob_ref = repo.upload_blob(vec![0u8; 100], "image/png").await.unwrap();
712
713 assert_eq!(blob_ref.reference.link, "bafyblob123");
714 assert_eq!(blob_ref.mime_type, "image/png");
715 assert_eq!(blob_ref.size, 2048);
716 }
717
718 #[tokio::test]
719 async fn upload_blob_pds_error() {
720 let server = MockServer::start().await;
721 Mock::given(method("POST"))
722 .and(path("/xrpc/com.atproto.repo.uploadBlob"))
723 .respond_with(ResponseTemplate::new(413).set_body_string("too large"))
724 .mount(&server)
725 .await;
726
727 let repo = mock_repo(&server).await;
728 let result = repo.upload_blob(vec![0u8; 100], "image/png").await;
729
730 assert!(matches!(result, Err(RepoError::Pds(_))));
731 }
732}