1use sqlx::SqlitePool;
4
5use crate::types::{Label, SignedLabel};
6
7pub struct LabelStore {
9 db: SqlitePool,
10}
11
12impl LabelStore {
13 pub fn new(db: SqlitePool) -> Self {
15 Self { db }
16 }
17
18 pub async fn migrate(&self) -> anyhow::Result<()> {
20 sqlx::query(
21 "CREATE TABLE IF NOT EXISTS atrg_labels (
22 id INTEGER PRIMARY KEY AUTOINCREMENT,
23 src TEXT NOT NULL,
24 uri TEXT NOT NULL,
25 cid TEXT,
26 val TEXT NOT NULL,
27 neg INTEGER NOT NULL DEFAULT 0,
28 cts TEXT NOT NULL,
29 exp TEXT,
30 sig TEXT NOT NULL,
31 created_at INTEGER NOT NULL DEFAULT (unixepoch())
32 )",
33 )
34 .execute(&self.db)
35 .await?;
36
37 sqlx::query("CREATE INDEX IF NOT EXISTS idx_atrg_labels_uri ON atrg_labels(uri)")
38 .execute(&self.db)
39 .await?;
40
41 sqlx::query("CREATE INDEX IF NOT EXISTS idx_atrg_labels_src ON atrg_labels(src)")
42 .execute(&self.db)
43 .await?;
44
45 sqlx::query("CREATE INDEX IF NOT EXISTS idx_atrg_labels_val ON atrg_labels(val)")
46 .execute(&self.db)
47 .await?;
48
49 Ok(())
50 }
51
52 pub async fn insert(&self, label: &SignedLabel) -> anyhow::Result<i64> {
56 let result = sqlx::query(
57 "INSERT INTO atrg_labels (src, uri, cid, val, neg, cts, exp, sig)
58 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
59 )
60 .bind(&label.label.src)
61 .bind(&label.label.uri)
62 .bind(&label.label.cid)
63 .bind(&label.label.val)
64 .bind(label.label.neg as i32)
65 .bind(&label.label.cts)
66 .bind(&label.label.exp)
67 .bind(&label.sig)
68 .execute(&self.db)
69 .await?;
70
71 Ok(result.last_insert_rowid())
72 }
73
74 pub async fn query_by_uri(&self, uri: &str) -> anyhow::Result<Vec<SignedLabel>> {
76 let rows = sqlx::query_as::<_, LabelRow>(
77 "SELECT src, uri, cid, val, neg, cts, exp, sig
78 FROM atrg_labels WHERE uri = ? ORDER BY id",
79 )
80 .bind(uri)
81 .fetch_all(&self.db)
82 .await?;
83
84 Ok(rows.into_iter().map(|r| r.into_signed_label()).collect())
85 }
86
87 pub async fn query_since(
92 &self,
93 cursor: i64,
94 limit: i64,
95 ) -> anyhow::Result<Vec<(i64, SignedLabel)>> {
96 let rows = sqlx::query_as::<_, LabelRowWithId>(
97 "SELECT id, src, uri, cid, val, neg, cts, exp, sig
98 FROM atrg_labels WHERE id > ? ORDER BY id LIMIT ?",
99 )
100 .bind(cursor)
101 .bind(limit)
102 .fetch_all(&self.db)
103 .await?;
104
105 Ok(rows
106 .into_iter()
107 .map(|r| {
108 let id = r.id;
109 (id, r.into_signed_label())
110 })
111 .collect())
112 }
113}
114
115#[derive(sqlx::FromRow)]
117struct LabelRow {
118 src: String,
119 uri: String,
120 cid: Option<String>,
121 val: String,
122 neg: i32,
123 cts: String,
124 exp: Option<String>,
125 sig: String,
126}
127
128impl LabelRow {
129 fn into_signed_label(self) -> SignedLabel {
130 SignedLabel {
131 label: Label {
132 ver: 1,
133 src: self.src,
134 uri: self.uri,
135 cid: self.cid,
136 val: self.val,
137 neg: self.neg != 0,
138 cts: self.cts,
139 exp: self.exp,
140 },
141 sig: self.sig,
142 }
143 }
144}
145
146#[derive(sqlx::FromRow)]
148struct LabelRowWithId {
149 id: i64,
150 src: String,
151 uri: String,
152 cid: Option<String>,
153 val: String,
154 neg: i32,
155 cts: String,
156 exp: Option<String>,
157 sig: String,
158}
159
160impl LabelRowWithId {
161 fn into_signed_label(self) -> SignedLabel {
162 SignedLabel {
163 label: Label {
164 ver: 1,
165 src: self.src,
166 uri: self.uri,
167 cid: self.cid,
168 val: self.val,
169 neg: self.neg != 0,
170 cts: self.cts,
171 exp: self.exp,
172 },
173 sig: self.sig,
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use sqlx::SqlitePool;
182
183 fn make_signed_label(src: &str, uri: &str, val: &str) -> SignedLabel {
184 SignedLabel {
185 label: Label {
186 ver: 1,
187 src: src.to_string(),
188 uri: uri.to_string(),
189 cid: None,
190 val: val.to_string(),
191 neg: false,
192 cts: "2024-01-01T00:00:00Z".to_string(),
193 exp: None,
194 },
195 sig: "test-sig".to_string(),
196 }
197 }
198
199 async fn setup_store() -> LabelStore {
200 let db = SqlitePool::connect("sqlite::memory:").await.unwrap();
201 let store = LabelStore::new(db);
202 store.migrate().await.unwrap();
203 store
204 }
205
206 #[tokio::test]
207 async fn test_migrate_creates_table() {
208 let db = SqlitePool::connect("sqlite::memory:").await.unwrap();
209 let store = LabelStore::new(db.clone());
210 store.migrate().await.unwrap();
211
212 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM atrg_labels")
214 .fetch_one(&db)
215 .await
216 .unwrap();
217 assert_eq!(row.0, 0);
218 }
219
220 #[tokio::test]
221 async fn test_insert_returns_positive_id() {
222 let store = setup_store().await;
223 let label = make_signed_label("did:plc:labeler", "at://did:plc:user/post/1", "spam");
224 let id = store.insert(&label).await.unwrap();
225 assert!(id > 0);
226 }
227
228 #[tokio::test]
229 async fn test_query_by_uri() {
230 let store = setup_store().await;
231
232 let uri_a = "at://did:plc:user/post/a";
233 let uri_b = "at://did:plc:user/post/b";
234
235 store
236 .insert(&make_signed_label("did:plc:l", uri_a, "spam"))
237 .await
238 .unwrap();
239 store
240 .insert(&make_signed_label("did:plc:l", uri_a, "porn"))
241 .await
242 .unwrap();
243 store
244 .insert(&make_signed_label("did:plc:l", uri_b, "nudity"))
245 .await
246 .unwrap();
247
248 let results_a = store.query_by_uri(uri_a).await.unwrap();
249 assert_eq!(results_a.len(), 2);
250 assert_eq!(results_a[0].label.val, "spam");
251 assert_eq!(results_a[1].label.val, "porn");
252
253 let results_b = store.query_by_uri(uri_b).await.unwrap();
254 assert_eq!(results_b.len(), 1);
255 assert_eq!(results_b[0].label.val, "nudity");
256 }
257
258 #[tokio::test]
259 async fn test_query_by_uri_empty() {
260 let store = setup_store().await;
261 let results = store.query_by_uri("at://nonexistent").await.unwrap();
262 assert!(results.is_empty());
263 }
264
265 #[tokio::test]
266 async fn test_query_since_with_cursor() {
267 let store = setup_store().await;
268 let uri = "at://did:plc:user/post/1";
269 for i in 0..5 {
270 store
271 .insert(&make_signed_label("did:plc:l", uri, &format!("val-{}", i)))
272 .await
273 .unwrap();
274 }
275
276 let page1 = store.query_since(0, 3).await.unwrap();
278 assert_eq!(page1.len(), 3);
279 assert_eq!(page1[0].1.label.val, "val-0");
280 assert_eq!(page1[2].1.label.val, "val-2");
281
282 let last_cursor = page1.last().unwrap().0;
284 let page2 = store.query_since(last_cursor, 3).await.unwrap();
285 assert_eq!(page2.len(), 2);
286 assert_eq!(page2[0].1.label.val, "val-3");
287 assert_eq!(page2[1].1.label.val, "val-4");
288 }
289
290 #[tokio::test]
291 async fn test_query_since_respects_limit() {
292 let store = setup_store().await;
293 let uri = "at://did:plc:user/post/1";
294 for i in 0..10 {
295 store
296 .insert(&make_signed_label("did:plc:l", uri, &format!("v{}", i)))
297 .await
298 .unwrap();
299 }
300
301 let results = store.query_since(0, 5).await.unwrap();
302 assert_eq!(results.len(), 5);
303 }
304}