Skip to main content

atrg_core/
rate_limit.rs

1//! Token-bucket rate limiting middleware.
2//!
3//! Provides a per-IP token-bucket rate limiter that can be used as Axum
4//! middleware or checked manually in handlers. Disabled by default when
5//! [`RateLimitConfig::enabled`] is `false`.
6
7use std::collections::HashMap;
8use std::net::IpAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use axum::body::Body;
13use axum::http::StatusCode;
14use axum::response::IntoResponse;
15use axum::Json;
16use tokio::sync::Mutex;
17
18// ---------------------------------------------------------------------------
19// Config
20// ---------------------------------------------------------------------------
21
22/// Configuration for the token-bucket rate limiter.
23#[derive(Debug, Clone)]
24pub struct RateLimitConfig {
25    /// Sustained request rate (tokens added per second).
26    pub requests_per_second: f64,
27    /// Maximum burst size (bucket capacity).
28    pub burst: u32,
29    /// Whether rate limiting is active. When `false`, all requests are allowed.
30    pub enabled: bool,
31}
32
33impl Default for RateLimitConfig {
34    fn default() -> Self {
35        Self {
36            requests_per_second: 10.0,
37            burst: 50,
38            enabled: true,
39        }
40    }
41}
42
43// ---------------------------------------------------------------------------
44// Token bucket (internal)
45// ---------------------------------------------------------------------------
46
47struct TokenBucket {
48    tokens: f64,
49    last_refill: Instant,
50    max_tokens: f64,
51    refill_rate: f64,
52}
53
54impl TokenBucket {
55    fn new(max_tokens: f64, refill_rate: f64) -> Self {
56        Self {
57            tokens: max_tokens,
58            last_refill: Instant::now(),
59            max_tokens,
60            refill_rate,
61        }
62    }
63
64    /// Refill tokens based on elapsed time since last refill.
65    fn refill(&mut self) {
66        let now = Instant::now();
67        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
68        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
69        self.last_refill = now;
70    }
71
72    /// Try to consume one token. Returns `true` if allowed.
73    fn try_consume(&mut self) -> bool {
74        self.refill();
75        if self.tokens >= 1.0 {
76            self.tokens -= 1.0;
77            true
78        } else {
79            false
80        }
81    }
82
83    /// Seconds until the next token becomes available.
84    fn retry_after(&self) -> f64 {
85        if self.tokens >= 1.0 {
86            return 0.0;
87        }
88        let deficit = 1.0 - self.tokens;
89        deficit / self.refill_rate
90    }
91}
92
93// ---------------------------------------------------------------------------
94// Rate limiter
95// ---------------------------------------------------------------------------
96
97/// Per-IP token-bucket rate limiter.
98///
99/// Thread-safe and cheaply cloneable (inner state is `Arc<Mutex<_>>`).
100#[derive(Clone)]
101pub struct RateLimiter {
102    buckets: Arc<Mutex<HashMap<IpAddr, TokenBucket>>>,
103    config: RateLimitConfig,
104}
105
106impl RateLimiter {
107    /// Create a new rate limiter with the given configuration.
108    pub fn new(config: RateLimitConfig) -> Self {
109        Self {
110            buckets: Arc::new(Mutex::new(HashMap::new())),
111            config,
112        }
113    }
114
115    /// Check whether a request from `ip` is allowed.
116    ///
117    /// Returns `Ok(())` if the request is within limits, or `Err(retry_after)`
118    /// with the number of seconds the client should wait before retrying.
119    pub async fn check(&self, ip: IpAddr) -> Result<(), f64> {
120        if !self.config.enabled {
121            return Ok(());
122        }
123
124        let mut buckets = self.buckets.lock().await;
125        let bucket = buckets.entry(ip).or_insert_with(|| {
126            TokenBucket::new(
127                f64::from(self.config.burst),
128                self.config.requests_per_second,
129            )
130        });
131
132        if bucket.try_consume() {
133            Ok(())
134        } else {
135            Err(bucket.retry_after())
136        }
137    }
138
139    /// Remove buckets that have not been used for longer than `max_age`.
140    ///
141    /// Call this periodically (e.g. every few minutes) to prevent unbounded
142    /// memory growth from unique IP addresses.
143    pub async fn cleanup(&self, max_age: Duration) {
144        let mut buckets = self.buckets.lock().await;
145        let cutoff = Instant::now() - max_age;
146        buckets.retain(|_ip, bucket| bucket.last_refill > cutoff);
147    }
148}
149
150// ---------------------------------------------------------------------------
151// HTTP response helper
152// ---------------------------------------------------------------------------
153
154/// Build a `429 Too Many Requests` response with AT-Protocol-style JSON body.
155pub fn rate_limit_response(retry_after_secs: f64) -> axum::response::Response<Body> {
156    let retry_after_ceil = retry_after_secs.ceil() as u64;
157
158    let body = serde_json::json!({
159        "error": "rate_limit_exceeded",
160        "message": format!(
161            "Rate limit exceeded. Retry after {} seconds.",
162            retry_after_ceil
163        ),
164    });
165
166    let mut response = (StatusCode::TOO_MANY_REQUESTS, Json(body)).into_response();
167    if let Ok(val) = axum::http::HeaderValue::from_str(&retry_after_ceil.to_string()) {
168        response.headers_mut().insert("Retry-After", val);
169    }
170    response
171}
172
173// ---------------------------------------------------------------------------
174// Tests
175// ---------------------------------------------------------------------------
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use std::net::Ipv4Addr;
181
182    #[test]
183    fn token_bucket_allows_within_burst() {
184        let mut bucket = TokenBucket::new(5.0, 1.0);
185        for _ in 0..5 {
186            assert!(bucket.try_consume(), "should allow requests within burst");
187        }
188        assert!(!bucket.try_consume(), "should deny after burst exhausted");
189    }
190
191    #[test]
192    fn token_bucket_retry_after_positive_when_empty() {
193        let mut bucket = TokenBucket::new(1.0, 10.0);
194        assert!(bucket.try_consume());
195        assert!(!bucket.try_consume());
196        let retry = bucket.retry_after();
197        assert!(
198            retry > 0.0,
199            "retry_after should be positive when empty, got {}",
200            retry
201        );
202        // At 10 tokens/sec, retry should be <= 0.1s
203        assert!(
204            retry <= 0.15,
205            "retry_after should be small at high refill rate, got {}",
206            retry
207        );
208    }
209
210    #[tokio::test]
211    async fn rate_limiter_allows_burst() {
212        let config = RateLimitConfig {
213            requests_per_second: 1.0,
214            burst: 3,
215            enabled: true,
216        };
217        let limiter = RateLimiter::new(config);
218        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
219
220        for i in 0..3 {
221            assert!(
222                limiter.check(ip).await.is_ok(),
223                "request {} should be allowed within burst",
224                i
225            );
226        }
227        assert!(
228            limiter.check(ip).await.is_err(),
229            "request beyond burst should be denied"
230        );
231    }
232
233    #[tokio::test]
234    async fn rate_limiter_disabled_allows_all() {
235        let config = RateLimitConfig {
236            requests_per_second: 1.0,
237            burst: 1,
238            enabled: false,
239        };
240        let limiter = RateLimiter::new(config);
241        let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
242
243        // Even 100 requests should be fine when disabled.
244        for _ in 0..100 {
245            assert!(limiter.check(ip).await.is_ok());
246        }
247    }
248
249    #[tokio::test]
250    async fn cleanup_removes_old_entries() {
251        let config = RateLimitConfig {
252            requests_per_second: 10.0,
253            burst: 10,
254            enabled: true,
255        };
256        let limiter = RateLimiter::new(config);
257        let ip = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1));
258
259        // Generate an entry.
260        let _ = limiter.check(ip).await;
261
262        // Cleanup with a zero max_age removes everything.
263        limiter.cleanup(Duration::from_secs(0)).await;
264
265        let buckets = limiter.buckets.lock().await;
266        assert!(
267            buckets.is_empty(),
268            "cleanup should have removed the stale entry"
269        );
270    }
271
272    #[test]
273    fn default_config_values() {
274        let cfg = RateLimitConfig::default();
275        assert!((cfg.requests_per_second - 10.0).abs() < f64::EPSILON);
276        assert_eq!(cfg.burst, 50);
277        assert!(cfg.enabled);
278    }
279
280    #[tokio::test]
281    async fn rate_limit_response_returns_429() {
282        let response = rate_limit_response(1.5);
283        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
284
285        let retry_after = response
286            .headers()
287            .get("retry-after")
288            .unwrap()
289            .to_str()
290            .unwrap();
291        assert_eq!(retry_after, "2"); // ceil(1.5) = 2
292
293        // Check body contains error
294        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
295            .await
296            .unwrap();
297        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
298        assert_eq!(json["error"], "rate_limit_exceeded");
299    }
300}