amazon-ads-mcp 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. amazon_ads_mcp/__init__.py +11 -0
  2. amazon_ads_mcp/auth/__init__.py +33 -0
  3. amazon_ads_mcp/auth/base.py +211 -0
  4. amazon_ads_mcp/auth/hooks.py +172 -0
  5. amazon_ads_mcp/auth/manager.py +791 -0
  6. amazon_ads_mcp/auth/oauth_state_store.py +277 -0
  7. amazon_ads_mcp/auth/providers/__init__.py +14 -0
  8. amazon_ads_mcp/auth/providers/direct.py +393 -0
  9. amazon_ads_mcp/auth/providers/example_auth0.py.example +216 -0
  10. amazon_ads_mcp/auth/providers/openbridge.py +512 -0
  11. amazon_ads_mcp/auth/registry.py +146 -0
  12. amazon_ads_mcp/auth/secure_token_store.py +297 -0
  13. amazon_ads_mcp/auth/token_store.py +723 -0
  14. amazon_ads_mcp/config/__init__.py +5 -0
  15. amazon_ads_mcp/config/sampling.py +111 -0
  16. amazon_ads_mcp/config/settings.py +366 -0
  17. amazon_ads_mcp/exceptions.py +314 -0
  18. amazon_ads_mcp/middleware/__init__.py +11 -0
  19. amazon_ads_mcp/middleware/authentication.py +1474 -0
  20. amazon_ads_mcp/middleware/caching.py +177 -0
  21. amazon_ads_mcp/middleware/oauth.py +175 -0
  22. amazon_ads_mcp/middleware/sampling.py +112 -0
  23. amazon_ads_mcp/models/__init__.py +320 -0
  24. amazon_ads_mcp/models/amc_models.py +837 -0
  25. amazon_ads_mcp/models/api_responses.py +847 -0
  26. amazon_ads_mcp/models/base_models.py +215 -0
  27. amazon_ads_mcp/models/builtin_responses.py +496 -0
  28. amazon_ads_mcp/models/dsp_models.py +556 -0
  29. amazon_ads_mcp/models/stores_brands.py +610 -0
  30. amazon_ads_mcp/server/__init__.py +6 -0
  31. amazon_ads_mcp/server/__main__.py +6 -0
  32. amazon_ads_mcp/server/builtin_prompts.py +269 -0
  33. amazon_ads_mcp/server/builtin_tools.py +962 -0
  34. amazon_ads_mcp/server/file_routes.py +547 -0
  35. amazon_ads_mcp/server/html_templates.py +149 -0
  36. amazon_ads_mcp/server/mcp_server.py +327 -0
  37. amazon_ads_mcp/server/openapi_utils.py +158 -0
  38. amazon_ads_mcp/server/sampling_handler.py +251 -0
  39. amazon_ads_mcp/server/server_builder.py +751 -0
  40. amazon_ads_mcp/server/sidecar_loader.py +178 -0
  41. amazon_ads_mcp/server/transform_executor.py +827 -0
  42. amazon_ads_mcp/tools/__init__.py +22 -0
  43. amazon_ads_mcp/tools/cache_management.py +105 -0
  44. amazon_ads_mcp/tools/download_tools.py +267 -0
  45. amazon_ads_mcp/tools/identity.py +236 -0
  46. amazon_ads_mcp/tools/oauth.py +598 -0
  47. amazon_ads_mcp/tools/profile.py +150 -0
  48. amazon_ads_mcp/tools/profile_listing.py +285 -0
  49. amazon_ads_mcp/tools/region.py +320 -0
  50. amazon_ads_mcp/tools/region_identity.py +175 -0
  51. amazon_ads_mcp/utils/__init__.py +6 -0
  52. amazon_ads_mcp/utils/async_compat.py +215 -0
  53. amazon_ads_mcp/utils/errors.py +452 -0
  54. amazon_ads_mcp/utils/export_content_type_resolver.py +249 -0
  55. amazon_ads_mcp/utils/export_download_handler.py +579 -0
  56. amazon_ads_mcp/utils/header_resolver.py +81 -0
  57. amazon_ads_mcp/utils/http/__init__.py +56 -0
  58. amazon_ads_mcp/utils/http/circuit_breaker.py +127 -0
  59. amazon_ads_mcp/utils/http/client_manager.py +329 -0
  60. amazon_ads_mcp/utils/http/request.py +207 -0
  61. amazon_ads_mcp/utils/http/resilience.py +512 -0
  62. amazon_ads_mcp/utils/http/resilient_client.py +195 -0
  63. amazon_ads_mcp/utils/http/retry.py +76 -0
  64. amazon_ads_mcp/utils/http_client.py +873 -0
  65. amazon_ads_mcp/utils/media/__init__.py +21 -0
  66. amazon_ads_mcp/utils/media/negotiator.py +243 -0
  67. amazon_ads_mcp/utils/media/types.py +199 -0
  68. amazon_ads_mcp/utils/openapi/__init__.py +16 -0
  69. amazon_ads_mcp/utils/openapi/json.py +55 -0
  70. amazon_ads_mcp/utils/openapi/loader.py +263 -0
  71. amazon_ads_mcp/utils/openapi/refs.py +46 -0
  72. amazon_ads_mcp/utils/region_config.py +200 -0
  73. amazon_ads_mcp/utils/response_wrapper.py +171 -0
  74. amazon_ads_mcp/utils/sampling_helpers.py +156 -0
  75. amazon_ads_mcp/utils/sampling_wrapper.py +173 -0
  76. amazon_ads_mcp/utils/security.py +630 -0
  77. amazon_ads_mcp/utils/tool_naming.py +137 -0
  78. amazon_ads_mcp-0.2.7.dist-info/METADATA +664 -0
  79. amazon_ads_mcp-0.2.7.dist-info/RECORD +82 -0
  80. amazon_ads_mcp-0.2.7.dist-info/WHEEL +4 -0
  81. amazon_ads_mcp-0.2.7.dist-info/entry_points.txt +3 -0
  82. amazon_ads_mcp-0.2.7.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1474 @@
