mcp-proxy-adapter 4.1.1__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. mcp_proxy_adapter/__main__.py +12 -0
  2. mcp_proxy_adapter/api/app.py +138 -11
  3. mcp_proxy_adapter/api/handlers.py +16 -1
  4. mcp_proxy_adapter/api/middleware/__init__.py +30 -29
  5. mcp_proxy_adapter/api/middleware/auth_adapter.py +235 -0
  6. mcp_proxy_adapter/api/middleware/error_handling.py +9 -0
  7. mcp_proxy_adapter/api/middleware/factory.py +219 -0
  8. mcp_proxy_adapter/api/middleware/logging.py +32 -6
  9. mcp_proxy_adapter/api/middleware/mtls_adapter.py +305 -0
  10. mcp_proxy_adapter/api/middleware/mtls_middleware.py +296 -0
  11. mcp_proxy_adapter/api/middleware/protocol_middleware.py +135 -0
  12. mcp_proxy_adapter/api/middleware/rate_limit_adapter.py +241 -0
  13. mcp_proxy_adapter/api/middleware/roles_adapter.py +365 -0
  14. mcp_proxy_adapter/api/middleware/roles_middleware.py +381 -0
  15. mcp_proxy_adapter/api/middleware/security.py +376 -0
  16. mcp_proxy_adapter/api/middleware/token_auth_middleware.py +261 -0
  17. mcp_proxy_adapter/api/middleware/transport_middleware.py +122 -0
  18. mcp_proxy_adapter/commands/__init__.py +13 -4
  19. mcp_proxy_adapter/commands/auth_validation_command.py +408 -0
  20. mcp_proxy_adapter/commands/base.py +61 -30
  21. mcp_proxy_adapter/commands/builtin_commands.py +89 -0
  22. mcp_proxy_adapter/commands/catalog_manager.py +838 -0
  23. mcp_proxy_adapter/commands/cert_monitor_command.py +620 -0
  24. mcp_proxy_adapter/commands/certificate_management_command.py +608 -0
  25. mcp_proxy_adapter/commands/command_registry.py +703 -354
  26. mcp_proxy_adapter/commands/dependency_manager.py +245 -0
  27. mcp_proxy_adapter/commands/health_command.py +7 -0
  28. mcp_proxy_adapter/commands/hooks.py +200 -167
  29. mcp_proxy_adapter/commands/key_management_command.py +506 -0
  30. mcp_proxy_adapter/commands/load_command.py +176 -0
  31. mcp_proxy_adapter/commands/plugins_command.py +235 -0
  32. mcp_proxy_adapter/commands/protocol_management_command.py +232 -0
  33. mcp_proxy_adapter/commands/proxy_registration_command.py +268 -0
  34. mcp_proxy_adapter/commands/reload_command.py +48 -50
  35. mcp_proxy_adapter/commands/result.py +1 -0
  36. mcp_proxy_adapter/commands/roles_management_command.py +697 -0
  37. mcp_proxy_adapter/commands/ssl_setup_command.py +483 -0
  38. mcp_proxy_adapter/commands/token_management_command.py +529 -0
  39. mcp_proxy_adapter/commands/transport_management_command.py +144 -0
  40. mcp_proxy_adapter/commands/unload_command.py +158 -0
  41. mcp_proxy_adapter/config.py +99 -2
  42. mcp_proxy_adapter/core/auth_validator.py +606 -0
  43. mcp_proxy_adapter/core/certificate_utils.py +827 -0
  44. mcp_proxy_adapter/core/config_converter.py +405 -0
  45. mcp_proxy_adapter/core/config_validator.py +218 -0
  46. mcp_proxy_adapter/core/logging.py +11 -0
  47. mcp_proxy_adapter/core/protocol_manager.py +226 -0
  48. mcp_proxy_adapter/core/proxy_registration.py +270 -0
  49. mcp_proxy_adapter/core/role_utils.py +426 -0
  50. mcp_proxy_adapter/core/security_adapter.py +373 -0
  51. mcp_proxy_adapter/core/security_factory.py +239 -0
  52. mcp_proxy_adapter/core/settings.py +1 -0
  53. mcp_proxy_adapter/core/ssl_utils.py +233 -0
  54. mcp_proxy_adapter/core/transport_manager.py +292 -0
  55. mcp_proxy_adapter/custom_openapi.py +22 -11
  56. mcp_proxy_adapter/examples/basic_server/config.json +58 -23
  57. mcp_proxy_adapter/examples/basic_server/config_all_protocols.json +54 -0
  58. mcp_proxy_adapter/examples/basic_server/config_http.json +70 -0
  59. mcp_proxy_adapter/examples/basic_server/config_http_only.json +52 -0
  60. mcp_proxy_adapter/examples/basic_server/config_https.json +58 -0
  61. mcp_proxy_adapter/examples/basic_server/config_mtls.json +58 -0
  62. mcp_proxy_adapter/examples/basic_server/config_ssl.json +46 -0
  63. mcp_proxy_adapter/examples/basic_server/server.py +12 -1
  64. mcp_proxy_adapter/examples/custom_commands/__init__.py +1 -1
  65. mcp_proxy_adapter/examples/custom_commands/advanced_hooks.py +339 -23
  66. mcp_proxy_adapter/examples/custom_commands/auto_commands/test_command.py +105 -0
  67. mcp_proxy_adapter/examples/custom_commands/catalog/commands/test_command.py +129 -0
  68. mcp_proxy_adapter/examples/custom_commands/config.json +101 -18
  69. mcp_proxy_adapter/examples/custom_commands/config_all_protocols.json +46 -0
  70. mcp_proxy_adapter/examples/custom_commands/config_https_only.json +46 -0
  71. mcp_proxy_adapter/examples/custom_commands/config_https_transport.json +33 -0
  72. mcp_proxy_adapter/examples/custom_commands/config_mtls_only.json +46 -0
  73. mcp_proxy_adapter/examples/custom_commands/config_mtls_transport.json +33 -0
  74. mcp_proxy_adapter/examples/custom_commands/config_single_transport.json +33 -0
  75. mcp_proxy_adapter/examples/custom_commands/full_help_response.json +1 -0
  76. mcp_proxy_adapter/examples/custom_commands/generated_openapi.json +629 -0
  77. mcp_proxy_adapter/examples/custom_commands/get_openapi.py +103 -0
  78. mcp_proxy_adapter/examples/custom_commands/loadable_commands/test_ignored.py +129 -0
  79. mcp_proxy_adapter/examples/custom_commands/proxy_connection_manager.py +278 -0
  80. mcp_proxy_adapter/examples/custom_commands/server.py +92 -68
  81. mcp_proxy_adapter/examples/custom_commands/simple_openapi_server.py +75 -0
  82. mcp_proxy_adapter/examples/custom_commands/start_server_with_proxy_manager.py +299 -0
  83. mcp_proxy_adapter/examples/custom_commands/start_server_with_registration.py +278 -0
  84. mcp_proxy_adapter/examples/custom_commands/test_openapi.py +27 -0
  85. mcp_proxy_adapter/examples/custom_commands/test_registry.py +23 -0
  86. mcp_proxy_adapter/examples/custom_commands/test_simple.py +19 -0
  87. mcp_proxy_adapter/examples/custom_project_example/README.md +103 -0
  88. mcp_proxy_adapter/examples/custom_project_example/README_EN.md +103 -0
  89. mcp_proxy_adapter/examples/simple_custom_commands/README.md +149 -0
  90. mcp_proxy_adapter/examples/simple_custom_commands/README_EN.md +149 -0
  91. mcp_proxy_adapter/main.py +175 -0
  92. mcp_proxy_adapter/schemas/roles_schema.json +162 -0
  93. mcp_proxy_adapter/tests/unit/test_config.py +53 -0
  94. mcp_proxy_adapter/version.py +1 -1
  95. {mcp_proxy_adapter-4.1.1.dist-info → mcp_proxy_adapter-6.0.0.dist-info}/METADATA +2 -1
  96. mcp_proxy_adapter-6.0.0.dist-info/RECORD +179 -0
  97. mcp_proxy_adapter/commands/reload_settings_command.py +0 -125
  98. mcp_proxy_adapter-4.1.1.dist-info/RECORD +0 -110
  99. {mcp_proxy_adapter-4.1.1.dist-info → mcp_proxy_adapter-6.0.0.dist-info}/WHEEL +0 -0
  100. {mcp_proxy_adapter-4.1.1.dist-info → mcp_proxy_adapter-6.0.0.dist-info}/licenses/LICENSE +0 -0
  101. {mcp_proxy_adapter-4.1.1.dist-info → mcp_proxy_adapter-6.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,296 @@
1
+ """
2
+ mTLS Middleware
3
+
4
+ This module provides middleware for mutual TLS (mTLS) authentication.
5
+ Extracts and validates client certificates, extracts roles, and validates access.
6
+
7
+ Author: MCP Proxy Adapter Team
8
+ Version: 1.0.0
9
+ """
10
+
11
+ import logging
12
+ from typing import Dict, List, Optional, Any
13
+ from cryptography import x509
14
+ from cryptography.hazmat.primitives import serialization
15
+
16
+ from fastapi import Request, Response
17
+ from starlette.middleware.base import BaseHTTPMiddleware
18
+
19
+ from ...core.auth_validator import AuthValidator
20
+ from ...core.role_utils import RoleUtils
21
+ from ...core.certificate_utils import CertificateUtils
22
+ from .base import BaseMiddleware
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class MTLSMiddleware(BaseMiddleware):
28
+ """
29
+ Middleware for mTLS authentication.
30
+
31
+ Extracts client certificates from requests, validates them against CA,
32
+ extracts roles, and validates access based on configuration.
33
+ """
34
+
35
+ def __init__(self, app, mtls_config: Dict[str, Any]):
36
+ """
37
+ Initialize mTLS middleware.
38
+
39
+ Args:
40
+ app: FastAPI application
41
+ mtls_config: mTLS configuration dictionary
42
+ """
43
+ super().__init__(app)
44
+ self.mtls_config = mtls_config
45
+ self.auth_validator = AuthValidator()
46
+ self.role_utils = RoleUtils()
47
+ self.certificate_utils = CertificateUtils()
48
+
49
+ # Extract configuration
50
+ self.enabled = mtls_config.get("enabled", False)
51
+ self.ca_cert_path = mtls_config.get("ca_cert")
52
+ self.verify_client = mtls_config.get("verify_client", True)
53
+ self.client_cert_required = mtls_config.get("client_cert_required", True)
54
+ self.allowed_roles = mtls_config.get("allowed_roles", [])
55
+ self.require_roles = mtls_config.get("require_roles", False)
56
+
57
+ logger.info(f"mTLS middleware initialized: enabled={self.enabled}, "
58
+ f"verify_client={self.verify_client}, "
59
+ f"client_cert_required={self.client_cert_required}")
60
+
61
+ async def before_request(self, request: Request) -> None:
62
+ """
63
+ Process request before calling the main handler.
64
+
65
+ Args:
66
+ request: FastAPI request object
67
+ """
68
+ if not self.enabled:
69
+ return
70
+
71
+ try:
72
+ # Extract client certificate
73
+ client_cert = self._extract_client_certificate(request)
74
+
75
+ if client_cert is None:
76
+ if self.client_cert_required:
77
+ raise ValueError("Client certificate is required but not provided")
78
+ return
79
+
80
+ # Validate client certificate
81
+ if not self._validate_client_certificate(client_cert):
82
+ raise ValueError("Client certificate validation failed")
83
+
84
+ # Extract roles from certificate
85
+ roles = self._extract_roles_from_certificate(client_cert)
86
+
87
+ # Validate access based on roles
88
+ if self.require_roles and not self._validate_access(roles):
89
+ raise ValueError("Access denied: insufficient roles")
90
+
91
+ # Store certificate and roles in request state
92
+ request.state.client_certificate = client_cert
93
+ request.state.client_roles = roles
94
+ request.state.client_common_name = self._get_common_name(client_cert)
95
+
96
+ logger.debug(f"mTLS authentication successful for {request.state.client_common_name} "
97
+ f"with roles: {roles}")
98
+
99
+ except Exception as e:
100
+ logger.error(f"mTLS authentication failed: {e}")
101
+ raise
102
+
103
+ def _extract_client_certificate(self, request: Request) -> Optional[x509.Certificate]:
104
+ """
105
+ Extract client certificate from request.
106
+
107
+ Args:
108
+ request: FastAPI request object
109
+
110
+ Returns:
111
+ Client certificate object or None if not found
112
+ """
113
+ try:
114
+ # Check if client certificate is available in SSL context
115
+ if hasattr(request, 'scope') and 'ssl' in request.scope:
116
+ ssl_context = request.scope['ssl']
117
+ if hasattr(ssl_context, 'getpeercert'):
118
+ cert_data = ssl_context.getpeercert(binary_form=True)
119
+ if cert_data:
120
+ return x509.load_der_x509_certificate(cert_data)
121
+
122
+ # Check for certificate in headers (for proxy scenarios)
123
+ cert_header = request.headers.get('ssl-client-cert')
124
+ if cert_header:
125
+ # Remove header prefix if present
126
+ if cert_header.startswith('-----BEGIN CERTIFICATE-----'):
127
+ cert_data = cert_header.encode('utf-8')
128
+ else:
129
+ # Assume it's base64 encoded
130
+ import base64
131
+ cert_data = base64.b64decode(cert_header)
132
+
133
+ return x509.load_pem_x509_certificate(cert_data)
134
+
135
+ return None
136
+
137
+ except Exception as e:
138
+ logger.error(f"Failed to extract client certificate: {e}")
139
+ return None
140
+
141
+ def _validate_client_certificate(self, cert: x509.Certificate) -> bool:
142
+ """
143
+ Validate client certificate.
144
+
145
+ Args:
146
+ cert: Client certificate object
147
+
148
+ Returns:
149
+ True if certificate is valid, False otherwise
150
+ """
151
+ try:
152
+ if not self.verify_client:
153
+ return True
154
+
155
+ # Convert certificate to PEM format for validation
156
+ cert_pem = cert.public_bytes(serialization.Encoding.PEM)
157
+
158
+ # Use AuthValidator to validate certificate
159
+ result = self.auth_validator.validate_certificate_data(cert_pem)
160
+ if not result.is_valid:
161
+ logger.warning(f"Certificate validation failed: {result.error_message}")
162
+ return False
163
+
164
+ # Validate certificate chain if CA is provided
165
+ if self.ca_cert_path and self.ca_cert_path != "None":
166
+ # Create temporary file for certificate
167
+ import tempfile
168
+ import os
169
+
170
+ with tempfile.NamedTemporaryFile(mode='wb', suffix='.crt', delete=False) as f:
171
+ f.write(cert_pem)
172
+ temp_cert_path = f.name
173
+
174
+ try:
175
+ chain_valid = self.certificate_utils.validate_certificate_chain(
176
+ temp_cert_path, self.ca_cert_path
177
+ )
178
+ if not chain_valid:
179
+ logger.warning("Certificate chain validation failed")
180
+ return False
181
+ finally:
182
+ os.unlink(temp_cert_path)
183
+
184
+ return True
185
+
186
+ except Exception as e:
187
+ logger.error(f"Failed to validate client certificate: {e}")
188
+ return False
189
+
190
+ def _extract_roles_from_certificate(self, cert: x509.Certificate) -> List[str]:
191
+ """
192
+ Extract roles from client certificate.
193
+
194
+ Args:
195
+ cert: Client certificate object
196
+
197
+ Returns:
198
+ List of roles extracted from certificate
199
+ """
200
+ try:
201
+ return self.certificate_utils.extract_roles_from_certificate_object(cert)
202
+ except Exception as e:
203
+ logger.error(f"Failed to extract roles from certificate: {e}")
204
+ return []
205
+
206
+ def _validate_access(self, roles: List[str]) -> bool:
207
+ """
208
+ Validate access based on roles.
209
+
210
+ Args:
211
+ roles: List of roles from client certificate
212
+
213
+ Returns:
214
+ True if access is allowed, False otherwise
215
+ """
216
+ try:
217
+ if not self.allowed_roles:
218
+ return True
219
+
220
+ if not roles:
221
+ return False
222
+
223
+ # Check if any of the client roles match allowed roles
224
+ for client_role in roles:
225
+ for allowed_role in self.allowed_roles:
226
+ if self.role_utils.compare_roles(client_role, allowed_role):
227
+ return True
228
+
229
+ return False
230
+
231
+ except Exception as e:
232
+ logger.error(f"Failed to validate access: {e}")
233
+ return False
234
+
235
+ def _get_common_name(self, cert: x509.Certificate) -> str:
236
+ """
237
+ Get common name from certificate.
238
+
239
+ Args:
240
+ cert: Certificate object
241
+
242
+ Returns:
243
+ Common name or empty string if not found
244
+ """
245
+ try:
246
+ for name_attribute in cert.subject:
247
+ if name_attribute.oid == x509.NameOID.COMMON_NAME:
248
+ return str(name_attribute.value)
249
+ return ""
250
+ except Exception as e:
251
+ logger.error(f"Failed to get common name: {e}")
252
+ return ""
253
+
254
+ async def handle_error(self, request: Request, exception: Exception) -> Response:
255
+ """
256
+ Handle mTLS authentication errors.
257
+
258
+ Args:
259
+ request: FastAPI request object
260
+ exception: Exception that occurred
261
+
262
+ Returns:
263
+ Error response
264
+ """
265
+ from fastapi.responses import JSONResponse
266
+
267
+ error_message = str(exception)
268
+
269
+ if "certificate is required" in error_message.lower():
270
+ status_code = 401
271
+ error_code = -32009 # Certificate not found
272
+ elif "validation failed" in error_message.lower():
273
+ status_code = 401
274
+ error_code = -32003 # Certificate validation failed
275
+ elif "access denied" in error_message.lower():
276
+ status_code = 403
277
+ error_code = -32007 # Role validation failed
278
+ else:
279
+ status_code = 500
280
+ error_code = -32603 # Internal error
281
+
282
+ return JSONResponse(
283
+ status_code=status_code,
284
+ content={
285
+ "jsonrpc": "2.0",
286
+ "error": {
287
+ "code": error_code,
288
+ "message": error_message,
289
+ "data": {
290
+ "validation_type": "mtls",
291
+ "request_id": getattr(request.state, 'request_id', None)
292
+ }
293
+ },
294
+ "id": None
295
+ }
296
+ )
@@ -0,0 +1,135 @@
1
+ """
2
+ Protocol middleware module.
3
+
4
+ This module provides middleware for validating protocol access based on configuration.
5
+ """
6
+
7
+ from typing import Callable
8
+ from fastapi import Request, Response
9
+ from starlette.middleware.base import BaseHTTPMiddleware
10
+ from starlette.responses import JSONResponse
11
+
12
+ from mcp_proxy_adapter.core.protocol_manager import protocol_manager
13
+ from mcp_proxy_adapter.core.logging import logger
14
+
15
+
16
+ class ProtocolMiddleware(BaseHTTPMiddleware):
17
+ """
18
+ Middleware for protocol validation.
19
+
20
+ This middleware checks if the incoming request protocol is allowed
21
+ based on the protocol configuration.
22
+ """
23
+
24
+ def __init__(self, app, protocol_manager_instance=None):
25
+ """
26
+ Initialize protocol middleware.
27
+
28
+ Args:
29
+ app: FastAPI application
30
+ protocol_manager_instance: Protocol manager instance (optional)
31
+ """
32
+ super().__init__(app)
33
+ self.protocol_manager = protocol_manager_instance or protocol_manager
34
+
35
+ async def dispatch(self, request: Request, call_next: Callable) -> Response:
36
+ """
37
+ Process request through protocol middleware.
38
+
39
+ Args:
40
+ request: Incoming request
41
+ call_next: Next middleware/endpoint function
42
+
43
+ Returns:
44
+ Response object
45
+ """
46
+ try:
47
+ # Get protocol from request
48
+ protocol = self._get_request_protocol(request)
49
+
50
+ # Check if protocol is allowed
51
+ if not self.protocol_manager.is_protocol_allowed(protocol):
52
+ logger.warning(f"Protocol '{protocol}' not allowed for request to {request.url.path}")
53
+ return JSONResponse(
54
+ status_code=403,
55
+ content={
56
+ "error": "Protocol not allowed",
57
+ "message": f"Protocol '{protocol}' is not allowed. Allowed protocols: {self.protocol_manager.get_allowed_protocols()}",
58
+ "allowed_protocols": self.protocol_manager.get_allowed_protocols()
59
+ }
60
+ )
61
+
62
+ # Continue processing
63
+ response = await call_next(request)
64
+ return response
65
+
66
+ except Exception as e:
67
+ logger.error(f"Protocol middleware error: {e}")
68
+ return JSONResponse(
69
+ status_code=500,
70
+ content={
71
+ "error": "Protocol validation error",
72
+ "message": str(e)
73
+ }
74
+ )
75
+
76
+ def _get_request_protocol(self, request: Request) -> str:
77
+ """
78
+ Extract protocol from request.
79
+
80
+ Args:
81
+ request: FastAPI request object
82
+
83
+ Returns:
84
+ Protocol name (http, https, mtls)
85
+ """
86
+ # Check if request is secure (HTTPS)
87
+ if request.url.scheme:
88
+ scheme = request.url.scheme.lower()
89
+
90
+ # If HTTPS, check if client certificate is provided (MTLS)
91
+ if scheme == "https":
92
+ # Check for client certificate in headers or SSL context
93
+ if hasattr(request, 'scope') and 'ssl' in request.scope:
94
+ ssl_context = request.scope.get('ssl')
95
+ if ssl_context and hasattr(ssl_context, 'getpeercert'):
96
+ try:
97
+ cert = ssl_context.getpeercert()
98
+ if cert:
99
+ return "mtls"
100
+ except:
101
+ pass
102
+
103
+ # Check for client certificate in headers
104
+ if request.headers.get("ssl-client-cert") or request.headers.get("x-client-cert"):
105
+ return "mtls"
106
+
107
+ return "https"
108
+
109
+ return scheme
110
+
111
+ # Fallback to checking headers
112
+ if request.headers.get("x-forwarded-proto"):
113
+ return request.headers.get("x-forwarded-proto").lower()
114
+
115
+ # Default to HTTP
116
+ return "http"
117
+
118
+
119
+ def setup_protocol_middleware(app, protocol_manager_instance=None):
120
+ """
121
+ Setup protocol middleware for FastAPI application.
122
+
123
+ Args:
124
+ app: FastAPI application
125
+ protocol_manager_instance: Protocol manager instance (optional)
126
+ """
127
+ if protocol_manager_instance is None:
128
+ protocol_manager_instance = protocol_manager
129
+
130
+ # Only add middleware if protocol management is enabled
131
+ if protocol_manager_instance.enabled:
132
+ app.add_middleware(ProtocolMiddleware, protocol_manager_instance=protocol_manager_instance)
133
+ logger.info("Protocol middleware added to application")
134
+ else:
135
+ logger.debug("Protocol management is disabled, skipping protocol middleware")
@@ -0,0 +1,241 @@
1
+ """
2
+ Rate Limit Middleware Adapter for backward compatibility.
3
+
4
+ This module provides an adapter that maintains the same interface as RateLimitMiddleware
5
+ while using the new SecurityMiddleware internally.
6
+ """
7
+
8
+ import time
9
+ from typing import Dict, List, Callable, Awaitable, Any
10
+ from collections import defaultdict
11
+
12
+ from fastapi import Request, Response
13
+ from starlette.responses import JSONResponse
14
+
15
+ from mcp_proxy_adapter.core.logging import logger
16
+ from .base import BaseMiddleware
17
+ from .security import SecurityMiddleware
18
+
19
+
20
+ class RateLimitMiddlewareAdapter(BaseMiddleware):
21
+ """
22
+ Adapter for RateLimitMiddleware that uses SecurityMiddleware internally.
23
+
24
+ Maintains the same interface as the original RateLimitMiddleware for backward compatibility.
25
+ """
26
+
27
+ def __init__(self, app, rate_limit: int = 100, time_window: int = 60,
28
+ by_ip: bool = True, by_user: bool = True,
29
+ public_paths: List[str] = None):
30
+ """
31
+ Initialize rate limit middleware adapter.
32
+
33
+ Args:
34
+ app: FastAPI application
35
+ rate_limit: Maximum number of requests in the specified time period
36
+ time_window: Time period in seconds
37
+ by_ip: Limit requests by IP address
38
+ by_user: Limit requests by user
39
+ public_paths: List of paths for which rate limiting is not applied
40
+ """
41
+ super().__init__(app)
42
+
43
+ # Store original parameters for backward compatibility
44
+ self.rate_limit = rate_limit
45
+ self.time_window = time_window
46
+ self.by_ip = by_ip
47
+ self.by_user = by_user
48
+ self.public_paths = public_paths or [
49
+ "/docs",
50
+ "/redoc",
51
+ "/openapi.json",
52
+ "/health"
53
+ ]
54
+
55
+ # Legacy storage for backward compatibility
56
+ self.ip_requests = defaultdict(list)
57
+ self.user_requests = defaultdict(list)
58
+
59
+ # Create internal security middleware
60
+ self.security_middleware = self._create_security_middleware()
61
+
62
+ logger.info(f"RateLimitMiddlewareAdapter initialized: rate_limit={rate_limit}, "
63
+ f"time_window={time_window}, by_ip={by_ip}, by_user={by_user}")
64
+
65
+ def _create_security_middleware(self) -> SecurityMiddleware:
66
+ """
67
+ Create internal SecurityMiddleware with RateLimitMiddleware configuration.
68
+
69
+ Returns:
70
+ SecurityMiddleware instance
71
+ """
72
+ # Convert RateLimitMiddleware config to SecurityMiddleware config
73
+ security_config = {
74
+ "security": {
75
+ "enabled": True,
76
+ "auth": {
77
+ "enabled": False
78
+ },
79
+ "ssl": {
80
+ "enabled": False
81
+ },
82
+ "permissions": {
83
+ "enabled": False
84
+ },
85
+ "rate_limit": {
86
+ "enabled": True,
87
+ "requests_per_minute": self.rate_limit,
88
+ "requests_per_hour": self.rate_limit * 60,
89
+ "burst_limit": self.rate_limit // 10,
90
+ "by_ip": self.by_ip,
91
+ "by_user": self.by_user
92
+ },
93
+ "public_paths": self.public_paths
94
+ }
95
+ }
96
+
97
+ return SecurityMiddleware(self.app, security_config)
98
+
99
+ async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
100
+ """
101
+ Process request using internal SecurityMiddleware with legacy fallback.
102
+
103
+ Args:
104
+ request: Request object
105
+ call_next: Next handler
106
+
107
+ Returns:
108
+ Response object
109
+ """
110
+ # Check if path is public
111
+ path = request.url.path
112
+ if self._is_public_path(path):
113
+ return await call_next(request)
114
+
115
+ # Try to use SecurityMiddleware first
116
+ try:
117
+ await self.security_middleware.before_request(request)
118
+ return await call_next(request)
119
+
120
+ except Exception as e:
121
+ # Fallback to legacy rate limiting if SecurityMiddleware fails
122
+ logger.warning(f"SecurityMiddleware rate limiting failed, using legacy fallback: {e}")
123
+ return await self._legacy_rate_limit_check(request, call_next)
124
+
125
+ async def _legacy_rate_limit_check(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
126
+ """
127
+ Legacy rate limiting implementation as fallback.
128
+
129
+ Args:
130
+ request: Request object
131
+ call_next: Next handler
132
+
133
+ Returns:
134
+ Response object
135
+ """
136
+ current_time = time.time()
137
+ client_ip = request.client.host if request.client else "unknown"
138
+ username = getattr(request.state, "username", None)
139
+
140
+ # Check limit by IP
141
+ if self.by_ip and client_ip != "unknown":
142
+ self._clean_old_requests(self.ip_requests[client_ip], current_time)
143
+
144
+ if len(self.ip_requests[client_ip]) >= self.rate_limit:
145
+ logger.warning(f"Rate limit exceeded for IP: {client_ip}")
146
+ return self._create_error_response("Rate limit exceeded", 429)
147
+
148
+ self.ip_requests[client_ip].append(current_time)
149
+
150
+ # Check limit by user
151
+ if self.by_user and username:
152
+ self._clean_old_requests(self.user_requests[username], current_time)
153
+
154
+ if len(self.user_requests[username]) >= self.rate_limit:
155
+ logger.warning(f"Rate limit exceeded for user: {username}")
156
+ return self._create_error_response("Rate limit exceeded", 429)
157
+
158
+ self.user_requests[username].append(current_time)
159
+
160
+ return await call_next(request)
161
+
162
+ def _clean_old_requests(self, requests_list: List[float], current_time: float) -> None:
163
+ """
164
+ Remove old requests from the list.
165
+
166
+ Args:
167
+ requests_list: List of request timestamps
168
+ current_time: Current time
169
+ """
170
+ cutoff_time = current_time - self.time_window
171
+ requests_list[:] = [req_time for req_time in requests_list if req_time > cutoff_time]
172
+
173
+ def _is_public_path(self, path: str) -> bool:
174
+ """
175
+ Check if the path is public (doesn't require rate limiting).
176
+
177
+ Args:
178
+ path: Request path
179
+
180
+ Returns:
181
+ True if path is public, False otherwise
182
+ """
183
+ return any(path.startswith(public_path) for public_path in self.public_paths)
184
+
185
+ def _create_error_response(self, message: str, status_code: int) -> JSONResponse:
186
+ """
187
+ Create error response in RateLimitMiddleware format.
188
+
189
+ Args:
190
+ message: Error message
191
+ status_code: HTTP status code
192
+
193
+ Returns:
194
+ JSONResponse with error
195
+ """
196
+ return JSONResponse(
197
+ status_code=status_code,
198
+ content={
199
+ "jsonrpc": "2.0",
200
+ "error": {
201
+ "code": -32008 if status_code == 429 else -32603,
202
+ "message": message,
203
+ "data": {
204
+ "rate_limit": self.rate_limit,
205
+ "time_window": self.time_window,
206
+ "status_code": status_code
207
+ }
208
+ },
209
+ "id": None
210
+ }
211
+ )
212
+
213
+ def get_rate_limit_info(self, request: Request) -> Dict[str, Any]:
214
+ """
215
+ Get rate limit information for the request (backward compatibility).
216
+
217
+ Args:
218
+ request: Request object
219
+
220
+ Returns:
221
+ Dictionary with rate limit information
222
+ """
223
+ client_ip = request.client.host if request.client else "unknown"
224
+ username = getattr(request.state, "username", None)
225
+
226
+ info = {
227
+ "rate_limit": self.rate_limit,
228
+ "time_window": self.time_window,
229
+ "by_ip": self.by_ip,
230
+ "by_user": self.by_user
231
+ }
232
+
233
+ if self.by_ip and client_ip != "unknown":
234
+ info["ip_requests"] = len(self.ip_requests[client_ip])
235
+ info["ip_remaining"] = max(0, self.rate_limit - len(self.ip_requests[client_ip]))
236
+
237
+ if self.by_user and username:
238
+ info["user_requests"] = len(self.user_requests[username])
239
+ info["user_remaining"] = max(0, self.rate_limit - len(self.user_requests[username]))
240
+
241
+ return info