mdb-engine 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +38 -6
- mdb_engine/auth/README.md +534 -11
- mdb_engine/auth/__init__.py +129 -28
- mdb_engine/auth/audit.py +592 -0
- mdb_engine/auth/casbin_factory.py +10 -14
- mdb_engine/auth/config_helpers.py +7 -6
- mdb_engine/auth/cookie_utils.py +3 -7
- mdb_engine/auth/csrf.py +373 -0
- mdb_engine/auth/decorators.py +3 -10
- mdb_engine/auth/dependencies.py +37 -45
- mdb_engine/auth/helpers.py +3 -3
- mdb_engine/auth/integration.py +30 -73
- mdb_engine/auth/jwt.py +2 -6
- mdb_engine/auth/middleware.py +77 -34
- mdb_engine/auth/oso_factory.py +16 -36
- mdb_engine/auth/provider.py +17 -38
- mdb_engine/auth/rate_limiter.py +504 -0
- mdb_engine/auth/restrictions.py +8 -24
- mdb_engine/auth/session_manager.py +14 -29
- mdb_engine/auth/shared_middleware.py +600 -0
- mdb_engine/auth/shared_users.py +759 -0
- mdb_engine/auth/token_store.py +14 -28
- mdb_engine/auth/users.py +54 -113
- mdb_engine/auth/utils.py +213 -15
- mdb_engine/cli/commands/generate.py +545 -9
- mdb_engine/cli/commands/validate.py +3 -7
- mdb_engine/cli/utils.py +3 -3
- mdb_engine/config.py +7 -21
- mdb_engine/constants.py +65 -0
- mdb_engine/core/README.md +117 -6
- mdb_engine/core/__init__.py +39 -7
- mdb_engine/core/app_registration.py +22 -41
- mdb_engine/core/app_secrets.py +290 -0
- mdb_engine/core/connection.py +18 -9
- mdb_engine/core/encryption.py +223 -0
- mdb_engine/core/engine.py +758 -95
- mdb_engine/core/index_management.py +12 -16
- mdb_engine/core/manifest.py +424 -135
- mdb_engine/core/ray_integration.py +435 -0
- mdb_engine/core/seeding.py +10 -18
- mdb_engine/core/service_initialization.py +12 -23
- mdb_engine/core/types.py +2 -5
- mdb_engine/database/README.md +112 -16
- mdb_engine/database/__init__.py +17 -6
- mdb_engine/database/abstraction.py +25 -37
- mdb_engine/database/connection.py +11 -18
- mdb_engine/database/query_validator.py +367 -0
- mdb_engine/database/resource_limiter.py +204 -0
- mdb_engine/database/scoped_wrapper.py +713 -196
- mdb_engine/embeddings/__init__.py +17 -9
- mdb_engine/embeddings/dependencies.py +1 -3
- mdb_engine/embeddings/service.py +11 -25
- mdb_engine/exceptions.py +92 -0
- mdb_engine/indexes/README.md +30 -13
- mdb_engine/indexes/__init__.py +1 -0
- mdb_engine/indexes/helpers.py +1 -1
- mdb_engine/indexes/manager.py +50 -114
- mdb_engine/memory/README.md +2 -2
- mdb_engine/memory/__init__.py +1 -2
- mdb_engine/memory/service.py +30 -87
- mdb_engine/observability/README.md +4 -2
- mdb_engine/observability/__init__.py +26 -9
- mdb_engine/observability/health.py +8 -9
- mdb_engine/observability/metrics.py +32 -12
- mdb_engine/routing/README.md +1 -1
- mdb_engine/routing/__init__.py +1 -3
- mdb_engine/routing/websockets.py +25 -60
- mdb_engine-0.1.7.dist-info/METADATA +285 -0
- mdb_engine-0.1.7.dist-info/RECORD +85 -0
- mdb_engine-0.1.6.dist-info/METADATA +0 -213
- mdb_engine-0.1.6.dist-info/RECORD +0 -75
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/WHEEL +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/top_level.txt +0 -0
mdb_engine/auth/provider.py
CHANGED
|
@@ -6,13 +6,12 @@ Defines the pluggable Authorization (AuthZ) interface for the platform.
|
|
|
6
6
|
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from __future__ import
|
|
10
|
-
annotations # MUST be first import for string type hints
|
|
9
|
+
from __future__ import annotations # MUST be first import for string type hints
|
|
11
10
|
|
|
12
11
|
import asyncio
|
|
13
12
|
import logging
|
|
14
13
|
import time
|
|
15
|
-
from typing import TYPE_CHECKING, Any,
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Optional, Protocol
|
|
16
15
|
|
|
17
16
|
from ..constants import AUTHZ_CACHE_TTL, MAX_CACHE_SIZE
|
|
18
17
|
|
|
@@ -32,7 +31,7 @@ class AuthorizationProvider(Protocol):
|
|
|
32
31
|
subject: str,
|
|
33
32
|
resource: str,
|
|
34
33
|
action: str,
|
|
35
|
-
user_object: Optional[
|
|
34
|
+
user_object: Optional[dict[str, Any]] = None,
|
|
36
35
|
) -> bool:
|
|
37
36
|
"""
|
|
38
37
|
Checks if a subject is allowed to perform an action on a resource.
|
|
@@ -53,18 +52,16 @@ class CasbinAdapter:
|
|
|
53
52
|
"""
|
|
54
53
|
self._enforcer = enforcer
|
|
55
54
|
# Cache for authorization results: {(subject, resource, action): (result, timestamp)}
|
|
56
|
-
self._cache:
|
|
55
|
+
self._cache: dict[tuple[str, str, str], tuple[bool, float]] = {}
|
|
57
56
|
self._cache_lock = asyncio.Lock()
|
|
58
|
-
logger.info(
|
|
59
|
-
"✔️ CasbinAdapter initialized with async thread pool execution and caching."
|
|
60
|
-
)
|
|
57
|
+
logger.info("✔️ CasbinAdapter initialized with async thread pool execution and caching.")
|
|
61
58
|
|
|
62
59
|
async def check(
|
|
63
60
|
self,
|
|
64
61
|
subject: str,
|
|
65
62
|
resource: str,
|
|
66
63
|
action: str,
|
|
67
|
-
user_object: Optional[
|
|
64
|
+
user_object: Optional[dict[str, Any]] = None,
|
|
68
65
|
) -> bool:
|
|
69
66
|
"""
|
|
70
67
|
Performs the authorization check using the wrapped enforcer.
|
|
@@ -79,9 +76,7 @@ class CasbinAdapter:
|
|
|
79
76
|
cached_result, cached_time = self._cache[cache_key]
|
|
80
77
|
# Check if cache entry is still valid
|
|
81
78
|
if current_time - cached_time < AUTHZ_CACHE_TTL:
|
|
82
|
-
logger.debug(
|
|
83
|
-
f"Authorization cache HIT for ({subject}, {resource}, {action})"
|
|
84
|
-
)
|
|
79
|
+
logger.debug(f"Authorization cache HIT for ({subject}, {resource}, {action})")
|
|
85
80
|
return cached_result
|
|
86
81
|
# Cache expired, remove it
|
|
87
82
|
del self._cache[cache_key]
|
|
@@ -89,9 +84,7 @@ class CasbinAdapter:
|
|
|
89
84
|
try:
|
|
90
85
|
# The .enforce() method on AsyncEnforcer is synchronous and blocks the event loop.
|
|
91
86
|
# Run it in a thread pool to prevent blocking.
|
|
92
|
-
result = await asyncio.to_thread(
|
|
93
|
-
self._enforcer.enforce, subject, resource, action
|
|
94
|
-
)
|
|
87
|
+
result = await asyncio.to_thread(self._enforcer.enforce, subject, resource, action)
|
|
95
88
|
|
|
96
89
|
# Cache the result
|
|
97
90
|
async with self._cache_lock:
|
|
@@ -211,18 +204,16 @@ class OsoAdapter:
|
|
|
211
204
|
"""
|
|
212
205
|
self._oso = oso_client
|
|
213
206
|
# Cache for authorization results: {(subject, resource, action): (result, timestamp)}
|
|
214
|
-
self._cache:
|
|
207
|
+
self._cache: dict[tuple[str, str, str], tuple[bool, float]] = {}
|
|
215
208
|
self._cache_lock = asyncio.Lock()
|
|
216
|
-
logger.info(
|
|
217
|
-
"✔️ OsoAdapter initialized with async thread pool execution and caching."
|
|
218
|
-
)
|
|
209
|
+
logger.info("✔️ OsoAdapter initialized with async thread pool execution and caching.")
|
|
219
210
|
|
|
220
211
|
async def check(
|
|
221
212
|
self,
|
|
222
213
|
subject: str,
|
|
223
214
|
resource: str,
|
|
224
215
|
action: str,
|
|
225
|
-
user_object: Optional[
|
|
216
|
+
user_object: Optional[dict[str, Any]] = None,
|
|
226
217
|
) -> bool:
|
|
227
218
|
"""
|
|
228
219
|
Performs the authorization check using OSO.
|
|
@@ -239,9 +230,7 @@ class OsoAdapter:
|
|
|
239
230
|
cached_result, cached_time = self._cache[cache_key]
|
|
240
231
|
# Check if cache entry is still valid
|
|
241
232
|
if current_time - cached_time < AUTHZ_CACHE_TTL:
|
|
242
|
-
logger.debug(
|
|
243
|
-
f"Authorization cache HIT for ({subject}, {resource}, {action})"
|
|
244
|
-
)
|
|
233
|
+
logger.debug(f"Authorization cache HIT for ({subject}, {resource}, {action})")
|
|
245
234
|
return cached_result
|
|
246
235
|
# Cache expired, remove it
|
|
247
236
|
del self._cache[cache_key]
|
|
@@ -267,9 +256,7 @@ class OsoAdapter:
|
|
|
267
256
|
resource_obj = resource
|
|
268
257
|
|
|
269
258
|
# Run in thread pool to prevent blocking the event loop
|
|
270
|
-
result = await asyncio.to_thread(
|
|
271
|
-
self._oso.authorize, actor, action, resource_obj
|
|
272
|
-
)
|
|
259
|
+
result = await asyncio.to_thread(self._oso.authorize, actor, action, resource_obj)
|
|
273
260
|
|
|
274
261
|
# Cache the result
|
|
275
262
|
async with self._cache_lock:
|
|
@@ -333,9 +320,7 @@ class OsoAdapter:
|
|
|
333
320
|
)
|
|
334
321
|
elif hasattr(self._oso, "register_constant"):
|
|
335
322
|
# OSO library - we'd need to use a different approach
|
|
336
|
-
logger.warning(
|
|
337
|
-
"OSO library mode: add_policy needs to be handled via policy files"
|
|
338
|
-
)
|
|
323
|
+
logger.warning("OSO library mode: add_policy needs to be handled via policy files")
|
|
339
324
|
result = True # Assume success for now
|
|
340
325
|
else:
|
|
341
326
|
logger.warning("OSO client doesn't support insert() or tell() method")
|
|
@@ -406,9 +391,7 @@ class OsoAdapter:
|
|
|
406
391
|
self._oso.tell, "has_role", user, role, resource
|
|
407
392
|
)
|
|
408
393
|
else:
|
|
409
|
-
result = await asyncio.to_thread(
|
|
410
|
-
self._oso.tell, "has_role", user, role
|
|
411
|
-
)
|
|
394
|
+
result = await asyncio.to_thread(self._oso.tell, "has_role", user, role)
|
|
412
395
|
elif hasattr(self._oso, "register_constant"):
|
|
413
396
|
# OSO library - we'd need to use a different approach
|
|
414
397
|
logger.warning(
|
|
@@ -507,9 +490,7 @@ class OsoAdapter:
|
|
|
507
490
|
# OSO library - query facts
|
|
508
491
|
result = await asyncio.to_thread(
|
|
509
492
|
lambda: list(
|
|
510
|
-
self._oso.query_rule(
|
|
511
|
-
"has_role", user, role, accept_expression=True
|
|
512
|
-
)
|
|
493
|
+
self._oso.query_rule("has_role", user, role, accept_expression=True)
|
|
513
494
|
)
|
|
514
495
|
)
|
|
515
496
|
return len(result) > 0
|
|
@@ -548,9 +529,7 @@ class OsoAdapter:
|
|
|
548
529
|
user, role = params
|
|
549
530
|
# OSO Cloud uses delete() method
|
|
550
531
|
if hasattr(self._oso, "delete"):
|
|
551
|
-
result = await asyncio.to_thread(
|
|
552
|
-
self._oso.delete, "has_role", user, role
|
|
553
|
-
)
|
|
532
|
+
result = await asyncio.to_thread(self._oso.delete, "has_role", user, role)
|
|
554
533
|
else:
|
|
555
534
|
logger.warning("OSO client doesn't support delete() method")
|
|
556
535
|
result = False
|
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rate Limiting for Authentication Endpoints
|
|
3
|
+
|
|
4
|
+
Provides rate limiting middleware to protect auth endpoints from brute-force attacks.
|
|
5
|
+
Supports both in-memory storage (single instance) and MongoDB-backed storage (distributed).
|
|
6
|
+
|
|
7
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
8
|
+
|
|
9
|
+
Features:
|
|
10
|
+
- Sliding window rate limiting algorithm
|
|
11
|
+
- Per-endpoint configurable limits via manifest
|
|
12
|
+
- IP + optional email-based tracking
|
|
13
|
+
- In-memory (default) or MongoDB storage
|
|
14
|
+
- 429 Too Many Requests with Retry-After header
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
# Via middleware (recommended for shared auth)
|
|
18
|
+
app.add_middleware(
|
|
19
|
+
AuthRateLimitMiddleware,
|
|
20
|
+
limits={
|
|
21
|
+
"/login": RateLimit(max_attempts=5, window_seconds=300),
|
|
22
|
+
"/register": RateLimit(max_attempts=3, window_seconds=3600),
|
|
23
|
+
}
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Via decorator (for specific endpoints)
|
|
27
|
+
@app.post("/login")
|
|
28
|
+
@rate_limit(max_attempts=5, window_seconds=300)
|
|
29
|
+
async def login(request: Request):
|
|
30
|
+
...
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import logging
|
|
34
|
+
import time
|
|
35
|
+
from collections import defaultdict
|
|
36
|
+
from dataclasses import dataclass
|
|
37
|
+
from datetime import datetime, timedelta
|
|
38
|
+
from functools import wraps
|
|
39
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
40
|
+
|
|
41
|
+
from pymongo.errors import OperationFailure
|
|
42
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
43
|
+
from starlette.requests import Request
|
|
44
|
+
from starlette.responses import JSONResponse, Response
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class RateLimit:
|
|
51
|
+
"""Rate limit configuration for an endpoint."""
|
|
52
|
+
|
|
53
|
+
max_attempts: int = 5
|
|
54
|
+
window_seconds: int = 300 # 5 minutes
|
|
55
|
+
|
|
56
|
+
def to_dict(self) -> Dict[str, int]:
|
|
57
|
+
return {
|
|
58
|
+
"max_attempts": self.max_attempts,
|
|
59
|
+
"window_seconds": self.window_seconds,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Default rate limits for auth endpoints
|
|
64
|
+
DEFAULT_AUTH_RATE_LIMITS: Dict[str, RateLimit] = {
|
|
65
|
+
"/login": RateLimit(max_attempts=5, window_seconds=300),
|
|
66
|
+
"/register": RateLimit(max_attempts=3, window_seconds=3600),
|
|
67
|
+
"/logout": RateLimit(max_attempts=10, window_seconds=60),
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class InMemoryRateLimitStore:
|
|
72
|
+
"""
|
|
73
|
+
In-memory rate limit storage using sliding window algorithm.
|
|
74
|
+
|
|
75
|
+
Suitable for single-instance deployments. For distributed systems,
|
|
76
|
+
use MongoDBRateLimitStore instead.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self):
|
|
80
|
+
# Structure: {identifier: [(timestamp, count), ...]}
|
|
81
|
+
self._storage: Dict[str, List[Tuple[float, int]]] = defaultdict(list)
|
|
82
|
+
|
|
83
|
+
async def record_attempt(
|
|
84
|
+
self,
|
|
85
|
+
identifier: str,
|
|
86
|
+
window_seconds: int,
|
|
87
|
+
) -> int:
|
|
88
|
+
"""
|
|
89
|
+
Record an attempt and return current count in window.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
identifier: Unique identifier (e.g., "login:192.168.1.1:user@example.com")
|
|
93
|
+
window_seconds: Time window in seconds
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Number of attempts in the current window (including this one)
|
|
97
|
+
"""
|
|
98
|
+
now = time.time()
|
|
99
|
+
cutoff = now - window_seconds
|
|
100
|
+
|
|
101
|
+
# Clean old entries and count current
|
|
102
|
+
entries = self._storage[identifier]
|
|
103
|
+
entries[:] = [(ts, c) for ts, c in entries if ts > cutoff]
|
|
104
|
+
|
|
105
|
+
# Add new attempt
|
|
106
|
+
entries.append((now, 1))
|
|
107
|
+
|
|
108
|
+
# Return total count
|
|
109
|
+
return sum(c for _, c in entries)
|
|
110
|
+
|
|
111
|
+
async def get_count(
|
|
112
|
+
self,
|
|
113
|
+
identifier: str,
|
|
114
|
+
window_seconds: int,
|
|
115
|
+
) -> int:
|
|
116
|
+
"""Get current attempt count without recording."""
|
|
117
|
+
now = time.time()
|
|
118
|
+
cutoff = now - window_seconds
|
|
119
|
+
|
|
120
|
+
entries = self._storage.get(identifier, [])
|
|
121
|
+
return sum(c for ts, c in entries if ts > cutoff)
|
|
122
|
+
|
|
123
|
+
async def reset(self, identifier: str) -> None:
|
|
124
|
+
"""Reset rate limit for an identifier (e.g., after successful login)."""
|
|
125
|
+
self._storage.pop(identifier, None)
|
|
126
|
+
|
|
127
|
+
def cleanup(self, max_age_seconds: int = 7200) -> int:
|
|
128
|
+
"""
|
|
129
|
+
Clean up old entries to prevent memory growth.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
max_age_seconds: Remove entries older than this (default: 2 hours)
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Number of identifiers cleaned up
|
|
136
|
+
"""
|
|
137
|
+
now = time.time()
|
|
138
|
+
cutoff = now - max_age_seconds
|
|
139
|
+
cleaned = 0
|
|
140
|
+
|
|
141
|
+
identifiers_to_remove = []
|
|
142
|
+
for identifier, entries in self._storage.items():
|
|
143
|
+
entries[:] = [(ts, c) for ts, c in entries if ts > cutoff]
|
|
144
|
+
if not entries:
|
|
145
|
+
identifiers_to_remove.append(identifier)
|
|
146
|
+
|
|
147
|
+
for identifier in identifiers_to_remove:
|
|
148
|
+
del self._storage[identifier]
|
|
149
|
+
cleaned += 1
|
|
150
|
+
|
|
151
|
+
return cleaned
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MongoDBRateLimitStore:
|
|
155
|
+
"""
|
|
156
|
+
MongoDB-backed rate limit storage for distributed deployments.
|
|
157
|
+
|
|
158
|
+
Uses TTL indexes for automatic cleanup.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
COLLECTION = "_mdb_engine_rate_limits"
|
|
162
|
+
|
|
163
|
+
def __init__(self, db):
|
|
164
|
+
"""
|
|
165
|
+
Initialize MongoDB rate limit store.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
db: MongoDB database instance (Motor AsyncIOMotorDatabase)
|
|
169
|
+
"""
|
|
170
|
+
self._db = db
|
|
171
|
+
self._collection = db[self.COLLECTION]
|
|
172
|
+
self._indexes_created = False
|
|
173
|
+
|
|
174
|
+
async def ensure_indexes(self) -> None:
|
|
175
|
+
"""Create necessary indexes."""
|
|
176
|
+
if self._indexes_created:
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
# Compound index for lookups
|
|
181
|
+
await self._collection.create_index(
|
|
182
|
+
[("identifier", 1), ("timestamp", -1)], name="identifier_timestamp_idx"
|
|
183
|
+
)
|
|
184
|
+
# TTL index for cleanup
|
|
185
|
+
await self._collection.create_index(
|
|
186
|
+
"expires_at", expireAfterSeconds=0, name="expires_at_ttl_idx"
|
|
187
|
+
)
|
|
188
|
+
self._indexes_created = True
|
|
189
|
+
logger.info("Rate limit indexes ensured")
|
|
190
|
+
except OperationFailure as e:
|
|
191
|
+
logger.warning(f"Failed to create rate limit indexes: {e}")
|
|
192
|
+
|
|
193
|
+
async def record_attempt(
|
|
194
|
+
self,
|
|
195
|
+
identifier: str,
|
|
196
|
+
window_seconds: int,
|
|
197
|
+
) -> int:
|
|
198
|
+
"""Record an attempt and return current count in window."""
|
|
199
|
+
await self.ensure_indexes()
|
|
200
|
+
|
|
201
|
+
now = datetime.utcnow()
|
|
202
|
+
expires_at = now + timedelta(seconds=window_seconds)
|
|
203
|
+
cutoff = now - timedelta(seconds=window_seconds)
|
|
204
|
+
|
|
205
|
+
# Insert attempt
|
|
206
|
+
await self._collection.insert_one(
|
|
207
|
+
{
|
|
208
|
+
"identifier": identifier,
|
|
209
|
+
"timestamp": now,
|
|
210
|
+
"expires_at": expires_at,
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Count attempts in window
|
|
215
|
+
count = await self._collection.count_documents(
|
|
216
|
+
{
|
|
217
|
+
"identifier": identifier,
|
|
218
|
+
"timestamp": {"$gte": cutoff},
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return count
|
|
223
|
+
|
|
224
|
+
async def get_count(
|
|
225
|
+
self,
|
|
226
|
+
identifier: str,
|
|
227
|
+
window_seconds: int,
|
|
228
|
+
) -> int:
|
|
229
|
+
"""Get current attempt count without recording."""
|
|
230
|
+
await self.ensure_indexes()
|
|
231
|
+
|
|
232
|
+
cutoff = datetime.utcnow() - timedelta(seconds=window_seconds)
|
|
233
|
+
|
|
234
|
+
count = await self._collection.count_documents(
|
|
235
|
+
{
|
|
236
|
+
"identifier": identifier,
|
|
237
|
+
"timestamp": {"$gte": cutoff},
|
|
238
|
+
}
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return count
|
|
242
|
+
|
|
243
|
+
async def reset(self, identifier: str) -> None:
|
|
244
|
+
"""Reset rate limit for an identifier."""
|
|
245
|
+
await self._collection.delete_many({"identifier": identifier})
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# Global in-memory store (shared across middleware instances in same process)
|
|
249
|
+
_default_store = InMemoryRateLimitStore()
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class AuthRateLimitMiddleware(BaseHTTPMiddleware):
|
|
253
|
+
"""
|
|
254
|
+
ASGI middleware for rate limiting authentication endpoints.
|
|
255
|
+
|
|
256
|
+
Automatically protects /login, /register, and other auth endpoints
|
|
257
|
+
from brute-force attacks.
|
|
258
|
+
|
|
259
|
+
Features:
|
|
260
|
+
- Configurable per-endpoint limits
|
|
261
|
+
- IP + email tracking for login attempts
|
|
262
|
+
- 429 responses with Retry-After header
|
|
263
|
+
- Skips rate limiting for non-auth endpoints
|
|
264
|
+
|
|
265
|
+
Usage:
|
|
266
|
+
# Basic usage with defaults
|
|
267
|
+
app.add_middleware(AuthRateLimitMiddleware)
|
|
268
|
+
|
|
269
|
+
# Custom limits
|
|
270
|
+
app.add_middleware(
|
|
271
|
+
AuthRateLimitMiddleware,
|
|
272
|
+
limits={"/login": RateLimit(max_attempts=3, window_seconds=60)}
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# With MongoDB storage for distributed deployments
|
|
276
|
+
app.add_middleware(
|
|
277
|
+
AuthRateLimitMiddleware,
|
|
278
|
+
store=MongoDBRateLimitStore(db)
|
|
279
|
+
)
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(
|
|
283
|
+
self,
|
|
284
|
+
app: Callable,
|
|
285
|
+
limits: Optional[Dict[str, RateLimit]] = None,
|
|
286
|
+
store: Optional[InMemoryRateLimitStore] = None,
|
|
287
|
+
include_email_in_key: bool = True,
|
|
288
|
+
):
|
|
289
|
+
"""
|
|
290
|
+
Initialize rate limit middleware.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
app: ASGI application
|
|
294
|
+
limits: Dict of path -> RateLimit config. Defaults to DEFAULT_AUTH_RATE_LIMITS.
|
|
295
|
+
store: Rate limit storage backend. Defaults to in-memory store.
|
|
296
|
+
include_email_in_key: Include email in rate limit key for more granular limits.
|
|
297
|
+
"""
|
|
298
|
+
super().__init__(app)
|
|
299
|
+
self._limits = limits or DEFAULT_AUTH_RATE_LIMITS
|
|
300
|
+
self._store = store or _default_store
|
|
301
|
+
self._include_email_in_key = include_email_in_key
|
|
302
|
+
|
|
303
|
+
logger.info(
|
|
304
|
+
f"AuthRateLimitMiddleware initialized with limits for: {list(self._limits.keys())}"
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
async def dispatch(
|
|
308
|
+
self,
|
|
309
|
+
request: Request,
|
|
310
|
+
call_next: Callable[[Request], Response],
|
|
311
|
+
) -> Response:
|
|
312
|
+
"""Process request through rate limiter."""
|
|
313
|
+
path = request.url.path
|
|
314
|
+
method = request.method
|
|
315
|
+
|
|
316
|
+
# Only rate limit POST requests to configured endpoints
|
|
317
|
+
if method != "POST" or path not in self._limits:
|
|
318
|
+
return await call_next(request)
|
|
319
|
+
|
|
320
|
+
limit = self._limits[path]
|
|
321
|
+
identifier = await self._build_identifier(request, path)
|
|
322
|
+
|
|
323
|
+
# Check current count (before recording this attempt)
|
|
324
|
+
current_count = await self._store.get_count(identifier, limit.window_seconds)
|
|
325
|
+
|
|
326
|
+
if current_count >= limit.max_attempts:
|
|
327
|
+
logger.warning(
|
|
328
|
+
f"Rate limit exceeded: {identifier} "
|
|
329
|
+
f"({current_count}/{limit.max_attempts} in {limit.window_seconds}s)"
|
|
330
|
+
)
|
|
331
|
+
return self._rate_limit_response(limit.window_seconds)
|
|
332
|
+
|
|
333
|
+
# Record this attempt
|
|
334
|
+
await self._store.record_attempt(identifier, limit.window_seconds)
|
|
335
|
+
|
|
336
|
+
# Process request
|
|
337
|
+
response = await call_next(request)
|
|
338
|
+
|
|
339
|
+
# Reset on successful login (2xx response)
|
|
340
|
+
if response.status_code < 300 and path == "/login":
|
|
341
|
+
await self._store.reset(identifier)
|
|
342
|
+
|
|
343
|
+
return response
|
|
344
|
+
|
|
345
|
+
async def _build_identifier(self, request: Request, path: str) -> str:
|
|
346
|
+
"""Build rate limit identifier from request."""
|
|
347
|
+
parts = [path]
|
|
348
|
+
|
|
349
|
+
# Add client IP
|
|
350
|
+
client_ip = self._get_client_ip(request)
|
|
351
|
+
parts.append(client_ip)
|
|
352
|
+
|
|
353
|
+
# Optionally add email for more granular rate limiting
|
|
354
|
+
if self._include_email_in_key:
|
|
355
|
+
email = await self._extract_email(request)
|
|
356
|
+
if email:
|
|
357
|
+
parts.append(email)
|
|
358
|
+
|
|
359
|
+
return ":".join(parts)
|
|
360
|
+
|
|
361
|
+
def _get_client_ip(self, request: Request) -> str:
|
|
362
|
+
"""Get client IP, respecting proxy headers."""
|
|
363
|
+
# Check X-Forwarded-For header (set by proxies/load balancers)
|
|
364
|
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
365
|
+
if forwarded_for:
|
|
366
|
+
# Take the first IP (original client)
|
|
367
|
+
return forwarded_for.split(",")[0].strip()
|
|
368
|
+
|
|
369
|
+
# Check X-Real-IP header
|
|
370
|
+
real_ip = request.headers.get("X-Real-IP")
|
|
371
|
+
if real_ip:
|
|
372
|
+
return real_ip.strip()
|
|
373
|
+
|
|
374
|
+
# Fall back to direct client IP
|
|
375
|
+
if request.client:
|
|
376
|
+
return request.client.host
|
|
377
|
+
|
|
378
|
+
return "unknown"
|
|
379
|
+
|
|
380
|
+
async def _extract_email(self, request: Request) -> Optional[str]:
|
|
381
|
+
"""Try to extract email from request body."""
|
|
382
|
+
try:
|
|
383
|
+
# Only try to read body for JSON requests
|
|
384
|
+
content_type = request.headers.get("content-type", "")
|
|
385
|
+
if "application/json" in content_type:
|
|
386
|
+
# Note: This consumes the body, so we need to be careful
|
|
387
|
+
# In practice, this is called before the body is read by the route
|
|
388
|
+
body = await request.body()
|
|
389
|
+
if body:
|
|
390
|
+
import json
|
|
391
|
+
|
|
392
|
+
data = json.loads(body)
|
|
393
|
+
return data.get("email")
|
|
394
|
+
except (ValueError, UnicodeDecodeError, KeyError):
|
|
395
|
+
pass
|
|
396
|
+
return None
|
|
397
|
+
|
|
398
|
+
@staticmethod
|
|
399
|
+
def _rate_limit_response(retry_after: int) -> JSONResponse:
|
|
400
|
+
"""Return 429 Too Many Requests response."""
|
|
401
|
+
return JSONResponse(
|
|
402
|
+
status_code=429,
|
|
403
|
+
content={
|
|
404
|
+
"detail": f"Too many attempts. Please try again in {retry_after} seconds.",
|
|
405
|
+
"error": "rate_limit_exceeded",
|
|
406
|
+
"retry_after": retry_after,
|
|
407
|
+
},
|
|
408
|
+
headers={"Retry-After": str(retry_after)},
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def create_rate_limit_middleware(
|
|
413
|
+
manifest_auth: Dict[str, Any],
|
|
414
|
+
store: Optional[InMemoryRateLimitStore] = None,
|
|
415
|
+
) -> type:
|
|
416
|
+
"""
|
|
417
|
+
Factory function to create rate limit middleware from manifest config.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
manifest_auth: Auth section from manifest
|
|
421
|
+
store: Optional storage backend
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
Configured middleware class
|
|
425
|
+
|
|
426
|
+
Manifest format:
|
|
427
|
+
{
|
|
428
|
+
"auth": {
|
|
429
|
+
"rate_limits": {
|
|
430
|
+
"/login": {"max_attempts": 5, "window_seconds": 300},
|
|
431
|
+
"/register": {"max_attempts": 3, "window_seconds": 3600}
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
"""
|
|
436
|
+
rate_limits_config = manifest_auth.get("rate_limits", {})
|
|
437
|
+
|
|
438
|
+
limits: Dict[str, RateLimit] = {}
|
|
439
|
+
for path, config in rate_limits_config.items():
|
|
440
|
+
limits[path] = RateLimit(
|
|
441
|
+
max_attempts=config.get("max_attempts", 5),
|
|
442
|
+
window_seconds=config.get("window_seconds", 300),
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Use defaults if no config provided
|
|
446
|
+
if not limits:
|
|
447
|
+
limits = DEFAULT_AUTH_RATE_LIMITS.copy()
|
|
448
|
+
|
|
449
|
+
class ConfiguredRateLimitMiddleware(AuthRateLimitMiddleware):
|
|
450
|
+
def __init__(self, app: Callable):
|
|
451
|
+
super().__init__(app, limits=limits, store=store)
|
|
452
|
+
|
|
453
|
+
return ConfiguredRateLimitMiddleware
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def rate_limit(
|
|
457
|
+
max_attempts: int = 5,
|
|
458
|
+
window_seconds: int = 300,
|
|
459
|
+
key_func: Optional[Callable[[Request], str]] = None,
|
|
460
|
+
):
|
|
461
|
+
"""
|
|
462
|
+
Decorator for rate limiting individual endpoints.
|
|
463
|
+
|
|
464
|
+
Usage:
|
|
465
|
+
@app.post("/login")
|
|
466
|
+
@rate_limit(max_attempts=5, window_seconds=300)
|
|
467
|
+
async def login(request: Request):
|
|
468
|
+
...
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
max_attempts: Maximum attempts in window
|
|
472
|
+
window_seconds: Time window in seconds
|
|
473
|
+
key_func: Optional function to generate rate limit key from request
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
def decorator(func: Callable) -> Callable:
|
|
477
|
+
@wraps(func)
|
|
478
|
+
async def wrapper(request: Request, *args, **kwargs):
|
|
479
|
+
# Build identifier
|
|
480
|
+
if key_func:
|
|
481
|
+
identifier = key_func(request)
|
|
482
|
+
else:
|
|
483
|
+
client_ip = request.client.host if request.client else "unknown"
|
|
484
|
+
identifier = f"{func.__name__}:{client_ip}"
|
|
485
|
+
|
|
486
|
+
# Check rate limit
|
|
487
|
+
count = await _default_store.record_attempt(identifier, window_seconds)
|
|
488
|
+
|
|
489
|
+
if count > max_attempts:
|
|
490
|
+
logger.warning(f"Rate limit exceeded: {identifier} ({count}/{max_attempts})")
|
|
491
|
+
return JSONResponse(
|
|
492
|
+
status_code=429,
|
|
493
|
+
content={
|
|
494
|
+
"detail": f"Too many attempts. Try again in {window_seconds} seconds.",
|
|
495
|
+
"error": "rate_limit_exceeded",
|
|
496
|
+
},
|
|
497
|
+
headers={"Retry-After": str(window_seconds)},
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
return await func(request, *args, **kwargs)
|
|
501
|
+
|
|
502
|
+
return wrapper
|
|
503
|
+
|
|
504
|
+
return decorator
|