zen-ai-pentest 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
api/rate_limiter_v2.py ADDED
@@ -0,0 +1,586 @@
1
+ """
2
+ User-based Rate Limiting für Zen-AI-Pentest API v2
3
+
4
+ Features:
5
+ - IP-basiertes Limiting für anonyme User
6
+ - Account-basiertes Limiting für authentifizierte User
7
+ - Verschiedene Limits je nach User-Tier (anonymous, user, admin)
8
+ - Redis-Unterstützung für verteilte Systeme
9
+ - Detaillierte Rate Limit Headers
10
+ """
11
+
12
+ import os
13
+ import time
14
+ import json
15
+ import hashlib
16
+ from typing import Dict, Optional, Callable, Literal
17
+ from functools import wraps
18
+ from dataclasses import dataclass
19
+ from fastapi import Request, HTTPException, status, Depends
20
+ import logging
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # =============================================================================
25
+ # Configuration
26
+ # =============================================================================
27
+
28
+ # Rate limits by user tier (requests per minute)
29
+ RATE_LIMITS = {
30
+ "anonymous": {
31
+ "requests_per_minute": int(os.getenv("RATE_LIMIT_ANON_RPM", "30")),
32
+ "burst_size": int(os.getenv("RATE_LIMIT_ANON_BURST", "5")),
33
+ "description": "Unauthenticated users"
34
+ },
35
+ "user": {
36
+ "requests_per_minute": int(os.getenv("RATE_LIMIT_USER_RPM", "60")),
37
+ "burst_size": int(os.getenv("RATE_LIMIT_USER_BURST", "10")),
38
+ "description": "Standard authenticated users"
39
+ },
40
+ "premium": {
41
+ "requests_per_minute": int(os.getenv("RATE_LIMIT_PREMIUM_RPM", "120")),
42
+ "burst_size": int(os.getenv("RATE_LIMIT_PREMIUM_BURST", "20")),
43
+ "description": "Premium users"
44
+ },
45
+ "admin": {
46
+ "requests_per_minute": int(os.getenv("RATE_LIMIT_ADMIN_RPM", "300")),
47
+ "burst_size": int(os.getenv("RATE_LIMIT_ADMIN_BURST", "50")),
48
+ "description": "Administrators"
49
+ }
50
+ }
51
+
52
+ # Auth endpoints - stricter limits
53
+ AUTH_RATE_LIMIT = int(os.getenv("AUTH_RATE_LIMIT", "5"))
54
+ AUTH_LOCKOUT_DURATION = int(os.getenv("AUTH_LOCKOUT_DURATION", "300")) # 5 minutes
55
+
56
+ # User tier detection (customize based on your auth system)
57
+ UserTier = Literal["anonymous", "user", "premium", "admin"]
58
+
59
+
60
+ # =============================================================================
61
+ # Token Bucket (Thread-safe)
62
+ # =============================================================================
63
+
64
+ @dataclass
65
+ class TokenBucket:
66
+ """Token Bucket für Rate Limiting"""
67
+ rate: float # Tokens per second
68
+ burst_size: int
69
+ tokens: float = 0
70
+ last_update: float = 0
71
+
72
+ def __post_init__(self):
73
+ if self.last_update == 0:
74
+ self.last_update = time.time()
75
+ self.tokens = self.burst_size
76
+
77
+ def _add_tokens(self):
78
+ now = time.time()
79
+ time_passed = now - self.last_update
80
+ tokens_to_add = time_passed * self.rate
81
+ self.tokens = min(self.burst_size, self.tokens + tokens_to_add)
82
+ self.last_update = now
83
+
84
+ def consume(self, tokens: int = 1) -> bool:
85
+ self._add_tokens()
86
+ if self.tokens >= tokens:
87
+ self.tokens -= tokens
88
+ return True
89
+ return False
90
+
91
+ def get_wait_time(self, tokens: int = 1) -> float:
92
+ self._add_tokens()
93
+ if self.tokens >= tokens:
94
+ return 0
95
+ tokens_needed = tokens - self.tokens
96
+ return tokens_needed / self.rate
97
+
98
+ def to_dict(self) -> dict:
99
+ self._add_tokens()
100
+ return {
101
+ "tokens": self.tokens,
102
+ "burst_size": self.burst_size,
103
+ "rate": self.rate,
104
+ "last_update": self.last_update
105
+ }
106
+
107
+ @classmethod
108
+ def from_dict(cls, data: dict) -> "TokenBucket":
109
+ bucket = cls(
110
+ rate=data["rate"],
111
+ burst_size=data["burst_size"],
112
+ tokens=data["tokens"],
113
+ last_update=data["last_update"]
114
+ )
115
+ return bucket
116
+
117
+
118
+ # =============================================================================
119
+ # Storage Backends
120
+ # =============================================================================
121
+
122
+ class RateLimitStorage:
123
+ """Base storage class for rate limits"""
124
+
125
+ def get_bucket(self, key: str, rate: float, burst_size: int) -> TokenBucket:
126
+ raise NotImplementedError
127
+
128
+ def save_bucket(self, key: str, bucket: TokenBucket):
129
+ raise NotImplementedError
130
+
131
+ def cleanup(self):
132
+ pass
133
+
134
+
135
+ class MemoryStorage(RateLimitStorage):
136
+ """In-Memory storage (single instance only)"""
137
+
138
+ def __init__(self):
139
+ self.buckets: Dict[str, TokenBucket] = {}
140
+ self.last_access: Dict[str, float] = {}
141
+ self.metadata: Dict[str, dict] = {} # Store tier info
142
+
143
+ def get_bucket(self, key: str, rate: float, burst_size: int) -> TokenBucket:
144
+ if key not in self.buckets:
145
+ self.buckets[key] = TokenBucket(rate=rate, burst_size=burst_size)
146
+
147
+ self.last_access[key] = time.time()
148
+ return self.buckets[key]
149
+
150
+ def save_bucket(self, key: str, bucket: TokenBucket):
151
+ self.buckets[key] = bucket
152
+ self.last_access[key] = time.time()
153
+
154
+ def set_metadata(self, key: str, metadata: dict):
155
+ self.metadata[key] = metadata
156
+
157
+ def get_metadata(self, key: str) -> Optional[dict]:
158
+ return self.metadata.get(key)
159
+
160
+ def cleanup(self, max_age: float = 3600):
161
+ now = time.time()
162
+ to_remove = [
163
+ key for key, last in self.last_access.items()
164
+ if now - last > max_age
165
+ ]
166
+ for key in to_remove:
167
+ del self.buckets[key]
168
+ del self.last_access[key]
169
+ if key in self.metadata:
170
+ del self.metadata[key]
171
+
172
+
173
+ class RedisStorage(RateLimitStorage):
174
+ """Redis storage for distributed systems"""
175
+
176
+ def __init__(self, redis_url: str = None):
177
+ self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379/0")
178
+ self._redis = None
179
+ self._init_redis()
180
+
181
+ def _init_redis(self):
182
+ try:
183
+ import redis
184
+ self._redis = redis.from_url(self.redis_url, decode_responses=True)
185
+ self._redis.ping()
186
+ logger.info("Redis connection established for rate limiting")
187
+ except Exception as e:
188
+ logger.warning(f"Redis not available, falling back to memory: {e}")
189
+ self._redis = None
190
+
191
+ def _get_key(self, key: str) -> str:
192
+ return f"rate_limit:{key}"
193
+
194
+ def get_bucket(self, key: str, rate: float, burst_size: int) -> TokenBucket:
195
+ if not self._redis:
196
+ # Fallback to memory
197
+ return MemoryStorage().get_bucket(key, rate, burst_size)
198
+
199
+ redis_key = self._get_key(key)
200
+ data = self._redis.get(redis_key)
201
+
202
+ if data:
203
+ try:
204
+ bucket_data = json.loads(data)
205
+ return TokenBucket.from_dict(bucket_data)
206
+ except (json.JSONDecodeError, KeyError):
207
+ pass
208
+
209
+ return TokenBucket(rate=rate, burst_size=burst_size)
210
+
211
+ def save_bucket(self, key: str, bucket: TokenBucket):
212
+ if not self._redis:
213
+ return
214
+
215
+ redis_key = self._get_key(key)
216
+ data = json.dumps(bucket.to_dict())
217
+ # TTL: 1 hour
218
+ self._redis.setex(redis_key, 3600, data)
219
+
220
+
221
+ # Global storage instance
222
+ storage_backend = os.getenv("RATE_LIMIT_STORAGE", "memory")
223
+ if storage_backend == "redis":
224
+ try:
225
+ rate_limit_storage = RedisStorage()
226
+ except Exception as e:
227
+ logger.warning(f"Failed to initialize Redis storage: {e}")
228
+ rate_limit_storage = MemoryStorage()
229
+ else:
230
+ rate_limit_storage = MemoryStorage()
231
+
232
+
233
+ # =============================================================================
234
+ # User Context
235
+ # =============================================================================
236
+
237
+ @dataclass
238
+ class UserContext:
239
+ """User information for rate limiting"""
240
+ user_id: Optional[str] = None
241
+ username: Optional[str] = None
242
+ tier: UserTier = "anonymous"
243
+ ip_address: str = "unknown"
244
+
245
+ def get_rate_limit_key(self) -> str:
246
+ """Generate unique key for this user"""
247
+ if self.user_id:
248
+ return f"user:{self.user_id}"
249
+ # For anonymous: hash IP + user agent
250
+ ip_hash = hashlib.sha256(self.ip_address.encode()).hexdigest()[:16]
251
+ return f"anon:{ip_hash}"
252
+
253
+ def get_limits(self) -> dict:
254
+ """Get rate limits for this user's tier"""
255
+ return RATE_LIMITS.get(self.tier, RATE_LIMITS["anonymous"])
256
+
257
+
258
+ def get_user_from_request(request: Request) -> UserContext:
259
+ """
260
+ Extract user information from request.
261
+
262
+ Customize this based on your auth system!
263
+ """
264
+ client_ip = request.client.host if request.client else "unknown"
265
+
266
+ # Try to get user from JWT token or session
267
+ # This is a simplified example - adapt to your auth system
268
+ auth_header = request.headers.get("authorization", "")
269
+ user_id = None
270
+ username = None
271
+ tier: UserTier = "anonymous"
272
+
273
+ if auth_header.startswith("Bearer "):
274
+ # Extract user from JWT (simplified)
275
+ # In production: verify token and extract claims
276
+ try:
277
+ # Placeholder - integrate with your JWT auth
278
+ # user_id = decode_jwt(auth_header[7:]).get("sub")
279
+ # tier = get_user_tier(user_id)
280
+ pass
281
+ except Exception:
282
+ pass
283
+
284
+ # Check for admin/premium in headers (customize this!)
285
+ user_tier_header = request.headers.get("x-user-tier")
286
+ if user_tier_header in RATE_LIMITS:
287
+ tier = user_tier_header # type: ignore
288
+
289
+ return UserContext(
290
+ user_id=user_id,
291
+ username=username,
292
+ tier=tier,
293
+ ip_address=client_ip
294
+ )
295
+
296
+
297
+ # =============================================================================
298
+ # Rate Limiting Decorator
299
+ # =============================================================================
300
+
301
+ def rate_limit(
302
+ requests_per_minute: Optional[int] = None,
303
+ burst_size: Optional[int] = None,
304
+ tier: Optional[UserTier] = None
305
+ ):
306
+ """
307
+ Rate limiting decorator with user-based limits.
308
+
309
+ Usage:
310
+ @app.get("/api/data")
311
+ @rate_limit() # Uses user's tier limits
312
+ async def get_data(request: Request):
313
+ return {"data": "value"}
314
+
315
+ @app.get("/api/admin")
316
+ @rate_limit(requests_per_minute=600) # Custom limit
317
+ async def admin_endpoint(request: Request):
318
+ return {"admin": "data"}
319
+ """
320
+ def decorator(func: Callable) -> Callable:
321
+ @wraps(func)
322
+ async def wrapper(*args, **kwargs):
323
+ # Find Request object
324
+ request = None
325
+ for arg in args:
326
+ if isinstance(arg, Request):
327
+ request = arg
328
+ break
329
+
330
+ if not request:
331
+ return await func(*args, **kwargs)
332
+
333
+ # Get user context
334
+ user = get_user_from_request(request)
335
+
336
+ # Determine limits
337
+ if tier:
338
+ limits = RATE_LIMITS[tier]
339
+ elif requests_per_minute:
340
+ limits = {
341
+ "requests_per_minute": requests_per_minute,
342
+ "burst_size": burst_size or 10
343
+ }
344
+ else:
345
+ limits = user.get_limits()
346
+
347
+ # Get or create bucket
348
+ key = user.get_rate_limit_key()
349
+ rate_per_second = limits["requests_per_minute"] / 60
350
+ bucket = rate_limit_storage.get_bucket(
351
+ key, rate_per_second, limits["burst_size"]
352
+ )
353
+
354
+ # Store metadata
355
+ if isinstance(rate_limit_storage, MemoryStorage):
356
+ rate_limit_storage.set_metadata(key, {
357
+ "tier": user.tier,
358
+ "user_id": user.user_id,
359
+ "ip": user.ip_address
360
+ })
361
+
362
+ # Check rate limit
363
+ if not bucket.consume():
364
+ wait_time = bucket.get_wait_time()
365
+ logger.warning(
366
+ f"Rate limit exceeded for {user.tier} user {user.user_id or user.ip_address}"
367
+ )
368
+
369
+ raise HTTPException(
370
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
371
+ detail={
372
+ "error": "Rate limit exceeded",
373
+ "retry_after": int(wait_time),
374
+ "tier": user.tier,
375
+ "limit": limits["requests_per_minute"]
376
+ },
377
+ headers={
378
+ "Retry-After": str(int(wait_time)),
379
+ "X-RateLimit-Limit": str(limits["requests_per_minute"]),
380
+ "X-RateLimit-Remaining": "0",
381
+ "X-RateLimit-Tier": user.tier
382
+ }
383
+ )
384
+
385
+ # Save bucket state
386
+ rate_limit_storage.save_bucket(key, bucket)
387
+
388
+ # Add rate limit headers to response
389
+ response = await func(*args, **kwargs)
390
+
391
+ return response
392
+
393
+ return wrapper
394
+ return decorator
395
+
396
+
397
+ # =============================================================================
398
+ # Middleware für globales Rate Limiting
399
+ # =============================================================================
400
+
401
+ class UserRateLimitMiddleware:
402
+ """
403
+ ASGI Middleware für user-basiertes Rate Limiting.
404
+
405
+ Usage:
406
+ app.add_middleware(UserRateLimitMiddleware)
407
+ """
408
+
409
+ def __init__(self, app):
410
+ self.app = app
411
+
412
+ async def __call__(self, scope, receive, send):
413
+ if scope["type"] != "http":
414
+ await self.app(scope, receive, send)
415
+ return
416
+
417
+ # Build minimal request for user detection
418
+ client = scope.get("client")
419
+ client_ip = client[0] if client else "unknown"
420
+
421
+ # Simple tier detection from headers (customize!)
422
+ headers = dict(scope.get("headers", []))
423
+ user_tier = headers.get(b"x-user-tier", b"anonymous").decode()
424
+
425
+ if user_tier not in RATE_LIMITS:
426
+ user_tier = "anonymous"
427
+
428
+ limits = RATE_LIMITS[user_tier]
429
+
430
+ # Create key (simplified)
431
+ key = f"middleware:{user_tier}:{client_ip}"
432
+ rate_per_second = limits["requests_per_minute"] / 60
433
+
434
+ bucket = rate_limit_storage.get_bucket(
435
+ key, rate_per_second, limits["burst_size"]
436
+ )
437
+
438
+ if not bucket.consume():
439
+ wait_time = bucket.get_wait_time()
440
+
441
+ await send({
442
+ "type": "http.response.start",
443
+ "status": 429,
444
+ "headers": [
445
+ [b"content-type", b"application/json"],
446
+ [b"retry-after", str(int(wait_time)).encode()],
447
+ [b"x-ratelimit-tier", user_tier.encode()]
448
+ ]
449
+ })
450
+ await send({
451
+ "type": "http.response.body",
452
+ "body": json.dumps({
453
+ "error": "Rate limit exceeded",
454
+ "retry_after": int(wait_time),
455
+ "tier": user_tier
456
+ }).encode()
457
+ })
458
+ return
459
+
460
+ rate_limit_storage.save_bucket(key, bucket)
461
+ await self.app(scope, receive, send)
462
+
463
+
464
+ # =============================================================================
465
+ # Auth Rate Limiting (mit User-ID Tracking)
466
+ # =============================================================================
467
+
468
+ class UserAuthRateLimiter:
469
+ """
470
+ Erweiterte Auth Rate Limiting mit User-ID Tracking.
471
+
472
+ Trackt sowohl IP als auch User-ID für bessere Security.
473
+ """
474
+
475
+ def __init__(self):
476
+ self.ip_attempts: Dict[str, list] = {}
477
+ self.user_attempts: Dict[str, list] = {}
478
+ self.lockout_duration = AUTH_LOCKOUT_DURATION
479
+ self.max_attempts = 5
480
+
481
+ def _cleanup_old(self, attempts: list, window: int = 60) -> list:
482
+ now = time.time()
483
+ return [t for t in attempts if now - t < window]
484
+
485
+ def is_allowed(self, client_ip: str, user_id: Optional[str] = None) -> tuple[bool, Optional[int], str]:
486
+ """
487
+ Prüft ob Auth erlaubt.
488
+
489
+ Returns: (allowed, lockout_seconds, reason)
490
+ """
491
+ now = time.time()
492
+
493
+ # Check IP-based limits
494
+ self.ip_attempts[client_ip] = self._cleanup_old(
495
+ self.ip_attempts.get(client_ip, [])
496
+ )
497
+
498
+ if len(self.ip_attempts[client_ip]) >= self.max_attempts:
499
+ oldest = min(self.ip_attempts[client_ip])
500
+ lockout = self.lockout_duration - (now - oldest)
501
+ if lockout > 0:
502
+ return False, int(lockout), "ip_blocked"
503
+ self.ip_attempts[client_ip] = []
504
+
505
+ # Check user-based limits (if user_id provided)
506
+ if user_id:
507
+ self.user_attempts[user_id] = self._cleanup_old(
508
+ self.user_attempts.get(user_id, [])
509
+ )
510
+
511
+ if len(self.user_attempts[user_id]) >= self.max_attempts:
512
+ oldest = min(self.user_attempts[user_id])
513
+ lockout = self.lockout_duration - (now - oldest)
514
+ if lockout > 0:
515
+ return False, int(lockout), "user_blocked"
516
+ self.user_attempts[user_id] = []
517
+
518
+ return True, None, "ok"
519
+
520
+ def record_failure(self, client_ip: str, user_id: Optional[str] = None):
521
+ now = time.time()
522
+
523
+ if client_ip not in self.ip_attempts:
524
+ self.ip_attempts[client_ip] = []
525
+ self.ip_attempts[client_ip].append(now)
526
+
527
+ if user_id:
528
+ if user_id not in self.user_attempts:
529
+ self.user_attempts[user_id] = []
530
+ self.user_attempts[user_id].append(now)
531
+
532
+ def record_success(self, client_ip: str, user_id: Optional[str] = None):
533
+ if client_ip in self.ip_attempts:
534
+ del self.ip_attempts[client_ip]
535
+ if user_id and user_id in self.user_attempts:
536
+ del self.user_attempts[user_id]
537
+
538
+
539
+ # Global instance
540
+ user_auth_rate_limiter = UserAuthRateLimiter()
541
+
542
+
543
+ def check_user_auth_rate_limit(client_ip: str, user_id: Optional[str] = None):
544
+ """Prüft Auth Rate Limit"""
545
+ allowed, lockout, reason = user_auth_rate_limiter.is_allowed(client_ip, user_id)
546
+
547
+ if not allowed:
548
+ logger.warning(f"Auth rate limit exceeded: {reason} for {client_ip}/{user_id}")
549
+ raise HTTPException(
550
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
551
+ detail={
552
+ "error": "Too many login attempts",
553
+ "retry_after": lockout,
554
+ "reason": reason
555
+ },
556
+ headers={"Retry-After": str(lockout)}
557
+ )
558
+
559
+
560
+ # =============================================================================
561
+ # Stats und Monitoring
562
+ # =============================================================================
563
+
564
+ def get_rate_limit_stats() -> dict:
565
+ """Gibt aktuelle Rate Limiting Statistiken zurück"""
566
+ if not isinstance(rate_limit_storage, MemoryStorage):
567
+ return {"error": "Stats only available with memory storage"}
568
+
569
+ stats = {
570
+ "total_buckets": len(rate_limit_storage.buckets),
571
+ "by_tier": {"anonymous": 0, "user": 0, "premium": 0, "admin": 0, "unknown": 0}
572
+ }
573
+
574
+ for key, metadata in rate_limit_storage.metadata.items():
575
+ tier = metadata.get("tier", "unknown")
576
+ stats["by_tier"][tier] = stats["by_tier"].get(tier, 0) + 1
577
+
578
+ return stats
579
+
580
+
581
+ # Cleanup Job (optional)
582
+ def cleanup_rate_limits():
583
+ """Entfernt alte Rate Limit Buckets"""
584
+ if isinstance(rate_limit_storage, MemoryStorage):
585
+ rate_limit_storage.cleanup()
586
+ logger.info("Rate limit storage cleaned up")