mdb-engine 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -86,6 +86,16 @@ class CasbinAdapter(BaseAuthorizationProvider):
86
86
  self._cache_lock = asyncio.Lock()
87
87
  self._mark_initialized()
88
88
 
89
+ @property
90
+ def enforcer(self):
91
+ """
92
+ Get the Casbin enforcer instance.
93
+
94
+ Returns:
95
+ casbin.AsyncEnforcer instance
96
+ """
97
+ return self._enforcer
98
+
89
99
  async def check(
90
100
  self,
91
101
  subject: str,
@@ -682,6 +682,47 @@ class SharedUserPool:
682
682
  )
683
683
  return result.modified_count > 0
684
684
 
685
+ async def update_user_metadata(
686
+ self,
687
+ email: str,
688
+ metadata: dict[str, Any],
689
+ ) -> dict[str, Any] | None:
690
+ """
691
+ Update user metadata fields.
692
+
693
+ This allows adding or updating custom fields on the user document
694
+ beyond the core schema (e.g., name, profile data, preferences).
695
+
696
+ Args:
697
+ email: User email
698
+ metadata: Dictionary of fields to update
699
+ (e.g., {"name": "John Doe", "preferences": {...}})
700
+
701
+ Returns:
702
+ Updated user document (without password_hash) or None if user not found
703
+
704
+ Example:
705
+ user = await pool.update_user_metadata(
706
+ "user@example.com",
707
+ {"name": "John Doe", "phone": "+1234567890"}
708
+ )
709
+ """
710
+ # Build update document, ensuring updated_at is always set
711
+ update_doc = {"$set": {**metadata, "updated_at": datetime.utcnow()}}
712
+
713
+ result = await self._collection.update_one(
714
+ {"email": email},
715
+ update_doc,
716
+ )
717
+
718
+ if result.modified_count > 0:
719
+ # Fetch and return updated user
720
+ updated_user = await self._collection.find_one({"email": email})
721
+ if updated_user:
722
+ logger.info(f"Updated metadata for user '{email}': {list(metadata.keys())}")
723
+ return self._sanitize_user(updated_user)
724
+ return None
725
+
685
726
  @staticmethod
