1use std::sync::Arc;
4
5use axum::extract::Query;
6use axum::routing::get;
7use axum::{Extension, Json, Router};
8use serde::{Deserialize, Serialize};
9
10use crate::label::LabelService;
11use crate::types::SignedLabel;
12use atrg_core::AppState;
13use atrg_xrpc::{XrpcError, XrpcErrorName};
14
15pub fn labeler_routes(label_service: Arc<LabelService>) -> Router<AppState> {
24 Router::new()
25 .route(
26 "/xrpc/com.atproto.label.queryLabels",
27 get(query_labels_handler),
28 )
29 .layer(Extension(label_service))
30}
31
32#[derive(Debug, Deserialize)]
34#[serde(rename_all = "camelCase")]
35struct QueryLabelsParams {
36 #[serde(default)]
39 uri_patterns: Vec<String>,
40 #[serde(default)]
42 sources: Vec<String>,
43 cursor: Option<String>,
45 limit: Option<i64>,
47}
48
49#[derive(Debug, Serialize)]
51struct QueryLabelsResponse {
52 labels: Vec<SignedLabel>,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 cursor: Option<String>,
57}
58
59async fn query_labels_handler(
65 Extension(service): Extension<Arc<LabelService>>,
66 Query(params): Query<QueryLabelsParams>,
67) -> Result<Json<QueryLabelsResponse>, XrpcError> {
68 let limit = params.limit.unwrap_or(50).clamp(1, 250);
69
70 if let Some(ref cursor_str) = params.cursor {
72 let cursor_id: i64 = cursor_str.parse().map_err(|_| XrpcError {
73 name: XrpcErrorName::InvalidRequest,
74 message: "Invalid cursor value".to_string(),
75 })?;
76
77 let results = service.query_since(cursor_id, limit).await.map_err(|e| {
78 tracing::error!(error = %e, "Failed to query labels since cursor");
79 XrpcError {
80 name: XrpcErrorName::InternalServerError,
81 message: "Failed to query labels".to_string(),
82 }
83 })?;
84
85 let next_cursor = results.last().map(|(id, _)| id.to_string());
86 let mut labels: Vec<SignedLabel> = results.into_iter().map(|(_, label)| label).collect();
87
88 if !params.sources.is_empty() {
90 labels.retain(|l| params.sources.contains(&l.label.src));
91 }
92
93 return Ok(Json(QueryLabelsResponse {
94 labels,
95 cursor: next_cursor,
96 }));
97 }
98
99 if params.uri_patterns.is_empty() {
101 return Err(XrpcError {
102 name: XrpcErrorName::InvalidRequest,
103 message: "At least one uriPatterns value or a cursor is required".to_string(),
104 });
105 }
106
107 let mut all_labels = Vec::new();
108 for pattern in ¶ms.uri_patterns {
109 let labels = service.query_labels(pattern).await.map_err(|e| {
110 tracing::error!(error = %e, uri = %pattern, "Failed to query labels by URI");
111 XrpcError {
112 name: XrpcErrorName::InternalServerError,
113 message: "Failed to query labels".to_string(),
114 }
115 })?;
116 all_labels.extend(labels);
117 }
118
119 if !params.sources.is_empty() {
121 all_labels.retain(|l| params.sources.contains(&l.label.src));
122 }
123
124 all_labels.truncate(limit as usize);
126
127 Ok(Json(QueryLabelsResponse {
128 labels: all_labels,
129 cursor: None,
130 }))
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::signing::LabelSigner;
137 use crate::types::LabelValue;
138 use atrg_core::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
139 use axum::body::Body;
140 use http_body_util::BodyExt;
141 use hyper::Request;
142 use tower::ServiceExt;
143
144 async fn test_state() -> AppState {
145 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
146 atrg_db::run_internal_migrations(&db).await.unwrap();
147 AppState {
148 config: Arc::new(Config {
149 app: AppConfig {
150 name: "test".into(),
151 host: "127.0.0.1".into(),
152 port: 3000,
153 secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
154 cors_origins: vec![],
155 environment: "development".into(),
156 },
157 auth: AuthConfig::default(),
158 database: DatabaseConfig::default(),
159 jetstream: None,
160 firehose: None,
161 feed_generator: None,
162 labeler: None,
163 rate_limit: None,
164 }),
165 db,
166 http: reqwest::Client::new(),
167 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
168 reqwest::Client::new(),
169 )),
170 }
171 }
172
173 async fn setup_service() -> Arc<LabelService> {
174 let db = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
175 let signer = LabelSigner::new(b"test-key".to_vec());
176 let svc = LabelService::new(db, signer, "did:plc:test-labeler".to_string());
177 svc.migrate().await.unwrap();
178 Arc::new(svc)
179 }
180
181 async fn build_app(service: Arc<LabelService>) -> axum::Router {
182 let state = test_state().await;
183 labeler_routes(service).with_state(state)
184 }
185
186 fn parse_body(bytes: &[u8]) -> serde_json::Value {
187 serde_json::from_slice(bytes).unwrap()
188 }
189
190 async fn get_labels(app: axum::Router, query: &str) -> (u16, serde_json::Value) {
192 let uri = format!("/xrpc/com.atproto.label.queryLabels{}", query);
193 let resp = app
194 .oneshot(Request::get(&uri).body(Body::empty()).unwrap())
195 .await
196 .unwrap();
197 let status = resp.status().as_u16();
198 let bytes = Body::new(resp.into_body())
199 .collect()
200 .await
201 .unwrap()
202 .to_bytes();
203 (status, parse_body(&bytes))
204 }
205
206 #[tokio::test]
209 async fn test_query_labels_returns_labels() {
210 let svc = setup_service().await;
211 svc.create_label("at://did:plc:user/post/1", LabelValue::Spam, None)
212 .await
213 .unwrap();
214 svc.create_label("at://did:plc:user/post/1", LabelValue::Porn, None)
215 .await
216 .unwrap();
217
218 let app = build_app(svc).await;
219 let (status, body) = get_labels(app, "?cursor=0&limit=10").await;
220
221 assert_eq!(status, 200);
222 let labels = body["labels"].as_array().unwrap();
223 assert_eq!(labels.len(), 2);
224 assert_eq!(labels[0]["val"], "spam");
225 assert_eq!(labels[1]["val"], "porn");
226 assert!(body["cursor"].as_str().is_some());
228 }
229
230 #[tokio::test]
231 async fn test_query_labels_with_cursor() {
232 let svc = setup_service().await;
233 for i in 0..5 {
234 svc.create_label(
235 "at://did:plc:user/post/1",
236 LabelValue::Custom(format!("val-{}", i)),
237 None,
238 )
239 .await
240 .unwrap();
241 }
242
243 let app = build_app(Arc::clone(&svc)).await;
245 let (status, body) = get_labels(app, "?cursor=0&limit=3").await;
246
247 assert_eq!(status, 200);
248 let labels = body["labels"].as_array().unwrap();
249 assert_eq!(labels.len(), 3);
250 assert_eq!(labels[0]["val"], "val-0");
251 assert_eq!(labels[2]["val"], "val-2");
252 let cursor = body["cursor"].as_str().unwrap();
253
254 let app2 = build_app(svc).await;
256 let (status2, body2) = get_labels(app2, &format!("?cursor={}&limit=3", cursor)).await;
257
258 assert_eq!(status2, 200);
259 let labels2 = body2["labels"].as_array().unwrap();
260 assert_eq!(labels2.len(), 2);
261 assert_eq!(labels2[0]["val"], "val-3");
262 assert_eq!(labels2[1]["val"], "val-4");
263 }
264
265 #[tokio::test]
266 async fn test_query_labels_empty() {
267 let svc = setup_service().await;
268 let app = build_app(svc).await;
269
270 let (status, body) = get_labels(app, "?cursor=0&limit=10").await;
272
273 assert_eq!(status, 200);
274 let labels = body["labels"].as_array().unwrap();
275 assert!(labels.is_empty());
276 assert!(body.get("cursor").is_none() || body["cursor"].is_null());
278 }
279
280 #[tokio::test]
281 async fn test_query_labels_default_limit() {
282 let svc = setup_service().await;
283 for i in 0..60 {
285 svc.create_label(
286 "at://did:plc:user/post/1",
287 LabelValue::Custom(format!("v{}", i)),
288 None,
289 )
290 .await
291 .unwrap();
292 }
293
294 let app = build_app(svc).await;
296 let (status, body) = get_labels(app, "?cursor=0").await;
297
298 assert_eq!(status, 200);
299 let labels = body["labels"].as_array().unwrap();
300 assert_eq!(labels.len(), 50);
301 }
302
303 #[tokio::test]
306 async fn test_query_labels_no_patterns_no_cursor_returns_error() {
307 let svc = setup_service().await;
308 let app = build_app(svc).await;
309
310 let (status, body) = get_labels(app, "").await;
311
312 assert_eq!(status, 400);
313 assert_eq!(body["error"], "InvalidRequest");
314 }
315
316 #[tokio::test]
317 async fn test_query_labels_invalid_cursor_returns_error() {
318 let svc = setup_service().await;
319 let app = build_app(svc).await;
320
321 let (status, body) = get_labels(app, "?cursor=not-a-number").await;
322
323 assert_eq!(status, 400);
324 assert_eq!(body["error"], "InvalidRequest");
325 assert!(body["message"].as_str().unwrap().contains("Invalid cursor"));
326 }
327
328 #[tokio::test]
331 async fn test_query_labels_limit_clamped_to_max_250() {
332 let svc = setup_service().await;
333 for i in 0..260 {
334 svc.create_label(
335 "at://did:plc:user/post/1",
336 LabelValue::Custom(format!("v{}", i)),
337 None,
338 )
339 .await
340 .unwrap();
341 }
342
343 let app = build_app(svc).await;
344 let (status, body) = get_labels(app, "?cursor=0&limit=999").await;
345
346 assert_eq!(status, 200);
347 let labels = body["labels"].as_array().unwrap();
348 assert_eq!(labels.len(), 250);
350 }
351
352 #[tokio::test]
353 async fn test_query_labels_limit_clamped_to_min_1() {
354 let svc = setup_service().await;
355 svc.create_label("at://did:plc:user/post/1", LabelValue::Spam, None)
356 .await
357 .unwrap();
358
359 let app = build_app(svc).await;
360 let (status, body) = get_labels(app, "?cursor=0&limit=0").await;
361
362 assert_eq!(status, 200);
363 let labels = body["labels"].as_array().unwrap();
364 assert_eq!(labels.len(), 1);
365 }
366
367 #[tokio::test]
370 async fn test_query_labels_response_contains_label_fields() {
371 let svc = setup_service().await;
372 svc.create_label("at://did:plc:user/post/1", LabelValue::Spam, None)
373 .await
374 .unwrap();
375
376 let app = build_app(svc).await;
377 let (status, body) = get_labels(app, "?cursor=0&limit=1").await;
378
379 assert_eq!(status, 200);
380 let label = &body["labels"][0];
381 assert_eq!(label["src"], "did:plc:test-labeler");
382 assert_eq!(label["uri"], "at://did:plc:user/post/1");
383 assert_eq!(label["val"], "spam");
384 assert_eq!(label["neg"], false);
385 assert_eq!(label["ver"], 1);
386 assert!(!label["sig"].as_str().unwrap().is_empty());
388 assert!(label["cts"].as_str().is_some());
390 }
391}