workspace-mcp 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
auth/oauth21/jwt.py ADDED
@@ -0,0 +1,438 @@
1
+ """
2
+ JWT Handler
3
+
4
+ Specialized JWT parsing and validation functionality with JWKS support.
5
+ Complements the token validator with JWT-specific features.
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, Any, List, Optional, Union
10
+ from datetime import datetime, timezone
11
+ from cryptography.hazmat.primitives.asymmetric import rsa, ec
12
+ from cryptography.hazmat.backends import default_backend
13
+
14
+ import aiohttp
15
+ import jwt
16
+ from cachetools import TTLCache
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class JWTHandler:
22
+ """Handles JWT parsing and validation with JWKS support."""
23
+
24
+ def __init__(
25
+ self,
26
+ jwks_cache_ttl: int = 3600, # 1 hour
27
+ max_jwks_cache_size: int = 50,
28
+ ):
29
+ """
30
+ Initialize the JWT handler.
31
+
32
+ Args:
33
+ jwks_cache_ttl: JWKS cache TTL in seconds
34
+ max_jwks_cache_size: Maximum number of cached JWKS entries
35
+ """
36
+ self.jwks_cache = TTLCache(maxsize=max_jwks_cache_size, ttl=jwks_cache_ttl)
37
+ self._session: Optional[aiohttp.ClientSession] = None
38
+
39
+ async def _get_session(self) -> aiohttp.ClientSession:
40
+ """Get or create HTTP session."""
41
+ if self._session is None or self._session.closed:
42
+ self._session = aiohttp.ClientSession(
43
+ timeout=aiohttp.ClientTimeout(total=30),
44
+ headers={"User-Agent": "MCP-JWT-Handler/1.0"},
45
+ )
46
+ return self._session
47
+
48
+ async def close(self):
49
+ """Clean up resources."""
50
+ if self._session and not self._session.closed:
51
+ await self._session.close()
52
+
53
+ def decode_jwt_header(self, token: str) -> Dict[str, Any]:
54
+ """
55
+ Decode JWT header without verification.
56
+
57
+ Args:
58
+ token: JWT token
59
+
60
+ Returns:
61
+ JWT header dictionary
62
+
63
+ Raises:
64
+ jwt.InvalidTokenError: If token format is invalid
65
+ """
66
+ try:
67
+ return jwt.get_unverified_header(token)
68
+ except Exception as e:
69
+ logger.error(f"Failed to decode JWT header: {e}")
70
+ raise jwt.InvalidTokenError(f"Invalid JWT header: {str(e)}")
71
+
72
+ def decode_jwt_payload(self, token: str) -> Dict[str, Any]:
73
+ """
74
+ Decode JWT payload without verification.
75
+
76
+ Args:
77
+ token: JWT token
78
+
79
+ Returns:
80
+ JWT payload dictionary
81
+
82
+ Raises:
83
+ jwt.InvalidTokenError: If token format is invalid
84
+ """
85
+ try:
86
+ return jwt.decode(token, options={"verify_signature": False})
87
+ except Exception as e:
88
+ logger.error(f"Failed to decode JWT payload: {e}")
89
+ raise jwt.InvalidTokenError(f"Invalid JWT payload: {str(e)}")
90
+
91
+ async def decode_jwt(
92
+ self,
93
+ token: str,
94
+ jwks_uri: Optional[str] = None,
95
+ audience: Optional[Union[str, List[str]]] = None,
96
+ issuer: Optional[str] = None,
97
+ algorithms: Optional[List[str]] = None,
98
+ ) -> Dict[str, Any]:
99
+ """
100
+ Decode and verify JWT signature with JWKS.
101
+
102
+ Args:
103
+ token: JWT token to decode
104
+ jwks_uri: JWKS endpoint URI
105
+ audience: Expected audience(s)
106
+ issuer: Expected issuer
107
+ algorithms: Allowed signing algorithms
108
+
109
+ Returns:
110
+ Verified JWT payload
111
+
112
+ Raises:
113
+ jwt.InvalidTokenError: If JWT verification fails
114
+ """
115
+ if algorithms is None:
116
+ algorithms = ["RS256", "ES256", "HS256"]
117
+
118
+ # Get JWT header to find key ID
119
+ header = self.decode_jwt_header(token)
120
+ kid = header.get("kid")
121
+ alg = header.get("alg")
122
+
123
+ if alg not in algorithms:
124
+ raise jwt.InvalidTokenError(f"Algorithm {alg} not allowed")
125
+
126
+ # Fetch JWKS if URI provided
127
+ verification_key = None
128
+ if jwks_uri:
129
+ jwks = await self.fetch_jwks(jwks_uri)
130
+ verification_key = self._find_key_in_jwks(jwks, kid, alg)
131
+
132
+ if not verification_key:
133
+ raise jwt.InvalidTokenError("No valid verification key found")
134
+
135
+ # Verify and decode JWT
136
+ try:
137
+ payload = jwt.decode(
138
+ token,
139
+ key=verification_key,
140
+ algorithms=[alg] if alg else algorithms,
141
+ audience=audience,
142
+ issuer=issuer,
143
+ options={
144
+ "verify_signature": True,
145
+ "verify_exp": True,
146
+ "verify_aud": audience is not None,
147
+ "verify_iss": issuer is not None,
148
+ }
149
+ )
150
+
151
+ logger.debug("Successfully decoded and verified JWT")
152
+ return payload
153
+
154
+ except jwt.ExpiredSignatureError:
155
+ logger.warning("JWT token has expired")
156
+ raise
157
+ except jwt.InvalidAudienceError:
158
+ logger.warning("JWT audience validation failed")
159
+ raise
160
+ except jwt.InvalidIssuerError:
161
+ logger.warning("JWT issuer validation failed")
162
+ raise
163
+ except jwt.InvalidSignatureError:
164
+ logger.warning("JWT signature verification failed")
165
+ raise
166
+ except Exception as e:
167
+ logger.error(f"JWT verification failed: {e}")
168
+ raise jwt.InvalidTokenError(f"JWT verification failed: {str(e)}")
169
+
170
+ async def fetch_jwks(self, jwks_uri: str) -> Dict[str, Any]:
171
+ """
172
+ Fetch and cache JWKS from URI.
173
+
174
+ Args:
175
+ jwks_uri: JWKS endpoint URI
176
+
177
+ Returns:
178
+ JWKS dictionary
179
+
180
+ Raises:
181
+ aiohttp.ClientError: If JWKS cannot be fetched
182
+ """
183
+ # Check cache first
184
+ if jwks_uri in self.jwks_cache:
185
+ logger.debug(f"Using cached JWKS for {jwks_uri}")
186
+ return self.jwks_cache[jwks_uri]
187
+
188
+ session = await self._get_session()
189
+
190
+ try:
191
+ logger.debug(f"Fetching JWKS from {jwks_uri}")
192
+ async with session.get(jwks_uri) as response:
193
+ if response.status != 200:
194
+ raise aiohttp.ClientError(f"JWKS fetch failed: {response.status}")
195
+
196
+ jwks = await response.json()
197
+
198
+ # Validate JWKS format
199
+ if not isinstance(jwks, dict) or "keys" not in jwks:
200
+ raise ValueError("Invalid JWKS format")
201
+
202
+ self.jwks_cache[jwks_uri] = jwks
203
+ logger.info(f"Successfully fetched and cached JWKS from {jwks_uri}")
204
+ return jwks
205
+
206
+ except aiohttp.ClientError:
207
+ raise
208
+ except Exception as e:
209
+ logger.error(f"Failed to fetch JWKS from {jwks_uri}: {e}")
210
+ raise aiohttp.ClientError(f"JWKS fetch failed: {str(e)}")
211
+
212
+ def _find_key_in_jwks(
213
+ self,
214
+ jwks: Dict[str, Any],
215
+ kid: Optional[str] = None,
216
+ alg: Optional[str] = None,
217
+ ) -> Optional[Any]:
218
+ """
219
+ Find appropriate key in JWKS for token verification.
220
+
221
+ Args:
222
+ jwks: JWKS dictionary
223
+ kid: Key ID from JWT header
224
+ alg: Algorithm from JWT header
225
+
226
+ Returns:
227
+ Verification key or None if not found
228
+ """
229
+ keys = jwks.get("keys", [])
230
+
231
+ for key_data in keys:
232
+ # Match by key ID if provided
233
+ if kid and key_data.get("kid") != kid:
234
+ continue
235
+
236
+ # Match by algorithm if provided
237
+ if alg and key_data.get("alg") and key_data.get("alg") != alg:
238
+ continue
239
+
240
+ # Convert JWK to key object
241
+ try:
242
+ key = self._jwk_to_key(key_data)
243
+ if key:
244
+ logger.debug(f"Found matching key in JWKS: kid={key_data.get('kid')}")
245
+ return key
246
+ except Exception as e:
247
+ logger.warning(f"Failed to convert JWK to key: {e}")
248
+ continue
249
+
250
+ logger.warning(f"No matching key found in JWKS for kid={kid}, alg={alg}")
251
+ return None
252
+
253
+ def _jwk_to_key(self, jwk: Dict[str, Any]) -> Optional[Any]:
254
+ """
255
+ Convert JWK (JSON Web Key) to cryptographic key object.
256
+
257
+ Args:
258
+ jwk: JWK dictionary
259
+
260
+ Returns:
261
+ Key object for verification
262
+ """
263
+ kty = jwk.get("kty")
264
+ use = jwk.get("use")
265
+
266
+ # Skip keys not for signature verification
267
+ if use and use != "sig":
268
+ return None
269
+
270
+ try:
271
+ if kty == "RSA":
272
+ return self._jwk_to_rsa_key(jwk)
273
+ elif kty == "EC":
274
+ return self._jwk_to_ec_key(jwk)
275
+ elif kty == "oct":
276
+ return self._jwk_to_symmetric_key(jwk)
277
+ else:
278
+ logger.warning(f"Unsupported key type: {kty}")
279
+ return None
280
+ except Exception as e:
281
+ logger.error(f"Failed to convert {kty} JWK to key: {e}")
282
+ return None
283
+
284
+ def _jwk_to_rsa_key(self, jwk: Dict[str, Any]) -> rsa.RSAPublicKey:
285
+ """Convert RSA JWK to RSA public key."""
286
+ import base64
287
+
288
+ n = jwk.get("n")
289
+ e = jwk.get("e")
290
+
291
+ if not n or not e:
292
+ raise ValueError("RSA JWK missing n or e parameter")
293
+
294
+ # Decode base64url
295
+ n_bytes = base64.urlsafe_b64decode(n + "==")
296
+ e_bytes = base64.urlsafe_b64decode(e + "==")
297
+
298
+ # Convert to integers
299
+ n_int = int.from_bytes(n_bytes, byteorder="big")
300
+ e_int = int.from_bytes(e_bytes, byteorder="big")
301
+
302
+ # Create RSA public key
303
+ public_key = rsa.RSAPublicNumbers(e_int, n_int).public_key(default_backend())
304
+ return public_key
305
+
306
+ def _jwk_to_ec_key(self, jwk: Dict[str, Any]) -> ec.EllipticCurvePublicKey:
307
+ """Convert EC JWK to EC public key."""
308
+ import base64
309
+
310
+ crv = jwk.get("crv")
311
+ x = jwk.get("x")
312
+ y = jwk.get("y")
313
+
314
+ if not all([crv, x, y]):
315
+ raise ValueError("EC JWK missing required parameters")
316
+
317
+ # Map curve names
318
+ curve_map = {
319
+ "P-256": ec.SECP256R1(),
320
+ "P-384": ec.SECP384R1(),
321
+ "P-521": ec.SECP521R1(),
322
+ }
323
+
324
+ curve = curve_map.get(crv)
325
+ if not curve:
326
+ raise ValueError(f"Unsupported EC curve: {crv}")
327
+
328
+ # Decode coordinates
329
+ x_bytes = base64.urlsafe_b64decode(x + "==")
330
+ y_bytes = base64.urlsafe_b64decode(y + "==")
331
+
332
+ x_int = int.from_bytes(x_bytes, byteorder="big")
333
+ y_int = int.from_bytes(y_bytes, byteorder="big")
334
+
335
+ # Create EC public key
336
+ public_key = ec.EllipticCurvePublicNumbers(x_int, y_int, curve).public_key(default_backend())
337
+ return public_key
338
+
339
+ def _jwk_to_symmetric_key(self, jwk: Dict[str, Any]) -> bytes:
340
+ """Convert symmetric JWK to key bytes."""
341
+ import base64
342
+
343
+ k = jwk.get("k")
344
+ if not k:
345
+ raise ValueError("Symmetric JWK missing k parameter")
346
+
347
+ return base64.urlsafe_b64decode(k + "==")
348
+
349
+ def extract_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
350
+ """
351
+ Extract and normalize standard JWT claims.
352
+
353
+ Args:
354
+ payload: JWT payload
355
+
356
+ Returns:
357
+ Dictionary of normalized claims
358
+ """
359
+ claims = {}
360
+
361
+ # Standard claims with normalization
362
+ claim_mapping = {
363
+ "iss": "issuer",
364
+ "sub": "subject",
365
+ "aud": "audience",
366
+ "exp": "expires_at",
367
+ "nbf": "not_before",
368
+ "iat": "issued_at",
369
+ "jti": "jwt_id",
370
+ }
371
+
372
+ for jwt_claim, normalized_name in claim_mapping.items():
373
+ if jwt_claim in payload:
374
+ claims[normalized_name] = payload[jwt_claim]
375
+
376
+ # Convert timestamps
377
+ for time_claim in ["expires_at", "not_before", "issued_at"]:
378
+ if time_claim in claims:
379
+ claims[time_claim] = self._timestamp_to_datetime(claims[time_claim])
380
+
381
+ # Additional common claims
382
+ for claim in ["email", "email_verified", "name", "preferred_username", "scope", "scp"]:
383
+ if claim in payload:
384
+ claims[claim] = payload[claim]
385
+
386
+ return claims
387
+
388
+ def _timestamp_to_datetime(self, timestamp: Union[int, float]) -> datetime:
389
+ """Convert Unix timestamp to datetime object."""
390
+ try:
391
+ return datetime.fromtimestamp(timestamp, tz=timezone.utc)
392
+ except (ValueError, TypeError) as e:
393
+ logger.warning(f"Invalid timestamp: {timestamp}: {e}")
394
+ return datetime.now(timezone.utc)
395
+
396
+ def is_jwt_expired(self, payload: Dict[str, Any]) -> bool:
397
+ """
398
+ Check if JWT is expired based on exp claim.
399
+
400
+ Args:
401
+ payload: JWT payload
402
+
403
+ Returns:
404
+ True if JWT is expired
405
+ """
406
+ exp = payload.get("exp")
407
+ if not exp:
408
+ return False
409
+
410
+ exp_time = self._timestamp_to_datetime(exp)
411
+ return datetime.now(timezone.utc) >= exp_time
412
+
413
+ def get_jwt_info(self, token: str) -> Dict[str, Any]:
414
+ """
415
+ Extract JWT information without verification.
416
+
417
+ Args:
418
+ token: JWT token
419
+
420
+ Returns:
421
+ Dictionary with JWT information
422
+ """
423
+ try:
424
+ header = self.decode_jwt_header(token)
425
+ payload = self.decode_jwt_payload(token)
426
+
427
+ return {
428
+ "header": header,
429
+ "payload": payload,
430
+ "claims": self.extract_claims(payload),
431
+ "expired": self.is_jwt_expired(payload),
432
+ "algorithm": header.get("alg"),
433
+ "key_id": header.get("kid"),
434
+ "token_type": header.get("typ", "JWT"),
435
+ }
436
+ except Exception as e:
437
+ logger.error(f"Failed to extract JWT info: {e}")
438
+ return {"error": str(e)}