Skip to main content

atrg_xrpc/
error.rs

1//! AT Protocol XRPC error envelope.
2//!
3//! Every `/xrpc/*` failure must use this type to ensure responses
4//! conform to the AT Protocol error format.
5
6use axum::http::StatusCode;
7use axum::response::{IntoResponse, Response};
8use axum::Json;
9
10use atrg_core::error::AtrgError;
11
12/// XRPC error name variants per the AT Protocol spec.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum XrpcErrorName {
15    /// The request was malformed. HTTP 400.
16    InvalidRequest,
17    /// Authentication is required. HTTP 401.
18    AuthRequired,
19    /// The authenticated user is not allowed to perform this action. HTTP 403.
20    Forbidden,
21    /// The requested resource was not found. HTTP 404.
22    NotFound,
23    /// Too many requests. HTTP 429.
24    RateLimitExceeded,
25    /// An unexpected server error occurred. HTTP 500.
26    InternalServerError,
27    /// The XRPC method is not implemented. HTTP 501.
28    MethodNotImplemented,
29}
30
31impl XrpcErrorName {
32    /// The string representation used in the JSON error envelope.
33    pub fn as_str(&self) -> &'static str {
34        match self {
35            Self::InvalidRequest => "InvalidRequest",
36            Self::AuthRequired => "AuthRequired",
37            Self::Forbidden => "Forbidden",
38            Self::NotFound => "NotFound",
39            Self::RateLimitExceeded => "RateLimitExceeded",
40            Self::InternalServerError => "InternalServerError",
41            Self::MethodNotImplemented => "MethodNotImplemented",
42        }
43    }
44
45    /// The HTTP status code for this error.
46    pub fn status_code(&self) -> StatusCode {
47        match self {
48            Self::InvalidRequest => StatusCode::BAD_REQUEST,
49            Self::AuthRequired => StatusCode::UNAUTHORIZED,
50            Self::Forbidden => StatusCode::FORBIDDEN,
51            Self::NotFound => StatusCode::NOT_FOUND,
52            Self::RateLimitExceeded => StatusCode::TOO_MANY_REQUESTS,
53            Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR,
54            Self::MethodNotImplemented => StatusCode::NOT_IMPLEMENTED,
55        }
56    }
57}
58
59/// AT Protocol XRPC error envelope.
60///
61/// Use this as the error type for all `/xrpc/*` handlers:
62///
63/// ```rust,ignore
64/// async fn get_posts() -> Result<Json<Posts>, XrpcError> {
65///     // ...
66/// }
67/// ```
68#[derive(Debug)]
69pub struct XrpcError {
70    /// The error category.
71    pub name: XrpcErrorName,
72    /// Human-readable error message.
73    pub message: String,
74}
75
76impl std::fmt::Display for XrpcError {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        write!(f, "{}: {}", self.name.as_str(), self.message)
79    }
80}
81
82impl std::error::Error for XrpcError {}
83
84impl IntoResponse for XrpcError {
85    fn into_response(self) -> Response {
86        let status = self.name.status_code();
87        let body = serde_json::json!({
88            "error": self.name.as_str(),
89            "message": self.message,
90        });
91        (status, Json(body)).into_response()
92    }
93}
94
95impl From<AtrgError> for XrpcError {
96    fn from(err: AtrgError) -> Self {
97        match err {
98            AtrgError::NotFound => XrpcError {
99                name: XrpcErrorName::NotFound,
100                message: "Not found".to_string(),
101            },
102            AtrgError::Auth(msg) => XrpcError {
103                name: XrpcErrorName::AuthRequired,
104                message: msg,
105            },
106            AtrgError::BadRequest(msg) => XrpcError {
107                name: XrpcErrorName::InvalidRequest,
108                message: msg,
109            },
110            AtrgError::Database(_) => XrpcError {
111                name: XrpcErrorName::InternalServerError,
112                message: "Internal server error".to_string(),
113            },
114            AtrgError::Internal(_) => XrpcError {
115                name: XrpcErrorName::InternalServerError,
116                message: "Internal server error".to_string(),
117            },
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use axum::body::Body;
126    use http_body_util::BodyExt;
127
128    async fn error_to_parts(err: XrpcError) -> (StatusCode, serde_json::Value) {
129        let response = err.into_response();
130        let status = response.status();
131        let bytes = Body::new(response.into_body())
132            .collect()
133            .await
134            .unwrap()
135            .to_bytes();
136        let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
137        (status, json)
138    }
139
140    #[tokio::test]
141    async fn invalid_request_400() {
142        let (status, body) = error_to_parts(XrpcError {
143            name: XrpcErrorName::InvalidRequest,
144            message: "bad input".into(),
145        })
146        .await;
147        assert_eq!(status, StatusCode::BAD_REQUEST);
148        assert_eq!(body["error"], "InvalidRequest");
149        assert_eq!(body["message"], "bad input");
150    }
151
152    #[tokio::test]
153    async fn auth_required_401() {
154        let (status, body) = error_to_parts(XrpcError {
155            name: XrpcErrorName::AuthRequired,
156            message: "login needed".into(),
157        })
158        .await;
159        assert_eq!(status, StatusCode::UNAUTHORIZED);
160        assert_eq!(body["error"], "AuthRequired");
161    }
162
163    #[tokio::test]
164    async fn forbidden_403() {
165        let (status, _) = error_to_parts(XrpcError {
166            name: XrpcErrorName::Forbidden,
167            message: "nope".into(),
168        })
169        .await;
170        assert_eq!(status, StatusCode::FORBIDDEN);
171    }
172
173    #[tokio::test]
174    async fn not_found_404() {
175        let (status, _) = error_to_parts(XrpcError {
176            name: XrpcErrorName::NotFound,
177            message: "gone".into(),
178        })
179        .await;
180        assert_eq!(status, StatusCode::NOT_FOUND);
181    }
182
183    #[tokio::test]
184    async fn rate_limit_429() {
185        let (status, _) = error_to_parts(XrpcError {
186            name: XrpcErrorName::RateLimitExceeded,
187            message: "slow down".into(),
188        })
189        .await;
190        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
191    }
192
193    #[tokio::test]
194    async fn internal_500() {
195        let (status, _) = error_to_parts(XrpcError {
196            name: XrpcErrorName::InternalServerError,
197            message: "oops".into(),
198        })
199        .await;
200        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
201    }
202
203    #[tokio::test]
204    async fn not_implemented_501() {
205        let (status, body) = error_to_parts(XrpcError {
206            name: XrpcErrorName::MethodNotImplemented,
207            message: "not here".into(),
208        })
209        .await;
210        assert_eq!(status, StatusCode::NOT_IMPLEMENTED);
211        assert_eq!(body["error"], "MethodNotImplemented");
212    }
213
214    #[tokio::test]
215    async fn from_atrg_error_not_found() {
216        let xrpc: XrpcError = AtrgError::NotFound.into();
217        assert_eq!(xrpc.name, XrpcErrorName::NotFound);
218    }
219
220    #[tokio::test]
221    async fn from_atrg_error_auth() {
222        let xrpc: XrpcError = AtrgError::Auth("no".into()).into();
223        assert_eq!(xrpc.name, XrpcErrorName::AuthRequired);
224    }
225
226    #[tokio::test]
227    async fn from_atrg_error_bad_request() {
228        let xrpc: XrpcError = AtrgError::BadRequest("bad".into()).into();
229        assert_eq!(xrpc.name, XrpcErrorName::InvalidRequest);
230    }
231
232    #[tokio::test]
233    async fn xrpc_router_fallback_returns_501() {
234        use atrg_core::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
235        use atrg_core::state::AppState;
236        use hyper::Request;
237        use std::sync::Arc;
238        use tower::ServiceExt;
239
240        let db = atrg_db::connect("sqlite::memory:").await.unwrap();
241        atrg_db::run_internal_migrations(&db).await.unwrap();
242        let state = AppState {
243            config: Arc::new(Config {
244                app: AppConfig {
245                    name: "test".into(),
246                    host: "127.0.0.1".into(),
247                    port: 3000,
248                    secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
249                    cors_origins: vec![],
250                    environment: "development".into(),
251                },
252                auth: AuthConfig::default(),
253                database: DatabaseConfig::default(),
254                jetstream: None,
255                firehose: None,
256                feed_generator: None,
257                labeler: None,
258                rate_limit: None,
259            }),
260            db,
261            http: reqwest::Client::new(),
262            identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
263                reqwest::Client::new(),
264            )),
265        };
266
267        let app = crate::xrpc_router::<AppState>().with_state(state);
268
269        let resp = app
270            .oneshot(
271                Request::get("/xrpc/com.nonexistent.method")
272                    .body(Body::empty())
273                    .unwrap(),
274            )
275            .await
276            .unwrap();
277        assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
278        let bytes = Body::new(resp.into_body())
279            .collect()
280            .await
281            .unwrap()
282            .to_bytes();
283        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
284        assert_eq!(body["error"], "MethodNotImplemented");
285    }
286}