mdb-engine 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -86,6 +86,16 @@ class CasbinAdapter(BaseAuthorizationProvider):
86
86
  self._cache_lock = asyncio.Lock()
87
87
  self._mark_initialized()
88
88
 
89
+ @property
90
+ def enforcer(self):
91
+ """
92
+ Get the Casbin enforcer instance.
93
+
94
+ Returns:
95
+ casbin.AsyncEnforcer instance
96
+ """
97
+ return self._enforcer
98
+
89
99
  async def check(
90
100
  self,
91
101
  subject: str,
@@ -120,6 +120,7 @@ class SharedUserPool:
120
120
  token_expiry_hours: int = DEFAULT_TOKEN_EXPIRY_HOURS,
121
121
  allow_insecure_dev: bool = False,
122
122
  blacklist_fail_closed: bool = True,
123
+ websocket_session_manager: Any | None = None,
123
124
  ):
124
125
  """
125
126
  Initialize the shared user pool.
@@ -174,6 +175,7 @@ class SharedUserPool:
174
175
 
175
176
  self._token_expiry_hours = token_expiry_hours
176
177
  self._blacklist_indexes_created = False
178
+ self._websocket_session_manager = websocket_session_manager
177
179
 
178
180
  logger.info(f"SharedUserPool initialized (algorithm={jwt_algorithm})")
179
181
 
@@ -340,7 +342,9 @@ class SharedUserPool:
340
342
  ip_address: str | None = None,
341
343
  fingerprint: str | None = None,
342
344
  session_binding: dict[str, Any] | None = None,
343
- ) -> str | None:
345
+ generate_websocket_session: bool = True,
346
+ app_slug: str | None = None,
347
+ ) -> str | tuple[str, str] | None:
344
348
  """
345
349
  Authenticate user and return JWT token.
346
350
 
@@ -352,9 +356,14 @@ class SharedUserPool:
352
356
  session_binding: Session binding config from manifest:
353
357
  - bind_ip: Include IP in token claims
354
358
  - bind_fingerprint: Include fingerprint in token claims
359
+ generate_websocket_session: If True and WebSocket session manager available,
360
+ also generate WebSocket session key (default: True)
361
+ app_slug: Optional app slug for WebSocket session scoping
355
362
 
356
363
  Returns:
