isa-model 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/client.py +732 -565
  3. isa_model/core/cache/redis_cache.py +401 -0
  4. isa_model/core/config/config_manager.py +53 -10
  5. isa_model/core/config.py +1 -1
  6. isa_model/core/database/__init__.py +1 -0
  7. isa_model/core/database/migrations.py +277 -0
  8. isa_model/core/database/supabase_client.py +123 -0
  9. isa_model/core/models/__init__.py +37 -0
  10. isa_model/core/models/model_billing_tracker.py +60 -88
  11. isa_model/core/models/model_manager.py +36 -18
  12. isa_model/core/models/model_repo.py +44 -38
  13. isa_model/core/models/model_statistics_tracker.py +234 -0
  14. isa_model/core/models/model_storage.py +0 -1
  15. isa_model/core/models/model_version_manager.py +959 -0
  16. isa_model/core/pricing_manager.py +2 -249
  17. isa_model/core/resilience/circuit_breaker.py +366 -0
  18. isa_model/core/security/secrets.py +358 -0
  19. isa_model/core/services/__init__.py +2 -4
  20. isa_model/core/services/intelligent_model_selector.py +101 -370
  21. isa_model/core/storage/hf_storage.py +1 -1
  22. isa_model/core/types.py +7 -0
  23. isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
  24. isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
  25. isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
  26. isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
  27. isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
  28. isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
  29. isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
  30. isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
  31. isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
  32. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
  33. isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
  34. isa_model/deployment/core/deployment_manager.py +6 -4
  35. isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
  36. isa_model/eval/benchmarks/__init__.py +27 -0
  37. isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
  38. isa_model/eval/benchmarks.py +244 -12
  39. isa_model/eval/evaluators/__init__.py +8 -2
  40. isa_model/eval/evaluators/audio_evaluator.py +727 -0
  41. isa_model/eval/evaluators/embedding_evaluator.py +742 -0
  42. isa_model/eval/evaluators/vision_evaluator.py +564 -0
  43. isa_model/eval/example_evaluation.py +395 -0
  44. isa_model/eval/factory.py +272 -5
  45. isa_model/eval/isa_benchmarks.py +700 -0
  46. isa_model/eval/isa_integration.py +582 -0
  47. isa_model/eval/metrics.py +159 -6
  48. isa_model/eval/tests/unit/test_basic.py +396 -0
  49. isa_model/inference/ai_factory.py +44 -8
  50. isa_model/inference/services/audio/__init__.py +21 -0
  51. isa_model/inference/services/audio/base_realtime_service.py +225 -0
  52. isa_model/inference/services/audio/isa_tts_service.py +0 -0
  53. isa_model/inference/services/audio/openai_realtime_service.py +320 -124
  54. isa_model/inference/services/audio/openai_stt_service.py +32 -6
  55. isa_model/inference/services/base_service.py +17 -1
  56. isa_model/inference/services/embedding/__init__.py +13 -0
  57. isa_model/inference/services/embedding/base_embed_service.py +111 -8
  58. isa_model/inference/services/embedding/isa_embed_service.py +305 -0
  59. isa_model/inference/services/embedding/openai_embed_service.py +2 -4
  60. isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
  61. isa_model/inference/services/img/__init__.py +2 -2
  62. isa_model/inference/services/img/base_image_gen_service.py +24 -7
  63. isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
  64. isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
  65. isa_model/inference/services/img/services/replicate_flux.py +226 -0
  66. isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
  67. isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
  68. isa_model/inference/services/img/tests/test_img_client.py +297 -0
  69. isa_model/inference/services/llm/base_llm_service.py +30 -6
  70. isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
  71. isa_model/inference/services/llm/ollama_llm_service.py +2 -1
  72. isa_model/inference/services/llm/openai_llm_service.py +652 -55
  73. isa_model/inference/services/llm/yyds_llm_service.py +2 -1
  74. isa_model/inference/services/vision/__init__.py +5 -5
  75. isa_model/inference/services/vision/base_vision_service.py +118 -185
  76. isa_model/inference/services/vision/helpers/image_utils.py +11 -5
  77. isa_model/inference/services/vision/isa_vision_service.py +573 -0
  78. isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
  79. isa_model/serving/api/fastapi_server.py +88 -16
  80. isa_model/serving/api/middleware/auth.py +311 -0
  81. isa_model/serving/api/middleware/security.py +278 -0
  82. isa_model/serving/api/routes/analytics.py +486 -0
  83. isa_model/serving/api/routes/deployments.py +339 -0
  84. isa_model/serving/api/routes/evaluations.py +579 -0
  85. isa_model/serving/api/routes/logs.py +430 -0
  86. isa_model/serving/api/routes/settings.py +582 -0
  87. isa_model/serving/api/routes/unified.py +324 -165
  88. isa_model/serving/api/startup.py +304 -0
  89. isa_model/serving/modal_proxy_server.py +249 -0
  90. isa_model/training/__init__.py +100 -6
  91. isa_model/training/core/__init__.py +4 -1
  92. isa_model/training/examples/intelligent_training_example.py +281 -0
  93. isa_model/training/intelligent/__init__.py +25 -0
  94. isa_model/training/intelligent/decision_engine.py +643 -0
  95. isa_model/training/intelligent/intelligent_factory.py +888 -0
  96. isa_model/training/intelligent/knowledge_base.py +751 -0
  97. isa_model/training/intelligent/resource_optimizer.py +839 -0
  98. isa_model/training/intelligent/task_classifier.py +576 -0
  99. isa_model/training/storage/__init__.py +24 -0
  100. isa_model/training/storage/core_integration.py +439 -0
  101. isa_model/training/storage/training_repository.py +552 -0
  102. isa_model/training/storage/training_storage.py +628 -0
  103. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
  104. isa_model-0.4.0.dist-info/RECORD +182 -0
  105. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
  106. isa_model/deployment/cloud/modal/register_models.py +0 -321
  107. isa_model/inference/adapter/unified_api.py +0 -248
  108. isa_model/inference/services/helpers/stacked_config.py +0 -148
  109. isa_model/inference/services/img/flux_professional_service.py +0 -603
  110. isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
  111. isa_model/inference/services/others/table_transformer_service.py +0 -61
  112. isa_model/inference/services/vision/doc_analysis_service.py +0 -640
  113. isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
  114. isa_model/inference/services/vision/ui_analysis_service.py +0 -823
  115. isa_model/scripts/inference_tracker.py +0 -283
  116. isa_model/scripts/mlflow_manager.py +0 -379
  117. isa_model/scripts/model_registry.py +0 -465
  118. isa_model/scripts/register_models.py +0 -370
  119. isa_model/scripts/register_models_with_embeddings.py +0 -510
  120. isa_model/scripts/start_mlflow.py +0 -95
  121. isa_model/scripts/training_tracker.py +0 -257
  122. isa_model-0.3.9.dist-info/RECORD +0 -138
  123. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
  124. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
