Skip to main content

atrg_stream/
consumer.rs

1//! Jetstream WebSocket consumer with bounded backpressure.
2//!
3//! The consumer connects to a Jetstream relay over WebSocket, reads events,
4//! and dispatches them through a bounded `mpsc` channel to a user-supplied
5//! handler. When the channel fills up, events are dropped and metrics are
6//! updated.
7
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10
11use futures::StreamExt;
12use tokio::sync::mpsc;
13use tokio_tungstenite::tungstenite::Message;
14
15use atrg_db::DbPool;
16
17use crate::backoff::Backoff;
18use crate::event::JetstreamEvent;
19use crate::metrics::MetricsCounter;
20use crate::EventHandler;
21use crate::StreamConfig;
22
23/// Spawn the Jetstream consumer as a pair of background tasks.
24///
25/// Returns a join handle for the reader task. The consumer architecture:
26///
27/// 1. **Reader task** — connects to the Jetstream WebSocket, deserializes
28///    incoming messages into [`JetstreamEvent`]s, and sends them into a
29///    bounded `mpsc` channel. Reconnects with exponential backoff on error.
30/// 2. **Dispatcher task** — reads events from the channel and invokes the
31///    user-supplied handler for each one.
32///
33/// Backpressure: when the channel is full, the reader drops events and
34/// increments the `events_dropped` metric counter.
35///
36/// The `state` parameter is an arbitrary `Clone + Send + 'static` value
37/// that is forwarded to the handler on every event. In a typical atrg app
38/// this is `AppState`, but the consumer itself does not depend on
39/// `atrg-core` to avoid a cyclic dependency.
40pub async fn spawn_consumer<S>(
41    config: &StreamConfig,
42    state: S,
43    handler: EventHandler<S>,
44) -> anyhow::Result<tokio::task::JoinHandle<()>>
45where
46    S: Clone + Send + Sync + 'static,
47{
48    let metrics = MetricsCounter::new();
49    let channel_capacity = config.channel_capacity;
50    let max_lag = config.max_lag_events;
51
52    // Build the WebSocket URL with collection filters.
53    let url = build_ws_url(&config.host, &config.collections, None);
54
55    tracing::info!(
56        url = %url,
57        channel_capacity = channel_capacity,
58        max_lag = max_lag,
59        "starting Jetstream consumer"
60    );
61
62    let (tx, rx) = mpsc::channel::<JetstreamEvent>(channel_capacity);
63
64    // Spawn the dispatcher task.
65    spawn_dispatcher(rx, handler, state, metrics.clone());
66
67    // Spawn the reader task.
68    let handle = spawn_reader(url, tx, metrics, max_lag);
69
70    Ok(handle)
71}
72
73/// Spawn a Jetstream consumer with cursor persistence.
74///
75/// Like [`spawn_consumer`], but loads the initial cursor from the database
76/// and periodically saves the latest processed event timestamp. This allows
77/// the consumer to resume from where it left off after a restart.
78///
79/// The `consumer_id` is a stable identifier for this consumer instance,
80/// used as the key in the cursor persistence table. Use a meaningful name
81/// like `"my-app-aggregator"` so multiple consumers can coexist.
82///
83/// Cursor behaviour is controlled by [`StreamConfig::cursor`]:
84/// - `None` or `"live"` — always start from now (no cursor in the URL).
85/// - `"auto"` — resume from the last stored cursor in the database.
86/// - A numeric string — use that value as the initial cursor timestamp.
87pub async fn spawn_consumer_with_cursor<S>(
88    config: &StreamConfig,
89    pool: &DbPool,
90    consumer_id: &str,
91    state: S,
92    handler: EventHandler<S>,
93) -> anyhow::Result<tokio::task::JoinHandle<()>>
94where
95    S: Clone + Send + Sync + 'static,
96{
97    // Ensure cursor table exists.
98    crate::cursor::ensure_cursor_table(pool).await?;
99
100    // Load stored cursor.
101    let stored_cursor = crate::cursor::load_cursor(pool, consumer_id).await?;
102
103    let initial_cursor = match config.cursor.as_deref() {
104        Some("live") | None => None,
105        Some("auto") => stored_cursor,
106        Some(numeric) => numeric.parse::<i64>().ok(),
107    };
108
109    if let Some(cursor) = initial_cursor {
110        tracing::info!(
111            cursor = cursor,
112            consumer_id = consumer_id,
113            "resuming Jetstream from stored cursor"
114        );
115    } else {
116        tracing::info!(
117            consumer_id = consumer_id,
118            "starting Jetstream from live (no cursor)"
119        );
120    }
121
122    // Build URL with cursor.
123    let url = build_ws_url(&config.host, &config.collections, initial_cursor);
124
125    let metrics = MetricsCounter::new();
126    let channel_capacity = config.channel_capacity;
127    let max_lag = config.max_lag_events;
128
129    tracing::info!(
130        url = %url,
131        channel_capacity = channel_capacity,
132        max_lag = max_lag,
133        "starting Jetstream consumer with cursor persistence"
134    );
135
136    let (tx, rx) = mpsc::channel::<JetstreamEvent>(channel_capacity);
137
138    // Spawn the cursor-persisting dispatcher task.
139    let pool_clone = pool.clone();
140    let cid = consumer_id.to_string();
141    spawn_cursor_dispatcher(rx, handler, state, metrics.clone(), pool_clone, cid);
142
143    // Spawn the reader task.
144    let handle = spawn_reader(url, tx, metrics, max_lag);
145
146    Ok(handle)
147}
148
149/// Build the Jetstream WebSocket subscription URL.
150///
151/// If `cursor` is provided, it is appended as a query parameter so that
152/// the relay replays events from that timestamp onward.
153fn build_ws_url(host: &str, collections: &[String], cursor: Option<i64>) -> String {
154    let mut params: Vec<String> = collections
155        .iter()
156        .map(|c| format!("wantedCollections={}", c))
157        .collect();
158
159    if let Some(cursor_us) = cursor {
160        params.push(format!("cursor={}", cursor_us));
161    }
162
163    if params.is_empty() {
164        format!("wss://{}/subscribe", host)
165    } else {
166        format!("wss://{}/subscribe?{}", host, params.join("&"))
167    }
168}
169
170/// Spawn the dispatcher task that reads from the channel and calls the handler.
171fn spawn_dispatcher<S>(
172    mut rx: mpsc::Receiver<JetstreamEvent>,
173    handler: EventHandler<S>,
174    state: S,
175    metrics: Arc<MetricsCounter>,
176) where
177    S: Clone + Send + Sync + 'static,
178{
179    tokio::spawn(async move {
180        while let Some(event) = rx.recv().await {
181            if let Err(e) = handler(event, state.clone()).await {
182                tracing::error!(error = %e, "Jetstream event handler error");
183                metrics.errors.fetch_add(1, Ordering::Relaxed);
184            }
185        }
186        tracing::info!("Jetstream dispatcher task exiting");
187    });
188}
189
190/// Interval (in number of events) between cursor saves.
191///
192/// Saving on every event would be too expensive for high-throughput streams.
193/// 100 events provides a good balance between resume precision and I/O load.
194const CURSOR_SAVE_INTERVAL: u64 = 100;
195
196/// Dispatcher that persists the cursor every [`CURSOR_SAVE_INTERVAL`] events.
197///
198/// On graceful shutdown (channel closed), the final cursor is saved so that
199/// no more than one interval's worth of events needs to be replayed.
200fn spawn_cursor_dispatcher<S>(
201    mut rx: mpsc::Receiver<JetstreamEvent>,
202    handler: EventHandler<S>,
203    state: S,
204    metrics: Arc<MetricsCounter>,
205    pool: DbPool,
206    consumer_id: String,
207) where
208    S: Clone + Send + Sync + 'static,
209{
210    tokio::spawn(async move {
211        let mut event_count: u64 = 0;
212        let mut last_time_us: Option<i64> = None;
213
214        while let Some(event) = rx.recv().await {
215            let time_us = event.time_us;
216
217            if let Err(e) = handler(event, state.clone()).await {
218                tracing::error!(error = %e, "Jetstream event handler error");
219                metrics.errors.fetch_add(1, Ordering::Relaxed);
220            }
221
222            // Only track non-zero timestamps (zero means the field was absent).
223            if time_us > 0 {
224                last_time_us = Some(time_us);
225            }
226            event_count += 1;
227
228            // Periodically persist the cursor.
229            if event_count % CURSOR_SAVE_INTERVAL == 0 {
230                if let Some(cursor) = last_time_us {
231                    if let Err(e) = crate::cursor::save_cursor(&pool, &consumer_id, cursor).await {
232                        tracing::warn!(error = %e, "failed to save Jetstream cursor");
233                    }
234                }
235            }
236        }
237
238        // Save final cursor on shutdown.
239        if let Some(cursor) = last_time_us {
240            if let Err(e) = crate::cursor::save_cursor(&pool, &consumer_id, cursor).await {
241                tracing::warn!(error = %e, "failed to save final Jetstream cursor");
242            } else {
243                tracing::info!(cursor = cursor, "saved final Jetstream cursor on shutdown");
244            }
245        }
246
247        tracing::info!("Jetstream cursor dispatcher task exiting");
248    });
249}
250
251/// Spawn the reader task that connects to the WebSocket and feeds the channel.
252fn spawn_reader(
253    url: String,
254    tx: mpsc::Sender<JetstreamEvent>,
255    metrics: Arc<MetricsCounter>,
256    max_lag: usize,
257) -> tokio::task::JoinHandle<()> {
258    tokio::spawn(async move {
259        let mut backoff = Backoff::new();
260
261        loop {
262            match connect_and_read(&url, &tx, &metrics, max_lag).await {
263                Ok(()) => {
264                    tracing::info!("Jetstream WebSocket closed cleanly");
265                }
266                Err(e) => {
267                    metrics.reconnects.fetch_add(1, Ordering::Relaxed);
268                    tracing::warn!(error = %e, "Jetstream connection error, will reconnect");
269                }
270            }
271
272            let delay = backoff.next_delay();
273            metrics
274                .current_backoff_ms
275                .store(delay.as_millis() as u64, Ordering::Relaxed);
276            tracing::info!(delay_ms = %delay.as_millis(), "reconnecting to Jetstream");
277            tokio::time::sleep(delay).await;
278        }
279    })
280}
281
282/// Connect to the WebSocket and read events until the connection drops.
283///
284/// On a successful connection the backoff counter is reset (via metrics).
285/// Returns `Ok(())` on a clean close, or an error on disconnect/failure.
286async fn connect_and_read(
287    url: &str,
288    tx: &mpsc::Sender<JetstreamEvent>,
289    metrics: &Arc<MetricsCounter>,
290    max_lag: usize,
291) -> anyhow::Result<()> {
292    let (ws_stream, _response) = tokio_tungstenite::connect_async(url).await?;
293    tracing::info!(url = %url, "connected to Jetstream");
294
295    // Reset backoff on successful connection.
296    metrics.current_backoff_ms.store(0, Ordering::Relaxed);
297
298    let (_write, mut read) = ws_stream.split();
299
300    while let Some(msg_result) = read.next().await {
301        let msg = msg_result?;
302        match msg {
303            Message::Text(text) => {
304                handle_text_message(&text, tx, metrics, max_lag);
305            }
306            Message::Close(_) => {
307                tracing::info!("Jetstream WebSocket closed by server");
308                break;
309            }
310            // Ping/Pong are handled automatically by tungstenite.
311            // Binary frames are not expected from Jetstream.
312            _ => {}
313        }
314    }
315
316    Ok(())
317}
318
319/// Parse and dispatch a single text message from the WebSocket.
320fn handle_text_message(
321    text: &str,
322    tx: &mpsc::Sender<JetstreamEvent>,
323    metrics: &Arc<MetricsCounter>,
324    max_lag: usize,
325) {
326    metrics.events_received.fetch_add(1, Ordering::Relaxed);
327    update_last_event_timestamp(metrics);
328
329    let event = match serde_json::from_str::<JetstreamEvent>(text) {
330        Ok(ev) => ev,
331        Err(e) => {
332            tracing::debug!(error = %e, "failed to parse Jetstream event");
333            metrics.errors.fetch_add(1, Ordering::Relaxed);
334            return;
335        }
336    };
337
338    // Lag detection: if remaining capacity is zero and the channel is large
339    // enough that we've hit the lag threshold, drop the event.
340    let remaining = tx.capacity();
341    if remaining == 0 {
342        metrics.events_dropped.fetch_add(1, Ordering::Relaxed);
343        if tx.max_capacity() >= max_lag {
344            tracing::warn!(
345                max_lag = max_lag,
346                "Jetstream consumer lagging beyond threshold, dropping event"
347            );
348        }
349        return;
350    }
351
352    // Try non-blocking send. If it fails (shouldn't after the capacity check,
353    // but races are possible), drop the event.
354    if tx.try_send(event).is_err() {
355        metrics.events_dropped.fetch_add(1, Ordering::Relaxed);
356        tracing::debug!("Jetstream channel full on try_send, dropping event");
357    }
358}
359
360/// Record the current wall-clock time as the last-event timestamp.
361fn update_last_event_timestamp(metrics: &Arc<MetricsCounter>) {
362    let now_ms = std::time::SystemTime::now()
363        .duration_since(std::time::UNIX_EPOCH)
364        .unwrap_or_default()
365        .as_millis() as u64;
366    metrics.last_event_at.store(now_ms, Ordering::Relaxed);
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn build_ws_url_no_collections() {
375        let url = build_ws_url("jetstream1.example.com", &[], None);
376        assert_eq!(url, "wss://jetstream1.example.com/subscribe");
377    }
378
379    #[test]
380    fn build_ws_url_single_collection() {
381        let url = build_ws_url(
382            "jetstream1.example.com",
383            &["app.bsky.feed.post".to_string()],
384            None,
385        );
386        assert_eq!(
387            url,
388            "wss://jetstream1.example.com/subscribe?wantedCollections=app.bsky.feed.post"
389        );
390    }
391
392    #[test]
393    fn build_ws_url_multiple_collections() {
394        let url = build_ws_url(
395            "jetstream1.example.com",
396            &[
397                "app.bsky.feed.post".to_string(),
398                "app.bsky.feed.like".to_string(),
399            ],
400            None,
401        );
402        assert_eq!(
403            url,
404            "wss://jetstream1.example.com/subscribe?wantedCollections=app.bsky.feed.post&wantedCollections=app.bsky.feed.like"
405        );
406    }
407
408    #[test]
409    fn build_ws_url_with_cursor_no_collections() {
410        let url = build_ws_url("jetstream1.example.com", &[], Some(1700000000000000));
411        assert_eq!(
412            url,
413            "wss://jetstream1.example.com/subscribe?cursor=1700000000000000"
414        );
415    }
416
417    #[test]
418    fn build_ws_url_with_cursor_and_collections() {
419        let url = build_ws_url(
420            "jetstream1.example.com",
421            &["app.bsky.feed.post".to_string()],
422            Some(1700000000000000),
423        );
424        assert_eq!(
425            url,
426            "wss://jetstream1.example.com/subscribe?wantedCollections=app.bsky.feed.post&cursor=1700000000000000"
427        );
428    }
429}