1use std::collections::HashMap;
8use std::sync::Arc;
9
10use axum::extract::{Query, State};
11use axum::routing::get;
12use axum::{Json, Router};
13
14use crate::handler::{FeedHandler, FeedRequest};
15use crate::types::{DescribeFeedGeneratorResponse, FeedConfig, FeedDescription, FeedSkeleton};
16use atrg_auth::AuthUser;
17use atrg_core::AppState;
18use atrg_xrpc::XrpcError;
19
20#[derive(Clone)]
22pub(crate) struct FeedState {
23 pub(crate) did: String,
25 pub(crate) feeds: Arc<HashMap<String, (FeedConfig, FeedHandler)>>,
27}
28
29pub fn build_router(
32 did: String,
33 feeds: HashMap<String, (FeedConfig, FeedHandler)>,
34) -> Router<AppState> {
35 let feed_state = FeedState {
36 did,
37 feeds: Arc::new(feeds),
38 };
39
40 Router::new()
41 .route(
42 "/xrpc/app.bsky.feed.describeFeedGenerator",
43 get(describe_feed_generator),
44 )
45 .route(
46 "/xrpc/app.bsky.feed.getFeedSkeleton",
47 get(get_feed_skeleton),
48 )
49 .layer(axum::Extension(feed_state))
50}
51
52async fn describe_feed_generator(
56 axum::Extension(feed_state): axum::Extension<FeedState>,
57) -> Result<Json<DescribeFeedGeneratorResponse>, XrpcError> {
58 let feeds = feed_state
59 .feeds
60 .iter()
61 .map(|(id, (_config, _handler))| FeedDescription {
62 uri: format!("at://{}/app.bsky.feed.generator/{}", feed_state.did, id),
63 cid: None,
64 })
65 .collect();
66
67 Ok(Json(DescribeFeedGeneratorResponse {
68 did: feed_state.did.clone(),
69 feeds,
70 }))
71}
72
73#[derive(serde::Deserialize)]
75struct GetSkeletonParams {
76 feed: String,
78 #[serde(default = "default_limit")]
80 limit: usize,
81 cursor: Option<String>,
83}
84
85fn default_limit() -> usize {
87 50
88}
89
90async fn get_feed_skeleton(
95 State(app_state): State<AppState>,
96 axum::Extension(feed_state): axum::Extension<FeedState>,
97 AuthUser(user): AuthUser,
98 Query(params): Query<GetSkeletonParams>,
99) -> Result<Json<FeedSkeleton>, XrpcError> {
100 let feed_id = params
103 .feed
104 .rsplit('/')
105 .next()
106 .ok_or_else(|| atrg_xrpc::xrpc_invalid_request("invalid feed URI"))?;
107
108 let (_config, handler) = feed_state
109 .feeds
110 .get(feed_id)
111 .ok_or_else(|| atrg_xrpc::xrpc_not_found(format!("feed '{}' not found", feed_id)))?;
112
113 let limit = params.limit.clamp(1, 100);
115
116 let request = FeedRequest {
117 feed: params.feed,
118 cursor: params.cursor,
119 limit,
120 requester_did: user.map(|u| u.did),
121 };
122
123 let skeleton = handler(request, app_state).await?;
124 Ok(Json(skeleton))
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::generator::FeedGenerator;
131 use crate::types::SkeletonItem;
132 use atrg_core::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
133 use axum::body::Body;
134 use http_body_util::BodyExt;
135 use hyper::Request;
136 use std::sync::Arc;
137 use tower::ServiceExt;
138
139 async fn test_state() -> AppState {
141 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
142 atrg_db::run_internal_migrations(&db).await.unwrap();
143 AppState {
144 config: Arc::new(Config {
145 app: AppConfig {
146 name: "test".into(),
147 host: "127.0.0.1".into(),
148 port: 3000,
149 secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
150 cors_origins: vec![],
151 environment: "development".into(),
152 },
153 auth: AuthConfig::default(),
154 database: DatabaseConfig::default(),
155 jetstream: None,
156 firehose: None,
157 feed_generator: None,
158 labeler: None,
159 rate_limit: None,
160 }),
161 db,
162 http: reqwest::Client::new(),
163 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
164 reqwest::Client::new(),
165 )),
166 }
167 }
168
169 async fn mock_handler(req: FeedRequest, _state: AppState) -> Result<FeedSkeleton, XrpcError> {
171 let items: Vec<SkeletonItem> = (0..req.limit)
172 .map(|i| SkeletonItem::new(format!("at://did:plc:test/app.bsky.feed.post/{}", i)))
173 .collect();
174 Ok(FeedSkeleton {
175 feed: items,
176 cursor: Some("next-cursor".to_string()),
177 })
178 }
179
180 async fn test_app() -> (axum::Router, AppState) {
182 let state = test_state().await;
183 let router = FeedGenerator::new("did:web:feeds.test")
184 .feed("test-feed", "Test Feed", Some("A test feed"), mock_handler)
185 .into_router()
186 .with_state(state.clone());
187 (router, state)
188 }
189
190 #[tokio::test]
191 async fn describe_returns_registered_feeds() {
192 let (app, _state) = test_app().await;
193 let resp = app
194 .oneshot(
195 Request::get("/xrpc/app.bsky.feed.describeFeedGenerator")
196 .body(Body::empty())
197 .unwrap(),
198 )
199 .await
200 .unwrap();
201
202 assert_eq!(resp.status(), 200);
203 let bytes = Body::new(resp.into_body())
204 .collect()
205 .await
206 .unwrap()
207 .to_bytes();
208 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
209
210 assert_eq!(body["did"], "did:web:feeds.test");
211 let feeds = body["feeds"].as_array().unwrap();
212 assert_eq!(feeds.len(), 1);
213 assert_eq!(
214 feeds[0]["uri"],
215 "at://did:web:feeds.test/app.bsky.feed.generator/test-feed"
216 );
217 }
218
219 #[tokio::test]
220 async fn get_skeleton_returns_feed_items() {
221 let (app, _state) = test_app().await;
222 let uri = "/xrpc/app.bsky.feed.getFeedSkeleton?feed=at://did:web:feeds.test/app.bsky.feed.generator/test-feed&limit=3";
223 let resp = app
224 .oneshot(Request::get(uri).body(Body::empty()).unwrap())
225 .await
226 .unwrap();
227
228 assert_eq!(resp.status(), 200);
229 let bytes = Body::new(resp.into_body())
230 .collect()
231 .await
232 .unwrap()
233 .to_bytes();
234 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
235
236 let items = body["feed"].as_array().unwrap();
237 assert_eq!(items.len(), 3);
238 assert_eq!(items[0]["post"], "at://did:plc:test/app.bsky.feed.post/0");
239 assert_eq!(body["cursor"], "next-cursor");
240 }
241
242 #[tokio::test]
243 async fn get_skeleton_unknown_feed_returns_404() {
244 let (app, _state) = test_app().await;
245 let uri = "/xrpc/app.bsky.feed.getFeedSkeleton?feed=at://did:web:feeds.test/app.bsky.feed.generator/nonexistent";
246 let resp = app
247 .oneshot(Request::get(uri).body(Body::empty()).unwrap())
248 .await
249 .unwrap();
250
251 assert_eq!(resp.status(), 404);
252 let bytes = Body::new(resp.into_body())
253 .collect()
254 .await
255 .unwrap()
256 .to_bytes();
257 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
258 assert_eq!(body["error"], "NotFound");
259 }
260
261 #[tokio::test]
262 async fn get_skeleton_clamps_limit_to_max_100() {
263 let (app, _state) = test_app().await;
264 let uri = "/xrpc/app.bsky.feed.getFeedSkeleton?feed=at://did:web:feeds.test/app.bsky.feed.generator/test-feed&limit=999";
265 let resp = app
266 .oneshot(Request::get(uri).body(Body::empty()).unwrap())
267 .await
268 .unwrap();
269
270 assert_eq!(resp.status(), 200);
271 let bytes = Body::new(resp.into_body())
272 .collect()
273 .await
274 .unwrap()
275 .to_bytes();
276 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
277
278 let items = body["feed"].as_array().unwrap();
279 assert_eq!(items.len(), 100);
280 }
281
282 #[tokio::test]
283 async fn get_skeleton_clamps_limit_to_min_1() {
284 let (app, _state) = test_app().await;
285 let uri = "/xrpc/app.bsky.feed.getFeedSkeleton?feed=at://did:web:feeds.test/app.bsky.feed.generator/test-feed&limit=0";
286 let resp = app
287 .oneshot(Request::get(uri).body(Body::empty()).unwrap())
288 .await
289 .unwrap();
290
291 assert_eq!(resp.status(), 200);
292 let bytes = Body::new(resp.into_body())
293 .collect()
294 .await
295 .unwrap()
296 .to_bytes();
297 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
298
299 let items = body["feed"].as_array().unwrap();
300 assert_eq!(items.len(), 1);
301 }
302
303 #[tokio::test]
304 async fn describe_with_multiple_feeds() {
305 let state = test_state().await;
306 let router = FeedGenerator::new("did:web:feeds.test")
307 .feed("feed-a", "Feed A", None, mock_handler)
308 .feed("feed-b", "Feed B", Some("Second feed"), mock_handler)
309 .into_router()
310 .with_state(state);
311
312 let resp = router
313 .oneshot(
314 Request::get("/xrpc/app.bsky.feed.describeFeedGenerator")
315 .body(Body::empty())
316 .unwrap(),
317 )
318 .await
319 .unwrap();
320
321 assert_eq!(resp.status(), 200);
322 let bytes = Body::new(resp.into_body())
323 .collect()
324 .await
325 .unwrap()
326 .to_bytes();
327 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
328
329 let feeds = body["feeds"].as_array().unwrap();
330 assert_eq!(feeds.len(), 2);
331 }
332}