1use std::path::Path;
18use std::sync::Arc;
19
20use atrg_db::DbPool;
21use axum::response::IntoResponse;
22use axum::routing::any;
23use axum::Router;
24use futures::future::BoxFuture;
25use tower_http::trace::TraceLayer;
26
27use crate::config::Config;
28use crate::cors::build_cors_layer;
29use crate::error::AtrgError;
30use crate::state::{AppState, Extensions};
31
32type CleanupFn = Box<dyn FnOnce(DbPool) + Send>;
35
36pub struct AtrgApp {
39 router: Router<AppState>,
40 builtin_router: Option<Router<AppState>>,
42 cleanup_fn: Option<CleanupFn>,
44 user_db_pool: Option<DbPool>,
47 event_handler: Option<atrg_stream::EventHandler<AppState>>,
49 #[cfg(feature = "firehose")]
51 firehose_handler: Option<atrg_firehose::FirehoseHandler<AppState>>,
52 extensions: Extensions,
54}
55
56impl AtrgApp {
57 pub fn new() -> Self {
59 Self {
60 router: Router::new(),
61 builtin_router: None,
62 cleanup_fn: None,
63 user_db_pool: None,
64 event_handler: None,
65 #[cfg(feature = "firehose")]
66 firehose_handler: None,
67 extensions: Extensions::new(),
68 }
69 }
70
71 pub fn mount(mut self, router: Router<AppState>) -> Self {
75 self.router = self.router.merge(router);
76 self
77 }
78
79 pub fn with_auth_routes(mut self, router: Router<AppState>) -> Self {
95 self.builtin_router = Some(router);
96 self
97 }
98
99 pub fn with_cleanup_task<F>(mut self, f: F) -> Self
105 where
106 F: FnOnce(DbPool) + Send + 'static,
107 {
108 self.cleanup_fn = Some(Box::new(f));
109 self
110 }
111
112 pub fn with_db_pool(mut self, pool: impl Into<DbPool>) -> Self {
135 self.user_db_pool = Some(pool.into());
136 self
137 }
138
139 pub fn with_extension<T: Send + Sync + 'static>(mut self, value: T) -> Self {
162 self.extensions.insert(value);
163 self
164 }
165
166 pub fn on_event<F, Fut>(mut self, handler: F) -> Self
172 where
173 F: Fn(atrg_stream::JetstreamEvent, AppState) -> Fut + Send + Sync + 'static,
174 Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
175 {
176 self.event_handler = Some(Arc::new(move |event, state| {
177 Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
178 }));
179 self
180 }
181
182 #[cfg(feature = "firehose")]
189 pub fn on_firehose_event<F, Fut>(mut self, handler: F) -> Self
190 where
191 F: Fn(atrg_firehose::FirehoseEvent, AppState) -> Fut + Send + Sync + 'static,
192 Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
193 {
194 self.firehose_handler = Some(std::sync::Arc::new(move |event, state| {
195 Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
196 }));
197 self
198 }
199
200 pub fn with_feed_generator(self, feed_router: Router<AppState>) -> Self {
214 self.mount(feed_router)
215 }
216
217 pub fn with_labeler(self, labeler_router: Router<AppState>) -> Self {
231 self.mount(labeler_router)
232 }
233
234 pub async fn run(self) -> anyhow::Result<()> {
246 let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
248 .unwrap_or_else(|_| {
249 tracing_subscriber::EnvFilter::new(
250 "info,atrg_core=debug,atrg_db=debug,atrg_auth=debug,atrg_cli=debug,tower_http=debug",
251 )
252 });
253
254 let _ = tracing_subscriber::fmt()
257 .with_env_filter(env_filter)
258 .try_init();
259
260 let config_path = std::env::var("ATRG_CONFIG").unwrap_or_else(|_| "./atrg.toml".into());
262 tracing::info!(path = %config_path, "loading configuration");
263 let config = Config::load(&config_path)?;
264 let config = Arc::new(config);
265
266 let db = match self.user_db_pool {
268 Some(pool) => {
269 tracing::info!(
270 backend = pool.backend(),
271 "using caller-supplied database pool (bypassing [database] url)"
272 );
273 pool
274 }
275 None => atrg_db::connect(&config.database.url).await?,
276 };
277 atrg_db::run_internal_migrations(&db).await?;
278
279 let user_migrations = Path::new("./migrations");
280 if user_migrations.is_dir() {
281 atrg_db::run_isolated_migrations(&db, user_migrations, "_app_migrations").await?;
282 }
283
284 let http = reqwest::Client::builder()
286 .user_agent(format!("atrg/{}", crate::version()))
287 .build()?;
288
289 let identity = Arc::new(atrg_identity::IdentityResolver::with_defaults(http.clone()));
291
292 let state = AppState {
294 config: config.clone(),
295 db,
296 http,
297 identity,
298 extensions: Arc::new(self.extensions),
299 };
300
301 if !config.app.admin_dids.is_empty() {
303 for did in &config.app.admin_dids {
304 let now = std::time::SystemTime::now()
305 .duration_since(std::time::UNIX_EPOCH)
306 .unwrap_or_default()
307 .as_secs()
308 .to_string();
309 let result: Result<(), sqlx::Error> = match &state.db {
310 #[cfg(feature = "sqlite")]
311 atrg_db::DbPool::Sqlite(p) => {
312 sqlx::query(
313 "INSERT OR IGNORE INTO atrg_roles (did, role, granted_by, granted_at) VALUES (?1, 'admin', 'system:bootstrap', ?2)"
314 ).bind(did).bind(&now).execute(p).await.map(|_| ())
315 }
316 #[cfg(feature = "postgres")]
317 atrg_db::DbPool::Postgres(p) => {
318 sqlx::query(
319 "INSERT INTO atrg_roles (did, role, granted_by, granted_at) VALUES ($1, 'admin', 'system:bootstrap', $2) ON CONFLICT DO NOTHING"
320 ).bind(did).bind(&now).execute(p).await.map(|_| ())
321 }
322 #[allow(unreachable_patterns)]
323 _ => Ok(()),
324 };
325 match result {
326 Ok(_) => tracing::info!(did = %did, "auto-provisioned admin DID"),
327 Err(e) => {
328 tracing::warn!(did = %did, error = %e, "failed to bootstrap admin DID (table may not exist yet)")
329 }
330 }
331 }
332 }
333
334 let cors = build_cors_layer(&config.app.cors_origins);
336
337 let mut router = Router::new();
339
340 router = router
342 .route("/healthz", axum::routing::get(crate::health::healthz))
343 .route("/readyz", axum::routing::get(crate::health::readyz));
344
345 if let Some(builtin) = self.builtin_router {
347 router = router.merge(builtin);
348 }
349
350 let mut router = router
351 .merge(self.router)
353 .fallback(any(fallback_not_found))
355 .with_state(state.clone())
356 .layer(cors)
357 .layer(axum::middleware::from_fn(
358 crate::request_id::request_id_middleware,
359 ))
360 .layer(TraceLayer::new_for_http());
361
362 if config.app.environment != "development" {
364 router = router.layer(axum::middleware::from_fn(
365 crate::security::security_headers_middleware,
366 ));
367 }
368
369 if let Some(ref rl_config) = config.rate_limit {
371 if rl_config.enabled {
372 let limiter =
373 crate::rate_limit::RateLimiter::new(crate::rate_limit::RateLimitConfig {
374 requests_per_second: rl_config.requests_per_second,
375 burst: rl_config.burst,
376 enabled: true,
377 });
378
379 let limiter_cleanup = limiter.clone();
381 tokio::spawn(async move {
382 let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
383 loop {
384 interval.tick().await;
385 limiter_cleanup
386 .cleanup(std::time::Duration::from_secs(600))
387 .await;
388 }
389 });
390
391 router = router.layer(axum::middleware::from_fn(
392 move |req: axum::extract::Request, next: axum::middleware::Next| {
393 let limiter = limiter.clone();
394 async move {
395 let ip = req
397 .extensions()
398 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
399 .map(|ci| ci.0.ip())
400 .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
401
402 match limiter.check(ip).await {
403 Ok(()) => next.run(req).await,
404 Err(retry_after) => {
405 crate::rate_limit::rate_limit_response(retry_after)
406 }
407 }
408 }
409 },
410 ));
411
412 tracing::info!(
413 rps = rl_config.requests_per_second,
414 burst = rl_config.burst,
415 "rate limiting enabled"
416 );
417 }
418 }
419
420 if let Some(ref js_config) = config.jetstream {
422 if let Some(handler) = self.event_handler {
423 let stream_config = atrg_stream::StreamConfig {
424 host: js_config.host.clone(),
425 collections: js_config.collections.clone(),
426 zstd_dict: js_config.zstd_dict.clone(),
427 channel_capacity: js_config.channel_capacity,
428 max_lag_events: js_config.max_lag_events,
429 cursor: None,
430 };
431 atrg_stream::spawn_consumer(&stream_config, state.clone(), handler).await?;
432 } else {
433 tracing::warn!("jetstream configured but no on_event handler registered");
434 }
435 }
436
437 #[cfg(feature = "firehose")]
439 if let Some(ref fh_config) = config.firehose {
440 if let Some(handler) = self.firehose_handler {
441 let firehose_config = atrg_firehose::FirehoseConfig {
442 relay: fh_config.relay.clone(),
443 cursor: fh_config.cursor,
444 channel_capacity: fh_config.channel_capacity,
445 };
446 atrg_firehose::spawn_firehose(&firehose_config, state.clone(), handler).await?;
447 } else {
448 tracing::warn!("firehose configured but no on_firehose_event handler registered");
449 }
450 }
451
452 if let Some(cleanup) = self.cleanup_fn {
454 cleanup(state.db.clone());
455 }
456
457 let addr = format!("{}:{}", config.app.host, config.app.port);
459 tracing::info!(addr = %addr, name = %config.app.name, "at-rust-go API serving");
460 let listener = tokio::net::TcpListener::bind(&addr).await?;
461 axum::serve(listener, router).await?;
462
463 Ok(())
464 }
465}
466
467impl Default for AtrgApp {
468 fn default() -> Self {
469 Self::new()
470 }
471}
472
473async fn fallback_not_found() -> impl IntoResponse {
475 AtrgError::NotFound
476}
477
478#[cfg(test)]
483pub(crate) fn build_test_router(user_router: Router<AppState>, state: AppState) -> Router {
484 build_test_router_with_auth(None, user_router, state)
485}
486
487#[cfg(test)]
489pub(crate) fn build_test_router_with_auth(
490 auth_router: Option<Router<AppState>>,
491 user_router: Router<AppState>,
492 state: AppState,
493) -> Router {
494 let cors = build_cors_layer(&state.config.app.cors_origins);
495
496 let mut router = Router::new();
497 if let Some(auth) = auth_router {
498 router = router.merge(auth);
499 }
500
501 router
502 .merge(user_router)
503 .fallback(any(fallback_not_found))
504 .with_state(state)
505 .layer(cors)
506 .layer(TraceLayer::new_for_http())
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
513 use axum::body::Body;
514 use axum::routing::get;
515 use axum::Json;
516 use http_body_util::BodyExt;
517 use hyper::Request;
518 use tower::ServiceExt;
519
520 async fn test_state() -> AppState {
522 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
523 atrg_db::run_internal_migrations(&db).await.unwrap();
524
525 let config = Config {
526 app: AppConfig {
527 name: "test-app".into(),
528 host: "127.0.0.1".into(),
529 port: 3000,
530 secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
531 cors_origins: vec![],
532 environment: "development".into(),
533 admin_dids: vec![],
534 },
535 auth: AuthConfig {
536 client_id: "http://localhost:3000/client-metadata.json".into(),
537 redirect_uri: "http://localhost:3000/auth/callback".into(),
538 scope: "atproto transition:generic".into(),
539 post_login_redirect: "/".into(),
540 },
541 database: DatabaseConfig {
542 url: "sqlite::memory:".into(),
543 },
544 jetstream: None,
545 firehose: None,
546 feed_generator: None,
547 labeler: None,
548 rate_limit: None,
549 };
550
551 AppState {
552 config: Arc::new(config),
553 db,
554 http: reqwest::Client::new(),
555 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
556 reqwest::Client::new(),
557 )),
558 extensions: Arc::new(Extensions::new()),
559 }
560 }
561
562 async fn body_bytes(response: axum::response::Response) -> Vec<u8> {
564 response
565 .into_body()
566 .collect()
567 .await
568 .unwrap()
569 .to_bytes()
570 .to_vec()
571 }
572
573 #[test]
574 fn atrg_app_default_is_new() {
575 let _app = AtrgApp::default();
577 }
578
579 #[test]
580 fn on_event_sets_handler() {
581 let app = AtrgApp::new().on_event(|_event, _state| async { Ok(()) });
582 assert!(app.event_handler.is_some());
583 }
584
585 #[tokio::test]
586 async fn with_db_pool_stores_caller_pool() {
587 let pool = atrg_db::connect("sqlite::memory:").await.unwrap();
589 let app = AtrgApp::new().with_db_pool(pool.clone());
590 assert!(app.user_db_pool.is_some());
591 assert_eq!(app.user_db_pool.as_ref().unwrap().backend(), "sqlite");
592 }
593
594 #[tokio::test]
595 async fn readyz_reports_backend_kind() {
596 let state = test_state().await;
599 let app: Router = Router::new()
600 .route("/readyz", get(crate::health::readyz))
601 .with_state(state);
602
603 let req = Request::builder()
604 .uri("/readyz")
605 .body(Body::empty())
606 .unwrap();
607 let resp = app.oneshot(req).await.unwrap();
608 assert_eq!(resp.status(), 200);
609
610 let bytes = body_bytes(resp).await;
611 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
612 assert_eq!(body["database_backend"], "sqlite");
613 }
614
615 #[tokio::test]
616 async fn mount_ping_returns_200_json() {
617 let state = test_state().await;
618
619 let user_router: Router<AppState> = Router::new().route(
620 "/ping",
621 get(|| async { Json(serde_json::json!({"pong": true})) }),
622 );
623
624 let app = build_test_router(user_router, state);
625
626 let request = Request::builder().uri("/ping").body(Body::empty()).unwrap();
627
628 let response = app.oneshot(request).await.unwrap();
629 assert_eq!(response.status(), 200);
630
631 let ct = response
632 .headers()
633 .get("content-type")
634 .unwrap()
635 .to_str()
636 .unwrap();
637 assert!(
638 ct.contains("application/json"),
639 "expected application/json, got {ct}"
640 );
641
642 let bytes = body_bytes(response).await;
643 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
644 assert_eq!(body["pong"], true);
645 }
646
647 #[tokio::test]
648 async fn unknown_route_returns_404_json() {
649 let state = test_state().await;
650 let app = build_test_router(Router::new(), state);
651
652 let request = Request::builder()
653 .uri("/does-not-exist")
654 .body(Body::empty())
655 .unwrap();
656
657 let response = app.oneshot(request).await.unwrap();
658 assert_eq!(response.status(), 404);
659
660 let ct = response
661 .headers()
662 .get("content-type")
663 .unwrap()
664 .to_str()
665 .unwrap();
666 assert!(
667 ct.contains("application/json"),
668 "expected application/json, got {ct}"
669 );
670
671 let bytes = body_bytes(response).await;
672 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
673 assert_eq!(body["error"], "not_found");
674 assert_eq!(body["message"], "Not found");
675 }
676
677 #[tokio::test]
678 async fn multiple_mounts_accumulate_routes() {
679 let state = test_state().await;
680
681 let r1: Router<AppState> = Router::new().route(
682 "/a",
683 get(|| async { Json(serde_json::json!({"route": "a"})) }),
684 );
685 let r2: Router<AppState> = Router::new().route(
686 "/b",
687 get(|| async { Json(serde_json::json!({"route": "b"})) }),
688 );
689
690 let app = build_test_router(r1.merge(r2), state);
691
692 let resp_a = app
693 .clone()
694 .oneshot(Request::builder().uri("/a").body(Body::empty()).unwrap())
695 .await
696 .unwrap();
697 assert_eq!(resp_a.status(), 200);
698
699 let resp_b = app
700 .oneshot(Request::builder().uri("/b").body(Body::empty()).unwrap())
701 .await
702 .unwrap();
703 assert_eq!(resp_b.status(), 200);
704 }
705
706 #[tokio::test]
707 async fn with_auth_routes_merges_builtin() {
708 let state = test_state().await;
709
710 let auth_router: Router<AppState> = Router::new().route(
712 "/auth/test",
713 get(|| async { Json(serde_json::json!({"auth": true})) }),
714 );
715
716 let user_router: Router<AppState> = Router::new().route(
717 "/ping",
718 get(|| async { Json(serde_json::json!({"pong": true})) }),
719 );
720
721 let app = build_test_router_with_auth(Some(auth_router), user_router, state);
722
723 let resp = app
725 .clone()
726 .oneshot(
727 Request::builder()
728 .uri("/auth/test")
729 .body(Body::empty())
730 .unwrap(),
731 )
732 .await
733 .unwrap();
734 assert_eq!(resp.status(), 200);
735
736 let resp = app
738 .oneshot(Request::builder().uri("/ping").body(Body::empty()).unwrap())
739 .await
740 .unwrap();
741 assert_eq!(resp.status(), 200);
742 }
743
744 #[tokio::test]
745 async fn with_extension_is_accessible_from_state() {
746 struct MyConfig {
747 magic_number: u64,
748 }
749
750 let state = test_state().await;
751 let mut ext = Extensions::new();
756 ext.insert(MyConfig { magic_number: 42 });
757
758 let state_with_ext = AppState {
759 config: state.config.clone(),
760 db: state.db.clone(),
761 http: state.http.clone(),
762 identity: state.identity.clone(),
763 extensions: Arc::new(ext),
764 };
765
766 let app: Router = Router::new()
767 .route(
768 "/magic",
769 get(
770 |axum::extract::State(s): axum::extract::State<AppState>| async move {
771 let cfg = s.extension::<MyConfig>();
772 Json(serde_json::json!({ "magic": cfg.magic_number }))
773 },
774 ),
775 )
776 .with_state(state_with_ext);
777
778 let resp = app
779 .oneshot(
780 Request::builder()
781 .uri("/magic")
782 .body(Body::empty())
783 .unwrap(),
784 )
785 .await
786 .unwrap();
787 assert_eq!(resp.status(), 200);
788
789 let body = body_bytes(resp).await;
790 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
791 assert_eq!(json["magic"], 42);
792 }
793
794 #[test]
795 fn with_extension_builder_accumulates_values() {
796 struct Foo(u32);
797 struct Bar(String);
798
799 let app = AtrgApp::new()
800 .with_extension(Foo(7))
801 .with_extension(Bar("baz".into()));
802
803 assert_eq!(app.extensions.get::<Foo>().unwrap().0, 7);
804 assert_eq!(app.extensions.get::<Bar>().unwrap().0, "baz");
805 }
806}