357
- JWT token if authentication succeeds, None otherwise
364
+ JWT token if authentication succeeds, None otherwise.
365
+ If generate_websocket_session=True and session manager available,
366
+ returns tuple (jwt_token, websocket_session_key), otherwise just jwt_token.
358
367
  """
359
368
  user = await self._collection.find_one(
360
369
  {
@@ -392,7 +401,28 @@ class SharedUserPool:
392
401
  # Generate JWT token with session binding claims
393
402
  token = self._generate_token(user, extra_claims=extra_claims or None)
394
403
 
404
+ # Generate WebSocket session key if requested and manager available
405
+ websocket_session_key = None
406
+ if generate_websocket_session and self._websocket_session_manager:
407
+ try:
408
+ user_id = str(user["_id"])
409
+ websocket_session_key = await self._websocket_session_manager.create_session(
410
+ user_id=user_id,
411
+ user_email=email,
412
+ app_slug=app_slug,
413
+ )
414
+ logger.debug(
415
+ f"Generated WebSocket session key for user '{email}' " f"(app: {app_slug})"
416
+ )
417
+ except (ValueError, TypeError, AttributeError, RuntimeError) as e:
418
+ # Log but don't fail authentication if WebSocket session generation fails
419
+ logger.warning(f"Failed to generate WebSocket session key: {e}")
420
+
395
421
  logger.info(f"User '{email}' authenticated successfully")
422
+
423
+ # Return tuple if WebSocket session key was generated, otherwise just token
424
+ if websocket_session_key:
425
+ return (token, websocket_session_key)
396
426
  return token
397
427
 
398
428
  async def validate_token(self, token: str) -> dict[str, Any] | None:
@@ -652,6 +682,47 @@ class SharedUserPool:
652
682
  )
653
683
  return result.modified_count > 0
654
684
 
685
+ async def update_user_metadata(
686
+ self,
687
+ email: str,
688
+ metadata: dict[str, Any],
689
+ ) -> dict[str, Any] | None:
690
+ """
691
+ Update user metadata fields.
692
+
693
+ This allows adding or updating custom fields on the user document
694
+ beyond the core schema (e.g., name, profile data, preferences).
695
+
696
+ Args:
697
+ email: User email
698
+ metadata: Dictionary of fields to update
699
+ (e.g., {"name": "John Doe", "preferences": {...}})
700
+
701
+ Returns:
702
+ Updated user document (without password_hash) or None if user not found
703
+
704
+ Example:
705
+ user = await pool.update_user_metadata(
706
+ "user@example.com",
707
+ {"name": "John Doe", "phone": "+1234567890"}
708
+ )
709
+ """
710
+ # Build update document, ensuring updated_at is always set
711
+ update_doc = {"$set": {**metadata, "updated_at": datetime.utcnow()}}
712
+
713
+ result = await self._collection.update_one(
714
+ {"email": email},
715
+ update_doc,
716
+ )
717
+
718
+ if result.modified_count > 0:
719
+ # Fetch and return updated user
720
+ updated_user = await self._collection.find_one({"email": email})
721
+ if updated_user:
722
+ logger.info(f"Updated metadata for user '{email}': {list(metadata.keys())}")
723
+ return self._sanitize_user(updated_user)
724
+ return None
725
+
655
726
  @staticmethod
656
727
  def user_has_role(
657
728
  user: dict[str, Any],
mdb_engine/auth/users.py CHANGED
@@ -1376,7 +1376,8 @@ async def sync_app_user_to_casbin(
1376
1376
  logger.debug("sync_app_user_to_casbin: Provider is not CasbinAdapter, skipping")
1377
1377
  return False
1378
1378
 
1379
- enforcer = authz_provider._enforcer
1379
+ # Access enforcer via property if available, otherwise fallback to private member
1380
+ enforcer = getattr(authz_provider, "enforcer", None) or authz_provider._enforcer # noqa: SLF001
1380
1381
 
1381
1382
  # Get user ID
1382
1383
  user_id = str(user.get("_id") or user.get("app_user_id", ""))
mdb_engine/auth/utils.py CHANGED
@@ -514,16 +514,41 @@ async def login_user(
514
514
  ip_address=device_info.get("ip_address"),
515
515
  )
516
516
 
517
+ # Generate WebSocket session key if WebSocket session manager available
518
+ websocket_session_key = None
519
+ try:
520
+ # Try to get WebSocket session manager from app state
521
+ app = getattr(request, "app", None)
522
+ if app:
523
+ websocket_session_manager = getattr(app.state, "websocket_session_manager", None)
524
+ if websocket_session_manager:
525
+ # Get app slug from request state if available
526
+ app_slug = getattr(request.state, "app_slug", None)
527
+ websocket_session_key = await websocket_session_manager.create_session(
528
+ user_id=str(user["_id"]),
529
+ user_email=user["email"],
530
+ app_slug=app_slug,
531
+ )
532
+ logger.debug(
533
+ f"Generated WebSocket session key for user '{user['email']}' "
534
+ f"(app: {app_slug})"
535
+ )
536
+ except (ValueError, TypeError, AttributeError, RuntimeError) as e:
537
+ # Log but don't fail login if WebSocket session generation fails
538
+ logger.warning(f"Failed to generate WebSocket session key during login: {e}")
539
+
517
540
  # Create response
541
+ response_data = {
542
+ "success": True,
543
+ "user": {"email": user["email"], "user_id": str(user["_id"])},
544
+ }
545
+ if websocket_session_key:
546
+ response_data["websocket_session_key"] = websocket_session_key
547
+
518
548
  if redirect_url:
519
549
  response = RedirectResponse(url=redirect_url, status_code=302)
520
550
  else:
521
- response = JSONResponse(
522
- {
523
- "success": True,
524
- "user": {"email": user["email"], "user_id": str(user["_id"])},
525
- }
526
- )
551
+ response = JSONResponse(response_data)
527
552
 
528
553
  # Set cookies
529
554
  set_auth_cookies(
@@ -0,0 +1,433 @@
1
+ """
2
+ WebSocket Session Manager with Envelope Encryption
3
+
4
+ Manages WebSocket session keys using envelope encryption and private collections.
5
+ Provides secure-by-default WebSocket authentication without relying on CSRF cookies.
6
+
7
+ This module is part of MDB_ENGINE - MongoDB Engine.
8
+
9
+ Security Model:
10
+ - Session keys generated on authentication
11
+ - Stored encrypted in _mdb_engine_websocket_sessions collection
12
+ - Validated during WebSocket upgrade
13
+ - Uses envelope encryption (same as app secrets)
14
+ - Security by default: CSRF always required
15
+ """
16
+
17
+ import base64
18
+ import logging
19
+ import secrets
20
+ from collections.abc import Callable
21
+ from datetime import datetime, timedelta
22
+ from typing import Any
23
+
24
+ from motor.motor_asyncio import AsyncIOMotorDatabase
25
+ from pymongo.errors import OperationFailure, PyMongoError
26
+
27
+ from ..core.encryption import EnvelopeEncryptionService
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Collection name for storing encrypted WebSocket session keys
32
+ WEBSOCKET_SESSIONS_COLLECTION_NAME = "_mdb_engine_websocket_sessions"
33
+
34
+ # Session key configuration
35
+ SESSION_KEY_SIZE = 32 # 256 bits
36
+ SESSION_TTL_HOURS = 24 # Sessions expire after 24 hours
37
+
38
+
39
+ class WebSocketSessionManager:
40
+ """
41
+ Manages WebSocket session keys using envelope encryption.
42
+
43
+ Session keys are:
44
+ - Generated on user authentication
45
+ - Encrypted using envelope encryption
46
+ - Stored in private collection (_mdb_engine_websocket_sessions)
47
+ - Validated during WebSocket upgrade
48
+ - Automatically expired after TTL
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ mongo_db: AsyncIOMotorDatabase,
54
+ encryption_service: EnvelopeEncryptionService,
55
+ ):
56
+ """
57
+ Initialize the WebSocket session manager.
58
+
59
+ Args:
60
+ mongo_db: MongoDB database instance (raw, not scoped)
61
+ encryption_service: Envelope encryption service instance
62
+ """
63
+ self._mongo_db = mongo_db
64
+ self._encryption_service = encryption_service
65
+ self._sessions_collection = mongo_db[WEBSOCKET_SESSIONS_COLLECTION_NAME]
66
+
67
+ @staticmethod
68
+ def generate_session_key() -> str:
69
+ """
70
+ Generate a random WebSocket session key.
71
+
72
+ Returns:
73
+ Base64-encoded session key string
74
+ """
75
+ key_bytes = secrets.token_bytes(SESSION_KEY_SIZE)
76
+ return base64.urlsafe_b64encode(key_bytes).decode().rstrip("=")
77
+
78
+ async def create_session(
79
+ self,
80
+ user_id: str,
81
+ user_email: str | None = None,
82
+ app_slug: str | None = None,
83
+ ) -> str:
84
+ """
85
+ Create a new WebSocket session with encrypted session key.
86
+
87
+ Args:
88
+ user_id: User ID
89
+ user_email: Optional user email
90
+ app_slug: Optional app slug for scoping
91
+
92
+ Returns:
93
+ Plaintext session key (to be sent to client)
94
+
95
+ Raises:
96
+ OperationFailure: If MongoDB operation fails
97
+ """
98
+ try:
99
+ # Generate session key
100
+ session_key = self.generate_session_key()
101
+
102
+ # Encrypt session key using envelope encryption
103
+ encrypted_key, encrypted_dek = self._encryption_service.encrypt_secret(session_key)
104
+
105
+ # Encode as base64 for storage
106
+ encrypted_key_b64 = base64.b64encode(encrypted_key).decode()
107
+ encrypted_dek_b64 = base64.b64encode(encrypted_dek).decode()
108
+
109
+ # Calculate expiration
110
+ expires_at = datetime.utcnow() + timedelta(hours=SESSION_TTL_HOURS)
111
+
112
+ # Prepare document
113
+ document = {
114
+ "_id": session_key, # Use session key as ID for fast lookup
115
+ "user_id": user_id,
116
+ "user_email": user_email,
117
+ "app_slug": app_slug,
118
+ "encrypted_key": encrypted_key_b64,
119
+ "encrypted_dek": encrypted_dek_b64,
120
+ "algorithm": "AES-256-GCM",
121
+ "created_at": datetime.utcnow(),
122
+ "expires_at": expires_at,
123
+ }
124
+
125
+ # Store in private collection
126
+ await self._sessions_collection.insert_one(document)
127
+
128
+ logger.info(
129
+ f"Created WebSocket session for user '{user_id}' "
130
+ f"(app: {app_slug}, expires: {expires_at})"
131
+ )
132
+
133
+ return session_key
134
+
135
+ except (OperationFailure, PyMongoError):
136
+ logger.exception("Failed to create WebSocket session")
137
+ raise
138
+
139
+ async def validate_session(
140
+ self,
141
+ session_key: str,
142
+ user_id: str | None = None,
143
+ ) -> dict[str, Any] | None:
144
+ """
145
+ Validate a WebSocket session key.
146
+
147
+ Args:
148
+ session_key: Session key to validate
149
+ user_id: Optional user ID for additional validation
150
+
151
+ Returns:
152
+ Session document if valid, None otherwise
153
+
154
+ Raises:
155
+ OperationFailure: If MongoDB operation fails
156
+ """
157
+ try:
158
+ # Find session by key
159
+ session_doc = await self._sessions_collection.find_one({"_id": session_key})
160
+
161
+ if not session_doc:
162
+ logger.warning(f"WebSocket session not found: {session_key[:16]}...")
163
+ return None
164
+
165
+ # Check expiration
166
+ expires_at = session_doc.get("expires_at")
167
+ if expires_at and expires_at < datetime.utcnow():
168
+ logger.warning(
169
+ f"WebSocket session expired: {session_key[:16]}... " f"(expired: {expires_at})"
170
+ )
171
+ # Clean up expired session
172
+ await self._sessions_collection.delete_one({"_id": session_key})
173
+ return None
174
+
175
+ # Optional: Validate user_id matches
176
+ if user_id and session_doc.get("user_id") != user_id:
177
+ logger.warning(
178
+ f"WebSocket session user mismatch: "
179
+ f"session_user={session_doc.get('user_id')}, "
180
+ f"provided_user={user_id}"
181
+ )
182
+ return None
183
+
184
+ # Decrypt session key to verify it's valid
185
+ try:
186
+ encrypted_key = base64.b64decode(session_doc["encrypted_key"])
187
+ encrypted_dek = base64.b64decode(session_doc["encrypted_dek"])
188
+ decrypted_key = self._encryption_service.decrypt_secret(
189
+ encrypted_key, encrypted_dek
190
+ )
191
+
192
+ # Verify decrypted key matches session_key
193
+ if decrypted_key != session_key:
194
+ logger.error(
195
+ f"WebSocket session key decryption mismatch: "
196
+ f"session_key={session_key[:16]}..."
197
+ )
198
+ return None
199
+
200
+ except (ValueError, TypeError, AttributeError, KeyError):
201
+ logger.exception("Failed to decrypt WebSocket session key")
202
+ return None
203
+
204
+ logger.debug(
205
+ f"Validated WebSocket session for user '{session_doc.get('user_id')}' "
206
+ f"(app: {session_doc.get('app_slug')})"
207
+ )
208
+
209
+ return {
210
+ "user_id": session_doc.get("user_id"),
211
+ "user_email": session_doc.get("user_email"),
212
+ "app_slug": session_doc.get("app_slug"),
213
+ "created_at": session_doc.get("created_at"),
214
+ "expires_at": session_doc.get("expires_at"),
215
+ }
216
+
217
+ except (OperationFailure, PyMongoError):
218
+ logger.exception("Failed to validate WebSocket session")
219
+ raise
220
+
221
+ async def revoke_session(self, session_key: str) -> bool:
222
+ """
223
+ Revoke a WebSocket session.
224
+
225
+ Args:
226
+ session_key: Session key to revoke
227
+
228
+ Returns:
229
+ True if session was revoked, False if not found
230
+ """
231
+ try:
232
+ result = await self._sessions_collection.delete_one({"_id": session_key})
233
+ if result.deleted_count > 0:
234
+ logger.info(f"Revoked WebSocket session: {session_key[:16]}...")
235
+ return True
236
+ return False
237
+ except (OperationFailure, PyMongoError):
238
+ logger.exception("Failed to revoke WebSocket session")
239
+ return False
240
+
241
+ async def revoke_user_sessions(self, user_id: str, app_slug: str | None = None) -> int:
242
+ """
243
+ Revoke all sessions for a user.
244
+
245
+ Args:
246
+ user_id: User ID
247
+ app_slug: Optional app slug filter
248
+
249
+ Returns:
250
+ Number of sessions revoked
251
+ """
252
+ try:
253
+ query = {"user_id": user_id}
254
+ if app_slug:
255
+ query["app_slug"] = app_slug
256
+
257
+ result = await self._sessions_collection.delete_many(query)
258
+ logger.info(
259
+ f"Revoked {result.deleted_count} WebSocket sessions "
260
+ f"for user '{user_id}' (app: {app_slug})"
261
+ )
262
+ return result.deleted_count
263
+ except (OperationFailure, PyMongoError):
264
+ logger.exception("Failed to revoke user WebSocket sessions")
265
+ return 0
266
+
267
+ async def cleanup_expired_sessions(self) -> int:
268
+ """
269
+ Clean up expired WebSocket sessions.
270
+
271
+ Returns:
272
+ Number of sessions cleaned up
273
+ """
274
+ try:
275
+ result = await self._sessions_collection.delete_many(
276
+ {"expires_at": {"$lt": datetime.utcnow()}}
277
+ )
278
+ if result.deleted_count > 0:
279
+ logger.info(f"Cleaned up {result.deleted_count} expired WebSocket sessions")
280
+ return result.deleted_count
281
+ except (OperationFailure, PyMongoError):
282
+ logger.exception("Failed to cleanup expired WebSocket sessions")
283
+ return 0
284
+
285
+
286
+ def create_websocket_session_endpoint(
287
+ session_manager: WebSocketSessionManager,
288
+ ) -> Callable:
289
+ """
290
+ Create a FastAPI endpoint for generating WebSocket session keys.
291
+
292
+ This endpoint requires authentication and generates a new WebSocket session key
293
+ for the authenticated user. The session key is encrypted and stored in the
294
+ private collection.
295
+
296
+ Args:
297
+ session_manager: WebSocketSessionManager instance
298
+
299
+ Returns:
300
+ FastAPI route handler function
301
+
302
+ Example:
303
+ ```python
304
+ from mdb_engine.auth.websocket_sessions import (
305
+ WebSocketSessionManager,
306
+ create_websocket_session_endpoint,
307
+ )
308
+ from mdb_engine.core.encryption import EnvelopeEncryptionService
309
+
310
+ # Initialize session manager
311
+ encryption_service = EnvelopeEncryptionService()
312
+ session_manager = WebSocketSessionManager(
313
+ mongo_db=db,
314
+ encryption_service=encryption_service,
315
+ )
316
+
317
+ # Create endpoint
318
+ endpoint = create_websocket_session_endpoint(session_manager)
319
+ app.get("/auth/websocket-session")(endpoint)
320
+ ```
321
+
322
+ The endpoint:
323
+ - Requires authentication (user must be logged in)
324
+ - Returns JSON: `{"session_key": "...", "expires_at": "..."}`
325
+ - Uses user info from `request.state.user` (set by SharedAuthMiddleware)
326
+ """
327
+ from fastapi import Request, status
328
+ from fastapi.responses import JSONResponse
329
+
330
+ async def websocket_session_endpoint(request: Request) -> JSONResponse:
331
+ """
332
+ Generate a WebSocket session key for the authenticated user.
333
+
334
+ Requires:
335
+ - User to be authenticated (via request.state.user or auth cookie)
336
+ - WebSocket session manager to be available
337
+
338
+ Returns:
339
+ - JSONResponse with session_key and expires_at
340
+ """
341
+ # Check if user is authenticated (set by middleware)
342
+ user = getattr(request.state, "user", None)
343
+
344
+ # If not set by middleware, try to authenticate using cookie
345
+ # This handles the case where endpoint is on parent app without auth middleware
346
+ if not user:
347
+ from .shared_middleware import AUTH_COOKIE_NAME
348
+
349
+ # Get user pool from app state
350
+ user_pool = None
351
+ try:
352
+ if hasattr(request, "app") and hasattr(request.app, "state"):
353
+ user_pool = getattr(request.app.state, "user_pool", None)
354
+ except (AttributeError, TypeError):
355
+ pass
356
+
357
+ # Only try to authenticate if we have a real user pool (not None)
358
+ if user_pool is not None:
359
+ # Extract token from cookie
360
+ token = None
361
+ try:
362
+ if hasattr(request, "cookies"):
363
+ token = request.cookies.get(AUTH_COOKIE_NAME)
364
+ except (AttributeError, TypeError):
365
+ pass
366
+
367
+ if token:
368
+ try:
369
+ # Validate token and get user
370
+ user = await user_pool.validate_token(token)
371
+ except (TypeError, AttributeError):
372
+ # If user_pool is a mock that can't be awaited, ignore
373
+ pass
374
+
375
+ if not user:
376
+ return JSONResponse(
377
+ status_code=status.HTTP_401_UNAUTHORIZED,
378
+ content={"detail": "Authentication required"},
379
+ )
380
+
381
+ # Extract user info
382
+ # Prefer user_id, sub (JWT standard), or _id (MongoDB document ID)
383
+ user_id = user.get("user_id") or user.get("sub") or user.get("_id")
384
+ if not user_id:
385
+ # Email is not a valid user_id - it's just metadata
386
+ logger.error("Cannot generate WebSocket session: user_id not found in user data")
387
+ return JSONResponse(
388
+ status_code=status.HTTP_400_BAD_REQUEST,
389
+ content={"detail": "Invalid user data"},
390
+ )
391
+ user_email = user.get("email")
392
+ app_slug = getattr(request.state, "app_slug", None)
393
+
394
+ try:
395
+ # Generate session key
396
+ session_key = await session_manager.create_session(
397
+ user_id=str(user_id),
398
+ user_email=user_email,
399
+ app_slug=app_slug,
400
+ )
401
+
402
+ # Get expiration time (24 hours from now)
403
+ from datetime import datetime, timedelta
404
+
405
+ expires_at = datetime.utcnow() + timedelta(hours=SESSION_TTL_HOURS)
406
+
407
+ logger.info(
408
+ f"Generated WebSocket session key for user '{user_id}' " f"(app: {app_slug})"
409
+ )
410
+
411
+ return JSONResponse(
412
+ {
413
+ "session_key": session_key,
414
+ "expires_at": expires_at.isoformat(),
415
+ "ttl_hours": SESSION_TTL_HOURS,
416
+ }
417
+ )
418
+
419
+ except (
420
+ ValueError,
421
+ TypeError,
422
+ AttributeError,
423
+ RuntimeError,
424
+ OperationFailure,
425
+ PyMongoError,
426
+ ):
427
+ logger.exception("Failed to generate WebSocket session key")
428
+ return JSONResponse(
429
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
430
+ content={"detail": "Failed to generate WebSocket session key"},
431
+ )
432
+
433
+ return websocket_session_endpoint