Skip to main content

atrg_testing/
mock_client.rs

1//! Mock AT Protocol client for testing.
2
3use std::collections::HashMap;
4use std::sync::Mutex;
5
6/// A recorded call to the mock client.
7#[derive(Debug, Clone)]
8pub struct RecordedCall {
9    /// The method name (e.g. "get_record", "put_record").
10    pub method: String,
11    /// The path or key.
12    pub path: String,
13    /// The request body (if any).
14    pub body: Option<serde_json::Value>,
15}
16
17/// Mock AT Protocol client that records calls and returns scripted responses.
18///
19/// Use this in tests instead of making real network calls.
20pub struct MockAtprotoClient {
21    calls: Mutex<Vec<RecordedCall>>,
22    responses: Mutex<HashMap<String, serde_json::Value>>,
23}
24
25impl MockAtprotoClient {
26    /// Create a new empty mock client.
27    pub fn new() -> Self {
28        Self {
29            calls: Mutex::new(Vec::new()),
30            responses: Mutex::new(HashMap::new()),
31        }
32    }
33
34    /// Script a response for a given method+path combination.
35    pub fn when(&self, method: &str, path: &str) -> &Self {
36        // Store the key for later; the response is set via `returns`
37        self.responses
38            .lock()
39            .expect("mutex poisoned")
40            .insert(format!("{}:{}", method, path), serde_json::json!(null));
41        self
42    }
43
44    /// Set the response for the most recently configured `when`.
45    pub fn returns(&self, method: &str, path: &str, response: serde_json::Value) {
46        self.responses
47            .lock()
48            .expect("mutex poisoned")
49            .insert(format!("{}:{}", method, path), response);
50    }
51
52    /// Simulate a get_record call.
53    pub fn get_record(&self, collection: &str, rkey: &str) -> anyhow::Result<serde_json::Value> {
54        let path = format!("{}/{}", collection, rkey);
55        self.record_call("get_record", &path, None);
56        self.get_response("get_record", &path)
57    }
58
59    /// Simulate a put_record call.
60    pub fn put_record(
61        &self,
62        collection: &str,
63        record: &serde_json::Value,
64    ) -> anyhow::Result<serde_json::Value> {
65        self.record_call("put_record", collection, Some(record.clone()));
66        self.get_response("put_record", collection)
67    }
68
69    /// Simulate a list_records call.
70    pub fn list_records(&self, collection: &str) -> anyhow::Result<serde_json::Value> {
71        self.record_call("list_records", collection, None);
72        self.get_response("list_records", collection)
73    }
74
75    /// Simulate a delete_record call.
76    pub fn delete_record(&self, collection: &str, rkey: &str) -> anyhow::Result<()> {
77        let path = format!("{}/{}", collection, rkey);
78        self.record_call("delete_record", &path, None);
79        Ok(())
80    }
81
82    /// Get all recorded calls.
83    pub fn calls(&self) -> Vec<RecordedCall> {
84        self.calls.lock().expect("mutex poisoned").clone()
85    }
86
87    /// Assert that a specific method+path was called exactly `n` times.
88    pub fn assert_called(&self, method: &str, path: &str, n: usize) {
89        let calls = self.calls.lock().expect("mutex poisoned");
90        let count = calls
91            .iter()
92            .filter(|c| c.method == method && c.path == path)
93            .count();
94        assert_eq!(
95            count, n,
96            "expected {} calls to {}:{}, got {}",
97            n, method, path, count
98        );
99    }
100
101    /// Assert that a method was called at least once (any path).
102    pub fn assert_called_any(&self, method: &str) {
103        let calls = self.calls.lock().expect("mutex poisoned");
104        let count = calls.iter().filter(|c| c.method == method).count();
105        assert!(count > 0, "expected at least one call to {}, got 0", method);
106    }
107
108    fn record_call(&self, method: &str, path: &str, body: Option<serde_json::Value>) {
109        self.calls
110            .lock()
111            .expect("mutex poisoned")
112            .push(RecordedCall {
113                method: method.to_string(),
114                path: path.to_string(),
115                body,
116            });
117    }
118
119    fn get_response(&self, method: &str, path: &str) -> anyhow::Result<serde_json::Value> {
120        let key = format!("{}:{}", method, path);
121        let responses = self.responses.lock().expect("mutex poisoned");
122        match responses.get(&key) {
123            Some(v) => Ok(v.clone()),
124            None => Ok(serde_json::json!({})),
125        }
126    }
127}
128
129impl Default for MockAtprotoClient {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn mock_records_calls() {
141        let mock = MockAtprotoClient::new();
142        mock.get_record("app.bsky.feed.post", "abc123").unwrap();
143        mock.assert_called("get_record", "app.bsky.feed.post/abc123", 1);
144    }
145
146    #[test]
147    fn mock_returns_scripted_response() {
148        let mock = MockAtprotoClient::new();
149        mock.returns(
150            "get_record",
151            "app.bsky.feed.post/abc",
152            serde_json::json!({"text": "hello"}),
153        );
154        let resp = mock.get_record("app.bsky.feed.post", "abc").unwrap();
155        assert_eq!(resp["text"], "hello");
156    }
157
158    #[test]
159    fn mock_put_record() {
160        let mock = MockAtprotoClient::new();
161        let record = serde_json::json!({"text": "new post"});
162        mock.put_record("app.bsky.feed.post", &record).unwrap();
163        mock.assert_called_any("put_record");
164    }
165
166    #[test]
167    fn mock_delete_record() {
168        let mock = MockAtprotoClient::new();
169        mock.delete_record("app.bsky.feed.post", "abc").unwrap();
170        mock.assert_called("delete_record", "app.bsky.feed.post/abc", 1);
171    }
172
173    #[test]
174    fn mock_default_response_is_empty_object() {
175        let mock = MockAtprotoClient::new();
176        let resp = mock.get_record("unknown.collection", "key").unwrap();
177        assert_eq!(resp, serde_json::json!({}));
178    }
179
180    #[test]
181    fn mock_when_sets_null_initially() {
182        let mock = MockAtprotoClient::new();
183        mock.when("get_record", "app.bsky.feed.post/test");
184        let resp = mock.get_record("app.bsky.feed.post", "test").unwrap();
185        assert!(resp.is_null());
186    }
187
188    #[test]
189    fn mock_list_records() {
190        let mock = MockAtprotoClient::new();
191        mock.returns(
192            "list_records",
193            "app.bsky.feed.post",
194            serde_json::json!({"records": []}),
195        );
196        let resp = mock.list_records("app.bsky.feed.post").unwrap();
197        assert!(resp["records"].is_array());
198    }
199
200    #[test]
201    fn mock_calls_returns_all_calls() {
202        let mock = MockAtprotoClient::new();
203        mock.get_record("col1", "a").unwrap();
204        mock.get_record("col2", "b").unwrap();
205        mock.delete_record("col1", "a").unwrap();
206        let calls = mock.calls();
207        assert_eq!(calls.len(), 3);
208        assert_eq!(calls[0].method, "get_record");
209        assert_eq!(calls[1].method, "get_record");
210        assert_eq!(calls[2].method, "delete_record");
211    }
212
213    #[test]
214    fn mock_put_record_captures_body() {
215        let mock = MockAtprotoClient::new();
216        let record = serde_json::json!({"text": "captured"});
217        mock.put_record("app.bsky.feed.post", &record).unwrap();
218        let calls = mock.calls();
219        assert_eq!(calls.len(), 1);
220        assert_eq!(calls[0].body.as_ref().unwrap()["text"], "captured");
221    }
222}