Skip to main content

atrg_core/
state.rs

1//! Application state shared across all Axum handlers.
2
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use atrg_db::DbPool;
8
9use crate::config::Config;
10use atrg_identity::IdentityResolver;
11
12// ---------------------------------------------------------------------------
13// Extensions — a type-erased map for app-specific state
14// ---------------------------------------------------------------------------
15
16/// A type-erased container for app-specific state.
17///
18/// `Extensions` lets applications attach arbitrary typed values to
19/// [`AppState`] without modifying the framework. Each type can appear at most
20/// once — the type itself is the key.
21///
22/// # Examples
23///
24/// ```rust
25/// use atrg_core::Extensions;
26///
27/// struct S3Client { bucket: String }
28/// struct SmtpConfig { host: String }
29///
30/// let mut ext = Extensions::new();
31/// ext.insert(S3Client { bucket: "my-blobs".into() });
32/// ext.insert(SmtpConfig { host: "smtp.example.com".into() });
33///
34/// assert_eq!(ext.get::<S3Client>().expect("registered").bucket, "my-blobs");
35/// assert_eq!(ext.get::<SmtpConfig>().expect("registered").host, "smtp.example.com");
36/// assert!(ext.get::<u64>().is_none());
37/// ```
38#[derive(Default)]
39pub struct Extensions {
40    map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
41}
42
43impl Extensions {
44    /// Create a new, empty extensions map.
45    pub fn new() -> Self {
46        Self {
47            map: HashMap::new(),
48        }
49    }
50
51    /// Insert a value into the map. If a value of this type already exists,
52    /// it is replaced and the old value is returned.
53    pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) -> Option<T> {
54        self.map
55            .insert(TypeId::of::<T>(), Box::new(value))
56            .and_then(|boxed| boxed.downcast::<T>().ok().map(|b| *b))
57    }
58
59    /// Retrieve a reference to a value by type. Returns `None` if the type
60    /// has not been inserted.
61    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
62        self.map
63            .get(&TypeId::of::<T>())
64            .and_then(|boxed| boxed.downcast_ref::<T>())
65    }
66
67    /// Returns `true` if the map contains a value of the given type.
68    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
69        self.map.contains_key(&TypeId::of::<T>())
70    }
71
72    /// Returns the number of entries in the map.
73    pub fn len(&self) -> usize {
74        self.map.len()
75    }
76
77    /// Returns `true` if the map is empty.
78    pub fn is_empty(&self) -> bool {
79        self.map.is_empty()
80    }
81}
82
83// Manual Debug impl because `dyn Any` is not Debug.
84impl std::fmt::Debug for Extensions {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("Extensions")
87            .field("len", &self.map.len())
88            .finish_non_exhaustive()
89    }
90}
91
92// ---------------------------------------------------------------------------
93// AppState
94// ---------------------------------------------------------------------------
95
96/// Shared application state passed to every Axum handler.
97///
98/// This is the central state object that every route handler receives via
99/// `axum::extract::State<AppState>`. It holds the parsed configuration,
100/// database connection pool, and a shared HTTP client for outbound requests.
101///
102/// `AppState` is cheaply cloneable — all inner fields are either `Arc`-wrapped
103/// or already use internal reference counting (e.g. sqlx pools, `reqwest::Client`).
104#[derive(Clone)]
105pub struct AppState {
106    /// Parsed configuration from `atrg.toml`.
107    pub config: Arc<Config>,
108    /// Database connection pool. May be SQLite or PostgreSQL depending on
109    /// the `[database] url` scheme in `atrg.toml` (and which features are
110    /// compiled in to `atrg-db`).
111    pub db: DbPool,
112    /// Shared HTTP client for outbound requests.
113    pub http: reqwest::Client,
114    /// DID/handle resolver with TTL-backed in-memory cache.
115    pub identity: Arc<IdentityResolver>,
116    /// Type-erased container for app-specific state (S3 clients, SMTP config,
117    /// domain-specific services, etc.). Access via [`AppState::extension`] or
118    /// [`AppState::try_extension`].
119    pub extensions: Arc<Extensions>,
120}
121
122impl AppState {
123    /// Retrieve a reference to an app-specific extension by type.
124    ///
125    /// # Panics
126    ///
127    /// Panics if the extension has not been registered. Use
128    /// [`try_extension`](Self::try_extension) for a non-panicking variant.
129    ///
130    /// # Examples
131    ///
132    /// ```rust,ignore
133    /// struct MyService { url: String }
134    ///
135    /// // In a handler:
136    /// async fn my_handler(State(state): State<AppState>) -> impl IntoResponse {
137    ///     let svc = state.extension::<MyService>();
138    ///     Json(json!({ "url": svc.url }))
139    /// }
140    /// ```
141    pub fn extension<T: Send + Sync + 'static>(&self) -> &T {
142        self.extensions.get::<T>().unwrap_or_else(|| {
143            panic!(
144                "AppState::extension::<{}>() called but no value of that type was registered. \
145                 Did you forget to call `AtrgApp::with_extension(value)` during app setup?",
146                std::any::type_name::<T>()
147            )
148        })
149    }
150
151    /// Retrieve a reference to an app-specific extension by type, returning
152    /// `None` if the type was never registered.
153    ///
154    /// # Examples
155    ///
156    /// ```rust,ignore
157    /// if let Some(metrics) = state.try_extension::<MetricsCollector>() {
158    ///     metrics.record_request();
159    /// }
160    /// ```
161    pub fn try_extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
162        self.extensions.get::<T>()
163    }
164
165    /// Returns `true` if an extension of type `T` has been registered.
166    pub fn has_extension<T: Send + Sync + 'static>(&self) -> bool {
167        self.extensions.contains::<T>()
168    }
169}
170
171// ---------------------------------------------------------------------------
172// FromRef implementations — allow Axum sub-extractors to pull individual
173// fields out of AppState without the handler needing to destructure manually.
174// ---------------------------------------------------------------------------
175
176impl axum::extract::FromRef<AppState> for DbPool {
177    fn from_ref(state: &AppState) -> Self {
178        state.db.clone()
179    }
180}
181
182impl axum::extract::FromRef<AppState> for Arc<Config> {
183    fn from_ref(state: &AppState) -> Self {
184        state.config.clone()
185    }
186}
187
188impl axum::extract::FromRef<AppState> for Arc<IdentityResolver> {
189    fn from_ref(state: &AppState) -> Self {
190        state.identity.clone()
191    }
192}
193
194impl axum::extract::FromRef<AppState> for Arc<Extensions> {
195    fn from_ref(state: &AppState) -> Self {
196        state.extensions.clone()
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    // Compile-time assertion helper.
205    fn _assert_send_sync_clone<T: Send + Sync + Clone>() {}
206
207    #[test]
208    fn app_state_is_send_sync_clone() {
209        _assert_send_sync_clone::<AppState>();
210    }
211
212    // -- Extensions unit tests ------------------------------------------------
213
214    #[test]
215    fn extensions_insert_and_get() {
216        struct Foo(u32);
217        struct Bar(String);
218
219        let mut ext = Extensions::new();
220        ext.insert(Foo(42));
221        ext.insert(Bar("hello".into()));
222
223        assert_eq!(ext.get::<Foo>().unwrap().0, 42);
224        assert_eq!(ext.get::<Bar>().unwrap().0, "hello");
225    }
226
227    #[test]
228    fn extensions_get_missing_returns_none() {
229        let ext = Extensions::new();
230        assert!(ext.get::<u32>().is_none());
231    }
232
233    #[test]
234    fn extensions_insert_replaces_and_returns_old() {
235        struct Config(String);
236
237        let mut ext = Extensions::new();
238        let old = ext.insert(Config("v1".into()));
239        assert!(old.is_none());
240
241        let old = ext.insert(Config("v2".into()));
242        assert_eq!(old.unwrap().0, "v1");
243        assert_eq!(ext.get::<Config>().unwrap().0, "v2");
244    }
245
246    #[test]
247    fn extensions_contains() {
248        struct Present;
249
250        let mut ext = Extensions::new();
251        assert!(!ext.contains::<Present>());
252        ext.insert(Present);
253        assert!(ext.contains::<Present>());
254    }
255
256    #[test]
257    fn extensions_len_and_is_empty() {
258        struct A;
259        struct B;
260
261        let mut ext = Extensions::new();
262        assert!(ext.is_empty());
263        assert_eq!(ext.len(), 0);
264
265        ext.insert(A);
266        assert!(!ext.is_empty());
267        assert_eq!(ext.len(), 1);
268
269        ext.insert(B);
270        assert_eq!(ext.len(), 2);
271    }
272
273    #[test]
274    fn extensions_debug_shows_len() {
275        let mut ext = Extensions::new();
276        ext.insert(42u32);
277        let dbg = format!("{:?}", ext);
278        assert!(dbg.contains("Extensions"));
279        assert!(dbg.contains("len"));
280    }
281
282    #[tokio::test]
283    async fn app_state_extension_returns_value() {
284        struct MyService {
285            name: String,
286        }
287
288        let mut ext = Extensions::new();
289        ext.insert(MyService {
290            name: "test".into(),
291        });
292
293        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
294        let state = AppState {
295            config: Arc::new(crate::config::Config {
296                app: crate::config::AppConfig {
297                    name: "test".into(),
298                    host: "127.0.0.1".into(),
299                    port: 3000,
300                    secret_key: "secret".into(),
301                    cors_origins: vec![],
302                    environment: "development".into(),
303                    admin_dids: vec![],
304                },
305                auth: crate::config::AuthConfig {
306                    client_id: "http://localhost/client-metadata.json".into(),
307                    redirect_uri: "http://localhost/auth/callback".into(),
308                    scope: "atproto transition:generic".into(),
309                    post_login_redirect: "/".into(),
310                },
311                database: crate::config::DatabaseConfig {
312                    url: "sqlite::memory:".into(),
313                },
314                jetstream: None,
315                firehose: None,
316                feed_generator: None,
317                labeler: None,
318                rate_limit: None,
319            }),
320            db,
321            http: reqwest::Client::new(),
322            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
323                reqwest::Client::new(),
324            )),
325            extensions: Arc::new(ext),
326        };
327
328        assert_eq!(state.extension::<MyService>().name, "test");
329    }
330
331    #[tokio::test]
332    async fn app_state_try_extension_returns_none_when_missing() {
333        struct NotRegistered;
334
335        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
336        let state = AppState {
337            config: Arc::new(crate::config::Config {
338                app: crate::config::AppConfig {
339                    name: "test".into(),
340                    host: "127.0.0.1".into(),
341                    port: 3000,
342                    secret_key: "secret".into(),
343                    cors_origins: vec![],
344                    environment: "development".into(),
345                    admin_dids: vec![],
346                },
347                auth: crate::config::AuthConfig {
348                    client_id: "http://localhost/client-metadata.json".into(),
349                    redirect_uri: "http://localhost/auth/callback".into(),
350                    scope: "atproto transition:generic".into(),
351                    post_login_redirect: "/".into(),
352                },
353                database: crate::config::DatabaseConfig {
354                    url: "sqlite::memory:".into(),
355                },
356                jetstream: None,
357                firehose: None,
358                feed_generator: None,
359                labeler: None,
360                rate_limit: None,
361            }),
362            db,
363            http: reqwest::Client::new(),
364            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
365                reqwest::Client::new(),
366            )),
367            extensions: Arc::new(Extensions::new()),
368        };
369
370        assert!(state.try_extension::<NotRegistered>().is_none());
371        assert!(!state.has_extension::<NotRegistered>());
372    }
373
374    #[tokio::test]
375    #[should_panic(expected = "no value of that type was registered")]
376    async fn app_state_extension_panics_when_missing() {
377        struct NotRegistered;
378
379        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
380        let state = AppState {
381            config: Arc::new(crate::config::Config {
382                app: crate::config::AppConfig {
383                    name: "test".into(),
384                    host: "127.0.0.1".into(),
385                    port: 3000,
386                    secret_key: "secret".into(),
387                    cors_origins: vec![],
388                    environment: "development".into(),
389                    admin_dids: vec![],
390                },
391                auth: crate::config::AuthConfig {
392                    client_id: "http://localhost/client-metadata.json".into(),
393                    redirect_uri: "http://localhost/auth/callback".into(),
394                    scope: "atproto transition:generic".into(),
395                    post_login_redirect: "/".into(),
396                },
397                database: crate::config::DatabaseConfig {
398                    url: "sqlite::memory:".into(),
399                },
400                jetstream: None,
401                firehose: None,
402                feed_generator: None,
403                labeler: None,
404                rate_limit: None,
405            }),
406            db,
407            http: reqwest::Client::new(),
408            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
409                reqwest::Client::new(),
410            )),
411            extensions: Arc::new(Extensions::new()),
412        };
413
414        let _ = state.extension::<NotRegistered>();
415    }
416}