1use std::path::Path;
8
9use axum::http;
10use serde::Deserialize;
11use url::Url;
12
13#[derive(Debug, Clone, Deserialize)]
19pub struct Config {
20 pub app: AppConfig,
22
23 #[serde(default)]
25 pub auth: AuthConfig,
26
27 #[serde(default)]
29 pub database: DatabaseConfig,
30
31 pub jetstream: Option<JetstreamConfig>,
33
34 pub firehose: Option<FirehoseConfig>,
36
37 pub feed_generator: Option<FeedGeneratorConfig>,
39
40 pub labeler: Option<LabelerConfig>,
42
43 pub rate_limit: Option<RateLimitTomlConfig>,
45}
46
47#[derive(Debug, Clone, Deserialize)]
53pub struct AppConfig {
54 pub name: String,
56
57 #[serde(default = "default_host")]
59 pub host: String,
60
61 #[serde(default = "default_port")]
63 pub port: u16,
64
65 pub secret_key: String,
68
69 #[serde(default)]
72 pub cors_origins: Vec<String>,
73
74 #[serde(default = "default_environment")]
77 pub environment: String,
78}
79
80impl Default for AppConfig {
81 fn default() -> Self {
82 Self {
83 name: String::new(),
84 host: default_host(),
85 port: default_port(),
86 secret_key: String::new(),
87 cors_origins: Vec::new(),
88 environment: default_environment(),
89 }
90 }
91}
92
93fn default_host() -> String {
94 "127.0.0.1".to_string()
95}
96
97fn default_port() -> u16 {
98 3000
99}
100
101fn default_environment() -> String {
102 "development".to_string()
103}
104
105#[derive(Debug, Clone, Deserialize)]
111pub struct AuthConfig {
112 #[serde(default = "default_client_id")]
114 pub client_id: String,
115
116 #[serde(default = "default_redirect_uri")]
118 pub redirect_uri: String,
119
120 #[serde(default = "default_scope")]
122 pub scope: String,
123}
124
125impl Default for AuthConfig {
126 fn default() -> Self {
127 Self {
128 client_id: default_client_id(),
129 redirect_uri: default_redirect_uri(),
130 scope: default_scope(),
131 }
132 }
133}
134
135fn default_client_id() -> String {
136 "http://localhost:3000/client-metadata.json".to_string()
137}
138
139fn default_redirect_uri() -> String {
140 "http://localhost:3000/auth/callback".to_string()
141}
142
143fn default_scope() -> String {
144 "atproto transition:generic".to_string()
145}
146
147#[derive(Debug, Clone, Deserialize)]
153pub struct DatabaseConfig {
154 #[serde(default = "default_database_url")]
156 pub url: String,
157}
158
159impl Default for DatabaseConfig {
160 fn default() -> Self {
161 Self {
162 url: default_database_url(),
163 }
164 }
165}
166
167fn default_database_url() -> String {
168 "sqlite://atrg.db".to_string()
169}
170
171#[derive(Debug, Clone, Deserialize)]
178pub struct JetstreamConfig {
179 pub host: String,
181
182 pub collections: Vec<String>,
184
185 pub zstd_dict: Option<String>,
187
188 #[serde(default = "default_channel_capacity")]
190 pub channel_capacity: usize,
191
192 #[serde(default = "default_max_lag_events")]
194 pub max_lag_events: usize,
195}
196
197fn default_channel_capacity() -> usize {
198 1024
199}
200
201fn default_max_lag_events() -> usize {
202 10_000
203}
204
205#[derive(Debug, Clone, Deserialize)]
212pub struct FirehoseConfig {
213 pub relay: String,
215
216 pub cursor: Option<i64>,
218
219 #[serde(default = "default_firehose_channel_capacity")]
221 pub channel_capacity: usize,
222}
223
224fn default_firehose_channel_capacity() -> usize {
225 1024
226}
227
228#[derive(Debug, Clone, Deserialize)]
235pub struct FeedGeneratorConfig {
236 pub did: String,
238}
239
240#[derive(Debug, Clone, Deserialize)]
247pub struct LabelerConfig {
248 pub did: String,
250
251 pub signing_key_path: Option<String>,
253
254 pub signing_key_base64: Option<String>,
256}
257
258#[derive(Debug, Clone, Deserialize)]
264pub struct RateLimitTomlConfig {
265 #[serde(default = "default_rps")]
267 pub requests_per_second: f64,
268
269 #[serde(default = "default_burst")]
271 pub burst: u32,
272
273 #[serde(default = "default_rate_limit_enabled")]
275 pub enabled: bool,
276}
277
278fn default_rps() -> f64 {
279 10.0
280}
281
282fn default_burst() -> u32 {
283 50
284}
285
286fn default_rate_limit_enabled() -> bool {
287 true
288}
289
290impl Config {
295 pub fn load(path: impl AsRef<Path>) -> anyhow::Result<Self> {
302 let path = path.as_ref();
303 let contents = std::fs::read_to_string(path).map_err(|e| {
304 anyhow::anyhow!(
305 "Failed to read config file '{}': {}. \
306 Make sure you're running from a directory that contains atrg.toml.",
307 path.display(),
308 e
309 )
310 })?;
311 Self::parse_toml(&contents)
312 }
313
314 pub fn parse_toml(toml_str: &str) -> anyhow::Result<Self> {
318 let config: Config = toml::from_str(toml_str).map_err(|e| {
319 let msg = e.to_string();
321 if msg.contains("missing field `app`") {
322 anyhow::anyhow!(
323 "Config error: the [app] section is required in atrg.toml. \
324 At minimum you need:\n\n\
325 [app]\n\
326 name = \"my-app\"\n\
327 secret_key = \"some-secret-key\"\n\n\
328 Full error: {e}"
329 )
330 } else {
331 anyhow::anyhow!("Failed to parse atrg.toml: {e}")
332 }
333 })?;
334
335 config.validate()?;
336 Ok(config)
337 }
338
339 fn validate(&self) -> anyhow::Result<()> {
341 if self.app.name.trim().is_empty() {
344 anyhow::bail!("Config error: app.name must not be empty");
345 }
346
347 if self.app.secret_key.trim().is_empty() {
348 anyhow::bail!("Config error: app.secret_key must not be empty");
349 }
350
351 if Url::parse(&self.auth.redirect_uri).is_err() {
353 anyhow::bail!(
354 "Config error: auth.redirect_uri '{}' is not a valid URL",
355 self.auth.redirect_uri
356 );
357 }
358
359 if Url::parse(&self.auth.client_id).is_err() {
361 anyhow::bail!(
362 "Config error: auth.client_id '{}' is not a valid URL",
363 self.auth.client_id
364 );
365 }
366
367 for origin in &self.app.cors_origins {
369 if origin == "*" {
370 continue; }
372 if origin.parse::<http::HeaderValue>().is_err() {
373 anyhow::bail!(
374 "Config error: cors_origins entry '{}' is not a valid origin",
375 origin
376 );
377 }
378 }
379
380 if self.app.secret_key.len() < 32 {
383 tracing::warn!(
384 "app.secret_key is only {} characters — use at least 32 for production",
385 self.app.secret_key.len()
386 );
387 }
388
389 let is_local = self.app.host == "localhost" || self.app.host == "127.0.0.1";
390 if self.app.secret_key == "CHANGE_ME_IN_PRODUCTION" && !is_local {
391 tracing::warn!(
392 "app.secret_key is the scaffold default and host is '{}' — \
393 change it before deploying!",
394 self.app.host
395 );
396 }
397
398 Ok(())
399 }
400}
401
402#[cfg(test)]
407mod tests {
408 use super::*;
409
410 const FULL_CONFIG: &str = r#"
412[app]
413name = "my-app"
414host = "0.0.0.0"
415port = 8080
416secret_key = "super-secret-key-that-is-long-enough"
417cors_origins = ["http://localhost:5173", "https://example.com"]
418environment = "production"
419
420[auth]
421client_id = "https://myapp.example.com/client-metadata.json"
422redirect_uri = "https://myapp.example.com/auth/callback"
423scope = "atproto transition:generic"
424
425[database]
426url = "sqlite://prod.db"
427
428[jetstream]
429host = "jetstream1.us-east.bsky.network"
430collections = ["app.bsky.feed.post", "app.bsky.feed.like"]
431zstd_dict = "/tmp/dict.bin"
432channel_capacity = 2048
433max_lag_events = 20000
434"#;
435
436 const MINIMAL_CONFIG: &str = r#"
438[app]
439name = "tiny"
440secret_key = "abcdefghijklmnopqrstuvwxyz123456"
441"#;
442
443 #[test]
444 fn parse_full_config() {
445 let cfg = Config::parse_toml(FULL_CONFIG).expect("should parse full config");
446
447 assert_eq!(cfg.app.name, "my-app");
448 assert_eq!(cfg.app.host, "0.0.0.0");
449 assert_eq!(cfg.app.port, 8080);
450 assert_eq!(cfg.app.environment, "production");
451 assert_eq!(cfg.app.cors_origins.len(), 2);
452
453 assert_eq!(
454 cfg.auth.client_id,
455 "https://myapp.example.com/client-metadata.json"
456 );
457 assert_eq!(
458 cfg.auth.redirect_uri,
459 "https://myapp.example.com/auth/callback"
460 );
461 assert_eq!(cfg.auth.scope, "atproto transition:generic");
462
463 assert_eq!(cfg.database.url, "sqlite://prod.db");
464
465 let js = cfg.jetstream.expect("jetstream should be present");
466 assert_eq!(js.host, "jetstream1.us-east.bsky.network");
467 assert_eq!(js.collections.len(), 2);
468 assert_eq!(js.zstd_dict.as_deref(), Some("/tmp/dict.bin"));
469 assert_eq!(js.channel_capacity, 2048);
470 assert_eq!(js.max_lag_events, 20000);
471 }
472
473 #[test]
474 fn parse_minimal_config_defaults_applied() {
475 let cfg = Config::parse_toml(MINIMAL_CONFIG).expect("should parse minimal config");
476
477 assert_eq!(cfg.app.name, "tiny");
479
480 assert_eq!(cfg.app.host, "127.0.0.1");
482 assert_eq!(cfg.app.port, 3000);
483 assert_eq!(cfg.app.environment, "development");
484 assert!(cfg.app.cors_origins.is_empty());
485
486 assert_eq!(
487 cfg.auth.client_id,
488 "http://localhost:3000/client-metadata.json"
489 );
490 assert_eq!(cfg.auth.redirect_uri, "http://localhost:3000/auth/callback");
491 assert_eq!(cfg.auth.scope, "atproto transition:generic");
492
493 assert_eq!(cfg.database.url, "sqlite://atrg.db");
494 assert!(cfg.jetstream.is_none());
495 }
496
497 #[test]
498 fn missing_app_section_gives_friendly_error() {
499 let toml = r#"
500[database]
501url = "sqlite://test.db"
502"#;
503 let err = Config::parse_toml(toml).unwrap_err();
504 let msg = err.to_string();
505 assert!(
506 msg.contains("[app] section is required"),
507 "expected friendly error, got: {msg}"
508 );
509 }
510
511 #[test]
512 fn empty_name_is_rejected() {
513 let toml = r#"
514[app]
515name = ""
516secret_key = "abcdefghijklmnopqrstuvwxyz123456"
517"#;
518 let err = Config::parse_toml(toml).unwrap_err();
519 assert!(
520 err.to_string().contains("app.name must not be empty"),
521 "got: {}",
522 err
523 );
524 }
525
526 #[test]
527 fn empty_secret_key_is_rejected() {
528 let toml = r#"
529[app]
530name = "test"
531secret_key = ""
532"#;
533 let err = Config::parse_toml(toml).unwrap_err();
534 assert!(
535 err.to_string().contains("app.secret_key must not be empty"),
536 "got: {}",
537 err
538 );
539 }
540
541 #[test]
542 fn invalid_redirect_uri_is_rejected() {
543 let toml = r#"
544[app]
545name = "test"
546secret_key = "abcdefghijklmnopqrstuvwxyz123456"
547
548[auth]
549redirect_uri = "not a url at all"
550"#;
551 let err = Config::parse_toml(toml).unwrap_err();
552 let msg = err.to_string();
553 assert!(
554 msg.contains("auth.redirect_uri") && msg.contains("not a valid URL"),
555 "expected redirect_uri error, got: {msg}"
556 );
557 }
558
559 #[test]
560 fn invalid_client_id_is_rejected() {
561 let toml = r#"
562[app]
563name = "test"
564secret_key = "abcdefghijklmnopqrstuvwxyz123456"
565
566[auth]
567client_id = "not a url"
568"#;
569 let err = Config::parse_toml(toml).unwrap_err();
570 let msg = err.to_string();
571 assert!(
572 msg.contains("auth.client_id") && msg.contains("not a valid URL"),
573 "expected client_id error, got: {msg}"
574 );
575 }
576
577 #[test]
578 fn invalid_cors_origin_is_rejected() {
579 let toml = r#"
580[app]
581name = "test"
582secret_key = "abcdefghijklmnopqrstuvwxyz123456"
583cors_origins = ["http://ok.example.com", "\x00bad"]
584"#;
585 let err = Config::parse_toml(toml).unwrap_err();
586 let msg = err.to_string();
587 assert!(
588 msg.contains("cors_origins"),
589 "expected cors origin error, got: {msg}"
590 );
591 }
592
593 #[test]
594 fn wildcard_cors_origin_is_accepted() {
595 let toml = r#"
596[app]
597name = "test"
598secret_key = "abcdefghijklmnopqrstuvwxyz123456"
599cors_origins = ["*"]
600"#;
601 Config::parse_toml(toml).expect("wildcard should be accepted");
602 }
603
604 #[test]
605 fn parse_config_with_firehose_and_feeds() {
606 let toml = r#"
607[app]
608name = "test"
609secret_key = "abcdefghijklmnopqrstuvwxyz123456"
610
611[firehose]
612relay = "wss://bsky.network"
613
614[feed_generator]
615did = "did:web:feeds.example.com"
616
617[labeler]
618did = "did:web:labels.example.com"
619signing_key_path = "/etc/keys/labeler.pem"
620
621[rate_limit]
622requests_per_second = 20.0
623burst = 100
624enabled = true
625"#;
626 let cfg = Config::parse_toml(toml).unwrap();
627 let fh = cfg.firehose.unwrap();
628 assert_eq!(fh.relay, "wss://bsky.network");
629 assert!(fh.cursor.is_none());
630 assert_eq!(fh.channel_capacity, 1024);
631
632 let fg = cfg.feed_generator.unwrap();
633 assert_eq!(fg.did, "did:web:feeds.example.com");
634
635 let lb = cfg.labeler.unwrap();
636 assert_eq!(lb.did, "did:web:labels.example.com");
637 assert_eq!(lb.signing_key_path.unwrap(), "/etc/keys/labeler.pem");
638
639 let rl = cfg.rate_limit.unwrap();
640 assert!((rl.requests_per_second - 20.0).abs() < f64::EPSILON);
641 assert_eq!(rl.burst, 100);
642 }
643
644 #[test]
645 fn new_sections_are_all_optional() {
646 let toml = r#"
647[app]
648name = "test"
649secret_key = "abcdefghijklmnopqrstuvwxyz123456"
650"#;
651 let cfg = Config::parse_toml(toml).unwrap();
652 assert!(cfg.firehose.is_none());
653 assert!(cfg.feed_generator.is_none());
654 assert!(cfg.labeler.is_none());
655 assert!(cfg.rate_limit.is_none());
656 }
657
658 #[test]
659 fn jetstream_defaults_applied() {
660 let toml = r#"
661[app]
662name = "test"
663secret_key = "abcdefghijklmnopqrstuvwxyz123456"
664
665[jetstream]
666host = "jetstream1.us-east.bsky.network"
667collections = ["app.bsky.feed.post"]
668"#;
669 let cfg = Config::parse_toml(toml).unwrap();
670 let js = cfg.jetstream.unwrap();
671 assert_eq!(js.channel_capacity, 1024);
672 assert_eq!(js.max_lag_events, 10_000);
673 assert!(js.zstd_dict.is_none());
674 }
675}