mdb-engine 0.1.6__py3-none-any.whl → 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +116 -11
- mdb_engine/auth/ARCHITECTURE.md +112 -0
- mdb_engine/auth/README.md +654 -11
- mdb_engine/auth/__init__.py +136 -29
- mdb_engine/auth/audit.py +592 -0
- mdb_engine/auth/base.py +252 -0
- mdb_engine/auth/casbin_factory.py +265 -70
- mdb_engine/auth/config_defaults.py +5 -5
- mdb_engine/auth/config_helpers.py +19 -18
- mdb_engine/auth/cookie_utils.py +12 -16
- mdb_engine/auth/csrf.py +483 -0
- mdb_engine/auth/decorators.py +10 -16
- mdb_engine/auth/dependencies.py +69 -71
- mdb_engine/auth/helpers.py +3 -3
- mdb_engine/auth/integration.py +61 -88
- mdb_engine/auth/jwt.py +11 -15
- mdb_engine/auth/middleware.py +79 -35
- mdb_engine/auth/oso_factory.py +21 -41
- mdb_engine/auth/provider.py +270 -171
- mdb_engine/auth/rate_limiter.py +505 -0
- mdb_engine/auth/restrictions.py +21 -36
- mdb_engine/auth/session_manager.py +24 -41
- mdb_engine/auth/shared_middleware.py +977 -0
- mdb_engine/auth/shared_users.py +775 -0
- mdb_engine/auth/token_lifecycle.py +10 -12
- mdb_engine/auth/token_store.py +17 -32
- mdb_engine/auth/users.py +99 -159
- mdb_engine/auth/utils.py +236 -42
- mdb_engine/cli/commands/generate.py +546 -10
- mdb_engine/cli/commands/validate.py +3 -7
- mdb_engine/cli/utils.py +7 -7
- mdb_engine/config.py +13 -28
- mdb_engine/constants.py +65 -0
- mdb_engine/core/README.md +117 -6
- mdb_engine/core/__init__.py +39 -7
- mdb_engine/core/app_registration.py +31 -50
- mdb_engine/core/app_secrets.py +289 -0
- mdb_engine/core/connection.py +20 -12
- mdb_engine/core/encryption.py +222 -0
- mdb_engine/core/engine.py +2862 -115
- mdb_engine/core/index_management.py +12 -16
- mdb_engine/core/manifest.py +628 -204
- mdb_engine/core/ray_integration.py +436 -0
- mdb_engine/core/seeding.py +13 -21
- mdb_engine/core/service_initialization.py +20 -30
- mdb_engine/core/types.py +40 -43
- mdb_engine/database/README.md +140 -17
- mdb_engine/database/__init__.py +17 -6
- mdb_engine/database/abstraction.py +37 -50
- mdb_engine/database/connection.py +51 -30
- mdb_engine/database/query_validator.py +367 -0
- mdb_engine/database/resource_limiter.py +204 -0
- mdb_engine/database/scoped_wrapper.py +747 -237
- mdb_engine/dependencies.py +427 -0
- mdb_engine/di/__init__.py +34 -0
- mdb_engine/di/container.py +247 -0
- mdb_engine/di/providers.py +206 -0
- mdb_engine/di/scopes.py +139 -0
- mdb_engine/embeddings/README.md +54 -24
- mdb_engine/embeddings/__init__.py +31 -24
- mdb_engine/embeddings/dependencies.py +38 -155
- mdb_engine/embeddings/service.py +78 -75
- mdb_engine/exceptions.py +104 -12
- mdb_engine/indexes/README.md +30 -13
- mdb_engine/indexes/__init__.py +1 -0
- mdb_engine/indexes/helpers.py +11 -11
- mdb_engine/indexes/manager.py +59 -123
- mdb_engine/memory/README.md +95 -4
- mdb_engine/memory/__init__.py +1 -2
- mdb_engine/memory/service.py +363 -1168
- mdb_engine/observability/README.md +4 -2
- mdb_engine/observability/__init__.py +26 -9
- mdb_engine/observability/health.py +17 -17
- mdb_engine/observability/logging.py +10 -10
- mdb_engine/observability/metrics.py +40 -19
- mdb_engine/repositories/__init__.py +34 -0
- mdb_engine/repositories/base.py +325 -0
- mdb_engine/repositories/mongo.py +233 -0
- mdb_engine/repositories/unit_of_work.py +166 -0
- mdb_engine/routing/README.md +1 -1
- mdb_engine/routing/__init__.py +1 -3
- mdb_engine/routing/websockets.py +41 -75
- mdb_engine/utils/__init__.py +3 -1
- mdb_engine/utils/mongo.py +117 -0
- mdb_engine-0.4.12.dist-info/METADATA +492 -0
- mdb_engine-0.4.12.dist-info/RECORD +97 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/WHEEL +1 -1
- mdb_engine-0.1.6.dist-info/METADATA +0 -213
- mdb_engine-0.1.6.dist-info/RECORD +0 -75
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,505 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rate Limiting for Authentication Endpoints
|
|
3
|
+
|
|
4
|
+
Provides rate limiting middleware to protect auth endpoints from brute-force attacks.
|
|
5
|
+
Supports both in-memory storage (single instance) and MongoDB-backed storage (distributed).
|
|
6
|
+
|
|
7
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
8
|
+
|
|
9
|
+
Features:
|
|
10
|
+
- Sliding window rate limiting algorithm
|
|
11
|
+
- Per-endpoint configurable limits via manifest
|
|
12
|
+
- IP + optional email-based tracking
|
|
13
|
+
- In-memory (default) or MongoDB storage
|
|
14
|
+
- 429 Too Many Requests with Retry-After header
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
# Via middleware (recommended for shared auth)
|
|
18
|
+
app.add_middleware(
|
|
19
|
+
AuthRateLimitMiddleware,
|
|
20
|
+
limits={
|
|
21
|
+
"/login": RateLimit(max_attempts=5, window_seconds=300),
|
|
22
|
+
"/register": RateLimit(max_attempts=3, window_seconds=3600),
|
|
23
|
+
}
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Via decorator (for specific endpoints)
|
|
27
|
+
@app.post("/login")
|
|
28
|
+
@rate_limit(max_attempts=5, window_seconds=300)
|
|
29
|
+
async def login(request: Request):
|
|
30
|
+
...
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import logging
|
|
34
|
+
import time
|
|
35
|
+
from collections import defaultdict
|
|
36
|
+
from collections.abc import Callable
|
|
37
|
+
from dataclasses import dataclass
|
|
38
|
+
from datetime import datetime, timedelta
|
|
39
|
+
from functools import wraps
|
|
40
|
+
from typing import Any
|
|
41
|
+
|
|
42
|
+
from pymongo.errors import OperationFailure
|
|
43
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
44
|
+
from starlette.requests import Request
|
|
45
|
+
from starlette.responses import JSONResponse, Response
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class RateLimit:
|
|
52
|
+
"""Rate limit configuration for an endpoint."""
|
|
53
|
+
|
|
54
|
+
max_attempts: int = 5
|
|
55
|
+
window_seconds: int = 300 # 5 minutes
|
|
56
|
+
|
|
57
|
+
def to_dict(self) -> dict[str, int]:
|
|
58
|
+
return {
|
|
59
|
+
"max_attempts": self.max_attempts,
|
|
60
|
+
"window_seconds": self.window_seconds,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Default rate limits for auth endpoints
|
|
65
|
+
DEFAULT_AUTH_RATE_LIMITS: dict[str, RateLimit] = {
|
|
66
|
+
"/login": RateLimit(max_attempts=5, window_seconds=300),
|
|
67
|
+
"/register": RateLimit(max_attempts=3, window_seconds=3600),
|
|
68
|
+
"/logout": RateLimit(max_attempts=10, window_seconds=60),
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class InMemoryRateLimitStore:
|
|
73
|
+
"""
|
|
74
|
+
In-memory rate limit storage using sliding window algorithm.
|
|
75
|
+
|
|
76
|
+
Suitable for single-instance deployments. For distributed systems,
|
|
77
|
+
use MongoDBRateLimitStore instead.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self):
|
|
81
|
+
# Structure: {identifier: [(timestamp, count), ...]}
|
|
82
|
+
self._storage: dict[str, list[tuple[float, int]]] = defaultdict(list)
|
|
83
|
+
|
|
84
|
+
async def record_attempt(
|
|
85
|
+
self,
|
|
86
|
+
identifier: str,
|
|
87
|
+
window_seconds: int,
|
|
88
|
+
) -> int:
|
|
89
|
+
"""
|
|
90
|
+
Record an attempt and return current count in window.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
identifier: Unique identifier (e.g., "login:192.168.1.1:user@example.com")
|
|
94
|
+
window_seconds: Time window in seconds
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Number of attempts in the current window (including this one)
|
|
98
|
+
"""
|
|
99
|
+
now = time.time()
|
|
100
|
+
cutoff = now - window_seconds
|
|
101
|
+
|
|
102
|
+
# Clean old entries and count current
|
|
103
|
+
entries = self._storage[identifier]
|
|
104
|
+
entries[:] = [(ts, c) for ts, c in entries if ts > cutoff]
|
|
105
|
+
|
|
106
|
+
# Add new attempt
|
|
107
|
+
entries.append((now, 1))
|
|
108
|
+
|
|
109
|
+
# Return total count
|
|
110
|
+
return sum(c for _, c in entries)
|
|
111
|
+
|
|
112
|
+
async def get_count(
|
|
113
|
+
self,
|
|
114
|
+
identifier: str,
|
|
115
|
+
window_seconds: int,
|
|
116
|
+
) -> int:
|
|
117
|
+
"""Get current attempt count without recording."""
|
|
118
|
+
now = time.time()
|
|
119
|
+
cutoff = now - window_seconds
|
|
120
|
+
|
|
121
|
+
entries = self._storage.get(identifier, [])
|
|
122
|
+
return sum(c for ts, c in entries if ts > cutoff)
|
|
123
|
+
|
|
124
|
+
async def reset(self, identifier: str) -> None:
|
|
125
|
+
"""Reset rate limit for an identifier (e.g., after successful login)."""
|
|
126
|
+
self._storage.pop(identifier, None)
|
|
127
|
+
|
|
128
|
+
def cleanup(self, max_age_seconds: int = 7200) -> int:
|
|
129
|
+
"""
|
|
130
|
+
Clean up old entries to prevent memory growth.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
max_age_seconds: Remove entries older than this (default: 2 hours)
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Number of identifiers cleaned up
|
|
137
|
+
"""
|
|
138
|
+
now = time.time()
|
|
139
|
+
cutoff = now - max_age_seconds
|
|
140
|
+
cleaned = 0
|
|
141
|
+
|
|
142
|
+
identifiers_to_remove = []
|
|
143
|
+
for identifier, entries in self._storage.items():
|
|
144
|
+
entries[:] = [(ts, c) for ts, c in entries if ts > cutoff]
|
|
145
|
+
if not entries:
|
|
146
|
+
identifiers_to_remove.append(identifier)
|
|
147
|
+
|
|
148
|
+
for identifier in identifiers_to_remove:
|
|
149
|
+
del self._storage[identifier]
|
|
150
|
+
cleaned += 1
|
|
151
|
+
|
|
152
|
+
return cleaned
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class MongoDBRateLimitStore:
|
|
156
|
+
"""
|
|
157
|
+
MongoDB-backed rate limit storage for distributed deployments.
|
|
158
|
+
|
|
159
|
+
Uses TTL indexes for automatic cleanup.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
COLLECTION = "_mdb_engine_rate_limits"
|
|
163
|
+
|
|
164
|
+
def __init__(self, db):
|
|
165
|
+
"""
|
|
166
|
+
Initialize MongoDB rate limit store.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
db: MongoDB database instance (Motor AsyncIOMotorDatabase)
|
|
170
|
+
"""
|
|
171
|
+
self._db = db
|
|
172
|
+
self._collection = db[self.COLLECTION]
|
|
173
|
+
self._indexes_created = False
|
|
174
|
+
|
|
175
|
+
async def ensure_indexes(self) -> None:
|
|
176
|
+
"""Create necessary indexes."""
|
|
177
|
+
if self._indexes_created:
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
# Compound index for lookups
|
|
182
|
+
await self._collection.create_index(
|
|
183
|
+
[("identifier", 1), ("timestamp", -1)], name="identifier_timestamp_idx"
|
|
184
|
+
)
|
|
185
|
+
# TTL index for cleanup
|
|
186
|
+
await self._collection.create_index(
|
|
187
|
+
"expires_at", expireAfterSeconds=0, name="expires_at_ttl_idx"
|
|
188
|
+
)
|
|
189
|
+
self._indexes_created = True
|
|
190
|
+
logger.info("Rate limit indexes ensured")
|
|
191
|
+
except OperationFailure as e:
|
|
192
|
+
logger.warning(f"Failed to create rate limit indexes: {e}")
|
|
193
|
+
|
|
194
|
+
async def record_attempt(
|
|
195
|
+
self,
|
|
196
|
+
identifier: str,
|
|
197
|
+
window_seconds: int,
|
|
198
|
+
) -> int:
|
|
199
|
+
"""Record an attempt and return current count in window."""
|
|
200
|
+
await self.ensure_indexes()
|
|
201
|
+
|
|
202
|
+
now = datetime.utcnow()
|
|
203
|
+
expires_at = now + timedelta(seconds=window_seconds)
|
|
204
|
+
cutoff = now - timedelta(seconds=window_seconds)
|
|
205
|
+
|
|
206
|
+
# Insert attempt
|
|
207
|
+
await self._collection.insert_one(
|
|
208
|
+
{
|
|
209
|
+
"identifier": identifier,
|
|
210
|
+
"timestamp": now,
|
|
211
|
+
"expires_at": expires_at,
|
|
212
|
+
}
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Count attempts in window
|
|
216
|
+
count = await self._collection.count_documents(
|
|
217
|
+
{
|
|
218
|
+
"identifier": identifier,
|
|
219
|
+
"timestamp": {"$gte": cutoff},
|
|
220
|
+
}
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return count
|
|
224
|
+
|
|
225
|
+
async def get_count(
|
|
226
|
+
self,
|
|
227
|
+
identifier: str,
|
|
228
|
+
window_seconds: int,
|
|
229
|
+
) -> int:
|
|
230
|
+
"""Get current attempt count without recording."""
|
|
231
|
+
await self.ensure_indexes()
|
|
232
|
+
|
|
233
|
+
cutoff = datetime.utcnow() - timedelta(seconds=window_seconds)
|
|
234
|
+
|
|
235
|
+
count = await self._collection.count_documents(
|
|
236
|
+
{
|
|
237
|
+
"identifier": identifier,
|
|
238
|
+
"timestamp": {"$gte": cutoff},
|
|
239
|
+
}
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
return count
|
|
243
|
+
|
|
244
|
+
async def reset(self, identifier: str) -> None:
|
|
245
|
+
"""Reset rate limit for an identifier."""
|
|
246
|
+
await self._collection.delete_many({"identifier": identifier})
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# Global in-memory store (shared across middleware instances in same process)
|
|
250
|
+
_default_store = InMemoryRateLimitStore()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class AuthRateLimitMiddleware(BaseHTTPMiddleware):
|
|
254
|
+
"""
|
|
255
|
+
ASGI middleware for rate limiting authentication endpoints.
|
|
256
|
+
|
|
257
|
+
Automatically protects /login, /register, and other auth endpoints
|
|
258
|
+
from brute-force attacks.
|
|
259
|
+
|
|
260
|
+
Features:
|
|
261
|
+
- Configurable per-endpoint limits
|
|
262
|
+
- IP + email tracking for login attempts
|
|
263
|
+
- 429 responses with Retry-After header
|
|
264
|
+
- Skips rate limiting for non-auth endpoints
|
|
265
|
+
|
|
266
|
+
Usage:
|
|
267
|
+
# Basic usage with defaults
|
|
268
|
+
app.add_middleware(AuthRateLimitMiddleware)
|
|
269
|
+
|
|
270
|
+
# Custom limits
|
|
271
|
+
app.add_middleware(
|
|
272
|
+
AuthRateLimitMiddleware,
|
|
273
|
+
limits={"/login": RateLimit(max_attempts=3, window_seconds=60)}
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# With MongoDB storage for distributed deployments
|
|
277
|
+
app.add_middleware(
|
|
278
|
+
AuthRateLimitMiddleware,
|
|
279
|
+
store=MongoDBRateLimitStore(db)
|
|
280
|
+
)
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
def __init__(
|
|
284
|
+
self,
|
|
285
|
+
app: Callable,
|
|
286
|
+
limits: dict[str, RateLimit] | None = None,
|
|
287
|
+
store: InMemoryRateLimitStore | None = None,
|
|
288
|
+
include_email_in_key: bool = True,
|
|
289
|
+
):
|
|
290
|
+
"""
|
|
291
|
+
Initialize rate limit middleware.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
app: ASGI application
|
|
295
|
+
limits: Dict of path -> RateLimit config. Defaults to DEFAULT_AUTH_RATE_LIMITS.
|
|
296
|
+
store: Rate limit storage backend. Defaults to in-memory store.
|
|
297
|
+
include_email_in_key: Include email in rate limit key for more granular limits.
|
|
298
|
+
"""
|
|
299
|
+
super().__init__(app)
|
|
300
|
+
self._limits = limits or DEFAULT_AUTH_RATE_LIMITS
|
|
301
|
+
self._store = store or _default_store
|
|
302
|
+
self._include_email_in_key = include_email_in_key
|
|
303
|
+
|
|
304
|
+
logger.info(
|
|
305
|
+
f"AuthRateLimitMiddleware initialized with limits for: {list(self._limits.keys())}"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
async def dispatch(
|
|
309
|
+
self,
|
|
310
|
+
request: Request,
|
|
311
|
+
call_next: Callable[[Request], Response],
|
|
312
|
+
) -> Response:
|
|
313
|
+
"""Process request through rate limiter."""
|
|
314
|
+
path = request.url.path
|
|
315
|
+
method = request.method
|
|
316
|
+
|
|
317
|
+
# Only rate limit POST requests to configured endpoints
|
|
318
|
+
if method != "POST" or path not in self._limits:
|
|
319
|
+
return await call_next(request)
|
|
320
|
+
|
|
321
|
+
limit = self._limits[path]
|
|
322
|
+
identifier = await self._build_identifier(request, path)
|
|
323
|
+
|
|
324
|
+
# Check current count (before recording this attempt)
|
|
325
|
+
current_count = await self._store.get_count(identifier, limit.window_seconds)
|
|
326
|
+
|
|
327
|
+
if current_count >= limit.max_attempts:
|
|
328
|
+
logger.warning(
|
|
329
|
+
f"Rate limit exceeded: {identifier} "
|
|
330
|
+
f"({current_count}/{limit.max_attempts} in {limit.window_seconds}s)"
|
|
331
|
+
)
|
|
332
|
+
return self._rate_limit_response(limit.window_seconds)
|
|
333
|
+
|
|
334
|
+
# Record this attempt
|
|
335
|
+
await self._store.record_attempt(identifier, limit.window_seconds)
|
|
336
|
+
|
|
337
|
+
# Process request
|
|
338
|
+
response = await call_next(request)
|
|
339
|
+
|
|
340
|
+
# Reset on successful login (2xx response)
|
|
341
|
+
if response.status_code < 300 and path == "/login":
|
|
342
|
+
await self._store.reset(identifier)
|
|
343
|
+
|
|
344
|
+
return response
|
|
345
|
+
|
|
346
|
+
async def _build_identifier(self, request: Request, path: str) -> str:
|
|
347
|
+
"""Build rate limit identifier from request."""
|
|
348
|
+
parts = [path]
|
|
349
|
+
|
|
350
|
+
# Add client IP
|
|
351
|
+
client_ip = self._get_client_ip(request)
|
|
352
|
+
parts.append(client_ip)
|
|
353
|
+
|
|
354
|
+
# Optionally add email for more granular rate limiting
|
|
355
|
+
if self._include_email_in_key:
|
|
356
|
+
email = await self._extract_email(request)
|
|
357
|
+
if email:
|
|
358
|
+
parts.append(email)
|
|
359
|
+
|
|
360
|
+
return ":".join(parts)
|
|
361
|
+
|
|
362
|
+
def _get_client_ip(self, request: Request) -> str:
|
|
363
|
+
"""Get client IP, respecting proxy headers."""
|
|
364
|
+
# Check X-Forwarded-For header (set by proxies/load balancers)
|
|
365
|
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
366
|
+
if forwarded_for:
|
|
367
|
+
# Take the first IP (original client)
|
|
368
|
+
return forwarded_for.split(",")[0].strip()
|
|
369
|
+
|
|
370
|
+
# Check X-Real-IP header
|
|
371
|
+
real_ip = request.headers.get("X-Real-IP")
|
|
372
|
+
if real_ip:
|
|
373
|
+
return real_ip.strip()
|
|
374
|
+
|
|
375
|
+
# Fall back to direct client IP
|
|
376
|
+
if request.client:
|
|
377
|
+
return request.client.host
|
|
378
|
+
|
|
379
|
+
return "unknown"
|
|
380
|
+
|
|
381
|
+
async def _extract_email(self, request: Request) -> str | None:
|
|
382
|
+
"""Try to extract email from request body."""
|
|
383
|
+
try:
|
|
384
|
+
# Only try to read body for JSON requests
|
|
385
|
+
content_type = request.headers.get("content-type", "")
|
|
386
|
+
if "application/json" in content_type:
|
|
387
|
+
# Note: This consumes the body, so we need to be careful
|
|
388
|
+
# In practice, this is called before the body is read by the route
|
|
389
|
+
body = await request.body()
|
|
390
|
+
if body:
|
|
391
|
+
import json
|
|
392
|
+
|
|
393
|
+
data = json.loads(body)
|
|
394
|
+
return data.get("email")
|
|
395
|
+
except (ValueError, UnicodeDecodeError, KeyError):
|
|
396
|
+
pass
|
|
397
|
+
return None
|
|
398
|
+
|
|
399
|
+
@staticmethod
|
|
400
|
+
def _rate_limit_response(retry_after: int) -> JSONResponse:
|
|
401
|
+
"""Return 429 Too Many Requests response."""
|
|
402
|
+
return JSONResponse(
|
|
403
|
+
status_code=429,
|
|
404
|
+
content={
|
|
405
|
+
"detail": f"Too many attempts. Please try again in {retry_after} seconds.",
|
|
406
|
+
"error": "rate_limit_exceeded",
|
|
407
|
+
"retry_after": retry_after,
|
|
408
|
+
},
|
|
409
|
+
headers={"Retry-After": str(retry_after)},
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def create_rate_limit_middleware(
|
|
414
|
+
manifest_auth: dict[str, Any],
|
|
415
|
+
store: InMemoryRateLimitStore | None = None,
|
|
416
|
+
) -> type:
|
|
417
|
+
"""
|
|
418
|
+
Factory function to create rate limit middleware from manifest config.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
manifest_auth: Auth section from manifest
|
|
422
|
+
store: Optional storage backend
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Configured middleware class
|
|
426
|
+
|
|
427
|
+
Manifest format:
|
|
428
|
+
{
|
|
429
|
+
"auth": {
|
|
430
|
+
"rate_limits": {
|
|
431
|
+
"/login": {"max_attempts": 5, "window_seconds": 300},
|
|
432
|
+
"/register": {"max_attempts": 3, "window_seconds": 3600}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
"""
|
|
437
|
+
rate_limits_config = manifest_auth.get("rate_limits", {})
|
|
438
|
+
|
|
439
|
+
limits: dict[str, RateLimit] = {}
|
|
440
|
+
for path, config in rate_limits_config.items():
|
|
441
|
+
limits[path] = RateLimit(
|
|
442
|
+
max_attempts=config.get("max_attempts", 5),
|
|
443
|
+
window_seconds=config.get("window_seconds", 300),
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
# Use defaults if no config provided
|
|
447
|
+
if not limits:
|
|
448
|
+
limits = DEFAULT_AUTH_RATE_LIMITS.copy()
|
|
449
|
+
|
|
450
|
+
class ConfiguredRateLimitMiddleware(AuthRateLimitMiddleware):
|
|
451
|
+
def __init__(self, app: Callable):
|
|
452
|
+
super().__init__(app, limits=limits, store=store)
|
|
453
|
+
|
|
454
|
+
return ConfiguredRateLimitMiddleware
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def rate_limit(
|
|
458
|
+
max_attempts: int = 5,
|
|
459
|
+
window_seconds: int = 300,
|
|
460
|
+
key_func: Callable[[Request], str] | None = None,
|
|
461
|
+
):
|
|
462
|
+
"""
|
|
463
|
+
Decorator for rate limiting individual endpoints.
|
|
464
|
+
|
|
465
|
+
Usage:
|
|
466
|
+
@app.post("/login")
|
|
467
|
+
@rate_limit(max_attempts=5, window_seconds=300)
|
|
468
|
+
async def login(request: Request):
|
|
469
|
+
...
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
max_attempts: Maximum attempts in window
|
|
473
|
+
window_seconds: Time window in seconds
|
|
474
|
+
key_func: Optional function to generate rate limit key from request
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
def decorator(func: Callable) -> Callable:
|
|
478
|
+
@wraps(func)
|
|
479
|
+
async def wrapper(request: Request, *args, **kwargs):
|
|
480
|
+
# Build identifier
|
|
481
|
+
if key_func:
|
|
482
|
+
identifier = key_func(request)
|
|
483
|
+
else:
|
|
484
|
+
client_ip = request.client.host if request.client else "unknown"
|
|
485
|
+
identifier = f"{func.__name__}:{client_ip}"
|
|
486
|
+
|
|
487
|
+
# Check rate limit
|
|
488
|
+
count = await _default_store.record_attempt(identifier, window_seconds)
|
|
489
|
+
|
|
490
|
+
if count > max_attempts:
|
|
491
|
+
logger.warning(f"Rate limit exceeded: {identifier} ({count}/{max_attempts})")
|
|
492
|
+
return JSONResponse(
|
|
493
|
+
status_code=429,
|
|
494
|
+
content={
|
|
495
|
+
"detail": f"Too many attempts. Try again in {window_seconds} seconds.",
|
|
496
|
+
"error": "rate_limit_exceeded",
|
|
497
|
+
},
|
|
498
|
+
headers={"Retry-After": str(window_seconds)},
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
return await func(request, *args, **kwargs)
|
|
502
|
+
|
|
503
|
+
return wrapper
|
|
504
|
+
|
|
505
|
+
return decorator
|
mdb_engine/auth/restrictions.py
CHANGED
|
@@ -16,7 +16,8 @@ This module is part of MDB_ENGINE - MongoDB Engine.
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
import logging
|
|
19
|
-
from
|
|
19
|
+
from collections.abc import Awaitable, Callable
|
|
20
|
+
from typing import Any
|
|
20
21
|
|
|
21
22
|
from fastapi import HTTPException, Request, status
|
|
22
23
|
|
|
@@ -27,9 +28,7 @@ from .users import get_app_user
|
|
|
27
28
|
logger = logging.getLogger(__name__)
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def is_demo_user(
|
|
31
|
-
user: Optional[Dict[str, Any]] = None, email: Optional[str] = None
|
|
32
|
-
) -> bool:
|
|
31
|
+
def is_demo_user(user: dict[str, Any] | None = None, email: str | None = None) -> bool:
|
|
33
32
|
"""
|
|
34
33
|
Check if a user is a demo user.
|
|
35
34
|
|
|
@@ -57,7 +56,7 @@ def is_demo_user(
|
|
|
57
56
|
return False
|
|
58
57
|
|
|
59
58
|
|
|
60
|
-
async def _get_platform_user(request: Request) ->
|
|
59
|
+
async def _get_platform_user(request: Request) -> dict[str, Any] | None:
|
|
61
60
|
"""Try to get user from platform authentication."""
|
|
62
61
|
try:
|
|
63
62
|
platform_user = await get_current_user_from_request(request)
|
|
@@ -72,9 +71,9 @@ async def _get_platform_user(request: Request) -> Optional[Dict[str, Any]]:
|
|
|
72
71
|
async def _get_sub_auth_user(
|
|
73
72
|
request: Request,
|
|
74
73
|
slug_id: str,
|
|
75
|
-
get_app_config_func: Callable[[Request, str,
|
|
74
|
+
get_app_config_func: Callable[[Request, str, dict], Awaitable[dict]],
|
|
76
75
|
get_app_db_func: Callable[[Request], Awaitable[Any]],
|
|
77
|
-
) ->
|
|
76
|
+
) -> dict[str, Any] | None:
|
|
78
77
|
"""Try to get user from sub-authentication."""
|
|
79
78
|
try:
|
|
80
79
|
db = await get_app_db_func(request)
|
|
@@ -85,9 +84,7 @@ async def _get_sub_auth_user(
|
|
|
85
84
|
if not (config and users_config.get("enabled", False)):
|
|
86
85
|
return None
|
|
87
86
|
|
|
88
|
-
app_user = await get_app_user(
|
|
89
|
-
request, slug_id, db, config, allow_demo_fallback=False
|
|
90
|
-
)
|
|
87
|
+
app_user = await get_app_user(request, slug_id, db, config, allow_demo_fallback=False)
|
|
91
88
|
if not app_user:
|
|
92
89
|
return None
|
|
93
90
|
|
|
@@ -108,9 +105,9 @@ async def _get_sub_auth_user(
|
|
|
108
105
|
async def _get_authenticated_user(
|
|
109
106
|
request: Request,
|
|
110
107
|
slug_id: str,
|
|
111
|
-
get_app_config_func:
|
|
112
|
-
get_app_db_func:
|
|
113
|
-
) ->
|
|
108
|
+
get_app_config_func: Callable[[Request, str, dict], Awaitable[dict]] | None,
|
|
109
|
+
get_app_db_func: Callable[[Request], Awaitable[Any]] | None,
|
|
110
|
+
) -> dict[str, Any] | None:
|
|
114
111
|
"""Get authenticated user from platform or sub-auth."""
|
|
115
112
|
# Try platform auth first
|
|
116
113
|
user = await _get_platform_user(request)
|
|
@@ -119,9 +116,7 @@ async def _get_authenticated_user(
|
|
|
119
116
|
|
|
120
117
|
# Try sub-auth if platform auth didn't work
|
|
121
118
|
if get_app_db_func and get_app_config_func:
|
|
122
|
-
return await _get_sub_auth_user(
|
|
123
|
-
request, slug_id, get_app_config_func, get_app_db_func
|
|
124
|
-
)
|
|
119
|
+
return await _get_sub_auth_user(request, slug_id, get_app_config_func, get_app_db_func)
|
|
125
120
|
|
|
126
121
|
return None
|
|
127
122
|
|
|
@@ -138,8 +133,8 @@ def _validate_slug_id(request: Request) -> str:
|
|
|
138
133
|
|
|
139
134
|
|
|
140
135
|
def _validate_dependencies(
|
|
141
|
-
get_app_config_func:
|
|
142
|
-
get_app_db_func:
|
|
136
|
+
get_app_config_func: Callable[[Request, str, dict], Awaitable[dict]] | None,
|
|
137
|
+
get_app_db_func: Callable[[Request], Awaitable[Any]] | None,
|
|
143
138
|
) -> None:
|
|
144
139
|
"""Validate that required dependencies are provided."""
|
|
145
140
|
if not get_app_db_func:
|
|
@@ -158,11 +153,9 @@ def _validate_dependencies(
|
|
|
158
153
|
|
|
159
154
|
async def require_non_demo_user(
|
|
160
155
|
request: Request,
|
|
161
|
-
get_app_config_func:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
get_app_db_func: Optional[Callable[[Request], Awaitable[Any]]] = None,
|
|
165
|
-
) -> Dict[str, Any]:
|
|
156
|
+
get_app_config_func: Callable[[Request, str, dict], Awaitable[dict]] | None = None,
|
|
157
|
+
get_app_db_func: Callable[[Request], Awaitable[Any]] | None = None,
|
|
158
|
+
) -> dict[str, Any]:
|
|
166
159
|
"""
|
|
167
160
|
FastAPI dependency that blocks demo users from accessing an endpoint.
|
|
168
161
|
|
|
@@ -189,9 +182,7 @@ async def require_non_demo_user(
|
|
|
189
182
|
if get_app_db_func and get_app_config_func:
|
|
190
183
|
_validate_dependencies(get_app_config_func, get_app_db_func)
|
|
191
184
|
|
|
192
|
-
user = await _get_authenticated_user(
|
|
193
|
-
request, slug_id, get_app_config_func, get_app_db_func
|
|
194
|
-
)
|
|
185
|
+
user = await _get_authenticated_user(request, slug_id, get_app_config_func, get_app_db_func)
|
|
195
186
|
|
|
196
187
|
# Check if user is demo
|
|
197
188
|
if user and is_demo_user(user):
|
|
@@ -215,10 +206,8 @@ async def require_non_demo_user(
|
|
|
215
206
|
|
|
216
207
|
async def block_demo_users(
|
|
217
208
|
request: Request,
|
|
218
|
-
get_app_config_func:
|
|
219
|
-
|
|
220
|
-
] = None,
|
|
221
|
-
get_app_db_func: Optional[Callable[[Request], Awaitable[Any]]] = None,
|
|
209
|
+
get_app_config_func: Callable[[Request, str, dict], Awaitable[dict]] | None = None,
|
|
210
|
+
get_app_db_func: Callable[[Request], Awaitable[Any]] | None = None,
|
|
222
211
|
):
|
|
223
212
|
"""
|
|
224
213
|
FastAPI dependency that blocks demo users and returns an error response.
|
|
@@ -253,15 +242,11 @@ async def block_demo_users(
|
|
|
253
242
|
|
|
254
243
|
# Check sub-auth if platform auth didn't work
|
|
255
244
|
if not user and get_app_db_func and get_app_config_func and slug_id:
|
|
256
|
-
user = await _get_sub_auth_user(
|
|
257
|
-
request, slug_id, get_app_config_func, get_app_db_func
|
|
258
|
-
)
|
|
245
|
+
user = await _get_sub_auth_user(request, slug_id, get_app_config_func, get_app_db_func)
|
|
259
246
|
|
|
260
247
|
# Block if demo user (only if user exists)
|
|
261
248
|
if user and is_demo_user(user):
|
|
262
|
-
logger.info(
|
|
263
|
-
f"Demo user '{user.get('email')}' blocked from accessing: {request.url.path}"
|
|
264
|
-
)
|
|
249
|
+
logger.info(f"Demo user '{user.get('email')}' blocked from accessing: {request.url.path}")
|
|
265
250
|
raise HTTPException(
|
|
266
251
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
267
252
|
detail="Demo users cannot access this endpoint. Demo mode is read-only.",
|