mdb-engine 0.1.6__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +104 -11
- mdb_engine/auth/ARCHITECTURE.md +112 -0
- mdb_engine/auth/README.md +648 -11
- mdb_engine/auth/__init__.py +136 -29
- mdb_engine/auth/audit.py +592 -0
- mdb_engine/auth/base.py +252 -0
- mdb_engine/auth/casbin_factory.py +264 -69
- mdb_engine/auth/config_helpers.py +7 -6
- mdb_engine/auth/cookie_utils.py +3 -7
- mdb_engine/auth/csrf.py +373 -0
- mdb_engine/auth/decorators.py +3 -10
- mdb_engine/auth/dependencies.py +47 -50
- mdb_engine/auth/helpers.py +3 -3
- mdb_engine/auth/integration.py +53 -80
- mdb_engine/auth/jwt.py +2 -6
- mdb_engine/auth/middleware.py +77 -34
- mdb_engine/auth/oso_factory.py +18 -38
- mdb_engine/auth/provider.py +270 -171
- mdb_engine/auth/rate_limiter.py +504 -0
- mdb_engine/auth/restrictions.py +8 -24
- mdb_engine/auth/session_manager.py +14 -29
- mdb_engine/auth/shared_middleware.py +600 -0
- mdb_engine/auth/shared_users.py +759 -0
- mdb_engine/auth/token_store.py +14 -28
- mdb_engine/auth/users.py +54 -113
- mdb_engine/auth/utils.py +213 -15
- mdb_engine/cli/commands/generate.py +545 -9
- mdb_engine/cli/commands/validate.py +3 -7
- mdb_engine/cli/utils.py +3 -3
- mdb_engine/config.py +7 -21
- mdb_engine/constants.py +65 -0
- mdb_engine/core/README.md +117 -6
- mdb_engine/core/__init__.py +39 -7
- mdb_engine/core/app_registration.py +22 -41
- mdb_engine/core/app_secrets.py +290 -0
- mdb_engine/core/connection.py +18 -9
- mdb_engine/core/encryption.py +223 -0
- mdb_engine/core/engine.py +1057 -93
- mdb_engine/core/index_management.py +12 -16
- mdb_engine/core/manifest.py +459 -150
- mdb_engine/core/ray_integration.py +435 -0
- mdb_engine/core/seeding.py +10 -18
- mdb_engine/core/service_initialization.py +12 -23
- mdb_engine/core/types.py +2 -5
- mdb_engine/database/README.md +140 -17
- mdb_engine/database/__init__.py +17 -6
- mdb_engine/database/abstraction.py +25 -37
- mdb_engine/database/connection.py +11 -18
- mdb_engine/database/query_validator.py +367 -0
- mdb_engine/database/resource_limiter.py +204 -0
- mdb_engine/database/scoped_wrapper.py +713 -196
- mdb_engine/dependencies.py +426 -0
- mdb_engine/di/__init__.py +34 -0
- mdb_engine/di/container.py +248 -0
- mdb_engine/di/providers.py +205 -0
- mdb_engine/di/scopes.py +139 -0
- mdb_engine/embeddings/README.md +54 -24
- mdb_engine/embeddings/__init__.py +31 -24
- mdb_engine/embeddings/dependencies.py +37 -154
- mdb_engine/embeddings/service.py +11 -25
- mdb_engine/exceptions.py +92 -0
- mdb_engine/indexes/README.md +30 -13
- mdb_engine/indexes/__init__.py +1 -0
- mdb_engine/indexes/helpers.py +1 -1
- mdb_engine/indexes/manager.py +50 -114
- mdb_engine/memory/README.md +2 -2
- mdb_engine/memory/__init__.py +1 -2
- mdb_engine/memory/service.py +30 -87
- mdb_engine/observability/README.md +4 -2
- mdb_engine/observability/__init__.py +26 -9
- mdb_engine/observability/health.py +8 -9
- mdb_engine/observability/metrics.py +32 -12
- mdb_engine/repositories/__init__.py +34 -0
- mdb_engine/repositories/base.py +325 -0
- mdb_engine/repositories/mongo.py +233 -0
- mdb_engine/repositories/unit_of_work.py +166 -0
- mdb_engine/routing/README.md +1 -1
- mdb_engine/routing/__init__.py +1 -3
- mdb_engine/routing/websockets.py +25 -60
- mdb_engine-0.2.0.dist-info/METADATA +313 -0
- mdb_engine-0.2.0.dist-info/RECORD +96 -0
- mdb_engine-0.1.6.dist-info/METADATA +0 -213
- mdb_engine-0.1.6.dist-info/RECORD +0 -75
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/WHEEL +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rate Limiting for Authentication Endpoints
|
|
3
|
+
|
|
4
|
+
Provides rate limiting middleware to protect auth endpoints from brute-force attacks.
|
|
5
|
+
Supports both in-memory storage (single instance) and MongoDB-backed storage (distributed).
|
|
6
|
+
|
|
7
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
8
|
+
|
|
9
|
+
Features:
|
|
10
|
+
- Sliding window rate limiting algorithm
|
|
11
|
+
- Per-endpoint configurable limits via manifest
|
|
12
|
+
- IP + optional email-based tracking
|
|
13
|
+
- In-memory (default) or MongoDB storage
|
|
14
|
+
- 429 Too Many Requests with Retry-After header
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
# Via middleware (recommended for shared auth)
|
|
18
|
+
app.add_middleware(
|
|
19
|
+
AuthRateLimitMiddleware,
|
|
20
|
+
limits={
|
|
21
|
+
"/login": RateLimit(max_attempts=5, window_seconds=300),
|
|
22
|
+
"/register": RateLimit(max_attempts=3, window_seconds=3600),
|
|
23
|
+
}
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Via decorator (for specific endpoints)
|
|
27
|
+
@app.post("/login")
|
|
28
|
+
@rate_limit(max_attempts=5, window_seconds=300)
|
|
29
|
+
async def login(request: Request):
|
|
30
|
+
...
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import logging
|
|
34
|
+
import time
|
|
35
|
+
from collections import defaultdict
|
|
36
|
+
from dataclasses import dataclass
|
|
37
|
+
from datetime import datetime, timedelta
|
|
38
|
+
from functools import wraps
|
|
39
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
40
|
+
|
|
41
|
+
from pymongo.errors import OperationFailure
|
|
42
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
43
|
+
from starlette.requests import Request
|
|
44
|
+
from starlette.responses import JSONResponse, Response
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class RateLimit:
|
|
51
|
+
"""Rate limit configuration for an endpoint."""
|
|
52
|
+
|
|
53
|
+
max_attempts: int = 5
|
|
54
|
+
window_seconds: int = 300 # 5 minutes
|
|
55
|
+
|
|
56
|
+
def to_dict(self) -> Dict[str, int]:
|
|
57
|
+
return {
|
|
58
|
+
"max_attempts": self.max_attempts,
|
|
59
|
+
"window_seconds": self.window_seconds,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Default rate limits for auth endpoints
|
|
64
|
+
DEFAULT_AUTH_RATE_LIMITS: Dict[str, RateLimit] = {
|
|
65
|
+
"/login": RateLimit(max_attempts=5, window_seconds=300),
|
|
66
|
+
"/register": RateLimit(max_attempts=3, window_seconds=3600),
|
|
67
|
+
"/logout": RateLimit(max_attempts=10, window_seconds=60),
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class InMemoryRateLimitStore:
|
|
72
|
+
"""
|
|
73
|
+
In-memory rate limit storage using sliding window algorithm.
|
|
74
|
+
|
|
75
|
+
Suitable for single-instance deployments. For distributed systems,
|
|
76
|
+
use MongoDBRateLimitStore instead.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self):
|
|
80
|
+
# Structure: {identifier: [(timestamp, count), ...]}
|
|
81
|
+
self._storage: Dict[str, List[Tuple[float, int]]] = defaultdict(list)
|
|
82
|
+
|
|
83
|
+
async def record_attempt(
|
|
84
|
+
self,
|
|
85
|
+
identifier: str,
|
|
86
|
+
window_seconds: int,
|
|
87
|
+
) -> int:
|
|
88
|
+
"""
|
|
89
|
+
Record an attempt and return current count in window.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
identifier: Unique identifier (e.g., "login:192.168.1.1:user@example.com")
|
|
93
|
+
window_seconds: Time window in seconds
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Number of attempts in the current window (including this one)
|
|
97
|
+
"""
|
|
98
|
+
now = time.time()
|
|
99
|
+
cutoff = now - window_seconds
|
|
100
|
+
|
|
101
|
+
# Clean old entries and count current
|
|
102
|
+
entries = self._storage[identifier]
|
|
103
|
+
entries[:] = [(ts, c) for ts, c in entries if ts > cutoff]
|
|
104
|
+
|
|
105
|
+
# Add new attempt
|
|
106
|
+
entries.append((now, 1))
|
|
107
|
+
|
|
108
|
+
# Return total count
|
|
109
|
+
return sum(c for _, c in entries)
|
|
110
|
+
|
|
111
|
+
async def get_count(
|
|
112
|
+
self,
|
|
113
|
+
identifier: str,
|
|
114
|
+
window_seconds: int,
|
|
115
|
+
) -> int:
|
|
116
|
+
"""Get current attempt count without recording."""
|
|
117
|
+
now = time.time()
|
|
118
|
+
cutoff = now - window_seconds
|
|
119
|
+
|
|
120
|
+
entries = self._storage.get(identifier, [])
|
|
121
|
+
return sum(c for ts, c in entries if ts > cutoff)
|
|
122
|
+
|
|
123
|
+
async def reset(self, identifier: str) -> None:
|
|
124
|
+
"""Reset rate limit for an identifier (e.g., after successful login)."""
|
|
125
|
+
self._storage.pop(identifier, None)
|
|
126
|
+
|
|
127
|
+
def cleanup(self, max_age_seconds: int = 7200) -> int:
|
|
128
|
+
"""
|
|
129
|
+
Clean up old entries to prevent memory growth.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
max_age_seconds: Remove entries older than this (default: 2 hours)
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Number of identifiers cleaned up
|
|
136
|
+
"""
|
|
137
|
+
now = time.time()
|
|
138
|
+
cutoff = now - max_age_seconds
|
|
139
|
+
cleaned = 0
|
|
140
|
+
|
|
141
|
+
identifiers_to_remove = []
|
|
142
|
+
for identifier, entries in self._storage.items():
|
|
143
|
+
entries[:] = [(ts, c) for ts, c in entries if ts > cutoff]
|
|
144
|
+
if not entries:
|
|
145
|
+
identifiers_to_remove.append(identifier)
|
|
146
|
+
|
|
147
|
+
for identifier in identifiers_to_remove:
|
|
148
|
+
del self._storage[identifier]
|
|
149
|
+
cleaned += 1
|
|
150
|
+
|
|
151
|
+
return cleaned
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MongoDBRateLimitStore:
|
|
155
|
+
"""
|
|
156
|
+
MongoDB-backed rate limit storage for distributed deployments.
|
|
157
|
+
|
|
158
|
+
Uses TTL indexes for automatic cleanup.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
COLLECTION = "_mdb_engine_rate_limits"
|
|
162
|
+
|
|
163
|
+
def __init__(self, db):
|
|
164
|
+
"""
|
|
165
|
+
Initialize MongoDB rate limit store.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
db: MongoDB database instance (Motor AsyncIOMotorDatabase)
|
|
169
|
+
"""
|
|
170
|
+
self._db = db
|
|
171
|
+
self._collection = db[self.COLLECTION]
|
|
172
|
+
self._indexes_created = False
|
|
173
|
+
|
|
174
|
+
async def ensure_indexes(self) -> None:
|
|
175
|
+
"""Create necessary indexes."""
|
|
176
|
+
if self._indexes_created:
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
# Compound index for lookups
|
|
181
|
+
await self._collection.create_index(
|
|
182
|
+
[("identifier", 1), ("timestamp", -1)], name="identifier_timestamp_idx"
|
|
183
|
+
)
|
|
184
|
+
# TTL index for cleanup
|
|
185
|
+
await self._collection.create_index(
|
|
186
|
+
"expires_at", expireAfterSeconds=0, name="expires_at_ttl_idx"
|
|
187
|
+
)
|
|
188
|
+
self._indexes_created = True
|
|
189
|
+
logger.info("Rate limit indexes ensured")
|
|
190
|
+
except OperationFailure as e:
|
|
191
|
+
logger.warning(f"Failed to create rate limit indexes: {e}")
|
|
192
|
+
|
|
193
|
+
async def record_attempt(
|
|
194
|
+
self,
|
|
195
|
+
identifier: str,
|
|
196
|
+
window_seconds: int,
|
|
197
|
+
) -> int:
|
|
198
|
+
"""Record an attempt and return current count in window."""
|
|
199
|
+
await self.ensure_indexes()
|
|
200
|
+
|
|
201
|
+
now = datetime.utcnow()
|
|
202
|
+
expires_at = now + timedelta(seconds=window_seconds)
|
|
203
|
+
cutoff = now - timedelta(seconds=window_seconds)
|
|
204
|
+
|
|
205
|
+
# Insert attempt
|
|
206
|
+
await self._collection.insert_one(
|
|
207
|
+
{
|
|
208
|
+
"identifier": identifier,
|
|
209
|
+
"timestamp": now,
|
|
210
|
+
"expires_at": expires_at,
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Count attempts in window
|
|
215
|
+
count = await self._collection.count_documents(
|
|
216
|
+
{
|
|
217
|
+
"identifier": identifier,
|
|
218
|
+
"timestamp": {"$gte": cutoff},
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return count
|
|
223
|
+
|
|
224
|
+
async def get_count(
|
|
225
|
+
self,
|
|
226
|
+
identifier: str,
|
|
227
|
+
window_seconds: int,
|
|
228
|
+
) -> int:
|
|
229
|
+
"""Get current attempt count without recording."""
|
|
230
|
+
await self.ensure_indexes()
|
|
231
|
+
|
|
232
|
+
cutoff = datetime.utcnow() - timedelta(seconds=window_seconds)
|
|
233
|
+
|
|
234
|
+
count = await self._collection.count_documents(
|
|
235
|
+
{
|
|
236
|
+
"identifier": identifier,
|
|
237
|
+
"timestamp": {"$gte": cutoff},
|
|
238
|
+
}
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return count
|
|
242
|
+
|
|
243
|
+
async def reset(self, identifier: str) -> None:
|
|
244
|
+
"""Reset rate limit for an identifier."""
|
|
245
|
+
await self._collection.delete_many({"identifier": identifier})
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# Global in-memory store (shared across middleware instances in same process)
|
|
249
|
+
_default_store = InMemoryRateLimitStore()
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class AuthRateLimitMiddleware(BaseHTTPMiddleware):
|
|
253
|
+
"""
|
|
254
|
+
ASGI middleware for rate limiting authentication endpoints.
|
|
255
|
+
|
|
256
|
+
Automatically protects /login, /register, and other auth endpoints
|
|
257
|
+
from brute-force attacks.
|
|
258
|
+
|
|
259
|
+
Features:
|
|
260
|
+
- Configurable per-endpoint limits
|
|
261
|
+
- IP + email tracking for login attempts
|
|
262
|
+
- 429 responses with Retry-After header
|
|
263
|
+
- Skips rate limiting for non-auth endpoints
|
|
264
|
+
|
|
265
|
+
Usage:
|
|
266
|
+
# Basic usage with defaults
|
|
267
|
+
app.add_middleware(AuthRateLimitMiddleware)
|
|
268
|
+
|
|
269
|
+
# Custom limits
|
|
270
|
+
app.add_middleware(
|
|
271
|
+
AuthRateLimitMiddleware,
|
|
272
|
+
limits={"/login": RateLimit(max_attempts=3, window_seconds=60)}
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# With MongoDB storage for distributed deployments
|
|
276
|
+
app.add_middleware(
|
|
277
|
+
AuthRateLimitMiddleware,
|
|
278
|
+
store=MongoDBRateLimitStore(db)
|
|
279
|
+
)
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(
|
|
283
|
+
self,
|
|
284
|
+
app: Callable,
|
|
285
|
+
limits: Optional[Dict[str, RateLimit]] = None,
|
|
286
|
+
store: Optional[InMemoryRateLimitStore] = None,
|
|
287
|
+
include_email_in_key: bool = True,
|
|
288
|
+
):
|
|
289
|
+
"""
|
|
290
|
+
Initialize rate limit middleware.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
app: ASGI application
|
|
294
|
+
limits: Dict of path -> RateLimit config. Defaults to DEFAULT_AUTH_RATE_LIMITS.
|
|
295
|
+
store: Rate limit storage backend. Defaults to in-memory store.
|
|
296
|
+
include_email_in_key: Include email in rate limit key for more granular limits.
|
|
297
|
+
"""
|
|
298
|
+
super().__init__(app)
|
|
299
|
+
self._limits = limits or DEFAULT_AUTH_RATE_LIMITS
|
|
300
|
+
self._store = store or _default_store
|
|
301
|
+
self._include_email_in_key = include_email_in_key
|
|
302
|
+
|
|
303
|
+
logger.info(
|
|
304
|
+
f"AuthRateLimitMiddleware initialized with limits for: {list(self._limits.keys())}"
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
async def dispatch(
|
|
308
|
+
self,
|
|
309
|
+
request: Request,
|
|
310
|
+
call_next: Callable[[Request], Response],
|
|
311
|
+
) -> Response:
|
|
312
|
+
"""Process request through rate limiter."""
|
|
313
|
+
path = request.url.path
|
|
314
|
+
method = request.method
|
|
315
|
+
|
|
316
|
+
# Only rate limit POST requests to configured endpoints
|
|
317
|
+
if method != "POST" or path not in self._limits:
|
|
318
|
+
return await call_next(request)
|
|
319
|
+
|
|
320
|
+
limit = self._limits[path]
|
|
321
|
+
identifier = await self._build_identifier(request, path)
|
|
322
|
+
|
|
323
|
+
# Check current count (before recording this attempt)
|
|
324
|
+
current_count = await self._store.get_count(identifier, limit.window_seconds)
|
|
325
|
+
|
|
326
|
+
if current_count >= limit.max_attempts:
|
|
327
|
+
logger.warning(
|
|
328
|
+
f"Rate limit exceeded: {identifier} "
|
|
329
|
+
f"({current_count}/{limit.max_attempts} in {limit.window_seconds}s)"
|
|
330
|
+
)
|
|
331
|
+
return self._rate_limit_response(limit.window_seconds)
|
|
332
|
+
|
|
333
|
+
# Record this attempt
|
|
334
|
+
await self._store.record_attempt(identifier, limit.window_seconds)
|
|
335
|
+
|
|
336
|
+
# Process request
|
|
337
|
+
response = await call_next(request)
|
|
338
|
+
|
|
339
|
+
# Reset on successful login (2xx response)
|
|
340
|
+
if response.status_code < 300 and path == "/login":
|
|
341
|
+
await self._store.reset(identifier)
|
|
342
|
+
|
|
343
|
+
return response
|
|
344
|
+
|
|
345
|
+
async def _build_identifier(self, request: Request, path: str) -> str:
|
|
346
|
+
"""Build rate limit identifier from request."""
|
|
347
|
+
parts = [path]
|
|
348
|
+
|
|
349
|
+
# Add client IP
|
|
350
|
+
client_ip = self._get_client_ip(request)
|
|
351
|
+
parts.append(client_ip)
|
|
352
|
+
|
|
353
|
+
# Optionally add email for more granular rate limiting
|
|
354
|
+
if self._include_email_in_key:
|
|
355
|
+
email = await self._extract_email(request)
|
|
356
|
+
if email:
|
|
357
|
+
parts.append(email)
|
|
358
|
+
|
|
359
|
+
return ":".join(parts)
|
|
360
|
+
|
|
361
|
+
def _get_client_ip(self, request: Request) -> str:
|
|
362
|
+
"""Get client IP, respecting proxy headers."""
|
|
363
|
+
# Check X-Forwarded-For header (set by proxies/load balancers)
|
|
364
|
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
365
|
+
if forwarded_for:
|
|
366
|
+
# Take the first IP (original client)
|
|
367
|
+
return forwarded_for.split(",")[0].strip()
|
|
368
|
+
|
|
369
|
+
# Check X-Real-IP header
|
|
370
|
+
real_ip = request.headers.get("X-Real-IP")
|
|
371
|
+
if real_ip:
|
|
372
|
+
return real_ip.strip()
|
|
373
|
+
|
|
374
|
+
# Fall back to direct client IP
|
|
375
|
+
if request.client:
|
|
376
|
+
return request.client.host
|
|
377
|
+
|
|
378
|
+
return "unknown"
|
|
379
|
+
|
|
380
|
+
async def _extract_email(self, request: Request) -> Optional[str]:
|
|
381
|
+
"""Try to extract email from request body."""
|
|
382
|
+
try:
|
|
383
|
+
# Only try to read body for JSON requests
|
|
384
|
+
content_type = request.headers.get("content-type", "")
|
|
385
|
+
if "application/json" in content_type:
|
|
386
|
+
# Note: This consumes the body, so we need to be careful
|
|
387
|
+
# In practice, this is called before the body is read by the route
|
|
388
|
+
body = await request.body()
|
|
389
|
+
if body:
|
|
390
|
+
import json
|
|
391
|
+
|
|
392
|
+
data = json.loads(body)
|
|
393
|
+
return data.get("email")
|
|
394
|
+
except (ValueError, UnicodeDecodeError, KeyError):
|
|
395
|
+
pass
|
|
396
|
+
return None
|
|
397
|
+
|
|
398
|
+
@staticmethod
|
|
399
|
+
def _rate_limit_response(retry_after: int) -> JSONResponse:
|
|
400
|
+
"""Return 429 Too Many Requests response."""
|
|
401
|
+
return JSONResponse(
|
|
402
|
+
status_code=429,
|
|
403
|
+
content={
|
|
404
|
+
"detail": f"Too many attempts. Please try again in {retry_after} seconds.",
|
|
405
|
+
"error": "rate_limit_exceeded",
|
|
406
|
+
"retry_after": retry_after,
|
|
407
|
+
},
|
|
408
|
+
headers={"Retry-After": str(retry_after)},
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def create_rate_limit_middleware(
|
|
413
|
+
manifest_auth: Dict[str, Any],
|
|
414
|
+
store: Optional[InMemoryRateLimitStore] = None,
|
|
415
|
+
) -> type:
|
|
416
|
+
"""
|
|
417
|
+
Factory function to create rate limit middleware from manifest config.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
manifest_auth: Auth section from manifest
|
|
421
|
+
store: Optional storage backend
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
Configured middleware class
|
|
425
|
+
|
|
426
|
+
Manifest format:
|
|
427
|
+
{
|
|
428
|
+
"auth": {
|
|
429
|
+
"rate_limits": {
|
|
430
|
+
"/login": {"max_attempts": 5, "window_seconds": 300},
|
|
431
|
+
"/register": {"max_attempts": 3, "window_seconds": 3600}
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
"""
|
|
436
|
+
rate_limits_config = manifest_auth.get("rate_limits", {})
|
|
437
|
+
|
|
438
|
+
limits: Dict[str, RateLimit] = {}
|
|
439
|
+
for path, config in rate_limits_config.items():
|
|
440
|
+
limits[path] = RateLimit(
|
|
441
|
+
max_attempts=config.get("max_attempts", 5),
|
|
442
|
+
window_seconds=config.get("window_seconds", 300),
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Use defaults if no config provided
|
|
446
|
+
if not limits:
|
|
447
|
+
limits = DEFAULT_AUTH_RATE_LIMITS.copy()
|
|
448
|
+
|
|
449
|
+
class ConfiguredRateLimitMiddleware(AuthRateLimitMiddleware):
|
|
450
|
+
def __init__(self, app: Callable):
|
|
451
|
+
super().__init__(app, limits=limits, store=store)
|
|
452
|
+
|
|
453
|
+
return ConfiguredRateLimitMiddleware
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def rate_limit(
|
|
457
|
+
max_attempts: int = 5,
|
|
458
|
+
window_seconds: int = 300,
|
|
459
|
+
key_func: Optional[Callable[[Request], str]] = None,
|
|
460
|
+
):
|
|
461
|
+
"""
|
|
462
|
+
Decorator for rate limiting individual endpoints.
|
|
463
|
+
|
|
464
|
+
Usage:
|
|
465
|
+
@app.post("/login")
|
|
466
|
+
@rate_limit(max_attempts=5, window_seconds=300)
|
|
467
|
+
async def login(request: Request):
|
|
468
|
+
...
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
max_attempts: Maximum attempts in window
|
|
472
|
+
window_seconds: Time window in seconds
|
|
473
|
+
key_func: Optional function to generate rate limit key from request
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
def decorator(func: Callable) -> Callable:
|
|
477
|
+
@wraps(func)
|
|
478
|
+
async def wrapper(request: Request, *args, **kwargs):
|
|
479
|
+
# Build identifier
|
|
480
|
+
if key_func:
|
|
481
|
+
identifier = key_func(request)
|
|
482
|
+
else:
|
|
483
|
+
client_ip = request.client.host if request.client else "unknown"
|
|
484
|
+
identifier = f"{func.__name__}:{client_ip}"
|
|
485
|
+
|
|
486
|
+
# Check rate limit
|
|
487
|
+
count = await _default_store.record_attempt(identifier, window_seconds)
|
|
488
|
+
|
|
489
|
+
if count > max_attempts:
|
|
490
|
+
logger.warning(f"Rate limit exceeded: {identifier} ({count}/{max_attempts})")
|
|
491
|
+
return JSONResponse(
|
|
492
|
+
status_code=429,
|
|
493
|
+
content={
|
|
494
|
+
"detail": f"Too many attempts. Try again in {window_seconds} seconds.",
|
|
495
|
+
"error": "rate_limit_exceeded",
|
|
496
|
+
},
|
|
497
|
+
headers={"Retry-After": str(window_seconds)},
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
return await func(request, *args, **kwargs)
|
|
501
|
+
|
|
502
|
+
return wrapper
|
|
503
|
+
|
|
504
|
+
return decorator
|
mdb_engine/auth/restrictions.py
CHANGED
|
@@ -27,9 +27,7 @@ from .users import get_app_user
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def is_demo_user(
|
|
31
|
-
user: Optional[Dict[str, Any]] = None, email: Optional[str] = None
|
|
32
|
-
) -> bool:
|
|
30
|
+
def is_demo_user(user: Optional[Dict[str, Any]] = None, email: Optional[str] = None) -> bool:
|
|
33
31
|
"""
|
|
34
32
|
Check if a user is a demo user.
|
|
35
33
|
|
|
@@ -85,9 +83,7 @@ async def _get_sub_auth_user(
|
|
|
85
83
|
if not (config and users_config.get("enabled", False)):
|
|
86
84
|
return None
|
|
87
85
|
|
|
88
|
-
app_user = await get_app_user(
|
|
89
|
-
request, slug_id, db, config, allow_demo_fallback=False
|
|
90
|
-
)
|
|
86
|
+
app_user = await get_app_user(request, slug_id, db, config, allow_demo_fallback=False)
|
|
91
87
|
if not app_user:
|
|
92
88
|
return None
|
|
93
89
|
|
|
@@ -119,9 +115,7 @@ async def _get_authenticated_user(
|
|
|
119
115
|
|
|
120
116
|
# Try sub-auth if platform auth didn't work
|
|
121
117
|
if get_app_db_func and get_app_config_func:
|
|
122
|
-
return await _get_sub_auth_user(
|
|
123
|
-
request, slug_id, get_app_config_func, get_app_db_func
|
|
124
|
-
)
|
|
118
|
+
return await _get_sub_auth_user(request, slug_id, get_app_config_func, get_app_db_func)
|
|
125
119
|
|
|
126
120
|
return None
|
|
127
121
|
|
|
@@ -158,9 +152,7 @@ def _validate_dependencies(
|
|
|
158
152
|
|
|
159
153
|
async def require_non_demo_user(
|
|
160
154
|
request: Request,
|
|
161
|
-
get_app_config_func: Optional[
|
|
162
|
-
Callable[[Request, str, Dict], Awaitable[Dict]]
|
|
163
|
-
] = None,
|
|
155
|
+
get_app_config_func: Optional[Callable[[Request, str, Dict], Awaitable[Dict]]] = None,
|
|
164
156
|
get_app_db_func: Optional[Callable[[Request], Awaitable[Any]]] = None,
|
|
165
157
|
) -> Dict[str, Any]:
|
|
166
158
|
"""
|
|
@@ -189,9 +181,7 @@ async def require_non_demo_user(
|
|
|
189
181
|
if get_app_db_func and get_app_config_func:
|
|
190
182
|
_validate_dependencies(get_app_config_func, get_app_db_func)
|
|
191
183
|
|
|
192
|
-
user = await _get_authenticated_user(
|
|
193
|
-
request, slug_id, get_app_config_func, get_app_db_func
|
|
194
|
-
)
|
|
184
|
+
user = await _get_authenticated_user(request, slug_id, get_app_config_func, get_app_db_func)
|
|
195
185
|
|
|
196
186
|
# Check if user is demo
|
|
197
187
|
if user and is_demo_user(user):
|
|
@@ -215,9 +205,7 @@ async def require_non_demo_user(
|
|
|
215
205
|
|
|
216
206
|
async def block_demo_users(
|
|
217
207
|
request: Request,
|
|
218
|
-
get_app_config_func: Optional[
|
|
219
|
-
Callable[[Request, str, Dict], Awaitable[Dict]]
|
|
220
|
-
] = None,
|
|
208
|
+
get_app_config_func: Optional[Callable[[Request, str, Dict], Awaitable[Dict]]] = None,
|
|
221
209
|
get_app_db_func: Optional[Callable[[Request], Awaitable[Any]]] = None,
|
|
222
210
|
):
|
|
223
211
|
"""
|
|
@@ -253,15 +241,11 @@ async def block_demo_users(
|
|
|
253
241
|
|
|
254
242
|
# Check sub-auth if platform auth didn't work
|
|
255
243
|
if not user and get_app_db_func and get_app_config_func and slug_id:
|
|
256
|
-
user = await _get_sub_auth_user(
|
|
257
|
-
request, slug_id, get_app_config_func, get_app_db_func
|
|
258
|
-
)
|
|
244
|
+
user = await _get_sub_auth_user(request, slug_id, get_app_config_func, get_app_db_func)
|
|
259
245
|
|
|
260
246
|
# Block if demo user (only if user exists)
|
|
261
247
|
if user and is_demo_user(user):
|
|
262
|
-
logger.info(
|
|
263
|
-
f"Demo user '{user.get('email')}' blocked from accessing: {request.url.path}"
|
|
264
|
-
)
|
|
248
|
+
logger.info(f"Demo user '{user.get('email')}' blocked from accessing: {request.url.path}")
|
|
265
249
|
raise HTTPException(
|
|
266
250
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
267
251
|
detail="Demo users cannot access this endpoint. Demo mode is read-only.",
|
|
@@ -13,8 +13,11 @@ from typing import Any, Dict, List, Optional
|
|
|
13
13
|
from bson.objectid import ObjectId
|
|
14
14
|
|
|
15
15
|
try:
|
|
16
|
-
from pymongo.errors import (
|
|
17
|
-
|
|
16
|
+
from pymongo.errors import (
|
|
17
|
+
ConnectionFailure,
|
|
18
|
+
OperationFailure,
|
|
19
|
+
ServerSelectionTimeoutError,
|
|
20
|
+
)
|
|
18
21
|
except ImportError:
|
|
19
22
|
ConnectionFailure = Exception
|
|
20
23
|
OperationFailure = Exception
|
|
@@ -120,9 +123,7 @@ class SessionManager:
|
|
|
120
123
|
# Remove oldest inactive session
|
|
121
124
|
await self.cleanup_inactive_sessions(user_id)
|
|
122
125
|
# Check again
|
|
123
|
-
active_sessions = await self.get_user_sessions(
|
|
124
|
-
user_id, active_only=True
|
|
125
|
-
)
|
|
126
|
+
active_sessions = await self.get_user_sessions(user_id, active_only=True)
|
|
126
127
|
if len(active_sessions) >= self.max_sessions:
|
|
127
128
|
# Force remove oldest session
|
|
128
129
|
if active_sessions:
|
|
@@ -168,9 +169,7 @@ class SessionManager:
|
|
|
168
169
|
ValueError,
|
|
169
170
|
TypeError,
|
|
170
171
|
) as e:
|
|
171
|
-
logger.error(
|
|
172
|
-
f"Error creating session for user {user_id}: {e}", exc_info=True
|
|
173
|
-
)
|
|
172
|
+
logger.error(f"Error creating session for user {user_id}: {e}", exc_info=True)
|
|
174
173
|
return None
|
|
175
174
|
|
|
176
175
|
async def update_session_activity(
|
|
@@ -205,14 +204,10 @@ class SessionManager:
|
|
|
205
204
|
ValueError,
|
|
206
205
|
TypeError,
|
|
207
206
|
) as e:
|
|
208
|
-
logger.error(
|
|
209
|
-
f"Error updating session activity for {refresh_jti}: {e}", exc_info=True
|
|
210
|
-
)
|
|
207
|
+
logger.error(f"Error updating session activity for {refresh_jti}: {e}", exc_info=True)
|
|
211
208
|
return False
|
|
212
209
|
|
|
213
|
-
async def get_session_by_refresh_token(
|
|
214
|
-
self, refresh_jti: str
|
|
215
|
-
) -> Optional[Dict[str, Any]]:
|
|
210
|
+
async def get_session_by_refresh_token(self, refresh_jti: str) -> Optional[Dict[str, Any]]:
|
|
216
211
|
"""
|
|
217
212
|
Get session by refresh token JWT ID.
|
|
218
213
|
|
|
@@ -223,9 +218,7 @@ class SessionManager:
|
|
|
223
218
|
Session document or None if not found
|
|
224
219
|
"""
|
|
225
220
|
try:
|
|
226
|
-
session = await self.collection.find_one(
|
|
227
|
-
{"refresh_jti": refresh_jti, "active": True}
|
|
228
|
-
)
|
|
221
|
+
session = await self.collection.find_one({"refresh_jti": refresh_jti, "active": True})
|
|
229
222
|
return session
|
|
230
223
|
except (
|
|
231
224
|
OperationFailure,
|
|
@@ -267,9 +260,7 @@ class SessionManager:
|
|
|
267
260
|
stored_fingerprint = session.get("session_fingerprint")
|
|
268
261
|
|
|
269
262
|
if not stored_fingerprint:
|
|
270
|
-
strict_mode =
|
|
271
|
-
strict if strict is not None else self.fingerprinting_strict
|
|
272
|
-
)
|
|
263
|
+
strict_mode = strict if strict is not None else self.fingerprinting_strict
|
|
273
264
|
return not strict_mode
|
|
274
265
|
|
|
275
266
|
return stored_fingerprint == current_fingerprint
|
|
@@ -313,9 +304,7 @@ class SessionManager:
|
|
|
313
304
|
if active_only:
|
|
314
305
|
query["active"] = True
|
|
315
306
|
|
|
316
|
-
sessions = (
|
|
317
|
-
await self.collection.find(query).sort("last_seen", -1).to_list(None)
|
|
318
|
-
)
|
|
307
|
+
sessions = await self.collection.find(query).sort("last_seen", -1).to_list(None)
|
|
319
308
|
return sessions
|
|
320
309
|
except (
|
|
321
310
|
OperationFailure,
|
|
@@ -324,9 +313,7 @@ class SessionManager:
|
|
|
324
313
|
ValueError,
|
|
325
314
|
TypeError,
|
|
326
315
|
) as e:
|
|
327
|
-
logger.error(
|
|
328
|
-
f"Error getting sessions for user {user_id}: {e}", exc_info=True
|
|
329
|
-
)
|
|
316
|
+
logger.error(f"Error getting sessions for user {user_id}: {e}", exc_info=True)
|
|
330
317
|
return []
|
|
331
318
|
|
|
332
319
|
async def revoke_session(self, session_id: Any) -> bool:
|
|
@@ -400,9 +387,7 @@ class SessionManager:
|
|
|
400
387
|
ValueError,
|
|
401
388
|
TypeError,
|
|
402
389
|
) as e:
|
|
403
|
-
logger.error(
|
|
404
|
-
f"Error revoking sessions for user {user_id}: {e}", exc_info=True
|
|
405
|
-
)
|
|
390
|
+
logger.error(f"Error revoking sessions for user {user_id}: {e}", exc_info=True)
|
|
406
391
|
return 0
|
|
407
392
|
|
|
408
393
|
async def cleanup_inactive_sessions(self, user_id: Optional[str] = None) -> int:
|