mdb-engine 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +38 -6
- mdb_engine/auth/README.md +534 -11
- mdb_engine/auth/__init__.py +129 -28
- mdb_engine/auth/audit.py +592 -0
- mdb_engine/auth/casbin_factory.py +10 -14
- mdb_engine/auth/config_helpers.py +7 -6
- mdb_engine/auth/cookie_utils.py +3 -7
- mdb_engine/auth/csrf.py +373 -0
- mdb_engine/auth/decorators.py +3 -10
- mdb_engine/auth/dependencies.py +37 -45
- mdb_engine/auth/helpers.py +3 -3
- mdb_engine/auth/integration.py +30 -73
- mdb_engine/auth/jwt.py +2 -6
- mdb_engine/auth/middleware.py +77 -34
- mdb_engine/auth/oso_factory.py +16 -36
- mdb_engine/auth/provider.py +17 -38
- mdb_engine/auth/rate_limiter.py +504 -0
- mdb_engine/auth/restrictions.py +8 -24
- mdb_engine/auth/session_manager.py +14 -29
- mdb_engine/auth/shared_middleware.py +600 -0
- mdb_engine/auth/shared_users.py +759 -0
- mdb_engine/auth/token_store.py +14 -28
- mdb_engine/auth/users.py +54 -113
- mdb_engine/auth/utils.py +213 -15
- mdb_engine/cli/commands/generate.py +545 -9
- mdb_engine/cli/commands/validate.py +3 -7
- mdb_engine/cli/utils.py +3 -3
- mdb_engine/config.py +7 -21
- mdb_engine/constants.py +65 -0
- mdb_engine/core/README.md +117 -6
- mdb_engine/core/__init__.py +39 -7
- mdb_engine/core/app_registration.py +22 -41
- mdb_engine/core/app_secrets.py +290 -0
- mdb_engine/core/connection.py +18 -9
- mdb_engine/core/encryption.py +223 -0
- mdb_engine/core/engine.py +758 -95
- mdb_engine/core/index_management.py +12 -16
- mdb_engine/core/manifest.py +424 -135
- mdb_engine/core/ray_integration.py +435 -0
- mdb_engine/core/seeding.py +10 -18
- mdb_engine/core/service_initialization.py +12 -23
- mdb_engine/core/types.py +2 -5
- mdb_engine/database/README.md +112 -16
- mdb_engine/database/__init__.py +17 -6
- mdb_engine/database/abstraction.py +25 -37
- mdb_engine/database/connection.py +11 -18
- mdb_engine/database/query_validator.py +367 -0
- mdb_engine/database/resource_limiter.py +204 -0
- mdb_engine/database/scoped_wrapper.py +713 -196
- mdb_engine/embeddings/__init__.py +17 -9
- mdb_engine/embeddings/dependencies.py +1 -3
- mdb_engine/embeddings/service.py +11 -25
- mdb_engine/exceptions.py +92 -0
- mdb_engine/indexes/README.md +30 -13
- mdb_engine/indexes/__init__.py +1 -0
- mdb_engine/indexes/helpers.py +1 -1
- mdb_engine/indexes/manager.py +50 -114
- mdb_engine/memory/README.md +2 -2
- mdb_engine/memory/__init__.py +1 -2
- mdb_engine/memory/service.py +30 -87
- mdb_engine/observability/README.md +4 -2
- mdb_engine/observability/__init__.py +26 -9
- mdb_engine/observability/health.py +8 -9
- mdb_engine/observability/metrics.py +32 -12
- mdb_engine/routing/README.md +1 -1
- mdb_engine/routing/__init__.py +1 -3
- mdb_engine/routing/websockets.py +25 -60
- mdb_engine-0.1.7.dist-info/METADATA +285 -0
- mdb_engine-0.1.7.dist-info/RECORD +85 -0
- mdb_engine-0.1.6.dist-info/METADATA +0 -213
- mdb_engine-0.1.6.dist-info/RECORD +0 -75
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/WHEEL +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/top_level.txt +0 -0
mdb_engine/auth/cookie_utils.py
CHANGED
|
@@ -51,8 +51,7 @@ def get_secure_cookie_settings(
|
|
|
51
51
|
# Auto-detect: secure if HTTPS or production environment
|
|
52
52
|
is_https = request.url.scheme == "https"
|
|
53
53
|
is_production = (
|
|
54
|
-
os.getenv("G_NOME_ENV") == "production"
|
|
55
|
-
or os.getenv("ENVIRONMENT") == "production"
|
|
54
|
+
os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
|
|
56
55
|
)
|
|
57
56
|
secure = is_https or is_production
|
|
58
57
|
elif cookie_secure == "true":
|
|
@@ -63,8 +62,7 @@ def get_secure_cookie_settings(
|
|
|
63
62
|
# No config - use environment-based defaults
|
|
64
63
|
is_https = request.url.scheme == "https"
|
|
65
64
|
is_production = (
|
|
66
|
-
os.getenv("G_NOME_ENV") == "production"
|
|
67
|
-
or os.getenv("ENVIRONMENT") == "production"
|
|
65
|
+
os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
|
|
68
66
|
)
|
|
69
67
|
secure = is_https or is_production
|
|
70
68
|
|
|
@@ -153,6 +151,4 @@ def clear_auth_cookies(response, request: Optional[Request] = None):
|
|
|
153
151
|
response.delete_cookie(key="token", httponly=True, secure=secure, samesite=samesite)
|
|
154
152
|
|
|
155
153
|
# Delete refresh token cookie
|
|
156
|
-
response.delete_cookie(
|
|
157
|
-
key="refresh_token", httponly=True, secure=secure, samesite=samesite
|
|
158
|
-
)
|
|
154
|
+
response.delete_cookie(key="refresh_token", httponly=True, secure=secure, samesite=samesite)
|
mdb_engine/auth/csrf.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CSRF Protection Middleware
|
|
3
|
+
|
|
4
|
+
Implements the Double-Submit Cookie pattern for Cross-Site Request Forgery protection.
|
|
5
|
+
Auto-enabled for shared auth mode, with manifest-configurable options.
|
|
6
|
+
|
|
7
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
8
|
+
|
|
9
|
+
Security Features:
|
|
10
|
+
- Double-submit cookie pattern (industry standard)
|
|
11
|
+
- Cryptographically secure token generation
|
|
12
|
+
- Configurable exempt routes for APIs
|
|
13
|
+
- SameSite cookie attribute for additional protection
|
|
14
|
+
- Token rotation on each request (optional)
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
# Auto-enabled for shared auth mode in engine.create_app()
|
|
18
|
+
|
|
19
|
+
# Or manual usage:
|
|
20
|
+
from mdb_engine.auth.csrf import CSRFMiddleware
|
|
21
|
+
app.add_middleware(CSRFMiddleware, exempt_routes=["/api/*"])
|
|
22
|
+
|
|
23
|
+
# In templates, include the token:
|
|
24
|
+
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
|
25
|
+
|
|
26
|
+
# Or in JavaScript:
|
|
27
|
+
fetch('/endpoint', {
|
|
28
|
+
headers: {'X-CSRF-Token': getCookie('csrf_token')}
|
|
29
|
+
})
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
import fnmatch
|
|
33
|
+
import hashlib
|
|
34
|
+
import hmac
|
|
35
|
+
import logging
|
|
36
|
+
import os
|
|
37
|
+
import secrets
|
|
38
|
+
import time
|
|
39
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set
|
|
40
|
+
|
|
41
|
+
from fastapi import Request, Response, status
|
|
42
|
+
from fastapi.responses import JSONResponse
|
|
43
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
# Token settings
|
|
48
|
+
CSRF_TOKEN_LENGTH = 32 # 256 bits
|
|
49
|
+
CSRF_COOKIE_NAME = "csrf_token"
|
|
50
|
+
CSRF_HEADER_NAME = "X-CSRF-Token"
|
|
51
|
+
CSRF_FORM_FIELD = "csrf_token"
|
|
52
|
+
DEFAULT_TOKEN_TTL = 3600 # 1 hour
|
|
53
|
+
|
|
54
|
+
# Methods that require CSRF validation
|
|
55
|
+
UNSAFE_METHODS = {"POST", "PUT", "DELETE", "PATCH"}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def generate_csrf_token(secret: Optional[str] = None) -> str:
|
|
59
|
+
"""
|
|
60
|
+
Generate a cryptographically secure CSRF token.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
secret: Optional secret for HMAC signing (adds tamper detection)
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
URL-safe base64 encoded token
|
|
67
|
+
"""
|
|
68
|
+
raw_token = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
|
69
|
+
|
|
70
|
+
if secret:
|
|
71
|
+
# Add HMAC signature for tamper detection
|
|
72
|
+
timestamp = str(int(time.time()))
|
|
73
|
+
message = f"{raw_token}:{timestamp}"
|
|
74
|
+
signature = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[:16]
|
|
75
|
+
return f"{raw_token}:{timestamp}:{signature}"
|
|
76
|
+
|
|
77
|
+
return raw_token
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def validate_csrf_token(
|
|
81
|
+
token: str,
|
|
82
|
+
secret: Optional[str] = None,
|
|
83
|
+
max_age: int = DEFAULT_TOKEN_TTL,
|
|
84
|
+
) -> bool:
|
|
85
|
+
"""
|
|
86
|
+
Validate a CSRF token.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
token: The token to validate
|
|
90
|
+
secret: Optional secret for HMAC verification
|
|
91
|
+
max_age: Maximum token age in seconds
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
True if valid, False otherwise
|
|
95
|
+
"""
|
|
96
|
+
if not token:
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
if secret and ":" in token:
|
|
100
|
+
# Verify HMAC-signed token
|
|
101
|
+
try:
|
|
102
|
+
parts = token.split(":")
|
|
103
|
+
if len(parts) != 3:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
raw_token, timestamp_str, signature = parts
|
|
107
|
+
timestamp = int(timestamp_str)
|
|
108
|
+
|
|
109
|
+
# Check age
|
|
110
|
+
if time.time() - timestamp > max_age:
|
|
111
|
+
logger.debug("CSRF token expired")
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
# Verify signature
|
|
115
|
+
message = f"{raw_token}:{timestamp_str}"
|
|
116
|
+
expected_sig = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[
|
|
117
|
+
:16
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
if not hmac.compare_digest(signature, expected_sig):
|
|
121
|
+
logger.warning("CSRF token signature mismatch")
|
|
122
|
+
return False
|
|
123
|
+
|
|
124
|
+
return True
|
|
125
|
+
except (ValueError, IndexError) as e:
|
|
126
|
+
logger.warning(f"CSRF token validation error: {e}")
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
# Simple token validation (just check it exists and has reasonable length)
|
|
130
|
+
return len(token) >= CSRF_TOKEN_LENGTH
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
134
|
+
"""
|
|
135
|
+
CSRF Protection Middleware using Double-Submit Cookie pattern.
|
|
136
|
+
|
|
137
|
+
The double-submit cookie pattern works by:
|
|
138
|
+
1. Setting a CSRF token in a cookie (with HttpOnly=False so JS can read it)
|
|
139
|
+
2. Requiring the same token in a header or form field
|
|
140
|
+
3. Since attackers can't read cookies from other domains, they can't forge requests
|
|
141
|
+
|
|
142
|
+
Additional protection from SameSite=Lax cookies prevents the browser from
|
|
143
|
+
sending cookies on cross-site requests.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
app,
|
|
149
|
+
secret: Optional[str] = None,
|
|
150
|
+
exempt_routes: Optional[List[str]] = None,
|
|
151
|
+
exempt_methods: Optional[Set[str]] = None,
|
|
152
|
+
cookie_name: str = CSRF_COOKIE_NAME,
|
|
153
|
+
header_name: str = CSRF_HEADER_NAME,
|
|
154
|
+
form_field: str = CSRF_FORM_FIELD,
|
|
155
|
+
token_ttl: int = DEFAULT_TOKEN_TTL,
|
|
156
|
+
rotate_tokens: bool = False,
|
|
157
|
+
secure_cookies: bool = True,
|
|
158
|
+
):
|
|
159
|
+
"""
|
|
160
|
+
Initialize CSRF middleware.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
app: FastAPI application
|
|
164
|
+
secret: Secret for HMAC token signing (recommended for production)
|
|
165
|
+
exempt_routes: Routes exempt from CSRF (supports wildcards: /api/*)
|
|
166
|
+
exempt_methods: HTTP methods exempt from CSRF (default: safe methods)
|
|
167
|
+
cookie_name: Name of the CSRF cookie
|
|
168
|
+
header_name: Name of the CSRF header
|
|
169
|
+
form_field: Name of the CSRF form field
|
|
170
|
+
token_ttl: Token time-to-live in seconds
|
|
171
|
+
rotate_tokens: Rotate token on each request (more secure, less convenient)
|
|
172
|
+
secure_cookies: Use Secure cookie flag (auto-detect HTTPS)
|
|
173
|
+
"""
|
|
174
|
+
super().__init__(app)
|
|
175
|
+
self.secret = secret or os.getenv("MDB_ENGINE_CSRF_SECRET")
|
|
176
|
+
self.exempt_routes = exempt_routes or []
|
|
177
|
+
self.exempt_methods = exempt_methods or {"GET", "HEAD", "OPTIONS", "TRACE"}
|
|
178
|
+
self.cookie_name = cookie_name
|
|
179
|
+
self.header_name = header_name
|
|
180
|
+
self.form_field = form_field
|
|
181
|
+
self.token_ttl = token_ttl
|
|
182
|
+
self.rotate_tokens = rotate_tokens
|
|
183
|
+
self.secure_cookies = secure_cookies
|
|
184
|
+
|
|
185
|
+
logger.info(
|
|
186
|
+
f"CSRFMiddleware initialized (exempt_routes={self.exempt_routes}, "
|
|
187
|
+
f"rotate_tokens={rotate_tokens})"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def _is_exempt(self, path: str) -> bool:
|
|
191
|
+
"""Check if a path is exempt from CSRF validation."""
|
|
192
|
+
for pattern in self.exempt_routes:
|
|
193
|
+
if fnmatch.fnmatch(path, pattern):
|
|
194
|
+
return True
|
|
195
|
+
return False
|
|
196
|
+
|
|
197
|
+
async def dispatch(
|
|
198
|
+
self,
|
|
199
|
+
request: Request,
|
|
200
|
+
call_next: Callable[[Request], Awaitable[Response]],
|
|
201
|
+
) -> Response:
|
|
202
|
+
"""
|
|
203
|
+
Process request through CSRF middleware.
|
|
204
|
+
"""
|
|
205
|
+
path = request.url.path
|
|
206
|
+
method = request.method
|
|
207
|
+
|
|
208
|
+
# Skip exempt routes
|
|
209
|
+
if self._is_exempt(path):
|
|
210
|
+
return await call_next(request)
|
|
211
|
+
|
|
212
|
+
# Skip safe methods
|
|
213
|
+
if method in self.exempt_methods:
|
|
214
|
+
# Generate and set token for GET requests (for forms)
|
|
215
|
+
response = await call_next(request)
|
|
216
|
+
|
|
217
|
+
# Set CSRF token cookie if not present
|
|
218
|
+
if not request.cookies.get(self.cookie_name):
|
|
219
|
+
token = generate_csrf_token(self.secret)
|
|
220
|
+
self._set_csrf_cookie(request, response, token)
|
|
221
|
+
|
|
222
|
+
# Make token available in request state for templates
|
|
223
|
+
request.state.csrf_token = request.cookies.get(self.cookie_name) or generate_csrf_token(
|
|
224
|
+
self.secret
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return response
|
|
228
|
+
|
|
229
|
+
# Validate CSRF token for unsafe methods
|
|
230
|
+
cookie_token = request.cookies.get(self.cookie_name)
|
|
231
|
+
if not cookie_token:
|
|
232
|
+
logger.warning(f"CSRF cookie missing for {method} {path}")
|
|
233
|
+
return JSONResponse(
|
|
234
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
235
|
+
content={"detail": "CSRF token missing"},
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Get token from header or form
|
|
239
|
+
header_token = request.headers.get(self.header_name)
|
|
240
|
+
form_token = None
|
|
241
|
+
|
|
242
|
+
# Note: Form-based CSRF token extraction not implemented.
|
|
243
|
+
# For now, we rely on header-based CSRF for all requests.
|
|
244
|
+
# TODO: Implement request.form() based extraction if needed.
|
|
245
|
+
|
|
246
|
+
submitted_token = header_token or form_token
|
|
247
|
+
|
|
248
|
+
if not submitted_token:
|
|
249
|
+
logger.warning(f"CSRF token not submitted for {method} {path}")
|
|
250
|
+
return JSONResponse(
|
|
251
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
252
|
+
content={"detail": "CSRF token not provided in header or form"},
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Compare tokens (constant-time comparison)
|
|
256
|
+
if not hmac.compare_digest(cookie_token, submitted_token):
|
|
257
|
+
logger.warning(f"CSRF token mismatch for {method} {path}")
|
|
258
|
+
return JSONResponse(
|
|
259
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
260
|
+
content={"detail": "CSRF token invalid"},
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Validate token (check signature if secret is used)
|
|
264
|
+
if self.secret and not validate_csrf_token(cookie_token, self.secret, self.token_ttl):
|
|
265
|
+
logger.warning(f"CSRF token validation failed for {method} {path}")
|
|
266
|
+
return JSONResponse(
|
|
267
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
268
|
+
content={"detail": "CSRF token expired or invalid"},
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Process request
|
|
272
|
+
response = await call_next(request)
|
|
273
|
+
|
|
274
|
+
# Optionally rotate token
|
|
275
|
+
if self.rotate_tokens:
|
|
276
|
+
new_token = generate_csrf_token(self.secret)
|
|
277
|
+
self._set_csrf_cookie(request, response, new_token)
|
|
278
|
+
|
|
279
|
+
return response
|
|
280
|
+
|
|
281
|
+
def _set_csrf_cookie(
|
|
282
|
+
self,
|
|
283
|
+
request: Request,
|
|
284
|
+
response: Response,
|
|
285
|
+
token: str,
|
|
286
|
+
) -> None:
|
|
287
|
+
"""Set the CSRF token cookie."""
|
|
288
|
+
is_https = request.url.scheme == "https"
|
|
289
|
+
is_production = os.getenv("ENVIRONMENT", "").lower() == "production"
|
|
290
|
+
|
|
291
|
+
response.set_cookie(
|
|
292
|
+
key=self.cookie_name,
|
|
293
|
+
value=token,
|
|
294
|
+
httponly=False, # Must be readable by JavaScript
|
|
295
|
+
secure=self.secure_cookies and (is_https or is_production),
|
|
296
|
+
samesite="lax", # Provides CSRF protection + allows top-level navigation
|
|
297
|
+
max_age=self.token_ttl,
|
|
298
|
+
path="/",
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def create_csrf_middleware(
|
|
303
|
+
manifest_auth: Dict[str, Any],
|
|
304
|
+
secret: Optional[str] = None,
|
|
305
|
+
) -> type:
|
|
306
|
+
"""
|
|
307
|
+
Create CSRF middleware from manifest configuration.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
manifest_auth: Auth section from manifest
|
|
311
|
+
secret: Optional CSRF secret (defaults to env var)
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Configured CSRFMiddleware class
|
|
315
|
+
"""
|
|
316
|
+
csrf_config = manifest_auth.get("csrf_protection", True)
|
|
317
|
+
|
|
318
|
+
# Handle boolean or object config
|
|
319
|
+
if isinstance(csrf_config, bool):
|
|
320
|
+
if not csrf_config:
|
|
321
|
+
# Return a no-op middleware
|
|
322
|
+
class NoOpMiddleware(BaseHTTPMiddleware):
|
|
323
|
+
async def dispatch(self, request, call_next):
|
|
324
|
+
return await call_next(request)
|
|
325
|
+
|
|
326
|
+
return NoOpMiddleware
|
|
327
|
+
|
|
328
|
+
# Use defaults
|
|
329
|
+
exempt_routes = manifest_auth.get("public_routes", [])
|
|
330
|
+
rotate_tokens = False
|
|
331
|
+
token_ttl = DEFAULT_TOKEN_TTL
|
|
332
|
+
else:
|
|
333
|
+
# Object configuration
|
|
334
|
+
exempt_routes = csrf_config.get("exempt_routes", manifest_auth.get("public_routes", []))
|
|
335
|
+
rotate_tokens = csrf_config.get("rotate_tokens", False)
|
|
336
|
+
token_ttl = csrf_config.get("token_ttl", DEFAULT_TOKEN_TTL)
|
|
337
|
+
|
|
338
|
+
# Create configured middleware class
|
|
339
|
+
class ConfiguredCSRFMiddleware(CSRFMiddleware):
|
|
340
|
+
def __init__(self, app):
|
|
341
|
+
super().__init__(
|
|
342
|
+
app,
|
|
343
|
+
secret=secret or os.getenv("MDB_ENGINE_CSRF_SECRET"),
|
|
344
|
+
exempt_routes=exempt_routes,
|
|
345
|
+
rotate_tokens=rotate_tokens,
|
|
346
|
+
token_ttl=token_ttl,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
return ConfiguredCSRFMiddleware
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
# Dependency for getting CSRF token in routes
|
|
353
|
+
def get_csrf_token(request: Request) -> str:
|
|
354
|
+
"""
|
|
355
|
+
Get or generate CSRF token for use in templates.
|
|
356
|
+
|
|
357
|
+
Usage in FastAPI route:
|
|
358
|
+
@app.get("/form")
|
|
359
|
+
def form_page(csrf_token: str = Depends(get_csrf_token)):
|
|
360
|
+
return templates.TemplateResponse("form.html", {"csrf_token": csrf_token})
|
|
361
|
+
"""
|
|
362
|
+
# Try to get from request state (set by middleware)
|
|
363
|
+
if hasattr(request.state, "csrf_token"):
|
|
364
|
+
return request.state.csrf_token
|
|
365
|
+
|
|
366
|
+
# Try to get from cookie
|
|
367
|
+
token = request.cookies.get(CSRF_COOKIE_NAME)
|
|
368
|
+
if token:
|
|
369
|
+
return token
|
|
370
|
+
|
|
371
|
+
# Generate new token
|
|
372
|
+
secret = os.getenv("MDB_ENGINE_CSRF_SECRET")
|
|
373
|
+
return generate_csrf_token(secret)
|
mdb_engine/auth/decorators.py
CHANGED
|
@@ -70,10 +70,7 @@ def _is_production_environment() -> bool:
|
|
|
70
70
|
"""Check if running in production environment."""
|
|
71
71
|
import os
|
|
72
72
|
|
|
73
|
-
return (
|
|
74
|
-
os.getenv("G_NOME_ENV") == "production"
|
|
75
|
-
or os.getenv("ENVIRONMENT") == "production"
|
|
76
|
-
)
|
|
73
|
+
return os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
|
|
77
74
|
|
|
78
75
|
|
|
79
76
|
def _validate_https(request: Request) -> None:
|
|
@@ -189,17 +186,13 @@ def rate_limit_auth(
|
|
|
189
186
|
|
|
190
187
|
# Use provided values or config values or defaults
|
|
191
188
|
if max_attempts is None:
|
|
192
|
-
max_attempts_val = (
|
|
193
|
-
rate_limit_config.get("max_attempts") if rate_limit_config else 5
|
|
194
|
-
)
|
|
189
|
+
max_attempts_val = rate_limit_config.get("max_attempts") if rate_limit_config else 5
|
|
195
190
|
else:
|
|
196
191
|
max_attempts_val = max_attempts
|
|
197
192
|
|
|
198
193
|
if window_seconds is None:
|
|
199
194
|
window_seconds_val = (
|
|
200
|
-
rate_limit_config.get("window_seconds")
|
|
201
|
-
if rate_limit_config
|
|
202
|
-
else 300
|
|
195
|
+
rate_limit_config.get("window_seconds") if rate_limit_config else 300
|
|
203
196
|
)
|
|
204
197
|
else:
|
|
205
198
|
window_seconds_val = window_seconds
|
mdb_engine/auth/dependencies.py
CHANGED
|
@@ -14,9 +14,11 @@ from typing import Any, Dict, Mapping, Optional, Tuple
|
|
|
14
14
|
|
|
15
15
|
import jwt
|
|
16
16
|
from fastapi import Cookie, Depends, HTTPException, Request, status
|
|
17
|
+
from pymongo.errors import PyMongoError
|
|
17
18
|
|
|
18
19
|
from ..exceptions import ConfigurationError
|
|
19
20
|
from .jwt import decode_jwt_token, extract_token_metadata
|
|
21
|
+
|
|
20
22
|
# Import from local modules
|
|
21
23
|
from .provider import AuthorizationProvider
|
|
22
24
|
from .session_manager import SessionManager
|
|
@@ -164,9 +166,7 @@ async def get_current_user(
|
|
|
164
166
|
if blacklist:
|
|
165
167
|
is_revoked = await blacklist.is_revoked(jti)
|
|
166
168
|
if is_revoked:
|
|
167
|
-
logger.info(
|
|
168
|
-
f"get_current_user: Token {jti} is blacklisted (revoked)"
|
|
169
|
-
)
|
|
169
|
+
logger.info(f"get_current_user: Token {jti} is blacklisted (revoked)")
|
|
170
170
|
return None
|
|
171
171
|
|
|
172
172
|
# Also check user-level revocation
|
|
@@ -174,9 +174,7 @@ async def get_current_user(
|
|
|
174
174
|
if user_id:
|
|
175
175
|
user_revoked = await blacklist.is_user_revoked(user_id)
|
|
176
176
|
if user_revoked:
|
|
177
|
-
logger.info(
|
|
178
|
-
f"get_current_user: All tokens for user {user_id} are revoked"
|
|
179
|
-
)
|
|
177
|
+
logger.info(f"get_current_user: All tokens for user {user_id} are revoked")
|
|
180
178
|
return None
|
|
181
179
|
|
|
182
180
|
payload = decode_jwt_token(token, str(SECRET_KEY))
|
|
@@ -184,9 +182,7 @@ async def get_current_user(
|
|
|
184
182
|
# Verify token type (should be access token for backward compatibility, or no type)
|
|
185
183
|
token_type = payload.get("type")
|
|
186
184
|
if token_type and token_type not in ("access", None):
|
|
187
|
-
logger.warning(
|
|
188
|
-
f"get_current_user: Invalid token type '{token_type}' for access token"
|
|
189
|
-
)
|
|
185
|
+
logger.warning(f"get_current_user: Invalid token type '{token_type}' for access token")
|
|
190
186
|
return None
|
|
191
187
|
|
|
192
188
|
logger.debug(
|
|
@@ -203,10 +199,12 @@ async def get_current_user(
|
|
|
203
199
|
except (ValueError, TypeError):
|
|
204
200
|
logger.exception("Validation error decoding JWT token")
|
|
205
201
|
return None
|
|
206
|
-
except
|
|
207
|
-
logger.exception("
|
|
208
|
-
|
|
209
|
-
|
|
202
|
+
except PyMongoError:
|
|
203
|
+
logger.exception("Database error checking token blacklist")
|
|
204
|
+
return None
|
|
205
|
+
except (AttributeError, KeyError):
|
|
206
|
+
logger.exception("State access error in get_current_user")
|
|
207
|
+
return None
|
|
210
208
|
|
|
211
209
|
|
|
212
210
|
async def get_current_user_from_request(request: Request) -> Optional[Dict[str, Any]]:
|
|
@@ -276,10 +274,12 @@ async def get_current_user_from_request(request: Request) -> Optional[Dict[str,
|
|
|
276
274
|
except (ValueError, TypeError):
|
|
277
275
|
logger.exception("Validation error decoding JWT token from request")
|
|
278
276
|
return None
|
|
279
|
-
except
|
|
280
|
-
logger.exception("
|
|
281
|
-
|
|
282
|
-
|
|
277
|
+
except PyMongoError:
|
|
278
|
+
logger.exception("Database error checking token blacklist from request")
|
|
279
|
+
return None
|
|
280
|
+
except (AttributeError, KeyError):
|
|
281
|
+
logger.exception("State access error in get_current_user_from_request")
|
|
282
|
+
return None
|
|
283
283
|
|
|
284
284
|
|
|
285
285
|
async def get_refresh_token(
|
|
@@ -314,9 +314,7 @@ async def get_refresh_token(
|
|
|
314
314
|
if blacklist:
|
|
315
315
|
is_revoked = await blacklist.is_revoked(jti)
|
|
316
316
|
if is_revoked:
|
|
317
|
-
logger.info(
|
|
318
|
-
f"get_refresh_token: Refresh token {jti} is blacklisted"
|
|
319
|
-
)
|
|
317
|
+
logger.info(f"get_refresh_token: Refresh token {jti} is blacklisted")
|
|
320
318
|
return None
|
|
321
319
|
|
|
322
320
|
payload = decode_jwt_token(refresh_token, str(SECRET_KEY))
|
|
@@ -350,13 +348,9 @@ async def get_refresh_token(
|
|
|
350
348
|
if stored_fingerprint:
|
|
351
349
|
from .utils import generate_session_fingerprint
|
|
352
350
|
|
|
353
|
-
device_id = request.cookies.get("device_id") or payload.get(
|
|
354
|
-
"device_id"
|
|
355
|
-
)
|
|
351
|
+
device_id = request.cookies.get("device_id") or payload.get("device_id")
|
|
356
352
|
if device_id:
|
|
357
|
-
current_fingerprint = generate_session_fingerprint(
|
|
358
|
-
request, device_id
|
|
359
|
-
)
|
|
353
|
+
current_fingerprint = generate_session_fingerprint(request, device_id)
|
|
360
354
|
if current_fingerprint != stored_fingerprint:
|
|
361
355
|
logger.warning(
|
|
362
356
|
f"get_refresh_token: Session fingerprint mismatch "
|
|
@@ -377,10 +371,12 @@ async def get_refresh_token(
|
|
|
377
371
|
except (ValueError, TypeError):
|
|
378
372
|
logger.exception("Validation error decoding refresh token")
|
|
379
373
|
return None
|
|
380
|
-
except
|
|
381
|
-
logger.exception("
|
|
382
|
-
|
|
383
|
-
|
|
374
|
+
except PyMongoError:
|
|
375
|
+
logger.exception("Database error checking refresh token")
|
|
376
|
+
return None
|
|
377
|
+
except (AttributeError, KeyError):
|
|
378
|
+
logger.exception("State access error in get_refresh_token")
|
|
379
|
+
return None
|
|
384
380
|
|
|
385
381
|
|
|
386
382
|
async def require_admin(
|
|
@@ -504,14 +500,14 @@ async def get_current_user_or_redirect(
|
|
|
504
500
|
headers={"Location": redirect_url},
|
|
505
501
|
detail="Not authenticated. Redirecting to login.",
|
|
506
502
|
)
|
|
507
|
-
except (ValueError, KeyError, AttributeError):
|
|
503
|
+
except (ValueError, KeyError, AttributeError) as e:
|
|
508
504
|
logger.exception(
|
|
509
505
|
f"Failed to generate login redirect URL for route '{login_route_name}'"
|
|
510
506
|
)
|
|
511
507
|
raise HTTPException(
|
|
512
508
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
513
509
|
detail="Authentication required, but redirect failed.",
|
|
514
|
-
)
|
|
510
|
+
) from e
|
|
515
511
|
return dict(user)
|
|
516
512
|
|
|
517
513
|
|
|
@@ -619,9 +615,7 @@ async def refresh_access_token(
|
|
|
619
615
|
from ..config import TOKEN_ROTATION_ENABLED
|
|
620
616
|
from .jwt import generate_token_pair
|
|
621
617
|
|
|
622
|
-
user_id = refresh_token_payload.get("user_id") or refresh_token_payload.get(
|
|
623
|
-
"email"
|
|
624
|
-
)
|
|
618
|
+
user_id = refresh_token_payload.get("user_id") or refresh_token_payload.get("email")
|
|
625
619
|
old_refresh_jti = refresh_token_payload.get("jti")
|
|
626
620
|
device_id = refresh_token_payload.get("device_id")
|
|
627
621
|
|
|
@@ -653,9 +647,7 @@ async def refresh_access_token(
|
|
|
653
647
|
|
|
654
648
|
device_id = device_id or request.cookies.get("device_id")
|
|
655
649
|
if device_id:
|
|
656
|
-
current_fingerprint = generate_session_fingerprint(
|
|
657
|
-
request, device_id
|
|
658
|
-
)
|
|
650
|
+
current_fingerprint = generate_session_fingerprint(request, device_id)
|
|
659
651
|
if current_fingerprint != stored_fingerprint:
|
|
660
652
|
logger.warning(
|
|
661
653
|
f"refresh_access_token: Session fingerprint mismatch "
|
|
@@ -671,9 +663,7 @@ async def refresh_access_token(
|
|
|
671
663
|
|
|
672
664
|
# Use existing device_id or generate new one
|
|
673
665
|
if not device_id:
|
|
674
|
-
device_id = (
|
|
675
|
-
str(uuid.uuid4()) if not device_info else device_info.get("device_id")
|
|
676
|
-
)
|
|
666
|
+
device_id = str(uuid.uuid4()) if not device_info else device_info.get("device_id")
|
|
677
667
|
|
|
678
668
|
if device_info:
|
|
679
669
|
device_info["device_id"] = device_id
|
|
@@ -741,7 +731,9 @@ async def refresh_access_token(
|
|
|
741
731
|
except (ValueError, TypeError, jwt.InvalidTokenError):
|
|
742
732
|
logger.exception("Validation error refreshing token")
|
|
743
733
|
return None
|
|
744
|
-
except
|
|
745
|
-
logger.exception("
|
|
746
|
-
|
|
747
|
-
|
|
734
|
+
except PyMongoError:
|
|
735
|
+
logger.exception("Database error refreshing token")
|
|
736
|
+
return None
|
|
737
|
+
except (AttributeError, KeyError):
|
|
738
|
+
logger.exception("State access error refreshing token")
|
|
739
|
+
return None
|
mdb_engine/auth/helpers.py
CHANGED
|
@@ -19,7 +19,7 @@ async def initialize_token_management(app, db):
|
|
|
19
19
|
|
|
20
20
|
Args:
|
|
21
21
|
app: FastAPI application instance
|
|
22
|
-
db: MongoDB database instance (
|
|
22
|
+
db: Scoped MongoDB database instance (ScopedMongoWrapper)
|
|
23
23
|
|
|
24
24
|
Example:
|
|
25
25
|
from mdb_engine.auth.helpers import initialize_token_management
|
|
@@ -27,8 +27,8 @@ async def initialize_token_management(app, db):
|
|
|
27
27
|
|
|
28
28
|
@app.on_event("startup")
|
|
29
29
|
async def startup():
|
|
30
|
-
# Get database from engine
|
|
31
|
-
db = engine.
|
|
30
|
+
# Get scoped database from engine
|
|
31
|
+
db = engine.get_scoped_db("my_app")
|
|
32
32
|
|
|
33
33
|
# Initialize token management
|
|
34
34
|
await initialize_token_management(app, db)
|