mcp-proxy-adapter 4.1.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_proxy_adapter/__main__.py +12 -0
- mcp_proxy_adapter/api/app.py +138 -11
- mcp_proxy_adapter/api/handlers.py +16 -1
- mcp_proxy_adapter/api/middleware/__init__.py +30 -29
- mcp_proxy_adapter/api/middleware/auth_adapter.py +235 -0
- mcp_proxy_adapter/api/middleware/error_handling.py +9 -0
- mcp_proxy_adapter/api/middleware/factory.py +219 -0
- mcp_proxy_adapter/api/middleware/logging.py +32 -6
- mcp_proxy_adapter/api/middleware/mtls_adapter.py +305 -0
- mcp_proxy_adapter/api/middleware/mtls_middleware.py +296 -0
- mcp_proxy_adapter/api/middleware/protocol_middleware.py +135 -0
- mcp_proxy_adapter/api/middleware/rate_limit_adapter.py +241 -0
- mcp_proxy_adapter/api/middleware/roles_adapter.py +365 -0
- mcp_proxy_adapter/api/middleware/roles_middleware.py +381 -0
- mcp_proxy_adapter/api/middleware/security.py +376 -0
- mcp_proxy_adapter/api/middleware/token_auth_middleware.py +261 -0
- mcp_proxy_adapter/api/middleware/transport_middleware.py +122 -0
- mcp_proxy_adapter/commands/__init__.py +13 -4
- mcp_proxy_adapter/commands/auth_validation_command.py +408 -0
- mcp_proxy_adapter/commands/base.py +61 -30
- mcp_proxy_adapter/commands/builtin_commands.py +89 -0
- mcp_proxy_adapter/commands/catalog_manager.py +838 -0
- mcp_proxy_adapter/commands/cert_monitor_command.py +620 -0
- mcp_proxy_adapter/commands/certificate_management_command.py +608 -0
- mcp_proxy_adapter/commands/command_registry.py +705 -345
- mcp_proxy_adapter/commands/dependency_manager.py +245 -0
- mcp_proxy_adapter/commands/health_command.py +7 -0
- mcp_proxy_adapter/commands/hooks.py +200 -167
- mcp_proxy_adapter/commands/key_management_command.py +506 -0
- mcp_proxy_adapter/commands/load_command.py +176 -0
- mcp_proxy_adapter/commands/plugins_command.py +235 -0
- mcp_proxy_adapter/commands/protocol_management_command.py +232 -0
- mcp_proxy_adapter/commands/proxy_registration_command.py +268 -0
- mcp_proxy_adapter/commands/reload_command.py +48 -50
- mcp_proxy_adapter/commands/result.py +1 -0
- mcp_proxy_adapter/commands/roles_management_command.py +697 -0
- mcp_proxy_adapter/commands/ssl_setup_command.py +483 -0
- mcp_proxy_adapter/commands/token_management_command.py +529 -0
- mcp_proxy_adapter/commands/transport_management_command.py +144 -0
- mcp_proxy_adapter/commands/unload_command.py +158 -0
- mcp_proxy_adapter/config.py +99 -2
- mcp_proxy_adapter/core/auth_validator.py +606 -0
- mcp_proxy_adapter/core/certificate_utils.py +827 -0
- mcp_proxy_adapter/core/config_converter.py +405 -0
- mcp_proxy_adapter/core/config_validator.py +218 -0
- mcp_proxy_adapter/core/logging.py +11 -0
- mcp_proxy_adapter/core/protocol_manager.py +226 -0
- mcp_proxy_adapter/core/proxy_registration.py +270 -0
- mcp_proxy_adapter/core/role_utils.py +426 -0
- mcp_proxy_adapter/core/security_adapter.py +373 -0
- mcp_proxy_adapter/core/security_factory.py +239 -0
- mcp_proxy_adapter/core/settings.py +1 -0
- mcp_proxy_adapter/core/ssl_utils.py +233 -0
- mcp_proxy_adapter/core/transport_manager.py +292 -0
- mcp_proxy_adapter/custom_openapi.py +22 -11
- mcp_proxy_adapter/examples/basic_server/config.json +58 -23
- mcp_proxy_adapter/examples/basic_server/config_all_protocols.json +54 -0
- mcp_proxy_adapter/examples/basic_server/config_http.json +70 -0
- mcp_proxy_adapter/examples/basic_server/config_http_only.json +52 -0
- mcp_proxy_adapter/examples/basic_server/config_https.json +58 -0
- mcp_proxy_adapter/examples/basic_server/config_mtls.json +58 -0
- mcp_proxy_adapter/examples/basic_server/config_ssl.json +46 -0
- mcp_proxy_adapter/examples/basic_server/server.py +17 -1
- mcp_proxy_adapter/examples/custom_commands/__init__.py +1 -1
- mcp_proxy_adapter/examples/custom_commands/advanced_hooks.py +339 -23
- mcp_proxy_adapter/examples/custom_commands/auto_commands/test_command.py +105 -0
- mcp_proxy_adapter/examples/custom_commands/catalog/commands/test_command.py +129 -0
- mcp_proxy_adapter/examples/custom_commands/config.json +97 -41
- mcp_proxy_adapter/examples/custom_commands/config_all_protocols.json +46 -0
- mcp_proxy_adapter/examples/custom_commands/config_https_only.json +46 -0
- mcp_proxy_adapter/examples/custom_commands/config_https_transport.json +33 -0
- mcp_proxy_adapter/examples/custom_commands/config_mtls_only.json +46 -0
- mcp_proxy_adapter/examples/custom_commands/config_mtls_transport.json +33 -0
- mcp_proxy_adapter/examples/custom_commands/config_single_transport.json +33 -0
- mcp_proxy_adapter/examples/custom_commands/full_help_response.json +1 -0
- mcp_proxy_adapter/examples/custom_commands/generated_openapi.json +629 -0
- mcp_proxy_adapter/examples/custom_commands/get_openapi.py +103 -0
- mcp_proxy_adapter/examples/custom_commands/loadable_commands/test_ignored.py +129 -0
- mcp_proxy_adapter/examples/custom_commands/proxy_connection_manager.py +278 -0
- mcp_proxy_adapter/examples/custom_commands/server.py +92 -63
- mcp_proxy_adapter/examples/custom_commands/simple_openapi_server.py +75 -0
- mcp_proxy_adapter/examples/custom_commands/start_server_with_proxy_manager.py +299 -0
- mcp_proxy_adapter/examples/custom_commands/start_server_with_registration.py +278 -0
- mcp_proxy_adapter/examples/custom_commands/test_openapi.py +27 -0
- mcp_proxy_adapter/examples/custom_commands/test_registry.py +23 -0
- mcp_proxy_adapter/examples/custom_commands/test_simple.py +19 -0
- mcp_proxy_adapter/examples/custom_project_example/README.md +103 -0
- mcp_proxy_adapter/examples/custom_project_example/README_EN.md +103 -0
- mcp_proxy_adapter/examples/simple_custom_commands/README.md +149 -0
- mcp_proxy_adapter/examples/simple_custom_commands/README_EN.md +149 -0
- mcp_proxy_adapter/main.py +175 -0
- mcp_proxy_adapter/schemas/roles_schema.json +162 -0
- mcp_proxy_adapter/tests/unit/test_config.py +53 -0
- mcp_proxy_adapter/version.py +1 -1
- {mcp_proxy_adapter-4.1.0.dist-info → mcp_proxy_adapter-6.0.0.dist-info}/METADATA +2 -1
- mcp_proxy_adapter-6.0.0.dist-info/RECORD +179 -0
- mcp_proxy_adapter/commands/reload_settings_command.py +0 -125
- mcp_proxy_adapter-4.1.0.dist-info/RECORD +0 -110
- {mcp_proxy_adapter-4.1.0.dist-info → mcp_proxy_adapter-6.0.0.dist-info}/WHEEL +0 -0
- {mcp_proxy_adapter-4.1.0.dist-info → mcp_proxy_adapter-6.0.0.dist-info}/licenses/LICENSE +0 -0
- {mcp_proxy_adapter-4.1.0.dist-info → mcp_proxy_adapter-6.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,376 @@
|
|
1
|
+
"""
|
2
|
+
Unified Security Middleware for mcp_security_framework integration.
|
3
|
+
|
4
|
+
This module provides a single middleware that replaces AuthMiddleware, RateLimitMiddleware,
|
5
|
+
MTLSMiddleware, and RolesMiddleware with a unified security solution.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
from typing import Callable, Awaitable, Dict, Any, Optional
|
11
|
+
|
12
|
+
from fastapi import Request, Response
|
13
|
+
from starlette.responses import JSONResponse
|
14
|
+
|
15
|
+
from mcp_proxy_adapter.core.security_factory import SecurityFactory
|
16
|
+
from mcp_proxy_adapter.core.logging import logger
|
17
|
+
from .base import BaseMiddleware
|
18
|
+
|
19
|
+
|
20
|
+
class SecurityMiddleware(BaseMiddleware):
|
21
|
+
"""
|
22
|
+
Unified security middleware based on mcp_security_framework.
|
23
|
+
|
24
|
+
Replaces AuthMiddleware, RateLimitMiddleware, MTLSMiddleware, and RolesMiddleware
|
25
|
+
with a single, comprehensive security solution.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, app, config: Dict[str, Any]):
|
29
|
+
"""
|
30
|
+
Initialize unified security middleware.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
app: FastAPI application
|
34
|
+
config: mcp_proxy_adapter configuration dictionary
|
35
|
+
"""
|
36
|
+
super().__init__(app)
|
37
|
+
self.config = config
|
38
|
+
self.security_config = config.get("security", {})
|
39
|
+
|
40
|
+
# Create security adapter
|
41
|
+
self.security_adapter = SecurityFactory.create_security_adapter(config)
|
42
|
+
|
43
|
+
# Public paths that don't require security validation
|
44
|
+
self.public_paths = [
|
45
|
+
"/docs",
|
46
|
+
"/redoc",
|
47
|
+
"/openapi.json",
|
48
|
+
"/health",
|
49
|
+
"/favicon.ico"
|
50
|
+
]
|
51
|
+
|
52
|
+
# Add custom public paths from config
|
53
|
+
custom_public_paths = self.security_config.get("public_paths", [])
|
54
|
+
self.public_paths.extend(custom_public_paths)
|
55
|
+
|
56
|
+
logger.info(f"Security middleware initialized with {len(self.public_paths)} public paths")
|
57
|
+
|
58
|
+
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
59
|
+
"""
|
60
|
+
Process request and handle security validation.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
request: Request object
|
64
|
+
call_next: Next handler
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Response object
|
68
|
+
"""
|
69
|
+
try:
|
70
|
+
# Process request before calling the main handler
|
71
|
+
await self.before_request(request)
|
72
|
+
|
73
|
+
# Call the next middleware or main handler
|
74
|
+
return await call_next(request)
|
75
|
+
|
76
|
+
except SecurityValidationError as e:
|
77
|
+
# Handle security validation errors
|
78
|
+
return await self.handle_error(request, e)
|
79
|
+
except Exception as e:
|
80
|
+
# Handle other errors
|
81
|
+
logger.error(f"Unexpected error in security middleware: {e}")
|
82
|
+
return await self.handle_error(request, e)
|
83
|
+
|
84
|
+
async def before_request(self, request: Request) -> None:
|
85
|
+
"""
|
86
|
+
Process request before calling the main handler.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
request: FastAPI request object
|
90
|
+
"""
|
91
|
+
# Check if security is enabled
|
92
|
+
if not self.security_config.get("enabled", True):
|
93
|
+
logger.debug("Security middleware disabled, skipping validation")
|
94
|
+
return
|
95
|
+
|
96
|
+
# Check if path is public
|
97
|
+
path = request.url.path
|
98
|
+
if self._is_public_path(path):
|
99
|
+
logger.debug(f"Public path accessed: {path}")
|
100
|
+
return
|
101
|
+
|
102
|
+
try:
|
103
|
+
# Prepare request data for validation
|
104
|
+
request_data = await self._prepare_request_data_async(request)
|
105
|
+
|
106
|
+
# Validate request through security framework
|
107
|
+
logger.debug(f"Validating request data: {request_data}")
|
108
|
+
validation_result = self.security_adapter.validate_request(request_data)
|
109
|
+
logger.debug(f"Validation result: {validation_result}")
|
110
|
+
|
111
|
+
if not validation_result.get("is_valid", False):
|
112
|
+
error_message = validation_result.get("error_message", "Security validation failed")
|
113
|
+
error_code = validation_result.get("error_code", -32000)
|
114
|
+
raise SecurityValidationError(error_message, error_code)
|
115
|
+
|
116
|
+
# Store validation results in request state
|
117
|
+
request.state.security_result = validation_result
|
118
|
+
request.state.user_roles = validation_result.get("roles", [])
|
119
|
+
request.state.user_id = validation_result.get("user_id")
|
120
|
+
request.state.security_validated = True
|
121
|
+
|
122
|
+
logger.debug(f"Security validation successful for {request.state.user_id} "
|
123
|
+
f"with roles: {request.state.user_roles}")
|
124
|
+
|
125
|
+
except SecurityValidationError as e:
|
126
|
+
# Re-raise security validation errors
|
127
|
+
raise
|
128
|
+
except Exception as e:
|
129
|
+
logger.error(f"Unexpected error in security validation: {e}")
|
130
|
+
raise SecurityValidationError(f"Internal security error: {str(e)}", -32603)
|
131
|
+
|
132
|
+
async def _prepare_request_data_async(self, request: Request) -> Dict[str, Any]:
|
133
|
+
"""
|
134
|
+
Prepare request data for security validation (async version).
|
135
|
+
|
136
|
+
Args:
|
137
|
+
request: FastAPI request object
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
Dictionary with request data for validation
|
141
|
+
"""
|
142
|
+
# Extract basic request information
|
143
|
+
request_data = {
|
144
|
+
"method": request.method,
|
145
|
+
"path": request.url.path,
|
146
|
+
"headers": dict(request.headers),
|
147
|
+
"query_params": dict(request.query_params),
|
148
|
+
"client_ip": self._get_client_ip(request),
|
149
|
+
"body": {}
|
150
|
+
}
|
151
|
+
|
152
|
+
# Extract request body for POST/PUT/PATCH requests
|
153
|
+
if request.method in ["POST", "PUT", "PATCH"]:
|
154
|
+
try:
|
155
|
+
body = await request.body()
|
156
|
+
if body:
|
157
|
+
try:
|
158
|
+
request_data["body"] = json.loads(body.decode("utf-8"))
|
159
|
+
except json.JSONDecodeError:
|
160
|
+
# If not JSON, store as string
|
161
|
+
request_data["body"] = body.decode("utf-8", errors="ignore")
|
162
|
+
except Exception as e:
|
163
|
+
logger.warning(f"Failed to extract request body: {e}")
|
164
|
+
|
165
|
+
return request_data
|
166
|
+
|
167
|
+
def _prepare_request_data(self, request: Request) -> Dict[str, Any]:
|
168
|
+
"""
|
169
|
+
Prepare request data for security validation (sync version).
|
170
|
+
|
171
|
+
Args:
|
172
|
+
request: FastAPI request object
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
Dictionary with request data for validation
|
176
|
+
"""
|
177
|
+
# Extract basic request information
|
178
|
+
request_data = {
|
179
|
+
"method": request.method,
|
180
|
+
"path": request.url.path,
|
181
|
+
"headers": dict(request.headers),
|
182
|
+
"query_params": dict(request.query_params),
|
183
|
+
"client_ip": self._get_client_ip(request),
|
184
|
+
"body": {}
|
185
|
+
}
|
186
|
+
|
187
|
+
return request_data
|
188
|
+
|
189
|
+
def _get_client_ip(self, request: Request) -> str:
|
190
|
+
"""
|
191
|
+
Get client IP address from request.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
request: FastAPI request object
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
Client IP address string
|
198
|
+
"""
|
199
|
+
# Check for forwarded headers first
|
200
|
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
201
|
+
if forwarded_for:
|
202
|
+
return forwarded_for.split(",")[0].strip()
|
203
|
+
|
204
|
+
# Check for real IP header
|
205
|
+
real_ip = request.headers.get("X-Real-IP")
|
206
|
+
if real_ip:
|
207
|
+
return real_ip
|
208
|
+
|
209
|
+
# Fallback to client host
|
210
|
+
if request.client:
|
211
|
+
return request.client.host
|
212
|
+
|
213
|
+
return "unknown"
|
214
|
+
|
215
|
+
def _is_public_path(self, path: str) -> bool:
|
216
|
+
"""
|
217
|
+
Check if the path is public (doesn't require security validation).
|
218
|
+
|
219
|
+
Args:
|
220
|
+
path: Request path
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
True if path is public, False otherwise
|
224
|
+
"""
|
225
|
+
return any(path.startswith(public_path) for public_path in self.public_paths)
|
226
|
+
|
227
|
+
async def handle_error(self, request: Request, exception: Exception) -> Response:
|
228
|
+
"""
|
229
|
+
Handle security validation errors.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
request: FastAPI request object
|
233
|
+
exception: Exception that occurred
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
Error response
|
237
|
+
"""
|
238
|
+
if isinstance(exception, SecurityValidationError):
|
239
|
+
status_code = self._get_status_code_for_error(exception.error_code)
|
240
|
+
error_code = exception.error_code
|
241
|
+
error_message = exception.message
|
242
|
+
else:
|
243
|
+
status_code = 500
|
244
|
+
error_code = -32603
|
245
|
+
error_message = "Internal server error"
|
246
|
+
|
247
|
+
logger.warning(f"Security validation failed: {error_message} | "
|
248
|
+
f"Path: {request.url.path} | Code: {error_code}")
|
249
|
+
|
250
|
+
return JSONResponse(
|
251
|
+
status_code=status_code,
|
252
|
+
content={
|
253
|
+
"jsonrpc": "2.0",
|
254
|
+
"error": {
|
255
|
+
"code": error_code,
|
256
|
+
"message": error_message,
|
257
|
+
"data": {
|
258
|
+
"validation_type": "security",
|
259
|
+
"path": request.url.path,
|
260
|
+
"method": request.method,
|
261
|
+
"client_ip": self._get_client_ip(request)
|
262
|
+
}
|
263
|
+
},
|
264
|
+
"id": None
|
265
|
+
}
|
266
|
+
)
|
267
|
+
|
268
|
+
def _get_status_code_for_error(self, error_code: int) -> int:
|
269
|
+
"""
|
270
|
+
Map security error codes to HTTP status codes.
|
271
|
+
|
272
|
+
Args:
|
273
|
+
error_code: Security error code
|
274
|
+
|
275
|
+
Returns:
|
276
|
+
HTTP status code
|
277
|
+
"""
|
278
|
+
error_code_mapping = {
|
279
|
+
-32000: 401, # Authentication failed
|
280
|
+
-32001: 401, # Authentication disabled
|
281
|
+
-32002: 500, # Invalid configuration
|
282
|
+
-32003: 401, # Certificate validation failed
|
283
|
+
-32004: 401, # Token validation failed
|
284
|
+
-32005: 401, # MTLS validation failed
|
285
|
+
-32006: 401, # SSL validation failed
|
286
|
+
-32007: 403, # Role validation failed
|
287
|
+
-32008: 401, # Certificate expired
|
288
|
+
-32009: 401, # Certificate not found
|
289
|
+
-32010: 401, # Token expired
|
290
|
+
-32011: 401, # Token not found
|
291
|
+
-32603: 500, # Internal error
|
292
|
+
}
|
293
|
+
|
294
|
+
return error_code_mapping.get(error_code, 500)
|
295
|
+
|
296
|
+
def get_user_roles(self, request: Request) -> list:
|
297
|
+
"""
|
298
|
+
Get user roles from request state.
|
299
|
+
|
300
|
+
Args:
|
301
|
+
request: FastAPI request object
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
List of user roles
|
305
|
+
"""
|
306
|
+
return getattr(request.state, 'user_roles', [])
|
307
|
+
|
308
|
+
def get_user_id(self, request: Request) -> Optional[str]:
|
309
|
+
"""
|
310
|
+
Get user ID from request state.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
request: FastAPI request object
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
User ID or None
|
317
|
+
"""
|
318
|
+
return getattr(request.state, 'user_id', None)
|
319
|
+
|
320
|
+
def is_security_validated(self, request: Request) -> bool:
|
321
|
+
"""
|
322
|
+
Check if security validation passed for the request.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
request: FastAPI request object
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
True if security validation passed
|
329
|
+
"""
|
330
|
+
return getattr(request.state, 'security_validated', False)
|
331
|
+
|
332
|
+
def has_role(self, request: Request, required_role: str) -> bool:
|
333
|
+
"""
|
334
|
+
Check if user has required role.
|
335
|
+
|
336
|
+
Args:
|
337
|
+
request: FastAPI request object
|
338
|
+
required_role: Required role to check
|
339
|
+
|
340
|
+
Returns:
|
341
|
+
True if user has required role
|
342
|
+
"""
|
343
|
+
user_roles = self.get_user_roles(request)
|
344
|
+
return required_role in user_roles or "*" in user_roles
|
345
|
+
|
346
|
+
def has_any_role(self, request: Request, required_roles: list) -> bool:
|
347
|
+
"""
|
348
|
+
Check if user has any of the required roles.
|
349
|
+
|
350
|
+
Args:
|
351
|
+
request: FastAPI request object
|
352
|
+
required_roles: List of required roles
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
True if user has any of the required roles
|
356
|
+
"""
|
357
|
+
user_roles = self.get_user_roles(request)
|
358
|
+
return any(role in user_roles for role in required_roles) or "*" in user_roles
|
359
|
+
|
360
|
+
|
361
|
+
class SecurityValidationError(Exception):
|
362
|
+
"""
|
363
|
+
Exception raised when security validation fails.
|
364
|
+
"""
|
365
|
+
|
366
|
+
def __init__(self, message: str, error_code: int):
|
367
|
+
"""
|
368
|
+
Initialize security validation error.
|
369
|
+
|
370
|
+
Args:
|
371
|
+
message: Error message
|
372
|
+
error_code: JSON-RPC error code
|
373
|
+
"""
|
374
|
+
super().__init__(message)
|
375
|
+
self.message = message
|
376
|
+
self.error_code = error_code
|
@@ -0,0 +1,261 @@
|
|
1
|
+
"""
|
2
|
+
Token Authentication Middleware
|
3
|
+
|
4
|
+
This module provides middleware for token-based authentication using JWT and API tokens.
|
5
|
+
Supports extraction of tokens from headers and validation using AuthValidator.
|
6
|
+
|
7
|
+
Author: MCP Proxy Adapter Team
|
8
|
+
Version: 1.0.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
import json
|
12
|
+
import logging
|
13
|
+
from typing import Dict, Any, Optional, List
|
14
|
+
from pathlib import Path
|
15
|
+
|
16
|
+
from fastapi import Request, HTTPException
|
17
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
18
|
+
from starlette.responses import JSONResponse
|
19
|
+
|
20
|
+
from ...core.auth_validator import AuthValidator, AuthValidationResult
|
21
|
+
from ...core.logging import logger
|
22
|
+
|
23
|
+
|
24
|
+
class TokenAuthMiddleware(BaseHTTPMiddleware):
|
25
|
+
"""
|
26
|
+
Token authentication middleware.
|
27
|
+
|
28
|
+
Validates JWT and API tokens from request headers.
|
29
|
+
Integrates with AuthValidator for token validation.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, app, token_config: Dict[str, Any]):
|
33
|
+
"""
|
34
|
+
Initialize token authentication middleware.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
app: FastAPI application
|
38
|
+
token_config: Token configuration dictionary
|
39
|
+
"""
|
40
|
+
super().__init__(app)
|
41
|
+
self.token_config = token_config
|
42
|
+
self.auth_validator = AuthValidator()
|
43
|
+
self.logger = logging.getLogger(__name__)
|
44
|
+
|
45
|
+
# Load configuration
|
46
|
+
self.enabled = token_config.get("enabled", False)
|
47
|
+
self.header_name = token_config.get("header_name", "Authorization")
|
48
|
+
self.token_prefix = token_config.get("token_prefix", "Bearer")
|
49
|
+
self.tokens_file = token_config.get("tokens_file", "tokens.json")
|
50
|
+
self.token_expiry = token_config.get("token_expiry", 3600)
|
51
|
+
self.jwt_secret = token_config.get("jwt_secret", "")
|
52
|
+
self.jwt_algorithm = token_config.get("jwt_algorithm", "HS256")
|
53
|
+
|
54
|
+
# Load tokens if file exists
|
55
|
+
self.tokens = self._load_tokens()
|
56
|
+
|
57
|
+
self.logger.info(f"TokenAuthMiddleware initialized. Enabled: {self.enabled}")
|
58
|
+
|
59
|
+
async def dispatch(self, request: Request, call_next):
|
60
|
+
"""
|
61
|
+
Process request through token authentication middleware.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
request: FastAPI request object
|
65
|
+
call_next: Next middleware or endpoint
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
Response from next middleware or endpoint
|
69
|
+
"""
|
70
|
+
if not self.enabled:
|
71
|
+
return await call_next(request)
|
72
|
+
|
73
|
+
try:
|
74
|
+
# Extract token from header
|
75
|
+
auth_header = request.headers.get(self.header_name)
|
76
|
+
if not auth_header:
|
77
|
+
return self._create_auth_error("Authorization header required", 401)
|
78
|
+
|
79
|
+
# Validate token
|
80
|
+
is_valid = self._validate_token(auth_header)
|
81
|
+
if not is_valid:
|
82
|
+
return self._create_auth_error("Invalid or expired token", 401)
|
83
|
+
|
84
|
+
# Continue to next middleware/endpoint
|
85
|
+
return await call_next(request)
|
86
|
+
|
87
|
+
except Exception as e:
|
88
|
+
self.logger.error(f"Token authentication error: {e}")
|
89
|
+
return self._create_auth_error("Token authentication failed", 500)
|
90
|
+
|
91
|
+
def _validate_token(self, auth_header: str) -> bool:
|
92
|
+
"""
|
93
|
+
Validate token from authorization header.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
auth_header: Authorization header value
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
True if token is valid, False otherwise
|
100
|
+
"""
|
101
|
+
try:
|
102
|
+
# Extract token from header
|
103
|
+
if not auth_header.startswith(f"{self.token_prefix} "):
|
104
|
+
return False
|
105
|
+
|
106
|
+
token = auth_header[len(f"{self.token_prefix} "):].strip()
|
107
|
+
if not token:
|
108
|
+
return False
|
109
|
+
|
110
|
+
# Determine token type and validate
|
111
|
+
if self._is_jwt_token(token):
|
112
|
+
return self._validate_jwt_token(token)
|
113
|
+
else:
|
114
|
+
return self._validate_api_token(token)
|
115
|
+
|
116
|
+
except Exception as e:
|
117
|
+
self.logger.error(f"Token validation error: {e}")
|
118
|
+
return False
|
119
|
+
|
120
|
+
def _is_jwt_token(self, token: str) -> bool:
|
121
|
+
"""
|
122
|
+
Check if token is JWT format.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
token: Token string
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
True if token appears to be JWT, False otherwise
|
129
|
+
"""
|
130
|
+
# Basic JWT format check (header.payload.signature)
|
131
|
+
parts = token.split('.')
|
132
|
+
return len(parts) == 3
|
133
|
+
|
134
|
+
def _validate_jwt_token(self, token: str) -> bool:
|
135
|
+
"""
|
136
|
+
Validate JWT token.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
token: JWT token string
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
True if JWT token is valid, False otherwise
|
143
|
+
"""
|
144
|
+
try:
|
145
|
+
# Use AuthValidator for JWT validation
|
146
|
+
result = self.auth_validator.validate_token(token, "jwt")
|
147
|
+
return result.is_valid
|
148
|
+
|
149
|
+
except Exception as e:
|
150
|
+
self.logger.error(f"JWT validation error: {e}")
|
151
|
+
return False
|
152
|
+
|
153
|
+
def _validate_api_token(self, token: str) -> bool:
|
154
|
+
"""
|
155
|
+
Validate API token.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
token: API token string
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
True if API token is valid, False otherwise
|
162
|
+
"""
|
163
|
+
try:
|
164
|
+
# Check if token exists in loaded tokens
|
165
|
+
if token in self.tokens:
|
166
|
+
token_data = self.tokens[token]
|
167
|
+
|
168
|
+
# Check if token is active
|
169
|
+
if not token_data.get("active", True):
|
170
|
+
return False
|
171
|
+
|
172
|
+
# Check if token has expired
|
173
|
+
if "expires_at" in token_data:
|
174
|
+
import time
|
175
|
+
if time.time() > token_data["expires_at"]:
|
176
|
+
return False
|
177
|
+
|
178
|
+
return True
|
179
|
+
|
180
|
+
# Use AuthValidator for API token validation
|
181
|
+
result = self.auth_validator.validate_token(token, "api")
|
182
|
+
return result.is_valid
|
183
|
+
|
184
|
+
except Exception as e:
|
185
|
+
self.logger.error(f"API token validation error: {e}")
|
186
|
+
return False
|
187
|
+
|
188
|
+
def _load_tokens(self) -> Dict[str, Any]:
|
189
|
+
"""
|
190
|
+
Load tokens from configuration file.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
Dictionary of tokens and their metadata
|
194
|
+
"""
|
195
|
+
try:
|
196
|
+
if not self.tokens_file or not Path(self.tokens_file).exists():
|
197
|
+
return {}
|
198
|
+
|
199
|
+
with open(self.tokens_file, 'r', encoding='utf-8') as f:
|
200
|
+
tokens_data = json.load(f)
|
201
|
+
|
202
|
+
self.logger.info(f"Loaded {len(tokens_data)} tokens from {self.tokens_file}")
|
203
|
+
return tokens_data
|
204
|
+
|
205
|
+
except Exception as e:
|
206
|
+
self.logger.error(f"Failed to load tokens from {self.tokens_file}: {e}")
|
207
|
+
return {}
|
208
|
+
|
209
|
+
def _create_auth_error(self, message: str, status_code: int) -> JSONResponse:
|
210
|
+
"""
|
211
|
+
Create authentication error response.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
message: Error message
|
215
|
+
status_code: HTTP status code
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
JSONResponse with error details
|
219
|
+
"""
|
220
|
+
error_data = {
|
221
|
+
"error": {
|
222
|
+
"code": -32004, # Token validation failed
|
223
|
+
"message": message,
|
224
|
+
"type": "token_authentication_error"
|
225
|
+
}
|
226
|
+
}
|
227
|
+
|
228
|
+
return JSONResponse(
|
229
|
+
status_code=status_code,
|
230
|
+
content=error_data
|
231
|
+
)
|
232
|
+
|
233
|
+
def get_roles_from_token(self, auth_header: str) -> List[str]:
|
234
|
+
"""
|
235
|
+
Extract roles from token.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
auth_header: Authorization header value
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
List of roles extracted from token
|
242
|
+
"""
|
243
|
+
try:
|
244
|
+
if not auth_header.startswith(f"{self.token_prefix} "):
|
245
|
+
return []
|
246
|
+
|
247
|
+
token = auth_header[len(f"{self.token_prefix} "):].strip()
|
248
|
+
if not token:
|
249
|
+
return []
|
250
|
+
|
251
|
+
# Use AuthValidator to extract roles
|
252
|
+
if self._is_jwt_token(token):
|
253
|
+
result = self.auth_validator.validate_token(token, "jwt")
|
254
|
+
else:
|
255
|
+
result = self.auth_validator.validate_token(token, "api")
|
256
|
+
|
257
|
+
return result.roles if result.is_valid else []
|
258
|
+
|
259
|
+
except Exception as e:
|
260
|
+
self.logger.error(f"Failed to extract roles from token: {e}")
|
261
|
+
return []
|