1
+ """
2
+ Optional Authentication Middleware
3
+
4
+ Provides optional API key authentication for the ISA Model Platform.
5
+ When authentication is disabled (default), all endpoints remain open.
6
+ When enabled, requires API keys for access.
7
+ """
8
+
9
+ from fastapi import HTTPException, status, Depends, Request
10
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, APIKeyHeader, APIKeyQuery
11
+ from typing import Optional, Dict, List, Union
12
+ import hashlib
13
+ import secrets
14
+ import time
15
+ import logging
16
+ import os
17
+ import json
18
+ from pathlib import Path
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Configuration
23
+ AUTH_ENABLED = os.getenv("REQUIRE_API_KEYS", "false").lower() == "true"
24
+ API_KEYS_FILE = Path(os.path.dirname(__file__)).parent.parent.parent / "deployment" / "dev" / ".api_keys.json"
25
+
26
+ # Security schemes (only used when auth is enabled)
27
+ bearer_scheme = HTTPBearer(auto_error=False)
28
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
29
+ api_key_query = APIKeyQuery(name="api_key", auto_error=False)
30
+
31
+ class APIKeyManager:
32
+ def __init__(self):
33
+ self.api_keys: Dict[str, Dict] = {}
34
+
35
+ # Load API keys first to check if auth should be enabled
36
+ self.load_api_keys()
37
+
38
+ # Determine auth state: enabled if explicitly set OR if API keys exist
39
+ explicit_auth = AUTH_ENABLED
40
+ has_keys = len(self.api_keys) > 0
41
+ self.auth_enabled = explicit_auth or has_keys
42
+
43
+ if self.auth_enabled:
44
+ logger.info(f"API Key authentication is ENABLED ({'explicit' if explicit_auth else 'auto-detected from keys'})")
45
+ else:
46
+ logger.info("API Key authentication is DISABLED - all endpoints are open")
47
+
48
+ def load_api_keys(self):
49
+ """Load API keys from file"""
50
+ try:
51
+ if API_KEYS_FILE.exists():
52
+ with open(API_KEYS_FILE, 'r') as f:
53
+ self.api_keys = json.load(f)
54
+ logger.info(f"Loaded {len(self.api_keys)} API keys")
55
+ else:
56
+ self.api_keys = {}
57
+ logger.info("No API keys file found - authentication will be disabled")
58
+ except Exception as e:
59
+ logger.error(f"Error loading API keys: {e}")
60
+ self.api_keys = {}
61
+
62
+ def save_api_keys(self):
63
+ """Save API keys to file"""
64
+ try:
65
+ API_KEYS_FILE.parent.mkdir(parents=True, exist_ok=True)
66
+ with open(API_KEYS_FILE, 'w') as f:
67
+ json.dump(self.api_keys, f, indent=2)
68
+ logger.info("API keys saved successfully")
69
+ except Exception as e:
70
+ logger.error(f"Error saving API keys: {e}")
71
+
72
+ def create_default_keys(self):
73
+ """Create default API keys for initial setup"""
74
+ admin_key = self.generate_api_key("admin", scopes=["read", "write", "admin"])
75
+ dev_key = self.generate_api_key("development", scopes=["read", "write"])
76
+
77
+ logger.warning("=== CREATED DEFAULT API KEYS ===")
78
+ logger.warning(f"Admin API Key: {admin_key}")
79
+ logger.warning(f"Development API Key: {dev_key}")
80
+ logger.warning("Please save these keys securely!")
81
+ logger.warning("=====================================")
82
+
83
+ return {"admin_key": admin_key, "dev_key": dev_key}
84
+
85
+ def generate_api_key(self, name: str, scopes: List[str] = None) -> str:
86
+ """Generate a new API key"""
87
+ if scopes is None:
88
+ scopes = ["read"]
89
+
90
+ # Generate secure random key
91
+ key = f"isa_{secrets.token_urlsafe(32)}"
92
+ key_hash = hashlib.sha256(key.encode()).hexdigest()
93
+
94
+ # Store key metadata
95
+ self.api_keys[key_hash] = {
96
+ "name": name,
97
+ "scopes": scopes,
98
+ "created_at": time.time(),
99
+ "last_used": None,
100
+ "usage_count": 0,
101
+ "active": True
102
+ }
103
+
104
+ self.save_api_keys()
105
+ return key
106
+
107
+ def validate_api_key(self, api_key: str) -> Optional[Dict]:
108
+ """Validate an API key and return its metadata"""
109
+ if not self.auth_enabled:
110
+ # When auth is disabled, return a default user context
111
+ return {
112
+ "name": "anonymous",
113
+ "scopes": ["read", "write", "admin"],
114
+ "auth_enabled": False
115
+ }
116
+
117
+ if not api_key:
118
+ return None
119
+
120
+ key_hash = hashlib.sha256(api_key.encode()).hexdigest()
121
+ key_data = self.api_keys.get(key_hash)
122
+
123
+ if not key_data or not key_data.get("active", True):
124
+ return None
125
+
126
+ # Update usage statistics
127
+ key_data["last_used"] = time.time()
128
+ key_data["usage_count"] = key_data.get("usage_count", 0) + 1
129
+ key_data["auth_enabled"] = True
130
+ self.save_api_keys()
131
+
132
+ return key_data
133
+
134
+ def revoke_api_key(self, api_key: str) -> bool:
135
+ """Revoke an API key"""
136
+ if not self.auth_enabled:
137
+ return False
138
+
139
+ key_hash = hashlib.sha256(api_key.encode()).hexdigest()
140
+ if key_hash in self.api_keys:
141
+ self.api_keys[key_hash]["active"] = False
142
+ self.save_api_keys()
143
+ return True
144
+ return False
145
+
146
+ def list_api_keys(self) -> List[Dict]:
147
+ """List all API keys (without revealing the actual keys)"""
148
+ if not self.auth_enabled:
149
+ return []
150
+
151
+ return [
152
+ {
153
+ "key_hash": key_hash[:16] + "...",
154
+ "name": data["name"],
155
+ "scopes": data["scopes"],
156
+ "created_at": data["created_at"],
157
+ "last_used": data.get("last_used"),
158
+ "usage_count": data.get("usage_count", 0),
159
+ "active": data.get("active", True)
160
+ }
161
+ for key_hash, data in self.api_keys.items()
162
+ ]
163
+
164
+ def enable_auth(self):
165
+ """Enable authentication"""
166
+ self.auth_enabled = True
167
+ if not self.api_keys:
168
+ return self.create_default_keys()
169
+ return None
170
+
171
+ def disable_auth(self):
172
+ """Disable authentication"""
173
+ self.auth_enabled = False
174
+
175
+ # Global API key manager instance
176
+ api_key_manager = APIKeyManager()
177
+
178
+ async def get_api_key_from_request(
179
+ request: Request,
180
+ bearer_token: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme),
181
+ header_key: Optional[str] = Depends(api_key_header),
182
+ query_key: Optional[str] = Depends(api_key_query)
183
+ ) -> Optional[str]:
184
+ """Extract API key from various sources"""
185
+
186
+ # If auth is disabled, return None (will be handled as anonymous)
187
+ if not api_key_manager.auth_enabled:
188
+ return None
189
+
190
+ # Try Bearer token first
191
+ if bearer_token:
192
+ return bearer_token.credentials
193
+
194
+ # Try X-API-Key header
195
+ if header_key:
196
+ return header_key
197
+
198
+ # Try query parameter
199
+ if query_key:
200
+ return query_key
201
+
202
+ return None
203
+
204
+ async def authenticate_api_key(api_key: str = Depends(get_api_key_from_request)) -> Dict:
205
+ """Authenticate API key and return user info (optional when auth disabled)"""
206
+
207
+ # When auth is disabled, always succeed with anonymous user
208
+ if not api_key_manager.auth_enabled:
209
+ return {
210
+ "name": "anonymous",
211
+ "scopes": ["read", "write", "admin"],
212
+ "auth_enabled": False,
213
+ "authenticated": False
214
+ }
215
+
216
+ # When auth is enabled, require valid API key
217
+ if not api_key:
218
+ raise HTTPException(
219
+ status_code=status.HTTP_401_UNAUTHORIZED,
220
+ detail="API key required. Provide via Authorization header, X-API-Key header, or api_key query parameter",
221
+ headers={"WWW-Authenticate": "Bearer"}
222
+ )
223
+
224
+ key_data = api_key_manager.validate_api_key(api_key)
225
+
226
+ if not key_data:
227
+ raise HTTPException(
228
+ status_code=status.HTTP_401_UNAUTHORIZED,
229
+ detail="Invalid or expired API key",
230
+ headers={"WWW-Authenticate": "Bearer"}
231
+ )
232
+
233
+ key_data["authenticated"] = True
234
+ return key_data
235
+
236
+ async def require_scope(required_scope: str):
237
+ """Create a dependency that requires a specific scope"""
238
+ async def check_scope(current_user: Dict = Depends(authenticate_api_key)) -> Dict:
239
+ # When auth is disabled, always allow
240
+ if not current_user.get("auth_enabled", True):
241
+ return current_user
242
+
243
+ user_scopes = current_user.get("scopes", [])
244
+
245
+ if required_scope not in user_scopes and "admin" not in user_scopes:
246
+ raise HTTPException(
247
+ status_code=status.HTTP_403_FORBIDDEN,
248
+ detail=f"Insufficient permissions. Required scope: {required_scope}"
249
+ )
250
+
251
+ return current_user
252
+
253
+ return check_scope
254
+
255
+ # Convenience dependencies for common scopes
256
+ async def require_read_access(current_user: Dict = Depends(authenticate_api_key)) -> Dict:
257
+ """Require read access (or auth disabled)"""
258
+ if not current_user.get("auth_enabled", True):
259
+ return current_user
260
+
261
+ user_scopes = current_user.get("scopes", [])
262
+ if not any(scope in user_scopes for scope in ["read", "write", "admin"]):
263
+ raise HTTPException(
264
+ status_code=status.HTTP_403_FORBIDDEN,
265
+ detail="Read access required"
266
+ )
267
+ return current_user
268
+
269
+ async def require_write_access(current_user: Dict = Depends(authenticate_api_key)) -> Dict:
270
+ """Require write access (or auth disabled)"""
271
+ if not current_user.get("auth_enabled", True):
272
+ return current_user
273
+
274
+ user_scopes = current_user.get("scopes", [])
275
+ if not any(scope in user_scopes for scope in ["write", "admin"]):
276
+ raise HTTPException(
277
+ status_code=status.HTTP_403_FORBIDDEN,
278
+ detail="Write access required"
279
+ )
280
+ return current_user
281
+
282
+ async def require_admin_access(current_user: Dict = Depends(authenticate_api_key)) -> Dict:
283
+ """Require admin access (or auth disabled)"""
284
+ if not current_user.get("auth_enabled", True):
285
+ return current_user
286
+
287
+ user_scopes = current_user.get("scopes", [])
288
+ if "admin" not in user_scopes:
289
+ raise HTTPException(
290
+ status_code=status.HTTP_403_FORBIDDEN,
291
+ detail="Admin access required"
292
+ )
293
+ return current_user
294
+
295
+ # Optional authentication (always returns user info, never fails)
296
+ async def optional_auth(api_key: str = Depends(get_api_key_from_request)) -> Dict:
297
+ """Optional authentication - returns user info if available, anonymous if not"""
298
+ try:
299
+ return api_key_manager.validate_api_key(api_key) or {
300
+ "name": "anonymous",
301
+ "scopes": [],
302
+ "auth_enabled": api_key_manager.auth_enabled,
303
+ "authenticated": False
304
+ }
305
+ except Exception:
306
+ return {
307
+ "name": "anonymous",
308
+ "scopes": [],
309
+ "auth_enabled": api_key_manager.auth_enabled,
310
+ "authenticated": False
311
+ }
@@ -0,0 +1,278 @@
1
+ """
2
+ Security middleware for production deployment
3
+
4
+ Provides comprehensive security features including:
5
+ - Rate limiting with Redis backend
6
+ - Security headers
7
+ - Request size limits
8
+ - Input validation and sanitization
9
+ - CORS protection
10
+ """
11
+
12
+ import time
13
+ import logging
14
+ import os
15
+ import redis
16
+ import structlog
17
+ from typing import Dict, Any, Optional, Callable
18
+ from fastapi import FastAPI, Request, Response, HTTPException, status
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.middleware.trustedhost import TrustedHostMiddleware
21
+ from slowapi import Limiter, _rate_limit_exceeded_handler
22
+ from slowapi.util import get_remote_address
23
+ from slowapi.errors import RateLimitExceeded
24
+ from starlette.middleware.base import BaseHTTPMiddleware
25
+ from starlette.responses import JSONResponse
26
+ import html
27
+
28
+ # Configure structured logging
29
+ logger = structlog.get_logger(__name__)
30
+
31
+ # Configuration from environment variables
32
+ MAX_REQUEST_SIZE = int(os.getenv("MAX_REQUEST_SIZE_MB", "50")) * 1024 * 1024 # 50MB default
33
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
34
+ ENABLE_RATE_LIMITING = os.getenv("ENABLE_RATE_LIMITING", "true").lower() == "true"
35
+ RATE_LIMIT_PER_MINUTE = os.getenv("RATE_LIMIT_PER_MINUTE", "100")
36
+ RATE_LIMIT_PER_HOUR = os.getenv("RATE_LIMIT_PER_HOUR", "1000")
37
+
38
+ # Security headers configuration
39
+ SECURITY_HEADERS = {
40
+ "X-Content-Type-Options": "nosniff",
41
+ "X-Frame-Options": "DENY",
42
+ "X-XSS-Protection": "1; mode=block",
43
+ "Strict-Transport-Security": "max-age=31536000; includeSubDomains",
44
+ "Content-Security-Policy": "default-src 'self'; script-src 'self' 'unsafe-inline' https://unpkg.com https://cdn.jsdelivr.net; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' https://fonts.gstatic.com; connect-src 'self'",
45
+ "Referrer-Policy": "strict-origin-when-cross-origin",
46
+ "Permissions-Policy": "geolocation=(), microphone=(), camera=()"
47
+ }
48
+
49
+ # Initialize Redis connection for rate limiting
50
+ try:
51
+ redis_client = redis.from_url(REDIS_URL, decode_responses=True)
52
+ redis_client.ping() # Test connection
53
+ logger.info("Redis connection established for rate limiting")
54
+ except Exception as e:
55
+ logger.warning(f"Redis connection failed, using in-memory rate limiting: {e}")
56
+ redis_client = None
57
+
58
+ # Initialize rate limiter
59
+ def get_remote_address_with_proxy(request: Request):
60
+ """Get client IP considering proxy headers"""
61
+ forwarded_for = request.headers.get("X-Forwarded-For")
62
+ if forwarded_for:
63
+ return forwarded_for.split(",")[0].strip()
64
+
65
+ real_ip = request.headers.get("X-Real-IP")
66
+ if real_ip:
67
+ return real_ip
68
+
69
+ return get_remote_address(request)
70
+
71
+ # Rate limiter with Redis backend if available
72
+ if redis_client:
73
+ limiter = Limiter(
74
+ key_func=get_remote_address_with_proxy,
75
+ storage_uri=REDIS_URL,
76
+ strategy="fixed-window"
77
+ )
78
+ else:
79
+ limiter = Limiter(
80
+ key_func=get_remote_address_with_proxy,
81
+ strategy="fixed-window"
82
+ )
83
+
84
+ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
85
+ """Add security headers to all responses"""
86
+
87
+ async def dispatch(self, request: Request, call_next: Callable) -> Response:
88
+ try:
89
+ response = await call_next(request)
90
+
91
+ # Add security headers
92
+ for header, value in SECURITY_HEADERS.items():
93
+ response.headers[header] = value
94
+
95
+ # Add processing time header
96
+ if hasattr(request.state, 'start_time'):
97
+ process_time = time.time() - request.state.start_time
98
+ response.headers["X-Process-Time"] = str(process_time)
99
+
100
+ return response
101
+
102
+ except Exception as e:
103
+ logger.error("Error in security headers middleware", error=str(e))
104
+ return JSONResponse(
105
+ status_code=500,
106
+ content={"error": "Internal server error"},
107
+ headers=SECURITY_HEADERS
108
+ )
109
+
110
+ class RequestValidationMiddleware(BaseHTTPMiddleware):
111
+ """Validate request size and sanitize inputs"""
112
+
113
+ async def dispatch(self, request: Request, call_next: Callable) -> Response:
114
+ # Record start time for performance monitoring
115
+ request.state.start_time = time.time()
116
+
117
+ try:
118
+ # Check request size
119
+ content_length = request.headers.get("content-length")
120
+ if content_length and int(content_length) > MAX_REQUEST_SIZE:
121
+ logger.warning(
122
+ "Request too large",
123
+ content_length=content_length,
124
+ max_size=MAX_REQUEST_SIZE,
125
+ client_ip=get_remote_address_with_proxy(request)
126
+ )
127
+ raise HTTPException(
128
+ status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
129
+ detail=f"Request too large. Maximum size: {MAX_REQUEST_SIZE // (1024*1024)}MB"
130
+ )
131
+
132
+ # Sanitize query parameters
133
+ if request.url.query:
134
+ sanitized_query = html.escape(request.url.query)
135
+ if sanitized_query != request.url.query:
136
+ logger.warning(
137
+ "Potentially malicious query parameters detected",
138
+ original=request.url.query,
139
+ sanitized=sanitized_query,
140
+ client_ip=get_remote_address_with_proxy(request)
141
+ )
142
+
143
+ # Log request details for monitoring
144
+ logger.info(
145
+ "Request received",
146
+ method=request.method,
147
+ path=request.url.path,
148
+ client_ip=get_remote_address_with_proxy(request),
149
+ user_agent=request.headers.get("user-agent", "unknown")
150
+ )
151
+
152
+ response = await call_next(request)
153
+
154
+ # Log response details
155
+ process_time = time.time() - request.state.start_time
156
+ logger.info(
157
+ "Request completed",
158
+ method=request.method,
159
+ path=request.url.path,
160
+ status_code=response.status_code,
161
+ process_time=process_time,
162
+ client_ip=get_remote_address_with_proxy(request)
163
+ )
164
+
165
+ return response
166
+
167
+ except HTTPException:
168
+ raise
169
+ except Exception as e:
170
+ logger.error(
171
+ "Error in request validation middleware",
172
+ error=str(e),
173
+ path=request.url.path,
174
+ method=request.method,
175
+ client_ip=get_remote_address_with_proxy(request)
176
+ )
177
+ raise HTTPException(
178
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
179
+ detail="Internal server error"
180
+ )
181
+
182
+ def setup_security_middleware(app: FastAPI):
183
+ """Setup all security middleware for the FastAPI application"""
184
+
185
+ # Rate limiting setup
186
+ if ENABLE_RATE_LIMITING:
187
+ app.state.limiter = limiter
188
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
189
+ logger.info("Rate limiting enabled", redis_backend=redis_client is not None)
190
+
191
+ # Trusted hosts (production should specify allowed hosts)
192
+ allowed_hosts = os.getenv("ALLOWED_HOSTS", "*").split(",")
193
+ if allowed_hosts != ["*"]:
194
+ app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts)
195
+ logger.info("Trusted hosts middleware enabled", allowed_hosts=allowed_hosts)
196
+
197
+ # CORS configuration
198
+ cors_origins = os.getenv("CORS_ORIGINS", "*").split(",")
199
+ app.add_middleware(
200
+ CORSMiddleware,
201
+ allow_origins=cors_origins,
202
+ allow_credentials=True,
203
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
204
+ allow_headers=["*"],
205
+ expose_headers=["X-Process-Time"]
206
+ )
207
+ logger.info("CORS middleware enabled", origins=cors_origins)
208
+
209
+ # Custom security middleware
210
+ app.add_middleware(SecurityHeadersMiddleware)
211
+ app.add_middleware(RequestValidationMiddleware)
212
+
213
+ logger.info("Security middleware setup completed")
214
+
215
+ def get_rate_limiter():
216
+ """Get the configured rate limiter"""
217
+ return limiter
218
+
219
+ # Rate limiting decorators for different use cases
220
+ def rate_limit_standard():
221
+ """Standard rate limit for general API usage"""
222
+ return limiter.limit(f"{RATE_LIMIT_PER_MINUTE}/minute")
223
+
224
+ def rate_limit_heavy():
225
+ """Heavy rate limit for resource-intensive operations"""
226
+ heavy_limit = int(RATE_LIMIT_PER_MINUTE) // 5 # 20% of standard limit
227
+ return limiter.limit(f"{heavy_limit}/minute")
228
+
229
+ def rate_limit_auth():
230
+ """Strict rate limit for authentication endpoints"""
231
+ return limiter.limit("10/minute")
232
+
233
+ # Security utilities
234
+ def sanitize_input(text: str) -> str:
235
+ """Sanitize text input to prevent XSS attacks"""
236
+ if not isinstance(text, str):
237
+ return text
238
+ return html.escape(text)
239
+
240
+ def validate_api_key_format(api_key: str) -> bool:
241
+ """Validate API key format"""
242
+ if not isinstance(api_key, str):
243
+ return False
244
+
245
+ # Check if it starts with expected prefix
246
+ if not api_key.startswith("isa_"):
247
+ return False
248
+
249
+ # Check minimum length (should be > 20 characters)
250
+ if len(api_key) < 25:
251
+ return False
252
+
253
+ return True
254
+
255
+ def get_client_info(request: Request) -> Dict[str, Any]:
256
+ """Extract client information for logging and monitoring"""
257
+ return {
258
+ "ip": get_remote_address_with_proxy(request),
259
+ "user_agent": request.headers.get("user-agent", "unknown"),
260
+ "referer": request.headers.get("referer"),
261
+ "forwarded_for": request.headers.get("x-forwarded-for"),
262
+ "real_ip": request.headers.get("x-real-ip"),
263
+ "method": request.method,
264
+ "path": request.url.path,
265
+ "query": request.url.query
266
+ }
267
+
268
+ # Health check for Redis connection
269
+ async def check_redis_health() -> Dict[str, Any]:
270
+ """Check Redis connection health"""
271
+ if not redis_client:
272
+ return {"redis": "disabled", "status": "ok"}
273
+
274
+ try:
275
+ redis_client.ping()
276
+ return {"redis": "connected", "status": "ok"}
277
+ except Exception as e:
278
+ return {"redis": "error", "status": "error", "error": str(e)}