amazon-ads-mcp 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amazon_ads_mcp/__init__.py +11 -0
- amazon_ads_mcp/auth/__init__.py +33 -0
- amazon_ads_mcp/auth/base.py +211 -0
- amazon_ads_mcp/auth/hooks.py +172 -0
- amazon_ads_mcp/auth/manager.py +791 -0
- amazon_ads_mcp/auth/oauth_state_store.py +277 -0
- amazon_ads_mcp/auth/providers/__init__.py +14 -0
- amazon_ads_mcp/auth/providers/direct.py +393 -0
- amazon_ads_mcp/auth/providers/example_auth0.py.example +216 -0
- amazon_ads_mcp/auth/providers/openbridge.py +512 -0
- amazon_ads_mcp/auth/registry.py +146 -0
- amazon_ads_mcp/auth/secure_token_store.py +297 -0
- amazon_ads_mcp/auth/token_store.py +723 -0
- amazon_ads_mcp/config/__init__.py +5 -0
- amazon_ads_mcp/config/sampling.py +111 -0
- amazon_ads_mcp/config/settings.py +366 -0
- amazon_ads_mcp/exceptions.py +314 -0
- amazon_ads_mcp/middleware/__init__.py +11 -0
- amazon_ads_mcp/middleware/authentication.py +1474 -0
- amazon_ads_mcp/middleware/caching.py +177 -0
- amazon_ads_mcp/middleware/oauth.py +175 -0
- amazon_ads_mcp/middleware/sampling.py +112 -0
- amazon_ads_mcp/models/__init__.py +320 -0
- amazon_ads_mcp/models/amc_models.py +837 -0
- amazon_ads_mcp/models/api_responses.py +847 -0
- amazon_ads_mcp/models/base_models.py +215 -0
- amazon_ads_mcp/models/builtin_responses.py +496 -0
- amazon_ads_mcp/models/dsp_models.py +556 -0
- amazon_ads_mcp/models/stores_brands.py +610 -0
- amazon_ads_mcp/server/__init__.py +6 -0
- amazon_ads_mcp/server/__main__.py +6 -0
- amazon_ads_mcp/server/builtin_prompts.py +269 -0
- amazon_ads_mcp/server/builtin_tools.py +962 -0
- amazon_ads_mcp/server/file_routes.py +547 -0
- amazon_ads_mcp/server/html_templates.py +149 -0
- amazon_ads_mcp/server/mcp_server.py +327 -0
- amazon_ads_mcp/server/openapi_utils.py +158 -0
- amazon_ads_mcp/server/sampling_handler.py +251 -0
- amazon_ads_mcp/server/server_builder.py +751 -0
- amazon_ads_mcp/server/sidecar_loader.py +178 -0
- amazon_ads_mcp/server/transform_executor.py +827 -0
- amazon_ads_mcp/tools/__init__.py +22 -0
- amazon_ads_mcp/tools/cache_management.py +105 -0
- amazon_ads_mcp/tools/download_tools.py +267 -0
- amazon_ads_mcp/tools/identity.py +236 -0
- amazon_ads_mcp/tools/oauth.py +598 -0
- amazon_ads_mcp/tools/profile.py +150 -0
- amazon_ads_mcp/tools/profile_listing.py +285 -0
- amazon_ads_mcp/tools/region.py +320 -0
- amazon_ads_mcp/tools/region_identity.py +175 -0
- amazon_ads_mcp/utils/__init__.py +6 -0
- amazon_ads_mcp/utils/async_compat.py +215 -0
- amazon_ads_mcp/utils/errors.py +452 -0
- amazon_ads_mcp/utils/export_content_type_resolver.py +249 -0
- amazon_ads_mcp/utils/export_download_handler.py +579 -0
- amazon_ads_mcp/utils/header_resolver.py +81 -0
- amazon_ads_mcp/utils/http/__init__.py +56 -0
- amazon_ads_mcp/utils/http/circuit_breaker.py +127 -0
- amazon_ads_mcp/utils/http/client_manager.py +329 -0
- amazon_ads_mcp/utils/http/request.py +207 -0
- amazon_ads_mcp/utils/http/resilience.py +512 -0
- amazon_ads_mcp/utils/http/resilient_client.py +195 -0
- amazon_ads_mcp/utils/http/retry.py +76 -0
- amazon_ads_mcp/utils/http_client.py +873 -0
- amazon_ads_mcp/utils/media/__init__.py +21 -0
- amazon_ads_mcp/utils/media/negotiator.py +243 -0
- amazon_ads_mcp/utils/media/types.py +199 -0
- amazon_ads_mcp/utils/openapi/__init__.py +16 -0
- amazon_ads_mcp/utils/openapi/json.py +55 -0
- amazon_ads_mcp/utils/openapi/loader.py +263 -0
- amazon_ads_mcp/utils/openapi/refs.py +46 -0
- amazon_ads_mcp/utils/region_config.py +200 -0
- amazon_ads_mcp/utils/response_wrapper.py +171 -0
- amazon_ads_mcp/utils/sampling_helpers.py +156 -0
- amazon_ads_mcp/utils/sampling_wrapper.py +173 -0
- amazon_ads_mcp/utils/security.py +630 -0
- amazon_ads_mcp/utils/tool_naming.py +137 -0
- amazon_ads_mcp-0.2.7.dist-info/METADATA +664 -0
- amazon_ads_mcp-0.2.7.dist-info/RECORD +82 -0
- amazon_ads_mcp-0.2.7.dist-info/WHEEL +4 -0
- amazon_ads_mcp-0.2.7.dist-info/entry_points.txt +3 -0
- amazon_ads_mcp-0.2.7.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1474 @@
|
|
|
1
|
+
"""Reusable FastMCP Authentication Middleware - Version 5.0 (Production Ready).
|
|
2
|
+
|
|
3
|
+
This module provides a comprehensive authentication middleware system for
|
|
4
|
+
FastMCP servers with support for JWT validation, refresh token conversion,
|
|
5
|
+
and context-safe token sharing between middleware components.
|
|
6
|
+
|
|
7
|
+
The module provides:
|
|
8
|
+
- AuthConfig: Configuration management for authentication settings
|
|
9
|
+
- JWTCache: Thread-safe JWT caching with automatic cleanup
|
|
10
|
+
- RefreshTokenMiddleware: Converts refresh tokens to JWT tokens
|
|
11
|
+
- JWTAuthenticationMiddleware: Validates JWT tokens with comprehensive error handling
|
|
12
|
+
- Utility functions for accessing JWT data and creating middleware chains
|
|
13
|
+
- Pre-configured configurations for common providers (OpenBridge, Auth0, JSON:API)
|
|
14
|
+
|
|
15
|
+
Key Features:
|
|
16
|
+
- FastMCP-compliant middleware patterns with proper hooks
|
|
17
|
+
- Context-safe JWT storage using contextvars for async safety
|
|
18
|
+
- JWT caching to reduce API calls and improve performance
|
|
19
|
+
- Comprehensive error handling and detailed logging
|
|
20
|
+
- OpenBridge-specific validation (user_id, account_id claims)
|
|
21
|
+
- Environment variable configuration for operator control
|
|
22
|
+
- Client disconnection handling and timeout management
|
|
23
|
+
- Support for multiple authentication providers
|
|
24
|
+
|
|
25
|
+
Examples:
|
|
26
|
+
>>> from .middleware.authentication import create_auth_middleware
|
|
27
|
+
>>> middleware = create_auth_middleware() # Auto-configure from environment
|
|
28
|
+
|
|
29
|
+
>>> # Use with specific configuration
|
|
30
|
+
>>> config = AuthConfig()
|
|
31
|
+
>>> config.load_from_env()
|
|
32
|
+
>>> middleware = create_auth_middleware(config)
|
|
33
|
+
|
|
34
|
+
>>> # Access JWT data in other parts of the application
|
|
35
|
+
>>> from .middleware.authentication import get_current_claims
|
|
36
|
+
>>> claims = get_current_claims()
|
|
37
|
+
>>> print(f"User ID: {claims.get('user_id')}")
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
import logging
|
|
41
|
+
import os
|
|
42
|
+
import threading
|
|
43
|
+
import time
|
|
44
|
+
from contextvars import ContextVar
|
|
45
|
+
from datetime import datetime, timedelta, timezone
|
|
46
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
47
|
+
|
|
48
|
+
import httpx
|
|
49
|
+
import jwt
|
|
50
|
+
from fastmcp.exceptions import ToolError
|
|
51
|
+
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
|
52
|
+
|
|
53
|
+
from ..utils.http import get_http_client
|
|
54
|
+
from ..utils.security import sanitize_string
|
|
55
|
+
|
|
56
|
+
# Context-safe storage for sharing JWT tokens between middleware
|
|
57
|
+
# Using contextvars instead of threading.local() for async safety
|
|
58
|
+
jwt_token_var: ContextVar[Optional[str]] = ContextVar("jwt_token", default=None)
|
|
59
|
+
jwt_claims_var: ContextVar[Optional[dict]] = ContextVar("jwt_claims", default=None)
|
|
60
|
+
|
|
61
|
+
logger = logging.getLogger(__name__)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class AuthConfig:
|
|
65
|
+
"""Configuration for authentication middleware.
|
|
66
|
+
|
|
67
|
+
This class manages all configuration settings for the authentication
|
|
68
|
+
middleware system, including JWT validation, refresh token conversion,
|
|
69
|
+
and caching settings. It supports loading configuration from environment
|
|
70
|
+
variables and provides validation methods.
|
|
71
|
+
|
|
72
|
+
The class handles:
|
|
73
|
+
- JWT validation settings (issuer, audience, signature verification)
|
|
74
|
+
- Refresh token conversion settings and handlers
|
|
75
|
+
- Caching configuration and TTL settings
|
|
76
|
+
- Environment variable loading and validation
|
|
77
|
+
- Provider-specific configurations
|
|
78
|
+
|
|
79
|
+
Key Features:
|
|
80
|
+
- Environment variable configuration for operator control
|
|
81
|
+
- Validation of configuration completeness
|
|
82
|
+
- Support for multiple authentication providers
|
|
83
|
+
- Flexible refresh token handler configuration
|
|
84
|
+
- Comprehensive JWT validation options
|
|
85
|
+
|
|
86
|
+
Examples:
|
|
87
|
+
>>> config = AuthConfig()
|
|
88
|
+
>>> config.load_from_env()
|
|
89
|
+
>>> if config.validate():
|
|
90
|
+
... print("Configuration is valid")
|
|
91
|
+
|
|
92
|
+
>>> # Configure refresh token handlers
|
|
93
|
+
>>> config.set_refresh_token_handlers(
|
|
94
|
+
... request_builder=lambda token: {"token": token},
|
|
95
|
+
... response_parser=lambda data: data.get("jwt")
|
|
96
|
+
... )
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(self):
|
|
100
|
+
# General settings
|
|
101
|
+
self.enabled = False
|
|
102
|
+
self.jwt_validation_enabled = False
|
|
103
|
+
self.refresh_token_enabled = False
|
|
104
|
+
|
|
105
|
+
# JWT validation settings
|
|
106
|
+
self.jwt_issuer: Optional[str] = None
|
|
107
|
+
self.jwt_audience: Optional[str] = None
|
|
108
|
+
self.jwt_jwks_uri: Optional[str] = None
|
|
109
|
+
self.jwt_public_key: Optional[str] = None
|
|
110
|
+
self.jwt_verify_signature = True
|
|
111
|
+
self.jwt_verify_iss = True
|
|
112
|
+
self.jwt_verify_aud = True
|
|
113
|
+
self.jwt_required_claims: list[str] = []
|
|
114
|
+
|
|
115
|
+
# Refresh token settings
|
|
116
|
+
self.refresh_token_endpoint: Optional[str] = None
|
|
117
|
+
self.refresh_token_request_builder: Optional[Callable[[str], dict]] = None
|
|
118
|
+
self.refresh_token_response_parser: Optional[
|
|
119
|
+
Callable[[dict], Optional[str]]
|
|
120
|
+
] = None
|
|
121
|
+
self.refresh_token_pattern: Optional[Callable[[str], bool]] = None
|
|
122
|
+
|
|
123
|
+
# Caching settings
|
|
124
|
+
self.jwt_cache_ttl = 3000 # 50 minutes (OpenBridge JWTs typically last 1 hour)
|
|
125
|
+
self.cache_cleanup_interval = 300 # 5 minutes
|
|
126
|
+
|
|
127
|
+
def load_from_env(self) -> None:
|
|
128
|
+
"""Load configuration from environment variables.
|
|
129
|
+
|
|
130
|
+
This method loads all authentication configuration settings from
|
|
131
|
+
environment variables, providing operator control over the
|
|
132
|
+
authentication system without code changes.
|
|
133
|
+
|
|
134
|
+
Environment Variables:
|
|
135
|
+
- AUTH_ENABLED: Enable/disable authentication (default: false)
|
|
136
|
+
- JWT_VALIDATION_ENABLED: Enable JWT validation (default: true)
|
|
137
|
+
- REFRESH_TOKEN_ENABLED: Enable refresh token conversion (default: false)
|
|
138
|
+
- JWT_ISSUER: JWT issuer for validation
|
|
139
|
+
- JWT_AUDIENCE: JWT audience for validation
|
|
140
|
+
- JWT_JWKS_URI: JWKS endpoint for public key retrieval
|
|
141
|
+
- JWT_PUBLIC_KEY: Static public key for JWT validation
|
|
142
|
+
- JWT_VERIFY_SIGNATURE: Enable signature verification (default: true)
|
|
143
|
+
- JWT_VERIFY_ISS: Enable issuer verification (default: true)
|
|
144
|
+
- JWT_VERIFY_AUD: Enable audience verification (default: true)
|
|
145
|
+
- JWT_REQUIRED_CLAIMS: Comma-separated list of required claims
|
|
146
|
+
- REFRESH_TOKEN_ENDPOINT: Endpoint for refresh token conversion
|
|
147
|
+
- JWT_CACHE_TTL: JWT cache TTL in seconds (default: 3000)
|
|
148
|
+
|
|
149
|
+
Examples:
|
|
150
|
+
>>> config = AuthConfig()
|
|
151
|
+
>>> config.load_from_env()
|
|
152
|
+
>>> print(f"Authentication enabled: {config.enabled}")
|
|
153
|
+
"""
|
|
154
|
+
self.enabled = os.getenv("AUTH_ENABLED", "false").lower() == "true"
|
|
155
|
+
self.jwt_validation_enabled = (
|
|
156
|
+
os.getenv("JWT_VALIDATION_ENABLED", "true").lower() == "true"
|
|
157
|
+
)
|
|
158
|
+
self.refresh_token_enabled = (
|
|
159
|
+
os.getenv("REFRESH_TOKEN_ENABLED", "false").lower() == "true"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# JWT settings
|
|
163
|
+
self.jwt_issuer = os.getenv("JWT_ISSUER")
|
|
164
|
+
self.jwt_audience = os.getenv("JWT_AUDIENCE")
|
|
165
|
+
self.jwt_jwks_uri = os.getenv("JWT_JWKS_URI")
|
|
166
|
+
self.jwt_public_key = os.getenv("JWT_PUBLIC_KEY")
|
|
167
|
+
self.jwt_verify_signature = (
|
|
168
|
+
os.getenv("JWT_VERIFY_SIGNATURE", "true").lower() == "true"
|
|
169
|
+
)
|
|
170
|
+
self.jwt_verify_iss = os.getenv("JWT_VERIFY_ISS", "true").lower() == "true"
|
|
171
|
+
self.jwt_verify_aud = os.getenv("JWT_VERIFY_AUD", "true").lower() == "true"
|
|
172
|
+
|
|
173
|
+
if os.getenv("JWT_REQUIRED_CLAIMS"):
|
|
174
|
+
self.jwt_required_claims = [
|
|
175
|
+
c.strip() for c in os.getenv("JWT_REQUIRED_CLAIMS").split(",")
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
# Refresh token settings - only set if not already configured
|
|
179
|
+
if not self.refresh_token_endpoint:
|
|
180
|
+
self.refresh_token_endpoint = os.getenv("REFRESH_TOKEN_ENDPOINT")
|
|
181
|
+
|
|
182
|
+
# Cache settings
|
|
183
|
+
cache_ttl = os.getenv("JWT_CACHE_TTL")
|
|
184
|
+
if cache_ttl:
|
|
185
|
+
try:
|
|
186
|
+
self.jwt_cache_ttl = int(cache_ttl)
|
|
187
|
+
except ValueError:
|
|
188
|
+
logger.warning(f"Invalid JWT_CACHE_TTL: {cache_ttl}, using default")
|
|
189
|
+
|
|
190
|
+
def set_refresh_token_handlers(
|
|
191
|
+
self,
|
|
192
|
+
request_builder: Callable[[str], dict],
|
|
193
|
+
response_parser: Callable[[dict], Optional[str]],
|
|
194
|
+
pattern_detector: Callable[[str], bool] = None,
|
|
195
|
+
):
|
|
196
|
+
"""Set handlers for refresh token conversion.
|
|
197
|
+
|
|
198
|
+
This method configures the handlers needed for converting refresh
|
|
199
|
+
tokens to JWT tokens. These handlers define how to build requests
|
|
200
|
+
to the refresh token endpoint and how to parse the responses.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
request_builder: Function that takes a refresh token and returns
|
|
204
|
+
the request payload for the refresh token endpoint.
|
|
205
|
+
response_parser: Function that takes the response data and returns
|
|
206
|
+
the JWT token, or None if parsing fails.
|
|
207
|
+
pattern_detector: Optional function that takes a token and returns
|
|
208
|
+
True if it matches the refresh token pattern. If None, a
|
|
209
|
+
default pattern detector is used.
|
|
210
|
+
|
|
211
|
+
Examples:
|
|
212
|
+
>>> def build_request(token):
|
|
213
|
+
... return {"refresh_token": token}
|
|
214
|
+
|
|
215
|
+
>>> def parse_response(data):
|
|
216
|
+
... return data.get("access_token")
|
|
217
|
+
|
|
218
|
+
>>> def detect_pattern(token):
|
|
219
|
+
... return ":" in token and len(token) > 20
|
|
220
|
+
|
|
221
|
+
>>> config.set_refresh_token_handlers(
|
|
222
|
+
... build_request, parse_response, detect_pattern
|
|
223
|
+
... )
|
|
224
|
+
"""
|
|
225
|
+
self.refresh_token_request_builder = request_builder
|
|
226
|
+
self.refresh_token_response_parser = response_parser
|
|
227
|
+
self.refresh_token_pattern = pattern_detector
|
|
228
|
+
|
|
229
|
+
def validate(self) -> bool:
|
|
230
|
+
"""Validate configuration.
|
|
231
|
+
|
|
232
|
+
This method validates the authentication configuration to ensure
|
|
233
|
+
all required settings are properly configured for the enabled
|
|
234
|
+
features. It checks for logical consistency and completeness.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
True if the configuration is valid, False otherwise.
|
|
238
|
+
|
|
239
|
+
Validation Rules:
|
|
240
|
+
- If JWT validation is enabled, either signature verification or
|
|
241
|
+
required claims must be configured
|
|
242
|
+
- If refresh token conversion is enabled, the endpoint must be configured
|
|
243
|
+
- Refresh token handlers can be auto-configured if missing
|
|
244
|
+
|
|
245
|
+
Examples:
|
|
246
|
+
>>> config = AuthConfig()
|
|
247
|
+
>>> config.load_from_env()
|
|
248
|
+
>>> if config.validate():
|
|
249
|
+
... print("Configuration is valid")
|
|
250
|
+
... else:
|
|
251
|
+
... print("Configuration has issues")
|
|
252
|
+
"""
|
|
253
|
+
if not self.enabled:
|
|
254
|
+
return True
|
|
255
|
+
|
|
256
|
+
if self.jwt_validation_enabled:
|
|
257
|
+
if not self.jwt_verify_signature and not self.jwt_required_claims:
|
|
258
|
+
logger.warning(
|
|
259
|
+
"JWT validation enabled but no signature verification or required claims configured"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if self.refresh_token_enabled:
|
|
263
|
+
if not self.refresh_token_endpoint:
|
|
264
|
+
logger.error("Refresh token enabled but no endpoint configured")
|
|
265
|
+
return False
|
|
266
|
+
# Allow auto-configuration to handle missing handlers
|
|
267
|
+
if (
|
|
268
|
+
not self.refresh_token_request_builder
|
|
269
|
+
or not self.refresh_token_response_parser
|
|
270
|
+
):
|
|
271
|
+
logger.info(
|
|
272
|
+
"Refresh token handlers not configured - will be auto-configured"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return True
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class JWTCache:
|
|
279
|
+
"""Thread-safe JWT cache with automatic cleanup.
|
|
280
|
+
|
|
281
|
+
This class provides a thread-safe caching mechanism for JWT tokens
|
|
282
|
+
to reduce API calls and improve performance. It automatically
|
|
283
|
+
cleans up expired entries and provides efficient token storage.
|
|
284
|
+
|
|
285
|
+
The cache provides:
|
|
286
|
+
- Thread-safe operations using locks
|
|
287
|
+
- Automatic expiration based on TTL
|
|
288
|
+
- Periodic cleanup of expired entries
|
|
289
|
+
- Efficient storage and retrieval
|
|
290
|
+
|
|
291
|
+
Key Features:
|
|
292
|
+
- Thread-safe operations for concurrent access
|
|
293
|
+
- Automatic cleanup to prevent memory leaks
|
|
294
|
+
- Configurable TTL and cleanup intervals
|
|
295
|
+
- Efficient storage with minimal overhead
|
|
296
|
+
|
|
297
|
+
Examples:
|
|
298
|
+
>>> cache = JWTCache(ttl=3000, cleanup_interval=300)
|
|
299
|
+
>>> cache.set("user123", "jwt_token_here")
|
|
300
|
+
>>> token = cache.get("user123")
|
|
301
|
+
>>> print(f"Retrieved token: {token is not None}")
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(self, ttl: int = 3000, cleanup_interval: int = 300, auth_manager=None):
|
|
305
|
+
self._cache: Dict[str, Tuple[str, float]] = {}
|
|
306
|
+
self._lock = threading.Lock()
|
|
307
|
+
self._ttl = ttl
|
|
308
|
+
self._cleanup_interval = cleanup_interval
|
|
309
|
+
self._last_cleanup = 0
|
|
310
|
+
self._auth_manager = (
|
|
311
|
+
auth_manager # Optional AuthManager for TokenStore integration
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def get(self, key: str) -> Optional[str]:
|
|
315
|
+
"""Get JWT from cache if not expired.
|
|
316
|
+
|
|
317
|
+
This method retrieves a JWT token from the cache if it exists
|
|
318
|
+
and has not expired. It automatically performs cleanup of
|
|
319
|
+
expired entries during retrieval operations.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
key: The cache key for the JWT token.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
The JWT token if found and not expired, None otherwise.
|
|
326
|
+
|
|
327
|
+
Examples:
|
|
328
|
+
>>> cache = JWTCache()
|
|
329
|
+
>>> token = cache.get("user123")
|
|
330
|
+
>>> if token:
|
|
331
|
+
... print("Token found and valid")
|
|
332
|
+
... else:
|
|
333
|
+
... print("Token not found or expired")
|
|
334
|
+
"""
|
|
335
|
+
# Try to get from TokenStore if AuthManager is available
|
|
336
|
+
if self._auth_manager and hasattr(self._auth_manager, "get_token"):
|
|
337
|
+
import asyncio
|
|
338
|
+
|
|
339
|
+
# Parse key to extract provider and identity info
|
|
340
|
+
# Key format: "provider:identity:region" or similar
|
|
341
|
+
parts = key.split(":")
|
|
342
|
+
if len(parts) >= 2:
|
|
343
|
+
_provider_type = parts[0]
|
|
344
|
+
_identity_id = parts[1]
|
|
345
|
+
_region = parts[2] if len(parts) > 2 else None
|
|
346
|
+
|
|
347
|
+
# Skip async TokenStore lookup in sync context
|
|
348
|
+
# Creating event loops here is problematic for Python 3.13
|
|
349
|
+
try:
|
|
350
|
+
asyncio.get_running_loop()
|
|
351
|
+
# In async context, but skip lookup to avoid blocking
|
|
352
|
+
logger.debug("In async context - deferring TokenStore lookup")
|
|
353
|
+
except RuntimeError:
|
|
354
|
+
# Not in async context - cannot safely perform async operations
|
|
355
|
+
# Token will be retrieved from local cache only
|
|
356
|
+
logger.debug("Not in async context - TokenStore lookup unavailable")
|
|
357
|
+
|
|
358
|
+
# Fall back to local cache
|
|
359
|
+
now = time.time()
|
|
360
|
+
|
|
361
|
+
# Clean up expired entries periodically
|
|
362
|
+
if now - self._last_cleanup > self._cleanup_interval:
|
|
363
|
+
self._cleanup(now)
|
|
364
|
+
self._last_cleanup = now
|
|
365
|
+
|
|
366
|
+
with self._lock:
|
|
367
|
+
if key in self._cache:
|
|
368
|
+
jwt_token, expiry_time = self._cache[key]
|
|
369
|
+
if now < expiry_time:
|
|
370
|
+
return jwt_token
|
|
371
|
+
else:
|
|
372
|
+
# Remove expired entry
|
|
373
|
+
del self._cache[key]
|
|
374
|
+
logger.debug(f"Removed expired cache entry for key: {key[:20]}...")
|
|
375
|
+
|
|
376
|
+
return None
|
|
377
|
+
|
|
378
|
+
def set(self, key: str, jwt_token: str) -> None:
|
|
379
|
+
"""Cache JWT with expiration.
|
|
380
|
+
|
|
381
|
+
This method stores a JWT token in the cache with an expiration
|
|
382
|
+
time based on the configured TTL. The token will be automatically
|
|
383
|
+
removed when it expires.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
key: The cache key for the JWT token.
|
|
387
|
+
jwt_token: The JWT token to cache.
|
|
388
|
+
|
|
389
|
+
Examples:
|
|
390
|
+
>>> cache = JWTCache(ttl=3000)
|
|
391
|
+
>>> cache.set("user123", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...")
|
|
392
|
+
>>> print("Token cached successfully")
|
|
393
|
+
"""
|
|
394
|
+
now = time.time()
|
|
395
|
+
expiry_time = now + self._ttl
|
|
396
|
+
|
|
397
|
+
# Store in TokenStore if AuthManager is available
|
|
398
|
+
if self._auth_manager and hasattr(self._auth_manager, "set_token"):
|
|
399
|
+
import asyncio
|
|
400
|
+
|
|
401
|
+
from ..auth.token_store import TokenKind
|
|
402
|
+
|
|
403
|
+
# Parse key to extract provider and identity info
|
|
404
|
+
parts = key.split(":")
|
|
405
|
+
if len(parts) >= 2:
|
|
406
|
+
provider_type = parts[0]
|
|
407
|
+
identity_id = parts[1]
|
|
408
|
+
region = parts[2] if len(parts) > 2 else None
|
|
409
|
+
|
|
410
|
+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=self._ttl)
|
|
411
|
+
|
|
412
|
+
try:
|
|
413
|
+
# Check if we're in an async context
|
|
414
|
+
try:
|
|
415
|
+
asyncio.get_running_loop()
|
|
416
|
+
# We're in async context, schedule the update
|
|
417
|
+
asyncio.create_task(
|
|
418
|
+
self._auth_manager.set_token(
|
|
419
|
+
provider_type=provider_type,
|
|
420
|
+
identity_id=identity_id,
|
|
421
|
+
token_kind=TokenKind.PROVIDER_JWT,
|
|
422
|
+
token=jwt_token,
|
|
423
|
+
expires_at=expires_at,
|
|
424
|
+
metadata={"cache_key": key},
|
|
425
|
+
region=region,
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
except RuntimeError:
|
|
429
|
+
# Not in async context - defer the update
|
|
430
|
+
logger.debug(
|
|
431
|
+
"Token store update deferred - not in async context"
|
|
432
|
+
)
|
|
433
|
+
except Exception as e:
|
|
434
|
+
logger.debug(f"Failed to store token in TokenStore: {e}")
|
|
435
|
+
|
|
436
|
+
# Also store in local cache for fast access
|
|
437
|
+
with self._lock:
|
|
438
|
+
self._cache[key] = (jwt_token, expiry_time)
|
|
439
|
+
|
|
440
|
+
def _cleanup(self, now: float) -> None:
|
|
441
|
+
"""Remove expired cache entries.
|
|
442
|
+
|
|
443
|
+
This method removes all expired cache entries to prevent memory
|
|
444
|
+
leaks and maintain cache efficiency. It is called automatically
|
|
445
|
+
during cache operations.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
now: The current timestamp for expiration comparison.
|
|
449
|
+
|
|
450
|
+
Examples:
|
|
451
|
+
>>> cache = JWTCache()
|
|
452
|
+
>>> # Cleanup is called automatically during get() operations
|
|
453
|
+
>>> cache._cleanup(time.time()) # Manual cleanup if needed
|
|
454
|
+
"""
|
|
455
|
+
with self._lock:
|
|
456
|
+
expired_keys = [
|
|
457
|
+
key
|
|
458
|
+
for key, (_, expiry_time) in self._cache.items()
|
|
459
|
+
if now >= expiry_time
|
|
460
|
+
]
|
|
461
|
+
for key in expired_keys:
|
|
462
|
+
del self._cache[key]
|
|
463
|
+
|
|
464
|
+
if expired_keys:
|
|
465
|
+
logger.info(f"Cleaned up {len(expired_keys)} expired JWT cache entries")
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class RefreshTokenMiddleware(Middleware):
|
|
469
|
+
"""Middleware to convert refresh tokens to JWT tokens with caching.
|
|
470
|
+
|
|
471
|
+
This middleware intercepts requests containing refresh tokens and
|
|
472
|
+
converts them to JWT tokens using configured endpoints. It provides
|
|
473
|
+
caching to reduce API calls and improve performance.
|
|
474
|
+
|
|
475
|
+
The middleware provides:
|
|
476
|
+
- Refresh token pattern detection
|
|
477
|
+
- Token conversion using configured endpoints
|
|
478
|
+
- JWT caching for performance optimization
|
|
479
|
+
- Context-safe JWT storage for other middleware
|
|
480
|
+
- Comprehensive error handling and logging
|
|
481
|
+
|
|
482
|
+
Key Features:
|
|
483
|
+
- Automatic refresh token detection using pattern matching
|
|
484
|
+
- Configurable request building and response parsing
|
|
485
|
+
- JWT caching to reduce API calls
|
|
486
|
+
- Context-safe token sharing with JWT middleware
|
|
487
|
+
- Comprehensive error handling and logging
|
|
488
|
+
|
|
489
|
+
Examples:
|
|
490
|
+
>>> config = AuthConfig()
|
|
491
|
+
>>> config.refresh_token_enabled = True
|
|
492
|
+
>>> config.refresh_token_endpoint = "https://api.example.com/refresh"
|
|
493
|
+
>>> middleware = RefreshTokenMiddleware(config)
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
def __init__(self, config: AuthConfig, auth_manager: Optional[Any] = None):
|
|
497
|
+
super().__init__()
|
|
498
|
+
self.config = config
|
|
499
|
+
self.auth_manager = auth_manager
|
|
500
|
+
self.logger = logging.getLogger(f"{__name__}.RefreshTokenMiddleware")
|
|
501
|
+
self._jwt_cache = JWTCache(
|
|
502
|
+
config.jwt_cache_ttl, config.cache_cleanup_interval, auth_manager
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
async def on_request(self, context: MiddlewareContext, call_next):
|
|
506
|
+
"""Convert refresh tokens to JWT tokens if needed.
|
|
507
|
+
|
|
508
|
+
This method intercepts incoming requests and checks for refresh
|
|
509
|
+
tokens in the Authorization header. If a refresh token is detected,
|
|
510
|
+
it converts it to a JWT token and stores it in context-safe storage
|
|
511
|
+
for use by other middleware components.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
context: The FastMCP middleware context.
|
|
515
|
+
call_next: The next middleware in the chain.
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
The result of the next middleware in the chain.
|
|
519
|
+
|
|
520
|
+
Examples:
|
|
521
|
+
>>> # This method is called automatically by FastMCP
|
|
522
|
+
>>> # when requests are processed through the middleware chain
|
|
523
|
+
"""
|
|
524
|
+
try:
|
|
525
|
+
# Get headers from the context if available
|
|
526
|
+
auth_header = ""
|
|
527
|
+
if context.fastmcp_context and context.fastmcp_context.request_context:
|
|
528
|
+
request = context.fastmcp_context.request_context.request
|
|
529
|
+
if request and hasattr(request, "headers"):
|
|
530
|
+
auth_header = request.headers.get("authorization", "")
|
|
531
|
+
|
|
532
|
+
if auth_header:
|
|
533
|
+
# Extract token more robustly - handle case variations and extra whitespace
|
|
534
|
+
parts = auth_header.split(" ", 1)
|
|
535
|
+
if len(parts) == 2 and parts[0].lower() == "bearer":
|
|
536
|
+
token = parts[1].strip() # Remove any leading/trailing whitespace
|
|
537
|
+
|
|
538
|
+
# CRITICAL: Always set refresh token in provider if available (for OpenBridge)
|
|
539
|
+
# This MUST happen even if config.enabled is False, so tools can use the token
|
|
540
|
+
# The provider needs the refresh token to authenticate API calls
|
|
541
|
+
if self.auth_manager and hasattr(self.auth_manager, "provider"):
|
|
542
|
+
provider = self.auth_manager.provider
|
|
543
|
+
if hasattr(provider, "set_refresh_token"):
|
|
544
|
+
self.logger.debug(
|
|
545
|
+
"Setting refresh token in OpenBridge provider from Authorization header"
|
|
546
|
+
)
|
|
547
|
+
provider.set_refresh_token(token)
|
|
548
|
+
|
|
549
|
+
# JWT conversion processing (only if enabled)
|
|
550
|
+
if self.config.enabled and self.config.refresh_token_enabled:
|
|
551
|
+
# Check if this matches the refresh token pattern for JWT conversion
|
|
552
|
+
if self.config.refresh_token_pattern and self.config.refresh_token_pattern(
|
|
553
|
+
token
|
|
554
|
+
):
|
|
555
|
+
self.logger.debug("Detected refresh token format, checking cache...")
|
|
556
|
+
|
|
557
|
+
jwt_token = await self._get_cached_or_convert_jwt(token)
|
|
558
|
+
if jwt_token:
|
|
559
|
+
self.logger.debug("JWT token ready (cached or converted)")
|
|
560
|
+
# Store the JWT in context-safe storage for the JWT middleware to use
|
|
561
|
+
jwt_token_var.set(jwt_token)
|
|
562
|
+
else:
|
|
563
|
+
self.logger.error("Failed to convert refresh token to JWT")
|
|
564
|
+
else:
|
|
565
|
+
self.logger.debug(
|
|
566
|
+
"Token does not match refresh token pattern - skipping JWT conversion"
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
except ToolError:
|
|
570
|
+
# Let ToolError propagate - it's handled by FastMCP
|
|
571
|
+
raise
|
|
572
|
+
except Exception as e:
|
|
573
|
+
self.logger.error(f"RefreshTokenMiddleware error: {e}")
|
|
574
|
+
|
|
575
|
+
return await call_next(context)
|
|
576
|
+
|
|
577
|
+
async def _get_cached_or_convert_jwt(self, refresh_token: str) -> Optional[str]:
|
|
578
|
+
"""Get JWT from cache or convert refresh token to JWT.
|
|
579
|
+
|
|
580
|
+
This method first checks the cache for an existing JWT token
|
|
581
|
+
for the given refresh token. If not found or expired, it
|
|
582
|
+
converts the refresh token to a JWT token using the configured
|
|
583
|
+
endpoint and caches the result.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
refresh_token: The refresh token to convert or lookup.
|
|
587
|
+
|
|
588
|
+
Returns:
|
|
589
|
+
The JWT token if conversion/lookup is successful, None otherwise.
|
|
590
|
+
|
|
591
|
+
Examples:
|
|
592
|
+
>>> jwt_token = await middleware._get_cached_or_convert_jwt("refresh_token_here")
|
|
593
|
+
>>> if jwt_token:
|
|
594
|
+
... print("JWT token obtained successfully")
|
|
595
|
+
"""
|
|
596
|
+
# Check cache first
|
|
597
|
+
cached_jwt = self._jwt_cache.get(refresh_token)
|
|
598
|
+
if cached_jwt:
|
|
599
|
+
self.logger.debug("Using cached JWT token")
|
|
600
|
+
return cached_jwt
|
|
601
|
+
|
|
602
|
+
# Convert refresh token to JWT
|
|
603
|
+
self.logger.info("Converting refresh token to JWT (cache miss)...")
|
|
604
|
+
jwt_token = await self._convert_refresh_to_jwt(refresh_token)
|
|
605
|
+
|
|
606
|
+
if jwt_token:
|
|
607
|
+
# Cache the JWT
|
|
608
|
+
self._jwt_cache.set(refresh_token, jwt_token)
|
|
609
|
+
self.logger.info("Cached new JWT token")
|
|
610
|
+
|
|
611
|
+
return jwt_token
|
|
612
|
+
|
|
613
|
+
async def _convert_refresh_to_jwt(self, refresh_token: str) -> Optional[str]:
|
|
614
|
+
"""Convert refresh token to JWT using configured endpoint.
|
|
615
|
+
|
|
616
|
+
This method converts a refresh token to a JWT token by making
|
|
617
|
+
a request to the configured refresh token endpoint. It uses
|
|
618
|
+
the configured request builder and response parser to handle
|
|
619
|
+
the conversion process.
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
refresh_token: The refresh token to convert.
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
The JWT token if conversion is successful, None otherwise.
|
|
626
|
+
|
|
627
|
+
Raises:
|
|
628
|
+
httpx.HTTPError: If the HTTP request to the refresh endpoint fails.
|
|
629
|
+
Exception: If any other error occurs during conversion.
|
|
630
|
+
|
|
631
|
+
Examples:
|
|
632
|
+
>>> jwt_token = await middleware._convert_refresh_to_jwt("refresh_token_here")
|
|
633
|
+
>>> if jwt_token:
|
|
634
|
+
... print("Token converted successfully")
|
|
635
|
+
"""
|
|
636
|
+
try:
|
|
637
|
+
# Build request using configured builder
|
|
638
|
+
payload = self.config.refresh_token_request_builder(refresh_token)
|
|
639
|
+
|
|
640
|
+
# Get shared HTTP client
|
|
641
|
+
client = await get_http_client()
|
|
642
|
+
|
|
643
|
+
response = await client.post(
|
|
644
|
+
self.config.refresh_token_endpoint,
|
|
645
|
+
json=payload,
|
|
646
|
+
headers={"Content-Type": "application/json"},
|
|
647
|
+
timeout=10.0,
|
|
648
|
+
)
|
|
649
|
+
response.raise_for_status()
|
|
650
|
+
|
|
651
|
+
# Parse response using configured parser
|
|
652
|
+
response_data = response.json()
|
|
653
|
+
jwt_token = self.config.refresh_token_response_parser(response_data)
|
|
654
|
+
|
|
655
|
+
if jwt_token:
|
|
656
|
+
self.logger.debug("Successfully converted refresh token to JWT")
|
|
657
|
+
return jwt_token
|
|
658
|
+
else:
|
|
659
|
+
self.logger.error("Response parser returned no JWT token")
|
|
660
|
+
return None
|
|
661
|
+
|
|
662
|
+
except httpx.HTTPError as e:
|
|
663
|
+
self.logger.error(f"Failed to convert refresh token: {e}")
|
|
664
|
+
return None
|
|
665
|
+
except Exception as e:
|
|
666
|
+
self.logger.error(f"Error converting refresh token: {e}")
|
|
667
|
+
return None
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
class JWTAuthenticationMiddleware(Middleware):
|
|
671
|
+
"""Middleware to validate JWT authentication with comprehensive error handling.
|
|
672
|
+
|
|
673
|
+
This middleware validates JWT tokens in incoming requests and provides
|
|
674
|
+
comprehensive error handling and logging. It supports both signature
|
|
675
|
+
verification and claim-based validation modes.
|
|
676
|
+
|
|
677
|
+
The middleware provides:
|
|
678
|
+
- JWT token extraction from Authorization headers
|
|
679
|
+
- Signature verification using public keys or JWKS
|
|
680
|
+
- Claim validation (issuer, audience, expiration)
|
|
681
|
+
- OpenBridge-specific validation (user_id, account_id)
|
|
682
|
+
- Context-safe claim storage for application use
|
|
683
|
+
- Comprehensive error handling and detailed logging
|
|
684
|
+
|
|
685
|
+
Key Features:
|
|
686
|
+
- Support for both signature verification and claim-only validation
|
|
687
|
+
- JWKS integration for dynamic public key retrieval
|
|
688
|
+
- OpenBridge-specific claim validation
|
|
689
|
+
- Context-safe claim storage using contextvars
|
|
690
|
+
- Comprehensive error handling with detailed logging
|
|
691
|
+
- Token corruption detection and cleanup attempts
|
|
692
|
+
|
|
693
|
+
Examples:
|
|
694
|
+
>>> config = AuthConfig()
|
|
695
|
+
>>> config.jwt_validation_enabled = True
|
|
696
|
+
>>> config.jwt_verify_signature = True
|
|
697
|
+
>>> middleware = JWTAuthenticationMiddleware(config)
|
|
698
|
+
"""
|
|
699
|
+
|
|
700
|
+
def __init__(self, config: AuthConfig):
|
|
701
|
+
super().__init__()
|
|
702
|
+
self.config = config
|
|
703
|
+
self.logger = logging.getLogger(f"{__name__}.JWTAuthenticationMiddleware")
|
|
704
|
+
|
|
705
|
+
async def on_request(self, context: MiddlewareContext, call_next):
|
|
706
|
+
"""Validate JWT authentication for all requests.
|
|
707
|
+
|
|
708
|
+
This method intercepts incoming requests and validates JWT tokens
|
|
709
|
+
from either context-safe storage (set by RefreshTokenMiddleware)
|
|
710
|
+
or Authorization headers. It stores validated claims in context
|
|
711
|
+
for use by other parts of the application.
|
|
712
|
+
|
|
713
|
+
Args:
|
|
714
|
+
context: The FastMCP middleware context.
|
|
715
|
+
call_next: The next middleware in the chain.
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
The result of the next middleware in the chain.
|
|
719
|
+
|
|
720
|
+
Raises:
|
|
721
|
+
ToolError: If authentication fails or is missing.
|
|
722
|
+
|
|
723
|
+
Examples:
|
|
724
|
+
>>> # This method is called automatically by FastMCP
|
|
725
|
+
>>> # when requests are processed through the middleware chain
|
|
726
|
+
"""
|
|
727
|
+
if not self.config.enabled or not self.config.jwt_validation_enabled:
|
|
728
|
+
self.logger.debug("Authentication disabled, skipping validation")
|
|
729
|
+
return await call_next(context)
|
|
730
|
+
|
|
731
|
+
try:
|
|
732
|
+
# Check for JWT set by RefreshTokenMiddleware in context-safe storage
|
|
733
|
+
jwt_token = jwt_token_var.get()
|
|
734
|
+
if jwt_token:
|
|
735
|
+
token = jwt_token
|
|
736
|
+
self.logger.info("Using JWT from context-safe storage")
|
|
737
|
+
# Clear the token from context after use
|
|
738
|
+
jwt_token_var.set(None)
|
|
739
|
+
else:
|
|
740
|
+
# Get token from headers via context
|
|
741
|
+
auth_header = ""
|
|
742
|
+
if context.fastmcp_context and context.fastmcp_context.request_context:
|
|
743
|
+
request = context.fastmcp_context.request_context.request
|
|
744
|
+
if request and hasattr(request, "headers"):
|
|
745
|
+
auth_header = request.headers.get("authorization", "")
|
|
746
|
+
|
|
747
|
+
self.logger.info(
|
|
748
|
+
f"JWT middleware - Authorization header present: {bool(auth_header)}"
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if not auth_header:
|
|
752
|
+
self.logger.warning(
|
|
753
|
+
"JWT middleware - Missing Authorization header, rejecting request"
|
|
754
|
+
)
|
|
755
|
+
raise ToolError(
|
|
756
|
+
"Authentication required: Missing Authorization header"
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
# Extract token more robustly - handle case variations and extra whitespace
|
|
760
|
+
self.logger.debug(
|
|
761
|
+
f"Authorization header: {sanitize_string(auth_header)}"
|
|
762
|
+
)
|
|
763
|
+
parts = auth_header.split(" ", 1)
|
|
764
|
+
if len(parts) != 2 or parts[0].lower() != "bearer":
|
|
765
|
+
self.logger.warning(
|
|
766
|
+
f"JWT middleware - Invalid Authorization header format: {sanitize_string(auth_header)}"
|
|
767
|
+
)
|
|
768
|
+
raise ToolError(
|
|
769
|
+
"Authentication required: Invalid Authorization header format"
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
token = parts[1].strip() # Remove any leading/trailing whitespace
|
|
773
|
+
self.logger.debug(f"Extracted token (length: {len(token)}, type: JWT)")
|
|
774
|
+
|
|
775
|
+
# Validate the JWT token and store claims in context
|
|
776
|
+
claims = await self._validate_jwt_token(token)
|
|
777
|
+
if claims:
|
|
778
|
+
self.logger.info("JWT token validation successful")
|
|
779
|
+
# Store validated claims in context for other parts of the application
|
|
780
|
+
jwt_claims_var.set(claims)
|
|
781
|
+
return await call_next(context)
|
|
782
|
+
else:
|
|
783
|
+
self.logger.error("JWT token validation failed")
|
|
784
|
+
raise ToolError("Invalid authentication token")
|
|
785
|
+
|
|
786
|
+
except ToolError:
|
|
787
|
+
raise
|
|
788
|
+
except Exception as e:
|
|
789
|
+
self.logger.error(f"Authentication error: {e}")
|
|
790
|
+
raise ToolError("Authentication failed")
|
|
791
|
+
|
|
792
|
+
async def _validate_jwt_token(self, token: str) -> Optional[dict]:
|
|
793
|
+
"""Validate JWT token and return claims if valid.
|
|
794
|
+
|
|
795
|
+
This method validates a JWT token using either signature verification
|
|
796
|
+
or claim-only validation based on the configuration. It handles
|
|
797
|
+
different validation modes and provides comprehensive error handling.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
token: The JWT token to validate.
|
|
801
|
+
|
|
802
|
+
Returns:
|
|
803
|
+
The decoded claims if validation is successful, None otherwise.
|
|
804
|
+
|
|
805
|
+
Examples:
|
|
806
|
+
>>> claims = await middleware._validate_jwt_token("jwt_token_here")
|
|
807
|
+
>>> if claims:
|
|
808
|
+
... print(f"User ID: {claims.get('user_id')}")
|
|
809
|
+
"""
|
|
810
|
+
try:
|
|
811
|
+
if self.config.jwt_verify_signature:
|
|
812
|
+
return await self._validate_jwt_with_signature(token)
|
|
813
|
+
else:
|
|
814
|
+
return await self._validate_jwt_without_signature(token)
|
|
815
|
+
except Exception as e:
|
|
816
|
+
self.logger.warning(f"JWT decode failed: {e}")
|
|
817
|
+
return None
|
|
818
|
+
|
|
819
|
+
async def _validate_jwt_with_signature(self, token: str) -> Optional[dict]:
|
|
820
|
+
"""Validate JWT token with signature verification and return claims.
|
|
821
|
+
|
|
822
|
+
This method validates a JWT token using cryptographic signature
|
|
823
|
+
verification. It retrieves the public key from configured sources
|
|
824
|
+
and validates the token's signature, issuer, audience, and expiration.
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
token: The JWT token to validate with signature verification.
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
The decoded claims if signature verification is successful, None otherwise.
|
|
831
|
+
|
|
832
|
+
Raises:
|
|
833
|
+
jwt.ExpiredSignatureError: If the token has expired.
|
|
834
|
+
jwt.InvalidIssuerError: If the issuer validation fails.
|
|
835
|
+
jwt.InvalidAudienceError: If the audience validation fails.
|
|
836
|
+
jwt.InvalidSignatureError: If the signature validation fails.
|
|
837
|
+
|
|
838
|
+
Examples:
|
|
839
|
+
>>> claims = await middleware._validate_jwt_with_signature("jwt_token_here")
|
|
840
|
+
>>> if claims:
|
|
841
|
+
... print("Signature verification successful")
|
|
842
|
+
"""
|
|
843
|
+
try:
|
|
844
|
+
# Get public key
|
|
845
|
+
public_key = await self._get_public_key(token)
|
|
846
|
+
if not public_key:
|
|
847
|
+
self.logger.warning("No public key available for JWT validation")
|
|
848
|
+
return None
|
|
849
|
+
|
|
850
|
+
# Prepare decode options
|
|
851
|
+
decode_options = {
|
|
852
|
+
"verify_signature": True,
|
|
853
|
+
"verify_exp": True,
|
|
854
|
+
"verify_iss": self.config.jwt_verify_iss,
|
|
855
|
+
"verify_aud": self.config.jwt_verify_aud,
|
|
856
|
+
}
|
|
857
|
+
|
|
858
|
+
# Only validate audience if explicitly configured (like FastMCP BearerAuthProvider)
|
|
859
|
+
decode_kwargs = {
|
|
860
|
+
"token": token,
|
|
861
|
+
"key": public_key,
|
|
862
|
+
"algorithms": ["RS256", "ES256", "HS256"],
|
|
863
|
+
"options": decode_options,
|
|
864
|
+
}
|
|
865
|
+
|
|
866
|
+
# Add audience only if configured and verification is enabled
|
|
867
|
+
if self.config.jwt_audience and self.config.jwt_verify_aud:
|
|
868
|
+
decode_kwargs["audience"] = self.config.jwt_audience
|
|
869
|
+
elif self.config.jwt_verify_aud:
|
|
870
|
+
# If audience verification is enabled but no audience is configured, disable it
|
|
871
|
+
decode_options["verify_aud"] = False
|
|
872
|
+
self.logger.debug(
|
|
873
|
+
"Audience verification disabled - no audience configured"
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# Decode with signature verification
|
|
877
|
+
payload = jwt.decode(**decode_kwargs)
|
|
878
|
+
|
|
879
|
+
# Validate required claims
|
|
880
|
+
for claim in self.config.jwt_required_claims:
|
|
881
|
+
if claim not in payload:
|
|
882
|
+
self.logger.warning(f"Missing required claim: {claim}")
|
|
883
|
+
return None
|
|
884
|
+
|
|
885
|
+
return payload
|
|
886
|
+
|
|
887
|
+
except jwt.ExpiredSignatureError:
|
|
888
|
+
self.logger.warning("JWT token is expired")
|
|
889
|
+
return None
|
|
890
|
+
except jwt.InvalidIssuerError:
|
|
891
|
+
self.logger.warning("JWT issuer validation failed")
|
|
892
|
+
return None
|
|
893
|
+
except jwt.InvalidAudienceError:
|
|
894
|
+
self.logger.warning("JWT audience validation failed")
|
|
895
|
+
return None
|
|
896
|
+
except jwt.InvalidSignatureError:
|
|
897
|
+
self.logger.warning("JWT signature validation failed")
|
|
898
|
+
return None
|
|
899
|
+
except Exception as e:
|
|
900
|
+
self.logger.warning(f"JWT validation error: {e}")
|
|
901
|
+
return None
|
|
902
|
+
|
|
903
|
+
async def _validate_jwt_without_signature(self, token: str) -> Optional[dict]:
|
|
904
|
+
"""Validate JWT token without signature verification (OpenBridge style) and return claims.
|
|
905
|
+
|
|
906
|
+
This method validates a JWT token without cryptographic signature
|
|
907
|
+
verification, focusing on claim validation and OpenBridge-specific
|
|
908
|
+
requirements. It handles token corruption detection and cleanup attempts.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
token: The JWT token to validate without signature verification.
|
|
912
|
+
|
|
913
|
+
Returns:
|
|
914
|
+
The decoded claims if validation is successful, None otherwise.
|
|
915
|
+
|
|
916
|
+
Examples:
|
|
917
|
+
>>> claims = await middleware._validate_jwt_without_signature("jwt_token_here")
|
|
918
|
+
>>> if claims:
|
|
919
|
+
... print(f"OpenBridge validation successful - User: {claims.get('user_id')}")
|
|
920
|
+
"""
|
|
921
|
+
try:
|
|
922
|
+
# Clean the token - remove any whitespace or encoding issues
|
|
923
|
+
token = token.strip()
|
|
924
|
+
|
|
925
|
+
# Debug token format - add more detailed logging
|
|
926
|
+
self.logger.debug(f"Validating JWT token (length: {len(token)})")
|
|
927
|
+
|
|
928
|
+
# Check for common corruption patterns
|
|
929
|
+
if token.startswith("Bearer "):
|
|
930
|
+
self.logger.error(
|
|
931
|
+
"Token still has 'Bearer ' prefix - extraction failed!"
|
|
932
|
+
)
|
|
933
|
+
token = token[7:].strip()
|
|
934
|
+
|
|
935
|
+
if " " in token:
|
|
936
|
+
self.logger.error("Token contains whitespace - may be corrupted")
|
|
937
|
+
self.logger.error(
|
|
938
|
+
f"Whitespace positions: {[i for i, c in enumerate(token) if c == ' ']}"
|
|
939
|
+
)
|
|
940
|
+
token = token.replace(" ", "")
|
|
941
|
+
|
|
942
|
+
if "\n" in token or "\r" in token:
|
|
943
|
+
self.logger.error("Token contains newlines - may be corrupted")
|
|
944
|
+
token = token.replace("\n", "").replace("\r", "")
|
|
945
|
+
|
|
946
|
+
# Try to decode the token parts first to understand the structure
|
|
947
|
+
try:
|
|
948
|
+
parts = token.split(".")
|
|
949
|
+
if len(parts) != 3:
|
|
950
|
+
self.logger.error(
|
|
951
|
+
f"Token doesn't have 3 parts! Found {len(parts)} parts"
|
|
952
|
+
)
|
|
953
|
+
return None
|
|
954
|
+
|
|
955
|
+
self.logger.debug(
|
|
956
|
+
f"Token has 3 parts: header({len(parts[0])}), payload({len(parts[1])}), signature({len(parts[2])})"
|
|
957
|
+
)
|
|
958
|
+
except Exception as e:
|
|
959
|
+
self.logger.error(f"Failed to split token: {e}")
|
|
960
|
+
return None
|
|
961
|
+
|
|
962
|
+
# Decode without signature verification
|
|
963
|
+
payload = jwt.decode(token, options={"verify_signature": False})
|
|
964
|
+
|
|
965
|
+
# Check if token is expired using expires_at (OpenBridge format)
|
|
966
|
+
expires_at = payload.get("expires_at")
|
|
967
|
+
if expires_at:
|
|
968
|
+
try:
|
|
969
|
+
exp_time = datetime.fromtimestamp(
|
|
970
|
+
float(expires_at), tz=timezone.utc
|
|
971
|
+
)
|
|
972
|
+
if datetime.now(timezone.utc) > exp_time:
|
|
973
|
+
self.logger.warning("JWT token is expired")
|
|
974
|
+
return None
|
|
975
|
+
except (ValueError, TypeError):
|
|
976
|
+
self.logger.warning(f"Invalid expires_at format: {expires_at}")
|
|
977
|
+
return None
|
|
978
|
+
|
|
979
|
+
# Validate OpenBridge-specific claims (user_id, account_id)
|
|
980
|
+
user_id = payload.get("user_id")
|
|
981
|
+
account_id = payload.get("account_id")
|
|
982
|
+
|
|
983
|
+
if not user_id or not account_id:
|
|
984
|
+
self.logger.warning(
|
|
985
|
+
"JWT missing required OpenBridge fields (user_id, account_id)"
|
|
986
|
+
)
|
|
987
|
+
return None
|
|
988
|
+
|
|
989
|
+
self.logger.info(
|
|
990
|
+
f"OpenBridge token validated - User: {user_id}, Account: {account_id}"
|
|
991
|
+
)
|
|
992
|
+
return payload
|
|
993
|
+
|
|
994
|
+
except jwt.DecodeError as e:
|
|
995
|
+
self.logger.warning(f"JWT decode error: {e}")
|
|
996
|
+
# Try to provide more specific error information
|
|
997
|
+
if "Invalid header padding" in str(e):
|
|
998
|
+
self.logger.error(
|
|
999
|
+
"JWT header padding error - token may be corrupted or improperly formatted"
|
|
1000
|
+
)
|
|
1001
|
+
self.logger.error(f"Original token: {sanitize_string(token)}")
|
|
1002
|
+
|
|
1003
|
+
# Try to clean the token and retry
|
|
1004
|
+
try:
|
|
1005
|
+
# Remove any potential encoding issues
|
|
1006
|
+
cleaned_token = (
|
|
1007
|
+
token.strip()
|
|
1008
|
+
.replace(" ", "")
|
|
1009
|
+
.replace("\n", "")
|
|
1010
|
+
.replace("\r", "")
|
|
1011
|
+
)
|
|
1012
|
+
if cleaned_token != token:
|
|
1013
|
+
self.logger.info("Attempting to decode cleaned token")
|
|
1014
|
+
self.logger.info(
|
|
1015
|
+
f"Cleaned token: {sanitize_string(cleaned_token)}"
|
|
1016
|
+
)
|
|
1017
|
+
payload = jwt.decode(
|
|
1018
|
+
cleaned_token, options={"verify_signature": False}
|
|
1019
|
+
)
|
|
1020
|
+
self.logger.info("Cleaned token decoded successfully")
|
|
1021
|
+
return payload
|
|
1022
|
+
except Exception as cleanup_error:
|
|
1023
|
+
self.logger.error(
|
|
1024
|
+
f"Failed to clean and decode token: {cleanup_error}"
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
# Try to analyze the token structure
|
|
1028
|
+
try:
|
|
1029
|
+
parts = token.split(".")
|
|
1030
|
+
if len(parts) >= 1:
|
|
1031
|
+
self.logger.error(f"Header part: {sanitize_string(parts[0])}")
|
|
1032
|
+
if len(parts) >= 2:
|
|
1033
|
+
self.logger.error(
|
|
1034
|
+
f"Payload part: {sanitize_string(parts[1])}"
|
|
1035
|
+
)
|
|
1036
|
+
except Exception as analyze_error:
|
|
1037
|
+
self.logger.error(
|
|
1038
|
+
f"Failed to analyze token structure: {analyze_error}"
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
return None
|
|
1042
|
+
except Exception as e:
|
|
1043
|
+
self.logger.warning(f"JWT validation error: {e}")
|
|
1044
|
+
return None
|
|
1045
|
+
|
|
1046
|
+
async def _get_public_key(self, token: str) -> Optional[str]:
|
|
1047
|
+
"""Get public key for JWT validation.
|
|
1048
|
+
|
|
1049
|
+
This method retrieves the public key needed for JWT signature
|
|
1050
|
+
verification. It supports both static public keys and dynamic
|
|
1051
|
+
key retrieval from JWKS endpoints.
|
|
1052
|
+
|
|
1053
|
+
Args:
|
|
1054
|
+
token: The JWT token to extract key information from.
|
|
1055
|
+
|
|
1056
|
+
Returns:
|
|
1057
|
+
The public key in PEM format if available, None otherwise.
|
|
1058
|
+
|
|
1059
|
+
Examples:
|
|
1060
|
+
>>> public_key = await middleware._get_public_key("jwt_token_here")
|
|
1061
|
+
>>> if public_key:
|
|
1062
|
+
... print("Public key retrieved successfully")
|
|
1063
|
+
"""
|
|
1064
|
+
if self.config.jwt_public_key:
|
|
1065
|
+
return self.config.jwt_public_key
|
|
1066
|
+
|
|
1067
|
+
if self.config.jwt_jwks_uri:
|
|
1068
|
+
try:
|
|
1069
|
+
# Decode header to get key ID
|
|
1070
|
+
header = jwt.get_unverified_header(token)
|
|
1071
|
+
kid = header.get("kid")
|
|
1072
|
+
|
|
1073
|
+
if not kid:
|
|
1074
|
+
self.logger.warning("No key ID in JWT header")
|
|
1075
|
+
return None
|
|
1076
|
+
|
|
1077
|
+
# Get shared HTTP client
|
|
1078
|
+
client = await get_http_client()
|
|
1079
|
+
|
|
1080
|
+
# Fetch JWKS
|
|
1081
|
+
response = await client.get(self.config.jwt_jwks_uri, timeout=10.0)
|
|
1082
|
+
response.raise_for_status()
|
|
1083
|
+
jwks = response.json()
|
|
1084
|
+
|
|
1085
|
+
# Find the key
|
|
1086
|
+
for key in jwks.get("keys", []):
|
|
1087
|
+
if key.get("kid") == kid:
|
|
1088
|
+
return self._jwk_to_pem(key)
|
|
1089
|
+
|
|
1090
|
+
self.logger.warning(f"Key ID {kid} not found in JWKS")
|
|
1091
|
+
return None
|
|
1092
|
+
|
|
1093
|
+
except httpx.HTTPError as e:
|
|
1094
|
+
self.logger.error(f"Failed to fetch JWKS: {e}")
|
|
1095
|
+
return None
|
|
1096
|
+
except Exception as e:
|
|
1097
|
+
self.logger.error(f"Error processing JWKS: {e}")
|
|
1098
|
+
return None
|
|
1099
|
+
|
|
1100
|
+
return None
|
|
1101
|
+
|
|
1102
|
+
def _jwk_to_pem(self, jwk_key: dict) -> Optional[str]:
|
|
1103
|
+
"""Convert JWK to PEM format.
|
|
1104
|
+
|
|
1105
|
+
This method converts a JSON Web Key (JWK) to PEM format for
|
|
1106
|
+
use in JWT signature verification. It currently supports RSA keys.
|
|
1107
|
+
|
|
1108
|
+
Args:
|
|
1109
|
+
jwk_key: The JWK key dictionary to convert.
|
|
1110
|
+
|
|
1111
|
+
Returns:
|
|
1112
|
+
The public key in PEM format if conversion is successful, None otherwise.
|
|
1113
|
+
|
|
1114
|
+
Examples:
|
|
1115
|
+
>>> jwk = {"kty": "RSA", "n": "...", "e": "..."}
|
|
1116
|
+
>>> pem_key = middleware._jwk_to_pem(jwk)
|
|
1117
|
+
>>> if pem_key:
|
|
1118
|
+
... print("JWK converted to PEM successfully")
|
|
1119
|
+
"""
|
|
1120
|
+
try:
|
|
1121
|
+
# This is a simplified conversion - in production you'd want a proper JWK library
|
|
1122
|
+
if jwk_key.get("kty") == "RSA":
|
|
1123
|
+
from cryptography.hazmat.primitives import serialization
|
|
1124
|
+
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
1125
|
+
|
|
1126
|
+
n = int.from_bytes(jwt.utils.base64url_decode(jwk_key["n"]), "big")
|
|
1127
|
+
e = int.from_bytes(jwt.utils.base64url_decode(jwk_key["e"]), "big")
|
|
1128
|
+
|
|
1129
|
+
public_key = rsa.RSAPublicNumbers(e, n).public_key()
|
|
1130
|
+
pem = public_key.public_bytes(
|
|
1131
|
+
encoding=serialization.Encoding.PEM,
|
|
1132
|
+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
1133
|
+
)
|
|
1134
|
+
return pem.decode("utf-8")
|
|
1135
|
+
|
|
1136
|
+
except Exception as e:
|
|
1137
|
+
self.logger.error(f"Error converting JWK to PEM: {e}")
|
|
1138
|
+
|
|
1139
|
+
return None
|
|
1140
|
+
|
|
1141
|
+
|
|
1142
|
+
# Utility functions for accessing JWT data from context
|
|
1143
|
+
def get_current_jwt() -> Optional[str]:
|
|
1144
|
+
"""Get JWT token for current request context.
|
|
1145
|
+
|
|
1146
|
+
This function retrieves the JWT token from the current request context
|
|
1147
|
+
using context-safe storage. It provides access to the JWT token
|
|
1148
|
+
for other parts of the application that need it.
|
|
1149
|
+
|
|
1150
|
+
:return: The JWT token for the current request context, or None if not available
|
|
1151
|
+
:rtype: Optional[str]
|
|
1152
|
+
|
|
1153
|
+
.. example::
|
|
1154
|
+
>>> jwt_token = get_current_jwt()
|
|
1155
|
+
>>> if jwt_token:
|
|
1156
|
+
... print("JWT token available in current context")
|
|
1157
|
+
"""
|
|
1158
|
+
return jwt_token_var.get()
|
|
1159
|
+
|
|
1160
|
+
|
|
1161
|
+
def get_current_claims() -> Optional[dict]:
|
|
1162
|
+
"""Get JWT claims for current request context.
|
|
1163
|
+
|
|
1164
|
+
This function retrieves the validated JWT claims from the current
|
|
1165
|
+
request context using context-safe storage. It provides access to
|
|
1166
|
+
user information and other claims for application logic.
|
|
1167
|
+
|
|
1168
|
+
:return: The JWT claims dictionary for the current request context, or None if not available
|
|
1169
|
+
:rtype: Optional[dict]
|
|
1170
|
+
|
|
1171
|
+
.. example::
|
|
1172
|
+
>>> claims = get_current_claims()
|
|
1173
|
+
>>> if claims:
|
|
1174
|
+
... print(f"User ID: {claims.get('user_id')}")
|
|
1175
|
+
... print(f"Account ID: {claims.get('account_id')}")
|
|
1176
|
+
"""
|
|
1177
|
+
return jwt_claims_var.get()
|
|
1178
|
+
|
|
1179
|
+
|
|
1180
|
+
def create_auth_middleware(
|
|
1181
|
+
config: Optional[AuthConfig] = None,
|
|
1182
|
+
refresh_token_middleware: bool = True,
|
|
1183
|
+
jwt_middleware: bool = True,
|
|
1184
|
+
auth_manager: Optional[Any] = None,
|
|
1185
|
+
) -> list:
|
|
1186
|
+
"""Create authentication middleware chain.
|
|
1187
|
+
|
|
1188
|
+
This function creates a complete authentication middleware chain
|
|
1189
|
+
with optional refresh token conversion and JWT validation. It
|
|
1190
|
+
supports auto-configuration for common providers like OpenBridge.
|
|
1191
|
+
|
|
1192
|
+
The function automatically detects provider types and configures
|
|
1193
|
+
appropriate handlers for refresh token conversion and JWT validation.
|
|
1194
|
+
It supports OpenBridge, generic JSON:API, and custom configurations.
|
|
1195
|
+
|
|
1196
|
+
:param config: Optional AuthConfig instance. If None, creates a new config and loads settings from environment variables
|
|
1197
|
+
:type config: Optional[AuthConfig]
|
|
1198
|
+
:param refresh_token_middleware: Whether to include refresh token conversion middleware
|
|
1199
|
+
:type refresh_token_middleware: bool
|
|
1200
|
+
:param jwt_middleware: Whether to include JWT validation middleware
|
|
1201
|
+
:type jwt_middleware: bool
|
|
1202
|
+
:param auth_manager: Optional AuthManager instance for token store integration
|
|
1203
|
+
:type auth_manager: Optional[Any]
|
|
1204
|
+
:return: A list of middleware instances ready for use with FastMCP
|
|
1205
|
+
:rtype: list
|
|
1206
|
+
|
|
1207
|
+
.. example::
|
|
1208
|
+
>>> # Auto-configure from environment
|
|
1209
|
+
>>> middleware = create_auth_middleware()
|
|
1210
|
+
|
|
1211
|
+
>>> # Use specific configuration
|
|
1212
|
+
>>> config = AuthConfig()
|
|
1213
|
+
>>> config.load_from_env()
|
|
1214
|
+
>>> middleware = create_auth_middleware(config)
|
|
1215
|
+
|
|
1216
|
+
>>> # JWT validation only
|
|
1217
|
+
>>> middleware = create_auth_middleware(
|
|
1218
|
+
... refresh_token_middleware=False, jwt_middleware=True
|
|
1219
|
+
... )
|
|
1220
|
+
"""
|
|
1221
|
+
if not config:
|
|
1222
|
+
config = AuthConfig()
|
|
1223
|
+
config.load_from_env()
|
|
1224
|
+
|
|
1225
|
+
# Auto-configure OpenBridge if the endpoint is detected
|
|
1226
|
+
if config.refresh_token_enabled and config.refresh_token_endpoint:
|
|
1227
|
+
# Validate hostname is legitimate OpenBridge domain
|
|
1228
|
+
from urllib.parse import urlparse
|
|
1229
|
+
parsed_endpoint = urlparse(config.refresh_token_endpoint)
|
|
1230
|
+
endpoint_host = (parsed_endpoint.hostname or "").lower()
|
|
1231
|
+
is_openbridge = endpoint_host.endswith(".openbridge.io") or endpoint_host == "openbridge.io"
|
|
1232
|
+
if is_openbridge:
|
|
1233
|
+
logger.info("Auto-configuring OpenBridge authentication")
|
|
1234
|
+
# Use OpenBridge-specific configuration
|
|
1235
|
+
config = create_openbridge_config()
|
|
1236
|
+
# Override with environment variables but preserve auto-configured handlers
|
|
1237
|
+
config.load_from_env()
|
|
1238
|
+
elif not config.refresh_token_request_builder:
|
|
1239
|
+
# Auto-configure generic JSON:API if no handlers are set
|
|
1240
|
+
logger.info("Auto-configuring generic JSON:API authentication")
|
|
1241
|
+
config = create_json_api_refresh_token_config(
|
|
1242
|
+
endpoint_url=config.refresh_token_endpoint,
|
|
1243
|
+
token_type_name=os.getenv("JSON_API_TOKEN_TYPE_NAME", "APIAuth"),
|
|
1244
|
+
required_claims=config.jwt_required_claims or ["user_id"],
|
|
1245
|
+
verify_signature=config.jwt_verify_signature,
|
|
1246
|
+
)
|
|
1247
|
+
# Override with environment variables but preserve auto-configured handlers
|
|
1248
|
+
config.load_from_env()
|
|
1249
|
+
|
|
1250
|
+
if not config.validate():
|
|
1251
|
+
logger.error("Invalid authentication configuration")
|
|
1252
|
+
return []
|
|
1253
|
+
|
|
1254
|
+
middleware = []
|
|
1255
|
+
|
|
1256
|
+
# Add refresh token middleware first (converts refresh tokens to JWTs)
|
|
1257
|
+
if refresh_token_middleware and config.refresh_token_enabled:
|
|
1258
|
+
middleware.append(RefreshTokenMiddleware(config, auth_manager))
|
|
1259
|
+
logger.info("Added RefreshTokenMiddleware")
|
|
1260
|
+
|
|
1261
|
+
# Add JWT authentication middleware (validates JWTs)
|
|
1262
|
+
if jwt_middleware and config.jwt_validation_enabled:
|
|
1263
|
+
middleware.append(JWTAuthenticationMiddleware(config))
|
|
1264
|
+
logger.info("Added JWTAuthenticationMiddleware")
|
|
1265
|
+
|
|
1266
|
+
logger.info(f"Created {len(middleware)} middleware components")
|
|
1267
|
+
return middleware
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
def create_json_api_refresh_token_config(
|
|
1271
|
+
endpoint_url: str,
|
|
1272
|
+
token_type_name: str,
|
|
1273
|
+
required_claims: list[str],
|
|
1274
|
+
verify_signature: bool = True,
|
|
1275
|
+
) -> AuthConfig:
|
|
1276
|
+
"""Create configuration for JSON:API style refresh token endpoints.
|
|
1277
|
+
|
|
1278
|
+
This function creates a pre-configured AuthConfig for JSON:API
|
|
1279
|
+
style refresh token endpoints with standard request/response
|
|
1280
|
+
patterns and handlers.
|
|
1281
|
+
|
|
1282
|
+
The configuration includes standard JSON:API request builders,
|
|
1283
|
+
response parsers, and token pattern detection for automatic
|
|
1284
|
+
refresh token handling.
|
|
1285
|
+
|
|
1286
|
+
:param endpoint_url: The refresh token endpoint URL
|
|
1287
|
+
:type endpoint_url: str
|
|
1288
|
+
:param token_type_name: The JSON:API resource type name for tokens
|
|
1289
|
+
:type token_type_name: str
|
|
1290
|
+
:param required_claims: List of required JWT claims for validation
|
|
1291
|
+
:type required_claims: list[str]
|
|
1292
|
+
:param verify_signature: Whether to verify JWT signatures
|
|
1293
|
+
:type verify_signature: bool
|
|
1294
|
+
:return: A configured AuthConfig instance for JSON:API refresh tokens
|
|
1295
|
+
:rtype: AuthConfig
|
|
1296
|
+
|
|
1297
|
+
.. example::
|
|
1298
|
+
>>> config = create_json_api_refresh_token_config(
|
|
1299
|
+
... endpoint_url="https://api.example.com/auth/refresh",
|
|
1300
|
+
... token_type_name="APIAuth",
|
|
1301
|
+
... required_claims=["user_id", "account_id"]
|
|
1302
|
+
... )
|
|
1303
|
+
>>> middleware = create_auth_middleware(config)
|
|
1304
|
+
"""
|
|
1305
|
+
config = AuthConfig()
|
|
1306
|
+
config.enabled = True
|
|
1307
|
+
config.refresh_token_enabled = True
|
|
1308
|
+
config.jwt_validation_enabled = True
|
|
1309
|
+
config.refresh_token_endpoint = endpoint_url
|
|
1310
|
+
config.jwt_required_claims = required_claims
|
|
1311
|
+
config.jwt_verify_signature = verify_signature
|
|
1312
|
+
|
|
1313
|
+
# JSON:API request builder
|
|
1314
|
+
def request_builder(refresh_token: str) -> dict:
|
|
1315
|
+
return {
|
|
1316
|
+
"data": {
|
|
1317
|
+
"type": token_type_name,
|
|
1318
|
+
"attributes": {"refresh_token": refresh_token},
|
|
1319
|
+
}
|
|
1320
|
+
}
|
|
1321
|
+
|
|
1322
|
+
# JSON:API response parser
|
|
1323
|
+
def response_parser(response_data: dict) -> Optional[str]:
|
|
1324
|
+
try:
|
|
1325
|
+
return response_data.get("data", {}).get("attributes", {}).get("token")
|
|
1326
|
+
except Exception:
|
|
1327
|
+
return None
|
|
1328
|
+
|
|
1329
|
+
# Pattern detector for refresh tokens (simple heuristic)
|
|
1330
|
+
def pattern_detector(token: str) -> bool:
|
|
1331
|
+
# OpenBridge refresh tokens typically contain a colon
|
|
1332
|
+
return ":" in token and len(token) > 20
|
|
1333
|
+
|
|
1334
|
+
config.set_refresh_token_handlers(
|
|
1335
|
+
request_builder, response_parser, pattern_detector
|
|
1336
|
+
)
|
|
1337
|
+
return config
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
def create_openbridge_config() -> AuthConfig:
|
|
1341
|
+
"""Create configuration for OpenBridge authentication.
|
|
1342
|
+
|
|
1343
|
+
This function creates a pre-configured AuthConfig specifically
|
|
1344
|
+
for OpenBridge authentication with the correct endpoint, token
|
|
1345
|
+
type, and validation settings.
|
|
1346
|
+
|
|
1347
|
+
The configuration automatically detects the OpenBridge authentication
|
|
1348
|
+
base URL from environment variables and sets up appropriate handlers
|
|
1349
|
+
for refresh token conversion and JWT validation.
|
|
1350
|
+
|
|
1351
|
+
:return: A configured AuthConfig instance for OpenBridge authentication
|
|
1352
|
+
:rtype: AuthConfig
|
|
1353
|
+
|
|
1354
|
+
.. example::
|
|
1355
|
+
>>> config = create_openbridge_config()
|
|
1356
|
+
>>> middleware = create_auth_middleware(config)
|
|
1357
|
+
|
|
1358
|
+
.. note::
|
|
1359
|
+
The configuration includes:
|
|
1360
|
+
|
|
1361
|
+
- OpenBridge refresh token endpoint
|
|
1362
|
+
- APIAuth token type
|
|
1363
|
+
- user_id and account_id required claims
|
|
1364
|
+
- Signature verification disabled (tokens trusted from API)
|
|
1365
|
+
"""
|
|
1366
|
+
# Build endpoint from env, with explicit REFRESH_TOKEN_ENDPOINT taking precedence
|
|
1367
|
+
auth_base = os.getenv(
|
|
1368
|
+
"OPENBRIDGE_AUTH_BASE_URL", "https://authentication.api.openbridge.io"
|
|
1369
|
+
).rstrip("/")
|
|
1370
|
+
endpoint_url = os.getenv("REFRESH_TOKEN_ENDPOINT", f"{auth_base}/auth/api/refresh")
|
|
1371
|
+
|
|
1372
|
+
config = create_json_api_refresh_token_config(
|
|
1373
|
+
endpoint_url=endpoint_url,
|
|
1374
|
+
token_type_name="APIAuth",
|
|
1375
|
+
required_claims=["user_id", "account_id"],
|
|
1376
|
+
verify_signature=False, # OpenBridge JWTs are trusted from the API, no public key available
|
|
1377
|
+
)
|
|
1378
|
+
|
|
1379
|
+
# OpenBridge-specific settings
|
|
1380
|
+
config.jwt_verify_iss = False # Don't validate issuer for OpenBridge
|
|
1381
|
+
config.jwt_verify_aud = False # Don't validate audience for OpenBridge
|
|
1382
|
+
|
|
1383
|
+
# Respect AUTH_ENABLED environment variable
|
|
1384
|
+
config.load_from_env()
|
|
1385
|
+
|
|
1386
|
+
return config
|
|
1387
|
+
|
|
1388
|
+
|
|
1389
|
+
def create_auth0_config(domain: str, audience: str) -> AuthConfig:
|
|
1390
|
+
"""Create Auth0 configuration.
|
|
1391
|
+
|
|
1392
|
+
This function creates a pre-configured AuthConfig for Auth0
|
|
1393
|
+
authentication with the correct issuer, audience, and JWKS
|
|
1394
|
+
endpoint settings.
|
|
1395
|
+
|
|
1396
|
+
The configuration sets up standard Auth0 JWT validation with
|
|
1397
|
+
public key retrieval from the JWKS endpoint and proper issuer
|
|
1398
|
+
and audience validation.
|
|
1399
|
+
|
|
1400
|
+
:param domain: The Auth0 domain (e.g., "example.auth0.com")
|
|
1401
|
+
:type domain: str
|
|
1402
|
+
:param audience: The Auth0 API audience identifier
|
|
1403
|
+
:type audience: str
|
|
1404
|
+
:return: A configured AuthConfig instance for Auth0 authentication
|
|
1405
|
+
:rtype: AuthConfig
|
|
1406
|
+
|
|
1407
|
+
.. example::
|
|
1408
|
+
>>> config = create_auth0_config(
|
|
1409
|
+
... domain="example.auth0.com",
|
|
1410
|
+
... audience="https://api.example.com"
|
|
1411
|
+
... )
|
|
1412
|
+
>>> middleware = create_auth_middleware(config)
|
|
1413
|
+
|
|
1414
|
+
.. note::
|
|
1415
|
+
The configuration includes:
|
|
1416
|
+
|
|
1417
|
+
- Auth0 issuer URL
|
|
1418
|
+
- JWKS endpoint for public key retrieval
|
|
1419
|
+
- Audience validation
|
|
1420
|
+
"""
|
|
1421
|
+
config = AuthConfig()
|
|
1422
|
+
config.load_from_env()
|
|
1423
|
+
config.jwt_issuer = f"https://{domain}/"
|
|
1424
|
+
config.jwt_audience = audience
|
|
1425
|
+
config.jwt_jwks_uri = f"https://{domain}/.well-known/jwks.json"
|
|
1426
|
+
return config
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
async def get_auth_info() -> Dict[str, Any]:
|
|
1430
|
+
"""Get authentication information.
|
|
1431
|
+
|
|
1432
|
+
This function retrieves comprehensive information about the current
|
|
1433
|
+
authentication configuration, including enabled features and settings.
|
|
1434
|
+
Useful for debugging and monitoring authentication status.
|
|
1435
|
+
|
|
1436
|
+
The returned information includes all key configuration settings
|
|
1437
|
+
loaded from environment variables and provides insight into the
|
|
1438
|
+
current authentication state.
|
|
1439
|
+
|
|
1440
|
+
:return: A dictionary containing authentication configuration information
|
|
1441
|
+
:rtype: Dict[str, Any]
|
|
1442
|
+
|
|
1443
|
+
Dictionary Keys:
|
|
1444
|
+
|
|
1445
|
+
- enabled: Whether authentication is enabled
|
|
1446
|
+
- jwt_validation_enabled: Whether JWT validation is enabled
|
|
1447
|
+
- refresh_token_enabled: Whether refresh token conversion is enabled
|
|
1448
|
+
- jwt_issuer: Configured JWT issuer
|
|
1449
|
+
- jwt_audience: Configured JWT audience
|
|
1450
|
+
- jwt_verify_signature: Whether signature verification is enabled
|
|
1451
|
+
- jwt_verify_iss: Whether issuer verification is enabled
|
|
1452
|
+
- jwt_verify_aud: Whether audience verification is enabled
|
|
1453
|
+
- jwt_required_claims: List of required JWT claims
|
|
1454
|
+
|
|
1455
|
+
.. example::
|
|
1456
|
+
>>> auth_info = await get_auth_info()
|
|
1457
|
+
>>> print(f"Authentication enabled: {auth_info['enabled']}")
|
|
1458
|
+
>>> print(f"JWT validation enabled: {auth_info['jwt_validation_enabled']}")
|
|
1459
|
+
>>> print(f"Required claims: {auth_info['jwt_required_claims']}")
|
|
1460
|
+
"""
|
|
1461
|
+
config = AuthConfig()
|
|
1462
|
+
config.load_from_env()
|
|
1463
|
+
|
|
1464
|
+
return {
|
|
1465
|
+
"enabled": config.enabled,
|
|
1466
|
+
"jwt_validation_enabled": config.jwt_validation_enabled,
|
|
1467
|
+
"refresh_token_enabled": config.refresh_token_enabled,
|
|
1468
|
+
"jwt_issuer": config.jwt_issuer,
|
|
1469
|
+
"jwt_audience": config.jwt_audience,
|
|
1470
|
+
"jwt_verify_signature": config.jwt_verify_signature,
|
|
1471
|
+
"jwt_verify_iss": config.jwt_verify_iss,
|
|
1472
|
+
"jwt_verify_aud": config.jwt_verify_aud,
|
|
1473
|
+
"jwt_required_claims": config.jwt_required_claims,
|
|
1474
|
+
}
|