1use std::collections::HashMap;
8use std::net::IpAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use axum::body::Body;
13use axum::http::StatusCode;
14use axum::response::IntoResponse;
15use axum::Json;
16use tokio::sync::Mutex;
17
18#[derive(Debug, Clone)]
24pub struct RateLimitConfig {
25 pub requests_per_second: f64,
27 pub burst: u32,
29 pub enabled: bool,
31}
32
33impl Default for RateLimitConfig {
34 fn default() -> Self {
35 Self {
36 requests_per_second: 10.0,
37 burst: 50,
38 enabled: true,
39 }
40 }
41}
42
43struct TokenBucket {
48 tokens: f64,
49 last_refill: Instant,
50 max_tokens: f64,
51 refill_rate: f64,
52}
53
54impl TokenBucket {
55 fn new(max_tokens: f64, refill_rate: f64) -> Self {
56 Self {
57 tokens: max_tokens,
58 last_refill: Instant::now(),
59 max_tokens,
60 refill_rate,
61 }
62 }
63
64 fn refill(&mut self) {
66 let now = Instant::now();
67 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
68 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
69 self.last_refill = now;
70 }
71
72 fn try_consume(&mut self) -> bool {
74 self.refill();
75 if self.tokens >= 1.0 {
76 self.tokens -= 1.0;
77 true
78 } else {
79 false
80 }
81 }
82
83 fn retry_after(&self) -> f64 {
85 if self.tokens >= 1.0 {
86 return 0.0;
87 }
88 let deficit = 1.0 - self.tokens;
89 deficit / self.refill_rate
90 }
91}
92
93#[derive(Clone)]
101pub struct RateLimiter {
102 buckets: Arc<Mutex<HashMap<IpAddr, TokenBucket>>>,
103 config: RateLimitConfig,
104}
105
106impl RateLimiter {
107 pub fn new(config: RateLimitConfig) -> Self {
109 Self {
110 buckets: Arc::new(Mutex::new(HashMap::new())),
111 config,
112 }
113 }
114
115 pub async fn check(&self, ip: IpAddr) -> Result<(), f64> {
120 if !self.config.enabled {
121 return Ok(());
122 }
123
124 let mut buckets = self.buckets.lock().await;
125 let bucket = buckets.entry(ip).or_insert_with(|| {
126 TokenBucket::new(
127 f64::from(self.config.burst),
128 self.config.requests_per_second,
129 )
130 });
131
132 if bucket.try_consume() {
133 Ok(())
134 } else {
135 Err(bucket.retry_after())
136 }
137 }
138
139 pub async fn cleanup(&self, max_age: Duration) {
144 let mut buckets = self.buckets.lock().await;
145 let cutoff = Instant::now() - max_age;
146 buckets.retain(|_ip, bucket| bucket.last_refill > cutoff);
147 }
148}
149
150pub fn rate_limit_response(retry_after_secs: f64) -> axum::response::Response<Body> {
156 let retry_after_ceil = retry_after_secs.ceil() as u64;
157
158 let body = serde_json::json!({
159 "error": "rate_limit_exceeded",
160 "message": format!(
161 "Rate limit exceeded. Retry after {} seconds.",
162 retry_after_ceil
163 ),
164 });
165
166 let mut response = (StatusCode::TOO_MANY_REQUESTS, Json(body)).into_response();
167 if let Ok(val) = axum::http::HeaderValue::from_str(&retry_after_ceil.to_string()) {
168 response.headers_mut().insert("Retry-After", val);
169 }
170 response
171}
172
173#[cfg(test)]
178mod tests {
179 use super::*;
180 use std::net::Ipv4Addr;
181
182 #[test]
183 fn token_bucket_allows_within_burst() {
184 let mut bucket = TokenBucket::new(5.0, 1.0);
185 for _ in 0..5 {
186 assert!(bucket.try_consume(), "should allow requests within burst");
187 }
188 assert!(!bucket.try_consume(), "should deny after burst exhausted");
189 }
190
191 #[test]
192 fn token_bucket_retry_after_positive_when_empty() {
193 let mut bucket = TokenBucket::new(1.0, 10.0);
194 assert!(bucket.try_consume());
195 assert!(!bucket.try_consume());
196 let retry = bucket.retry_after();
197 assert!(
198 retry > 0.0,
199 "retry_after should be positive when empty, got {}",
200 retry
201 );
202 assert!(
204 retry <= 0.15,
205 "retry_after should be small at high refill rate, got {}",
206 retry
207 );
208 }
209
210 #[tokio::test]
211 async fn rate_limiter_allows_burst() {
212 let config = RateLimitConfig {
213 requests_per_second: 1.0,
214 burst: 3,
215 enabled: true,
216 };
217 let limiter = RateLimiter::new(config);
218 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
219
220 for i in 0..3 {
221 assert!(
222 limiter.check(ip).await.is_ok(),
223 "request {} should be allowed within burst",
224 i
225 );
226 }
227 assert!(
228 limiter.check(ip).await.is_err(),
229 "request beyond burst should be denied"
230 );
231 }
232
233 #[tokio::test]
234 async fn rate_limiter_disabled_allows_all() {
235 let config = RateLimitConfig {
236 requests_per_second: 1.0,
237 burst: 1,
238 enabled: false,
239 };
240 let limiter = RateLimiter::new(config);
241 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
242
243 for _ in 0..100 {
245 assert!(limiter.check(ip).await.is_ok());
246 }
247 }
248
249 #[tokio::test]
250 async fn cleanup_removes_old_entries() {
251 let config = RateLimitConfig {
252 requests_per_second: 10.0,
253 burst: 10,
254 enabled: true,
255 };
256 let limiter = RateLimiter::new(config);
257 let ip = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1));
258
259 let _ = limiter.check(ip).await;
261
262 limiter.cleanup(Duration::from_secs(0)).await;
264
265 let buckets = limiter.buckets.lock().await;
266 assert!(
267 buckets.is_empty(),
268 "cleanup should have removed the stale entry"
269 );
270 }
271
272 #[test]
273 fn default_config_values() {
274 let cfg = RateLimitConfig::default();
275 assert!((cfg.requests_per_second - 10.0).abs() < f64::EPSILON);
276 assert_eq!(cfg.burst, 50);
277 assert!(cfg.enabled);
278 }
279
280 #[tokio::test]
281 async fn rate_limit_response_returns_429() {
282 let response = rate_limit_response(1.5);
283 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
284
285 let retry_after = response
286 .headers()
287 .get("retry-after")
288 .unwrap()
289 .to_str()
290 .unwrap();
291 assert_eq!(retry_after, "2"); let body = axum::body::to_bytes(response.into_body(), usize::MAX)
295 .await
296 .unwrap();
297 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
298 assert_eq!(json["error"], "rate_limit_exceeded");
299 }
300}