1use std::path::PathBuf;
8
9use sha2::{Digest, Sha256};
10
11pub async fn load_dictionary(source: &str, http: &reqwest::Client) -> anyhow::Result<Vec<u8>> {
17 if source.starts_with("http://") || source.starts_with("https://") {
18 load_from_url(source, http).await
19 } else {
20 load_from_file(source).await
21 }
22}
23
24async fn load_from_file(path: &str) -> anyhow::Result<Vec<u8>> {
25 let data = tokio::fs::read(path).await?;
26 tracing::info!(path = %path, size = data.len(), "loaded ZSTD dictionary from file");
27 Ok(data)
28}
29
30async fn load_from_url(url: &str, http: &reqwest::Client) -> anyhow::Result<Vec<u8>> {
31 let cache_path = cache_path_for_url(url);
32
33 if cache_path.exists() {
35 let data = tokio::fs::read(&cache_path).await?;
36 tracing::info!(
37 path = %cache_path.display(),
38 size = data.len(),
39 "loaded ZSTD dictionary from cache"
40 );
41 return Ok(data);
42 }
43
44 tracing::info!(url = %url, "downloading ZSTD dictionary");
46 let resp = http.get(url).send().await?;
47 if !resp.status().is_success() {
48 anyhow::bail!("failed to download ZSTD dictionary: HTTP {}", resp.status());
49 }
50 let data = resp.bytes().await?.to_vec();
51
52 if let Some(parent) = cache_path.parent() {
54 tokio::fs::create_dir_all(parent).await?;
55 }
56 tokio::fs::write(&cache_path, &data).await?;
57 tracing::info!(
58 path = %cache_path.display(),
59 size = data.len(),
60 "cached ZSTD dictionary"
61 );
62
63 Ok(data)
64}
65
66pub fn cache_path_for_url(url: &str) -> PathBuf {
70 let hash = hex_encode(&Sha256::digest(url.as_bytes()));
71 let cache_dir = dirs::cache_dir()
72 .unwrap_or_else(|| PathBuf::from(".cache"))
73 .join("atrg");
74 cache_dir.join(format!("jetstream-dict-{}.bin", &hash[..16]))
75}
76
77fn hex_encode(data: &[u8]) -> String {
78 data.iter().map(|b| format!("{:02x}", b)).collect()
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn cache_path_is_deterministic() {
87 let p1 = cache_path_for_url("https://example.com/dict.bin");
88 let p2 = cache_path_for_url("https://example.com/dict.bin");
89 assert_eq!(p1, p2);
90 }
91
92 #[test]
93 fn different_urls_different_paths() {
94 let p1 = cache_path_for_url("https://example.com/a.bin");
95 let p2 = cache_path_for_url("https://example.com/b.bin");
96 assert_ne!(p1, p2);
97 }
98
99 #[test]
100 fn cache_path_under_atrg_dir() {
101 let p = cache_path_for_url("https://example.com/dict.bin");
102 let s = p.to_string_lossy();
103 assert!(s.contains("atrg"), "expected 'atrg' in path: {s}");
104 assert!(
105 s.contains("jetstream-dict-"),
106 "expected 'jetstream-dict-' in path: {s}"
107 );
108 }
109
110 #[test]
111 fn hex_encode_works() {
112 assert_eq!(hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
113 assert_eq!(hex_encode(&[0x00, 0xff]), "00ff");
114 }
115}