1use axum::http::StatusCode;
10use axum::response::{IntoResponse, Response};
11use axum::Json;
12
13pub type AtrgResult<T> = Result<T, AtrgError>;
15
16#[derive(Debug, thiserror::Error)]
20pub enum AtrgError {
21 #[error("database error: {0}")]
24 Database(#[from] sqlx::Error),
25
26 #[error("unauthorized: {0}")]
29 Auth(String),
30
31 #[error("not found")]
33 NotFound,
34
35 #[error("bad request: {0}")]
38 BadRequest(String),
39
40 #[error("internal error: {0}")]
43 Internal(anyhow::Error),
44}
45
46impl From<anyhow::Error> for AtrgError {
47 fn from(err: anyhow::Error) -> Self {
48 AtrgError::Internal(err)
49 }
50}
51
52impl IntoResponse for AtrgError {
53 fn into_response(self) -> Response {
54 let (status, code, message) = match &self {
55 AtrgError::NotFound => (StatusCode::NOT_FOUND, "not_found", "Not found".to_string()),
56 AtrgError::Auth(m) => (StatusCode::UNAUTHORIZED, "unauthorized", m.clone()),
57 AtrgError::BadRequest(m) => (StatusCode::BAD_REQUEST, "bad_request", m.clone()),
58 AtrgError::Database(e) => {
59 tracing::error!(error = %e, "database error");
60 (
61 StatusCode::INTERNAL_SERVER_ERROR,
62 "database_error",
63 "Database error".to_string(),
64 )
65 }
66 AtrgError::Internal(e) => {
67 tracing::error!(error = %e, "internal error");
68 (
69 StatusCode::INTERNAL_SERVER_ERROR,
70 "internal_error",
71 "Internal server error".to_string(),
72 )
73 }
74 };
75
76 (
77 status,
78 Json(serde_json::json!({
79 "error": code,
80 "message": message,
81 })),
82 )
83 .into_response()
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use axum::body::Body;
91 use http_body_util::BodyExt;
92
93 async fn error_to_parts(err: AtrgError) -> (StatusCode, serde_json::Value) {
94 let response = err.into_response();
95 let status = response.status();
96 let body = response.into_body();
97 let bytes = Body::new(body).collect().await.unwrap().to_bytes();
98 let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
99 (status, json)
100 }
101
102 #[tokio::test]
103 async fn not_found_returns_404() {
104 let (status, body) = error_to_parts(AtrgError::NotFound).await;
105 assert_eq!(status, StatusCode::NOT_FOUND);
106 assert_eq!(body["error"], "not_found");
107 assert_eq!(body["message"], "Not found");
108 }
109
110 #[tokio::test]
111 async fn auth_returns_401() {
112 let (status, body) = error_to_parts(AtrgError::Auth("bad token".into())).await;
113 assert_eq!(status, StatusCode::UNAUTHORIZED);
114 assert_eq!(body["error"], "unauthorized");
115 assert_eq!(body["message"], "bad token");
116 }
117
118 #[tokio::test]
119 async fn bad_request_returns_400() {
120 let (status, body) = error_to_parts(AtrgError::BadRequest("missing field".into())).await;
121 assert_eq!(status, StatusCode::BAD_REQUEST);
122 assert_eq!(body["error"], "bad_request");
123 assert_eq!(body["message"], "missing field");
124 }
125
126 #[tokio::test]
127 async fn database_error_returns_500() {
128 let err = AtrgError::Database(sqlx::Error::RowNotFound);
129 let (status, body) = error_to_parts(err).await;
130 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
131 assert_eq!(body["error"], "database_error");
132 assert_eq!(body["message"], "Database error");
133 }
134
135 #[tokio::test]
136 async fn internal_error_returns_500() {
137 let err = AtrgError::Internal(anyhow::anyhow!("something broke"));
138 let (status, body) = error_to_parts(err).await;
139 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
140 assert_eq!(body["error"], "internal_error");
141 assert_eq!(body["message"], "Internal server error");
142 }
143
144 #[tokio::test]
145 async fn response_content_type_is_json() {
146 let response = AtrgError::NotFound.into_response();
147 let content_type = response
148 .headers()
149 .get(axum::http::header::CONTENT_TYPE)
150 .unwrap()
151 .to_str()
152 .unwrap();
153 assert!(
154 content_type.contains("application/json"),
155 "expected application/json, got: {content_type}"
156 );
157 }
158
159 #[test]
160 fn from_sqlx_error() {
161 let err: AtrgError = sqlx::Error::RowNotFound.into();
162 assert!(matches!(err, AtrgError::Database(_)));
163 }
164
165 #[test]
166 fn from_anyhow_error() {
167 let err: AtrgError = anyhow::anyhow!("boom").into();
168 assert!(matches!(err, AtrgError::Internal(_)));
169 }
170
171 #[tokio::test]
172 async fn response_body_has_exactly_two_keys() {
173 let (_, body) = error_to_parts(AtrgError::NotFound).await;
174 let obj = body.as_object().unwrap();
175 assert_eq!(obj.len(), 2, "expected exactly 'error' and 'message' keys");
176 assert!(obj.contains_key("error"));
177 assert!(obj.contains_key("message"));
178 }
179}