1use std::path::Path;
18use std::sync::Arc;
19
20use axum::response::IntoResponse;
21use axum::routing::any;
22use axum::Router;
23use futures::future::BoxFuture;
24use sqlx::SqlitePool;
25use tower_http::trace::TraceLayer;
26
27use crate::config::Config;
28use crate::cors::build_cors_layer;
29use crate::error::AtrgError;
30use crate::state::AppState;
31
32type CleanupFn = Box<dyn FnOnce(SqlitePool) + Send>;
35
36pub struct AtrgApp {
39 router: Router<AppState>,
40 builtin_router: Option<Router<AppState>>,
42 cleanup_fn: Option<CleanupFn>,
44 event_handler: Option<atrg_stream::EventHandler<AppState>>,
46 #[cfg(feature = "firehose")]
48 firehose_handler: Option<atrg_firehose::FirehoseHandler<AppState>>,
49}
50
51impl AtrgApp {
52 pub fn new() -> Self {
54 Self {
55 router: Router::new(),
56 builtin_router: None,
57 cleanup_fn: None,
58 event_handler: None,
59 #[cfg(feature = "firehose")]
60 firehose_handler: None,
61 }
62 }
63
64 pub fn mount(mut self, router: Router<AppState>) -> Self {
68 self.router = self.router.merge(router);
69 self
70 }
71
72 pub fn with_auth_routes(mut self, router: Router<AppState>) -> Self {
88 self.builtin_router = Some(router);
89 self
90 }
91
92 pub fn with_cleanup_task<F>(mut self, f: F) -> Self
98 where
99 F: FnOnce(SqlitePool) + Send + 'static,
100 {
101 self.cleanup_fn = Some(Box::new(f));
102 self
103 }
104
105 pub fn on_event<F, Fut>(mut self, handler: F) -> Self
111 where
112 F: Fn(atrg_stream::JetstreamEvent, AppState) -> Fut + Send + Sync + 'static,
113 Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
114 {
115 self.event_handler = Some(Arc::new(move |event, state| {
116 Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
117 }));
118 self
119 }
120
121 #[cfg(feature = "firehose")]
128 pub fn on_firehose_event<F, Fut>(mut self, handler: F) -> Self
129 where
130 F: Fn(atrg_firehose::FirehoseEvent, AppState) -> Fut + Send + Sync + 'static,
131 Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
132 {
133 self.firehose_handler = Some(std::sync::Arc::new(move |event, state| {
134 Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
135 }));
136 self
137 }
138
139 pub fn with_feed_generator(self, feed_router: Router<AppState>) -> Self {
153 self.mount(feed_router)
154 }
155
156 pub fn with_labeler(self, labeler_router: Router<AppState>) -> Self {
170 self.mount(labeler_router)
171 }
172
173 pub async fn run(self) -> anyhow::Result<()> {
185 let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
187 .unwrap_or_else(|_| {
188 tracing_subscriber::EnvFilter::new(
189 "info,atrg_core=debug,atrg_db=debug,atrg_auth=debug,atrg_cli=debug,tower_http=debug",
190 )
191 });
192
193 let _ = tracing_subscriber::fmt()
196 .with_env_filter(env_filter)
197 .try_init();
198
199 let config_path = std::env::var("ATRG_CONFIG").unwrap_or_else(|_| "./atrg.toml".into());
201 tracing::info!(path = %config_path, "loading configuration");
202 let config = Config::load(&config_path)?;
203 let config = Arc::new(config);
204
205 let db = atrg_db::connect(&config.database.url).await?;
207 atrg_db::run_internal_migrations(&db).await?;
208
209 let user_migrations = Path::new("./migrations");
210 if user_migrations.is_dir() {
211 atrg_db::run_user_migrations(&db, user_migrations).await?;
212 }
213
214 let http = reqwest::Client::builder()
216 .user_agent(format!("atrg/{}", crate::version()))
217 .build()?;
218
219 let identity = Arc::new(atrg_identity::IdentityResolver::with_defaults(http.clone()));
221
222 let state = AppState {
224 config: config.clone(),
225 db,
226 http,
227 identity,
228 };
229
230 let cors = build_cors_layer(&config.app.cors_origins);
232
233 let mut router = Router::new();
235
236 router = router
238 .route("/healthz", axum::routing::get(crate::health::healthz))
239 .route("/readyz", axum::routing::get(crate::health::readyz));
240
241 if let Some(builtin) = self.builtin_router {
243 router = router.merge(builtin);
244 }
245
246 let mut router = router
247 .merge(self.router)
249 .fallback(any(fallback_not_found))
251 .with_state(state.clone())
252 .layer(cors)
253 .layer(axum::middleware::from_fn(
254 crate::request_id::request_id_middleware,
255 ))
256 .layer(TraceLayer::new_for_http());
257
258 if config.app.environment != "development" {
260 router = router.layer(axum::middleware::from_fn(
261 crate::security::security_headers_middleware,
262 ));
263 }
264
265 if let Some(ref js_config) = config.jetstream {
267 if let Some(handler) = self.event_handler {
268 let stream_config = atrg_stream::StreamConfig {
269 host: js_config.host.clone(),
270 collections: js_config.collections.clone(),
271 zstd_dict: js_config.zstd_dict.clone(),
272 channel_capacity: js_config.channel_capacity,
273 max_lag_events: js_config.max_lag_events,
274 };
275 atrg_stream::spawn_consumer(&stream_config, state.clone(), handler).await?;
276 } else {
277 tracing::warn!("jetstream configured but no on_event handler registered");
278 }
279 }
280
281 #[cfg(feature = "firehose")]
283 if let Some(ref fh_config) = config.firehose {
284 if let Some(handler) = self.firehose_handler {
285 let firehose_config = atrg_firehose::FirehoseConfig {
286 relay: fh_config.relay.clone(),
287 cursor: fh_config.cursor,
288 channel_capacity: fh_config.channel_capacity,
289 };
290 atrg_firehose::spawn_firehose(&firehose_config, state.clone(), handler).await?;
291 } else {
292 tracing::warn!("firehose configured but no on_firehose_event handler registered");
293 }
294 }
295
296 if let Some(cleanup) = self.cleanup_fn {
298 cleanup(state.db.clone());
299 }
300
301 let addr = format!("{}:{}", config.app.host, config.app.port);
303 tracing::info!(addr = %addr, name = %config.app.name, "at-rust-go API serving");
304 let listener = tokio::net::TcpListener::bind(&addr).await?;
305 axum::serve(listener, router).await?;
306
307 Ok(())
308 }
309}
310
311impl Default for AtrgApp {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317async fn fallback_not_found() -> impl IntoResponse {
319 AtrgError::NotFound
320}
321
322#[cfg(test)]
327pub(crate) fn build_test_router(user_router: Router<AppState>, state: AppState) -> Router {
328 build_test_router_with_auth(None, user_router, state)
329}
330
331#[cfg(test)]
333pub(crate) fn build_test_router_with_auth(
334 auth_router: Option<Router<AppState>>,
335 user_router: Router<AppState>,
336 state: AppState,
337) -> Router {
338 let cors = build_cors_layer(&state.config.app.cors_origins);
339
340 let mut router = Router::new();
341 if let Some(auth) = auth_router {
342 router = router.merge(auth);
343 }
344
345 router
346 .merge(user_router)
347 .fallback(any(fallback_not_found))
348 .with_state(state)
349 .layer(cors)
350 .layer(TraceLayer::new_for_http())
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
357 use axum::body::Body;
358 use axum::routing::get;
359 use axum::Json;
360 use http_body_util::BodyExt;
361 use hyper::Request;
362 use tower::ServiceExt;
363
364 async fn test_state() -> AppState {
366 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
367 atrg_db::run_internal_migrations(&db).await.unwrap();
368
369 let config = Config {
370 app: AppConfig {
371 name: "test-app".into(),
372 host: "127.0.0.1".into(),
373 port: 3000,
374 secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
375 cors_origins: vec![],
376 environment: "development".into(),
377 },
378 auth: AuthConfig {
379 client_id: "http://localhost:3000/client-metadata.json".into(),
380 redirect_uri: "http://localhost:3000/auth/callback".into(),
381 scope: "atproto transition:generic".into(),
382 },
383 database: DatabaseConfig {
384 url: "sqlite::memory:".into(),
385 },
386 jetstream: None,
387 firehose: None,
388 feed_generator: None,
389 labeler: None,
390 rate_limit: None,
391 };
392
393 AppState {
394 config: Arc::new(config),
395 db,
396 http: reqwest::Client::new(),
397 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
398 reqwest::Client::new(),
399 )),
400 }
401 }
402
403 async fn body_bytes(response: axum::response::Response) -> Vec<u8> {
405 response
406 .into_body()
407 .collect()
408 .await
409 .unwrap()
410 .to_bytes()
411 .to_vec()
412 }
413
414 #[test]
415 fn atrg_app_default_is_new() {
416 let _app = AtrgApp::default();
418 }
419
420 #[test]
421 fn on_event_sets_handler() {
422 let app = AtrgApp::new().on_event(|_event, _state| async { Ok(()) });
423 assert!(app.event_handler.is_some());
424 }
425
426 #[tokio::test]
427 async fn mount_ping_returns_200_json() {
428 let state = test_state().await;
429
430 let user_router: Router<AppState> = Router::new().route(
431 "/ping",
432 get(|| async { Json(serde_json::json!({"pong": true})) }),
433 );
434
435 let app = build_test_router(user_router, state);
436
437 let request = Request::builder().uri("/ping").body(Body::empty()).unwrap();
438
439 let response = app.oneshot(request).await.unwrap();
440 assert_eq!(response.status(), 200);
441
442 let ct = response
443 .headers()
444 .get("content-type")
445 .unwrap()
446 .to_str()
447 .unwrap();
448 assert!(
449 ct.contains("application/json"),
450 "expected application/json, got {ct}"
451 );
452
453 let bytes = body_bytes(response).await;
454 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
455 assert_eq!(body["pong"], true);
456 }
457
458 #[tokio::test]
459 async fn unknown_route_returns_404_json() {
460 let state = test_state().await;
461 let app = build_test_router(Router::new(), state);
462
463 let request = Request::builder()
464 .uri("/does-not-exist")
465 .body(Body::empty())
466 .unwrap();
467
468 let response = app.oneshot(request).await.unwrap();
469 assert_eq!(response.status(), 404);
470
471 let ct = response
472 .headers()
473 .get("content-type")
474 .unwrap()
475 .to_str()
476 .unwrap();
477 assert!(
478 ct.contains("application/json"),
479 "expected application/json, got {ct}"
480 );
481
482 let bytes = body_bytes(response).await;
483 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
484 assert_eq!(body["error"], "not_found");
485 assert_eq!(body["message"], "Not found");
486 }
487
488 #[tokio::test]
489 async fn multiple_mounts_accumulate_routes() {
490 let state = test_state().await;
491
492 let r1: Router<AppState> = Router::new().route(
493 "/a",
494 get(|| async { Json(serde_json::json!({"route": "a"})) }),
495 );
496 let r2: Router<AppState> = Router::new().route(
497 "/b",
498 get(|| async { Json(serde_json::json!({"route": "b"})) }),
499 );
500
501 let app = build_test_router(r1.merge(r2), state);
502
503 let resp_a = app
504 .clone()
505 .oneshot(Request::builder().uri("/a").body(Body::empty()).unwrap())
506 .await
507 .unwrap();
508 assert_eq!(resp_a.status(), 200);
509
510 let resp_b = app
511 .oneshot(Request::builder().uri("/b").body(Body::empty()).unwrap())
512 .await
513 .unwrap();
514 assert_eq!(resp_b.status(), 200);
515 }
516
517 #[tokio::test]
518 async fn with_auth_routes_merges_builtin() {
519 let state = test_state().await;
520
521 let auth_router: Router<AppState> = Router::new().route(
523 "/auth/test",
524 get(|| async { Json(serde_json::json!({"auth": true})) }),
525 );
526
527 let user_router: Router<AppState> = Router::new().route(
528 "/ping",
529 get(|| async { Json(serde_json::json!({"pong": true})) }),
530 );
531
532 let app = build_test_router_with_auth(Some(auth_router), user_router, state);
533
534 let resp = app
536 .clone()
537 .oneshot(
538 Request::builder()
539 .uri("/auth/test")
540 .body(Body::empty())
541 .unwrap(),
542 )
543 .await
544 .unwrap();
545 assert_eq!(resp.status(), 200);
546
547 let resp = app
549 .oneshot(Request::builder().uri("/ping").body(Body::empty()).unwrap())
550 .await
551 .unwrap();
552 assert_eq!(resp.status(), 200);
553 }
554}