1use std::sync::atomic::Ordering;
9use std::sync::Arc;
10
11use futures::StreamExt;
12use tokio::sync::mpsc;
13use tokio_tungstenite::tungstenite::Message;
14
15use atrg_db::DbPool;
16
17use crate::backoff::Backoff;
18use crate::event::JetstreamEvent;
19use crate::metrics::MetricsCounter;
20use crate::EventHandler;
21use crate::StreamConfig;
22
23pub async fn spawn_consumer<S>(
41 config: &StreamConfig,
42 state: S,
43 handler: EventHandler<S>,
44) -> anyhow::Result<tokio::task::JoinHandle<()>>
45where
46 S: Clone + Send + Sync + 'static,
47{
48 let metrics = MetricsCounter::new();
49 let channel_capacity = config.channel_capacity;
50 let max_lag = config.max_lag_events;
51
52 let url = build_ws_url(&config.host, &config.collections, None);
54
55 tracing::info!(
56 url = %url,
57 channel_capacity = channel_capacity,
58 max_lag = max_lag,
59 "starting Jetstream consumer"
60 );
61
62 let (tx, rx) = mpsc::channel::<JetstreamEvent>(channel_capacity);
63
64 spawn_dispatcher(rx, handler, state, metrics.clone());
66
67 let handle = spawn_reader(url, tx, metrics, max_lag);
69
70 Ok(handle)
71}
72
73pub async fn spawn_consumer_with_cursor<S>(
88 config: &StreamConfig,
89 pool: &DbPool,
90 consumer_id: &str,
91 state: S,
92 handler: EventHandler<S>,
93) -> anyhow::Result<tokio::task::JoinHandle<()>>
94where
95 S: Clone + Send + Sync + 'static,
96{
97 crate::cursor::ensure_cursor_table(pool).await?;
99
100 let stored_cursor = crate::cursor::load_cursor(pool, consumer_id).await?;
102
103 let initial_cursor = match config.cursor.as_deref() {
104 Some("live") | None => None,
105 Some("auto") => stored_cursor,
106 Some(numeric) => numeric.parse::<i64>().ok(),
107 };
108
109 if let Some(cursor) = initial_cursor {
110 tracing::info!(
111 cursor = cursor,
112 consumer_id = consumer_id,
113 "resuming Jetstream from stored cursor"
114 );
115 } else {
116 tracing::info!(
117 consumer_id = consumer_id,
118 "starting Jetstream from live (no cursor)"
119 );
120 }
121
122 let url = build_ws_url(&config.host, &config.collections, initial_cursor);
124
125 let metrics = MetricsCounter::new();
126 let channel_capacity = config.channel_capacity;
127 let max_lag = config.max_lag_events;
128
129 tracing::info!(
130 url = %url,
131 channel_capacity = channel_capacity,
132 max_lag = max_lag,
133 "starting Jetstream consumer with cursor persistence"
134 );
135
136 let (tx, rx) = mpsc::channel::<JetstreamEvent>(channel_capacity);
137
138 let pool_clone = pool.clone();
140 let cid = consumer_id.to_string();
141 spawn_cursor_dispatcher(rx, handler, state, metrics.clone(), pool_clone, cid);
142
143 let handle = spawn_reader(url, tx, metrics, max_lag);
145
146 Ok(handle)
147}
148
149fn build_ws_url(host: &str, collections: &[String], cursor: Option<i64>) -> String {
154 let mut params: Vec<String> = collections
155 .iter()
156 .map(|c| format!("wantedCollections={}", c))
157 .collect();
158
159 if let Some(cursor_us) = cursor {
160 params.push(format!("cursor={}", cursor_us));
161 }
162
163 if params.is_empty() {
164 format!("wss://{}/subscribe", host)
165 } else {
166 format!("wss://{}/subscribe?{}", host, params.join("&"))
167 }
168}
169
170fn spawn_dispatcher<S>(
172 mut rx: mpsc::Receiver<JetstreamEvent>,
173 handler: EventHandler<S>,
174 state: S,
175 metrics: Arc<MetricsCounter>,
176) where
177 S: Clone + Send + Sync + 'static,
178{
179 tokio::spawn(async move {
180 while let Some(event) = rx.recv().await {
181 if let Err(e) = handler(event, state.clone()).await {
182 tracing::error!(error = %e, "Jetstream event handler error");
183 metrics.errors.fetch_add(1, Ordering::Relaxed);
184 }
185 }
186 tracing::info!("Jetstream dispatcher task exiting");
187 });
188}
189
190const CURSOR_SAVE_INTERVAL: u64 = 100;
195
196fn spawn_cursor_dispatcher<S>(
201 mut rx: mpsc::Receiver<JetstreamEvent>,
202 handler: EventHandler<S>,
203 state: S,
204 metrics: Arc<MetricsCounter>,
205 pool: DbPool,
206 consumer_id: String,
207) where
208 S: Clone + Send + Sync + 'static,
209{
210 tokio::spawn(async move {
211 let mut event_count: u64 = 0;
212 let mut last_time_us: Option<i64> = None;
213
214 while let Some(event) = rx.recv().await {
215 let time_us = event.time_us;
216
217 if let Err(e) = handler(event, state.clone()).await {
218 tracing::error!(error = %e, "Jetstream event handler error");
219 metrics.errors.fetch_add(1, Ordering::Relaxed);
220 }
221
222 if time_us > 0 {
224 last_time_us = Some(time_us);
225 }
226 event_count += 1;
227
228 if event_count % CURSOR_SAVE_INTERVAL == 0 {
230 if let Some(cursor) = last_time_us {
231 if let Err(e) = crate::cursor::save_cursor(&pool, &consumer_id, cursor).await {
232 tracing::warn!(error = %e, "failed to save Jetstream cursor");
233 }
234 }
235 }
236 }
237
238 if let Some(cursor) = last_time_us {
240 if let Err(e) = crate::cursor::save_cursor(&pool, &consumer_id, cursor).await {
241 tracing::warn!(error = %e, "failed to save final Jetstream cursor");
242 } else {
243 tracing::info!(cursor = cursor, "saved final Jetstream cursor on shutdown");
244 }
245 }
246
247 tracing::info!("Jetstream cursor dispatcher task exiting");
248 });
249}
250
251fn spawn_reader(
253 url: String,
254 tx: mpsc::Sender<JetstreamEvent>,
255 metrics: Arc<MetricsCounter>,
256 max_lag: usize,
257) -> tokio::task::JoinHandle<()> {
258 tokio::spawn(async move {
259 let mut backoff = Backoff::new();
260
261 loop {
262 match connect_and_read(&url, &tx, &metrics, max_lag).await {
263 Ok(()) => {
264 tracing::info!("Jetstream WebSocket closed cleanly");
265 }
266 Err(e) => {
267 metrics.reconnects.fetch_add(1, Ordering::Relaxed);
268 tracing::warn!(error = %e, "Jetstream connection error, will reconnect");
269 }
270 }
271
272 let delay = backoff.next_delay();
273 metrics
274 .current_backoff_ms
275 .store(delay.as_millis() as u64, Ordering::Relaxed);
276 tracing::info!(delay_ms = %delay.as_millis(), "reconnecting to Jetstream");
277 tokio::time::sleep(delay).await;
278 }
279 })
280}
281
282async fn connect_and_read(
287 url: &str,
288 tx: &mpsc::Sender<JetstreamEvent>,
289 metrics: &Arc<MetricsCounter>,
290 max_lag: usize,
291) -> anyhow::Result<()> {
292 let (ws_stream, _response) = tokio_tungstenite::connect_async(url).await?;
293 tracing::info!(url = %url, "connected to Jetstream");
294
295 metrics.current_backoff_ms.store(0, Ordering::Relaxed);
297
298 let (_write, mut read) = ws_stream.split();
299
300 while let Some(msg_result) = read.next().await {
301 let msg = msg_result?;
302 match msg {
303 Message::Text(text) => {
304 handle_text_message(&text, tx, metrics, max_lag);
305 }
306 Message::Close(_) => {
307 tracing::info!("Jetstream WebSocket closed by server");
308 break;
309 }
310 _ => {}
313 }
314 }
315
316 Ok(())
317}
318
319fn handle_text_message(
321 text: &str,
322 tx: &mpsc::Sender<JetstreamEvent>,
323 metrics: &Arc<MetricsCounter>,
324 max_lag: usize,
325) {
326 metrics.events_received.fetch_add(1, Ordering::Relaxed);
327 update_last_event_timestamp(metrics);
328
329 let event = match serde_json::from_str::<JetstreamEvent>(text) {
330 Ok(ev) => ev,
331 Err(e) => {
332 tracing::debug!(error = %e, "failed to parse Jetstream event");
333 metrics.errors.fetch_add(1, Ordering::Relaxed);
334 return;
335 }
336 };
337
338 let remaining = tx.capacity();
341 if remaining == 0 {
342 metrics.events_dropped.fetch_add(1, Ordering::Relaxed);
343 if tx.max_capacity() >= max_lag {
344 tracing::warn!(
345 max_lag = max_lag,
346 "Jetstream consumer lagging beyond threshold, dropping event"
347 );
348 }
349 return;
350 }
351
352 if tx.try_send(event).is_err() {
355 metrics.events_dropped.fetch_add(1, Ordering::Relaxed);
356 tracing::debug!("Jetstream channel full on try_send, dropping event");
357 }
358}
359
360fn update_last_event_timestamp(metrics: &Arc<MetricsCounter>) {
362 let now_ms = std::time::SystemTime::now()
363 .duration_since(std::time::UNIX_EPOCH)
364 .unwrap_or_default()
365 .as_millis() as u64;
366 metrics.last_event_at.store(now_ms, Ordering::Relaxed);
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn build_ws_url_no_collections() {
375 let url = build_ws_url("jetstream1.example.com", &[], None);
376 assert_eq!(url, "wss://jetstream1.example.com/subscribe");
377 }
378
379 #[test]
380 fn build_ws_url_single_collection() {
381 let url = build_ws_url(
382 "jetstream1.example.com",
383 &["app.bsky.feed.post".to_string()],
384 None,
385 );
386 assert_eq!(
387 url,
388 "wss://jetstream1.example.com/subscribe?wantedCollections=app.bsky.feed.post"
389 );
390 }
391
392 #[test]
393 fn build_ws_url_multiple_collections() {
394 let url = build_ws_url(
395 "jetstream1.example.com",
396 &[
397 "app.bsky.feed.post".to_string(),
398 "app.bsky.feed.like".to_string(),
399 ],
400 None,
401 );
402 assert_eq!(
403 url,
404 "wss://jetstream1.example.com/subscribe?wantedCollections=app.bsky.feed.post&wantedCollections=app.bsky.feed.like"
405 );
406 }
407
408 #[test]
409 fn build_ws_url_with_cursor_no_collections() {
410 let url = build_ws_url("jetstream1.example.com", &[], Some(1700000000000000));
411 assert_eq!(
412 url,
413 "wss://jetstream1.example.com/subscribe?cursor=1700000000000000"
414 );
415 }
416
417 #[test]
418 fn build_ws_url_with_cursor_and_collections() {
419 let url = build_ws_url(
420 "jetstream1.example.com",
421 &["app.bsky.feed.post".to_string()],
422 Some(1700000000000000),
423 );
424 assert_eq!(
425 url,
426 "wss://jetstream1.example.com/subscribe?wantedCollections=app.bsky.feed.post&cursor=1700000000000000"
427 );
428 }
429}