mdb-engine 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +13 -9
- mdb_engine/auth/__init__.py +18 -0
- mdb_engine/auth/csrf.py +651 -69
- mdb_engine/auth/provider.py +10 -0
- mdb_engine/auth/shared_users.py +73 -2
- mdb_engine/auth/users.py +2 -1
- mdb_engine/auth/utils.py +31 -6
- mdb_engine/auth/websocket_sessions.py +433 -0
- mdb_engine/auth/websocket_tickets.py +307 -0
- mdb_engine/core/app_registration.py +10 -0
- mdb_engine/core/engine.py +656 -21
- mdb_engine/core/manifest.py +26 -0
- mdb_engine/core/ray_integration.py +4 -4
- mdb_engine/core/types.py +2 -0
- mdb_engine/database/connection.py +6 -3
- mdb_engine/database/scoped_wrapper.py +3 -3
- mdb_engine/indexes/manager.py +3 -3
- mdb_engine/observability/health.py +7 -7
- mdb_engine/routing/README.md +9 -2
- mdb_engine/routing/websockets.py +479 -56
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/METADATA +128 -4
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/RECORD +26 -24
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/WHEEL +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/top_level.txt +0 -0
mdb_engine/auth/provider.py
CHANGED
|
@@ -86,6 +86,16 @@ class CasbinAdapter(BaseAuthorizationProvider):
|
|
|
86
86
|
self._cache_lock = asyncio.Lock()
|
|
87
87
|
self._mark_initialized()
|
|
88
88
|
|
|
89
|
+
@property
|
|
90
|
+
def enforcer(self):
|
|
91
|
+
"""
|
|
92
|
+
Get the Casbin enforcer instance.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
casbin.AsyncEnforcer instance
|
|
96
|
+
"""
|
|
97
|
+
return self._enforcer
|
|
98
|
+
|
|
89
99
|
async def check(
|
|
90
100
|
self,
|
|
91
101
|
subject: str,
|
mdb_engine/auth/shared_users.py
CHANGED
|
@@ -120,6 +120,7 @@ class SharedUserPool:
|
|
|
120
120
|
token_expiry_hours: int = DEFAULT_TOKEN_EXPIRY_HOURS,
|
|
121
121
|
allow_insecure_dev: bool = False,
|
|
122
122
|
blacklist_fail_closed: bool = True,
|
|
123
|
+
websocket_session_manager: Any | None = None,
|
|
123
124
|
):
|
|
124
125
|
"""
|
|
125
126
|
Initialize the shared user pool.
|
|
@@ -174,6 +175,7 @@ class SharedUserPool:
|
|
|
174
175
|
|
|
175
176
|
self._token_expiry_hours = token_expiry_hours
|
|
176
177
|
self._blacklist_indexes_created = False
|
|
178
|
+
self._websocket_session_manager = websocket_session_manager
|
|
177
179
|
|
|
178
180
|
logger.info(f"SharedUserPool initialized (algorithm={jwt_algorithm})")
|
|
179
181
|
|
|
@@ -340,7 +342,9 @@ class SharedUserPool:
|
|
|
340
342
|
ip_address: str | None = None,
|
|
341
343
|
fingerprint: str | None = None,
|
|
342
344
|
session_binding: dict[str, Any] | None = None,
|
|
343
|
-
|
|
345
|
+
generate_websocket_session: bool = True,
|
|
346
|
+
app_slug: str | None = None,
|
|
347
|
+
) -> str | tuple[str, str] | None:
|
|
344
348
|
"""
|
|
345
349
|
Authenticate user and return JWT token.
|
|
346
350
|
|
|
@@ -352,9 +356,14 @@ class SharedUserPool:
|
|
|
352
356
|
session_binding: Session binding config from manifest:
|
|
353
357
|
- bind_ip: Include IP in token claims
|
|
354
358
|
- bind_fingerprint: Include fingerprint in token claims
|
|
359
|
+
generate_websocket_session: If True and WebSocket session manager available,
|
|
360
|
+
also generate WebSocket session key (default: True)
|
|
361
|
+
app_slug: Optional app slug for WebSocket session scoping
|
|
355
362
|
|
|
356
363
|
Returns:
|
|
357
|
-
JWT token if authentication succeeds, None otherwise
|
|
364
|
+
JWT token if authentication succeeds, None otherwise.
|
|
365
|
+
If generate_websocket_session=True and session manager available,
|
|
366
|
+
returns tuple (jwt_token, websocket_session_key), otherwise just jwt_token.
|
|
358
367
|
"""
|
|
359
368
|
user = await self._collection.find_one(
|
|
360
369
|
{
|
|
@@ -392,7 +401,28 @@ class SharedUserPool:
|
|
|
392
401
|
# Generate JWT token with session binding claims
|
|
393
402
|
token = self._generate_token(user, extra_claims=extra_claims or None)
|
|
394
403
|
|
|
404
|
+
# Generate WebSocket session key if requested and manager available
|
|
405
|
+
websocket_session_key = None
|
|
406
|
+
if generate_websocket_session and self._websocket_session_manager:
|
|
407
|
+
try:
|
|
408
|
+
user_id = str(user["_id"])
|
|
409
|
+
websocket_session_key = await self._websocket_session_manager.create_session(
|
|
410
|
+
user_id=user_id,
|
|
411
|
+
user_email=email,
|
|
412
|
+
app_slug=app_slug,
|
|
413
|
+
)
|
|
414
|
+
logger.debug(
|
|
415
|
+
f"Generated WebSocket session key for user '{email}' " f"(app: {app_slug})"
|
|
416
|
+
)
|
|
417
|
+
except (ValueError, TypeError, AttributeError, RuntimeError) as e:
|
|
418
|
+
# Log but don't fail authentication if WebSocket session generation fails
|
|
419
|
+
logger.warning(f"Failed to generate WebSocket session key: {e}")
|
|
420
|
+
|
|
395
421
|
logger.info(f"User '{email}' authenticated successfully")
|
|
422
|
+
|
|
423
|
+
# Return tuple if WebSocket session key was generated, otherwise just token
|
|
424
|
+
if websocket_session_key:
|
|
425
|
+
return (token, websocket_session_key)
|
|
396
426
|
return token
|
|
397
427
|
|
|
398
428
|
async def validate_token(self, token: str) -> dict[str, Any] | None:
|
|
@@ -652,6 +682,47 @@ class SharedUserPool:
|
|
|
652
682
|
)
|
|
653
683
|
return result.modified_count > 0
|
|
654
684
|
|
|
685
|
+
async def update_user_metadata(
|
|
686
|
+
self,
|
|
687
|
+
email: str,
|
|
688
|
+
metadata: dict[str, Any],
|
|
689
|
+
) -> dict[str, Any] | None:
|
|
690
|
+
"""
|
|
691
|
+
Update user metadata fields.
|
|
692
|
+
|
|
693
|
+
This allows adding or updating custom fields on the user document
|
|
694
|
+
beyond the core schema (e.g., name, profile data, preferences).
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
email: User email
|
|
698
|
+
metadata: Dictionary of fields to update
|
|
699
|
+
(e.g., {"name": "John Doe", "preferences": {...}})
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
Updated user document (without password_hash) or None if user not found
|
|
703
|
+
|
|
704
|
+
Example:
|
|
705
|
+
user = await pool.update_user_metadata(
|
|
706
|
+
"user@example.com",
|
|
707
|
+
{"name": "John Doe", "phone": "+1234567890"}
|
|
708
|
+
)
|
|
709
|
+
"""
|
|
710
|
+
# Build update document, ensuring updated_at is always set
|
|
711
|
+
update_doc = {"$set": {**metadata, "updated_at": datetime.utcnow()}}
|
|
712
|
+
|
|
713
|
+
result = await self._collection.update_one(
|
|
714
|
+
{"email": email},
|
|
715
|
+
update_doc,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
if result.modified_count > 0:
|
|
719
|
+
# Fetch and return updated user
|
|
720
|
+
updated_user = await self._collection.find_one({"email": email})
|
|
721
|
+
if updated_user:
|
|
722
|
+
logger.info(f"Updated metadata for user '{email}': {list(metadata.keys())}")
|
|
723
|
+
return self._sanitize_user(updated_user)
|
|
724
|
+
return None
|
|
725
|
+
|
|
655
726
|
@staticmethod
|
|
656
727
|
def user_has_role(
|
|
657
728
|
user: dict[str, Any],
|
mdb_engine/auth/users.py
CHANGED
|
@@ -1376,7 +1376,8 @@ async def sync_app_user_to_casbin(
|
|
|
1376
1376
|
logger.debug("sync_app_user_to_casbin: Provider is not CasbinAdapter, skipping")
|
|
1377
1377
|
return False
|
|
1378
1378
|
|
|
1379
|
-
enforcer
|
|
1379
|
+
# Access enforcer via property if available, otherwise fallback to private member
|
|
1380
|
+
enforcer = getattr(authz_provider, "enforcer", None) or authz_provider._enforcer # noqa: SLF001
|
|
1380
1381
|
|
|
1381
1382
|
# Get user ID
|
|
1382
1383
|
user_id = str(user.get("_id") or user.get("app_user_id", ""))
|
mdb_engine/auth/utils.py
CHANGED
|
@@ -514,16 +514,41 @@ async def login_user(
|
|
|
514
514
|
ip_address=device_info.get("ip_address"),
|
|
515
515
|
)
|
|
516
516
|
|
|
517
|
+
# Generate WebSocket session key if WebSocket session manager available
|
|
518
|
+
websocket_session_key = None
|
|
519
|
+
try:
|
|
520
|
+
# Try to get WebSocket session manager from app state
|
|
521
|
+
app = getattr(request, "app", None)
|
|
522
|
+
if app:
|
|
523
|
+
websocket_session_manager = getattr(app.state, "websocket_session_manager", None)
|
|
524
|
+
if websocket_session_manager:
|
|
525
|
+
# Get app slug from request state if available
|
|
526
|
+
app_slug = getattr(request.state, "app_slug", None)
|
|
527
|
+
websocket_session_key = await websocket_session_manager.create_session(
|
|
528
|
+
user_id=str(user["_id"]),
|
|
529
|
+
user_email=user["email"],
|
|
530
|
+
app_slug=app_slug,
|
|
531
|
+
)
|
|
532
|
+
logger.debug(
|
|
533
|
+
f"Generated WebSocket session key for user '{user['email']}' "
|
|
534
|
+
f"(app: {app_slug})"
|
|
535
|
+
)
|
|
536
|
+
except (ValueError, TypeError, AttributeError, RuntimeError) as e:
|
|
537
|
+
# Log but don't fail login if WebSocket session generation fails
|
|
538
|
+
logger.warning(f"Failed to generate WebSocket session key during login: {e}")
|
|
539
|
+
|
|
517
540
|
# Create response
|
|
541
|
+
response_data = {
|
|
542
|
+
"success": True,
|
|
543
|
+
"user": {"email": user["email"], "user_id": str(user["_id"])},
|
|
544
|
+
}
|
|
545
|
+
if websocket_session_key:
|
|
546
|
+
response_data["websocket_session_key"] = websocket_session_key
|
|
547
|
+
|
|
518
548
|
if redirect_url:
|
|
519
549
|
response = RedirectResponse(url=redirect_url, status_code=302)
|
|
520
550
|
else:
|
|
521
|
-
response = JSONResponse(
|
|
522
|
-
{
|
|
523
|
-
"success": True,
|
|
524
|
-
"user": {"email": user["email"], "user_id": str(user["_id"])},
|
|
525
|
-
}
|
|
526
|
-
)
|
|
551
|
+
response = JSONResponse(response_data)
|
|
527
552
|
|
|
528
553
|
# Set cookies
|
|
529
554
|
set_auth_cookies(
|
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebSocket Session Manager with Envelope Encryption
|
|
3
|
+
|
|
4
|
+
Manages WebSocket session keys using envelope encryption and private collections.
|
|
5
|
+
Provides secure-by-default WebSocket authentication without relying on CSRF cookies.
|
|
6
|
+
|
|
7
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
8
|
+
|
|
9
|
+
Security Model:
|
|
10
|
+
- Session keys generated on authentication
|
|
11
|
+
- Stored encrypted in _mdb_engine_websocket_sessions collection
|
|
12
|
+
- Validated during WebSocket upgrade
|
|
13
|
+
- Uses envelope encryption (same as app secrets)
|
|
14
|
+
- Security by default: CSRF always required
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import base64
|
|
18
|
+
import logging
|
|
19
|
+
import secrets
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from datetime import datetime, timedelta
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from motor.motor_asyncio import AsyncIOMotorDatabase
|
|
25
|
+
from pymongo.errors import OperationFailure, PyMongoError
|
|
26
|
+
|
|
27
|
+
from ..core.encryption import EnvelopeEncryptionService
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
# Collection name for storing encrypted WebSocket session keys
|
|
32
|
+
WEBSOCKET_SESSIONS_COLLECTION_NAME = "_mdb_engine_websocket_sessions"
|
|
33
|
+
|
|
34
|
+
# Session key configuration
|
|
35
|
+
SESSION_KEY_SIZE = 32 # 256 bits
|
|
36
|
+
SESSION_TTL_HOURS = 24 # Sessions expire after 24 hours
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class WebSocketSessionManager:
|
|
40
|
+
"""
|
|
41
|
+
Manages WebSocket session keys using envelope encryption.
|
|
42
|
+
|
|
43
|
+
Session keys are:
|
|
44
|
+
- Generated on user authentication
|
|
45
|
+
- Encrypted using envelope encryption
|
|
46
|
+
- Stored in private collection (_mdb_engine_websocket_sessions)
|
|
47
|
+
- Validated during WebSocket upgrade
|
|
48
|
+
- Automatically expired after TTL
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
mongo_db: AsyncIOMotorDatabase,
|
|
54
|
+
encryption_service: EnvelopeEncryptionService,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Initialize the WebSocket session manager.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
mongo_db: MongoDB database instance (raw, not scoped)
|
|
61
|
+
encryption_service: Envelope encryption service instance
|
|
62
|
+
"""
|
|
63
|
+
self._mongo_db = mongo_db
|
|
64
|
+
self._encryption_service = encryption_service
|
|
65
|
+
self._sessions_collection = mongo_db[WEBSOCKET_SESSIONS_COLLECTION_NAME]
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def generate_session_key() -> str:
|
|
69
|
+
"""
|
|
70
|
+
Generate a random WebSocket session key.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Base64-encoded session key string
|
|
74
|
+
"""
|
|
75
|
+
key_bytes = secrets.token_bytes(SESSION_KEY_SIZE)
|
|
76
|
+
return base64.urlsafe_b64encode(key_bytes).decode().rstrip("=")
|
|
77
|
+
|
|
78
|
+
async def create_session(
|
|
79
|
+
self,
|
|
80
|
+
user_id: str,
|
|
81
|
+
user_email: str | None = None,
|
|
82
|
+
app_slug: str | None = None,
|
|
83
|
+
) -> str:
|
|
84
|
+
"""
|
|
85
|
+
Create a new WebSocket session with encrypted session key.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
user_id: User ID
|
|
89
|
+
user_email: Optional user email
|
|
90
|
+
app_slug: Optional app slug for scoping
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Plaintext session key (to be sent to client)
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
OperationFailure: If MongoDB operation fails
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
# Generate session key
|
|
100
|
+
session_key = self.generate_session_key()
|
|
101
|
+
|
|
102
|
+
# Encrypt session key using envelope encryption
|
|
103
|
+
encrypted_key, encrypted_dek = self._encryption_service.encrypt_secret(session_key)
|
|
104
|
+
|
|
105
|
+
# Encode as base64 for storage
|
|
106
|
+
encrypted_key_b64 = base64.b64encode(encrypted_key).decode()
|
|
107
|
+
encrypted_dek_b64 = base64.b64encode(encrypted_dek).decode()
|
|
108
|
+
|
|
109
|
+
# Calculate expiration
|
|
110
|
+
expires_at = datetime.utcnow() + timedelta(hours=SESSION_TTL_HOURS)
|
|
111
|
+
|
|
112
|
+
# Prepare document
|
|
113
|
+
document = {
|
|
114
|
+
"_id": session_key, # Use session key as ID for fast lookup
|
|
115
|
+
"user_id": user_id,
|
|
116
|
+
"user_email": user_email,
|
|
117
|
+
"app_slug": app_slug,
|
|
118
|
+
"encrypted_key": encrypted_key_b64,
|
|
119
|
+
"encrypted_dek": encrypted_dek_b64,
|
|
120
|
+
"algorithm": "AES-256-GCM",
|
|
121
|
+
"created_at": datetime.utcnow(),
|
|
122
|
+
"expires_at": expires_at,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
# Store in private collection
|
|
126
|
+
await self._sessions_collection.insert_one(document)
|
|
127
|
+
|
|
128
|
+
logger.info(
|
|
129
|
+
f"Created WebSocket session for user '{user_id}' "
|
|
130
|
+
f"(app: {app_slug}, expires: {expires_at})"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return session_key
|
|
134
|
+
|
|
135
|
+
except (OperationFailure, PyMongoError):
|
|
136
|
+
logger.exception("Failed to create WebSocket session")
|
|
137
|
+
raise
|
|
138
|
+
|
|
139
|
+
async def validate_session(
|
|
140
|
+
self,
|
|
141
|
+
session_key: str,
|
|
142
|
+
user_id: str | None = None,
|
|
143
|
+
) -> dict[str, Any] | None:
|
|
144
|
+
"""
|
|
145
|
+
Validate a WebSocket session key.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
session_key: Session key to validate
|
|
149
|
+
user_id: Optional user ID for additional validation
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Session document if valid, None otherwise
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
OperationFailure: If MongoDB operation fails
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
# Find session by key
|
|
159
|
+
session_doc = await self._sessions_collection.find_one({"_id": session_key})
|
|
160
|
+
|
|
161
|
+
if not session_doc:
|
|
162
|
+
logger.warning(f"WebSocket session not found: {session_key[:16]}...")
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
# Check expiration
|
|
166
|
+
expires_at = session_doc.get("expires_at")
|
|
167
|
+
if expires_at and expires_at < datetime.utcnow():
|
|
168
|
+
logger.warning(
|
|
169
|
+
f"WebSocket session expired: {session_key[:16]}... " f"(expired: {expires_at})"
|
|
170
|
+
)
|
|
171
|
+
# Clean up expired session
|
|
172
|
+
await self._sessions_collection.delete_one({"_id": session_key})
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
# Optional: Validate user_id matches
|
|
176
|
+
if user_id and session_doc.get("user_id") != user_id:
|
|
177
|
+
logger.warning(
|
|
178
|
+
f"WebSocket session user mismatch: "
|
|
179
|
+
f"session_user={session_doc.get('user_id')}, "
|
|
180
|
+
f"provided_user={user_id}"
|
|
181
|
+
)
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
# Decrypt session key to verify it's valid
|
|
185
|
+
try:
|
|
186
|
+
encrypted_key = base64.b64decode(session_doc["encrypted_key"])
|
|
187
|
+
encrypted_dek = base64.b64decode(session_doc["encrypted_dek"])
|
|
188
|
+
decrypted_key = self._encryption_service.decrypt_secret(
|
|
189
|
+
encrypted_key, encrypted_dek
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Verify decrypted key matches session_key
|
|
193
|
+
if decrypted_key != session_key:
|
|
194
|
+
logger.error(
|
|
195
|
+
f"WebSocket session key decryption mismatch: "
|
|
196
|
+
f"session_key={session_key[:16]}..."
|
|
197
|
+
)
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
except (ValueError, TypeError, AttributeError, KeyError):
|
|
201
|
+
logger.exception("Failed to decrypt WebSocket session key")
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
logger.debug(
|
|
205
|
+
f"Validated WebSocket session for user '{session_doc.get('user_id')}' "
|
|
206
|
+
f"(app: {session_doc.get('app_slug')})"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return {
|
|
210
|
+
"user_id": session_doc.get("user_id"),
|
|
211
|
+
"user_email": session_doc.get("user_email"),
|
|
212
|
+
"app_slug": session_doc.get("app_slug"),
|
|
213
|
+
"created_at": session_doc.get("created_at"),
|
|
214
|
+
"expires_at": session_doc.get("expires_at"),
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
except (OperationFailure, PyMongoError):
|
|
218
|
+
logger.exception("Failed to validate WebSocket session")
|
|
219
|
+
raise
|
|
220
|
+
|
|
221
|
+
async def revoke_session(self, session_key: str) -> bool:
|
|
222
|
+
"""
|
|
223
|
+
Revoke a WebSocket session.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
session_key: Session key to revoke
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
True if session was revoked, False if not found
|
|
230
|
+
"""
|
|
231
|
+
try:
|
|
232
|
+
result = await self._sessions_collection.delete_one({"_id": session_key})
|
|
233
|
+
if result.deleted_count > 0:
|
|
234
|
+
logger.info(f"Revoked WebSocket session: {session_key[:16]}...")
|
|
235
|
+
return True
|
|
236
|
+
return False
|
|
237
|
+
except (OperationFailure, PyMongoError):
|
|
238
|
+
logger.exception("Failed to revoke WebSocket session")
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
async def revoke_user_sessions(self, user_id: str, app_slug: str | None = None) -> int:
|
|
242
|
+
"""
|
|
243
|
+
Revoke all sessions for a user.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
user_id: User ID
|
|
247
|
+
app_slug: Optional app slug filter
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Number of sessions revoked
|
|
251
|
+
"""
|
|
252
|
+
try:
|
|
253
|
+
query = {"user_id": user_id}
|
|
254
|
+
if app_slug:
|
|
255
|
+
query["app_slug"] = app_slug
|
|
256
|
+
|
|
257
|
+
result = await self._sessions_collection.delete_many(query)
|
|
258
|
+
logger.info(
|
|
259
|
+
f"Revoked {result.deleted_count} WebSocket sessions "
|
|
260
|
+
f"for user '{user_id}' (app: {app_slug})"
|
|
261
|
+
)
|
|
262
|
+
return result.deleted_count
|
|
263
|
+
except (OperationFailure, PyMongoError):
|
|
264
|
+
logger.exception("Failed to revoke user WebSocket sessions")
|
|
265
|
+
return 0
|
|
266
|
+
|
|
267
|
+
async def cleanup_expired_sessions(self) -> int:
|
|
268
|
+
"""
|
|
269
|
+
Clean up expired WebSocket sessions.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Number of sessions cleaned up
|
|
273
|
+
"""
|
|
274
|
+
try:
|
|
275
|
+
result = await self._sessions_collection.delete_many(
|
|
276
|
+
{"expires_at": {"$lt": datetime.utcnow()}}
|
|
277
|
+
)
|
|
278
|
+
if result.deleted_count > 0:
|
|
279
|
+
logger.info(f"Cleaned up {result.deleted_count} expired WebSocket sessions")
|
|
280
|
+
return result.deleted_count
|
|
281
|
+
except (OperationFailure, PyMongoError):
|
|
282
|
+
logger.exception("Failed to cleanup expired WebSocket sessions")
|
|
283
|
+
return 0
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def create_websocket_session_endpoint(
|
|
287
|
+
session_manager: WebSocketSessionManager,
|
|
288
|
+
) -> Callable:
|
|
289
|
+
"""
|
|
290
|
+
Create a FastAPI endpoint for generating WebSocket session keys.
|
|
291
|
+
|
|
292
|
+
This endpoint requires authentication and generates a new WebSocket session key
|
|
293
|
+
for the authenticated user. The session key is encrypted and stored in the
|
|
294
|
+
private collection.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
session_manager: WebSocketSessionManager instance
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
FastAPI route handler function
|
|
301
|
+
|
|
302
|
+
Example:
|
|
303
|
+
```python
|
|
304
|
+
from mdb_engine.auth.websocket_sessions import (
|
|
305
|
+
WebSocketSessionManager,
|
|
306
|
+
create_websocket_session_endpoint,
|
|
307
|
+
)
|
|
308
|
+
from mdb_engine.core.encryption import EnvelopeEncryptionService
|
|
309
|
+
|
|
310
|
+
# Initialize session manager
|
|
311
|
+
encryption_service = EnvelopeEncryptionService()
|
|
312
|
+
session_manager = WebSocketSessionManager(
|
|
313
|
+
mongo_db=db,
|
|
314
|
+
encryption_service=encryption_service,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Create endpoint
|
|
318
|
+
endpoint = create_websocket_session_endpoint(session_manager)
|
|
319
|
+
app.get("/auth/websocket-session")(endpoint)
|
|
320
|
+
```
|
|
321
|
+
|
|
322
|
+
The endpoint:
|
|
323
|
+
- Requires authentication (user must be logged in)
|
|
324
|
+
- Returns JSON: `{"session_key": "...", "expires_at": "..."}`
|
|
325
|
+
- Uses user info from `request.state.user` (set by SharedAuthMiddleware)
|
|
326
|
+
"""
|
|
327
|
+
from fastapi import Request, status
|
|
328
|
+
from fastapi.responses import JSONResponse
|
|
329
|
+
|
|
330
|
+
async def websocket_session_endpoint(request: Request) -> JSONResponse:
|
|
331
|
+
"""
|
|
332
|
+
Generate a WebSocket session key for the authenticated user.
|
|
333
|
+
|
|
334
|
+
Requires:
|
|
335
|
+
- User to be authenticated (via request.state.user or auth cookie)
|
|
336
|
+
- WebSocket session manager to be available
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
- JSONResponse with session_key and expires_at
|
|
340
|
+
"""
|
|
341
|
+
# Check if user is authenticated (set by middleware)
|
|
342
|
+
user = getattr(request.state, "user", None)
|
|
343
|
+
|
|
344
|
+
# If not set by middleware, try to authenticate using cookie
|
|
345
|
+
# This handles the case where endpoint is on parent app without auth middleware
|
|
346
|
+
if not user:
|
|
347
|
+
from .shared_middleware import AUTH_COOKIE_NAME
|
|
348
|
+
|
|
349
|
+
# Get user pool from app state
|
|
350
|
+
user_pool = None
|
|
351
|
+
try:
|
|
352
|
+
if hasattr(request, "app") and hasattr(request.app, "state"):
|
|
353
|
+
user_pool = getattr(request.app.state, "user_pool", None)
|
|
354
|
+
except (AttributeError, TypeError):
|
|
355
|
+
pass
|
|
356
|
+
|
|
357
|
+
# Only try to authenticate if we have a real user pool (not None)
|
|
358
|
+
if user_pool is not None:
|
|
359
|
+
# Extract token from cookie
|
|
360
|
+
token = None
|
|
361
|
+
try:
|
|
362
|
+
if hasattr(request, "cookies"):
|
|
363
|
+
token = request.cookies.get(AUTH_COOKIE_NAME)
|
|
364
|
+
except (AttributeError, TypeError):
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
if token:
|
|
368
|
+
try:
|
|
369
|
+
# Validate token and get user
|
|
370
|
+
user = await user_pool.validate_token(token)
|
|
371
|
+
except (TypeError, AttributeError):
|
|
372
|
+
# If user_pool is a mock that can't be awaited, ignore
|
|
373
|
+
pass
|
|
374
|
+
|
|
375
|
+
if not user:
|
|
376
|
+
return JSONResponse(
|
|
377
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
378
|
+
content={"detail": "Authentication required"},
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Extract user info
|
|
382
|
+
# Prefer user_id, sub (JWT standard), or _id (MongoDB document ID)
|
|
383
|
+
user_id = user.get("user_id") or user.get("sub") or user.get("_id")
|
|
384
|
+
if not user_id:
|
|
385
|
+
# Email is not a valid user_id - it's just metadata
|
|
386
|
+
logger.error("Cannot generate WebSocket session: user_id not found in user data")
|
|
387
|
+
return JSONResponse(
|
|
388
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
389
|
+
content={"detail": "Invalid user data"},
|
|
390
|
+
)
|
|
391
|
+
user_email = user.get("email")
|
|
392
|
+
app_slug = getattr(request.state, "app_slug", None)
|
|
393
|
+
|
|
394
|
+
try:
|
|
395
|
+
# Generate session key
|
|
396
|
+
session_key = await session_manager.create_session(
|
|
397
|
+
user_id=str(user_id),
|
|
398
|
+
user_email=user_email,
|
|
399
|
+
app_slug=app_slug,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Get expiration time (24 hours from now)
|
|
403
|
+
from datetime import datetime, timedelta
|
|
404
|
+
|
|
405
|
+
expires_at = datetime.utcnow() + timedelta(hours=SESSION_TTL_HOURS)
|
|
406
|
+
|
|
407
|
+
logger.info(
|
|
408
|
+
f"Generated WebSocket session key for user '{user_id}' " f"(app: {app_slug})"
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
return JSONResponse(
|
|
412
|
+
{
|
|
413
|
+
"session_key": session_key,
|
|
414
|
+
"expires_at": expires_at.isoformat(),
|
|
415
|
+
"ttl_hours": SESSION_TTL_HOURS,
|
|
416
|
+
}
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
except (
|
|
420
|
+
ValueError,
|
|
421
|
+
TypeError,
|
|
422
|
+
AttributeError,
|
|
423
|
+
RuntimeError,
|
|
424
|
+
OperationFailure,
|
|
425
|
+
PyMongoError,
|
|
426
|
+
):
|
|
427
|
+
logger.exception("Failed to generate WebSocket session key")
|
|
428
|
+
return JSONResponse(
|
|
429
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
430
|
+
content={"detail": "Failed to generate WebSocket session key"},
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
return websocket_session_endpoint
|