Skip to main content

atrg_core/
config.rs

1//! Configuration types and loader for `atrg.toml`.
2//!
3//! The [`Config`] struct is the single source of truth for all framework
4//! configuration. It is loaded once at startup by [`Config::load`] and then
5//! wrapped in an `Arc` inside [`AppState`](crate::state::AppState).
6
7use std::path::Path;
8
9use axum::http;
10use serde::Deserialize;
11use url::Url;
12
13// ---------------------------------------------------------------------------
14// Top-level config
15// ---------------------------------------------------------------------------
16
17/// Root configuration, deserialized from `atrg.toml`.
18#[derive(Debug, Clone, Deserialize)]
19pub struct Config {
20    /// Application-level settings.
21    pub app: AppConfig,
22
23    /// OAuth / authentication settings.
24    #[serde(default)]
25    pub auth: AuthConfig,
26
27    /// Database connection settings.
28    #[serde(default)]
29    pub database: DatabaseConfig,
30
31    /// Optional Jetstream real-time event consumer settings.
32    pub jetstream: Option<JetstreamConfig>,
33
34    /// Optional relay firehose consumer settings.
35    pub firehose: Option<FirehoseConfig>,
36
37    /// Optional feed generator settings.
38    pub feed_generator: Option<FeedGeneratorConfig>,
39
40    /// Optional labeler settings.
41    pub labeler: Option<LabelerConfig>,
42
43    /// Optional rate limiting settings.
44    pub rate_limit: Option<RateLimitTomlConfig>,
45}
46
47// ---------------------------------------------------------------------------
48// AppConfig
49// ---------------------------------------------------------------------------
50
51/// `[app]` section of `atrg.toml`.
52#[derive(Debug, Clone, Deserialize)]
53pub struct AppConfig {
54    /// Human-readable application name. Must be non-empty.
55    pub name: String,
56
57    /// Bind address for the HTTP server.
58    #[serde(default = "default_host")]
59    pub host: String,
60
61    /// Bind port for the HTTP server.
62    #[serde(default = "default_port")]
63    pub port: u16,
64
65    /// Secret key used for session signing. Should be ≥ 32 characters in
66    /// production.
67    pub secret_key: String,
68
69    /// Allowed CORS origins. An empty list means same-origin only. A single
70    /// `"*"` entry enables the permissive wildcard.
71    #[serde(default)]
72    pub cors_origins: Vec<String>,
73
74    /// `"development"` or `"production"`. Affects cookie flags and security
75    /// headers.
76    #[serde(default = "default_environment")]
77    pub environment: String,
78}
79
80impl Default for AppConfig {
81    fn default() -> Self {
82        Self {
83            name: String::new(),
84            host: default_host(),
85            port: default_port(),
86            secret_key: String::new(),
87            cors_origins: Vec::new(),
88            environment: default_environment(),
89        }
90    }
91}
92
93fn default_host() -> String {
94    "127.0.0.1".to_string()
95}
96
97fn default_port() -> u16 {
98    3000
99}
100
101fn default_environment() -> String {
102    "development".to_string()
103}
104
105// ---------------------------------------------------------------------------
106// AuthConfig
107// ---------------------------------------------------------------------------
108
109/// `[auth]` section of `atrg.toml`.
110#[derive(Debug, Clone, Deserialize)]
111pub struct AuthConfig {
112    /// AT Protocol OAuth client ID (must be a valid URL).
113    #[serde(default = "default_client_id")]
114    pub client_id: String,
115
116    /// OAuth redirect URI (must be a valid URL).
117    #[serde(default = "default_redirect_uri")]
118    pub redirect_uri: String,
119
120    /// OAuth scope string.
121    #[serde(default = "default_scope")]
122    pub scope: String,
123}
124
125impl Default for AuthConfig {
126    fn default() -> Self {
127        Self {
128            client_id: default_client_id(),
129            redirect_uri: default_redirect_uri(),
130            scope: default_scope(),
131        }
132    }
133}
134
135fn default_client_id() -> String {
136    "http://localhost:3000/client-metadata.json".to_string()
137}
138
139fn default_redirect_uri() -> String {
140    "http://localhost:3000/auth/callback".to_string()
141}
142
143fn default_scope() -> String {
144    "atproto transition:generic".to_string()
145}
146
147// ---------------------------------------------------------------------------
148// DatabaseConfig
149// ---------------------------------------------------------------------------
150
151/// `[database]` section of `atrg.toml`.
152#[derive(Debug, Clone, Deserialize)]
153pub struct DatabaseConfig {
154    /// SQLite connection URL.
155    #[serde(default = "default_database_url")]
156    pub url: String,
157}
158
159impl Default for DatabaseConfig {
160    fn default() -> Self {
161        Self {
162            url: default_database_url(),
163        }
164    }
165}
166
167fn default_database_url() -> String {
168    "sqlite://atrg.db".to_string()
169}
170
171// ---------------------------------------------------------------------------
172// JetstreamConfig
173// ---------------------------------------------------------------------------
174
175/// `[jetstream]` section of `atrg.toml`. Only present when Jetstream
176/// consumption is enabled.
177#[derive(Debug, Clone, Deserialize)]
178pub struct JetstreamConfig {
179    /// Jetstream relay host, e.g. `"jetstream1.us-east.bsky.network"`.
180    pub host: String,
181
182    /// NSID collections to subscribe to, e.g. `["app.bsky.feed.post"]`.
183    pub collections: Vec<String>,
184
185    /// Optional path or URL to a ZSTD dictionary for decompression.
186    pub zstd_dict: Option<String>,
187
188    /// Bounded back-pressure channel size.
189    #[serde(default = "default_channel_capacity")]
190    pub channel_capacity: usize,
191
192    /// Event lag threshold before shedding/warning.
193    #[serde(default = "default_max_lag_events")]
194    pub max_lag_events: usize,
195}
196
197fn default_channel_capacity() -> usize {
198    1024
199}
200
201fn default_max_lag_events() -> usize {
202    10_000
203}
204
205// ---------------------------------------------------------------------------
206// FirehoseConfig
207// ---------------------------------------------------------------------------
208
209/// `[firehose]` section of `atrg.toml`. Present when relay firehose
210/// consumption is enabled (full `com.atproto.sync.subscribeRepos`).
211#[derive(Debug, Clone, Deserialize)]
212pub struct FirehoseConfig {
213    /// Relay WebSocket URL, e.g. `"wss://bsky.network"`.
214    pub relay: String,
215
216    /// Sequence number to resume from. `None` means start from head.
217    pub cursor: Option<i64>,
218
219    /// Bounded back-pressure channel capacity.
220    #[serde(default = "default_firehose_channel_capacity")]
221    pub channel_capacity: usize,
222}
223
224fn default_firehose_channel_capacity() -> usize {
225    1024
226}
227
228// ---------------------------------------------------------------------------
229// FeedGeneratorConfig
230// ---------------------------------------------------------------------------
231
232/// `[feed_generator]` section of `atrg.toml`. Present when the server
233/// acts as an AT Protocol feed generator.
234#[derive(Debug, Clone, Deserialize)]
235pub struct FeedGeneratorConfig {
236    /// DID of the feed generator service (typically `did:web:<hostname>`).
237    pub did: String,
238}
239
240// ---------------------------------------------------------------------------
241// LabelerConfig
242// ---------------------------------------------------------------------------
243
244/// `[labeler]` section of `atrg.toml`. Present when the server acts as
245/// an AT Protocol labeler.
246#[derive(Debug, Clone, Deserialize)]
247pub struct LabelerConfig {
248    /// DID of the labeler service.
249    pub did: String,
250
251    /// Path to the signing key file (PEM format).
252    pub signing_key_path: Option<String>,
253
254    /// Inline signing key (base64-encoded, for env var injection).
255    pub signing_key_base64: Option<String>,
256}
257
258// ---------------------------------------------------------------------------
259// RateLimitConfig (TOML)
260// ---------------------------------------------------------------------------
261
262/// `[rate_limit]` section of `atrg.toml`.
263#[derive(Debug, Clone, Deserialize)]
264pub struct RateLimitTomlConfig {
265    /// Maximum sustained requests per second.
266    #[serde(default = "default_rps")]
267    pub requests_per_second: f64,
268
269    /// Maximum burst size.
270    #[serde(default = "default_burst")]
271    pub burst: u32,
272
273    /// Whether rate limiting is enabled (default: true in production).
274    #[serde(default = "default_rate_limit_enabled")]
275    pub enabled: bool,
276}
277
278fn default_rps() -> f64 {
279    10.0
280}
281
282fn default_burst() -> u32 {
283    50
284}
285
286fn default_rate_limit_enabled() -> bool {
287    true
288}
289
290// ---------------------------------------------------------------------------
291// Loading & validation
292// ---------------------------------------------------------------------------
293
294impl Config {
295    /// Load and validate a [`Config`] from the TOML file at `path`.
296    ///
297    /// # Errors
298    ///
299    /// Returns an error if the file cannot be read, the TOML is malformed, or
300    /// mandatory validation checks fail (e.g. empty `app.name`).
301    pub fn load(path: impl AsRef<Path>) -> anyhow::Result<Self> {
302        let path = path.as_ref();
303        let contents = std::fs::read_to_string(path).map_err(|e| {
304            anyhow::anyhow!(
305                "Failed to read config file '{}': {}. \
306                 Make sure you're running from a directory that contains atrg.toml.",
307                path.display(),
308                e
309            )
310        })?;
311        Self::parse_toml(&contents)
312    }
313
314    /// Parse and validate a [`Config`] from a TOML string.
315    ///
316    /// This is the inner implementation shared by [`Config::load`] and tests.
317    pub fn parse_toml(toml_str: &str) -> anyhow::Result<Self> {
318        let config: Config = toml::from_str(toml_str).map_err(|e| {
319            // Provide a friendlier message when a required section is missing.
320            let msg = e.to_string();
321            if msg.contains("missing field `app`") {
322                anyhow::anyhow!(
323                    "Config error: the [app] section is required in atrg.toml. \
324                     At minimum you need:\n\n\
325                     [app]\n\
326                     name = \"my-app\"\n\
327                     secret_key = \"some-secret-key\"\n\n\
328                     Full error: {e}"
329                )
330            } else {
331                anyhow::anyhow!("Failed to parse atrg.toml: {e}")
332            }
333        })?;
334
335        config.validate()?;
336        Ok(config)
337    }
338
339    /// Run all validation checks and emit warnings.
340    fn validate(&self) -> anyhow::Result<()> {
341        // -- hard errors ------------------------------------------------
342
343        if self.app.name.trim().is_empty() {
344            anyhow::bail!("Config error: app.name must not be empty");
345        }
346
347        if self.app.secret_key.trim().is_empty() {
348            anyhow::bail!("Config error: app.secret_key must not be empty");
349        }
350
351        // Validate redirect_uri is a proper URL.
352        if Url::parse(&self.auth.redirect_uri).is_err() {
353            anyhow::bail!(
354                "Config error: auth.redirect_uri '{}' is not a valid URL",
355                self.auth.redirect_uri
356            );
357        }
358
359        // Validate client_id is a proper URL.
360        if Url::parse(&self.auth.client_id).is_err() {
361            anyhow::bail!(
362                "Config error: auth.client_id '{}' is not a valid URL",
363                self.auth.client_id
364            );
365        }
366
367        // Validate each CORS origin entry.
368        for origin in &self.app.cors_origins {
369            if origin == "*" {
370                continue; // wildcard is fine
371            }
372            if origin.parse::<http::HeaderValue>().is_err() {
373                anyhow::bail!(
374                    "Config error: cors_origins entry '{}' is not a valid origin",
375                    origin
376                );
377            }
378        }
379
380        // -- soft warnings ---------------------------------------------
381
382        if self.app.secret_key.len() < 32 {
383            tracing::warn!(
384                "app.secret_key is only {} characters — use at least 32 for production",
385                self.app.secret_key.len()
386            );
387        }
388
389        let is_local = self.app.host == "localhost" || self.app.host == "127.0.0.1";
390        if self.app.secret_key == "CHANGE_ME_IN_PRODUCTION" && !is_local {
391            tracing::warn!(
392                "app.secret_key is the scaffold default and host is '{}' — \
393                 change it before deploying!",
394                self.app.host
395            );
396        }
397
398        Ok(())
399    }
400}
401
402// ---------------------------------------------------------------------------
403// Tests
404// ---------------------------------------------------------------------------
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    /// A full config fixture exercising every field.
411    const FULL_CONFIG: &str = r#"
412[app]
413name = "my-app"
414host = "0.0.0.0"
415port = 8080
416secret_key = "super-secret-key-that-is-long-enough"
417cors_origins = ["http://localhost:5173", "https://example.com"]
418environment = "production"
419
420[auth]
421client_id = "https://myapp.example.com/client-metadata.json"
422redirect_uri = "https://myapp.example.com/auth/callback"
423scope = "atproto transition:generic"
424
425[database]
426url = "sqlite://prod.db"
427
428[jetstream]
429host = "jetstream1.us-east.bsky.network"
430collections = ["app.bsky.feed.post", "app.bsky.feed.like"]
431zstd_dict = "/tmp/dict.bin"
432channel_capacity = 2048
433max_lag_events = 20000
434"#;
435
436    /// Minimal config — only the required fields.
437    const MINIMAL_CONFIG: &str = r#"
438[app]
439name = "tiny"
440secret_key = "abcdefghijklmnopqrstuvwxyz123456"
441"#;
442
443    #[test]
444    fn parse_full_config() {
445        let cfg = Config::parse_toml(FULL_CONFIG).expect("should parse full config");
446
447        assert_eq!(cfg.app.name, "my-app");
448        assert_eq!(cfg.app.host, "0.0.0.0");
449        assert_eq!(cfg.app.port, 8080);
450        assert_eq!(cfg.app.environment, "production");
451        assert_eq!(cfg.app.cors_origins.len(), 2);
452
453        assert_eq!(
454            cfg.auth.client_id,
455            "https://myapp.example.com/client-metadata.json"
456        );
457        assert_eq!(
458            cfg.auth.redirect_uri,
459            "https://myapp.example.com/auth/callback"
460        );
461        assert_eq!(cfg.auth.scope, "atproto transition:generic");
462
463        assert_eq!(cfg.database.url, "sqlite://prod.db");
464
465        let js = cfg.jetstream.expect("jetstream should be present");
466        assert_eq!(js.host, "jetstream1.us-east.bsky.network");
467        assert_eq!(js.collections.len(), 2);
468        assert_eq!(js.zstd_dict.as_deref(), Some("/tmp/dict.bin"));
469        assert_eq!(js.channel_capacity, 2048);
470        assert_eq!(js.max_lag_events, 20000);
471    }
472
473    #[test]
474    fn parse_minimal_config_defaults_applied() {
475        let cfg = Config::parse_toml(MINIMAL_CONFIG).expect("should parse minimal config");
476
477        // Explicit values
478        assert_eq!(cfg.app.name, "tiny");
479
480        // Defaults
481        assert_eq!(cfg.app.host, "127.0.0.1");
482        assert_eq!(cfg.app.port, 3000);
483        assert_eq!(cfg.app.environment, "development");
484        assert!(cfg.app.cors_origins.is_empty());
485
486        assert_eq!(
487            cfg.auth.client_id,
488            "http://localhost:3000/client-metadata.json"
489        );
490        assert_eq!(cfg.auth.redirect_uri, "http://localhost:3000/auth/callback");
491        assert_eq!(cfg.auth.scope, "atproto transition:generic");
492
493        assert_eq!(cfg.database.url, "sqlite://atrg.db");
494        assert!(cfg.jetstream.is_none());
495    }
496
497    #[test]
498    fn missing_app_section_gives_friendly_error() {
499        let toml = r#"
500[database]
501url = "sqlite://test.db"
502"#;
503        let err = Config::parse_toml(toml).unwrap_err();
504        let msg = err.to_string();
505        assert!(
506            msg.contains("[app] section is required"),
507            "expected friendly error, got: {msg}"
508        );
509    }
510
511    #[test]
512    fn empty_name_is_rejected() {
513        let toml = r#"
514[app]
515name = ""
516secret_key = "abcdefghijklmnopqrstuvwxyz123456"
517"#;
518        let err = Config::parse_toml(toml).unwrap_err();
519        assert!(
520            err.to_string().contains("app.name must not be empty"),
521            "got: {}",
522            err
523        );
524    }
525
526    #[test]
527    fn empty_secret_key_is_rejected() {
528        let toml = r#"
529[app]
530name = "test"
531secret_key = ""
532"#;
533        let err = Config::parse_toml(toml).unwrap_err();
534        assert!(
535            err.to_string().contains("app.secret_key must not be empty"),
536            "got: {}",
537            err
538        );
539    }
540
541    #[test]
542    fn invalid_redirect_uri_is_rejected() {
543        let toml = r#"
544[app]
545name = "test"
546secret_key = "abcdefghijklmnopqrstuvwxyz123456"
547
548[auth]
549redirect_uri = "not a url at all"
550"#;
551        let err = Config::parse_toml(toml).unwrap_err();
552        let msg = err.to_string();
553        assert!(
554            msg.contains("auth.redirect_uri") && msg.contains("not a valid URL"),
555            "expected redirect_uri error, got: {msg}"
556        );
557    }
558
559    #[test]
560    fn invalid_client_id_is_rejected() {
561        let toml = r#"
562[app]
563name = "test"
564secret_key = "abcdefghijklmnopqrstuvwxyz123456"
565
566[auth]
567client_id = "not a url"
568"#;
569        let err = Config::parse_toml(toml).unwrap_err();
570        let msg = err.to_string();
571        assert!(
572            msg.contains("auth.client_id") && msg.contains("not a valid URL"),
573            "expected client_id error, got: {msg}"
574        );
575    }
576
577    #[test]
578    fn invalid_cors_origin_is_rejected() {
579        let toml = r#"
580[app]
581name = "test"
582secret_key = "abcdefghijklmnopqrstuvwxyz123456"
583cors_origins = ["http://ok.example.com", "\x00bad"]
584"#;
585        let err = Config::parse_toml(toml).unwrap_err();
586        let msg = err.to_string();
587        assert!(
588            msg.contains("cors_origins"),
589            "expected cors origin error, got: {msg}"
590        );
591    }
592
593    #[test]
594    fn wildcard_cors_origin_is_accepted() {
595        let toml = r#"
596[app]
597name = "test"
598secret_key = "abcdefghijklmnopqrstuvwxyz123456"
599cors_origins = ["*"]
600"#;
601        Config::parse_toml(toml).expect("wildcard should be accepted");
602    }
603
604    #[test]
605    fn parse_config_with_firehose_and_feeds() {
606        let toml = r#"
607[app]
608name = "test"
609secret_key = "abcdefghijklmnopqrstuvwxyz123456"
610
611[firehose]
612relay = "wss://bsky.network"
613
614[feed_generator]
615did = "did:web:feeds.example.com"
616
617[labeler]
618did = "did:web:labels.example.com"
619signing_key_path = "/etc/keys/labeler.pem"
620
621[rate_limit]
622requests_per_second = 20.0
623burst = 100
624enabled = true
625"#;
626        let cfg = Config::parse_toml(toml).unwrap();
627        let fh = cfg.firehose.unwrap();
628        assert_eq!(fh.relay, "wss://bsky.network");
629        assert!(fh.cursor.is_none());
630        assert_eq!(fh.channel_capacity, 1024);
631
632        let fg = cfg.feed_generator.unwrap();
633        assert_eq!(fg.did, "did:web:feeds.example.com");
634
635        let lb = cfg.labeler.unwrap();
636        assert_eq!(lb.did, "did:web:labels.example.com");
637        assert_eq!(lb.signing_key_path.unwrap(), "/etc/keys/labeler.pem");
638
639        let rl = cfg.rate_limit.unwrap();
640        assert!((rl.requests_per_second - 20.0).abs() < f64::EPSILON);
641        assert_eq!(rl.burst, 100);
642    }
643
644    #[test]
645    fn new_sections_are_all_optional() {
646        let toml = r#"
647[app]
648name = "test"
649secret_key = "abcdefghijklmnopqrstuvwxyz123456"
650"#;
651        let cfg = Config::parse_toml(toml).unwrap();
652        assert!(cfg.firehose.is_none());
653        assert!(cfg.feed_generator.is_none());
654        assert!(cfg.labeler.is_none());
655        assert!(cfg.rate_limit.is_none());
656    }
657
658    #[test]
659    fn jetstream_defaults_applied() {
660        let toml = r#"
661[app]
662name = "test"
663secret_key = "abcdefghijklmnopqrstuvwxyz123456"
664
665[jetstream]
666host = "jetstream1.us-east.bsky.network"
667collections = ["app.bsky.feed.post"]
668"#;
669        let cfg = Config::parse_toml(toml).unwrap();
670        let js = cfg.jetstream.unwrap();
671        assert_eq!(js.channel_capacity, 1024);
672        assert_eq!(js.max_lag_events, 10_000);
673        assert!(js.zstd_dict.is_none());
674    }
675}