Skip to main content

atrg_label/
routes.rs

1//! Axum routes for labeler XRPC endpoints.
2
3use std::sync::Arc;
4
5use axum::extract::Query;
6use axum::routing::get;
7use axum::{Extension, Json, Router};
8use serde::{Deserialize, Serialize};
9
10use crate::label::LabelService;
11use crate::types::SignedLabel;
12use atrg_core::AppState;
13use atrg_xrpc::{XrpcError, XrpcErrorName};
14
15/// Build the labeler router with the label query endpoint.
16///
17/// Registers:
18/// - `GET /xrpc/com.atproto.label.queryLabels`
19///
20/// The returned router must be merged into the application router.
21/// The `label_service` is injected via an Axum [`Extension`] layer so that
22/// handlers can access it without polluting [`AppState`].
23pub fn labeler_routes(label_service: Arc<LabelService>) -> Router<AppState> {
24    Router::new()
25        .route(
26            "/xrpc/com.atproto.label.queryLabels",
27            get(query_labels_handler),
28        )
29        .layer(Extension(label_service))
30}
31
32/// Query parameters for `com.atproto.label.queryLabels`.
33#[derive(Debug, Deserialize)]
34#[serde(rename_all = "camelCase")]
35struct QueryLabelsParams {
36    /// URI patterns to match against label subjects.
37    /// Multiple values can be provided as repeated query parameters.
38    #[serde(default)]
39    uri_patterns: Vec<String>,
40    /// Filter by label source DIDs.
41    #[serde(default)]
42    sources: Vec<String>,
43    /// Cursor for pagination (opaque string representing the last seen row id).
44    cursor: Option<String>,
45    /// Maximum number of labels to return (default 50, max 250).
46    limit: Option<i64>,
47}
48
49/// Response body for `com.atproto.label.queryLabels`.
50#[derive(Debug, Serialize)]
51struct QueryLabelsResponse {
52    /// The matching labels.
53    labels: Vec<SignedLabel>,
54    /// Cursor for the next page, if more results are available.
55    #[serde(skip_serializing_if = "Option::is_none")]
56    cursor: Option<String>,
57}
58
59/// Handler for `GET /xrpc/com.atproto.label.queryLabels`.
60///
61/// Supports filtering by URI patterns and source DIDs. When a cursor is
62/// provided, returns labels created after that cursor position. Results
63/// are ordered by creation time (ascending).
64async fn query_labels_handler(
65    Extension(service): Extension<Arc<LabelService>>,
66    Query(params): Query<QueryLabelsParams>,
67) -> Result<Json<QueryLabelsResponse>, XrpcError> {
68    let limit = params.limit.unwrap_or(50).clamp(1, 250);
69
70    // If a cursor is provided, use cursor-based pagination via query_since.
71    if let Some(ref cursor_str) = params.cursor {
72        let cursor_id: i64 = cursor_str.parse().map_err(|_| XrpcError {
73            name: XrpcErrorName::InvalidRequest,
74            message: "Invalid cursor value".to_string(),
75        })?;
76
77        let results = service.query_since(cursor_id, limit).await.map_err(|e| {
78            tracing::error!(error = %e, "Failed to query labels since cursor");
79            XrpcError {
80                name: XrpcErrorName::InternalServerError,
81                message: "Failed to query labels".to_string(),
82            }
83        })?;
84
85        let next_cursor = results.last().map(|(id, _)| id.to_string());
86        let mut labels: Vec<SignedLabel> = results.into_iter().map(|(_, label)| label).collect();
87
88        // Apply source filter if provided.
89        if !params.sources.is_empty() {
90            labels.retain(|l| params.sources.contains(&l.label.src));
91        }
92
93        return Ok(Json(QueryLabelsResponse {
94            labels,
95            cursor: next_cursor,
96        }));
97    }
98
99    // No cursor — query by URI patterns.
100    if params.uri_patterns.is_empty() {
101        return Err(XrpcError {
102            name: XrpcErrorName::InvalidRequest,
103            message: "At least one uriPatterns value or a cursor is required".to_string(),
104        });
105    }
106
107    let mut all_labels = Vec::new();
108    for pattern in &params.uri_patterns {
109        let labels = service.query_labels(pattern).await.map_err(|e| {
110            tracing::error!(error = %e, uri = %pattern, "Failed to query labels by URI");
111            XrpcError {
112                name: XrpcErrorName::InternalServerError,
113                message: "Failed to query labels".to_string(),
114            }
115        })?;
116        all_labels.extend(labels);
117    }
118
119    // Apply source filter if provided.
120    if !params.sources.is_empty() {
121        all_labels.retain(|l| params.sources.contains(&l.label.src));
122    }
123
124    // Truncate to limit.
125    all_labels.truncate(limit as usize);
126
127    Ok(Json(QueryLabelsResponse {
128        labels: all_labels,
129        cursor: None,
130    }))
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::signing::LabelSigner;
137    use crate::types::LabelValue;
138    use atrg_core::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
139    use axum::body::Body;
140    use http_body_util::BodyExt;
141    use hyper::Request;
142    use tower::ServiceExt;
143
144    async fn test_state() -> AppState {
145        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
146        atrg_db::run_internal_migrations(&db).await.unwrap();
147        AppState {
148            config: Arc::new(Config {
149                app: AppConfig {
150                    name: "test".into(),
151                    host: "127.0.0.1".into(),
152                    port: 3000,
153                    secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
154                    cors_origins: vec![],
155                    environment: "development".into(),
156                },
157                auth: AuthConfig::default(),
158                database: DatabaseConfig::default(),
159                jetstream: None,
160                firehose: None,
161                feed_generator: None,
162                labeler: None,
163                rate_limit: None,
164            }),
165            db,
166            http: reqwest::Client::new(),
167            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
168                reqwest::Client::new(),
169            )),
170        }
171    }
172
173    async fn setup_service() -> Arc<LabelService> {
174        let db = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
175        let signer = LabelSigner::new(b"test-key".to_vec());
176        let svc = LabelService::new(db, signer, "did:plc:test-labeler".to_string());
177        svc.migrate().await.unwrap();
178        Arc::new(svc)
179    }
180
181    async fn build_app(service: Arc<LabelService>) -> axum::Router {
182        let state = test_state().await;
183        labeler_routes(service).with_state(state)
184    }
185
186    fn parse_body(bytes: &[u8]) -> serde_json::Value {
187        serde_json::from_slice(bytes).unwrap()
188    }
189
190    /// Helper: fetch the endpoint and return (status, parsed JSON body).
191    async fn get_labels(app: axum::Router, query: &str) -> (u16, serde_json::Value) {
192        let uri = format!("/xrpc/com.atproto.label.queryLabels{}", query);
193        let resp = app
194            .oneshot(Request::get(&uri).body(Body::empty()).unwrap())
195            .await
196            .unwrap();
197        let status = resp.status().as_u16();
198        let bytes = Body::new(resp.into_body())
199            .collect()
200            .await
201            .unwrap()
202            .to_bytes();
203        (status, parse_body(&bytes))
204    }
205
206    // --- Cursor-based tests (the primary pagination path) ---
207
208    #[tokio::test]
209    async fn test_query_labels_returns_labels() {
210        let svc = setup_service().await;
211        svc.create_label("at://did:plc:user/post/1", LabelValue::Spam, None)
212            .await
213            .unwrap();
214        svc.create_label("at://did:plc:user/post/1", LabelValue::Porn, None)
215            .await
216            .unwrap();
217
218        let app = build_app(svc).await;
219        let (status, body) = get_labels(app, "?cursor=0&limit=10").await;
220
221        assert_eq!(status, 200);
222        let labels = body["labels"].as_array().unwrap();
223        assert_eq!(labels.len(), 2);
224        assert_eq!(labels[0]["val"], "spam");
225        assert_eq!(labels[1]["val"], "porn");
226        // Cursor should be present when there are results.
227        assert!(body["cursor"].as_str().is_some());
228    }
229
230    #[tokio::test]
231    async fn test_query_labels_with_cursor() {
232        let svc = setup_service().await;
233        for i in 0..5 {
234            svc.create_label(
235                "at://did:plc:user/post/1",
236                LabelValue::Custom(format!("val-{}", i)),
237                None,
238            )
239            .await
240            .unwrap();
241        }
242
243        // First page: cursor=0, limit=3.
244        let app = build_app(Arc::clone(&svc)).await;
245        let (status, body) = get_labels(app, "?cursor=0&limit=3").await;
246
247        assert_eq!(status, 200);
248        let labels = body["labels"].as_array().unwrap();
249        assert_eq!(labels.len(), 3);
250        assert_eq!(labels[0]["val"], "val-0");
251        assert_eq!(labels[2]["val"], "val-2");
252        let cursor = body["cursor"].as_str().unwrap();
253
254        // Second page using returned cursor.
255        let app2 = build_app(svc).await;
256        let (status2, body2) = get_labels(app2, &format!("?cursor={}&limit=3", cursor)).await;
257
258        assert_eq!(status2, 200);
259        let labels2 = body2["labels"].as_array().unwrap();
260        assert_eq!(labels2.len(), 2);
261        assert_eq!(labels2[0]["val"], "val-3");
262        assert_eq!(labels2[1]["val"], "val-4");
263    }
264
265    #[tokio::test]
266    async fn test_query_labels_empty() {
267        let svc = setup_service().await;
268        let app = build_app(svc).await;
269
270        // No labels inserted — cursor-based query returns empty list.
271        let (status, body) = get_labels(app, "?cursor=0&limit=10").await;
272
273        assert_eq!(status, 200);
274        let labels = body["labels"].as_array().unwrap();
275        assert!(labels.is_empty());
276        // No cursor when there are no results.
277        assert!(body.get("cursor").is_none() || body["cursor"].is_null());
278    }
279
280    #[tokio::test]
281    async fn test_query_labels_default_limit() {
282        let svc = setup_service().await;
283        // Insert 60 labels — more than the default limit of 50.
284        for i in 0..60 {
285            svc.create_label(
286                "at://did:plc:user/post/1",
287                LabelValue::Custom(format!("v{}", i)),
288                None,
289            )
290            .await
291            .unwrap();
292        }
293
294        // No explicit limit — should default to 50.
295        let app = build_app(svc).await;
296        let (status, body) = get_labels(app, "?cursor=0").await;
297
298        assert_eq!(status, 200);
299        let labels = body["labels"].as_array().unwrap();
300        assert_eq!(labels.len(), 50);
301    }
302
303    // --- Error cases ---
304
305    #[tokio::test]
306    async fn test_query_labels_no_patterns_no_cursor_returns_error() {
307        let svc = setup_service().await;
308        let app = build_app(svc).await;
309
310        let (status, body) = get_labels(app, "").await;
311
312        assert_eq!(status, 400);
313        assert_eq!(body["error"], "InvalidRequest");
314    }
315
316    #[tokio::test]
317    async fn test_query_labels_invalid_cursor_returns_error() {
318        let svc = setup_service().await;
319        let app = build_app(svc).await;
320
321        let (status, body) = get_labels(app, "?cursor=not-a-number").await;
322
323        assert_eq!(status, 400);
324        assert_eq!(body["error"], "InvalidRequest");
325        assert!(body["message"].as_str().unwrap().contains("Invalid cursor"));
326    }
327
328    // --- Limit clamping ---
329
330    #[tokio::test]
331    async fn test_query_labels_limit_clamped_to_max_250() {
332        let svc = setup_service().await;
333        for i in 0..260 {
334            svc.create_label(
335                "at://did:plc:user/post/1",
336                LabelValue::Custom(format!("v{}", i)),
337                None,
338            )
339            .await
340            .unwrap();
341        }
342
343        let app = build_app(svc).await;
344        let (status, body) = get_labels(app, "?cursor=0&limit=999").await;
345
346        assert_eq!(status, 200);
347        let labels = body["labels"].as_array().unwrap();
348        // limit is clamped to 250.
349        assert_eq!(labels.len(), 250);
350    }
351
352    #[tokio::test]
353    async fn test_query_labels_limit_clamped_to_min_1() {
354        let svc = setup_service().await;
355        svc.create_label("at://did:plc:user/post/1", LabelValue::Spam, None)
356            .await
357            .unwrap();
358
359        let app = build_app(svc).await;
360        let (status, body) = get_labels(app, "?cursor=0&limit=0").await;
361
362        assert_eq!(status, 200);
363        let labels = body["labels"].as_array().unwrap();
364        assert_eq!(labels.len(), 1);
365    }
366
367    // --- Response shape ---
368
369    #[tokio::test]
370    async fn test_query_labels_response_contains_label_fields() {
371        let svc = setup_service().await;
372        svc.create_label("at://did:plc:user/post/1", LabelValue::Spam, None)
373            .await
374            .unwrap();
375
376        let app = build_app(svc).await;
377        let (status, body) = get_labels(app, "?cursor=0&limit=1").await;
378
379        assert_eq!(status, 200);
380        let label = &body["labels"][0];
381        assert_eq!(label["src"], "did:plc:test-labeler");
382        assert_eq!(label["uri"], "at://did:plc:user/post/1");
383        assert_eq!(label["val"], "spam");
384        assert_eq!(label["neg"], false);
385        assert_eq!(label["ver"], 1);
386        // sig should be present and non-empty.
387        assert!(!label["sig"].as_str().unwrap().is_empty());
388        // cts should be present.
389        assert!(label["cts"].as_str().is_some());
390    }
391}