Skip to main content

atrg_core/
app.rs

1//! The `AtrgApp` builder — the main entry point for assembling and running an atrg server.
2//!
3//! A minimal application looks like this:
4//!
5//! ```rust,no_run
6//! use atrg_core::AtrgApp;
7//!
8//! #[tokio::main]
9//! async fn main() -> anyhow::Result<()> {
10//!     AtrgApp::new()
11//!         .mount(axum::Router::new())
12//!         .run()
13//!         .await
14//! }
15//! ```
16
17use std::path::Path;
18use std::sync::Arc;
19
20use axum::response::IntoResponse;
21use axum::routing::any;
22use axum::Router;
23use futures::future::BoxFuture;
24use sqlx::SqlitePool;
25use tower_http::trace::TraceLayer;
26
27use crate::config::Config;
28use crate::cors::build_cors_layer;
29use crate::error::AtrgError;
30use crate::state::AppState;
31
32/// A cleanup task function that receives a database pool and spawns
33/// background maintenance work (e.g. expired session cleanup).
34type CleanupFn = Box<dyn FnOnce(SqlitePool) + Send>;
35
36/// The application builder. Accumulates user routers and configuration,
37/// then boots the full server when [`AtrgApp::run`] is called.
38pub struct AtrgApp {
39    router: Router<AppState>,
40    /// Built-in routes (auth, well-known, etc.) merged before user routes.
41    builtin_router: Option<Router<AppState>>,
42    /// Optional cleanup task spawner (e.g. session/oauth-state cleanup).
43    cleanup_fn: Option<CleanupFn>,
44    /// Jetstream event handler registered via [`AtrgApp::on_event`].
45    event_handler: Option<atrg_stream::EventHandler<AppState>>,
46    /// Firehose event handler (registered via [`AtrgApp::on_firehose_event`]).
47    #[cfg(feature = "firehose")]
48    firehose_handler: Option<atrg_firehose::FirehoseHandler<AppState>>,
49}
50
51impl AtrgApp {
52    /// Create a new, empty application builder.
53    pub fn new() -> Self {
54        Self {
55            router: Router::new(),
56            builtin_router: None,
57            cleanup_fn: None,
58            event_handler: None,
59            #[cfg(feature = "firehose")]
60            firehose_handler: None,
61        }
62    }
63
64    /// Mount an additional [`axum::Router`] into the application.
65    ///
66    /// Routes are merged, so multiple calls to `mount` accumulate routes.
67    pub fn mount(mut self, router: Router<AppState>) -> Self {
68        self.router = self.router.merge(router);
69        self
70    }
71
72    /// Register built-in auth routes (OAuth login/callback/logout, client-metadata, well-known).
73    ///
74    /// The supplied router is merged **before** user routes so that atrg's
75    /// built-in endpoints take precedence.
76    ///
77    /// # Example
78    ///
79    /// ```rust,ignore
80    /// use atrg_core::AtrgApp;
81    ///
82    /// AtrgApp::new()
83    ///     .with_auth_routes(atrg_auth::routes::auth_router())
84    ///     // ...
85    /// # ;
86    /// ```
87    pub fn with_auth_routes(mut self, router: Router<AppState>) -> Self {
88        self.builtin_router = Some(router);
89        self
90    }
91
92    /// Register a background cleanup task that is spawned after the server
93    /// starts. Typically used for periodic session / OAuth-state expiry.
94    ///
95    /// The callback receives the [`SqlitePool`] and is expected to call
96    /// `tokio::spawn` internally.
97    pub fn with_cleanup_task<F>(mut self, f: F) -> Self
98    where
99        F: FnOnce(SqlitePool) + Send + 'static,
100    {
101        self.cleanup_fn = Some(Box::new(f));
102        self
103    }
104
105    /// Register a Jetstream event handler.
106    ///
107    /// The handler is called for every event received from the Jetstream
108    /// firehose. It is spawned as a background task inside [`AtrgApp::run`]
109    /// when `[jetstream]` is present in `atrg.toml`.
110    pub fn on_event<F, Fut>(mut self, handler: F) -> Self
111    where
112        F: Fn(atrg_stream::JetstreamEvent, AppState) -> Fut + Send + Sync + 'static,
113        Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
114    {
115        self.event_handler = Some(Arc::new(move |event, state| {
116            Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
117        }));
118        self
119    }
120
121    /// Register a firehose event handler.
122    ///
123    /// The handler is called for every event received from the AT Protocol
124    /// relay firehose (`com.atproto.sync.subscribeRepos`). It is spawned as
125    /// a background task inside [`AtrgApp::run`] when `[firehose]` is present
126    /// in `atrg.toml`.
127    #[cfg(feature = "firehose")]
128    pub fn on_firehose_event<F, Fut>(mut self, handler: F) -> Self
129    where
130        F: Fn(atrg_firehose::FirehoseEvent, AppState) -> Fut + Send + Sync + 'static,
131        Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
132    {
133        self.firehose_handler = Some(std::sync::Arc::new(move |event, state| {
134            Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
135        }));
136        self
137    }
138
139    /// Mount a feed generator's routes.
140    ///
141    /// Pass the router produced by `FeedGenerator::into_router()` (from the
142    /// `atrg-feed` crate).
143    /// This is a semantic alias for [`mount`](Self::mount) that makes the
144    /// builder read more clearly.
145    ///
146    /// # Example
147    ///
148    /// ```rust,ignore
149    /// AtrgApp::new()
150    ///     .with_feed_generator(feed_gen.into_router())
151    /// ```
152    pub fn with_feed_generator(self, feed_router: Router<AppState>) -> Self {
153        self.mount(feed_router)
154    }
155
156    /// Mount a labeler service's routes.
157    ///
158    /// Pass the router produced by `labeler_routes()` (from the `atrg-label`
159    /// crate).
160    /// This is a semantic alias for [`mount`](Self::mount) that makes the
161    /// builder read more clearly.
162    ///
163    /// # Example
164    ///
165    /// ```rust,ignore
166    /// AtrgApp::new()
167    ///     .with_labeler(atrg_label::routes::labeler_routes(service))
168    /// ```
169    pub fn with_labeler(self, labeler_router: Router<AppState>) -> Self {
170        self.mount(labeler_router)
171    }
172
173    /// Boot the server.
174    ///
175    /// This is the **only** async entry point. It:
176    ///
177    /// 1. Initialises tracing (respects `RUST_LOG`).
178    /// 2. Loads `atrg.toml` (or `$ATRG_CONFIG`).
179    /// 3. Connects to SQLite and runs migrations.
180    /// 4. Builds [`AppState`] (including the identity resolver).
181    /// 5. Assembles the Axum router with CORS, tracing, and a JSON 404 fallback.
182    /// 6. Spawns optional cleanup tasks.
183    /// 7. Binds a TCP listener and serves.
184    pub async fn run(self) -> anyhow::Result<()> {
185        // 1. Init tracing -------------------------------------------------------
186        let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
187            .unwrap_or_else(|_| {
188                tracing_subscriber::EnvFilter::new(
189                    "info,atrg_core=debug,atrg_db=debug,atrg_auth=debug,atrg_cli=debug,tower_http=debug",
190                )
191            });
192
193        // If another test or binary already initialised the subscriber, silently
194        // ignore the error rather than panicking.
195        let _ = tracing_subscriber::fmt()
196            .with_env_filter(env_filter)
197            .try_init();
198
199        // 2. Load config --------------------------------------------------------
200        let config_path = std::env::var("ATRG_CONFIG").unwrap_or_else(|_| "./atrg.toml".into());
201        tracing::info!(path = %config_path, "loading configuration");
202        let config = Config::load(&config_path)?;
203        let config = Arc::new(config);
204
205        // 3. Connect DB + migrations --------------------------------------------
206        let db = atrg_db::connect(&config.database.url).await?;
207        atrg_db::run_internal_migrations(&db).await?;
208
209        let user_migrations = Path::new("./migrations");
210        if user_migrations.is_dir() {
211            atrg_db::run_user_migrations(&db, user_migrations).await?;
212        }
213
214        // 4. Build HTTP client --------------------------------------------------
215        let http = reqwest::Client::builder()
216            .user_agent(format!("atrg/{}", crate::version()))
217            .build()?;
218
219        // 4b. Build identity resolver -------------------------------------------
220        let identity = Arc::new(atrg_identity::IdentityResolver::with_defaults(http.clone()));
221
222        // 5. Assemble AppState --------------------------------------------------
223        let state = AppState {
224            config: config.clone(),
225            db,
226            http,
227            identity,
228        };
229
230        // 6. Build CORS layer ---------------------------------------------------
231        let cors = build_cors_layer(&config.app.cors_origins);
232
233        // 7. Build router -------------------------------------------------------
234        let mut router = Router::new();
235
236        // Built-in health endpoints
237        router = router
238            .route("/healthz", axum::routing::get(crate::health::healthz))
239            .route("/readyz", axum::routing::get(crate::health::readyz));
240
241        // Merge built-in auth routes (if registered via with_auth_routes)
242        if let Some(builtin) = self.builtin_router {
243            router = router.merge(builtin);
244        }
245
246        let mut router = router
247            // User routes
248            .merge(self.router)
249            // JSON 404 fallback for any unmatched path
250            .fallback(any(fallback_not_found))
251            .with_state(state.clone())
252            .layer(cors)
253            .layer(axum::middleware::from_fn(
254                crate::request_id::request_id_middleware,
255            ))
256            .layer(TraceLayer::new_for_http());
257
258        // Apply security headers in non-development mode
259        if config.app.environment != "development" {
260            router = router.layer(axum::middleware::from_fn(
261                crate::security::security_headers_middleware,
262            ));
263        }
264
265        // 8. Jetstream ----------------------------------------------------------
266        if let Some(ref js_config) = config.jetstream {
267            if let Some(handler) = self.event_handler {
268                let stream_config = atrg_stream::StreamConfig {
269                    host: js_config.host.clone(),
270                    collections: js_config.collections.clone(),
271                    zstd_dict: js_config.zstd_dict.clone(),
272                    channel_capacity: js_config.channel_capacity,
273                    max_lag_events: js_config.max_lag_events,
274                };
275                atrg_stream::spawn_consumer(&stream_config, state.clone(), handler).await?;
276            } else {
277                tracing::warn!("jetstream configured but no on_event handler registered");
278            }
279        }
280
281        // 8c. Firehose consumer --------------------------------------------------
282        #[cfg(feature = "firehose")]
283        if let Some(ref fh_config) = config.firehose {
284            if let Some(handler) = self.firehose_handler {
285                let firehose_config = atrg_firehose::FirehoseConfig {
286                    relay: fh_config.relay.clone(),
287                    cursor: fh_config.cursor,
288                    channel_capacity: fh_config.channel_capacity,
289                };
290                atrg_firehose::spawn_firehose(&firehose_config, state.clone(), handler).await?;
291            } else {
292                tracing::warn!("firehose configured but no on_firehose_event handler registered");
293            }
294        }
295
296        // 8b. Spawn cleanup task (if registered) --------------------------------
297        if let Some(cleanup) = self.cleanup_fn {
298            cleanup(state.db.clone());
299        }
300
301        // 9. Serve --------------------------------------------------------------
302        let addr = format!("{}:{}", config.app.host, config.app.port);
303        tracing::info!(addr = %addr, name = %config.app.name, "at-rust-go API serving");
304        let listener = tokio::net::TcpListener::bind(&addr).await?;
305        axum::serve(listener, router).await?;
306
307        Ok(())
308    }
309}
310
311impl Default for AtrgApp {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317/// Global fallback handler — returns a JSON 404 for any unmatched route.
318async fn fallback_not_found() -> impl IntoResponse {
319    AtrgError::NotFound
320}
321
322/// Build a fully-wired [`Router`] for testing purposes (no TCP listener).
323///
324/// This is **not** part of the public API — it exists so integration tests
325/// can exercise the full middleware stack via `tower::ServiceExt::oneshot`.
326#[cfg(test)]
327pub(crate) fn build_test_router(user_router: Router<AppState>, state: AppState) -> Router {
328    build_test_router_with_auth(None, user_router, state)
329}
330
331/// Like [`build_test_router`], but also merges optional auth routes.
332#[cfg(test)]
333pub(crate) fn build_test_router_with_auth(
334    auth_router: Option<Router<AppState>>,
335    user_router: Router<AppState>,
336    state: AppState,
337) -> Router {
338    let cors = build_cors_layer(&state.config.app.cors_origins);
339
340    let mut router = Router::new();
341    if let Some(auth) = auth_router {
342        router = router.merge(auth);
343    }
344
345    router
346        .merge(user_router)
347        .fallback(any(fallback_not_found))
348        .with_state(state)
349        .layer(cors)
350        .layer(TraceLayer::new_for_http())
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
357    use axum::body::Body;
358    use axum::routing::get;
359    use axum::Json;
360    use http_body_util::BodyExt;
361    use hyper::Request;
362    use tower::ServiceExt;
363
364    /// Build an [`AppState`] backed by an in-memory SQLite database.
365    async fn test_state() -> AppState {
366        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
367        atrg_db::run_internal_migrations(&db).await.unwrap();
368
369        let config = Config {
370            app: AppConfig {
371                name: "test-app".into(),
372                host: "127.0.0.1".into(),
373                port: 3000,
374                secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
375                cors_origins: vec![],
376                environment: "development".into(),
377            },
378            auth: AuthConfig {
379                client_id: "http://localhost:3000/client-metadata.json".into(),
380                redirect_uri: "http://localhost:3000/auth/callback".into(),
381                scope: "atproto transition:generic".into(),
382            },
383            database: DatabaseConfig {
384                url: "sqlite::memory:".into(),
385            },
386            jetstream: None,
387            firehose: None,
388            feed_generator: None,
389            labeler: None,
390            rate_limit: None,
391        };
392
393        AppState {
394            config: Arc::new(config),
395            db,
396            http: reqwest::Client::new(),
397            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
398                reqwest::Client::new(),
399            )),
400        }
401    }
402
403    /// Helper: extract the full body bytes from a response.
404    async fn body_bytes(response: axum::response::Response) -> Vec<u8> {
405        response
406            .into_body()
407            .collect()
408            .await
409            .unwrap()
410            .to_bytes()
411            .to_vec()
412    }
413
414    #[test]
415    fn atrg_app_default_is_new() {
416        // Just ensure Default compiles and doesn't panic.
417        let _app = AtrgApp::default();
418    }
419
420    #[test]
421    fn on_event_sets_handler() {
422        let app = AtrgApp::new().on_event(|_event, _state| async { Ok(()) });
423        assert!(app.event_handler.is_some());
424    }
425
426    #[tokio::test]
427    async fn mount_ping_returns_200_json() {
428        let state = test_state().await;
429
430        let user_router: Router<AppState> = Router::new().route(
431            "/ping",
432            get(|| async { Json(serde_json::json!({"pong": true})) }),
433        );
434
435        let app = build_test_router(user_router, state);
436
437        let request = Request::builder().uri("/ping").body(Body::empty()).unwrap();
438
439        let response = app.oneshot(request).await.unwrap();
440        assert_eq!(response.status(), 200);
441
442        let ct = response
443            .headers()
444            .get("content-type")
445            .unwrap()
446            .to_str()
447            .unwrap();
448        assert!(
449            ct.contains("application/json"),
450            "expected application/json, got {ct}"
451        );
452
453        let bytes = body_bytes(response).await;
454        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
455        assert_eq!(body["pong"], true);
456    }
457
458    #[tokio::test]
459    async fn unknown_route_returns_404_json() {
460        let state = test_state().await;
461        let app = build_test_router(Router::new(), state);
462
463        let request = Request::builder()
464            .uri("/does-not-exist")
465            .body(Body::empty())
466            .unwrap();
467
468        let response = app.oneshot(request).await.unwrap();
469        assert_eq!(response.status(), 404);
470
471        let ct = response
472            .headers()
473            .get("content-type")
474            .unwrap()
475            .to_str()
476            .unwrap();
477        assert!(
478            ct.contains("application/json"),
479            "expected application/json, got {ct}"
480        );
481
482        let bytes = body_bytes(response).await;
483        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
484        assert_eq!(body["error"], "not_found");
485        assert_eq!(body["message"], "Not found");
486    }
487
488    #[tokio::test]
489    async fn multiple_mounts_accumulate_routes() {
490        let state = test_state().await;
491
492        let r1: Router<AppState> = Router::new().route(
493            "/a",
494            get(|| async { Json(serde_json::json!({"route": "a"})) }),
495        );
496        let r2: Router<AppState> = Router::new().route(
497            "/b",
498            get(|| async { Json(serde_json::json!({"route": "b"})) }),
499        );
500
501        let app = build_test_router(r1.merge(r2), state);
502
503        let resp_a = app
504            .clone()
505            .oneshot(Request::builder().uri("/a").body(Body::empty()).unwrap())
506            .await
507            .unwrap();
508        assert_eq!(resp_a.status(), 200);
509
510        let resp_b = app
511            .oneshot(Request::builder().uri("/b").body(Body::empty()).unwrap())
512            .await
513            .unwrap();
514        assert_eq!(resp_b.status(), 200);
515    }
516
517    #[tokio::test]
518    async fn with_auth_routes_merges_builtin() {
519        let state = test_state().await;
520
521        // Simulate an auth router with a test endpoint
522        let auth_router: Router<AppState> = Router::new().route(
523            "/auth/test",
524            get(|| async { Json(serde_json::json!({"auth": true})) }),
525        );
526
527        let user_router: Router<AppState> = Router::new().route(
528            "/ping",
529            get(|| async { Json(serde_json::json!({"pong": true})) }),
530        );
531
532        let app = build_test_router_with_auth(Some(auth_router), user_router, state);
533
534        // Auth route works
535        let resp = app
536            .clone()
537            .oneshot(
538                Request::builder()
539                    .uri("/auth/test")
540                    .body(Body::empty())
541                    .unwrap(),
542            )
543            .await
544            .unwrap();
545        assert_eq!(resp.status(), 200);
546
547        // User route also works
548        let resp = app
549            .oneshot(Request::builder().uri("/ping").body(Body::empty()).unwrap())
550            .await
551            .unwrap();
552        assert_eq!(resp.status(), 200);
553    }
554}