Skip to main content

atrg_stream/
zstd_dict.rs

1//! ZSTD dictionary auto-fetch and caching.
2//!
3//! Supports loading a dictionary from a local file path or an HTTP(S) URL.
4//! URL-sourced dictionaries are cached under `~/.cache/atrg/` keyed by a
5//! SHA-256 hash of the URL so that restarts don't re-download.
6
7use std::path::PathBuf;
8
9use sha2::{Digest, Sha256};
10
11/// Load a ZSTD dictionary from a local path or URL.
12///
13/// - If `source` is a local file path → load directly.
14/// - If `source` is an HTTP(S) URL → download and cache under `~/.cache/atrg/`.
15/// - Returns the raw bytes of the dictionary.
16pub async fn load_dictionary(source: &str, http: &reqwest::Client) -> anyhow::Result<Vec<u8>> {
17    if source.starts_with("http://") || source.starts_with("https://") {
18        load_from_url(source, http).await
19    } else {
20        load_from_file(source).await
21    }
22}
23
24async fn load_from_file(path: &str) -> anyhow::Result<Vec<u8>> {
25    let data = tokio::fs::read(path).await?;
26    tracing::info!(path = %path, size = data.len(), "loaded ZSTD dictionary from file");
27    Ok(data)
28}
29
30async fn load_from_url(url: &str, http: &reqwest::Client) -> anyhow::Result<Vec<u8>> {
31    let cache_path = cache_path_for_url(url);
32
33    // Try cached first
34    if cache_path.exists() {
35        let data = tokio::fs::read(&cache_path).await?;
36        tracing::info!(
37            path = %cache_path.display(),
38            size = data.len(),
39            "loaded ZSTD dictionary from cache"
40        );
41        return Ok(data);
42    }
43
44    // Download
45    tracing::info!(url = %url, "downloading ZSTD dictionary");
46    let resp = http.get(url).send().await?;
47    if !resp.status().is_success() {
48        anyhow::bail!("failed to download ZSTD dictionary: HTTP {}", resp.status());
49    }
50    let data = resp.bytes().await?.to_vec();
51
52    // Cache
53    if let Some(parent) = cache_path.parent() {
54        tokio::fs::create_dir_all(parent).await?;
55    }
56    tokio::fs::write(&cache_path, &data).await?;
57    tracing::info!(
58        path = %cache_path.display(),
59        size = data.len(),
60        "cached ZSTD dictionary"
61    );
62
63    Ok(data)
64}
65
66/// Compute the cache file path for a URL.
67///
68/// The path is deterministic: same URL always maps to the same file.
69pub fn cache_path_for_url(url: &str) -> PathBuf {
70    let hash = hex_encode(&Sha256::digest(url.as_bytes()));
71    let cache_dir = dirs::cache_dir()
72        .unwrap_or_else(|| PathBuf::from(".cache"))
73        .join("atrg");
74    cache_dir.join(format!("jetstream-dict-{}.bin", &hash[..16]))
75}
76
77fn hex_encode(data: &[u8]) -> String {
78    data.iter().map(|b| format!("{:02x}", b)).collect()
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn cache_path_is_deterministic() {
87        let p1 = cache_path_for_url("https://example.com/dict.bin");
88        let p2 = cache_path_for_url("https://example.com/dict.bin");
89        assert_eq!(p1, p2);
90    }
91
92    #[test]
93    fn different_urls_different_paths() {
94        let p1 = cache_path_for_url("https://example.com/a.bin");
95        let p2 = cache_path_for_url("https://example.com/b.bin");
96        assert_ne!(p1, p2);
97    }
98
99    #[test]
100    fn cache_path_under_atrg_dir() {
101        let p = cache_path_for_url("https://example.com/dict.bin");
102        let s = p.to_string_lossy();
103        assert!(s.contains("atrg"), "expected 'atrg' in path: {s}");
104        assert!(
105            s.contains("jetstream-dict-"),
106            "expected 'jetstream-dict-' in path: {s}"
107        );
108    }
109
110    #[test]
111    fn hex_encode_works() {
112        assert_eq!(hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
113        assert_eq!(hex_encode(&[0x00, 0xff]), "00ff");
114    }
115}