mdb-engine 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. mdb_engine/__init__.py +38 -6
  2. mdb_engine/auth/README.md +534 -11
  3. mdb_engine/auth/__init__.py +129 -28
  4. mdb_engine/auth/audit.py +592 -0
  5. mdb_engine/auth/casbin_factory.py +10 -14
  6. mdb_engine/auth/config_helpers.py +7 -6
  7. mdb_engine/auth/cookie_utils.py +3 -7
  8. mdb_engine/auth/csrf.py +373 -0
  9. mdb_engine/auth/decorators.py +3 -10
  10. mdb_engine/auth/dependencies.py +37 -45
  11. mdb_engine/auth/helpers.py +3 -3
  12. mdb_engine/auth/integration.py +30 -73
  13. mdb_engine/auth/jwt.py +2 -6
  14. mdb_engine/auth/middleware.py +77 -34
  15. mdb_engine/auth/oso_factory.py +16 -36
  16. mdb_engine/auth/provider.py +17 -38
  17. mdb_engine/auth/rate_limiter.py +504 -0
  18. mdb_engine/auth/restrictions.py +8 -24
  19. mdb_engine/auth/session_manager.py +14 -29
  20. mdb_engine/auth/shared_middleware.py +600 -0
  21. mdb_engine/auth/shared_users.py +759 -0
  22. mdb_engine/auth/token_store.py +14 -28
  23. mdb_engine/auth/users.py +54 -113
  24. mdb_engine/auth/utils.py +213 -15
  25. mdb_engine/cli/commands/generate.py +545 -9
  26. mdb_engine/cli/commands/validate.py +3 -7
  27. mdb_engine/cli/utils.py +3 -3
  28. mdb_engine/config.py +7 -21
  29. mdb_engine/constants.py +65 -0
  30. mdb_engine/core/README.md +117 -6
  31. mdb_engine/core/__init__.py +39 -7
  32. mdb_engine/core/app_registration.py +22 -41
  33. mdb_engine/core/app_secrets.py +290 -0
  34. mdb_engine/core/connection.py +18 -9
  35. mdb_engine/core/encryption.py +223 -0
  36. mdb_engine/core/engine.py +758 -95
  37. mdb_engine/core/index_management.py +12 -16
  38. mdb_engine/core/manifest.py +424 -135
  39. mdb_engine/core/ray_integration.py +435 -0
  40. mdb_engine/core/seeding.py +10 -18
  41. mdb_engine/core/service_initialization.py +12 -23
  42. mdb_engine/core/types.py +2 -5
  43. mdb_engine/database/README.md +112 -16
  44. mdb_engine/database/__init__.py +17 -6
  45. mdb_engine/database/abstraction.py +25 -37
  46. mdb_engine/database/connection.py +11 -18
  47. mdb_engine/database/query_validator.py +367 -0
  48. mdb_engine/database/resource_limiter.py +204 -0
  49. mdb_engine/database/scoped_wrapper.py +713 -196
  50. mdb_engine/embeddings/__init__.py +17 -9
  51. mdb_engine/embeddings/dependencies.py +1 -3
  52. mdb_engine/embeddings/service.py +11 -25
  53. mdb_engine/exceptions.py +92 -0
  54. mdb_engine/indexes/README.md +30 -13
  55. mdb_engine/indexes/__init__.py +1 -0
  56. mdb_engine/indexes/helpers.py +1 -1
  57. mdb_engine/indexes/manager.py +50 -114
  58. mdb_engine/memory/README.md +2 -2
  59. mdb_engine/memory/__init__.py +1 -2
  60. mdb_engine/memory/service.py +30 -87
  61. mdb_engine/observability/README.md +4 -2
  62. mdb_engine/observability/__init__.py +26 -9
  63. mdb_engine/observability/health.py +8 -9
  64. mdb_engine/observability/metrics.py +32 -12
  65. mdb_engine/routing/README.md +1 -1
  66. mdb_engine/routing/__init__.py +1 -3
  67. mdb_engine/routing/websockets.py +25 -60
  68. mdb_engine-0.1.7.dist-info/METADATA +285 -0
  69. mdb_engine-0.1.7.dist-info/RECORD +85 -0
  70. mdb_engine-0.1.6.dist-info/METADATA +0 -213
  71. mdb_engine-0.1.6.dist-info/RECORD +0 -75
  72. {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/WHEEL +0 -0
  73. {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/entry_points.txt +0 -0
  74. {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/licenses/LICENSE +0 -0
  75. {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,12 @@ Defines the pluggable Authorization (AuthZ) interface for the platform.
6
6
  This module is part of MDB_ENGINE - MongoDB Engine.
7
7
  """
8
8
 
9
- from __future__ import \
10
- annotations # MUST be first import for string type hints
9
+ from __future__ import annotations # MUST be first import for string type hints
11
10
 
12
11
  import asyncio
13
12
  import logging
14
13
  import time
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Tuple
14
+ from typing import TYPE_CHECKING, Any, Optional, Protocol
16
15
 
17
16
  from ..constants import AUTHZ_CACHE_TTL, MAX_CACHE_SIZE
18
17
 
@@ -32,7 +31,7 @@ class AuthorizationProvider(Protocol):
32
31
  subject: str,
33
32
  resource: str,
34
33
  action: str,
35
- user_object: Optional[Dict[str, Any]] = None,
34
+ user_object: Optional[dict[str, Any]] = None,
36
35
  ) -> bool:
37
36
  """
38
37
  Checks if a subject is allowed to perform an action on a resource.
@@ -53,18 +52,16 @@ class CasbinAdapter:
53
52
  """
54
53
  self._enforcer = enforcer
55
54
  # Cache for authorization results: {(subject, resource, action): (result, timestamp)}
56
- self._cache: Dict[Tuple[str, str, str], Tuple[bool, float]] = {}
55
+ self._cache: dict[tuple[str, str, str], tuple[bool, float]] = {}
57
56
  self._cache_lock = asyncio.Lock()
58
- logger.info(
59
- "✔️ CasbinAdapter initialized with async thread pool execution and caching."
60
- )
57
+ logger.info("✔️ CasbinAdapter initialized with async thread pool execution and caching.")
61
58
 
62
59
  async def check(
63
60
  self,
64
61
  subject: str,
65
62
  resource: str,
66
63
  action: str,
67
- user_object: Optional[Dict[str, Any]] = None,
64
+ user_object: Optional[dict[str, Any]] = None,
68
65
  ) -> bool:
69
66
  """
70
67
  Performs the authorization check using the wrapped enforcer.
@@ -79,9 +76,7 @@ class CasbinAdapter:
79
76
  cached_result, cached_time = self._cache[cache_key]
80
77
  # Check if cache entry is still valid
81
78
  if current_time - cached_time < AUTHZ_CACHE_TTL:
82
- logger.debug(
83
- f"Authorization cache HIT for ({subject}, {resource}, {action})"
84
- )
79
+ logger.debug(f"Authorization cache HIT for ({subject}, {resource}, {action})")
85
80
  return cached_result
86
81
  # Cache expired, remove it
87
82
  del self._cache[cache_key]
@@ -89,9 +84,7 @@ class CasbinAdapter:
89
84
  try:
90
85
  # The .enforce() method on AsyncEnforcer is synchronous and blocks the event loop.
91
86
  # Run it in a thread pool to prevent blocking.
92
- result = await asyncio.to_thread(
93
- self._enforcer.enforce, subject, resource, action
94
- )
87
+ result = await asyncio.to_thread(self._enforcer.enforce, subject, resource, action)
95
88
 
96
89
  # Cache the result
97
90
  async with self._cache_lock:
@@ -211,18 +204,16 @@ class OsoAdapter:
211
204
  """
212
205
  self._oso = oso_client
213
206
  # Cache for authorization results: {(subject, resource, action): (result, timestamp)}
214
- self._cache: Dict[Tuple[str, str, str], Tuple[bool, float]] = {}
207
+ self._cache: dict[tuple[str, str, str], tuple[bool, float]] = {}
215
208
  self._cache_lock = asyncio.Lock()
216
- logger.info(
217
- "✔️ OsoAdapter initialized with async thread pool execution and caching."
218
- )
209
+ logger.info("✔️ OsoAdapter initialized with async thread pool execution and caching.")
219
210
 
220
211
  async def check(
221
212
  self,
222
213
  subject: str,
223
214
  resource: str,
224
215
  action: str,
225
- user_object: Optional[Dict[str, Any]] = None,
216
+ user_object: Optional[dict[str, Any]] = None,
226
217
  ) -> bool:
227
218
  """
228
219
  Performs the authorization check using OSO.
@@ -239,9 +230,7 @@ class OsoAdapter:
239
230
  cached_result, cached_time = self._cache[cache_key]
240
231
  # Check if cache entry is still valid
241
232
  if current_time - cached_time < AUTHZ_CACHE_TTL:
242
- logger.debug(
243
- f"Authorization cache HIT for ({subject}, {resource}, {action})"
244
- )
233
+ logger.debug(f"Authorization cache HIT for ({subject}, {resource}, {action})")
245
234
  return cached_result
246
235
  # Cache expired, remove it
247
236
  del self._cache[cache_key]
@@ -267,9 +256,7 @@ class OsoAdapter:
267
256
  resource_obj = resource
268
257
 
269
258
  # Run in thread pool to prevent blocking the event loop
270
- result = await asyncio.to_thread(
271
- self._oso.authorize, actor, action, resource_obj
272
- )
259
+ result = await asyncio.to_thread(self._oso.authorize, actor, action, resource_obj)
273
260
 
274
261
  # Cache the result
275
262
  async with self._cache_lock:
@@ -333,9 +320,7 @@ class OsoAdapter:
333
320
  )
334
321
  elif hasattr(self._oso, "register_constant"):
335
322
  # OSO library - we'd need to use a different approach
336
- logger.warning(
337
- "OSO library mode: add_policy needs to be handled via policy files"
338
- )
323
+ logger.warning("OSO library mode: add_policy needs to be handled via policy files")
339
324
  result = True # Assume success for now
340
325
  else:
341
326
  logger.warning("OSO client doesn't support insert() or tell() method")
@@ -406,9 +391,7 @@ class OsoAdapter:
406
391
  self._oso.tell, "has_role", user, role, resource
407
392
  )
408
393
  else:
409
- result = await asyncio.to_thread(
410
- self._oso.tell, "has_role", user, role
411
- )
394
+ result = await asyncio.to_thread(self._oso.tell, "has_role", user, role)
412
395
  elif hasattr(self._oso, "register_constant"):
413
396
  # OSO library - we'd need to use a different approach
414
397
  logger.warning(
@@ -507,9 +490,7 @@ class OsoAdapter:
507
490
  # OSO library - query facts
508
491
  result = await asyncio.to_thread(
509
492
  lambda: list(
510
- self._oso.query_rule(
511
- "has_role", user, role, accept_expression=True
512
- )
493
+ self._oso.query_rule("has_role", user, role, accept_expression=True)
513
494
  )
514
495
  )
515
496
  return len(result) > 0
@@ -548,9 +529,7 @@ class OsoAdapter:
548
529
  user, role = params
549
530
  # OSO Cloud uses delete() method
550
531
  if hasattr(self._oso, "delete"):
551
- result = await asyncio.to_thread(
552
- self._oso.delete, "has_role", user, role
553
- )
532
+ result = await asyncio.to_thread(self._oso.delete, "has_role", user, role)
554
533
  else:
555
534
  logger.warning("OSO client doesn't support delete() method")
556
535
  result = False
@@ -0,0 +1,504 @@
1
+ """
2
+ Rate Limiting for Authentication Endpoints
3
+
4
+ Provides rate limiting middleware to protect auth endpoints from brute-force attacks.
5
+ Supports both in-memory storage (single instance) and MongoDB-backed storage (distributed).
6
+
7
+ This module is part of MDB_ENGINE - MongoDB Engine.
8
+
9
+ Features:
10
+ - Sliding window rate limiting algorithm
11
+ - Per-endpoint configurable limits via manifest
12
+ - IP + optional email-based tracking
13
+ - In-memory (default) or MongoDB storage
14
+ - 429 Too Many Requests with Retry-After header
15
+
16
+ Usage:
17
+ # Via middleware (recommended for shared auth)
18
+ app.add_middleware(
19
+ AuthRateLimitMiddleware,
20
+ limits={
21
+ "/login": RateLimit(max_attempts=5, window_seconds=300),
22
+ "/register": RateLimit(max_attempts=3, window_seconds=3600),
23
+ }
24
+ )
25
+
26
+ # Via decorator (for specific endpoints)
27
+ @app.post("/login")
28
+ @rate_limit(max_attempts=5, window_seconds=300)
29
+ async def login(request: Request):
30
+ ...
31
+ """
32
+
33
+ import logging
34
+ import time
35
+ from collections import defaultdict
36
+ from dataclasses import dataclass
37
+ from datetime import datetime, timedelta
38
+ from functools import wraps
39
+ from typing import Any, Callable, Dict, List, Optional, Tuple
40
+
41
+ from pymongo.errors import OperationFailure
42
+ from starlette.middleware.base import BaseHTTPMiddleware
43
+ from starlette.requests import Request
44
+ from starlette.responses import JSONResponse, Response
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ @dataclass
50
+ class RateLimit:
51
+ """Rate limit configuration for an endpoint."""
52
+
53
+ max_attempts: int = 5
54
+ window_seconds: int = 300 # 5 minutes
55
+
56
+ def to_dict(self) -> Dict[str, int]:
57
+ return {
58
+ "max_attempts": self.max_attempts,
59
+ "window_seconds": self.window_seconds,
60
+ }
61
+
62
+
63
+ # Default rate limits for auth endpoints
64
+ DEFAULT_AUTH_RATE_LIMITS: Dict[str, RateLimit] = {
65
+ "/login": RateLimit(max_attempts=5, window_seconds=300),
66
+ "/register": RateLimit(max_attempts=3, window_seconds=3600),
67
+ "/logout": RateLimit(max_attempts=10, window_seconds=60),
68
+ }
69
+
70
+
71
+ class InMemoryRateLimitStore:
72
+ """
73
+ In-memory rate limit storage using sliding window algorithm.
74
+
75
+ Suitable for single-instance deployments. For distributed systems,
76
+ use MongoDBRateLimitStore instead.
77
+ """
78
+
79
+ def __init__(self):
80
+ # Structure: {identifier: [(timestamp, count), ...]}
81
+ self._storage: Dict[str, List[Tuple[float, int]]] = defaultdict(list)
82
+
83
+ async def record_attempt(
84
+ self,
85
+ identifier: str,
86
+ window_seconds: int,
87
+ ) -> int:
88
+ """
89
+ Record an attempt and return current count in window.
90
+
91
+ Args:
92
+ identifier: Unique identifier (e.g., "login:192.168.1.1:user@example.com")
93
+ window_seconds: Time window in seconds
94
+
95
+ Returns:
96
+ Number of attempts in the current window (including this one)
97
+ """
98
+ now = time.time()
99
+ cutoff = now - window_seconds
100
+
101
+ # Clean old entries and count current
102
+ entries = self._storage[identifier]
103
+ entries[:] = [(ts, c) for ts, c in entries if ts > cutoff]
104
+
105
+ # Add new attempt
106
+ entries.append((now, 1))
107
+
108
+ # Return total count
109
+ return sum(c for _, c in entries)
110
+
111
+ async def get_count(
112
+ self,
113
+ identifier: str,
114
+ window_seconds: int,
115
+ ) -> int:
116
+ """Get current attempt count without recording."""
117
+ now = time.time()
118
+ cutoff = now - window_seconds
119
+
120
+ entries = self._storage.get(identifier, [])
121
+ return sum(c for ts, c in entries if ts > cutoff)
122
+
123
+ async def reset(self, identifier: str) -> None:
124
+ """Reset rate limit for an identifier (e.g., after successful login)."""
125
+ self._storage.pop(identifier, None)
126
+
127
+ def cleanup(self, max_age_seconds: int = 7200) -> int:
128
+ """
129
+ Clean up old entries to prevent memory growth.
130
+
131
+ Args:
132
+ max_age_seconds: Remove entries older than this (default: 2 hours)
133
+
134
+ Returns:
135
+ Number of identifiers cleaned up
136
+ """
137
+ now = time.time()
138
+ cutoff = now - max_age_seconds
139
+ cleaned = 0
140
+
141
+ identifiers_to_remove = []
142
+ for identifier, entries in self._storage.items():
143
+ entries[:] = [(ts, c) for ts, c in entries if ts > cutoff]
144
+ if not entries:
145
+ identifiers_to_remove.append(identifier)
146
+
147
+ for identifier in identifiers_to_remove:
148
+ del self._storage[identifier]
149
+ cleaned += 1
150
+
151
+ return cleaned
152
+
153
+
154
+ class MongoDBRateLimitStore:
155
+ """
156
+ MongoDB-backed rate limit storage for distributed deployments.
157
+
158
+ Uses TTL indexes for automatic cleanup.
159
+ """
160
+
161
+ COLLECTION = "_mdb_engine_rate_limits"
162
+
163
+ def __init__(self, db):
164
+ """
165
+ Initialize MongoDB rate limit store.
166
+
167
+ Args:
168
+ db: MongoDB database instance (Motor AsyncIOMotorDatabase)
169
+ """
170
+ self._db = db
171
+ self._collection = db[self.COLLECTION]
172
+ self._indexes_created = False
173
+
174
+ async def ensure_indexes(self) -> None:
175
+ """Create necessary indexes."""
176
+ if self._indexes_created:
177
+ return
178
+
179
+ try:
180
+ # Compound index for lookups
181
+ await self._collection.create_index(
182
+ [("identifier", 1), ("timestamp", -1)], name="identifier_timestamp_idx"
183
+ )
184
+ # TTL index for cleanup
185
+ await self._collection.create_index(
186
+ "expires_at", expireAfterSeconds=0, name="expires_at_ttl_idx"
187
+ )
188
+ self._indexes_created = True
189
+ logger.info("Rate limit indexes ensured")
190
+ except OperationFailure as e:
191
+ logger.warning(f"Failed to create rate limit indexes: {e}")
192
+
193
+ async def record_attempt(
194
+ self,
195
+ identifier: str,
196
+ window_seconds: int,
197
+ ) -> int:
198
+ """Record an attempt and return current count in window."""
199
+ await self.ensure_indexes()
200
+
201
+ now = datetime.utcnow()
202
+ expires_at = now + timedelta(seconds=window_seconds)
203
+ cutoff = now - timedelta(seconds=window_seconds)
204
+
205
+ # Insert attempt
206
+ await self._collection.insert_one(
207
+ {
208
+ "identifier": identifier,
209
+ "timestamp": now,
210
+ "expires_at": expires_at,
211
+ }
212
+ )
213
+
214
+ # Count attempts in window
215
+ count = await self._collection.count_documents(
216
+ {
217
+ "identifier": identifier,
218
+ "timestamp": {"$gte": cutoff},
219
+ }
220
+ )
221
+
222
+ return count
223
+
224
+ async def get_count(
225
+ self,
226
+ identifier: str,
227
+ window_seconds: int,
228
+ ) -> int:
229
+ """Get current attempt count without recording."""
230
+ await self.ensure_indexes()
231
+
232
+ cutoff = datetime.utcnow() - timedelta(seconds=window_seconds)
233
+
234
+ count = await self._collection.count_documents(
235
+ {
236
+ "identifier": identifier,
237
+ "timestamp": {"$gte": cutoff},
238
+ }
239
+ )
240
+
241
+ return count
242
+
243
+ async def reset(self, identifier: str) -> None:
244
+ """Reset rate limit for an identifier."""
245
+ await self._collection.delete_many({"identifier": identifier})
246
+
247
+
248
+ # Global in-memory store (shared across middleware instances in same process)
249
+ _default_store = InMemoryRateLimitStore()
250
+
251
+
252
+ class AuthRateLimitMiddleware(BaseHTTPMiddleware):
253
+ """
254
+ ASGI middleware for rate limiting authentication endpoints.
255
+
256
+ Automatically protects /login, /register, and other auth endpoints
257
+ from brute-force attacks.
258
+
259
+ Features:
260
+ - Configurable per-endpoint limits
261
+ - IP + email tracking for login attempts
262
+ - 429 responses with Retry-After header
263
+ - Skips rate limiting for non-auth endpoints
264
+
265
+ Usage:
266
+ # Basic usage with defaults
267
+ app.add_middleware(AuthRateLimitMiddleware)
268
+
269
+ # Custom limits
270
+ app.add_middleware(
271
+ AuthRateLimitMiddleware,
272
+ limits={"/login": RateLimit(max_attempts=3, window_seconds=60)}
273
+ )
274
+
275
+ # With MongoDB storage for distributed deployments
276
+ app.add_middleware(
277
+ AuthRateLimitMiddleware,
278
+ store=MongoDBRateLimitStore(db)
279
+ )
280
+ """
281
+
282
+ def __init__(
283
+ self,
284
+ app: Callable,
285
+ limits: Optional[Dict[str, RateLimit]] = None,
286
+ store: Optional[InMemoryRateLimitStore] = None,
287
+ include_email_in_key: bool = True,
288
+ ):
289
+ """
290
+ Initialize rate limit middleware.
291
+
292
+ Args:
293
+ app: ASGI application
294
+ limits: Dict of path -> RateLimit config. Defaults to DEFAULT_AUTH_RATE_LIMITS.
295
+ store: Rate limit storage backend. Defaults to in-memory store.
296
+ include_email_in_key: Include email in rate limit key for more granular limits.
297
+ """
298
+ super().__init__(app)
299
+ self._limits = limits or DEFAULT_AUTH_RATE_LIMITS
300
+ self._store = store or _default_store
301
+ self._include_email_in_key = include_email_in_key
302
+
303
+ logger.info(
304
+ f"AuthRateLimitMiddleware initialized with limits for: {list(self._limits.keys())}"
305
+ )
306
+
307
+ async def dispatch(
308
+ self,
309
+ request: Request,
310
+ call_next: Callable[[Request], Response],
311
+ ) -> Response:
312
+ """Process request through rate limiter."""
313
+ path = request.url.path
314
+ method = request.method
315
+
316
+ # Only rate limit POST requests to configured endpoints
317
+ if method != "POST" or path not in self._limits:
318
+ return await call_next(request)
319
+
320
+ limit = self._limits[path]
321
+ identifier = await self._build_identifier(request, path)
322
+
323
+ # Check current count (before recording this attempt)
324
+ current_count = await self._store.get_count(identifier, limit.window_seconds)
325
+
326
+ if current_count >= limit.max_attempts:
327
+ logger.warning(
328
+ f"Rate limit exceeded: {identifier} "
329
+ f"({current_count}/{limit.max_attempts} in {limit.window_seconds}s)"
330
+ )
331
+ return self._rate_limit_response(limit.window_seconds)
332
+
333
+ # Record this attempt
334
+ await self._store.record_attempt(identifier, limit.window_seconds)
335
+
336
+ # Process request
337
+ response = await call_next(request)
338
+
339
+ # Reset on successful login (2xx response)
340
+ if response.status_code < 300 and path == "/login":
341
+ await self._store.reset(identifier)
342
+
343
+ return response
344
+
345
+ async def _build_identifier(self, request: Request, path: str) -> str:
346
+ """Build rate limit identifier from request."""
347
+ parts = [path]
348
+
349
+ # Add client IP
350
+ client_ip = self._get_client_ip(request)
351
+ parts.append(client_ip)
352
+
353
+ # Optionally add email for more granular rate limiting
354
+ if self._include_email_in_key:
355
+ email = await self._extract_email(request)
356
+ if email:
357
+ parts.append(email)
358
+
359
+ return ":".join(parts)
360
+
361
+ def _get_client_ip(self, request: Request) -> str:
362
+ """Get client IP, respecting proxy headers."""
363
+ # Check X-Forwarded-For header (set by proxies/load balancers)
364
+ forwarded_for = request.headers.get("X-Forwarded-For")
365
+ if forwarded_for:
366
+ # Take the first IP (original client)
367
+ return forwarded_for.split(",")[0].strip()
368
+
369
+ # Check X-Real-IP header
370
+ real_ip = request.headers.get("X-Real-IP")
371
+ if real_ip:
372
+ return real_ip.strip()
373
+
374
+ # Fall back to direct client IP
375
+ if request.client:
376
+ return request.client.host
377
+
378
+ return "unknown"
379
+
380
+ async def _extract_email(self, request: Request) -> Optional[str]:
381
+ """Try to extract email from request body."""
382
+ try:
383
+ # Only try to read body for JSON requests
384
+ content_type = request.headers.get("content-type", "")
385
+ if "application/json" in content_type:
386
+ # Note: This consumes the body, so we need to be careful
387
+ # In practice, this is called before the body is read by the route
388
+ body = await request.body()
389
+ if body:
390
+ import json
391
+
392
+ data = json.loads(body)
393
+ return data.get("email")
394
+ except (ValueError, UnicodeDecodeError, KeyError):
395
+ pass
396
+ return None
397
+
398
+ @staticmethod
399
+ def _rate_limit_response(retry_after: int) -> JSONResponse:
400
+ """Return 429 Too Many Requests response."""
401
+ return JSONResponse(
402
+ status_code=429,
403
+ content={
404
+ "detail": f"Too many attempts. Please try again in {retry_after} seconds.",
405
+ "error": "rate_limit_exceeded",
406
+ "retry_after": retry_after,
407
+ },
408
+ headers={"Retry-After": str(retry_after)},
409
+ )
410
+
411
+
412
+ def create_rate_limit_middleware(
413
+ manifest_auth: Dict[str, Any],
414
+ store: Optional[InMemoryRateLimitStore] = None,
415
+ ) -> type:
416
+ """
417
+ Factory function to create rate limit middleware from manifest config.
418
+
419
+ Args:
420
+ manifest_auth: Auth section from manifest
421
+ store: Optional storage backend
422
+
423
+ Returns:
424
+ Configured middleware class
425
+
426
+ Manifest format:
427
+ {
428
+ "auth": {
429
+ "rate_limits": {
430
+ "/login": {"max_attempts": 5, "window_seconds": 300},
431
+ "/register": {"max_attempts": 3, "window_seconds": 3600}
432
+ }
433
+ }
434
+ }
435
+ """
436
+ rate_limits_config = manifest_auth.get("rate_limits", {})
437
+
438
+ limits: Dict[str, RateLimit] = {}
439
+ for path, config in rate_limits_config.items():
440
+ limits[path] = RateLimit(
441
+ max_attempts=config.get("max_attempts", 5),
442
+ window_seconds=config.get("window_seconds", 300),
443
+ )
444
+
445
+ # Use defaults if no config provided
446
+ if not limits:
447
+ limits = DEFAULT_AUTH_RATE_LIMITS.copy()
448
+
449
+ class ConfiguredRateLimitMiddleware(AuthRateLimitMiddleware):
450
+ def __init__(self, app: Callable):
451
+ super().__init__(app, limits=limits, store=store)
452
+
453
+ return ConfiguredRateLimitMiddleware
454
+
455
+
456
+ def rate_limit(
457
+ max_attempts: int = 5,
458
+ window_seconds: int = 300,
459
+ key_func: Optional[Callable[[Request], str]] = None,
460
+ ):
461
+ """
462
+ Decorator for rate limiting individual endpoints.
463
+
464
+ Usage:
465
+ @app.post("/login")
466
+ @rate_limit(max_attempts=5, window_seconds=300)
467
+ async def login(request: Request):
468
+ ...
469
+
470
+ Args:
471
+ max_attempts: Maximum attempts in window
472
+ window_seconds: Time window in seconds
473
+ key_func: Optional function to generate rate limit key from request
474
+ """
475
+
476
+ def decorator(func: Callable) -> Callable:
477
+ @wraps(func)
478
+ async def wrapper(request: Request, *args, **kwargs):
479
+ # Build identifier
480
+ if key_func:
481
+ identifier = key_func(request)
482
+ else:
483
+ client_ip = request.client.host if request.client else "unknown"
484
+ identifier = f"{func.__name__}:{client_ip}"
485
+
486
+ # Check rate limit
487
+ count = await _default_store.record_attempt(identifier, window_seconds)
488
+
489
+ if count > max_attempts:
490
+ logger.warning(f"Rate limit exceeded: {identifier} ({count}/{max_attempts})")
491
+ return JSONResponse(
492
+ status_code=429,
493
+ content={
494
+ "detail": f"Too many attempts. Try again in {window_seconds} seconds.",
495
+ "error": "rate_limit_exceeded",
496
+ },
497
+ headers={"Retry-After": str(window_seconds)},
498
+ )
499
+
500
+ return await func(request, *args, **kwargs)
501
+
502
+ return wrapper
503
+
504
+ return decorator