1use std::path::Path;
8
9use axum::http;
10use serde::Deserialize;
11use url::Url;
12
13#[derive(Debug, Clone, Deserialize)]
19pub struct Config {
20 pub app: AppConfig,
22
23 #[serde(default)]
25 pub auth: AuthConfig,
26
27 #[serde(default)]
29 pub database: DatabaseConfig,
30
31 pub jetstream: Option<JetstreamConfig>,
33
34 pub firehose: Option<FirehoseConfig>,
36
37 pub feed_generator: Option<FeedGeneratorConfig>,
39
40 pub labeler: Option<LabelerConfig>,
42
43 pub rate_limit: Option<RateLimitTomlConfig>,
45}
46
47#[derive(Debug, Clone, Deserialize)]
53pub struct AppConfig {
54 pub name: String,
56
57 #[serde(default = "default_host")]
59 pub host: String,
60
61 #[serde(default = "default_port")]
63 pub port: u16,
64
65 pub secret_key: String,
68
69 #[serde(default)]
72 pub cors_origins: Vec<String>,
73
74 #[serde(default = "default_environment")]
77 pub environment: String,
78
79 #[serde(default)]
82 pub admin_dids: Vec<String>,
83}
84
85impl Default for AppConfig {
86 fn default() -> Self {
87 Self {
88 name: String::new(),
89 host: default_host(),
90 port: default_port(),
91 secret_key: String::new(),
92 cors_origins: Vec::new(),
93 environment: default_environment(),
94 admin_dids: Vec::new(),
95 }
96 }
97}
98
99fn default_host() -> String {
100 "127.0.0.1".to_string()
101}
102
103fn default_port() -> u16 {
104 3000
105}
106
107fn default_environment() -> String {
108 "development".to_string()
109}
110
111#[derive(Debug, Clone, Deserialize)]
117pub struct AuthConfig {
118 #[serde(default = "default_client_id")]
120 pub client_id: String,
121
122 #[serde(default = "default_redirect_uri")]
124 pub redirect_uri: String,
125
126 #[serde(default = "default_scope")]
128 pub scope: String,
129
130 #[serde(default = "default_post_login_redirect")]
134 pub post_login_redirect: String,
135}
136
137impl Default for AuthConfig {
138 fn default() -> Self {
139 Self {
140 client_id: default_client_id(),
141 redirect_uri: default_redirect_uri(),
142 scope: default_scope(),
143 post_login_redirect: default_post_login_redirect(),
144 }
145 }
146}
147
148fn default_client_id() -> String {
149 "http://localhost:3000/client-metadata.json".to_string()
150}
151
152fn default_redirect_uri() -> String {
153 "http://localhost:3000/auth/callback".to_string()
154}
155
156fn default_scope() -> String {
157 "atproto transition:generic".to_string()
158}
159
160fn default_post_login_redirect() -> String {
161 "/".to_string()
162}
163
164#[derive(Debug, Clone, Deserialize)]
170pub struct DatabaseConfig {
171 #[serde(default = "default_database_url")]
173 pub url: String,
174}
175
176impl Default for DatabaseConfig {
177 fn default() -> Self {
178 Self {
179 url: default_database_url(),
180 }
181 }
182}
183
184fn default_database_url() -> String {
185 "sqlite://atrg.db".to_string()
186}
187
188#[derive(Debug, Clone, Deserialize)]
195pub struct JetstreamConfig {
196 pub host: String,
198
199 pub collections: Vec<String>,
201
202 pub zstd_dict: Option<String>,
204
205 #[serde(default = "default_channel_capacity")]
207 pub channel_capacity: usize,
208
209 #[serde(default = "default_max_lag_events")]
211 pub max_lag_events: usize,
212}
213
214fn default_channel_capacity() -> usize {
215 1024
216}
217
218fn default_max_lag_events() -> usize {
219 10_000
220}
221
222#[derive(Debug, Clone, Deserialize)]
229pub struct FirehoseConfig {
230 pub relay: String,
232
233 pub cursor: Option<i64>,
235
236 #[serde(default = "default_firehose_channel_capacity")]
238 pub channel_capacity: usize,
239}
240
241fn default_firehose_channel_capacity() -> usize {
242 1024
243}
244
245#[derive(Debug, Clone, Deserialize)]
252pub struct FeedGeneratorConfig {
253 pub did: String,
255}
256
257#[derive(Debug, Clone, Deserialize)]
264pub struct LabelerConfig {
265 pub did: String,
267
268 pub signing_key_path: Option<String>,
270
271 pub signing_key_base64: Option<String>,
273}
274
275#[derive(Debug, Clone, Deserialize)]
281pub struct RateLimitTomlConfig {
282 #[serde(default = "default_rps")]
284 pub requests_per_second: f64,
285
286 #[serde(default = "default_burst")]
288 pub burst: u32,
289
290 #[serde(default = "default_rate_limit_enabled")]
292 pub enabled: bool,
293}
294
295fn default_rps() -> f64 {
296 10.0
297}
298
299fn default_burst() -> u32 {
300 50
301}
302
303fn default_rate_limit_enabled() -> bool {
304 true
305}
306
307impl Config {
312 pub fn load(path: impl AsRef<Path>) -> anyhow::Result<Self> {
319 let path = path.as_ref();
320 let contents = std::fs::read_to_string(path).map_err(|e| {
321 anyhow::anyhow!(
322 "Failed to read config file '{}': {}. \
323 Make sure you're running from a directory that contains atrg.toml.",
324 path.display(),
325 e
326 )
327 })?;
328 Self::parse_toml(&contents)
329 }
330
331 pub fn parse_toml(toml_str: &str) -> anyhow::Result<Self> {
335 let config: Config = toml::from_str(toml_str).map_err(|e| {
336 let msg = e.to_string();
338 if msg.contains("missing field `app`") {
339 anyhow::anyhow!(
340 "Config error: the [app] section is required in atrg.toml. \
341 At minimum you need:\n\n\
342 [app]\n\
343 name = \"my-app\"\n\
344 secret_key = \"some-secret-key\"\n\n\
345 Full error: {e}"
346 )
347 } else {
348 anyhow::anyhow!("Failed to parse atrg.toml: {e}")
349 }
350 })?;
351
352 config.validate()?;
353 Ok(config)
354 }
355
356 fn validate(&self) -> anyhow::Result<()> {
358 if self.app.name.trim().is_empty() {
361 anyhow::bail!("Config error: app.name must not be empty");
362 }
363
364 if self.app.secret_key.trim().is_empty() {
365 anyhow::bail!("Config error: app.secret_key must not be empty");
366 }
367
368 if Url::parse(&self.auth.redirect_uri).is_err() {
370 anyhow::bail!(
371 "Config error: auth.redirect_uri '{}' is not a valid URL",
372 self.auth.redirect_uri
373 );
374 }
375
376 if Url::parse(&self.auth.client_id).is_err() {
378 anyhow::bail!(
379 "Config error: auth.client_id '{}' is not a valid URL",
380 self.auth.client_id
381 );
382 }
383
384 for origin in &self.app.cors_origins {
386 if origin == "*" {
387 continue; }
389 if origin.parse::<http::HeaderValue>().is_err() {
390 anyhow::bail!(
391 "Config error: cors_origins entry '{}' is not a valid origin",
392 origin
393 );
394 }
395 }
396
397 if self.app.secret_key.len() < 32 {
400 tracing::warn!(
401 "app.secret_key is only {} characters — use at least 32 for production",
402 self.app.secret_key.len()
403 );
404 }
405
406 let is_local = self.app.host == "localhost" || self.app.host == "127.0.0.1";
407 if self.app.secret_key == "CHANGE_ME_IN_PRODUCTION" && !is_local {
408 tracing::warn!(
409 "app.secret_key is the scaffold default and host is '{}' — \
410 change it before deploying!",
411 self.app.host
412 );
413 }
414
415 Ok(())
416 }
417}
418
419pub fn load_app_config<T: serde::de::DeserializeOwned>(section_name: &str) -> anyhow::Result<T> {
444 load_app_config_from_path::<T>(section_name, "atrg.toml")
445}
446
447pub fn load_app_config_from_path<T: serde::de::DeserializeOwned>(
449 section_name: &str,
450 path: &str,
451) -> anyhow::Result<T> {
452 let toml_str = std::fs::read_to_string(path)
453 .map_err(|e| anyhow::anyhow!("Failed to read {}: {}", path, e))?;
454 let toml_val: toml::Value = toml::from_str(&toml_str)
455 .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path, e))?;
456 let section = toml_val
457 .get(section_name)
458 .ok_or_else(|| anyhow::anyhow!("Missing [{}] section in {}", section_name, path))?;
459 let config: T = section.clone().try_into().map_err(|e| {
460 anyhow::anyhow!(
461 "Invalid [{}] configuration in {}: {}",
462 section_name,
463 path,
464 e
465 )
466 })?;
467 Ok(config)
468}
469
470#[cfg(test)]
475mod tests {
476 use super::*;
477
478 const FULL_CONFIG: &str = r#"
480[app]
481name = "my-app"
482host = "0.0.0.0"
483port = 8080
484secret_key = "super-secret-key-that-is-long-enough"
485cors_origins = ["http://localhost:5173", "https://example.com"]
486environment = "production"
487
488[auth]
489client_id = "https://myapp.example.com/client-metadata.json"
490redirect_uri = "https://myapp.example.com/auth/callback"
491scope = "atproto transition:generic"
492
493[database]
494url = "sqlite://prod.db"
495
496[jetstream]
497host = "jetstream1.us-east.bsky.network"
498collections = ["app.bsky.feed.post", "app.bsky.feed.like"]
499zstd_dict = "/tmp/dict.bin"
500channel_capacity = 2048
501max_lag_events = 20000
502"#;
503
504 const MINIMAL_CONFIG: &str = r#"
506[app]
507name = "tiny"
508secret_key = "abcdefghijklmnopqrstuvwxyz123456"
509"#;
510
511 #[test]
512 fn parse_full_config() {
513 let cfg = Config::parse_toml(FULL_CONFIG).expect("should parse full config");
514
515 assert_eq!(cfg.app.name, "my-app");
516 assert_eq!(cfg.app.host, "0.0.0.0");
517 assert_eq!(cfg.app.port, 8080);
518 assert_eq!(cfg.app.environment, "production");
519 assert_eq!(cfg.app.cors_origins.len(), 2);
520
521 assert_eq!(
522 cfg.auth.client_id,
523 "https://myapp.example.com/client-metadata.json"
524 );
525 assert_eq!(
526 cfg.auth.redirect_uri,
527 "https://myapp.example.com/auth/callback"
528 );
529 assert_eq!(cfg.auth.scope, "atproto transition:generic");
530
531 assert_eq!(cfg.database.url, "sqlite://prod.db");
532
533 let js = cfg.jetstream.expect("jetstream should be present");
534 assert_eq!(js.host, "jetstream1.us-east.bsky.network");
535 assert_eq!(js.collections.len(), 2);
536 assert_eq!(js.zstd_dict.as_deref(), Some("/tmp/dict.bin"));
537 assert_eq!(js.channel_capacity, 2048);
538 assert_eq!(js.max_lag_events, 20000);
539 }
540
541 #[test]
542 fn parse_minimal_config_defaults_applied() {
543 let cfg = Config::parse_toml(MINIMAL_CONFIG).expect("should parse minimal config");
544
545 assert_eq!(cfg.app.name, "tiny");
547
548 assert_eq!(cfg.app.host, "127.0.0.1");
550 assert_eq!(cfg.app.port, 3000);
551 assert_eq!(cfg.app.environment, "development");
552 assert!(cfg.app.cors_origins.is_empty());
553
554 assert_eq!(
555 cfg.auth.client_id,
556 "http://localhost:3000/client-metadata.json"
557 );
558 assert_eq!(cfg.auth.redirect_uri, "http://localhost:3000/auth/callback");
559 assert_eq!(cfg.auth.scope, "atproto transition:generic");
560
561 assert_eq!(cfg.database.url, "sqlite://atrg.db");
562 assert!(cfg.jetstream.is_none());
563 }
564
565 #[test]
566 fn missing_app_section_gives_friendly_error() {
567 let toml = r#"
568[database]
569url = "sqlite://test.db"
570"#;
571 let err = Config::parse_toml(toml).unwrap_err();
572 let msg = err.to_string();
573 assert!(
574 msg.contains("[app] section is required"),
575 "expected friendly error, got: {msg}"
576 );
577 }
578
579 #[test]
580 fn empty_name_is_rejected() {
581 let toml = r#"
582[app]
583name = ""
584secret_key = "abcdefghijklmnopqrstuvwxyz123456"
585"#;
586 let err = Config::parse_toml(toml).unwrap_err();
587 assert!(
588 err.to_string().contains("app.name must not be empty"),
589 "got: {}",
590 err
591 );
592 }
593
594 #[test]
595 fn empty_secret_key_is_rejected() {
596 let toml = r#"
597[app]
598name = "test"
599secret_key = ""
600"#;
601 let err = Config::parse_toml(toml).unwrap_err();
602 assert!(
603 err.to_string().contains("app.secret_key must not be empty"),
604 "got: {}",
605 err
606 );
607 }
608
609 #[test]
610 fn invalid_redirect_uri_is_rejected() {
611 let toml = r#"
612[app]
613name = "test"
614secret_key = "abcdefghijklmnopqrstuvwxyz123456"
615
616[auth]
617redirect_uri = "not a url at all"
618"#;
619 let err = Config::parse_toml(toml).unwrap_err();
620 let msg = err.to_string();
621 assert!(
622 msg.contains("auth.redirect_uri") && msg.contains("not a valid URL"),
623 "expected redirect_uri error, got: {msg}"
624 );
625 }
626
627 #[test]
628 fn invalid_client_id_is_rejected() {
629 let toml = r#"
630[app]
631name = "test"
632secret_key = "abcdefghijklmnopqrstuvwxyz123456"
633
634[auth]
635client_id = "not a url"
636"#;
637 let err = Config::parse_toml(toml).unwrap_err();
638 let msg = err.to_string();
639 assert!(
640 msg.contains("auth.client_id") && msg.contains("not a valid URL"),
641 "expected client_id error, got: {msg}"
642 );
643 }
644
645 #[test]
646 fn invalid_cors_origin_is_rejected() {
647 let toml = r#"
648[app]
649name = "test"
650secret_key = "abcdefghijklmnopqrstuvwxyz123456"
651cors_origins = ["http://ok.example.com", "\x00bad"]
652"#;
653 let err = Config::parse_toml(toml).unwrap_err();
654 let msg = err.to_string();
655 assert!(
656 msg.contains("cors_origins"),
657 "expected cors origin error, got: {msg}"
658 );
659 }
660
661 #[test]
662 fn wildcard_cors_origin_is_accepted() {
663 let toml = r#"
664[app]
665name = "test"
666secret_key = "abcdefghijklmnopqrstuvwxyz123456"
667cors_origins = ["*"]
668"#;
669 Config::parse_toml(toml).expect("wildcard should be accepted");
670 }
671
672 #[test]
673 fn parse_config_with_firehose_and_feeds() {
674 let toml = r#"
675[app]
676name = "test"
677secret_key = "abcdefghijklmnopqrstuvwxyz123456"
678
679[firehose]
680relay = "wss://bsky.network"
681
682[feed_generator]
683did = "did:web:feeds.example.com"
684
685[labeler]
686did = "did:web:labels.example.com"
687signing_key_path = "/etc/keys/labeler.pem"
688
689[rate_limit]
690requests_per_second = 20.0
691burst = 100
692enabled = true
693"#;
694 let cfg = Config::parse_toml(toml).unwrap();
695 let fh = cfg.firehose.unwrap();
696 assert_eq!(fh.relay, "wss://bsky.network");
697 assert!(fh.cursor.is_none());
698 assert_eq!(fh.channel_capacity, 1024);
699
700 let fg = cfg.feed_generator.unwrap();
701 assert_eq!(fg.did, "did:web:feeds.example.com");
702
703 let lb = cfg.labeler.unwrap();
704 assert_eq!(lb.did, "did:web:labels.example.com");
705 assert_eq!(lb.signing_key_path.unwrap(), "/etc/keys/labeler.pem");
706
707 let rl = cfg.rate_limit.unwrap();
708 assert!((rl.requests_per_second - 20.0).abs() < f64::EPSILON);
709 assert_eq!(rl.burst, 100);
710 }
711
712 #[test]
713 fn new_sections_are_all_optional() {
714 let toml = r#"
715[app]
716name = "test"
717secret_key = "abcdefghijklmnopqrstuvwxyz123456"
718"#;
719 let cfg = Config::parse_toml(toml).unwrap();
720 assert!(cfg.firehose.is_none());
721 assert!(cfg.feed_generator.is_none());
722 assert!(cfg.labeler.is_none());
723 assert!(cfg.rate_limit.is_none());
724 }
725
726 #[test]
727 fn jetstream_defaults_applied() {
728 let toml = r#"
729[app]
730name = "test"
731secret_key = "abcdefghijklmnopqrstuvwxyz123456"
732
733[jetstream]
734host = "jetstream1.us-east.bsky.network"
735collections = ["app.bsky.feed.post"]
736"#;
737 let cfg = Config::parse_toml(toml).unwrap();
738 let js = cfg.jetstream.unwrap();
739 assert_eq!(js.channel_capacity, 1024);
740 assert_eq!(js.max_lag_events, 10_000);
741 assert!(js.zstd_dict.is_none());
742 }
743}