Skip to main content

atrg_feed/
routes.rs

1//! Axum route handlers for feed generator XRPC endpoints.
2//!
3//! Provides handlers for:
4//! - `app.bsky.feed.describeFeedGenerator` — lists available feeds
5//! - `app.bsky.feed.getFeedSkeleton` — returns a feed skeleton for a given feed
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use axum::extract::{Query, State};
11use axum::routing::get;
12use axum::{Json, Router};
13
14use crate::handler::{FeedHandler, FeedRequest};
15use crate::types::{DescribeFeedGeneratorResponse, FeedConfig, FeedDescription, FeedSkeleton};
16use atrg_auth::AuthUser;
17use atrg_core::AppState;
18use atrg_xrpc::XrpcError;
19
20/// Shared state for feed routes, injected via Axum extension.
21#[derive(Clone)]
22pub(crate) struct FeedState {
23    /// DID of the feed generator service.
24    pub(crate) did: String,
25    /// Registered feeds: id -> (config, handler).
26    pub(crate) feeds: Arc<HashMap<String, (FeedConfig, FeedHandler)>>,
27}
28
29/// Build the feed generator router with `describeFeedGenerator` and
30/// `getFeedSkeleton` XRPC endpoints.
31pub fn build_router(
32    did: String,
33    feeds: HashMap<String, (FeedConfig, FeedHandler)>,
34) -> Router<AppState> {
35    let feed_state = FeedState {
36        did,
37        feeds: Arc::new(feeds),
38    };
39
40    Router::new()
41        .route(
42            "/xrpc/app.bsky.feed.describeFeedGenerator",
43            get(describe_feed_generator),
44        )
45        .route(
46            "/xrpc/app.bsky.feed.getFeedSkeleton",
47            get(get_feed_skeleton),
48        )
49        .layer(axum::Extension(feed_state))
50}
51
52/// Handler for `app.bsky.feed.describeFeedGenerator`.
53///
54/// Returns a list of all feeds served by this generator.
55async fn describe_feed_generator(
56    axum::Extension(feed_state): axum::Extension<FeedState>,
57) -> Result<Json<DescribeFeedGeneratorResponse>, XrpcError> {
58    let feeds = feed_state
59        .feeds
60        .iter()
61        .map(|(id, (_config, _handler))| FeedDescription {
62            uri: format!("at://{}/app.bsky.feed.generator/{}", feed_state.did, id),
63            cid: None,
64        })
65        .collect();
66
67    Ok(Json(DescribeFeedGeneratorResponse {
68        did: feed_state.did.clone(),
69        feeds,
70    }))
71}
72
73/// Query parameters for `getFeedSkeleton`.
74#[derive(serde::Deserialize)]
75struct GetSkeletonParams {
76    /// AT-URI of the feed being requested.
77    feed: String,
78    /// Maximum number of items to return (default 50).
79    #[serde(default = "default_limit")]
80    limit: usize,
81    /// Pagination cursor.
82    cursor: Option<String>,
83}
84
85/// Default limit for feed skeleton requests.
86fn default_limit() -> usize {
87    50
88}
89
90/// Handler for `app.bsky.feed.getFeedSkeleton`.
91///
92/// Extracts the feed ID from the AT-URI, looks up the registered handler,
93/// and delegates skeleton generation.
94async fn get_feed_skeleton(
95    State(app_state): State<AppState>,
96    axum::Extension(feed_state): axum::Extension<FeedState>,
97    AuthUser(user): AuthUser,
98    Query(params): Query<GetSkeletonParams>,
99) -> Result<Json<FeedSkeleton>, XrpcError> {
100    // Extract the feed ID from the AT-URI.
101    // Expected format: at://did:xxx/app.bsky.feed.generator/feed-id
102    let feed_id = params
103        .feed
104        .rsplit('/')
105        .next()
106        .ok_or_else(|| atrg_xrpc::xrpc_invalid_request("invalid feed URI"))?;
107
108    let (_config, handler) = feed_state
109        .feeds
110        .get(feed_id)
111        .ok_or_else(|| atrg_xrpc::xrpc_not_found(format!("feed '{}' not found", feed_id)))?;
112
113    // Clamp limit to [1, 100]
114    let limit = params.limit.clamp(1, 100);
115
116    let request = FeedRequest {
117        feed: params.feed,
118        cursor: params.cursor,
119        limit,
120        requester_did: user.map(|u| u.did),
121    };
122
123    let skeleton = handler(request, app_state).await?;
124    Ok(Json(skeleton))
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::generator::FeedGenerator;
131    use crate::types::SkeletonItem;
132    use atrg_core::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
133    use axum::body::Body;
134    use http_body_util::BodyExt;
135    use hyper::Request;
136    use std::sync::Arc;
137    use tower::ServiceExt;
138
139    /// Build a test `AppState` with in-memory SQLite.
140    async fn test_state() -> AppState {
141        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
142        atrg_db::run_internal_migrations(&db).await.unwrap();
143        AppState {
144            config: Arc::new(Config {
145                app: AppConfig {
146                    name: "test".into(),
147                    host: "127.0.0.1".into(),
148                    port: 3000,
149                    secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
150                    cors_origins: vec![],
151                    environment: "development".into(),
152                },
153                auth: AuthConfig::default(),
154                database: DatabaseConfig::default(),
155                jetstream: None,
156                firehose: None,
157                feed_generator: None,
158                labeler: None,
159                rate_limit: None,
160            }),
161            db,
162            http: reqwest::Client::new(),
163            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
164                reqwest::Client::new(),
165            )),
166        }
167    }
168
169    /// A simple test feed handler that returns hardcoded posts.
170    async fn mock_handler(req: FeedRequest, _state: AppState) -> Result<FeedSkeleton, XrpcError> {
171        let items: Vec<SkeletonItem> = (0..req.limit)
172            .map(|i| SkeletonItem::new(format!("at://did:plc:test/app.bsky.feed.post/{}", i)))
173            .collect();
174        Ok(FeedSkeleton {
175            feed: items,
176            cursor: Some("next-cursor".to_string()),
177        })
178    }
179
180    /// Build a test app with one registered feed.
181    async fn test_app() -> (axum::Router, AppState) {
182        let state = test_state().await;
183        let router = FeedGenerator::new("did:web:feeds.test")
184            .feed("test-feed", "Test Feed", Some("A test feed"), mock_handler)
185            .into_router()
186            .with_state(state.clone());
187        (router, state)
188    }
189
190    #[tokio::test]
191    async fn describe_returns_registered_feeds() {
192        let (app, _state) = test_app().await;
193        let resp = app
194            .oneshot(
195                Request::get("/xrpc/app.bsky.feed.describeFeedGenerator")
196                    .body(Body::empty())
197                    .unwrap(),
198            )
199            .await
200            .unwrap();
201
202        assert_eq!(resp.status(), 200);
203        let bytes = Body::new(resp.into_body())
204            .collect()
205            .await
206            .unwrap()
207            .to_bytes();
208        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
209
210        assert_eq!(body["did"], "did:web:feeds.test");
211        let feeds = body["feeds"].as_array().unwrap();
212        assert_eq!(feeds.len(), 1);
213        assert_eq!(
214            feeds[0]["uri"],
215            "at://did:web:feeds.test/app.bsky.feed.generator/test-feed"
216        );
217    }
218
219    #[tokio::test]
220    async fn get_skeleton_returns_feed_items() {
221        let (app, _state) = test_app().await;
222        let uri = "/xrpc/app.bsky.feed.getFeedSkeleton?feed=at://did:web:feeds.test/app.bsky.feed.generator/test-feed&limit=3";
223        let resp = app
224            .oneshot(Request::get(uri).body(Body::empty()).unwrap())
225            .await
226            .unwrap();
227
228        assert_eq!(resp.status(), 200);
229        let bytes = Body::new(resp.into_body())
230            .collect()
231            .await
232            .unwrap()
233            .to_bytes();
234        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
235
236        let items = body["feed"].as_array().unwrap();
237        assert_eq!(items.len(), 3);
238        assert_eq!(items[0]["post"], "at://did:plc:test/app.bsky.feed.post/0");
239        assert_eq!(body["cursor"], "next-cursor");
240    }
241
242    #[tokio::test]
243    async fn get_skeleton_unknown_feed_returns_404() {
244        let (app, _state) = test_app().await;
245        let uri = "/xrpc/app.bsky.feed.getFeedSkeleton?feed=at://did:web:feeds.test/app.bsky.feed.generator/nonexistent";
246        let resp = app
247            .oneshot(Request::get(uri).body(Body::empty()).unwrap())
248            .await
249            .unwrap();
250
251        assert_eq!(resp.status(), 404);
252        let bytes = Body::new(resp.into_body())
253            .collect()
254            .await
255            .unwrap()
256            .to_bytes();
257        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
258        assert_eq!(body["error"], "NotFound");
259    }
260
261    #[tokio::test]
262    async fn get_skeleton_clamps_limit_to_max_100() {
263        let (app, _state) = test_app().await;
264        let uri = "/xrpc/app.bsky.feed.getFeedSkeleton?feed=at://did:web:feeds.test/app.bsky.feed.generator/test-feed&limit=999";
265        let resp = app
266            .oneshot(Request::get(uri).body(Body::empty()).unwrap())
267            .await
268            .unwrap();
269
270        assert_eq!(resp.status(), 200);
271        let bytes = Body::new(resp.into_body())
272            .collect()
273            .await
274            .unwrap()
275            .to_bytes();
276        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
277
278        let items = body["feed"].as_array().unwrap();
279        assert_eq!(items.len(), 100);
280    }
281
282    #[tokio::test]
283    async fn get_skeleton_clamps_limit_to_min_1() {
284        let (app, _state) = test_app().await;
285        let uri = "/xrpc/app.bsky.feed.getFeedSkeleton?feed=at://did:web:feeds.test/app.bsky.feed.generator/test-feed&limit=0";
286        let resp = app
287            .oneshot(Request::get(uri).body(Body::empty()).unwrap())
288            .await
289            .unwrap();
290
291        assert_eq!(resp.status(), 200);
292        let bytes = Body::new(resp.into_body())
293            .collect()
294            .await
295            .unwrap()
296            .to_bytes();
297        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
298
299        let items = body["feed"].as_array().unwrap();
300        assert_eq!(items.len(), 1);
301    }
302
303    #[tokio::test]
304    async fn describe_with_multiple_feeds() {
305        let state = test_state().await;
306        let router = FeedGenerator::new("did:web:feeds.test")
307            .feed("feed-a", "Feed A", None, mock_handler)
308            .feed("feed-b", "Feed B", Some("Second feed"), mock_handler)
309            .into_router()
310            .with_state(state);
311
312        let resp = router
313            .oneshot(
314                Request::get("/xrpc/app.bsky.feed.describeFeedGenerator")
315                    .body(Body::empty())
316                    .unwrap(),
317            )
318            .await
319            .unwrap();
320
321        assert_eq!(resp.status(), 200);
322        let bytes = Body::new(resp.into_body())
323            .collect()
324            .await
325            .unwrap()
326            .to_bytes();
327        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
328
329        let feeds = body["feeds"].as_array().unwrap();
330        assert_eq!(feeds.len(), 2);
331    }
332}