686
727
  def user_has_role(
687
728
  user: dict[str, Any],
mdb_engine/auth/users.py CHANGED
@@ -1376,7 +1376,8 @@ async def sync_app_user_to_casbin(
1376
1376
  logger.debug("sync_app_user_to_casbin: Provider is not CasbinAdapter, skipping")
1377
1377
  return False
1378
1378
 
1379
- enforcer = authz_provider._enforcer
1379
+ # Access enforcer via property if available, otherwise fallback to private member
1380
+ enforcer = getattr(authz_provider, "enforcer", None) or authz_provider._enforcer # noqa: SLF001
1380
1381
 
1381
1382
  # Get user ID
1382
1383
  user_id = str(user.get("_id") or user.get("app_user_id", ""))
@@ -0,0 +1,307 @@
1
+ """
2
+ WebSocket Ticket Store for Multi-App SSO
3
+
4
+ Manages short-lived, single-use tickets for WebSocket authentication.
5
+ Tickets are exchanged for JWT tokens and consumed immediately upon validation.
6
+
7
+ This module is part of MDB_ENGINE - MongoDB Engine.
8
+
9
+ Security Model:
10
+ - Tickets generated on authentication (JWT → Ticket exchange)
11
+ - Stored in-memory (no database)
12
+ - Short TTL (10 seconds default)
13
+ - Single-use (consumed immediately after validation)
14
+ - Secure-by-default for multi-app SSO setups
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ import time
20
+ import uuid
21
+ from collections.abc import Callable
22
+ from typing import Any
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Ticket configuration
27
+ DEFAULT_TICKET_TTL_SECONDS = 10 # Tickets expire after 10 seconds
28
+
29
+
30
+ class WebSocketTicketStore:
31
+ """
32
+ Manages WebSocket tickets using in-memory storage.
33
+
34
+ Tickets are:
35
+ - Generated on JWT → Ticket exchange
36
+ - Stored in-memory dictionary
37
+ - Validated and consumed immediately (single-use)
38
+ - Automatically expired after TTL
39
+ """
40
+
41
+ def __init__(self, ticket_ttl_seconds: int = DEFAULT_TICKET_TTL_SECONDS):
42
+ """
43
+ Initialize the WebSocket ticket store.
44
+
45
+ Args:
46
+ ticket_ttl_seconds: Ticket time-to-live in seconds (default: 10)
47
+ """
48
+ self._tickets: dict[str, dict[str, Any]] = {}
49
+ self._lock = asyncio.Lock()
50
+ self._ticket_ttl = ticket_ttl_seconds
51
+ logger.info(f"Initialized WebSocket ticket store (TTL: {ticket_ttl_seconds}s)")
52
+
53
+ @property
54
+ def ticket_ttl(self) -> int:
55
+ """
56
+ Get the ticket time-to-live in seconds.
57
+
58
+ Returns:
59
+ Ticket TTL in seconds
60
+ """
61
+ return self._ticket_ttl
62
+
63
+ def create_ticket(
64
+ self,
65
+ user_id: str,
66
+ user_email: str | None = None,
67
+ app_slug: str | None = None,
68
+ ) -> str:
69
+ """
70
+ Create a new WebSocket ticket.
71
+
72
+ Args:
73
+ user_id: User ID
74
+ user_email: Optional user email
75
+ app_slug: Optional app slug for scoping
76
+
77
+ Returns:
78
+ Ticket UUID string
79
+ """
80
+ ticket_id = str(uuid.uuid4())
81
+ expires_at = time.time() + self._ticket_ttl
82
+
83
+ ticket_data = {
84
+ "user_id": user_id,
85
+ "user_email": user_email,
86
+ "app_slug": app_slug,
87
+ "exp": expires_at,
88
+ "created_at": time.time(),
89
+ }
90
+
91
+ # Thread-safe ticket creation
92
+ self._tickets[ticket_id] = ticket_data
93
+
94
+ logger.debug(
95
+ f"Created WebSocket ticket for user '{user_id}' "
96
+ f"(app: {app_slug}, expires in {self._ticket_ttl}s)"
97
+ )
98
+
99
+ return ticket_id
100
+
101
+ async def validate_and_consume_ticket(self, ticket_id: str) -> dict[str, Any] | None:
102
+ """
103
+ Validate and consume a WebSocket ticket (atomic operation).
104
+
105
+ This method validates the ticket and removes it immediately,
106
+ ensuring single-use behavior.
107
+
108
+ Args:
109
+ ticket_id: Ticket UUID to validate
110
+
111
+ Returns:
112
+ Ticket data dict if valid, None otherwise
113
+ """
114
+ async with self._lock:
115
+ # Check if ticket exists
116
+ if ticket_id not in self._tickets:
117
+ logger.warning(f"WebSocket ticket not found: {ticket_id[:16]}...")
118
+ return None
119
+
120
+ ticket_data = self._tickets[ticket_id]
121
+
122
+ # Check expiration
123
+ if time.time() > ticket_data["exp"]:
124
+ logger.warning(
125
+ f"WebSocket ticket expired: {ticket_id[:16]}... "
126
+ f"(expired: {ticket_data['exp']})"
127
+ )
128
+ # Remove expired ticket
129
+ del self._tickets[ticket_id]
130
+ return None
131
+
132
+ # CONSUME TICKET (atomic operation - remove immediately)
133
+ # This ensures single-use behavior
134
+ user_id = ticket_data["user_id"]
135
+ user_email = ticket_data["user_email"]
136
+ app_slug = ticket_data.get("app_slug")
137
+
138
+ # Remove ticket before returning (single-use)
139
+ del self._tickets[ticket_id]
140
+
141
+ logger.debug(
142
+ f"Validated and consumed WebSocket ticket for user '{user_id}' "
143
+ f"(app: {app_slug})"
144
+ )
145
+
146
+ return {
147
+ "user_id": user_id,
148
+ "user_email": user_email,
149
+ "app_slug": app_slug,
150
+ }
151
+
152
+ async def cleanup_expired_tickets(self) -> int:
153
+ """
154
+ Clean up expired tickets.
155
+
156
+ Returns:
157
+ Number of tickets cleaned up
158
+ """
159
+ async with self._lock:
160
+ now = time.time()
161
+ expired_tickets = [
162
+ ticket_id
163
+ for ticket_id, ticket_data in self._tickets.items()
164
+ if ticket_data["exp"] < now
165
+ ]
166
+
167
+ for ticket_id in expired_tickets:
168
+ del self._tickets[ticket_id]
169
+
170
+ if expired_tickets:
171
+ logger.debug(f"Cleaned up {len(expired_tickets)} expired WebSocket tickets")
172
+
173
+ return len(expired_tickets)
174
+
175
+ def get_ticket_count(self) -> int:
176
+ """Get the number of active tickets."""
177
+ return len(self._tickets)
178
+
179
+
180
+ def create_websocket_ticket_endpoint(
181
+ ticket_store: WebSocketTicketStore,
182
+ ) -> Callable:
183
+ """
184
+ Create a FastAPI endpoint for generating WebSocket tickets.
185
+
186
+ This endpoint requires authentication and generates a new one-time ticket
187
+ for the authenticated user. The ticket is short-lived (10 seconds) and
188
+ single-use.
189
+
190
+ Args:
191
+ ticket_store: WebSocketTicketStore instance
192
+
193
+ Returns:
194
+ FastAPI route handler function
195
+
196
+ Example:
197
+ ```python
198
+ from mdb_engine.auth.websocket_tickets import (
199
+ WebSocketTicketStore,
200
+ create_websocket_ticket_endpoint,
201
+ )
202
+
203
+ # Initialize ticket store
204
+ ticket_store = WebSocketTicketStore(ticket_ttl_seconds=10)
205
+
206
+ # Create endpoint
207
+ endpoint = create_websocket_ticket_endpoint(ticket_store)
208
+ app.post("/auth/ticket")(endpoint)
209
+ ```
210
+
211
+ The endpoint:
212
+ - Requires authentication (user must be logged in)
213
+ - Returns JSON: `{"ticket": "...", "expires_in": 10}`
214
+ - Uses user info from `request.state.user` (set by SharedAuthMiddleware)
215
+ """
216
+ from fastapi import Request, status
217
+ from fastapi.responses import JSONResponse
218
+
219
+ async def websocket_ticket_endpoint(request: Request) -> JSONResponse:
220
+ """
221
+ Generate a WebSocket ticket for the authenticated user.
222
+
223
+ Requires:
224
+ - User to be authenticated (via request.state.user or auth cookie)
225
+ - WebSocket ticket store to be available
226
+
227
+ Returns:
228
+ - JSONResponse with ticket and expires_in
229
+ """
230
+ # Check if user is authenticated (set by middleware)
231
+ user = getattr(request.state, "user", None)
232
+
233
+ # If not set by middleware, try to authenticate using cookie
234
+ # This handles the case where endpoint is on parent app without auth middleware
235
+ if not user:
236
+ from .shared_middleware import AUTH_COOKIE_NAME
237
+
238
+ # Get user pool from app state
239
+ user_pool = None
240
+ try:
241
+ if hasattr(request, "app") and hasattr(request.app, "state"):
242
+ user_pool = getattr(request.app.state, "user_pool", None)
243
+ except (AttributeError, TypeError):
244
+ pass
245
+
246
+ # Only try to authenticate if we have a real user pool (not None)
247
+ if user_pool is not None:
248
+ # Extract token from cookie
249
+ token = None
250
+ try:
251
+ if hasattr(request, "cookies"):
252
+ token = request.cookies.get(AUTH_COOKIE_NAME)
253
+ except (AttributeError, TypeError):
254
+ pass
255
+
256
+ if token:
257
+ try:
258
+ # Validate token and get user
259
+ user = await user_pool.validate_token(token)
260
+ except (TypeError, AttributeError):
261
+ # If user_pool is a mock that can't be awaited, ignore
262
+ pass
263
+
264
+ if not user:
265
+ return JSONResponse(
266
+ status_code=status.HTTP_401_UNAUTHORIZED,
267
+ content={"detail": "Authentication required"},
268
+ )
269
+
270
+ # Extract user info
271
+ # Prefer user_id, sub (JWT standard), or _id (MongoDB document ID)
272
+ user_id = user.get("user_id") or user.get("sub") or user.get("_id")
273
+ if not user_id:
274
+ # Email is not a valid user_id - it's just metadata
275
+ logger.error("Cannot generate WebSocket ticket: user_id not found in user data")
276
+ return JSONResponse(
277
+ status_code=status.HTTP_400_BAD_REQUEST,
278
+ content={"detail": "Invalid user data"},
279
+ )
280
+ user_email = user.get("email")
281
+ app_slug = getattr(request.state, "app_slug", None)
282
+
283
+ try:
284
+ # Generate ticket
285
+ ticket = ticket_store.create_ticket(
286
+ user_id=str(user_id),
287
+ user_email=user_email,
288
+ app_slug=app_slug,
289
+ )
290
+
291
+ logger.info(f"Generated WebSocket ticket for user '{user_id}' " f"(app: {app_slug})")
292
+
293
+ return JSONResponse(
294
+ {
295
+ "ticket": ticket,
296
+ "expires_in": ticket_store.ticket_ttl,
297
+ }
298
+ )
299
+
300
+ except (ValueError, TypeError, AttributeError, RuntimeError):
301
+ logger.exception("Failed to generate WebSocket ticket")
302
+ return JSONResponse(
303
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
304
+ content={"detail": "Failed to generate WebSocket ticket"},
305
+ )
306
+
307
+ return websocket_ticket_endpoint
@@ -359,6 +359,16 @@ class AppRegistrationManager:
359
359
  """
360
360
  return self._apps.get(slug)
361
361
 
362
+ @property
363
+ def apps(self) -> dict[str, dict[str, Any]]:
364
+ """
365
+ Get all registered apps.
366
+
367
+ Returns:
368
+ Dictionary of registered apps
369
+ """
370
+ return self._apps
371
+
362
372
  def list_apps(self) -> list[str]:
363
373
  """
364
374
  List all registered app slugs.