Skip to main content

atrg_core/
app.rs

1//! The `AtrgApp` builder — the main entry point for assembling and running an atrg server.
2//!
3//! A minimal application looks like this:
4//!
5//! ```rust,no_run
6//! use atrg_core::AtrgApp;
7//!
8//! #[tokio::main]
9//! async fn main() -> anyhow::Result<()> {
10//!     AtrgApp::new()
11//!         .mount(axum::Router::new())
12//!         .run()
13//!         .await
14//! }
15//! ```
16
17use std::path::Path;
18use std::sync::Arc;
19
20use atrg_db::DbPool;
21use axum::response::IntoResponse;
22use axum::routing::any;
23use axum::Router;
24use futures::future::BoxFuture;
25use tower_http::trace::TraceLayer;
26
27use crate::config::Config;
28use crate::cors::build_cors_layer;
29use crate::error::AtrgError;
30use crate::state::{AppState, Extensions};
31
32/// A cleanup task function that receives a database pool and spawns
33/// background maintenance work (e.g. expired session cleanup).
34type CleanupFn = Box<dyn FnOnce(DbPool) + Send>;
35
36/// The application builder. Accumulates user routers and configuration,
37/// then boots the full server when [`AtrgApp::run`] is called.
38pub struct AtrgApp {
39    router: Router<AppState>,
40    /// Built-in routes (auth, well-known, etc.) merged before user routes.
41    builtin_router: Option<Router<AppState>>,
42    /// Optional cleanup task spawner (e.g. session/oauth-state cleanup).
43    cleanup_fn: Option<CleanupFn>,
44    /// Optional caller-supplied database pool. When set, [`AtrgApp::run`]
45    /// uses this pool instead of opening one from `[database] url`.
46    user_db_pool: Option<DbPool>,
47    /// Jetstream event handler registered via [`AtrgApp::on_event`].
48    event_handler: Option<atrg_stream::EventHandler<AppState>>,
49    /// Firehose event handler (registered via [`AtrgApp::on_firehose_event`]).
50    #[cfg(feature = "firehose")]
51    firehose_handler: Option<atrg_firehose::FirehoseHandler<AppState>>,
52    /// App-specific extensions collected during build and passed into AppState.
53    extensions: Extensions,
54}
55
56impl AtrgApp {
57    /// Create a new, empty application builder.
58    pub fn new() -> Self {
59        Self {
60            router: Router::new(),
61            builtin_router: None,
62            cleanup_fn: None,
63            user_db_pool: None,
64            event_handler: None,
65            #[cfg(feature = "firehose")]
66            firehose_handler: None,
67            extensions: Extensions::new(),
68        }
69    }
70
71    /// Mount an additional [`axum::Router`] into the application.
72    ///
73    /// Routes are merged, so multiple calls to `mount` accumulate routes.
74    pub fn mount(mut self, router: Router<AppState>) -> Self {
75        self.router = self.router.merge(router);
76        self
77    }
78
79    /// Register built-in auth routes (OAuth login/callback/logout, client-metadata, well-known).
80    ///
81    /// The supplied router is merged **before** user routes so that atrg's
82    /// built-in endpoints take precedence.
83    ///
84    /// # Example
85    ///
86    /// ```rust,ignore
87    /// use atrg_core::AtrgApp;
88    ///
89    /// AtrgApp::new()
90    ///     .with_auth_routes(atrg_auth::routes::auth_router())
91    ///     // ...
92    /// # ;
93    /// ```
94    pub fn with_auth_routes(mut self, router: Router<AppState>) -> Self {
95        self.builtin_router = Some(router);
96        self
97    }
98
99    /// Register a background cleanup task that is spawned after the server
100    /// starts. Typically used for periodic session / OAuth-state expiry.
101    ///
102    /// The callback receives the [`DbPool`] and is expected to call
103    /// `tokio::spawn` internally.
104    pub fn with_cleanup_task<F>(mut self, f: F) -> Self
105    where
106        F: FnOnce(DbPool) + Send + 'static,
107    {
108        self.cleanup_fn = Some(Box::new(f));
109        self
110    }
111
112    /// Use a caller-provided database pool instead of opening a fresh one
113    /// from `[database] url`.
114    ///
115    /// This is the recommended way to integrate atrg into an existing
116    /// application that already manages its own connection pool — for
117    /// example, a service that uses PostgreSQL for its business data and
118    /// wants atrg's internal tables (sessions, OAuth state) to live in the
119    /// same database:
120    ///
121    /// ```rust,ignore
122    /// let pool = sqlx::PgPool::connect(&db_url).await?;
123    ///
124    /// AtrgApp::new()
125    ///     .with_db_pool(pool.into())   // accepts SqlitePool, PgPool, or DbPool
126    ///     .mount(routes::api())
127    ///     .run()
128    ///     .await
129    /// ```
130    ///
131    /// When a pool is provided this way, `[database] url` from `atrg.toml`
132    /// is ignored. atrg's internal migrations are still applied to the
133    /// supplied pool on startup.
134    pub fn with_db_pool(mut self, pool: impl Into<DbPool>) -> Self {
135        self.user_db_pool = Some(pool.into());
136        self
137    }
138
139    /// Register an app-specific extension value.
140    ///
141    /// Extensions are type-erased values accessible from any handler via
142    /// [`AppState::extension::<T>()`](crate::state::AppState::extension) or
143    /// [`AppState::try_extension::<T>()`](crate::state::AppState::try_extension).
144    ///
145    /// Each type can appear at most once — inserting a second value of the
146    /// same type replaces the first.
147    ///
148    /// # Examples
149    ///
150    /// ```rust,ignore
151    /// struct S3Client { bucket: String }
152    /// struct SmtpConfig { host: String }
153    ///
154    /// AtrgApp::new()
155    ///     .with_extension(S3Client { bucket: "my-blobs".into() })
156    ///     .with_extension(SmtpConfig { host: "smtp.example.com".into() })
157    ///     .mount(routes())
158    ///     .run()
159    ///     .await
160    /// ```
161    pub fn with_extension<T: Send + Sync + 'static>(mut self, value: T) -> Self {
162        self.extensions.insert(value);
163        self
164    }
165
166    /// Register a Jetstream event handler.
167    ///
168    /// The handler is called for every event received from the Jetstream
169    /// firehose. It is spawned as a background task inside [`AtrgApp::run`]
170    /// when `[jetstream]` is present in `atrg.toml`.
171    pub fn on_event<F, Fut>(mut self, handler: F) -> Self
172    where
173        F: Fn(atrg_stream::JetstreamEvent, AppState) -> Fut + Send + Sync + 'static,
174        Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
175    {
176        self.event_handler = Some(Arc::new(move |event, state| {
177            Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
178        }));
179        self
180    }
181
182    /// Register a firehose event handler.
183    ///
184    /// The handler is called for every event received from the AT Protocol
185    /// relay firehose (`com.atproto.sync.subscribeRepos`). It is spawned as
186    /// a background task inside [`AtrgApp::run`] when `[firehose]` is present
187    /// in `atrg.toml`.
188    #[cfg(feature = "firehose")]
189    pub fn on_firehose_event<F, Fut>(mut self, handler: F) -> Self
190    where
191        F: Fn(atrg_firehose::FirehoseEvent, AppState) -> Fut + Send + Sync + 'static,
192        Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
193    {
194        self.firehose_handler = Some(std::sync::Arc::new(move |event, state| {
195            Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
196        }));
197        self
198    }
199
200    /// Mount a feed generator's routes.
201    ///
202    /// Pass the router produced by `FeedGenerator::into_router()` (from the
203    /// `atrg-feed` crate).
204    /// This is a semantic alias for [`mount`](Self::mount) that makes the
205    /// builder read more clearly.
206    ///
207    /// # Example
208    ///
209    /// ```rust,ignore
210    /// AtrgApp::new()
211    ///     .with_feed_generator(feed_gen.into_router())
212    /// ```
213    pub fn with_feed_generator(self, feed_router: Router<AppState>) -> Self {
214        self.mount(feed_router)
215    }
216
217    /// Mount a labeler service's routes.
218    ///
219    /// Pass the router produced by `labeler_routes()` (from the `atrg-label`
220    /// crate).
221    /// This is a semantic alias for [`mount`](Self::mount) that makes the
222    /// builder read more clearly.
223    ///
224    /// # Example
225    ///
226    /// ```rust,ignore
227    /// AtrgApp::new()
228    ///     .with_labeler(atrg_label::routes::labeler_routes(service))
229    /// ```
230    pub fn with_labeler(self, labeler_router: Router<AppState>) -> Self {
231        self.mount(labeler_router)
232    }
233
234    /// Boot the server.
235    ///
236    /// This is the **only** async entry point. It:
237    ///
238    /// 1. Initialises tracing (respects `RUST_LOG`).
239    /// 2. Loads `atrg.toml` (or `$ATRG_CONFIG`).
240    /// 3. Connects to SQLite and runs migrations.
241    /// 4. Builds [`AppState`] (including the identity resolver).
242    /// 5. Assembles the Axum router with CORS, tracing, and a JSON 404 fallback.
243    /// 6. Spawns optional cleanup tasks.
244    /// 7. Binds a TCP listener and serves.
245    pub async fn run(self) -> anyhow::Result<()> {
246        // 1. Init tracing -------------------------------------------------------
247        let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
248            .unwrap_or_else(|_| {
249                tracing_subscriber::EnvFilter::new(
250                    "info,atrg_core=debug,atrg_db=debug,atrg_auth=debug,atrg_cli=debug,tower_http=debug",
251                )
252            });
253
254        // If another test or binary already initialised the subscriber, silently
255        // ignore the error rather than panicking.
256        let _ = tracing_subscriber::fmt()
257            .with_env_filter(env_filter)
258            .try_init();
259
260        // 2. Load config --------------------------------------------------------
261        let config_path = std::env::var("ATRG_CONFIG").unwrap_or_else(|_| "./atrg.toml".into());
262        tracing::info!(path = %config_path, "loading configuration");
263        let config = Config::load(&config_path)?;
264        let config = Arc::new(config);
265
266        // 3. Connect DB + migrations --------------------------------------------
267        let db = match self.user_db_pool {
268            Some(pool) => {
269                tracing::info!(
270                    backend = pool.backend(),
271                    "using caller-supplied database pool (bypassing [database] url)"
272                );
273                pool
274            }
275            None => atrg_db::connect(&config.database.url).await?,
276        };
277        atrg_db::run_internal_migrations(&db).await?;
278
279        let user_migrations = Path::new("./migrations");
280        if user_migrations.is_dir() {
281            atrg_db::run_isolated_migrations(&db, user_migrations, "_app_migrations").await?;
282        }
283
284        // 4. Build HTTP client --------------------------------------------------
285        let http = reqwest::Client::builder()
286            .user_agent(format!("atrg/{}", crate::version()))
287            .build()?;
288
289        // 4b. Build identity resolver -------------------------------------------
290        let identity = Arc::new(atrg_identity::IdentityResolver::with_defaults(http.clone()));
291
292        // 5. Assemble AppState --------------------------------------------------
293        let state = AppState {
294            config: config.clone(),
295            db,
296            http,
297            identity,
298            extensions: Arc::new(self.extensions),
299        };
300
301        // 5b. Admin bootstrap ---------------------------------------------------
302        if !config.app.admin_dids.is_empty() {
303            for did in &config.app.admin_dids {
304                let now = std::time::SystemTime::now()
305                    .duration_since(std::time::UNIX_EPOCH)
306                    .unwrap_or_default()
307                    .as_secs()
308                    .to_string();
309                let result: Result<(), sqlx::Error> = match &state.db {
310                    #[cfg(feature = "sqlite")]
311                    atrg_db::DbPool::Sqlite(p) => {
312                        sqlx::query(
313                            "INSERT OR IGNORE INTO atrg_roles (did, role, granted_by, granted_at) VALUES (?1, 'admin', 'system:bootstrap', ?2)"
314                        ).bind(did).bind(&now).execute(p).await.map(|_| ())
315                    }
316                    #[cfg(feature = "postgres")]
317                    atrg_db::DbPool::Postgres(p) => {
318                        sqlx::query(
319                            "INSERT INTO atrg_roles (did, role, granted_by, granted_at) VALUES ($1, 'admin', 'system:bootstrap', $2) ON CONFLICT DO NOTHING"
320                        ).bind(did).bind(&now).execute(p).await.map(|_| ())
321                    }
322                    #[allow(unreachable_patterns)]
323                    _ => Ok(()),
324                };
325                match result {
326                    Ok(_) => tracing::info!(did = %did, "auto-provisioned admin DID"),
327                    Err(e) => {
328                        tracing::warn!(did = %did, error = %e, "failed to bootstrap admin DID (table may not exist yet)")
329                    }
330                }
331            }
332        }
333
334        // 6. Build CORS layer ---------------------------------------------------
335        let cors = build_cors_layer(&config.app.cors_origins);
336
337        // 7. Build router -------------------------------------------------------
338        let mut router = Router::new();
339
340        // Built-in health endpoints
341        router = router
342            .route("/healthz", axum::routing::get(crate::health::healthz))
343            .route("/readyz", axum::routing::get(crate::health::readyz));
344
345        // Merge built-in auth routes (if registered via with_auth_routes)
346        if let Some(builtin) = self.builtin_router {
347            router = router.merge(builtin);
348        }
349
350        let mut router = router
351            // User routes
352            .merge(self.router)
353            // JSON 404 fallback for any unmatched path
354            .fallback(any(fallback_not_found))
355            .with_state(state.clone())
356            .layer(cors)
357            .layer(axum::middleware::from_fn(
358                crate::request_id::request_id_middleware,
359            ))
360            .layer(TraceLayer::new_for_http());
361
362        // Apply security headers in non-development mode
363        if config.app.environment != "development" {
364            router = router.layer(axum::middleware::from_fn(
365                crate::security::security_headers_middleware,
366            ));
367        }
368
369        // 8a. Rate limiting (if configured) -------------------------------------
370        if let Some(ref rl_config) = config.rate_limit {
371            if rl_config.enabled {
372                let limiter =
373                    crate::rate_limit::RateLimiter::new(crate::rate_limit::RateLimitConfig {
374                        requests_per_second: rl_config.requests_per_second,
375                        burst: rl_config.burst,
376                        enabled: true,
377                    });
378
379                // Spawn periodic cleanup task (every 5 minutes, remove entries older than 10 min)
380                let limiter_cleanup = limiter.clone();
381                tokio::spawn(async move {
382                    let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
383                    loop {
384                        interval.tick().await;
385                        limiter_cleanup
386                            .cleanup(std::time::Duration::from_secs(600))
387                            .await;
388                    }
389                });
390
391                router = router.layer(axum::middleware::from_fn(
392                    move |req: axum::extract::Request, next: axum::middleware::Next| {
393                        let limiter = limiter.clone();
394                        async move {
395                            // Extract client IP from connection info or X-Forwarded-For
396                            let ip = req
397                                .extensions()
398                                .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
399                                .map(|ci| ci.0.ip())
400                                .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
401
402                            match limiter.check(ip).await {
403                                Ok(()) => next.run(req).await,
404                                Err(retry_after) => {
405                                    crate::rate_limit::rate_limit_response(retry_after)
406                                }
407                            }
408                        }
409                    },
410                ));
411
412                tracing::info!(
413                    rps = rl_config.requests_per_second,
414                    burst = rl_config.burst,
415                    "rate limiting enabled"
416                );
417            }
418        }
419
420        // 8. Jetstream ----------------------------------------------------------
421        if let Some(ref js_config) = config.jetstream {
422            if let Some(handler) = self.event_handler {
423                let stream_config = atrg_stream::StreamConfig {
424                    host: js_config.host.clone(),
425                    collections: js_config.collections.clone(),
426                    zstd_dict: js_config.zstd_dict.clone(),
427                    channel_capacity: js_config.channel_capacity,
428                    max_lag_events: js_config.max_lag_events,
429                    cursor: None,
430                };
431                atrg_stream::spawn_consumer(&stream_config, state.clone(), handler).await?;
432            } else {
433                tracing::warn!("jetstream configured but no on_event handler registered");
434            }
435        }
436
437        // 8c. Firehose consumer --------------------------------------------------
438        #[cfg(feature = "firehose")]
439        if let Some(ref fh_config) = config.firehose {
440            if let Some(handler) = self.firehose_handler {
441                let firehose_config = atrg_firehose::FirehoseConfig {
442                    relay: fh_config.relay.clone(),
443                    cursor: fh_config.cursor,
444                    channel_capacity: fh_config.channel_capacity,
445                };
446                atrg_firehose::spawn_firehose(&firehose_config, state.clone(), handler).await?;
447            } else {
448                tracing::warn!("firehose configured but no on_firehose_event handler registered");
449            }
450        }
451
452        // 8b. Spawn cleanup task (if registered) --------------------------------
453        if let Some(cleanup) = self.cleanup_fn {
454            cleanup(state.db.clone());
455        }
456
457        // 9. Serve --------------------------------------------------------------
458        let addr = format!("{}:{}", config.app.host, config.app.port);
459        tracing::info!(addr = %addr, name = %config.app.name, "at-rust-go API serving");
460        let listener = tokio::net::TcpListener::bind(&addr).await?;
461        axum::serve(listener, router).await?;
462
463        Ok(())
464    }
465}
466
467impl Default for AtrgApp {
468    fn default() -> Self {
469        Self::new()
470    }
471}
472
473/// Global fallback handler — returns a JSON 404 for any unmatched route.
474async fn fallback_not_found() -> impl IntoResponse {
475    AtrgError::NotFound
476}
477
478/// Build a fully-wired [`Router`] for testing purposes (no TCP listener).
479///
480/// This is **not** part of the public API — it exists so integration tests
481/// can exercise the full middleware stack via `tower::ServiceExt::oneshot`.
482#[cfg(test)]
483pub(crate) fn build_test_router(user_router: Router<AppState>, state: AppState) -> Router {
484    build_test_router_with_auth(None, user_router, state)
485}
486
487/// Like [`build_test_router`], but also merges optional auth routes.
488#[cfg(test)]
489pub(crate) fn build_test_router_with_auth(
490    auth_router: Option<Router<AppState>>,
491    user_router: Router<AppState>,
492    state: AppState,
493) -> Router {
494    let cors = build_cors_layer(&state.config.app.cors_origins);
495
496    let mut router = Router::new();
497    if let Some(auth) = auth_router {
498        router = router.merge(auth);
499    }
500
501    router
502        .merge(user_router)
503        .fallback(any(fallback_not_found))
504        .with_state(state)
505        .layer(cors)
506        .layer(TraceLayer::new_for_http())
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
513    use axum::body::Body;
514    use axum::routing::get;
515    use axum::Json;
516    use http_body_util::BodyExt;
517    use hyper::Request;
518    use tower::ServiceExt;
519
520    /// Build an [`AppState`] backed by an in-memory SQLite database.
521    async fn test_state() -> AppState {
522        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
523        atrg_db::run_internal_migrations(&db).await.unwrap();
524
525        let config = Config {
526            app: AppConfig {
527                name: "test-app".into(),
528                host: "127.0.0.1".into(),
529                port: 3000,
530                secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
531                cors_origins: vec![],
532                environment: "development".into(),
533                admin_dids: vec![],
534            },
535            auth: AuthConfig {
536                client_id: "http://localhost:3000/client-metadata.json".into(),
537                redirect_uri: "http://localhost:3000/auth/callback".into(),
538                scope: "atproto transition:generic".into(),
539                post_login_redirect: "/".into(),
540            },
541            database: DatabaseConfig {
542                url: "sqlite::memory:".into(),
543            },
544            jetstream: None,
545            firehose: None,
546            feed_generator: None,
547            labeler: None,
548            rate_limit: None,
549        };
550
551        AppState {
552            config: Arc::new(config),
553            db,
554            http: reqwest::Client::new(),
555            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
556                reqwest::Client::new(),
557            )),
558            extensions: Arc::new(Extensions::new()),
559        }
560    }
561
562    /// Helper: extract the full body bytes from a response.
563    async fn body_bytes(response: axum::response::Response) -> Vec<u8> {
564        response
565            .into_body()
566            .collect()
567            .await
568            .unwrap()
569            .to_bytes()
570            .to_vec()
571    }
572
573    #[test]
574    fn atrg_app_default_is_new() {
575        // Just ensure Default compiles and doesn't panic.
576        let _app = AtrgApp::default();
577    }
578
579    #[test]
580    fn on_event_sets_handler() {
581        let app = AtrgApp::new().on_event(|_event, _state| async { Ok(()) });
582        assert!(app.event_handler.is_some());
583    }
584
585    #[tokio::test]
586    async fn with_db_pool_stores_caller_pool() {
587        // Caller-supplied pool should override the [database] url path.
588        let pool = atrg_db::connect("sqlite::memory:").await.unwrap();
589        let app = AtrgApp::new().with_db_pool(pool.clone());
590        assert!(app.user_db_pool.is_some());
591        assert_eq!(app.user_db_pool.as_ref().unwrap().backend(), "sqlite");
592    }
593
594    #[tokio::test]
595    async fn readyz_reports_backend_kind() {
596        // Readiness response should expose the backend identifier so ops
597        // can confirm the right driver is in use.
598        let state = test_state().await;
599        let app: Router = Router::new()
600            .route("/readyz", get(crate::health::readyz))
601            .with_state(state);
602
603        let req = Request::builder()
604            .uri("/readyz")
605            .body(Body::empty())
606            .unwrap();
607        let resp = app.oneshot(req).await.unwrap();
608        assert_eq!(resp.status(), 200);
609
610        let bytes = body_bytes(resp).await;
611        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
612        assert_eq!(body["database_backend"], "sqlite");
613    }
614
615    #[tokio::test]
616    async fn mount_ping_returns_200_json() {
617        let state = test_state().await;
618
619        let user_router: Router<AppState> = Router::new().route(
620            "/ping",
621            get(|| async { Json(serde_json::json!({"pong": true})) }),
622        );
623
624        let app = build_test_router(user_router, state);
625
626        let request = Request::builder().uri("/ping").body(Body::empty()).unwrap();
627
628        let response = app.oneshot(request).await.unwrap();
629        assert_eq!(response.status(), 200);
630
631        let ct = response
632            .headers()
633            .get("content-type")
634            .unwrap()
635            .to_str()
636            .unwrap();
637        assert!(
638            ct.contains("application/json"),
639            "expected application/json, got {ct}"
640        );
641
642        let bytes = body_bytes(response).await;
643        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
644        assert_eq!(body["pong"], true);
645    }
646
647    #[tokio::test]
648    async fn unknown_route_returns_404_json() {
649        let state = test_state().await;
650        let app = build_test_router(Router::new(), state);
651
652        let request = Request::builder()
653            .uri("/does-not-exist")
654            .body(Body::empty())
655            .unwrap();
656
657        let response = app.oneshot(request).await.unwrap();
658        assert_eq!(response.status(), 404);
659
660        let ct = response
661            .headers()
662            .get("content-type")
663            .unwrap()
664            .to_str()
665            .unwrap();
666        assert!(
667            ct.contains("application/json"),
668            "expected application/json, got {ct}"
669        );
670
671        let bytes = body_bytes(response).await;
672        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
673        assert_eq!(body["error"], "not_found");
674        assert_eq!(body["message"], "Not found");
675    }
676
677    #[tokio::test]
678    async fn multiple_mounts_accumulate_routes() {
679        let state = test_state().await;
680
681        let r1: Router<AppState> = Router::new().route(
682            "/a",
683            get(|| async { Json(serde_json::json!({"route": "a"})) }),
684        );
685        let r2: Router<AppState> = Router::new().route(
686            "/b",
687            get(|| async { Json(serde_json::json!({"route": "b"})) }),
688        );
689
690        let app = build_test_router(r1.merge(r2), state);
691
692        let resp_a = app
693            .clone()
694            .oneshot(Request::builder().uri("/a").body(Body::empty()).unwrap())
695            .await
696            .unwrap();
697        assert_eq!(resp_a.status(), 200);
698
699        let resp_b = app
700            .oneshot(Request::builder().uri("/b").body(Body::empty()).unwrap())
701            .await
702            .unwrap();
703        assert_eq!(resp_b.status(), 200);
704    }
705
706    #[tokio::test]
707    async fn with_auth_routes_merges_builtin() {
708        let state = test_state().await;
709
710        // Simulate an auth router with a test endpoint
711        let auth_router: Router<AppState> = Router::new().route(
712            "/auth/test",
713            get(|| async { Json(serde_json::json!({"auth": true})) }),
714        );
715
716        let user_router: Router<AppState> = Router::new().route(
717            "/ping",
718            get(|| async { Json(serde_json::json!({"pong": true})) }),
719        );
720
721        let app = build_test_router_with_auth(Some(auth_router), user_router, state);
722
723        // Auth route works
724        let resp = app
725            .clone()
726            .oneshot(
727                Request::builder()
728                    .uri("/auth/test")
729                    .body(Body::empty())
730                    .unwrap(),
731            )
732            .await
733            .unwrap();
734        assert_eq!(resp.status(), 200);
735
736        // User route also works
737        let resp = app
738            .oneshot(Request::builder().uri("/ping").body(Body::empty()).unwrap())
739            .await
740            .unwrap();
741        assert_eq!(resp.status(), 200);
742    }
743
744    #[tokio::test]
745    async fn with_extension_is_accessible_from_state() {
746        struct MyConfig {
747            magic_number: u64,
748        }
749
750        let state = test_state().await;
751        // Verify the extension is reachable via AppState constructed by the builder.
752        // Since `run()` binds a port (can't easily test full lifecycle), we test
753        // the builder populates extensions correctly by constructing an AppState
754        // with the extension and hitting a handler that reads it.
755        let mut ext = Extensions::new();
756        ext.insert(MyConfig { magic_number: 42 });
757
758        let state_with_ext = AppState {
759            config: state.config.clone(),
760            db: state.db.clone(),
761            http: state.http.clone(),
762            identity: state.identity.clone(),
763            extensions: Arc::new(ext),
764        };
765
766        let app: Router = Router::new()
767            .route(
768                "/magic",
769                get(
770                    |axum::extract::State(s): axum::extract::State<AppState>| async move {
771                        let cfg = s.extension::<MyConfig>();
772                        Json(serde_json::json!({ "magic": cfg.magic_number }))
773                    },
774                ),
775            )
776            .with_state(state_with_ext);
777
778        let resp = app
779            .oneshot(
780                Request::builder()
781                    .uri("/magic")
782                    .body(Body::empty())
783                    .unwrap(),
784            )
785            .await
786            .unwrap();
787        assert_eq!(resp.status(), 200);
788
789        let body = body_bytes(resp).await;
790        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
791        assert_eq!(json["magic"], 42);
792    }
793
794    #[test]
795    fn with_extension_builder_accumulates_values() {
796        struct Foo(u32);
797        struct Bar(String);
798
799        let app = AtrgApp::new()
800            .with_extension(Foo(7))
801            .with_extension(Bar("baz".into()));
802
803        assert_eq!(app.extensions.get::<Foo>().unwrap().0, 7);
804        assert_eq!(app.extensions.get::<Bar>().unwrap().0, "baz");
805    }
806}