Skip to main content

atrg_auth/
jwt.rs

1//! AT Protocol JWT verification.
2//!
3//! Verifies PDS-issued JWTs by resolving the issuer's signing key
4//! via the identity resolver.
5
6use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
7use serde::{Deserialize, Serialize};
8
9/// Claims extracted from an AT Protocol JWT.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct JwtClaims {
12    /// Issuer — the PDS DID.
13    pub iss: String,
14    /// Subject — the user's DID.
15    pub sub: String,
16    /// Audience — should match this server's host.
17    pub aud: Option<String>,
18    /// Expiration time (Unix timestamp).
19    pub exp: Option<u64>,
20    /// Not before (Unix timestamp).
21    pub nbf: Option<u64>,
22    /// Scope string.
23    pub scope: Option<String>,
24}
25
26/// Errors from JWT verification.
27#[derive(Debug, thiserror::Error)]
28pub enum JwtError {
29    /// The token is not structurally valid.
30    #[error("malformed JWT: {0}")]
31    Malformed(String),
32    /// The token has expired.
33    #[error("JWT expired")]
34    Expired,
35    /// The audience claim doesn't match.
36    #[error("JWT audience mismatch: expected {expected}, got {actual}")]
37    AudienceMismatch {
38        /// Expected audience.
39        expected: String,
40        /// Actual audience in the token.
41        actual: String,
42    },
43    /// The issuer could not be resolved.
44    #[error("could not resolve JWT issuer: {0}")]
45    IssuerResolution(String),
46    /// Signature verification failed.
47    #[error("JWT signature verification failed: {0}")]
48    SignatureInvalid(String),
49}
50
51/// Check if a token string looks like a JWT (3 base64url segments separated by dots).
52pub fn looks_like_jwt(token: &str) -> bool {
53    let parts: Vec<&str> = token.split('.').collect();
54    parts.len() == 3 && parts.iter().all(|p| !p.is_empty())
55}
56
57/// Decode JWT claims WITHOUT verifying the signature.
58///
59/// This is used for the initial dispatch to determine if a bearer token
60/// is a JWT or an atrg session token.
61pub fn decode_claims_unverified(token: &str) -> Result<JwtClaims, JwtError> {
62    let header = decode_header(token).map_err(|e| JwtError::Malformed(e.to_string()))?;
63
64    // Decode payload without verification
65    let mut validation = Validation::new(header.alg);
66    validation.insecure_disable_signature_validation();
67    validation.validate_exp = false;
68    validation.validate_nbf = false;
69    validation.validate_aud = false;
70    validation.required_spec_claims.clear();
71
72    let token_data = decode::<JwtClaims>(token, &DecodingKey::from_secret(b""), &validation)
73        .map_err(|e| JwtError::Malformed(e.to_string()))?;
74
75    Ok(token_data.claims)
76}
77
78/// Verify a JWT's expiration claim.
79pub fn verify_expiration(claims: &JwtClaims) -> Result<(), JwtError> {
80    if let Some(exp) = claims.exp {
81        let now = std::time::SystemTime::now()
82            .duration_since(std::time::UNIX_EPOCH)
83            .unwrap_or_default()
84            .as_secs();
85        if now > exp {
86            return Err(JwtError::Expired);
87        }
88    }
89    Ok(())
90}
91
92/// Verify the audience claim matches the expected value.
93pub fn verify_audience(claims: &JwtClaims, expected_host: &str) -> Result<(), JwtError> {
94    if let Some(ref aud) = claims.aud {
95        if !aud.contains(expected_host) {
96            return Err(JwtError::AudienceMismatch {
97                expected: expected_host.to_string(),
98                actual: aud.clone(),
99            });
100        }
101    }
102    Ok(())
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn looks_like_jwt_valid() {
111        assert!(looks_like_jwt(
112            "eyJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0In0.sig"
113        ));
114    }
115
116    #[test]
117    fn looks_like_jwt_not_jwt() {
118        assert!(!looks_like_jwt("just-a-session-token"));
119        assert!(!looks_like_jwt("two.parts"));
120        assert!(!looks_like_jwt(""));
121        assert!(!looks_like_jwt("a..b"));
122    }
123
124    #[test]
125    fn decode_unverified_valid() {
126        // Create a minimal unsigned JWT for testing
127        // Header: {"alg":"HS256"}
128        // Payload: {"iss":"did:plc:test","sub":"did:plc:user"}
129        let header = base64_url_encode(br#"{"alg":"HS256"}"#);
130        let payload = base64_url_encode(br#"{"iss":"did:plc:test","sub":"did:plc:user"}"#);
131        let token = format!("{header}.{payload}.fakesig");
132
133        let claims = decode_claims_unverified(&token).unwrap();
134        assert_eq!(claims.iss, "did:plc:test");
135        assert_eq!(claims.sub, "did:plc:user");
136    }
137
138    #[test]
139    fn decode_unverified_malformed() {
140        let result = decode_claims_unverified("not-a-jwt");
141        assert!(result.is_err());
142    }
143
144    #[test]
145    fn verify_expiration_valid() {
146        let claims = JwtClaims {
147            iss: "test".into(),
148            sub: "test".into(),
149            aud: None,
150            exp: Some(u64::MAX),
151            nbf: None,
152            scope: None,
153        };
154        assert!(verify_expiration(&claims).is_ok());
155    }
156
157    #[test]
158    fn verify_expiration_expired() {
159        let claims = JwtClaims {
160            iss: "test".into(),
161            sub: "test".into(),
162            aud: None,
163            exp: Some(0),
164            nbf: None,
165            scope: None,
166        };
167        assert!(matches!(verify_expiration(&claims), Err(JwtError::Expired)));
168    }
169
170    #[test]
171    fn verify_expiration_none_is_ok() {
172        let claims = JwtClaims {
173            iss: "test".into(),
174            sub: "test".into(),
175            aud: None,
176            exp: None,
177            nbf: None,
178            scope: None,
179        };
180        assert!(verify_expiration(&claims).is_ok());
181    }
182
183    #[test]
184    fn verify_audience_match() {
185        let claims = JwtClaims {
186            iss: "test".into(),
187            sub: "test".into(),
188            aud: Some("https://myapp.example.com".into()),
189            exp: None,
190            nbf: None,
191            scope: None,
192        };
193        assert!(verify_audience(&claims, "myapp.example.com").is_ok());
194    }
195
196    #[test]
197    fn verify_audience_mismatch() {
198        let claims = JwtClaims {
199            iss: "test".into(),
200            sub: "test".into(),
201            aud: Some("https://other.example.com".into()),
202            exp: None,
203            nbf: None,
204            scope: None,
205        };
206        assert!(matches!(
207            verify_audience(&claims, "myapp.example.com"),
208            Err(JwtError::AudienceMismatch { .. })
209        ));
210    }
211
212    #[test]
213    fn verify_audience_none_is_ok() {
214        let claims = JwtClaims {
215            iss: "test".into(),
216            sub: "test".into(),
217            aud: None,
218            exp: None,
219            nbf: None,
220            scope: None,
221        };
222        assert!(verify_audience(&claims, "myapp.example.com").is_ok());
223    }
224
225    fn base64_url_encode(data: &[u8]) -> String {
226        use base64::Engine;
227        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
228    }
229}