mdb-engine 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/auth/__init__.py +9 -0
- mdb_engine/auth/csrf.py +493 -144
- mdb_engine/auth/provider.py +10 -0
- mdb_engine/auth/shared_users.py +41 -0
- mdb_engine/auth/users.py +2 -1
- mdb_engine/auth/websocket_tickets.py +307 -0
- mdb_engine/core/app_registration.py +10 -0
- mdb_engine/core/engine.py +632 -37
- mdb_engine/core/manifest.py +14 -0
- mdb_engine/core/ray_integration.py +4 -4
- mdb_engine/core/types.py +1 -0
- mdb_engine/database/connection.py +6 -3
- mdb_engine/database/scoped_wrapper.py +3 -3
- mdb_engine/indexes/manager.py +3 -3
- mdb_engine/observability/health.py +7 -7
- mdb_engine/routing/README.md +9 -2
- mdb_engine/routing/websockets.py +453 -74
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/METADATA +128 -4
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/RECORD +23 -22
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/WHEEL +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/top_level.txt +0 -0
mdb_engine/auth/provider.py
CHANGED
|
@@ -86,6 +86,16 @@ class CasbinAdapter(BaseAuthorizationProvider):
|
|
|
86
86
|
self._cache_lock = asyncio.Lock()
|
|
87
87
|
self._mark_initialized()
|
|
88
88
|
|
|
89
|
+
@property
|
|
90
|
+
def enforcer(self):
|
|
91
|
+
"""
|
|
92
|
+
Get the Casbin enforcer instance.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
casbin.AsyncEnforcer instance
|
|
96
|
+
"""
|
|
97
|
+
return self._enforcer
|
|
98
|
+
|
|
89
99
|
async def check(
|
|
90
100
|
self,
|
|
91
101
|
subject: str,
|
mdb_engine/auth/shared_users.py
CHANGED
|
@@ -682,6 +682,47 @@ class SharedUserPool:
|
|
|
682
682
|
)
|
|
683
683
|
return result.modified_count > 0
|
|
684
684
|
|
|
685
|
+
async def update_user_metadata(
|
|
686
|
+
self,
|
|
687
|
+
email: str,
|
|
688
|
+
metadata: dict[str, Any],
|
|
689
|
+
) -> dict[str, Any] | None:
|
|
690
|
+
"""
|
|
691
|
+
Update user metadata fields.
|
|
692
|
+
|
|
693
|
+
This allows adding or updating custom fields on the user document
|
|
694
|
+
beyond the core schema (e.g., name, profile data, preferences).
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
email: User email
|
|
698
|
+
metadata: Dictionary of fields to update
|
|
699
|
+
(e.g., {"name": "John Doe", "preferences": {...}})
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
Updated user document (without password_hash) or None if user not found
|
|
703
|
+
|
|
704
|
+
Example:
|
|
705
|
+
user = await pool.update_user_metadata(
|
|
706
|
+
"user@example.com",
|
|
707
|
+
{"name": "John Doe", "phone": "+1234567890"}
|
|
708
|
+
)
|
|
709
|
+
"""
|
|
710
|
+
# Build update document, ensuring updated_at is always set
|
|
711
|
+
update_doc = {"$set": {**metadata, "updated_at": datetime.utcnow()}}
|
|
712
|
+
|
|
713
|
+
result = await self._collection.update_one(
|
|
714
|
+
{"email": email},
|
|
715
|
+
update_doc,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
if result.modified_count > 0:
|
|
719
|
+
# Fetch and return updated user
|
|
720
|
+
updated_user = await self._collection.find_one({"email": email})
|
|
721
|
+
if updated_user:
|
|
722
|
+
logger.info(f"Updated metadata for user '{email}': {list(metadata.keys())}")
|
|
723
|
+
return self._sanitize_user(updated_user)
|
|
724
|
+
return None
|
|
725
|
+
|
|
685
726
|
@staticmethod
|
|
686
727
|
def user_has_role(
|
|
687
728
|
user: dict[str, Any],
|
mdb_engine/auth/users.py
CHANGED
|
@@ -1376,7 +1376,8 @@ async def sync_app_user_to_casbin(
|
|
|
1376
1376
|
logger.debug("sync_app_user_to_casbin: Provider is not CasbinAdapter, skipping")
|
|
1377
1377
|
return False
|
|
1378
1378
|
|
|
1379
|
-
enforcer
|
|
1379
|
+
# Access enforcer via property if available, otherwise fallback to private member
|
|
1380
|
+
enforcer = getattr(authz_provider, "enforcer", None) or authz_provider._enforcer # noqa: SLF001
|
|
1380
1381
|
|
|
1381
1382
|
# Get user ID
|
|
1382
1383
|
user_id = str(user.get("_id") or user.get("app_user_id", ""))
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebSocket Ticket Store for Multi-App SSO
|
|
3
|
+
|
|
4
|
+
Manages short-lived, single-use tickets for WebSocket authentication.
|
|
5
|
+
Tickets are exchanged for JWT tokens and consumed immediately upon validation.
|
|
6
|
+
|
|
7
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
8
|
+
|
|
9
|
+
Security Model:
|
|
10
|
+
- Tickets generated on authentication (JWT → Ticket exchange)
|
|
11
|
+
- Stored in-memory (no database)
|
|
12
|
+
- Short TTL (10 seconds default)
|
|
13
|
+
- Single-use (consumed immediately after validation)
|
|
14
|
+
- Secure-by-default for multi-app SSO setups
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
import time
|
|
20
|
+
import uuid
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
# Ticket configuration
|
|
27
|
+
DEFAULT_TICKET_TTL_SECONDS = 10 # Tickets expire after 10 seconds
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class WebSocketTicketStore:
|
|
31
|
+
"""
|
|
32
|
+
Manages WebSocket tickets using in-memory storage.
|
|
33
|
+
|
|
34
|
+
Tickets are:
|
|
35
|
+
- Generated on JWT → Ticket exchange
|
|
36
|
+
- Stored in-memory dictionary
|
|
37
|
+
- Validated and consumed immediately (single-use)
|
|
38
|
+
- Automatically expired after TTL
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, ticket_ttl_seconds: int = DEFAULT_TICKET_TTL_SECONDS):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the WebSocket ticket store.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
ticket_ttl_seconds: Ticket time-to-live in seconds (default: 10)
|
|
47
|
+
"""
|
|
48
|
+
self._tickets: dict[str, dict[str, Any]] = {}
|
|
49
|
+
self._lock = asyncio.Lock()
|
|
50
|
+
self._ticket_ttl = ticket_ttl_seconds
|
|
51
|
+
logger.info(f"Initialized WebSocket ticket store (TTL: {ticket_ttl_seconds}s)")
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def ticket_ttl(self) -> int:
|
|
55
|
+
"""
|
|
56
|
+
Get the ticket time-to-live in seconds.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Ticket TTL in seconds
|
|
60
|
+
"""
|
|
61
|
+
return self._ticket_ttl
|
|
62
|
+
|
|
63
|
+
def create_ticket(
|
|
64
|
+
self,
|
|
65
|
+
user_id: str,
|
|
66
|
+
user_email: str | None = None,
|
|
67
|
+
app_slug: str | None = None,
|
|
68
|
+
) -> str:
|
|
69
|
+
"""
|
|
70
|
+
Create a new WebSocket ticket.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
user_id: User ID
|
|
74
|
+
user_email: Optional user email
|
|
75
|
+
app_slug: Optional app slug for scoping
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Ticket UUID string
|
|
79
|
+
"""
|
|
80
|
+
ticket_id = str(uuid.uuid4())
|
|
81
|
+
expires_at = time.time() + self._ticket_ttl
|
|
82
|
+
|
|
83
|
+
ticket_data = {
|
|
84
|
+
"user_id": user_id,
|
|
85
|
+
"user_email": user_email,
|
|
86
|
+
"app_slug": app_slug,
|
|
87
|
+
"exp": expires_at,
|
|
88
|
+
"created_at": time.time(),
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
# Thread-safe ticket creation
|
|
92
|
+
self._tickets[ticket_id] = ticket_data
|
|
93
|
+
|
|
94
|
+
logger.debug(
|
|
95
|
+
f"Created WebSocket ticket for user '{user_id}' "
|
|
96
|
+
f"(app: {app_slug}, expires in {self._ticket_ttl}s)"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return ticket_id
|
|
100
|
+
|
|
101
|
+
async def validate_and_consume_ticket(self, ticket_id: str) -> dict[str, Any] | None:
|
|
102
|
+
"""
|
|
103
|
+
Validate and consume a WebSocket ticket (atomic operation).
|
|
104
|
+
|
|
105
|
+
This method validates the ticket and removes it immediately,
|
|
106
|
+
ensuring single-use behavior.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
ticket_id: Ticket UUID to validate
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Ticket data dict if valid, None otherwise
|
|
113
|
+
"""
|
|
114
|
+
async with self._lock:
|
|
115
|
+
# Check if ticket exists
|
|
116
|
+
if ticket_id not in self._tickets:
|
|
117
|
+
logger.warning(f"WebSocket ticket not found: {ticket_id[:16]}...")
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
ticket_data = self._tickets[ticket_id]
|
|
121
|
+
|
|
122
|
+
# Check expiration
|
|
123
|
+
if time.time() > ticket_data["exp"]:
|
|
124
|
+
logger.warning(
|
|
125
|
+
f"WebSocket ticket expired: {ticket_id[:16]}... "
|
|
126
|
+
f"(expired: {ticket_data['exp']})"
|
|
127
|
+
)
|
|
128
|
+
# Remove expired ticket
|
|
129
|
+
del self._tickets[ticket_id]
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
# CONSUME TICKET (atomic operation - remove immediately)
|
|
133
|
+
# This ensures single-use behavior
|
|
134
|
+
user_id = ticket_data["user_id"]
|
|
135
|
+
user_email = ticket_data["user_email"]
|
|
136
|
+
app_slug = ticket_data.get("app_slug")
|
|
137
|
+
|
|
138
|
+
# Remove ticket before returning (single-use)
|
|
139
|
+
del self._tickets[ticket_id]
|
|
140
|
+
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"Validated and consumed WebSocket ticket for user '{user_id}' "
|
|
143
|
+
f"(app: {app_slug})"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return {
|
|
147
|
+
"user_id": user_id,
|
|
148
|
+
"user_email": user_email,
|
|
149
|
+
"app_slug": app_slug,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
async def cleanup_expired_tickets(self) -> int:
|
|
153
|
+
"""
|
|
154
|
+
Clean up expired tickets.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Number of tickets cleaned up
|
|
158
|
+
"""
|
|
159
|
+
async with self._lock:
|
|
160
|
+
now = time.time()
|
|
161
|
+
expired_tickets = [
|
|
162
|
+
ticket_id
|
|
163
|
+
for ticket_id, ticket_data in self._tickets.items()
|
|
164
|
+
if ticket_data["exp"] < now
|
|
165
|
+
]
|
|
166
|
+
|
|
167
|
+
for ticket_id in expired_tickets:
|
|
168
|
+
del self._tickets[ticket_id]
|
|
169
|
+
|
|
170
|
+
if expired_tickets:
|
|
171
|
+
logger.debug(f"Cleaned up {len(expired_tickets)} expired WebSocket tickets")
|
|
172
|
+
|
|
173
|
+
return len(expired_tickets)
|
|
174
|
+
|
|
175
|
+
def get_ticket_count(self) -> int:
|
|
176
|
+
"""Get the number of active tickets."""
|
|
177
|
+
return len(self._tickets)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def create_websocket_ticket_endpoint(
|
|
181
|
+
ticket_store: WebSocketTicketStore,
|
|
182
|
+
) -> Callable:
|
|
183
|
+
"""
|
|
184
|
+
Create a FastAPI endpoint for generating WebSocket tickets.
|
|
185
|
+
|
|
186
|
+
This endpoint requires authentication and generates a new one-time ticket
|
|
187
|
+
for the authenticated user. The ticket is short-lived (10 seconds) and
|
|
188
|
+
single-use.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
ticket_store: WebSocketTicketStore instance
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
FastAPI route handler function
|
|
195
|
+
|
|
196
|
+
Example:
|
|
197
|
+
```python
|
|
198
|
+
from mdb_engine.auth.websocket_tickets import (
|
|
199
|
+
WebSocketTicketStore,
|
|
200
|
+
create_websocket_ticket_endpoint,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Initialize ticket store
|
|
204
|
+
ticket_store = WebSocketTicketStore(ticket_ttl_seconds=10)
|
|
205
|
+
|
|
206
|
+
# Create endpoint
|
|
207
|
+
endpoint = create_websocket_ticket_endpoint(ticket_store)
|
|
208
|
+
app.post("/auth/ticket")(endpoint)
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
The endpoint:
|
|
212
|
+
- Requires authentication (user must be logged in)
|
|
213
|
+
- Returns JSON: `{"ticket": "...", "expires_in": 10}`
|
|
214
|
+
- Uses user info from `request.state.user` (set by SharedAuthMiddleware)
|
|
215
|
+
"""
|
|
216
|
+
from fastapi import Request, status
|
|
217
|
+
from fastapi.responses import JSONResponse
|
|
218
|
+
|
|
219
|
+
async def websocket_ticket_endpoint(request: Request) -> JSONResponse:
|
|
220
|
+
"""
|
|
221
|
+
Generate a WebSocket ticket for the authenticated user.
|
|
222
|
+
|
|
223
|
+
Requires:
|
|
224
|
+
- User to be authenticated (via request.state.user or auth cookie)
|
|
225
|
+
- WebSocket ticket store to be available
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
- JSONResponse with ticket and expires_in
|
|
229
|
+
"""
|
|
230
|
+
# Check if user is authenticated (set by middleware)
|
|
231
|
+
user = getattr(request.state, "user", None)
|
|
232
|
+
|
|
233
|
+
# If not set by middleware, try to authenticate using cookie
|
|
234
|
+
# This handles the case where endpoint is on parent app without auth middleware
|
|
235
|
+
if not user:
|
|
236
|
+
from .shared_middleware import AUTH_COOKIE_NAME
|
|
237
|
+
|
|
238
|
+
# Get user pool from app state
|
|
239
|
+
user_pool = None
|
|
240
|
+
try:
|
|
241
|
+
if hasattr(request, "app") and hasattr(request.app, "state"):
|
|
242
|
+
user_pool = getattr(request.app.state, "user_pool", None)
|
|
243
|
+
except (AttributeError, TypeError):
|
|
244
|
+
pass
|
|
245
|
+
|
|
246
|
+
# Only try to authenticate if we have a real user pool (not None)
|
|
247
|
+
if user_pool is not None:
|
|
248
|
+
# Extract token from cookie
|
|
249
|
+
token = None
|
|
250
|
+
try:
|
|
251
|
+
if hasattr(request, "cookies"):
|
|
252
|
+
token = request.cookies.get(AUTH_COOKIE_NAME)
|
|
253
|
+
except (AttributeError, TypeError):
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
if token:
|
|
257
|
+
try:
|
|
258
|
+
# Validate token and get user
|
|
259
|
+
user = await user_pool.validate_token(token)
|
|
260
|
+
except (TypeError, AttributeError):
|
|
261
|
+
# If user_pool is a mock that can't be awaited, ignore
|
|
262
|
+
pass
|
|
263
|
+
|
|
264
|
+
if not user:
|
|
265
|
+
return JSONResponse(
|
|
266
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
267
|
+
content={"detail": "Authentication required"},
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Extract user info
|
|
271
|
+
# Prefer user_id, sub (JWT standard), or _id (MongoDB document ID)
|
|
272
|
+
user_id = user.get("user_id") or user.get("sub") or user.get("_id")
|
|
273
|
+
if not user_id:
|
|
274
|
+
# Email is not a valid user_id - it's just metadata
|
|
275
|
+
logger.error("Cannot generate WebSocket ticket: user_id not found in user data")
|
|
276
|
+
return JSONResponse(
|
|
277
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
278
|
+
content={"detail": "Invalid user data"},
|
|
279
|
+
)
|
|
280
|
+
user_email = user.get("email")
|
|
281
|
+
app_slug = getattr(request.state, "app_slug", None)
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
# Generate ticket
|
|
285
|
+
ticket = ticket_store.create_ticket(
|
|
286
|
+
user_id=str(user_id),
|
|
287
|
+
user_email=user_email,
|
|
288
|
+
app_slug=app_slug,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
logger.info(f"Generated WebSocket ticket for user '{user_id}' " f"(app: {app_slug})")
|
|
292
|
+
|
|
293
|
+
return JSONResponse(
|
|
294
|
+
{
|
|
295
|
+
"ticket": ticket,
|
|
296
|
+
"expires_in": ticket_store.ticket_ttl,
|
|
297
|
+
}
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
except (ValueError, TypeError, AttributeError, RuntimeError):
|
|
301
|
+
logger.exception("Failed to generate WebSocket ticket")
|
|
302
|
+
return JSONResponse(
|
|
303
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
304
|
+
content={"detail": "Failed to generate WebSocket ticket"},
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
return websocket_ticket_endpoint
|
|
@@ -359,6 +359,16 @@ class AppRegistrationManager:
|
|
|
359
359
|
"""
|
|
360
360
|
return self._apps.get(slug)
|
|
361
361
|
|
|
362
|
+
@property
|
|
363
|
+
def apps(self) -> dict[str, dict[str, Any]]:
|
|
364
|
+
"""
|
|
365
|
+
Get all registered apps.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
Dictionary of registered apps
|
|
369
|
+
"""
|
|
370
|
+
return self._apps
|
|
371
|
+
|
|
362
372
|
def list_apps(self) -> list[str]:
|
|
363
373
|
"""
|
|
364
374
|
List all registered app slugs.
|