zen-ai-pentest 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
api/rate_limiter.py ADDED
@@ -0,0 +1,317 @@
1
+ """
2
+ Rate Limiting für Zen-AI-Pentest API
3
+
4
+ Schützt API vor:
5
+ - Brute Force Angriffen
6
+ - DoS Attacken
7
+ - API Missbrauch
8
+ """
9
+
10
+ import os
11
+ import time
12
+ from typing import Dict, Optional, Callable
13
+ from functools import wraps
14
+ from fastapi import Request, HTTPException, status
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # =============================================================================
20
+ # Configuration
21
+ # =============================================================================
22
+
23
+ # Rate limits from environment variables
24
+ RATE_LIMIT_REQUESTS_PER_MINUTE = int(os.getenv("RATE_LIMIT_REQUESTS_PER_MINUTE", "60"))
25
+ RATE_LIMIT_BURST_SIZE = int(os.getenv("RATE_LIMIT_BURST_SIZE", "10"))
26
+
27
+ # Stricter limits for auth endpoints
28
+ AUTH_RATE_LIMIT = int(os.getenv("AUTH_RATE_LIMIT", "5")) # 5 attempts per minute
29
+
30
+
31
+ # =============================================================================
32
+ # Token Bucket Rate Limiter
33
+ # =============================================================================
34
+
35
+ class TokenBucket:
36
+ """
37
+ Token Bucket Algorithm für Rate Limiting.
38
+
39
+ Ermöglicht:
40
+ - Bursts von Requests (bis zu burst_size)
41
+ - Gleichmäßige Rate im Durchschnitt
42
+ """
43
+
44
+ def __init__(self, rate: int, burst_size: int):
45
+ """
46
+ Args:
47
+ rate: Tokens pro Minute
48
+ burst_size: Maximale Token-Anzahl (Burst)
49
+ """
50
+ self.rate = rate
51
+ self.burst_size = burst_size
52
+ self.tokens = burst_size
53
+ self.last_update = time.time()
54
+ self.lock = False
55
+
56
+ def _add_tokens(self):
57
+ """Fügt Tokens basierend auf vergangener Zeit hinzu"""
58
+ now = time.time()
59
+ time_passed = now - self.last_update
60
+ tokens_to_add = (time_passed / 60) * self.rate
61
+
62
+ self.tokens = min(self.burst_size, self.tokens + tokens_to_add)
63
+ self.last_update = now
64
+
65
+ def consume(self, tokens: int = 1) -> bool:
66
+ """
67
+ Versucht Tokens zu verbrauchen.
68
+
69
+ Returns:
70
+ True wenn erfolgreich, False wenn Rate Limit überschritten
71
+ """
72
+ self._add_tokens()
73
+
74
+ if self.tokens >= tokens:
75
+ self.tokens -= tokens
76
+ return True
77
+ return False
78
+
79
+ def get_wait_time(self, tokens: int = 1) -> float:
80
+ """Berechnet Wartezeit bis genug Tokens verfügbar"""
81
+ if self.tokens >= tokens:
82
+ return 0
83
+
84
+ tokens_needed = tokens - self.tokens
85
+ return (tokens_needed / self.rate) * 60
86
+
87
+
88
+ # =============================================================================
89
+ # Rate Limit Storage
90
+ # =============================================================================
91
+
92
+ class RateLimitStorage:
93
+ """
94
+ Speichert Rate Limit Buckets pro Client.
95
+
96
+ In Produktion: Redis verwenden für verteilte Systeme!
97
+ """
98
+
99
+ def __init__(self):
100
+ self.buckets: Dict[str, TokenBucket] = {}
101
+ self.last_access: Dict[str, float] = {}
102
+
103
+ def get_bucket(self, key: str, rate: int, burst_size: int) -> TokenBucket:
104
+ """Holt oder erstellt Bucket für Client"""
105
+ if key not in self.buckets:
106
+ self.buckets[key] = TokenBucket(rate, burst_size)
107
+
108
+ self.last_access[key] = time.time()
109
+ return self.buckets[key]
110
+
111
+ def cleanup_old_buckets(self, max_age: float = 3600):
112
+ """Entfernt alte Buckets (Housekeeping)"""
113
+ now = time.time()
114
+ to_remove = [
115
+ key for key, last in self.last_access.items()
116
+ if now - last > max_age
117
+ ]
118
+ for key in to_remove:
119
+ del self.buckets[key]
120
+ del self.last_access[key]
121
+
122
+
123
+ # Global storage
124
+ rate_limit_storage = RateLimitStorage()
125
+
126
+
127
+ # =============================================================================
128
+ # Rate Limiting Decorator
129
+ # =============================================================================
130
+
131
+ def rate_limit(requests_per_minute: int = None, burst_size: int = None):
132
+ """
133
+ Decorator für Rate Limiting auf Endpoints.
134
+
135
+ Usage:
136
+ @app.get("/api/data")
137
+ @rate_limit(requests_per_minute=30)
138
+ async def get_data():
139
+ return {"data": "value"}
140
+ """
141
+ rpm = requests_per_minute or RATE_LIMIT_REQUESTS_PER_MINUTE
142
+ burst = burst_size or RATE_LIMIT_BURST_SIZE
143
+
144
+ def decorator(func: Callable) -> Callable:
145
+ @wraps(func)
146
+ async def wrapper(*args, **kwargs):
147
+ # Finde Request Objekt
148
+ request = None
149
+ for arg in args:
150
+ if isinstance(arg, Request):
151
+ request = arg
152
+ break
153
+
154
+ if not request:
155
+ # Kein Request Objekt gefunden - Rate Limit nicht anwendbar
156
+ return await func(*args, **kwargs)
157
+
158
+ # Client Identifikation
159
+ client_ip = request.client.host if request.client else "unknown"
160
+ user_agent = request.headers.get("user-agent", "")
161
+ key = f"{client_ip}:{user_agent[:50]}" # Limit key
162
+
163
+ # Rate Limit prüfen
164
+ bucket = rate_limit_storage.get_bucket(key, rpm, burst)
165
+
166
+ if not bucket.consume():
167
+ wait_time = bucket.get_wait_time()
168
+ logger.warning(f"Rate limit exceeded for {client_ip}")
169
+ raise HTTPException(
170
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
171
+ detail=f"Rate limit exceeded. Try again in {wait_time:.0f} seconds.",
172
+ headers={"Retry-After": str(int(wait_time))}
173
+ )
174
+
175
+ return await func(*args, **kwargs)
176
+
177
+ return wrapper
178
+ return decorator
179
+
180
+
181
+ # =============================================================================
182
+ # Middleware für globales Rate Limiting
183
+ # =============================================================================
184
+
185
+ class RateLimitMiddleware:
186
+ """
187
+ ASGI Middleware für globales Rate Limiting.
188
+
189
+ Usage:
190
+ app.add_middleware(RateLimitMiddleware)
191
+ """
192
+
193
+ def __init__(self, app, requests_per_minute: int = None, burst_size: int = None):
194
+ self.app = app
195
+ self.rpm = requests_per_minute or RATE_LIMIT_REQUESTS_PER_MINUTE
196
+ self.burst = burst_size or RATE_LIMIT_BURST_SIZE
197
+
198
+ async def __call__(self, scope, receive, send):
199
+ if scope["type"] != "http":
200
+ await self.app(scope, receive, send)
201
+ return
202
+
203
+ # Client identifizieren
204
+ client = scope.get("client")
205
+ client_ip = client[0] if client else "unknown"
206
+
207
+ # Rate Limit prüfen
208
+ bucket = rate_limit_storage.get_bucket(client_ip, self.rpm, self.burst)
209
+
210
+ if not bucket.consume():
211
+ wait_time = bucket.get_wait_time()
212
+ logger.warning(f"Global rate limit exceeded for {client_ip}")
213
+
214
+ # 429 Response
215
+ await send({
216
+ "type": "http.response.start",
217
+ "status": 429,
218
+ "headers": [
219
+ [b"content-type", b"application/json"],
220
+ [b"retry-after", str(int(wait_time)).encode()]
221
+ ]
222
+ })
223
+ await send({
224
+ "type": "http.response.body",
225
+ "body": f'"Rate limit exceeded. Retry after {int(wait_time)} seconds."'.encode()
226
+ })
227
+ return
228
+
229
+ await self.app(scope, receive, send)
230
+
231
+
232
+ # =============================================================================
233
+ # Auth-spezifisches Rate Limiting
234
+ # =============================================================================
235
+
236
+ class AuthRateLimiter:
237
+ """
238
+ Spezielles Rate Limiting für Auth-Endpunkte.
239
+
240
+ Stricktere Limits gegen Brute Force.
241
+ """
242
+
243
+ def __init__(self):
244
+ self.failed_attempts: Dict[str, list] = {} # IP -> list of timestamps
245
+ self.lockout_duration = 300 # 5 Minuten Lockout
246
+ self.max_attempts = 5 # 5 Versuche pro Minute
247
+
248
+ def is_allowed(self, client_ip: str) -> tuple[bool, Optional[int]]:
249
+ """
250
+ Prüft ob Auth-Versuch erlaubt.
251
+
252
+ Returns:
253
+ (allowed, lockout_seconds)
254
+ """
255
+ now = time.time()
256
+
257
+ # Alte Einträge entfernen (älter als 1 Minute)
258
+ if client_ip in self.failed_attempts:
259
+ self.failed_attempts[client_ip] = [
260
+ t for t in self.failed_attempts[client_ip]
261
+ if now - t < 60
262
+ ]
263
+
264
+ attempts = len(self.failed_attempts.get(client_ip, []))
265
+
266
+ if attempts >= self.max_attempts:
267
+ # Lockout prüfen
268
+ oldest_attempt = min(self.failed_attempts[client_ip])
269
+ lockout_remaining = self.lockout_duration - (now - oldest_attempt)
270
+
271
+ if lockout_remaining > 0:
272
+ return False, int(lockout_remaining)
273
+ else:
274
+ # Lockout abgelaufen, zurücksetzen
275
+ self.failed_attempts[client_ip] = []
276
+
277
+ return True, None
278
+
279
+ def record_failure(self, client_ip: str):
280
+ """Speichert fehlgeschlagenen Versuch"""
281
+ if client_ip not in self.failed_attempts:
282
+ self.failed_attempts[client_ip] = []
283
+ self.failed_attempts[client_ip].append(time.time())
284
+
285
+ def record_success(self, client_ip: str):
286
+ """Löscht Failed Attempts bei Erfolg"""
287
+ if client_ip in self.failed_attempts:
288
+ del self.failed_attempts[client_ip]
289
+
290
+
291
+ # Global auth rate limiter
292
+ auth_rate_limiter = AuthRateLimiter()
293
+
294
+
295
+ def check_auth_rate_limit(client_ip: str):
296
+ """
297
+ Prüft Auth Rate Limit und wirft Exception falls überschritten.
298
+ """
299
+ allowed, lockout = auth_rate_limiter.is_allowed(client_ip)
300
+
301
+ if not allowed:
302
+ logger.warning(f"Auth rate limit exceeded for {client_ip}")
303
+ raise HTTPException(
304
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
305
+ detail=f"Too many login attempts. Please try again in {lockout} seconds.",
306
+ headers={"Retry-After": str(lockout)}
307
+ )
308
+
309
+
310
+ def record_auth_failure(client_ip: str):
311
+ """Record failed auth attempt"""
312
+ auth_rate_limiter.record_failure(client_ip)
313
+
314
+
315
+ def record_auth_success(client_ip: str):
316
+ """Record successful auth"""
317
+ auth_rate_limiter.record_success(client_ip)