Skip to main content

atrg_label/
store.rs

1//! SQLite-backed label storage.
2
3use sqlx::SqlitePool;
4
5use crate::types::{Label, SignedLabel};
6
7/// Persistent label store backed by SQLite.
8pub struct LabelStore {
9    db: SqlitePool,
10}
11
12impl LabelStore {
13    /// Create a new label store using the given database pool.
14    pub fn new(db: SqlitePool) -> Self {
15        Self { db }
16    }
17
18    /// Run the label store migrations (creates the `atrg_labels` table).
19    pub async fn migrate(&self) -> anyhow::Result<()> {
20        sqlx::query(
21            "CREATE TABLE IF NOT EXISTS atrg_labels (
22                id INTEGER PRIMARY KEY AUTOINCREMENT,
23                src TEXT NOT NULL,
24                uri TEXT NOT NULL,
25                cid TEXT,
26                val TEXT NOT NULL,
27                neg INTEGER NOT NULL DEFAULT 0,
28                cts TEXT NOT NULL,
29                exp TEXT,
30                sig TEXT NOT NULL,
31                created_at INTEGER NOT NULL DEFAULT (unixepoch())
32            )",
33        )
34        .execute(&self.db)
35        .await?;
36
37        sqlx::query("CREATE INDEX IF NOT EXISTS idx_atrg_labels_uri ON atrg_labels(uri)")
38            .execute(&self.db)
39            .await?;
40
41        sqlx::query("CREATE INDEX IF NOT EXISTS idx_atrg_labels_src ON atrg_labels(src)")
42            .execute(&self.db)
43            .await?;
44
45        sqlx::query("CREATE INDEX IF NOT EXISTS idx_atrg_labels_val ON atrg_labels(val)")
46            .execute(&self.db)
47            .await?;
48
49        Ok(())
50    }
51
52    /// Insert a signed label into the store.
53    ///
54    /// Returns the auto-generated row ID of the inserted label.
55    pub async fn insert(&self, label: &SignedLabel) -> anyhow::Result<i64> {
56        let result = sqlx::query(
57            "INSERT INTO atrg_labels (src, uri, cid, val, neg, cts, exp, sig)
58             VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
59        )
60        .bind(&label.label.src)
61        .bind(&label.label.uri)
62        .bind(&label.label.cid)
63        .bind(&label.label.val)
64        .bind(label.label.neg as i32)
65        .bind(&label.label.cts)
66        .bind(&label.label.exp)
67        .bind(&label.sig)
68        .execute(&self.db)
69        .await?;
70
71        Ok(result.last_insert_rowid())
72    }
73
74    /// Query labels for a given subject URI.
75    pub async fn query_by_uri(&self, uri: &str) -> anyhow::Result<Vec<SignedLabel>> {
76        let rows = sqlx::query_as::<_, LabelRow>(
77            "SELECT src, uri, cid, val, neg, cts, exp, sig
78             FROM atrg_labels WHERE uri = ? ORDER BY id",
79        )
80        .bind(uri)
81        .fetch_all(&self.db)
82        .await?;
83
84        Ok(rows.into_iter().map(|r| r.into_signed_label()).collect())
85    }
86
87    /// Query labels since a given cursor (row id), for subscription streaming.
88    ///
89    /// Returns pairs of `(id, SignedLabel)` so callers can use the id as the
90    /// next cursor value.
91    pub async fn query_since(
92        &self,
93        cursor: i64,
94        limit: i64,
95    ) -> anyhow::Result<Vec<(i64, SignedLabel)>> {
96        let rows = sqlx::query_as::<_, LabelRowWithId>(
97            "SELECT id, src, uri, cid, val, neg, cts, exp, sig
98             FROM atrg_labels WHERE id > ? ORDER BY id LIMIT ?",
99        )
100        .bind(cursor)
101        .bind(limit)
102        .fetch_all(&self.db)
103        .await?;
104
105        Ok(rows
106            .into_iter()
107            .map(|r| {
108                let id = r.id;
109                (id, r.into_signed_label())
110            })
111            .collect())
112    }
113}
114
115/// Internal row type for SQLx mapping.
116#[derive(sqlx::FromRow)]
117struct LabelRow {
118    src: String,
119    uri: String,
120    cid: Option<String>,
121    val: String,
122    neg: i32,
123    cts: String,
124    exp: Option<String>,
125    sig: String,
126}
127
128impl LabelRow {
129    fn into_signed_label(self) -> SignedLabel {
130        SignedLabel {
131            label: Label {
132                ver: 1,
133                src: self.src,
134                uri: self.uri,
135                cid: self.cid,
136                val: self.val,
137                neg: self.neg != 0,
138                cts: self.cts,
139                exp: self.exp,
140            },
141            sig: self.sig,
142        }
143    }
144}
145
146/// Internal row type that includes the auto-generated id for cursor support.
147#[derive(sqlx::FromRow)]
148struct LabelRowWithId {
149    id: i64,
150    src: String,
151    uri: String,
152    cid: Option<String>,
153    val: String,
154    neg: i32,
155    cts: String,
156    exp: Option<String>,
157    sig: String,
158}
159
160impl LabelRowWithId {
161    fn into_signed_label(self) -> SignedLabel {
162        SignedLabel {
163            label: Label {
164                ver: 1,
165                src: self.src,
166                uri: self.uri,
167                cid: self.cid,
168                val: self.val,
169                neg: self.neg != 0,
170                cts: self.cts,
171                exp: self.exp,
172            },
173            sig: self.sig,
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use sqlx::SqlitePool;
182
183    fn make_signed_label(src: &str, uri: &str, val: &str) -> SignedLabel {
184        SignedLabel {
185            label: Label {
186                ver: 1,
187                src: src.to_string(),
188                uri: uri.to_string(),
189                cid: None,
190                val: val.to_string(),
191                neg: false,
192                cts: "2024-01-01T00:00:00Z".to_string(),
193                exp: None,
194            },
195            sig: "test-sig".to_string(),
196        }
197    }
198
199    async fn setup_store() -> LabelStore {
200        let db = SqlitePool::connect("sqlite::memory:").await.unwrap();
201        let store = LabelStore::new(db);
202        store.migrate().await.unwrap();
203        store
204    }
205
206    #[tokio::test]
207    async fn test_migrate_creates_table() {
208        let db = SqlitePool::connect("sqlite::memory:").await.unwrap();
209        let store = LabelStore::new(db.clone());
210        store.migrate().await.unwrap();
211
212        // Verify the table exists by querying it.
213        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM atrg_labels")
214            .fetch_one(&db)
215            .await
216            .unwrap();
217        assert_eq!(row.0, 0);
218    }
219
220    #[tokio::test]
221    async fn test_insert_returns_positive_id() {
222        let store = setup_store().await;
223        let label = make_signed_label("did:plc:labeler", "at://did:plc:user/post/1", "spam");
224        let id = store.insert(&label).await.unwrap();
225        assert!(id > 0);
226    }
227
228    #[tokio::test]
229    async fn test_query_by_uri() {
230        let store = setup_store().await;
231
232        let uri_a = "at://did:plc:user/post/a";
233        let uri_b = "at://did:plc:user/post/b";
234
235        store
236            .insert(&make_signed_label("did:plc:l", uri_a, "spam"))
237            .await
238            .unwrap();
239        store
240            .insert(&make_signed_label("did:plc:l", uri_a, "porn"))
241            .await
242            .unwrap();
243        store
244            .insert(&make_signed_label("did:plc:l", uri_b, "nudity"))
245            .await
246            .unwrap();
247
248        let results_a = store.query_by_uri(uri_a).await.unwrap();
249        assert_eq!(results_a.len(), 2);
250        assert_eq!(results_a[0].label.val, "spam");
251        assert_eq!(results_a[1].label.val, "porn");
252
253        let results_b = store.query_by_uri(uri_b).await.unwrap();
254        assert_eq!(results_b.len(), 1);
255        assert_eq!(results_b[0].label.val, "nudity");
256    }
257
258    #[tokio::test]
259    async fn test_query_by_uri_empty() {
260        let store = setup_store().await;
261        let results = store.query_by_uri("at://nonexistent").await.unwrap();
262        assert!(results.is_empty());
263    }
264
265    #[tokio::test]
266    async fn test_query_since_with_cursor() {
267        let store = setup_store().await;
268        let uri = "at://did:plc:user/post/1";
269        for i in 0..5 {
270            store
271                .insert(&make_signed_label("did:plc:l", uri, &format!("val-{}", i)))
272                .await
273                .unwrap();
274        }
275
276        // First page: 3 results starting from cursor 0.
277        let page1 = store.query_since(0, 3).await.unwrap();
278        assert_eq!(page1.len(), 3);
279        assert_eq!(page1[0].1.label.val, "val-0");
280        assert_eq!(page1[2].1.label.val, "val-2");
281
282        // Second page: use last id as cursor.
283        let last_cursor = page1.last().unwrap().0;
284        let page2 = store.query_since(last_cursor, 3).await.unwrap();
285        assert_eq!(page2.len(), 2);
286        assert_eq!(page2[0].1.label.val, "val-3");
287        assert_eq!(page2[1].1.label.val, "val-4");
288    }
289
290    #[tokio::test]
291    async fn test_query_since_respects_limit() {
292        let store = setup_store().await;
293        let uri = "at://did:plc:user/post/1";
294        for i in 0..10 {
295            store
296                .insert(&make_signed_label("did:plc:l", uri, &format!("v{}", i)))
297                .await
298                .unwrap();
299        }
300
301        let results = store.query_since(0, 5).await.unwrap();
302        assert_eq!(results.len(), 5);
303    }
304}