1use axum::http::StatusCode;
7use axum::response::{IntoResponse, Response};
8use axum::Json;
9
10use atrg_core::error::AtrgError;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum XrpcErrorName {
15 InvalidRequest,
17 AuthRequired,
19 Forbidden,
21 NotFound,
23 RateLimitExceeded,
25 InternalServerError,
27 MethodNotImplemented,
29}
30
31impl XrpcErrorName {
32 pub fn as_str(&self) -> &'static str {
34 match self {
35 Self::InvalidRequest => "InvalidRequest",
36 Self::AuthRequired => "AuthRequired",
37 Self::Forbidden => "Forbidden",
38 Self::NotFound => "NotFound",
39 Self::RateLimitExceeded => "RateLimitExceeded",
40 Self::InternalServerError => "InternalServerError",
41 Self::MethodNotImplemented => "MethodNotImplemented",
42 }
43 }
44
45 pub fn status_code(&self) -> StatusCode {
47 match self {
48 Self::InvalidRequest => StatusCode::BAD_REQUEST,
49 Self::AuthRequired => StatusCode::UNAUTHORIZED,
50 Self::Forbidden => StatusCode::FORBIDDEN,
51 Self::NotFound => StatusCode::NOT_FOUND,
52 Self::RateLimitExceeded => StatusCode::TOO_MANY_REQUESTS,
53 Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR,
54 Self::MethodNotImplemented => StatusCode::NOT_IMPLEMENTED,
55 }
56 }
57}
58
59#[derive(Debug)]
69pub struct XrpcError {
70 pub name: XrpcErrorName,
72 pub message: String,
74}
75
76impl std::fmt::Display for XrpcError {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(f, "{}: {}", self.name.as_str(), self.message)
79 }
80}
81
82impl std::error::Error for XrpcError {}
83
84impl IntoResponse for XrpcError {
85 fn into_response(self) -> Response {
86 let status = self.name.status_code();
87 let body = serde_json::json!({
88 "error": self.name.as_str(),
89 "message": self.message,
90 });
91 (status, Json(body)).into_response()
92 }
93}
94
95impl From<AtrgError> for XrpcError {
96 fn from(err: AtrgError) -> Self {
97 match err {
98 AtrgError::NotFound => XrpcError {
99 name: XrpcErrorName::NotFound,
100 message: "Not found".to_string(),
101 },
102 AtrgError::Auth(msg) => XrpcError {
103 name: XrpcErrorName::AuthRequired,
104 message: msg,
105 },
106 AtrgError::BadRequest(msg) => XrpcError {
107 name: XrpcErrorName::InvalidRequest,
108 message: msg,
109 },
110 AtrgError::Database(_) => XrpcError {
111 name: XrpcErrorName::InternalServerError,
112 message: "Internal server error".to_string(),
113 },
114 AtrgError::Internal(_) => XrpcError {
115 name: XrpcErrorName::InternalServerError,
116 message: "Internal server error".to_string(),
117 },
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use axum::body::Body;
126 use http_body_util::BodyExt;
127
128 async fn error_to_parts(err: XrpcError) -> (StatusCode, serde_json::Value) {
129 let response = err.into_response();
130 let status = response.status();
131 let bytes = Body::new(response.into_body())
132 .collect()
133 .await
134 .unwrap()
135 .to_bytes();
136 let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
137 (status, json)
138 }
139
140 #[tokio::test]
141 async fn invalid_request_400() {
142 let (status, body) = error_to_parts(XrpcError {
143 name: XrpcErrorName::InvalidRequest,
144 message: "bad input".into(),
145 })
146 .await;
147 assert_eq!(status, StatusCode::BAD_REQUEST);
148 assert_eq!(body["error"], "InvalidRequest");
149 assert_eq!(body["message"], "bad input");
150 }
151
152 #[tokio::test]
153 async fn auth_required_401() {
154 let (status, body) = error_to_parts(XrpcError {
155 name: XrpcErrorName::AuthRequired,
156 message: "login needed".into(),
157 })
158 .await;
159 assert_eq!(status, StatusCode::UNAUTHORIZED);
160 assert_eq!(body["error"], "AuthRequired");
161 }
162
163 #[tokio::test]
164 async fn forbidden_403() {
165 let (status, _) = error_to_parts(XrpcError {
166 name: XrpcErrorName::Forbidden,
167 message: "nope".into(),
168 })
169 .await;
170 assert_eq!(status, StatusCode::FORBIDDEN);
171 }
172
173 #[tokio::test]
174 async fn not_found_404() {
175 let (status, _) = error_to_parts(XrpcError {
176 name: XrpcErrorName::NotFound,
177 message: "gone".into(),
178 })
179 .await;
180 assert_eq!(status, StatusCode::NOT_FOUND);
181 }
182
183 #[tokio::test]
184 async fn rate_limit_429() {
185 let (status, _) = error_to_parts(XrpcError {
186 name: XrpcErrorName::RateLimitExceeded,
187 message: "slow down".into(),
188 })
189 .await;
190 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
191 }
192
193 #[tokio::test]
194 async fn internal_500() {
195 let (status, _) = error_to_parts(XrpcError {
196 name: XrpcErrorName::InternalServerError,
197 message: "oops".into(),
198 })
199 .await;
200 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
201 }
202
203 #[tokio::test]
204 async fn not_implemented_501() {
205 let (status, body) = error_to_parts(XrpcError {
206 name: XrpcErrorName::MethodNotImplemented,
207 message: "not here".into(),
208 })
209 .await;
210 assert_eq!(status, StatusCode::NOT_IMPLEMENTED);
211 assert_eq!(body["error"], "MethodNotImplemented");
212 }
213
214 #[tokio::test]
215 async fn from_atrg_error_not_found() {
216 let xrpc: XrpcError = AtrgError::NotFound.into();
217 assert_eq!(xrpc.name, XrpcErrorName::NotFound);
218 }
219
220 #[tokio::test]
221 async fn from_atrg_error_auth() {
222 let xrpc: XrpcError = AtrgError::Auth("no".into()).into();
223 assert_eq!(xrpc.name, XrpcErrorName::AuthRequired);
224 }
225
226 #[tokio::test]
227 async fn from_atrg_error_bad_request() {
228 let xrpc: XrpcError = AtrgError::BadRequest("bad".into()).into();
229 assert_eq!(xrpc.name, XrpcErrorName::InvalidRequest);
230 }
231
232 #[tokio::test]
233 async fn xrpc_router_fallback_returns_501() {
234 use atrg_core::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
235 use atrg_core::state::AppState;
236 use hyper::Request;
237 use std::sync::Arc;
238 use tower::ServiceExt;
239
240 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
241 atrg_db::run_internal_migrations(&db).await.unwrap();
242 let state = AppState {
243 config: Arc::new(Config {
244 app: AppConfig {
245 name: "test".into(),
246 host: "127.0.0.1".into(),
247 port: 3000,
248 secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
249 cors_origins: vec![],
250 environment: "development".into(),
251 },
252 auth: AuthConfig::default(),
253 database: DatabaseConfig::default(),
254 jetstream: None,
255 firehose: None,
256 feed_generator: None,
257 labeler: None,
258 rate_limit: None,
259 }),
260 db,
261 http: reqwest::Client::new(),
262 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
263 reqwest::Client::new(),
264 )),
265 };
266
267 let app = crate::xrpc_router::<AppState>().with_state(state);
268
269 let resp = app
270 .oneshot(
271 Request::get("/xrpc/com.nonexistent.method")
272 .body(Body::empty())
273 .unwrap(),
274 )
275 .await
276 .unwrap();
277 assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
278 let bytes = Body::new(resp.into_body())
279 .collect()
280 .await
281 .unwrap()
282 .to_bytes();
283 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
284 assert_eq!(body["error"], "MethodNotImplemented");
285 }
286}