1use std::collections::HashMap;
4use std::sync::Mutex;
5
6#[derive(Debug, Clone)]
8pub struct RecordedCall {
9 pub method: String,
11 pub path: String,
13 pub body: Option<serde_json::Value>,
15}
16
17pub struct MockAtprotoClient {
21 calls: Mutex<Vec<RecordedCall>>,
22 responses: Mutex<HashMap<String, serde_json::Value>>,
23}
24
25impl MockAtprotoClient {
26 pub fn new() -> Self {
28 Self {
29 calls: Mutex::new(Vec::new()),
30 responses: Mutex::new(HashMap::new()),
31 }
32 }
33
34 pub fn when(&self, method: &str, path: &str) -> &Self {
36 self.responses
38 .lock()
39 .expect("mutex poisoned")
40 .insert(format!("{}:{}", method, path), serde_json::json!(null));
41 self
42 }
43
44 pub fn returns(&self, method: &str, path: &str, response: serde_json::Value) {
46 self.responses
47 .lock()
48 .expect("mutex poisoned")
49 .insert(format!("{}:{}", method, path), response);
50 }
51
52 pub fn get_record(&self, collection: &str, rkey: &str) -> anyhow::Result<serde_json::Value> {
54 let path = format!("{}/{}", collection, rkey);
55 self.record_call("get_record", &path, None);
56 self.get_response("get_record", &path)
57 }
58
59 pub fn put_record(
61 &self,
62 collection: &str,
63 record: &serde_json::Value,
64 ) -> anyhow::Result<serde_json::Value> {
65 self.record_call("put_record", collection, Some(record.clone()));
66 self.get_response("put_record", collection)
67 }
68
69 pub fn list_records(&self, collection: &str) -> anyhow::Result<serde_json::Value> {
71 self.record_call("list_records", collection, None);
72 self.get_response("list_records", collection)
73 }
74
75 pub fn delete_record(&self, collection: &str, rkey: &str) -> anyhow::Result<()> {
77 let path = format!("{}/{}", collection, rkey);
78 self.record_call("delete_record", &path, None);
79 Ok(())
80 }
81
82 pub fn calls(&self) -> Vec<RecordedCall> {
84 self.calls.lock().expect("mutex poisoned").clone()
85 }
86
87 pub fn assert_called(&self, method: &str, path: &str, n: usize) {
89 let calls = self.calls.lock().expect("mutex poisoned");
90 let count = calls
91 .iter()
92 .filter(|c| c.method == method && c.path == path)
93 .count();
94 assert_eq!(
95 count, n,
96 "expected {} calls to {}:{}, got {}",
97 n, method, path, count
98 );
99 }
100
101 pub fn assert_called_any(&self, method: &str) {
103 let calls = self.calls.lock().expect("mutex poisoned");
104 let count = calls.iter().filter(|c| c.method == method).count();
105 assert!(count > 0, "expected at least one call to {}, got 0", method);
106 }
107
108 fn record_call(&self, method: &str, path: &str, body: Option<serde_json::Value>) {
109 self.calls
110 .lock()
111 .expect("mutex poisoned")
112 .push(RecordedCall {
113 method: method.to_string(),
114 path: path.to_string(),
115 body,
116 });
117 }
118
119 fn get_response(&self, method: &str, path: &str) -> anyhow::Result<serde_json::Value> {
120 let key = format!("{}:{}", method, path);
121 let responses = self.responses.lock().expect("mutex poisoned");
122 match responses.get(&key) {
123 Some(v) => Ok(v.clone()),
124 None => Ok(serde_json::json!({})),
125 }
126 }
127}
128
129impl Default for MockAtprotoClient {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn mock_records_calls() {
141 let mock = MockAtprotoClient::new();
142 mock.get_record("app.bsky.feed.post", "abc123").unwrap();
143 mock.assert_called("get_record", "app.bsky.feed.post/abc123", 1);
144 }
145
146 #[test]
147 fn mock_returns_scripted_response() {
148 let mock = MockAtprotoClient::new();
149 mock.returns(
150 "get_record",
151 "app.bsky.feed.post/abc",
152 serde_json::json!({"text": "hello"}),
153 );
154 let resp = mock.get_record("app.bsky.feed.post", "abc").unwrap();
155 assert_eq!(resp["text"], "hello");
156 }
157
158 #[test]
159 fn mock_put_record() {
160 let mock = MockAtprotoClient::new();
161 let record = serde_json::json!({"text": "new post"});
162 mock.put_record("app.bsky.feed.post", &record).unwrap();
163 mock.assert_called_any("put_record");
164 }
165
166 #[test]
167 fn mock_delete_record() {
168 let mock = MockAtprotoClient::new();
169 mock.delete_record("app.bsky.feed.post", "abc").unwrap();
170 mock.assert_called("delete_record", "app.bsky.feed.post/abc", 1);
171 }
172
173 #[test]
174 fn mock_default_response_is_empty_object() {
175 let mock = MockAtprotoClient::new();
176 let resp = mock.get_record("unknown.collection", "key").unwrap();
177 assert_eq!(resp, serde_json::json!({}));
178 }
179
180 #[test]
181 fn mock_when_sets_null_initially() {
182 let mock = MockAtprotoClient::new();
183 mock.when("get_record", "app.bsky.feed.post/test");
184 let resp = mock.get_record("app.bsky.feed.post", "test").unwrap();
185 assert!(resp.is_null());
186 }
187
188 #[test]
189 fn mock_list_records() {
190 let mock = MockAtprotoClient::new();
191 mock.returns(
192 "list_records",
193 "app.bsky.feed.post",
194 serde_json::json!({"records": []}),
195 );
196 let resp = mock.list_records("app.bsky.feed.post").unwrap();
197 assert!(resp["records"].is_array());
198 }
199
200 #[test]
201 fn mock_calls_returns_all_calls() {
202 let mock = MockAtprotoClient::new();
203 mock.get_record("col1", "a").unwrap();
204 mock.get_record("col2", "b").unwrap();
205 mock.delete_record("col1", "a").unwrap();
206 let calls = mock.calls();
207 assert_eq!(calls.len(), 3);
208 assert_eq!(calls[0].method, "get_record");
209 assert_eq!(calls[1].method, "get_record");
210 assert_eq!(calls[2].method, "delete_record");
211 }
212
213 #[test]
214 fn mock_put_record_captures_body() {
215 let mock = MockAtprotoClient::new();
216 let record = serde_json::json!({"text": "captured"});
217 mock.put_record("app.bsky.feed.post", &record).unwrap();
218 let calls = mock.calls();
219 assert_eq!(calls.len(), 1);
220 assert_eq!(calls[0].body.as_ref().unwrap()["text"], "captured");
221 }
222}