1
+ """Reusable FastMCP Authentication Middleware - Version 5.0 (Production Ready).
2
+
3
+ This module provides a comprehensive authentication middleware system for
4
+ FastMCP servers with support for JWT validation, refresh token conversion,
5
+ and context-safe token sharing between middleware components.
6
+
7
+ The module provides:
8
+ - AuthConfig: Configuration management for authentication settings
9
+ - JWTCache: Thread-safe JWT caching with automatic cleanup
10
+ - RefreshTokenMiddleware: Converts refresh tokens to JWT tokens
11
+ - JWTAuthenticationMiddleware: Validates JWT tokens with comprehensive error handling
12
+ - Utility functions for accessing JWT data and creating middleware chains
13
+ - Pre-configured configurations for common providers (OpenBridge, Auth0, JSON:API)
14
+
15
+ Key Features:
16
+ - FastMCP-compliant middleware patterns with proper hooks
17
+ - Context-safe JWT storage using contextvars for async safety
18
+ - JWT caching to reduce API calls and improve performance
19
+ - Comprehensive error handling and detailed logging
20
+ - OpenBridge-specific validation (user_id, account_id claims)
21
+ - Environment variable configuration for operator control
22
+ - Client disconnection handling and timeout management
23
+ - Support for multiple authentication providers
24
+
25
+ Examples:
26
+ >>> from .middleware.authentication import create_auth_middleware
27
+ >>> middleware = create_auth_middleware() # Auto-configure from environment
28
+
29
+ >>> # Use with specific configuration
30
+ >>> config = AuthConfig()
31
+ >>> config.load_from_env()
32
+ >>> middleware = create_auth_middleware(config)
33
+
34
+ >>> # Access JWT data in other parts of the application
35
+ >>> from .middleware.authentication import get_current_claims
36
+ >>> claims = get_current_claims()
37
+ >>> print(f"User ID: {claims.get('user_id')}")
38
+ """
39
+
40
+ import logging
41
+ import os
42
+ import threading
43
+ import time
44
+ from contextvars import ContextVar
45
+ from datetime import datetime, timedelta, timezone
46
+ from typing import Any, Callable, Dict, Optional, Tuple
47
+
48
+ import httpx
49
+ import jwt
50
+ from fastmcp.exceptions import ToolError
51
+ from fastmcp.server.middleware import Middleware, MiddlewareContext
52
+
53
+ from ..utils.http import get_http_client
54
+ from ..utils.security import sanitize_string
55
+
56
+ # Context-safe storage for sharing JWT tokens between middleware
57
+ # Using contextvars instead of threading.local() for async safety
58
+ jwt_token_var: ContextVar[Optional[str]] = ContextVar("jwt_token", default=None)
59
+ jwt_claims_var: ContextVar[Optional[dict]] = ContextVar("jwt_claims", default=None)
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ class AuthConfig:
65
+ """Configuration for authentication middleware.
66
+
67
+ This class manages all configuration settings for the authentication
68
+ middleware system, including JWT validation, refresh token conversion,
69
+ and caching settings. It supports loading configuration from environment
70
+ variables and provides validation methods.
71
+
72
+ The class handles:
73
+ - JWT validation settings (issuer, audience, signature verification)
74
+ - Refresh token conversion settings and handlers
75
+ - Caching configuration and TTL settings
76
+ - Environment variable loading and validation
77
+ - Provider-specific configurations
78
+
79
+ Key Features:
80
+ - Environment variable configuration for operator control
81
+ - Validation of configuration completeness
82
+ - Support for multiple authentication providers
83
+ - Flexible refresh token handler configuration
84
+ - Comprehensive JWT validation options
85
+
86
+ Examples:
87
+ >>> config = AuthConfig()
88
+ >>> config.load_from_env()
89
+ >>> if config.validate():
90
+ ... print("Configuration is valid")
91
+
92
+ >>> # Configure refresh token handlers
93
+ >>> config.set_refresh_token_handlers(
94
+ ... request_builder=lambda token: {"token": token},
95
+ ... response_parser=lambda data: data.get("jwt")
96
+ ... )
97
+ """
98
+
99
+ def __init__(self):
100
+ # General settings
101
+ self.enabled = False
102
+ self.jwt_validation_enabled = False
103
+ self.refresh_token_enabled = False
104
+
105
+ # JWT validation settings
106
+ self.jwt_issuer: Optional[str] = None
107
+ self.jwt_audience: Optional[str] = None
108
+ self.jwt_jwks_uri: Optional[str] = None
109
+ self.jwt_public_key: Optional[str] = None
110
+ self.jwt_verify_signature = True
111
+ self.jwt_verify_iss = True
112
+ self.jwt_verify_aud = True
113
+ self.jwt_required_claims: list[str] = []
114
+
115
+ # Refresh token settings
116
+ self.refresh_token_endpoint: Optional[str] = None
117
+ self.refresh_token_request_builder: Optional[Callable[[str], dict]] = None
118
+ self.refresh_token_response_parser: Optional[
119
+ Callable[[dict], Optional[str]]
120
+ ] = None
121
+ self.refresh_token_pattern: Optional[Callable[[str], bool]] = None
122
+
123
+ # Caching settings
124
+ self.jwt_cache_ttl = 3000 # 50 minutes (OpenBridge JWTs typically last 1 hour)
125
+ self.cache_cleanup_interval = 300 # 5 minutes
126
+
127
+ def load_from_env(self) -> None:
128
+ """Load configuration from environment variables.
129
+
130
+ This method loads all authentication configuration settings from
131
+ environment variables, providing operator control over the
132
+ authentication system without code changes.
133
+
134
+ Environment Variables:
135
+ - AUTH_ENABLED: Enable/disable authentication (default: false)
136
+ - JWT_VALIDATION_ENABLED: Enable JWT validation (default: true)
137
+ - REFRESH_TOKEN_ENABLED: Enable refresh token conversion (default: false)
138
+ - JWT_ISSUER: JWT issuer for validation
139
+ - JWT_AUDIENCE: JWT audience for validation
140
+ - JWT_JWKS_URI: JWKS endpoint for public key retrieval
141
+ - JWT_PUBLIC_KEY: Static public key for JWT validation
142
+ - JWT_VERIFY_SIGNATURE: Enable signature verification (default: true)
143
+ - JWT_VERIFY_ISS: Enable issuer verification (default: true)
144
+ - JWT_VERIFY_AUD: Enable audience verification (default: true)
145
+ - JWT_REQUIRED_CLAIMS: Comma-separated list of required claims
146
+ - REFRESH_TOKEN_ENDPOINT: Endpoint for refresh token conversion
147
+ - JWT_CACHE_TTL: JWT cache TTL in seconds (default: 3000)
148
+
149
+ Examples:
150
+ >>> config = AuthConfig()
151
+ >>> config.load_from_env()
152
+ >>> print(f"Authentication enabled: {config.enabled}")
153
+ """
154
+ self.enabled = os.getenv("AUTH_ENABLED", "false").lower() == "true"
155
+ self.jwt_validation_enabled = (
156
+ os.getenv("JWT_VALIDATION_ENABLED", "true").lower() == "true"
157
+ )
158
+ self.refresh_token_enabled = (
159
+ os.getenv("REFRESH_TOKEN_ENABLED", "false").lower() == "true"
160
+ )
161
+
162
+ # JWT settings
163
+ self.jwt_issuer = os.getenv("JWT_ISSUER")
164
+ self.jwt_audience = os.getenv("JWT_AUDIENCE")
165
+ self.jwt_jwks_uri = os.getenv("JWT_JWKS_URI")
166
+ self.jwt_public_key = os.getenv("JWT_PUBLIC_KEY")
167
+ self.jwt_verify_signature = (
168
+ os.getenv("JWT_VERIFY_SIGNATURE", "true").lower() == "true"
169
+ )
170
+ self.jwt_verify_iss = os.getenv("JWT_VERIFY_ISS", "true").lower() == "true"
171
+ self.jwt_verify_aud = os.getenv("JWT_VERIFY_AUD", "true").lower() == "true"
172
+
173
+ if os.getenv("JWT_REQUIRED_CLAIMS"):
174
+ self.jwt_required_claims = [
175
+ c.strip() for c in os.getenv("JWT_REQUIRED_CLAIMS").split(",")
176
+ ]
177
+
178
+ # Refresh token settings - only set if not already configured
179
+ if not self.refresh_token_endpoint:
180
+ self.refresh_token_endpoint = os.getenv("REFRESH_TOKEN_ENDPOINT")
181
+
182
+ # Cache settings
183
+ cache_ttl = os.getenv("JWT_CACHE_TTL")
184
+ if cache_ttl:
185
+ try:
186
+ self.jwt_cache_ttl = int(cache_ttl)
187
+ except ValueError:
188
+ logger.warning(f"Invalid JWT_CACHE_TTL: {cache_ttl}, using default")
189
+
190
+ def set_refresh_token_handlers(
191
+ self,
192
+ request_builder: Callable[[str], dict],
193
+ response_parser: Callable[[dict], Optional[str]],
194
+ pattern_detector: Callable[[str], bool] = None,
195
+ ):
196
+ """Set handlers for refresh token conversion.
197
+
198
+ This method configures the handlers needed for converting refresh
199
+ tokens to JWT tokens. These handlers define how to build requests
200
+ to the refresh token endpoint and how to parse the responses.
201
+
202
+ Args:
203
+ request_builder: Function that takes a refresh token and returns
204
+ the request payload for the refresh token endpoint.
205
+ response_parser: Function that takes the response data and returns
206
+ the JWT token, or None if parsing fails.
207
+ pattern_detector: Optional function that takes a token and returns
208
+ True if it matches the refresh token pattern. If None, a
209
+ default pattern detector is used.
210
+
211
+ Examples:
212
+ >>> def build_request(token):
213
+ ... return {"refresh_token": token}
214
+
215
+ >>> def parse_response(data):
216
+ ... return data.get("access_token")
217
+
218
+ >>> def detect_pattern(token):
219
+ ... return ":" in token and len(token) > 20
220
+
221
+ >>> config.set_refresh_token_handlers(
222
+ ... build_request, parse_response, detect_pattern
223
+ ... )
224
+ """
225
+ self.refresh_token_request_builder = request_builder
226
+ self.refresh_token_response_parser = response_parser
227
+ self.refresh_token_pattern = pattern_detector
228
+
229
+ def validate(self) -> bool:
230
+ """Validate configuration.
231
+
232
+ This method validates the authentication configuration to ensure
233
+ all required settings are properly configured for the enabled
234
+ features. It checks for logical consistency and completeness.
235
+
236
+ Returns:
237
+ True if the configuration is valid, False otherwise.
238
+
239
+ Validation Rules:
240
+ - If JWT validation is enabled, either signature verification or
241
+ required claims must be configured
242
+ - If refresh token conversion is enabled, the endpoint must be configured
243
+ - Refresh token handlers can be auto-configured if missing
244
+
245
+ Examples:
246
+ >>> config = AuthConfig()
247
+ >>> config.load_from_env()
248
+ >>> if config.validate():
249
+ ... print("Configuration is valid")
250
+ ... else:
251
+ ... print("Configuration has issues")
252
+ """
253
+ if not self.enabled:
254
+ return True
255
+
256
+ if self.jwt_validation_enabled:
257
+ if not self.jwt_verify_signature and not self.jwt_required_claims:
258
+ logger.warning(
259
+ "JWT validation enabled but no signature verification or required claims configured"
260
+ )
261
+
262
+ if self.refresh_token_enabled:
263
+ if not self.refresh_token_endpoint:
264
+ logger.error("Refresh token enabled but no endpoint configured")
265
+ return False
266
+ # Allow auto-configuration to handle missing handlers
267
+ if (
268
+ not self.refresh_token_request_builder
269
+ or not self.refresh_token_response_parser
270
+ ):
271
+ logger.info(
272
+ "Refresh token handlers not configured - will be auto-configured"
273
+ )
274
+
275
+ return True
276
+
277
+
278
+ class JWTCache:
279
+ """Thread-safe JWT cache with automatic cleanup.
280
+
281
+ This class provides a thread-safe caching mechanism for JWT tokens
282
+ to reduce API calls and improve performance. It automatically
283
+ cleans up expired entries and provides efficient token storage.
284
+
285
+ The cache provides:
286
+ - Thread-safe operations using locks
287
+ - Automatic expiration based on TTL
288
+ - Periodic cleanup of expired entries
289
+ - Efficient storage and retrieval
290
+
291
+ Key Features:
292
+ - Thread-safe operations for concurrent access
293
+ - Automatic cleanup to prevent memory leaks
294
+ - Configurable TTL and cleanup intervals
295
+ - Efficient storage with minimal overhead
296
+
297
+ Examples:
298
+ >>> cache = JWTCache(ttl=3000, cleanup_interval=300)
299
+ >>> cache.set("user123", "jwt_token_here")
300
+ >>> token = cache.get("user123")
301
+ >>> print(f"Retrieved token: {token is not None}")
302
+ """
303
+
304
+ def __init__(self, ttl: int = 3000, cleanup_interval: int = 300, auth_manager=None):
305
+ self._cache: Dict[str, Tuple[str, float]] = {}
306
+ self._lock = threading.Lock()
307
+ self._ttl = ttl
308
+ self._cleanup_interval = cleanup_interval
309
+ self._last_cleanup = 0
310
+ self._auth_manager = (
311
+ auth_manager # Optional AuthManager for TokenStore integration
312
+ )
313
+
314
+ def get(self, key: str) -> Optional[str]:
315
+ """Get JWT from cache if not expired.
316
+
317
+ This method retrieves a JWT token from the cache if it exists
318
+ and has not expired. It automatically performs cleanup of
319
+ expired entries during retrieval operations.
320
+
321
+ Args:
322
+ key: The cache key for the JWT token.
323
+
324
+ Returns:
325
+ The JWT token if found and not expired, None otherwise.
326
+
327
+ Examples:
328
+ >>> cache = JWTCache()
329
+ >>> token = cache.get("user123")
330
+ >>> if token:
331
+ ... print("Token found and valid")
332
+ ... else:
333
+ ... print("Token not found or expired")
334
+ """
335
+ # Try to get from TokenStore if AuthManager is available
336
+ if self._auth_manager and hasattr(self._auth_manager, "get_token"):
337
+ import asyncio
338
+
339
+ # Parse key to extract provider and identity info
340
+ # Key format: "provider:identity:region" or similar
341
+ parts = key.split(":")
342
+ if len(parts) >= 2:
343
+ _provider_type = parts[0]
344
+ _identity_id = parts[1]
345
+ _region = parts[2] if len(parts) > 2 else None
346
+
347
+ # Skip async TokenStore lookup in sync context
348
+ # Creating event loops here is problematic for Python 3.13
349
+ try:
350
+ asyncio.get_running_loop()
351
+ # In async context, but skip lookup to avoid blocking
352
+ logger.debug("In async context - deferring TokenStore lookup")
353
+ except RuntimeError:
354
+ # Not in async context - cannot safely perform async operations
355
+ # Token will be retrieved from local cache only
356
+ logger.debug("Not in async context - TokenStore lookup unavailable")
357
+
358
+ # Fall back to local cache
359
+ now = time.time()
360
+
361
+ # Clean up expired entries periodically
362
+ if now - self._last_cleanup > self._cleanup_interval:
363
+ self._cleanup(now)
364
+ self._last_cleanup = now
365
+
366
+ with self._lock:
367
+ if key in self._cache:
368
+ jwt_token, expiry_time = self._cache[key]
369
+ if now < expiry_time:
370
+ return jwt_token
371
+ else:
372
+ # Remove expired entry
373
+ del self._cache[key]
374
+ logger.debug(f"Removed expired cache entry for key: {key[:20]}...")
375
+
376
+ return None
377
+
378
+ def set(self, key: str, jwt_token: str) -> None:
379
+ """Cache JWT with expiration.
380
+
381
+ This method stores a JWT token in the cache with an expiration
382
+ time based on the configured TTL. The token will be automatically
383
+ removed when it expires.
384
+
385
+ Args:
386
+ key: The cache key for the JWT token.
387
+ jwt_token: The JWT token to cache.
388
+
389
+ Examples:
390
+ >>> cache = JWTCache(ttl=3000)
391
+ >>> cache.set("user123", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...")
392
+ >>> print("Token cached successfully")
393
+ """
394
+ now = time.time()
395
+ expiry_time = now + self._ttl
396
+
397
+ # Store in TokenStore if AuthManager is available
398
+ if self._auth_manager and hasattr(self._auth_manager, "set_token"):
399
+ import asyncio
400
+
401
+ from ..auth.token_store import TokenKind
402
+
403
+ # Parse key to extract provider and identity info
404
+ parts = key.split(":")
405
+ if len(parts) >= 2:
406
+ provider_type = parts[0]
407
+ identity_id = parts[1]
408
+ region = parts[2] if len(parts) > 2 else None
409
+
410
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=self._ttl)
411
+
412
+ try:
413
+ # Check if we're in an async context
414
+ try:
415
+ asyncio.get_running_loop()
416
+ # We're in async context, schedule the update
417
+ asyncio.create_task(
418
+ self._auth_manager.set_token(
419
+ provider_type=provider_type,
420
+ identity_id=identity_id,
421
+ token_kind=TokenKind.PROVIDER_JWT,
422
+ token=jwt_token,
423
+ expires_at=expires_at,
424
+ metadata={"cache_key": key},
425
+ region=region,
426
+ )
427
+ )
428
+ except RuntimeError:
429
+ # Not in async context - defer the update
430
+ logger.debug(
431
+ "Token store update deferred - not in async context"
432
+ )
433
+ except Exception as e:
434
+ logger.debug(f"Failed to store token in TokenStore: {e}")
435
+
436
+ # Also store in local cache for fast access
437
+ with self._lock:
438
+ self._cache[key] = (jwt_token, expiry_time)
439
+
440
+ def _cleanup(self, now: float) -> None:
441
+ """Remove expired cache entries.
442
+
443
+ This method removes all expired cache entries to prevent memory
444
+ leaks and maintain cache efficiency. It is called automatically
445
+ during cache operations.
446
+
447
+ Args:
448
+ now: The current timestamp for expiration comparison.
449
+
450
+ Examples:
451
+ >>> cache = JWTCache()
452
+ >>> # Cleanup is called automatically during get() operations
453
+ >>> cache._cleanup(time.time()) # Manual cleanup if needed
454
+ """
455
+ with self._lock:
456
+ expired_keys = [
457
+ key
458
+ for key, (_, expiry_time) in self._cache.items()
459
+ if now >= expiry_time
460
+ ]
461
+ for key in expired_keys:
462
+ del self._cache[key]
463
+
464
+ if expired_keys:
465
+ logger.info(f"Cleaned up {len(expired_keys)} expired JWT cache entries")
466
+
467
+
468
+ class RefreshTokenMiddleware(Middleware):
469
+ """Middleware to convert refresh tokens to JWT tokens with caching.
470
+
471
+ This middleware intercepts requests containing refresh tokens and
472
+ converts them to JWT tokens using configured endpoints. It provides
473
+ caching to reduce API calls and improve performance.
474
+
475
+ The middleware provides:
476
+ - Refresh token pattern detection
477
+ - Token conversion using configured endpoints
478
+ - JWT caching for performance optimization
479
+ - Context-safe JWT storage for other middleware
480
+ - Comprehensive error handling and logging
481
+
482
+ Key Features:
483
+ - Automatic refresh token detection using pattern matching
484
+ - Configurable request building and response parsing
485
+ - JWT caching to reduce API calls
486
+ - Context-safe token sharing with JWT middleware
487
+ - Comprehensive error handling and logging
488
+
489
+ Examples:
490
+ >>> config = AuthConfig()
491
+ >>> config.refresh_token_enabled = True
492
+ >>> config.refresh_token_endpoint = "https://api.example.com/refresh"
493
+ >>> middleware = RefreshTokenMiddleware(config)
494
+ """
495
+
496
+ def __init__(self, config: AuthConfig, auth_manager: Optional[Any] = None):
497
+ super().__init__()
498
+ self.config = config
499
+ self.auth_manager = auth_manager
500
+ self.logger = logging.getLogger(f"{__name__}.RefreshTokenMiddleware")
501
+ self._jwt_cache = JWTCache(
502
+ config.jwt_cache_ttl, config.cache_cleanup_interval, auth_manager
503
+ )
504
+
505
+ async def on_request(self, context: MiddlewareContext, call_next):
506
+ """Convert refresh tokens to JWT tokens if needed.
507
+
508
+ This method intercepts incoming requests and checks for refresh
509
+ tokens in the Authorization header. If a refresh token is detected,
510
+ it converts it to a JWT token and stores it in context-safe storage
511
+ for use by other middleware components.
512
+
513
+ Args:
514
+ context: The FastMCP middleware context.
515
+ call_next: The next middleware in the chain.
516
+
517
+ Returns:
518
+ The result of the next middleware in the chain.
519
+
520
+ Examples:
521
+ >>> # This method is called automatically by FastMCP
522
+ >>> # when requests are processed through the middleware chain
523
+ """
524
+ try:
525
+ # Get headers from the context if available
526
+ auth_header = ""
527
+ if context.fastmcp_context and context.fastmcp_context.request_context:
528
+ request = context.fastmcp_context.request_context.request
529
+ if request and hasattr(request, "headers"):
530
+ auth_header = request.headers.get("authorization", "")
531
+
532
+ if auth_header:
533
+ # Extract token more robustly - handle case variations and extra whitespace
534
+ parts = auth_header.split(" ", 1)
535
+ if len(parts) == 2 and parts[0].lower() == "bearer":
536
+ token = parts[1].strip() # Remove any leading/trailing whitespace
537
+
538
+ # CRITICAL: Always set refresh token in provider if available (for OpenBridge)
539
+ # This MUST happen even if config.enabled is False, so tools can use the token
540
+ # The provider needs the refresh token to authenticate API calls
541
+ if self.auth_manager and hasattr(self.auth_manager, "provider"):
542
+ provider = self.auth_manager.provider
543
+ if hasattr(provider, "set_refresh_token"):
544
+ self.logger.debug(
545
+ "Setting refresh token in OpenBridge provider from Authorization header"
546
+ )
547
+ provider.set_refresh_token(token)
548
+
549
+ # JWT conversion processing (only if enabled)
550
+ if self.config.enabled and self.config.refresh_token_enabled:
551
+ # Check if this matches the refresh token pattern for JWT conversion
552
+ if self.config.refresh_token_pattern and self.config.refresh_token_pattern(
553
+ token
554
+ ):
555
+ self.logger.debug("Detected refresh token format, checking cache...")
556
+
557
+ jwt_token = await self._get_cached_or_convert_jwt(token)
558
+ if jwt_token:
559
+ self.logger.debug("JWT token ready (cached or converted)")
560
+ # Store the JWT in context-safe storage for the JWT middleware to use
561
+ jwt_token_var.set(jwt_token)
562
+ else:
563
+ self.logger.error("Failed to convert refresh token to JWT")
564
+ else:
565
+ self.logger.debug(
566
+ "Token does not match refresh token pattern - skipping JWT conversion"
567
+ )
568
+
569
+ except ToolError:
570
+ # Let ToolError propagate - it's handled by FastMCP
571
+ raise
572
+ except Exception as e:
573
+ self.logger.error(f"RefreshTokenMiddleware error: {e}")
574
+
575
+ return await call_next(context)
576
+
577
+ async def _get_cached_or_convert_jwt(self, refresh_token: str) -> Optional[str]:
578
+ """Get JWT from cache or convert refresh token to JWT.
579
+
580
+ This method first checks the cache for an existing JWT token
581
+ for the given refresh token. If not found or expired, it
582
+ converts the refresh token to a JWT token using the configured
583
+ endpoint and caches the result.
584
+
585
+ Args:
586
+ refresh_token: The refresh token to convert or lookup.
587
+
588
+ Returns:
589
+ The JWT token if conversion/lookup is successful, None otherwise.
590
+
591
+ Examples:
592
+ >>> jwt_token = await middleware._get_cached_or_convert_jwt("refresh_token_here")
593
+ >>> if jwt_token:
594
+ ... print("JWT token obtained successfully")
595
+ """
596
+ # Check cache first
597
+ cached_jwt = self._jwt_cache.get(refresh_token)
598
+ if cached_jwt:
599
+ self.logger.debug("Using cached JWT token")
600
+ return cached_jwt
601
+
602
+ # Convert refresh token to JWT
603
+ self.logger.info("Converting refresh token to JWT (cache miss)...")
604
+ jwt_token = await self._convert_refresh_to_jwt(refresh_token)
605
+
606
+ if jwt_token:
607
+ # Cache the JWT
608
+ self._jwt_cache.set(refresh_token, jwt_token)
609
+ self.logger.info("Cached new JWT token")
610
+
611
+ return jwt_token
612
+
613
+ async def _convert_refresh_to_jwt(self, refresh_token: str) -> Optional[str]:
614
+ """Convert refresh token to JWT using configured endpoint.
615
+
616
+ This method converts a refresh token to a JWT token by making
617
+ a request to the configured refresh token endpoint. It uses
618
+ the configured request builder and response parser to handle
619
+ the conversion process.
620
+
621
+ Args:
622
+ refresh_token: The refresh token to convert.
623
+
624
+ Returns:
625
+ The JWT token if conversion is successful, None otherwise.
626
+
627
+ Raises:
628
+ httpx.HTTPError: If the HTTP request to the refresh endpoint fails.
629
+ Exception: If any other error occurs during conversion.
630
+
631
+ Examples:
632
+ >>> jwt_token = await middleware._convert_refresh_to_jwt("refresh_token_here")
633
+ >>> if jwt_token:
634
+ ... print("Token converted successfully")
635
+ """
636
+ try:
637
+ # Build request using configured builder
638
+ payload = self.config.refresh_token_request_builder(refresh_token)
639
+
640
+ # Get shared HTTP client
641
+ client = await get_http_client()
642
+
643
+ response = await client.post(
644
+ self.config.refresh_token_endpoint,
645
+ json=payload,
646
+ headers={"Content-Type": "application/json"},
647
+ timeout=10.0,
648
+ )
649
+ response.raise_for_status()
650
+
651
+ # Parse response using configured parser
652
+ response_data = response.json()
653
+ jwt_token = self.config.refresh_token_response_parser(response_data)
654
+
655
+ if jwt_token:
656
+ self.logger.debug("Successfully converted refresh token to JWT")
657
+ return jwt_token
658
+ else:
659
+ self.logger.error("Response parser returned no JWT token")
660
+ return None
661
+
662
+ except httpx.HTTPError as e:
663
+ self.logger.error(f"Failed to convert refresh token: {e}")
664
+ return None
665
+ except Exception as e:
666
+ self.logger.error(f"Error converting refresh token: {e}")
667
+ return None
668
+
669
+
670
+ class JWTAuthenticationMiddleware(Middleware):
671
+ """Middleware to validate JWT authentication with comprehensive error handling.
672
+
673
+ This middleware validates JWT tokens in incoming requests and provides
674
+ comprehensive error handling and logging. It supports both signature
675
+ verification and claim-based validation modes.
676
+
677
+ The middleware provides:
678
+ - JWT token extraction from Authorization headers
679
+ - Signature verification using public keys or JWKS
680
+ - Claim validation (issuer, audience, expiration)
681
+ - OpenBridge-specific validation (user_id, account_id)
682
+ - Context-safe claim storage for application use
683
+ - Comprehensive error handling and detailed logging
684
+
685
+ Key Features:
686
+ - Support for both signature verification and claim-only validation
687
+ - JWKS integration for dynamic public key retrieval
688
+ - OpenBridge-specific claim validation
689
+ - Context-safe claim storage using contextvars
690
+ - Comprehensive error handling with detailed logging
691
+ - Token corruption detection and cleanup attempts
692
+
693
+ Examples:
694
+ >>> config = AuthConfig()
695
+ >>> config.jwt_validation_enabled = True
696
+ >>> config.jwt_verify_signature = True
697
+ >>> middleware = JWTAuthenticationMiddleware(config)
698
+ """
699
+
700
+ def __init__(self, config: AuthConfig):
701
+ super().__init__()
702
+ self.config = config
703
+ self.logger = logging.getLogger(f"{__name__}.JWTAuthenticationMiddleware")
704
+
705
+ async def on_request(self, context: MiddlewareContext, call_next):
706
+ """Validate JWT authentication for all requests.
707
+
708
+ This method intercepts incoming requests and validates JWT tokens
709
+ from either context-safe storage (set by RefreshTokenMiddleware)
710
+ or Authorization headers. It stores validated claims in context
711
+ for use by other parts of the application.
712
+
713
+ Args:
714
+ context: The FastMCP middleware context.
715
+ call_next: The next middleware in the chain.
716
+
717
+ Returns:
718
+ The result of the next middleware in the chain.
719
+
720
+ Raises:
721
+ ToolError: If authentication fails or is missing.
722
+
723
+ Examples:
724
+ >>> # This method is called automatically by FastMCP
725
+ >>> # when requests are processed through the middleware chain
726
+ """
727
+ if not self.config.enabled or not self.config.jwt_validation_enabled:
728
+ self.logger.debug("Authentication disabled, skipping validation")
729
+ return await call_next(context)
730
+
731
+ try:
732
+ # Check for JWT set by RefreshTokenMiddleware in context-safe storage
733
+ jwt_token = jwt_token_var.get()
734
+ if jwt_token:
735
+ token = jwt_token
736
+ self.logger.info("Using JWT from context-safe storage")
737
+ # Clear the token from context after use
738
+ jwt_token_var.set(None)
739
+ else:
740
+ # Get token from headers via context
741
+ auth_header = ""
742
+ if context.fastmcp_context and context.fastmcp_context.request_context:
743
+ request = context.fastmcp_context.request_context.request
744
+ if request and hasattr(request, "headers"):
745
+ auth_header = request.headers.get("authorization", "")
746
+
747
+ self.logger.info(
748
+ f"JWT middleware - Authorization header present: {bool(auth_header)}"
749
+ )
750
+
751
+ if not auth_header:
752
+ self.logger.warning(
753
+ "JWT middleware - Missing Authorization header, rejecting request"
754
+ )
755
+ raise ToolError(
756
+ "Authentication required: Missing Authorization header"
757
+ )
758
+
759
+ # Extract token more robustly - handle case variations and extra whitespace
760
+ self.logger.debug(
761
+ f"Authorization header: {sanitize_string(auth_header)}"
762
+ )
763
+ parts = auth_header.split(" ", 1)
764
+ if len(parts) != 2 or parts[0].lower() != "bearer":
765
+ self.logger.warning(
766
+ f"JWT middleware - Invalid Authorization header format: {sanitize_string(auth_header)}"
767
+ )
768
+ raise ToolError(
769
+ "Authentication required: Invalid Authorization header format"
770
+ )
771
+
772
+ token = parts[1].strip() # Remove any leading/trailing whitespace
773
+ self.logger.debug(f"Extracted token (length: {len(token)}, type: JWT)")
774
+
775
+ # Validate the JWT token and store claims in context
776
+ claims = await self._validate_jwt_token(token)
777
+ if claims:
778
+ self.logger.info("JWT token validation successful")
779
+ # Store validated claims in context for other parts of the application
780
+ jwt_claims_var.set(claims)
781
+ return await call_next(context)
782
+ else:
783
+ self.logger.error("JWT token validation failed")
784
+ raise ToolError("Invalid authentication token")
785
+
786
+ except ToolError:
787
+ raise
788
+ except Exception as e:
789
+ self.logger.error(f"Authentication error: {e}")
790
+ raise ToolError("Authentication failed")
791
+
792
+ async def _validate_jwt_token(self, token: str) -> Optional[dict]:
793
+ """Validate JWT token and return claims if valid.
794
+
795
+ This method validates a JWT token using either signature verification
796
+ or claim-only validation based on the configuration. It handles
797
+ different validation modes and provides comprehensive error handling.
798
+
799
+ Args:
800
+ token: The JWT token to validate.
801
+
802
+ Returns:
803
+ The decoded claims if validation is successful, None otherwise.
804
+
805
+ Examples:
806
+ >>> claims = await middleware._validate_jwt_token("jwt_token_here")
807
+ >>> if claims:
808
+ ... print(f"User ID: {claims.get('user_id')}")
809
+ """
810
+ try:
811
+ if self.config.jwt_verify_signature:
812
+ return await self._validate_jwt_with_signature(token)
813
+ else:
814
+ return await self._validate_jwt_without_signature(token)
815
+ except Exception as e:
816
+ self.logger.warning(f"JWT decode failed: {e}")
817
+ return None
818
+
819
+ async def _validate_jwt_with_signature(self, token: str) -> Optional[dict]:
820
+ """Validate JWT token with signature verification and return claims.
821
+
822
+ This method validates a JWT token using cryptographic signature
823
+ verification. It retrieves the public key from configured sources
824
+ and validates the token's signature, issuer, audience, and expiration.
825
+
826
+ Args:
827
+ token: The JWT token to validate with signature verification.
828
+
829
+ Returns:
830
+ The decoded claims if signature verification is successful, None otherwise.
831
+
832
+ Raises:
833
+ jwt.ExpiredSignatureError: If the token has expired.
834
+ jwt.InvalidIssuerError: If the issuer validation fails.
835
+ jwt.InvalidAudienceError: If the audience validation fails.
836
+ jwt.InvalidSignatureError: If the signature validation fails.
837
+
838
+ Examples:
839
+ >>> claims = await middleware._validate_jwt_with_signature("jwt_token_here")
840
+ >>> if claims:
841
+ ... print("Signature verification successful")
842
+ """
843
+ try:
844
+ # Get public key
845
+ public_key = await self._get_public_key(token)
846
+ if not public_key:
847
+ self.logger.warning("No public key available for JWT validation")
848
+ return None
849
+
850
+ # Prepare decode options
851
+ decode_options = {
852
+ "verify_signature": True,
853
+ "verify_exp": True,
854
+ "verify_iss": self.config.jwt_verify_iss,
855
+ "verify_aud": self.config.jwt_verify_aud,
856
+ }
857
+
858
+ # Only validate audience if explicitly configured (like FastMCP BearerAuthProvider)
859
+ decode_kwargs = {
860
+ "token": token,
861
+ "key": public_key,
862
+ "algorithms": ["RS256", "ES256", "HS256"],
863
+ "options": decode_options,
864
+ }
865
+
866
+ # Add audience only if configured and verification is enabled
867
+ if self.config.jwt_audience and self.config.jwt_verify_aud:
868
+ decode_kwargs["audience"] = self.config.jwt_audience
869
+ elif self.config.jwt_verify_aud:
870
+ # If audience verification is enabled but no audience is configured, disable it
871
+ decode_options["verify_aud"] = False
872
+ self.logger.debug(
873
+ "Audience verification disabled - no audience configured"
874
+ )
875
+
876
+ # Decode with signature verification
877
+ payload = jwt.decode(**decode_kwargs)
878
+
879
+ # Validate required claims
880
+ for claim in self.config.jwt_required_claims:
881
+ if claim not in payload:
882
+ self.logger.warning(f"Missing required claim: {claim}")
883
+ return None
884
+
885
+ return payload
886
+
887
+ except jwt.ExpiredSignatureError:
888
+ self.logger.warning("JWT token is expired")
889
+ return None
890
+ except jwt.InvalidIssuerError:
891
+ self.logger.warning("JWT issuer validation failed")
892
+ return None
893
+ except jwt.InvalidAudienceError:
894
+ self.logger.warning("JWT audience validation failed")
895
+ return None
896
+ except jwt.InvalidSignatureError:
897
+ self.logger.warning("JWT signature validation failed")
898
+ return None
899
+ except Exception as e:
900
+ self.logger.warning(f"JWT validation error: {e}")
901
+ return None
902
+
903
+ async def _validate_jwt_without_signature(self, token: str) -> Optional[dict]:
904
+ """Validate JWT token without signature verification (OpenBridge style) and return claims.
905
+
906
+ This method validates a JWT token without cryptographic signature
907
+ verification, focusing on claim validation and OpenBridge-specific
908
+ requirements. It handles token corruption detection and cleanup attempts.
909
+
910
+ Args:
911
+ token: The JWT token to validate without signature verification.
912
+
913
+ Returns:
914
+ The decoded claims if validation is successful, None otherwise.
915
+
916
+ Examples:
917
+ >>> claims = await middleware._validate_jwt_without_signature("jwt_token_here")
918
+ >>> if claims:
919
+ ... print(f"OpenBridge validation successful - User: {claims.get('user_id')}")
920
+ """
921
+ try:
922
+ # Clean the token - remove any whitespace or encoding issues
923
+ token = token.strip()
924
+
925
+ # Debug token format - add more detailed logging
926
+ self.logger.debug(f"Validating JWT token (length: {len(token)})")
927
+
928
+ # Check for common corruption patterns
929
+ if token.startswith("Bearer "):
930
+ self.logger.error(
931
+ "Token still has 'Bearer ' prefix - extraction failed!"
932
+ )
933
+ token = token[7:].strip()
934
+
935
+ if " " in token:
936
+ self.logger.error("Token contains whitespace - may be corrupted")
937
+ self.logger.error(
938
+ f"Whitespace positions: {[i for i, c in enumerate(token) if c == ' ']}"
939
+ )
940
+ token = token.replace(" ", "")
941
+
942
+ if "\n" in token or "\r" in token:
943
+ self.logger.error("Token contains newlines - may be corrupted")
944
+ token = token.replace("\n", "").replace("\r", "")
945
+
946
+ # Try to decode the token parts first to understand the structure
947
+ try:
948
+ parts = token.split(".")
949
+ if len(parts) != 3:
950
+ self.logger.error(
951
+ f"Token doesn't have 3 parts! Found {len(parts)} parts"
952
+ )
953
+ return None
954
+
955
+ self.logger.debug(
956
+ f"Token has 3 parts: header({len(parts[0])}), payload({len(parts[1])}), signature({len(parts[2])})"
957
+ )
958
+ except Exception as e:
959
+ self.logger.error(f"Failed to split token: {e}")
960
+ return None
961
+
962
+ # Decode without signature verification
963
+ payload = jwt.decode(token, options={"verify_signature": False})
964
+
965
+ # Check if token is expired using expires_at (OpenBridge format)
966
+ expires_at = payload.get("expires_at")
967
+ if expires_at:
968
+ try:
969
+ exp_time = datetime.fromtimestamp(
970
+ float(expires_at), tz=timezone.utc
971
+ )
972
+ if datetime.now(timezone.utc) > exp_time:
973
+ self.logger.warning("JWT token is expired")
974
+ return None
975
+ except (ValueError, TypeError):
976
+ self.logger.warning(f"Invalid expires_at format: {expires_at}")
977
+ return None
978
+
979
+ # Validate OpenBridge-specific claims (user_id, account_id)
980
+ user_id = payload.get("user_id")
981
+ account_id = payload.get("account_id")
982
+
983
+ if not user_id or not account_id:
984
+ self.logger.warning(
985
+ "JWT missing required OpenBridge fields (user_id, account_id)"
986
+ )
987
+ return None
988
+
989
+ self.logger.info(
990
+ f"OpenBridge token validated - User: {user_id}, Account: {account_id}"
991
+ )
992
+ return payload
993
+
994
+ except jwt.DecodeError as e:
995
+ self.logger.warning(f"JWT decode error: {e}")
996
+ # Try to provide more specific error information
997
+ if "Invalid header padding" in str(e):
998
+ self.logger.error(
999
+ "JWT header padding error - token may be corrupted or improperly formatted"
1000
+ )
1001
+ self.logger.error(f"Original token: {sanitize_string(token)}")
1002
+
1003
+ # Try to clean the token and retry
1004
+ try:
1005
+ # Remove any potential encoding issues
1006
+ cleaned_token = (
1007
+ token.strip()
1008
+ .replace(" ", "")
1009
+ .replace("\n", "")
1010
+ .replace("\r", "")
1011
+ )
1012
+ if cleaned_token != token:
1013
+ self.logger.info("Attempting to decode cleaned token")
1014
+ self.logger.info(
1015
+ f"Cleaned token: {sanitize_string(cleaned_token)}"
1016
+ )
1017
+ payload = jwt.decode(
1018
+ cleaned_token, options={"verify_signature": False}
1019
+ )
1020
+ self.logger.info("Cleaned token decoded successfully")
1021
+ return payload
1022
+ except Exception as cleanup_error:
1023
+ self.logger.error(
1024
+ f"Failed to clean and decode token: {cleanup_error}"
1025
+ )
1026
+
1027
+ # Try to analyze the token structure
1028
+ try:
1029
+ parts = token.split(".")
1030
+ if len(parts) >= 1:
1031
+ self.logger.error(f"Header part: {sanitize_string(parts[0])}")
1032
+ if len(parts) >= 2:
1033
+ self.logger.error(
1034
+ f"Payload part: {sanitize_string(parts[1])}"
1035
+ )
1036
+ except Exception as analyze_error:
1037
+ self.logger.error(
1038
+ f"Failed to analyze token structure: {analyze_error}"
1039
+ )
1040
+
1041
+ return None
1042
+ except Exception as e:
1043
+ self.logger.warning(f"JWT validation error: {e}")
1044
+ return None
1045
+
1046
+ async def _get_public_key(self, token: str) -> Optional[str]:
1047
+ """Get public key for JWT validation.
1048
+
1049
+ This method retrieves the public key needed for JWT signature
1050
+ verification. It supports both static public keys and dynamic
1051
+ key retrieval from JWKS endpoints.
1052
+
1053
+ Args:
1054
+ token: The JWT token to extract key information from.
1055
+
1056
+ Returns:
1057
+ The public key in PEM format if available, None otherwise.
1058
+
1059
+ Examples:
1060
+ >>> public_key = await middleware._get_public_key("jwt_token_here")
1061
+ >>> if public_key:
1062
+ ... print("Public key retrieved successfully")
1063
+ """
1064
+ if self.config.jwt_public_key:
1065
+ return self.config.jwt_public_key
1066
+
1067
+ if self.config.jwt_jwks_uri:
1068
+ try:
1069
+ # Decode header to get key ID
1070
+ header = jwt.get_unverified_header(token)
1071
+ kid = header.get("kid")
1072
+
1073
+ if not kid:
1074
+ self.logger.warning("No key ID in JWT header")
1075
+ return None
1076
+
1077
+ # Get shared HTTP client
1078
+ client = await get_http_client()
1079
+
1080
+ # Fetch JWKS
1081
+ response = await client.get(self.config.jwt_jwks_uri, timeout=10.0)
1082
+ response.raise_for_status()
1083
+ jwks = response.json()
1084
+
1085
+ # Find the key
1086
+ for key in jwks.get("keys", []):
1087
+ if key.get("kid") == kid:
1088
+ return self._jwk_to_pem(key)
1089
+
1090
+ self.logger.warning(f"Key ID {kid} not found in JWKS")
1091
+ return None
1092
+
1093
+ except httpx.HTTPError as e:
1094
+ self.logger.error(f"Failed to fetch JWKS: {e}")
1095
+ return None
1096
+ except Exception as e:
1097
+ self.logger.error(f"Error processing JWKS: {e}")
1098
+ return None
1099
+
1100
+ return None
1101
+
1102
+ def _jwk_to_pem(self, jwk_key: dict) -> Optional[str]:
1103
+ """Convert JWK to PEM format.
1104
+
1105
+ This method converts a JSON Web Key (JWK) to PEM format for
1106
+ use in JWT signature verification. It currently supports RSA keys.
1107
+
1108
+ Args:
1109
+ jwk_key: The JWK key dictionary to convert.
1110
+
1111
+ Returns:
1112
+ The public key in PEM format if conversion is successful, None otherwise.
1113
+
1114
+ Examples:
1115
+ >>> jwk = {"kty": "RSA", "n": "...", "e": "..."}
1116
+ >>> pem_key = middleware._jwk_to_pem(jwk)
1117
+ >>> if pem_key:
1118
+ ... print("JWK converted to PEM successfully")
1119
+ """
1120
+ try:
1121
+ # This is a simplified conversion - in production you'd want a proper JWK library
1122
+ if jwk_key.get("kty") == "RSA":
1123
+ from cryptography.hazmat.primitives import serialization
1124
+ from cryptography.hazmat.primitives.asymmetric import rsa
1125
+
1126
+ n = int.from_bytes(jwt.utils.base64url_decode(jwk_key["n"]), "big")
1127
+ e = int.from_bytes(jwt.utils.base64url_decode(jwk_key["e"]), "big")
1128
+
1129
+ public_key = rsa.RSAPublicNumbers(e, n).public_key()
1130
+ pem = public_key.public_bytes(
1131
+ encoding=serialization.Encoding.PEM,
1132
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
1133
+ )
1134
+ return pem.decode("utf-8")
1135
+
1136
+ except Exception as e:
1137
+ self.logger.error(f"Error converting JWK to PEM: {e}")
1138
+
1139
+ return None
1140
+
1141
+
1142
+ # Utility functions for accessing JWT data from context
1143
+ def get_current_jwt() -> Optional[str]:
1144
+ """Get JWT token for current request context.
1145
+
1146
+ This function retrieves the JWT token from the current request context
1147
+ using context-safe storage. It provides access to the JWT token
1148
+ for other parts of the application that need it.
1149
+
1150
+ :return: The JWT token for the current request context, or None if not available
1151
+ :rtype: Optional[str]
1152
+
1153
+ .. example::
1154
+ >>> jwt_token = get_current_jwt()
1155
+ >>> if jwt_token:
1156
+ ... print("JWT token available in current context")
1157
+ """
1158
+ return jwt_token_var.get()
1159
+
1160
+
1161
+ def get_current_claims() -> Optional[dict]:
1162
+ """Get JWT claims for current request context.
1163
+
1164
+ This function retrieves the validated JWT claims from the current
1165
+ request context using context-safe storage. It provides access to
1166
+ user information and other claims for application logic.
1167
+
1168
+ :return: The JWT claims dictionary for the current request context, or None if not available
1169
+ :rtype: Optional[dict]
1170
+
1171
+ .. example::
1172
+ >>> claims = get_current_claims()
1173
+ >>> if claims:
1174
+ ... print(f"User ID: {claims.get('user_id')}")
1175
+ ... print(f"Account ID: {claims.get('account_id')}")
1176
+ """
1177
+ return jwt_claims_var.get()
1178
+
1179
+
1180
+ def create_auth_middleware(
1181
+ config: Optional[AuthConfig] = None,
1182
+ refresh_token_middleware: bool = True,
1183
+ jwt_middleware: bool = True,
1184
+ auth_manager: Optional[Any] = None,
1185
+ ) -> list:
1186
+ """Create authentication middleware chain.
1187
+
1188
+ This function creates a complete authentication middleware chain
1189
+ with optional refresh token conversion and JWT validation. It
1190
+ supports auto-configuration for common providers like OpenBridge.
1191
+
1192
+ The function automatically detects provider types and configures
1193
+ appropriate handlers for refresh token conversion and JWT validation.
1194
+ It supports OpenBridge, generic JSON:API, and custom configurations.
1195
+
1196
+ :param config: Optional AuthConfig instance. If None, creates a new config and loads settings from environment variables
1197
+ :type config: Optional[AuthConfig]
1198
+ :param refresh_token_middleware: Whether to include refresh token conversion middleware
1199
+ :type refresh_token_middleware: bool
1200
+ :param jwt_middleware: Whether to include JWT validation middleware
1201
+ :type jwt_middleware: bool
1202
+ :param auth_manager: Optional AuthManager instance for token store integration
1203
+ :type auth_manager: Optional[Any]
1204
+ :return: A list of middleware instances ready for use with FastMCP
1205
+ :rtype: list
1206
+
1207
+ .. example::
1208
+ >>> # Auto-configure from environment
1209
+ >>> middleware = create_auth_middleware()
1210
+
1211
+ >>> # Use specific configuration
1212
+ >>> config = AuthConfig()
1213
+ >>> config.load_from_env()
1214
+ >>> middleware = create_auth_middleware(config)
1215
+
1216
+ >>> # JWT validation only
1217
+ >>> middleware = create_auth_middleware(
1218
+ ... refresh_token_middleware=False, jwt_middleware=True
1219
+ ... )
1220
+ """
1221
+ if not config:
1222
+ config = AuthConfig()
1223
+ config.load_from_env()
1224
+
1225
+ # Auto-configure OpenBridge if the endpoint is detected
1226
+ if config.refresh_token_enabled and config.refresh_token_endpoint:
1227
+ # Validate hostname is legitimate OpenBridge domain
1228
+ from urllib.parse import urlparse
1229
+ parsed_endpoint = urlparse(config.refresh_token_endpoint)
1230
+ endpoint_host = (parsed_endpoint.hostname or "").lower()
1231
+ is_openbridge = endpoint_host.endswith(".openbridge.io") or endpoint_host == "openbridge.io"
1232
+ if is_openbridge:
1233
+ logger.info("Auto-configuring OpenBridge authentication")
1234
+ # Use OpenBridge-specific configuration
1235
+ config = create_openbridge_config()
1236
+ # Override with environment variables but preserve auto-configured handlers
1237
+ config.load_from_env()
1238
+ elif not config.refresh_token_request_builder:
1239
+ # Auto-configure generic JSON:API if no handlers are set
1240
+ logger.info("Auto-configuring generic JSON:API authentication")
1241
+ config = create_json_api_refresh_token_config(
1242
+ endpoint_url=config.refresh_token_endpoint,
1243
+ token_type_name=os.getenv("JSON_API_TOKEN_TYPE_NAME", "APIAuth"),
1244
+ required_claims=config.jwt_required_claims or ["user_id"],
1245
+ verify_signature=config.jwt_verify_signature,
1246
+ )
1247
+ # Override with environment variables but preserve auto-configured handlers
1248
+ config.load_from_env()
1249
+
1250
+ if not config.validate():
1251
+ logger.error("Invalid authentication configuration")
1252
+ return []
1253
+
1254
+ middleware = []
1255
+
1256
+ # Add refresh token middleware first (converts refresh tokens to JWTs)
1257
+ if refresh_token_middleware and config.refresh_token_enabled:
1258
+ middleware.append(RefreshTokenMiddleware(config, auth_manager))
1259
+ logger.info("Added RefreshTokenMiddleware")
1260
+
1261
+ # Add JWT authentication middleware (validates JWTs)
1262
+ if jwt_middleware and config.jwt_validation_enabled:
1263
+ middleware.append(JWTAuthenticationMiddleware(config))
1264
+ logger.info("Added JWTAuthenticationMiddleware")
1265
+
1266
+ logger.info(f"Created {len(middleware)} middleware components")
1267
+ return middleware
1268
+
1269
+
1270
+ def create_json_api_refresh_token_config(
1271
+ endpoint_url: str,
1272
+ token_type_name: str,
1273
+ required_claims: list[str],
1274
+ verify_signature: bool = True,
1275
+ ) -> AuthConfig:
1276
+ """Create configuration for JSON:API style refresh token endpoints.
1277
+
1278
+ This function creates a pre-configured AuthConfig for JSON:API
1279
+ style refresh token endpoints with standard request/response
1280
+ patterns and handlers.
1281
+
1282
+ The configuration includes standard JSON:API request builders,
1283
+ response parsers, and token pattern detection for automatic
1284
+ refresh token handling.
1285
+
1286
+ :param endpoint_url: The refresh token endpoint URL
1287
+ :type endpoint_url: str
1288
+ :param token_type_name: The JSON:API resource type name for tokens
1289
+ :type token_type_name: str
1290
+ :param required_claims: List of required JWT claims for validation
1291
+ :type required_claims: list[str]
1292
+ :param verify_signature: Whether to verify JWT signatures
1293
+ :type verify_signature: bool
1294
+ :return: A configured AuthConfig instance for JSON:API refresh tokens
1295
+ :rtype: AuthConfig
1296
+
1297
+ .. example::
1298
+ >>> config = create_json_api_refresh_token_config(
1299
+ ... endpoint_url="https://api.example.com/auth/refresh",
1300
+ ... token_type_name="APIAuth",
1301
+ ... required_claims=["user_id", "account_id"]
1302
+ ... )
1303
+ >>> middleware = create_auth_middleware(config)
1304
+ """
1305
+ config = AuthConfig()
1306
+ config.enabled = True
1307
+ config.refresh_token_enabled = True
1308
+ config.jwt_validation_enabled = True
1309
+ config.refresh_token_endpoint = endpoint_url
1310
+ config.jwt_required_claims = required_claims
1311
+ config.jwt_verify_signature = verify_signature
1312
+
1313
+ # JSON:API request builder
1314
+ def request_builder(refresh_token: str) -> dict:
1315
+ return {
1316
+ "data": {
1317
+ "type": token_type_name,
1318
+ "attributes": {"refresh_token": refresh_token},
1319
+ }
1320
+ }
1321
+
1322
+ # JSON:API response parser
1323
+ def response_parser(response_data: dict) -> Optional[str]:
1324
+ try:
1325
+ return response_data.get("data", {}).get("attributes", {}).get("token")
1326
+ except Exception:
1327
+ return None
1328
+
1329
+ # Pattern detector for refresh tokens (simple heuristic)
1330
+ def pattern_detector(token: str) -> bool:
1331
+ # OpenBridge refresh tokens typically contain a colon
1332
+ return ":" in token and len(token) > 20
1333
+
1334
+ config.set_refresh_token_handlers(
1335
+ request_builder, response_parser, pattern_detector
1336
+ )
1337
+ return config
1338
+
1339
+
1340
+ def create_openbridge_config() -> AuthConfig:
1341
+ """Create configuration for OpenBridge authentication.
1342
+
1343
+ This function creates a pre-configured AuthConfig specifically
1344
+ for OpenBridge authentication with the correct endpoint, token
1345
+ type, and validation settings.
1346
+
1347
+ The configuration automatically detects the OpenBridge authentication
1348
+ base URL from environment variables and sets up appropriate handlers
1349
+ for refresh token conversion and JWT validation.
1350
+
1351
+ :return: A configured AuthConfig instance for OpenBridge authentication
1352
+ :rtype: AuthConfig
1353
+
1354
+ .. example::
1355
+ >>> config = create_openbridge_config()
1356
+ >>> middleware = create_auth_middleware(config)
1357
+
1358
+ .. note::
1359
+ The configuration includes:
1360
+
1361
+ - OpenBridge refresh token endpoint
1362
+ - APIAuth token type
1363
+ - user_id and account_id required claims
1364
+ - Signature verification disabled (tokens trusted from API)
1365
+ """
1366
+ # Build endpoint from env, with explicit REFRESH_TOKEN_ENDPOINT taking precedence
1367
+ auth_base = os.getenv(
1368
+ "OPENBRIDGE_AUTH_BASE_URL", "https://authentication.api.openbridge.io"
1369
+ ).rstrip("/")
1370
+ endpoint_url = os.getenv("REFRESH_TOKEN_ENDPOINT", f"{auth_base}/auth/api/refresh")
1371
+
1372
+ config = create_json_api_refresh_token_config(
1373
+ endpoint_url=endpoint_url,
1374
+ token_type_name="APIAuth",
1375
+ required_claims=["user_id", "account_id"],
1376
+ verify_signature=False, # OpenBridge JWTs are trusted from the API, no public key available
1377
+ )
1378
+
1379
+ # OpenBridge-specific settings
1380
+ config.jwt_verify_iss = False # Don't validate issuer for OpenBridge
1381
+ config.jwt_verify_aud = False # Don't validate audience for OpenBridge
1382
+
1383
+ # Respect AUTH_ENABLED environment variable
1384
+ config.load_from_env()
1385
+
1386
+ return config
1387
+
1388
+
1389
+ def create_auth0_config(domain: str, audience: str) -> AuthConfig:
1390
+ """Create Auth0 configuration.
1391
+
1392
+ This function creates a pre-configured AuthConfig for Auth0
1393
+ authentication with the correct issuer, audience, and JWKS
1394
+ endpoint settings.
1395
+
1396
+ The configuration sets up standard Auth0 JWT validation with
1397
+ public key retrieval from the JWKS endpoint and proper issuer
1398
+ and audience validation.
1399
+
1400
+ :param domain: The Auth0 domain (e.g., "example.auth0.com")
1401
+ :type domain: str
1402
+ :param audience: The Auth0 API audience identifier
1403
+ :type audience: str
1404
+ :return: A configured AuthConfig instance for Auth0 authentication
1405
+ :rtype: AuthConfig
1406
+
1407
+ .. example::
1408
+ >>> config = create_auth0_config(
1409
+ ... domain="example.auth0.com",
1410
+ ... audience="https://api.example.com"
1411
+ ... )
1412
+ >>> middleware = create_auth_middleware(config)
1413
+
1414
+ .. note::
1415
+ The configuration includes:
1416
+
1417
+ - Auth0 issuer URL
1418
+ - JWKS endpoint for public key retrieval
1419
+ - Audience validation
1420
+ """
1421
+ config = AuthConfig()
1422
+ config.load_from_env()
1423
+ config.jwt_issuer = f"https://{domain}/"
1424
+ config.jwt_audience = audience
1425
+ config.jwt_jwks_uri = f"https://{domain}/.well-known/jwks.json"
1426
+ return config
1427
+
1428
+
1429
+ async def get_auth_info() -> Dict[str, Any]:
1430
+ """Get authentication information.
1431
+
1432
+ This function retrieves comprehensive information about the current
1433
+ authentication configuration, including enabled features and settings.
1434
+ Useful for debugging and monitoring authentication status.
1435
+
1436
+ The returned information includes all key configuration settings
1437
+ loaded from environment variables and provides insight into the
1438
+ current authentication state.
1439
+
1440
+ :return: A dictionary containing authentication configuration information
1441
+ :rtype: Dict[str, Any]
1442
+
1443
+ Dictionary Keys:
1444
+
1445
+ - enabled: Whether authentication is enabled
1446
+ - jwt_validation_enabled: Whether JWT validation is enabled
1447
+ - refresh_token_enabled: Whether refresh token conversion is enabled
1448
+ - jwt_issuer: Configured JWT issuer
1449
+ - jwt_audience: Configured JWT audience
1450
+ - jwt_verify_signature: Whether signature verification is enabled
1451
+ - jwt_verify_iss: Whether issuer verification is enabled
1452
+ - jwt_verify_aud: Whether audience verification is enabled
1453
+ - jwt_required_claims: List of required JWT claims
1454
+
1455
+ .. example::
1456
+ >>> auth_info = await get_auth_info()
1457
+ >>> print(f"Authentication enabled: {auth_info['enabled']}")
1458
+ >>> print(f"JWT validation enabled: {auth_info['jwt_validation_enabled']}")
1459
+ >>> print(f"Required claims: {auth_info['jwt_required_claims']}")
1460
+ """
1461
+ config = AuthConfig()
1462
+ config.load_from_env()
1463
+
1464
+ return {
1465
+ "enabled": config.enabled,
1466
+ "jwt_validation_enabled": config.jwt_validation_enabled,
1467
+ "refresh_token_enabled": config.refresh_token_enabled,
1468
+ "jwt_issuer": config.jwt_issuer,
1469
+ "jwt_audience": config.jwt_audience,
1470
+ "jwt_verify_signature": config.jwt_verify_signature,
1471
+ "jwt_verify_iss": config.jwt_verify_iss,
1472
+ "jwt_verify_aud": config.jwt_verify_aud,
1473
+ "jwt_required_claims": config.jwt_required_claims,
1474
+ }