Skip to main content

atrg_stream/
consumer.rs

1//! Jetstream WebSocket consumer with bounded backpressure.
2//!
3//! The consumer connects to a Jetstream relay over WebSocket, reads events,
4//! and dispatches them through a bounded `mpsc` channel to a user-supplied
5//! handler. When the channel fills up, events are dropped and metrics are
6//! updated.
7
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10
11use futures::StreamExt;
12use tokio::sync::mpsc;
13use tokio_tungstenite::tungstenite::Message;
14
15use crate::backoff::Backoff;
16use crate::event::JetstreamEvent;
17use crate::metrics::MetricsCounter;
18use crate::EventHandler;
19use crate::StreamConfig;
20
21/// Spawn the Jetstream consumer as a pair of background tasks.
22///
23/// Returns a join handle for the reader task. The consumer architecture:
24///
25/// 1. **Reader task** — connects to the Jetstream WebSocket, deserializes
26///    incoming messages into [`JetstreamEvent`]s, and sends them into a
27///    bounded `mpsc` channel. Reconnects with exponential backoff on error.
28/// 2. **Dispatcher task** — reads events from the channel and invokes the
29///    user-supplied handler for each one.
30///
31/// Backpressure: when the channel is full, the reader drops events and
32/// increments the `events_dropped` metric counter.
33///
34/// The `state` parameter is an arbitrary `Clone + Send + 'static` value
35/// that is forwarded to the handler on every event. In a typical atrg app
36/// this is `AppState`, but the consumer itself does not depend on
37/// `atrg-core` to avoid a cyclic dependency.
38pub async fn spawn_consumer<S>(
39    config: &StreamConfig,
40    state: S,
41    handler: EventHandler<S>,
42) -> anyhow::Result<tokio::task::JoinHandle<()>>
43where
44    S: Clone + Send + Sync + 'static,
45{
46    let metrics = MetricsCounter::new();
47    let channel_capacity = config.channel_capacity;
48    let max_lag = config.max_lag_events;
49
50    // Build the WebSocket URL with collection filters.
51    let url = build_ws_url(&config.host, &config.collections);
52
53    tracing::info!(
54        url = %url,
55        channel_capacity = channel_capacity,
56        max_lag = max_lag,
57        "starting Jetstream consumer"
58    );
59
60    let (tx, rx) = mpsc::channel::<JetstreamEvent>(channel_capacity);
61
62    // Spawn the dispatcher task.
63    spawn_dispatcher(rx, handler, state, metrics.clone());
64
65    // Spawn the reader task.
66    let handle = spawn_reader(url, tx, metrics, max_lag);
67
68    Ok(handle)
69}
70
71/// Build the Jetstream WebSocket subscription URL.
72fn build_ws_url(host: &str, collections: &[String]) -> String {
73    if collections.is_empty() {
74        return format!("wss://{}/subscribe", host);
75    }
76
77    let params: Vec<String> = collections
78        .iter()
79        .map(|c| format!("wantedCollections={}", c))
80        .collect();
81
82    format!("wss://{}/subscribe?{}", host, params.join("&"))
83}
84
85/// Spawn the dispatcher task that reads from the channel and calls the handler.
86fn spawn_dispatcher<S>(
87    mut rx: mpsc::Receiver<JetstreamEvent>,
88    handler: EventHandler<S>,
89    state: S,
90    metrics: Arc<MetricsCounter>,
91) where
92    S: Clone + Send + Sync + 'static,
93{
94    tokio::spawn(async move {
95        while let Some(event) = rx.recv().await {
96            if let Err(e) = handler(event, state.clone()).await {
97                tracing::error!(error = %e, "Jetstream event handler error");
98                metrics.errors.fetch_add(1, Ordering::Relaxed);
99            }
100        }
101        tracing::info!("Jetstream dispatcher task exiting");
102    });
103}
104
105/// Spawn the reader task that connects to the WebSocket and feeds the channel.
106fn spawn_reader(
107    url: String,
108    tx: mpsc::Sender<JetstreamEvent>,
109    metrics: Arc<MetricsCounter>,
110    max_lag: usize,
111) -> tokio::task::JoinHandle<()> {
112    tokio::spawn(async move {
113        let mut backoff = Backoff::new();
114
115        loop {
116            match connect_and_read(&url, &tx, &metrics, max_lag).await {
117                Ok(()) => {
118                    tracing::info!("Jetstream WebSocket closed cleanly");
119                }
120                Err(e) => {
121                    metrics.reconnects.fetch_add(1, Ordering::Relaxed);
122                    tracing::warn!(error = %e, "Jetstream connection error, will reconnect");
123                }
124            }
125
126            let delay = backoff.next_delay();
127            metrics
128                .current_backoff_ms
129                .store(delay.as_millis() as u64, Ordering::Relaxed);
130            tracing::info!(delay_ms = %delay.as_millis(), "reconnecting to Jetstream");
131            tokio::time::sleep(delay).await;
132        }
133    })
134}
135
136/// Connect to the WebSocket and read events until the connection drops.
137///
138/// On a successful connection the backoff counter is reset (via metrics).
139/// Returns `Ok(())` on a clean close, or an error on disconnect/failure.
140async fn connect_and_read(
141    url: &str,
142    tx: &mpsc::Sender<JetstreamEvent>,
143    metrics: &Arc<MetricsCounter>,
144    max_lag: usize,
145) -> anyhow::Result<()> {
146    let (ws_stream, _response) = tokio_tungstenite::connect_async(url).await?;
147    tracing::info!(url = %url, "connected to Jetstream");
148
149    // Reset backoff on successful connection.
150    metrics.current_backoff_ms.store(0, Ordering::Relaxed);
151
152    let (_write, mut read) = ws_stream.split();
153
154    while let Some(msg_result) = read.next().await {
155        let msg = msg_result?;
156        match msg {
157            Message::Text(text) => {
158                handle_text_message(&text, tx, metrics, max_lag);
159            }
160            Message::Close(_) => {
161                tracing::info!("Jetstream WebSocket closed by server");
162                break;
163            }
164            // Ping/Pong are handled automatically by tungstenite.
165            // Binary frames are not expected from Jetstream.
166            _ => {}
167        }
168    }
169
170    Ok(())
171}
172
173/// Parse and dispatch a single text message from the WebSocket.
174fn handle_text_message(
175    text: &str,
176    tx: &mpsc::Sender<JetstreamEvent>,
177    metrics: &Arc<MetricsCounter>,
178    max_lag: usize,
179) {
180    metrics.events_received.fetch_add(1, Ordering::Relaxed);
181    update_last_event_timestamp(metrics);
182
183    let event = match serde_json::from_str::<JetstreamEvent>(text) {
184        Ok(ev) => ev,
185        Err(e) => {
186            tracing::debug!(error = %e, "failed to parse Jetstream event");
187            metrics.errors.fetch_add(1, Ordering::Relaxed);
188            return;
189        }
190    };
191
192    // Lag detection: if remaining capacity is zero and the channel is large
193    // enough that we've hit the lag threshold, drop the event.
194    let remaining = tx.capacity();
195    if remaining == 0 {
196        metrics.events_dropped.fetch_add(1, Ordering::Relaxed);
197        if tx.max_capacity() >= max_lag {
198            tracing::warn!(
199                max_lag = max_lag,
200                "Jetstream consumer lagging beyond threshold, dropping event"
201            );
202        }
203        return;
204    }
205
206    // Try non-blocking send. If it fails (shouldn't after the capacity check,
207    // but races are possible), drop the event.
208    if tx.try_send(event).is_err() {
209        metrics.events_dropped.fetch_add(1, Ordering::Relaxed);
210        tracing::debug!("Jetstream channel full on try_send, dropping event");
211    }
212}
213
214/// Record the current wall-clock time as the last-event timestamp.
215fn update_last_event_timestamp(metrics: &Arc<MetricsCounter>) {
216    let now_ms = std::time::SystemTime::now()
217        .duration_since(std::time::UNIX_EPOCH)
218        .unwrap_or_default()
219        .as_millis() as u64;
220    metrics.last_event_at.store(now_ms, Ordering::Relaxed);
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn build_ws_url_no_collections() {
229        let url = build_ws_url("jetstream1.example.com", &[]);
230        assert_eq!(url, "wss://jetstream1.example.com/subscribe");
231    }
232
233    #[test]
234    fn build_ws_url_single_collection() {
235        let url = build_ws_url(
236            "jetstream1.example.com",
237            &["app.bsky.feed.post".to_string()],
238        );
239        assert_eq!(
240            url,
241            "wss://jetstream1.example.com/subscribe?wantedCollections=app.bsky.feed.post"
242        );
243    }
244
245    #[test]
246    fn build_ws_url_multiple_collections() {
247        let url = build_ws_url(
248            "jetstream1.example.com",
249            &[
250                "app.bsky.feed.post".to_string(),
251                "app.bsky.feed.like".to_string(),
252            ],
253        );
254        assert_eq!(
255            url,
256            "wss://jetstream1.example.com/subscribe?wantedCollections=app.bsky.feed.post&wantedCollections=app.bsky.feed.like"
257        );
258    }
259}