zen-ai-pentest 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api/auth.py +61 -7
- api/csrf_protection.py +286 -0
- api/main.py +77 -11
- api/rate_limiter.py +317 -0
- api/rate_limiter_v2.py +586 -0
- autonomous/ki_analysis_agent.py +1033 -0
- benchmarks/__init__.py +12 -142
- benchmarks/agent_performance.py +374 -0
- benchmarks/api_performance.py +479 -0
- benchmarks/scan_performance.py +272 -0
- modules/agent_coordinator.py +255 -0
- modules/api_key_manager.py +501 -0
- modules/benchmark.py +706 -0
- modules/cve_updater.py +303 -0
- modules/false_positive_filter.py +149 -0
- modules/output_formats.py +1088 -0
- modules/risk_scoring.py +206 -0
- {zen_ai_pentest-2.2.0.dist-info → zen_ai_pentest-2.3.0.dist-info}/METADATA +134 -289
- {zen_ai_pentest-2.2.0.dist-info → zen_ai_pentest-2.3.0.dist-info}/RECORD +23 -9
- {zen_ai_pentest-2.2.0.dist-info → zen_ai_pentest-2.3.0.dist-info}/WHEEL +0 -0
- {zen_ai_pentest-2.2.0.dist-info → zen_ai_pentest-2.3.0.dist-info}/entry_points.txt +0 -0
- {zen_ai_pentest-2.2.0.dist-info → zen_ai_pentest-2.3.0.dist-info}/licenses/LICENSE +0 -0
- {zen_ai_pentest-2.2.0.dist-info → zen_ai_pentest-2.3.0.dist-info}/top_level.txt +0 -0
api/rate_limiter_v2.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
"""
|
|
2
|
+
User-based Rate Limiting für Zen-AI-Pentest API v2
|
|
3
|
+
|
|
4
|
+
Features:
|
|
5
|
+
- IP-basiertes Limiting für anonyme User
|
|
6
|
+
- Account-basiertes Limiting für authentifizierte User
|
|
7
|
+
- Verschiedene Limits je nach User-Tier (anonymous, user, admin)
|
|
8
|
+
- Redis-Unterstützung für verteilte Systeme
|
|
9
|
+
- Detaillierte Rate Limit Headers
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import time
|
|
14
|
+
import json
|
|
15
|
+
import hashlib
|
|
16
|
+
from typing import Dict, Optional, Callable, Literal
|
|
17
|
+
from functools import wraps
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from fastapi import Request, HTTPException, status, Depends
|
|
20
|
+
import logging
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# =============================================================================
|
|
25
|
+
# Configuration
|
|
26
|
+
# =============================================================================
|
|
27
|
+
|
|
28
|
+
# Rate limits by user tier (requests per minute)
|
|
29
|
+
RATE_LIMITS = {
|
|
30
|
+
"anonymous": {
|
|
31
|
+
"requests_per_minute": int(os.getenv("RATE_LIMIT_ANON_RPM", "30")),
|
|
32
|
+
"burst_size": int(os.getenv("RATE_LIMIT_ANON_BURST", "5")),
|
|
33
|
+
"description": "Unauthenticated users"
|
|
34
|
+
},
|
|
35
|
+
"user": {
|
|
36
|
+
"requests_per_minute": int(os.getenv("RATE_LIMIT_USER_RPM", "60")),
|
|
37
|
+
"burst_size": int(os.getenv("RATE_LIMIT_USER_BURST", "10")),
|
|
38
|
+
"description": "Standard authenticated users"
|
|
39
|
+
},
|
|
40
|
+
"premium": {
|
|
41
|
+
"requests_per_minute": int(os.getenv("RATE_LIMIT_PREMIUM_RPM", "120")),
|
|
42
|
+
"burst_size": int(os.getenv("RATE_LIMIT_PREMIUM_BURST", "20")),
|
|
43
|
+
"description": "Premium users"
|
|
44
|
+
},
|
|
45
|
+
"admin": {
|
|
46
|
+
"requests_per_minute": int(os.getenv("RATE_LIMIT_ADMIN_RPM", "300")),
|
|
47
|
+
"burst_size": int(os.getenv("RATE_LIMIT_ADMIN_BURST", "50")),
|
|
48
|
+
"description": "Administrators"
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
# Auth endpoints - stricter limits
|
|
53
|
+
AUTH_RATE_LIMIT = int(os.getenv("AUTH_RATE_LIMIT", "5"))
|
|
54
|
+
AUTH_LOCKOUT_DURATION = int(os.getenv("AUTH_LOCKOUT_DURATION", "300")) # 5 minutes
|
|
55
|
+
|
|
56
|
+
# User tier detection (customize based on your auth system)
|
|
57
|
+
UserTier = Literal["anonymous", "user", "premium", "admin"]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# =============================================================================
|
|
61
|
+
# Token Bucket (Thread-safe)
|
|
62
|
+
# =============================================================================
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class TokenBucket:
|
|
66
|
+
"""Token Bucket für Rate Limiting"""
|
|
67
|
+
rate: float # Tokens per second
|
|
68
|
+
burst_size: int
|
|
69
|
+
tokens: float = 0
|
|
70
|
+
last_update: float = 0
|
|
71
|
+
|
|
72
|
+
def __post_init__(self):
|
|
73
|
+
if self.last_update == 0:
|
|
74
|
+
self.last_update = time.time()
|
|
75
|
+
self.tokens = self.burst_size
|
|
76
|
+
|
|
77
|
+
def _add_tokens(self):
|
|
78
|
+
now = time.time()
|
|
79
|
+
time_passed = now - self.last_update
|
|
80
|
+
tokens_to_add = time_passed * self.rate
|
|
81
|
+
self.tokens = min(self.burst_size, self.tokens + tokens_to_add)
|
|
82
|
+
self.last_update = now
|
|
83
|
+
|
|
84
|
+
def consume(self, tokens: int = 1) -> bool:
|
|
85
|
+
self._add_tokens()
|
|
86
|
+
if self.tokens >= tokens:
|
|
87
|
+
self.tokens -= tokens
|
|
88
|
+
return True
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
def get_wait_time(self, tokens: int = 1) -> float:
|
|
92
|
+
self._add_tokens()
|
|
93
|
+
if self.tokens >= tokens:
|
|
94
|
+
return 0
|
|
95
|
+
tokens_needed = tokens - self.tokens
|
|
96
|
+
return tokens_needed / self.rate
|
|
97
|
+
|
|
98
|
+
def to_dict(self) -> dict:
|
|
99
|
+
self._add_tokens()
|
|
100
|
+
return {
|
|
101
|
+
"tokens": self.tokens,
|
|
102
|
+
"burst_size": self.burst_size,
|
|
103
|
+
"rate": self.rate,
|
|
104
|
+
"last_update": self.last_update
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def from_dict(cls, data: dict) -> "TokenBucket":
|
|
109
|
+
bucket = cls(
|
|
110
|
+
rate=data["rate"],
|
|
111
|
+
burst_size=data["burst_size"],
|
|
112
|
+
tokens=data["tokens"],
|
|
113
|
+
last_update=data["last_update"]
|
|
114
|
+
)
|
|
115
|
+
return bucket
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# =============================================================================
|
|
119
|
+
# Storage Backends
|
|
120
|
+
# =============================================================================
|
|
121
|
+
|
|
122
|
+
class RateLimitStorage:
|
|
123
|
+
"""Base storage class for rate limits"""
|
|
124
|
+
|
|
125
|
+
def get_bucket(self, key: str, rate: float, burst_size: int) -> TokenBucket:
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
def save_bucket(self, key: str, bucket: TokenBucket):
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
def cleanup(self):
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class MemoryStorage(RateLimitStorage):
|
|
136
|
+
"""In-Memory storage (single instance only)"""
|
|
137
|
+
|
|
138
|
+
def __init__(self):
|
|
139
|
+
self.buckets: Dict[str, TokenBucket] = {}
|
|
140
|
+
self.last_access: Dict[str, float] = {}
|
|
141
|
+
self.metadata: Dict[str, dict] = {} # Store tier info
|
|
142
|
+
|
|
143
|
+
def get_bucket(self, key: str, rate: float, burst_size: int) -> TokenBucket:
|
|
144
|
+
if key not in self.buckets:
|
|
145
|
+
self.buckets[key] = TokenBucket(rate=rate, burst_size=burst_size)
|
|
146
|
+
|
|
147
|
+
self.last_access[key] = time.time()
|
|
148
|
+
return self.buckets[key]
|
|
149
|
+
|
|
150
|
+
def save_bucket(self, key: str, bucket: TokenBucket):
|
|
151
|
+
self.buckets[key] = bucket
|
|
152
|
+
self.last_access[key] = time.time()
|
|
153
|
+
|
|
154
|
+
def set_metadata(self, key: str, metadata: dict):
|
|
155
|
+
self.metadata[key] = metadata
|
|
156
|
+
|
|
157
|
+
def get_metadata(self, key: str) -> Optional[dict]:
|
|
158
|
+
return self.metadata.get(key)
|
|
159
|
+
|
|
160
|
+
def cleanup(self, max_age: float = 3600):
|
|
161
|
+
now = time.time()
|
|
162
|
+
to_remove = [
|
|
163
|
+
key for key, last in self.last_access.items()
|
|
164
|
+
if now - last > max_age
|
|
165
|
+
]
|
|
166
|
+
for key in to_remove:
|
|
167
|
+
del self.buckets[key]
|
|
168
|
+
del self.last_access[key]
|
|
169
|
+
if key in self.metadata:
|
|
170
|
+
del self.metadata[key]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class RedisStorage(RateLimitStorage):
|
|
174
|
+
"""Redis storage for distributed systems"""
|
|
175
|
+
|
|
176
|
+
def __init__(self, redis_url: str = None):
|
|
177
|
+
self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
|
178
|
+
self._redis = None
|
|
179
|
+
self._init_redis()
|
|
180
|
+
|
|
181
|
+
def _init_redis(self):
|
|
182
|
+
try:
|
|
183
|
+
import redis
|
|
184
|
+
self._redis = redis.from_url(self.redis_url, decode_responses=True)
|
|
185
|
+
self._redis.ping()
|
|
186
|
+
logger.info("Redis connection established for rate limiting")
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.warning(f"Redis not available, falling back to memory: {e}")
|
|
189
|
+
self._redis = None
|
|
190
|
+
|
|
191
|
+
def _get_key(self, key: str) -> str:
|
|
192
|
+
return f"rate_limit:{key}"
|
|
193
|
+
|
|
194
|
+
def get_bucket(self, key: str, rate: float, burst_size: int) -> TokenBucket:
|
|
195
|
+
if not self._redis:
|
|
196
|
+
# Fallback to memory
|
|
197
|
+
return MemoryStorage().get_bucket(key, rate, burst_size)
|
|
198
|
+
|
|
199
|
+
redis_key = self._get_key(key)
|
|
200
|
+
data = self._redis.get(redis_key)
|
|
201
|
+
|
|
202
|
+
if data:
|
|
203
|
+
try:
|
|
204
|
+
bucket_data = json.loads(data)
|
|
205
|
+
return TokenBucket.from_dict(bucket_data)
|
|
206
|
+
except (json.JSONDecodeError, KeyError):
|
|
207
|
+
pass
|
|
208
|
+
|
|
209
|
+
return TokenBucket(rate=rate, burst_size=burst_size)
|
|
210
|
+
|
|
211
|
+
def save_bucket(self, key: str, bucket: TokenBucket):
|
|
212
|
+
if not self._redis:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
redis_key = self._get_key(key)
|
|
216
|
+
data = json.dumps(bucket.to_dict())
|
|
217
|
+
# TTL: 1 hour
|
|
218
|
+
self._redis.setex(redis_key, 3600, data)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# Global storage instance
|
|
222
|
+
storage_backend = os.getenv("RATE_LIMIT_STORAGE", "memory")
|
|
223
|
+
if storage_backend == "redis":
|
|
224
|
+
try:
|
|
225
|
+
rate_limit_storage = RedisStorage()
|
|
226
|
+
except Exception as e:
|
|
227
|
+
logger.warning(f"Failed to initialize Redis storage: {e}")
|
|
228
|
+
rate_limit_storage = MemoryStorage()
|
|
229
|
+
else:
|
|
230
|
+
rate_limit_storage = MemoryStorage()
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# =============================================================================
|
|
234
|
+
# User Context
|
|
235
|
+
# =============================================================================
|
|
236
|
+
|
|
237
|
+
@dataclass
|
|
238
|
+
class UserContext:
|
|
239
|
+
"""User information for rate limiting"""
|
|
240
|
+
user_id: Optional[str] = None
|
|
241
|
+
username: Optional[str] = None
|
|
242
|
+
tier: UserTier = "anonymous"
|
|
243
|
+
ip_address: str = "unknown"
|
|
244
|
+
|
|
245
|
+
def get_rate_limit_key(self) -> str:
|
|
246
|
+
"""Generate unique key for this user"""
|
|
247
|
+
if self.user_id:
|
|
248
|
+
return f"user:{self.user_id}"
|
|
249
|
+
# For anonymous: hash IP + user agent
|
|
250
|
+
ip_hash = hashlib.sha256(self.ip_address.encode()).hexdigest()[:16]
|
|
251
|
+
return f"anon:{ip_hash}"
|
|
252
|
+
|
|
253
|
+
def get_limits(self) -> dict:
|
|
254
|
+
"""Get rate limits for this user's tier"""
|
|
255
|
+
return RATE_LIMITS.get(self.tier, RATE_LIMITS["anonymous"])
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def get_user_from_request(request: Request) -> UserContext:
|
|
259
|
+
"""
|
|
260
|
+
Extract user information from request.
|
|
261
|
+
|
|
262
|
+
Customize this based on your auth system!
|
|
263
|
+
"""
|
|
264
|
+
client_ip = request.client.host if request.client else "unknown"
|
|
265
|
+
|
|
266
|
+
# Try to get user from JWT token or session
|
|
267
|
+
# This is a simplified example - adapt to your auth system
|
|
268
|
+
auth_header = request.headers.get("authorization", "")
|
|
269
|
+
user_id = None
|
|
270
|
+
username = None
|
|
271
|
+
tier: UserTier = "anonymous"
|
|
272
|
+
|
|
273
|
+
if auth_header.startswith("Bearer "):
|
|
274
|
+
# Extract user from JWT (simplified)
|
|
275
|
+
# In production: verify token and extract claims
|
|
276
|
+
try:
|
|
277
|
+
# Placeholder - integrate with your JWT auth
|
|
278
|
+
# user_id = decode_jwt(auth_header[7:]).get("sub")
|
|
279
|
+
# tier = get_user_tier(user_id)
|
|
280
|
+
pass
|
|
281
|
+
except Exception:
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
# Check for admin/premium in headers (customize this!)
|
|
285
|
+
user_tier_header = request.headers.get("x-user-tier")
|
|
286
|
+
if user_tier_header in RATE_LIMITS:
|
|
287
|
+
tier = user_tier_header # type: ignore
|
|
288
|
+
|
|
289
|
+
return UserContext(
|
|
290
|
+
user_id=user_id,
|
|
291
|
+
username=username,
|
|
292
|
+
tier=tier,
|
|
293
|
+
ip_address=client_ip
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# =============================================================================
|
|
298
|
+
# Rate Limiting Decorator
|
|
299
|
+
# =============================================================================
|
|
300
|
+
|
|
301
|
+
def rate_limit(
|
|
302
|
+
requests_per_minute: Optional[int] = None,
|
|
303
|
+
burst_size: Optional[int] = None,
|
|
304
|
+
tier: Optional[UserTier] = None
|
|
305
|
+
):
|
|
306
|
+
"""
|
|
307
|
+
Rate limiting decorator with user-based limits.
|
|
308
|
+
|
|
309
|
+
Usage:
|
|
310
|
+
@app.get("/api/data")
|
|
311
|
+
@rate_limit() # Uses user's tier limits
|
|
312
|
+
async def get_data(request: Request):
|
|
313
|
+
return {"data": "value"}
|
|
314
|
+
|
|
315
|
+
@app.get("/api/admin")
|
|
316
|
+
@rate_limit(requests_per_minute=600) # Custom limit
|
|
317
|
+
async def admin_endpoint(request: Request):
|
|
318
|
+
return {"admin": "data"}
|
|
319
|
+
"""
|
|
320
|
+
def decorator(func: Callable) -> Callable:
|
|
321
|
+
@wraps(func)
|
|
322
|
+
async def wrapper(*args, **kwargs):
|
|
323
|
+
# Find Request object
|
|
324
|
+
request = None
|
|
325
|
+
for arg in args:
|
|
326
|
+
if isinstance(arg, Request):
|
|
327
|
+
request = arg
|
|
328
|
+
break
|
|
329
|
+
|
|
330
|
+
if not request:
|
|
331
|
+
return await func(*args, **kwargs)
|
|
332
|
+
|
|
333
|
+
# Get user context
|
|
334
|
+
user = get_user_from_request(request)
|
|
335
|
+
|
|
336
|
+
# Determine limits
|
|
337
|
+
if tier:
|
|
338
|
+
limits = RATE_LIMITS[tier]
|
|
339
|
+
elif requests_per_minute:
|
|
340
|
+
limits = {
|
|
341
|
+
"requests_per_minute": requests_per_minute,
|
|
342
|
+
"burst_size": burst_size or 10
|
|
343
|
+
}
|
|
344
|
+
else:
|
|
345
|
+
limits = user.get_limits()
|
|
346
|
+
|
|
347
|
+
# Get or create bucket
|
|
348
|
+
key = user.get_rate_limit_key()
|
|
349
|
+
rate_per_second = limits["requests_per_minute"] / 60
|
|
350
|
+
bucket = rate_limit_storage.get_bucket(
|
|
351
|
+
key, rate_per_second, limits["burst_size"]
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Store metadata
|
|
355
|
+
if isinstance(rate_limit_storage, MemoryStorage):
|
|
356
|
+
rate_limit_storage.set_metadata(key, {
|
|
357
|
+
"tier": user.tier,
|
|
358
|
+
"user_id": user.user_id,
|
|
359
|
+
"ip": user.ip_address
|
|
360
|
+
})
|
|
361
|
+
|
|
362
|
+
# Check rate limit
|
|
363
|
+
if not bucket.consume():
|
|
364
|
+
wait_time = bucket.get_wait_time()
|
|
365
|
+
logger.warning(
|
|
366
|
+
f"Rate limit exceeded for {user.tier} user {user.user_id or user.ip_address}"
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
raise HTTPException(
|
|
370
|
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
371
|
+
detail={
|
|
372
|
+
"error": "Rate limit exceeded",
|
|
373
|
+
"retry_after": int(wait_time),
|
|
374
|
+
"tier": user.tier,
|
|
375
|
+
"limit": limits["requests_per_minute"]
|
|
376
|
+
},
|
|
377
|
+
headers={
|
|
378
|
+
"Retry-After": str(int(wait_time)),
|
|
379
|
+
"X-RateLimit-Limit": str(limits["requests_per_minute"]),
|
|
380
|
+
"X-RateLimit-Remaining": "0",
|
|
381
|
+
"X-RateLimit-Tier": user.tier
|
|
382
|
+
}
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Save bucket state
|
|
386
|
+
rate_limit_storage.save_bucket(key, bucket)
|
|
387
|
+
|
|
388
|
+
# Add rate limit headers to response
|
|
389
|
+
response = await func(*args, **kwargs)
|
|
390
|
+
|
|
391
|
+
return response
|
|
392
|
+
|
|
393
|
+
return wrapper
|
|
394
|
+
return decorator
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
# =============================================================================
|
|
398
|
+
# Middleware für globales Rate Limiting
|
|
399
|
+
# =============================================================================
|
|
400
|
+
|
|
401
|
+
class UserRateLimitMiddleware:
|
|
402
|
+
"""
|
|
403
|
+
ASGI Middleware für user-basiertes Rate Limiting.
|
|
404
|
+
|
|
405
|
+
Usage:
|
|
406
|
+
app.add_middleware(UserRateLimitMiddleware)
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
def __init__(self, app):
|
|
410
|
+
self.app = app
|
|
411
|
+
|
|
412
|
+
async def __call__(self, scope, receive, send):
|
|
413
|
+
if scope["type"] != "http":
|
|
414
|
+
await self.app(scope, receive, send)
|
|
415
|
+
return
|
|
416
|
+
|
|
417
|
+
# Build minimal request for user detection
|
|
418
|
+
client = scope.get("client")
|
|
419
|
+
client_ip = client[0] if client else "unknown"
|
|
420
|
+
|
|
421
|
+
# Simple tier detection from headers (customize!)
|
|
422
|
+
headers = dict(scope.get("headers", []))
|
|
423
|
+
user_tier = headers.get(b"x-user-tier", b"anonymous").decode()
|
|
424
|
+
|
|
425
|
+
if user_tier not in RATE_LIMITS:
|
|
426
|
+
user_tier = "anonymous"
|
|
427
|
+
|
|
428
|
+
limits = RATE_LIMITS[user_tier]
|
|
429
|
+
|
|
430
|
+
# Create key (simplified)
|
|
431
|
+
key = f"middleware:{user_tier}:{client_ip}"
|
|
432
|
+
rate_per_second = limits["requests_per_minute"] / 60
|
|
433
|
+
|
|
434
|
+
bucket = rate_limit_storage.get_bucket(
|
|
435
|
+
key, rate_per_second, limits["burst_size"]
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
if not bucket.consume():
|
|
439
|
+
wait_time = bucket.get_wait_time()
|
|
440
|
+
|
|
441
|
+
await send({
|
|
442
|
+
"type": "http.response.start",
|
|
443
|
+
"status": 429,
|
|
444
|
+
"headers": [
|
|
445
|
+
[b"content-type", b"application/json"],
|
|
446
|
+
[b"retry-after", str(int(wait_time)).encode()],
|
|
447
|
+
[b"x-ratelimit-tier", user_tier.encode()]
|
|
448
|
+
]
|
|
449
|
+
})
|
|
450
|
+
await send({
|
|
451
|
+
"type": "http.response.body",
|
|
452
|
+
"body": json.dumps({
|
|
453
|
+
"error": "Rate limit exceeded",
|
|
454
|
+
"retry_after": int(wait_time),
|
|
455
|
+
"tier": user_tier
|
|
456
|
+
}).encode()
|
|
457
|
+
})
|
|
458
|
+
return
|
|
459
|
+
|
|
460
|
+
rate_limit_storage.save_bucket(key, bucket)
|
|
461
|
+
await self.app(scope, receive, send)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
# =============================================================================
|
|
465
|
+
# Auth Rate Limiting (mit User-ID Tracking)
|
|
466
|
+
# =============================================================================
|
|
467
|
+
|
|
468
|
+
class UserAuthRateLimiter:
|
|
469
|
+
"""
|
|
470
|
+
Erweiterte Auth Rate Limiting mit User-ID Tracking.
|
|
471
|
+
|
|
472
|
+
Trackt sowohl IP als auch User-ID für bessere Security.
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
def __init__(self):
|
|
476
|
+
self.ip_attempts: Dict[str, list] = {}
|
|
477
|
+
self.user_attempts: Dict[str, list] = {}
|
|
478
|
+
self.lockout_duration = AUTH_LOCKOUT_DURATION
|
|
479
|
+
self.max_attempts = 5
|
|
480
|
+
|
|
481
|
+
def _cleanup_old(self, attempts: list, window: int = 60) -> list:
|
|
482
|
+
now = time.time()
|
|
483
|
+
return [t for t in attempts if now - t < window]
|
|
484
|
+
|
|
485
|
+
def is_allowed(self, client_ip: str, user_id: Optional[str] = None) -> tuple[bool, Optional[int], str]:
|
|
486
|
+
"""
|
|
487
|
+
Prüft ob Auth erlaubt.
|
|
488
|
+
|
|
489
|
+
Returns: (allowed, lockout_seconds, reason)
|
|
490
|
+
"""
|
|
491
|
+
now = time.time()
|
|
492
|
+
|
|
493
|
+
# Check IP-based limits
|
|
494
|
+
self.ip_attempts[client_ip] = self._cleanup_old(
|
|
495
|
+
self.ip_attempts.get(client_ip, [])
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
if len(self.ip_attempts[client_ip]) >= self.max_attempts:
|
|
499
|
+
oldest = min(self.ip_attempts[client_ip])
|
|
500
|
+
lockout = self.lockout_duration - (now - oldest)
|
|
501
|
+
if lockout > 0:
|
|
502
|
+
return False, int(lockout), "ip_blocked"
|
|
503
|
+
self.ip_attempts[client_ip] = []
|
|
504
|
+
|
|
505
|
+
# Check user-based limits (if user_id provided)
|
|
506
|
+
if user_id:
|
|
507
|
+
self.user_attempts[user_id] = self._cleanup_old(
|
|
508
|
+
self.user_attempts.get(user_id, [])
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
if len(self.user_attempts[user_id]) >= self.max_attempts:
|
|
512
|
+
oldest = min(self.user_attempts[user_id])
|
|
513
|
+
lockout = self.lockout_duration - (now - oldest)
|
|
514
|
+
if lockout > 0:
|
|
515
|
+
return False, int(lockout), "user_blocked"
|
|
516
|
+
self.user_attempts[user_id] = []
|
|
517
|
+
|
|
518
|
+
return True, None, "ok"
|
|
519
|
+
|
|
520
|
+
def record_failure(self, client_ip: str, user_id: Optional[str] = None):
|
|
521
|
+
now = time.time()
|
|
522
|
+
|
|
523
|
+
if client_ip not in self.ip_attempts:
|
|
524
|
+
self.ip_attempts[client_ip] = []
|
|
525
|
+
self.ip_attempts[client_ip].append(now)
|
|
526
|
+
|
|
527
|
+
if user_id:
|
|
528
|
+
if user_id not in self.user_attempts:
|
|
529
|
+
self.user_attempts[user_id] = []
|
|
530
|
+
self.user_attempts[user_id].append(now)
|
|
531
|
+
|
|
532
|
+
def record_success(self, client_ip: str, user_id: Optional[str] = None):
|
|
533
|
+
if client_ip in self.ip_attempts:
|
|
534
|
+
del self.ip_attempts[client_ip]
|
|
535
|
+
if user_id and user_id in self.user_attempts:
|
|
536
|
+
del self.user_attempts[user_id]
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
# Global instance
|
|
540
|
+
user_auth_rate_limiter = UserAuthRateLimiter()
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def check_user_auth_rate_limit(client_ip: str, user_id: Optional[str] = None):
|
|
544
|
+
"""Prüft Auth Rate Limit"""
|
|
545
|
+
allowed, lockout, reason = user_auth_rate_limiter.is_allowed(client_ip, user_id)
|
|
546
|
+
|
|
547
|
+
if not allowed:
|
|
548
|
+
logger.warning(f"Auth rate limit exceeded: {reason} for {client_ip}/{user_id}")
|
|
549
|
+
raise HTTPException(
|
|
550
|
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
551
|
+
detail={
|
|
552
|
+
"error": "Too many login attempts",
|
|
553
|
+
"retry_after": lockout,
|
|
554
|
+
"reason": reason
|
|
555
|
+
},
|
|
556
|
+
headers={"Retry-After": str(lockout)}
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
# =============================================================================
|
|
561
|
+
# Stats und Monitoring
|
|
562
|
+
# =============================================================================
|
|
563
|
+
|
|
564
|
+
def get_rate_limit_stats() -> dict:
|
|
565
|
+
"""Gibt aktuelle Rate Limiting Statistiken zurück"""
|
|
566
|
+
if not isinstance(rate_limit_storage, MemoryStorage):
|
|
567
|
+
return {"error": "Stats only available with memory storage"}
|
|
568
|
+
|
|
569
|
+
stats = {
|
|
570
|
+
"total_buckets": len(rate_limit_storage.buckets),
|
|
571
|
+
"by_tier": {"anonymous": 0, "user": 0, "premium": 0, "admin": 0, "unknown": 0}
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
for key, metadata in rate_limit_storage.metadata.items():
|
|
575
|
+
tier = metadata.get("tier", "unknown")
|
|
576
|
+
stats["by_tier"][tier] = stats["by_tier"].get(tier, 0) + 1
|
|
577
|
+
|
|
578
|
+
return stats
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
# Cleanup Job (optional)
|
|
582
|
+
def cleanup_rate_limits():
|
|
583
|
+
"""Entfernt alte Rate Limit Buckets"""
|
|
584
|
+
if isinstance(rate_limit_storage, MemoryStorage):
|
|
585
|
+
rate_limit_storage.cleanup()
|
|
586
|
+
logger.info("Rate limit storage cleaned up")
|