alchemist-nrel 0.2.1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. alchemist_core/__init__.py +14 -7
  2. alchemist_core/acquisition/botorch_acquisition.py +15 -6
  3. alchemist_core/audit_log.py +594 -0
  4. alchemist_core/data/experiment_manager.py +76 -5
  5. alchemist_core/models/botorch_model.py +6 -4
  6. alchemist_core/models/sklearn_model.py +74 -8
  7. alchemist_core/session.py +788 -39
  8. alchemist_core/utils/doe.py +200 -0
  9. alchemist_nrel-0.3.1.dist-info/METADATA +185 -0
  10. alchemist_nrel-0.3.1.dist-info/RECORD +66 -0
  11. {alchemist_nrel-0.2.1.dist-info → alchemist_nrel-0.3.1.dist-info}/entry_points.txt +1 -0
  12. api/example_client.py +7 -2
  13. api/main.py +21 -4
  14. api/models/requests.py +95 -1
  15. api/models/responses.py +167 -0
  16. api/routers/acquisition.py +25 -0
  17. api/routers/experiments.py +134 -6
  18. api/routers/sessions.py +438 -10
  19. api/routers/visualizations.py +10 -5
  20. api/routers/websocket.py +132 -0
  21. api/run_api.py +56 -0
  22. api/services/session_store.py +285 -54
  23. api/static/NEW_ICON.ico +0 -0
  24. api/static/NEW_ICON.png +0 -0
  25. api/static/NEW_LOGO_DARK.png +0 -0
  26. api/static/NEW_LOGO_LIGHT.png +0 -0
  27. api/static/assets/api-vcoXEqyq.js +1 -0
  28. api/static/assets/index-DWfIKU9j.js +4094 -0
  29. api/static/assets/index-sMIa_1hV.css +1 -0
  30. api/static/index.html +14 -0
  31. api/static/vite.svg +1 -0
  32. ui/gpr_panel.py +7 -2
  33. ui/notifications.py +197 -10
  34. ui/ui.py +1117 -68
  35. ui/variables_setup.py +47 -2
  36. ui/visualizations.py +60 -3
  37. alchemist_core/models/ax_model.py +0 -159
  38. alchemist_nrel-0.2.1.dist-info/METADATA +0 -206
  39. alchemist_nrel-0.2.1.dist-info/RECORD +0 -54
  40. {alchemist_nrel-0.2.1.dist-info → alchemist_nrel-0.3.1.dist-info}/WHEEL +0 -0
  41. {alchemist_nrel-0.2.1.dist-info → alchemist_nrel-0.3.1.dist-info}/licenses/LICENSE +0 -0
  42. {alchemist_nrel-0.2.1.dist-info → alchemist_nrel-0.3.1.dist-info}/top_level.txt +0 -0
api/routers/sessions.py CHANGED
@@ -2,13 +2,24 @@
2
2
  Sessions router - Session lifecycle management.
3
3
  """
4
4
 
5
- from fastapi import APIRouter, HTTPException, status, UploadFile, File
6
- from fastapi.responses import Response
7
- from ..models.responses import SessionCreateResponse, SessionInfoResponse
5
+ from fastapi import APIRouter, HTTPException, status, UploadFile, File, Depends
6
+ from fastapi.responses import Response, FileResponse, JSONResponse
7
+ from typing import Optional
8
+ from ..models.requests import UpdateMetadataRequest, LockDecisionRequest, SessionLockRequest
9
+ from ..models.responses import (
10
+ SessionCreateResponse, SessionInfoResponse, SessionStateResponse,
11
+ SessionMetadataResponse, AuditLogResponse, AuditEntryResponse, LockDecisionResponse,
12
+ SessionLockResponse
13
+ )
14
+ from .websocket import broadcast_to_session
8
15
  from ..services import session_store
9
16
  from ..dependencies import get_session
17
+ from alchemist_core.session import OptimizationSession
10
18
  from datetime import datetime
11
19
  import logging
20
+ import json
21
+ from pathlib import Path
22
+ import tempfile
12
23
 
13
24
  logger = logging.getLogger(__name__)
14
25
 
@@ -50,6 +61,36 @@ async def get_session_info(session_id: str):
50
61
  return SessionInfoResponse(**info)
51
62
 
52
63
 
64
+ @router.get("/sessions/{session_id}/state", response_model=SessionStateResponse)
65
+ async def get_session_state(
66
+ session_id: str,
67
+ session: OptimizationSession = Depends(get_session)
68
+ ):
69
+ """
70
+ Get current session state for monitoring autonomous optimization.
71
+
72
+ Returns key metrics for dashboard displays or autonomous controllers
73
+ to monitor optimization progress without retrieving full session data.
74
+ """
75
+ # Get session metrics
76
+ n_variables = len(session.search_space.variables)
77
+ n_experiments = len(session.experiment_manager.df)
78
+ model_trained = session.model is not None
79
+
80
+ # Get last suggestion if available
81
+ last_suggestion = None
82
+ if hasattr(session, '_last_suggestion') and session._last_suggestion:
83
+ last_suggestion = session._last_suggestion
84
+
85
+ return SessionStateResponse(
86
+ session_id=session_id,
87
+ n_variables=n_variables,
88
+ n_experiments=n_experiments,
89
+ model_trained=model_trained,
90
+ last_suggestion=last_suggestion
91
+ )
92
+
93
+
53
94
  @router.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
54
95
  async def delete_session(session_id: str):
55
96
  """
@@ -89,13 +130,30 @@ async def extend_session(session_id: str, hours: int = 24):
89
130
  }
90
131
 
91
132
 
133
+ @router.post("/sessions/{session_id}/save", status_code=status.HTTP_200_OK)
134
+ async def save_session_server_side(session_id: str):
135
+ """
136
+ Persist the current in-memory session to the server-side session file.
137
+
138
+ This allows the web UI to save changes directly to the session store file
139
+ instead of triggering a browser download.
140
+ """
141
+ success = session_store.persist_session_to_disk(session_id)
142
+ if not success:
143
+ raise HTTPException(
144
+ status_code=status.HTTP_404_NOT_FOUND,
145
+ detail=f"Session {session_id} not found or failed to save"
146
+ )
147
+ return {"message": "Session persisted to server storage"}
148
+
149
+
92
150
  @router.get("/sessions/{session_id}/export")
93
151
  async def export_session(session_id: str):
94
152
  """
95
153
  Export a session for download.
96
154
 
97
- Downloads the complete session state as a .pkl file that can be
98
- reimported later.
155
+ Downloads the complete session state as a .json file that can be
156
+ reimported later or used in desktop application.
99
157
  """
100
158
  session_data = session_store.export_session(session_id)
101
159
  if session_data is None:
@@ -106,9 +164,9 @@ async def export_session(session_id: str):
106
164
 
107
165
  return Response(
108
166
  content=session_data,
109
- media_type="application/octet-stream",
167
+ media_type="application/json",
110
168
  headers={
111
- "Content-Disposition": f"attachment; filename=session_{session_id}.pkl"
169
+ "Content-Disposition": f"attachment; filename=session_{session_id}.json"
112
170
  }
113
171
  )
114
172
 
@@ -118,12 +176,14 @@ async def import_session(file: UploadFile = File(...)):
118
176
  """
119
177
  Import a previously exported session.
120
178
 
121
- Uploads a .pkl session file and creates a new session with the imported data.
122
- A new session ID will be generated.
179
+ Uploads a .json session file and creates a new session with the imported data.
180
+ A new session ID will be generated. Compatible with desktop application sessions.
123
181
  """
124
182
  try:
125
183
  session_data = await file.read()
126
- session_id = session_store.import_session(session_data)
184
+ # Decode bytes to string for JSON
185
+ session_json = session_data.decode('utf-8')
186
+ session_id = session_store.import_session(session_json)
127
187
 
128
188
  if session_id is None:
129
189
  raise HTTPException(
@@ -144,3 +204,371 @@ async def import_session(file: UploadFile = File(...)):
144
204
  status_code=status.HTTP_400_BAD_REQUEST,
145
205
  detail=f"Failed to import session: {str(e)}"
146
206
  )
207
+
208
+
209
+ # ============================================================
210
+ # Metadata Management Endpoints
211
+ # ============================================================
212
+
213
+ @router.get("/sessions/{session_id}/metadata", response_model=SessionMetadataResponse)
214
+ async def get_metadata(
215
+ session_id: str,
216
+ session: OptimizationSession = Depends(get_session)
217
+ ):
218
+ """
219
+ Get session metadata.
220
+
221
+ Returns the session's user-friendly name, description, tags, and timestamps.
222
+ """
223
+ return SessionMetadataResponse(
224
+ session_id=session.metadata.session_id,
225
+ name=session.metadata.name,
226
+ created_at=session.metadata.created_at,
227
+ last_modified=session.metadata.last_modified,
228
+ description=session.metadata.description,
229
+ tags=session.metadata.tags
230
+ )
231
+
232
+
233
+ @router.patch("/sessions/{session_id}/metadata", response_model=SessionMetadataResponse)
234
+ async def update_metadata(
235
+ session_id: str,
236
+ request: UpdateMetadataRequest,
237
+ session: OptimizationSession = Depends(get_session)
238
+ ):
239
+ """
240
+ Update session metadata.
241
+
242
+ Update the session's name, description, and/or tags. Only provided fields
243
+ will be updated; omitted fields remain unchanged.
244
+ """
245
+ session.update_metadata(
246
+ name=request.name,
247
+ description=request.description,
248
+ tags=request.tags
249
+ )
250
+
251
+ return SessionMetadataResponse(
252
+ session_id=session.metadata.session_id,
253
+ name=session.metadata.name,
254
+ created_at=session.metadata.created_at,
255
+ last_modified=session.metadata.last_modified,
256
+ description=session.metadata.description,
257
+ tags=session.metadata.tags
258
+ )
259
+
260
+
261
+ # ============================================================
262
+ # Audit Log Endpoints
263
+ # ============================================================
264
+
265
+ @router.get("/sessions/{session_id}/audit", response_model=AuditLogResponse)
266
+ async def get_audit_log(
267
+ session_id: str,
268
+ entry_type: str = None,
269
+ session: OptimizationSession = Depends(get_session)
270
+ ):
271
+ """
272
+ Get audit log entries.
273
+
274
+ Retrieves the complete audit trail or filters by entry type.
275
+
276
+ Args:
277
+ session_id: Session identifier
278
+ entry_type: Optional filter ('data_locked', 'model_locked', 'acquisition_locked')
279
+ """
280
+ if entry_type:
281
+ entries = session.audit_log.get_entries(entry_type)
282
+ else:
283
+ entries = session.audit_log.get_entries()
284
+
285
+ return AuditLogResponse(
286
+ entries=[AuditEntryResponse(**e.to_dict()) for e in entries],
287
+ n_entries=len(entries)
288
+ )
289
+
290
+
291
+ @router.post("/sessions/{session_id}/audit/lock", response_model=LockDecisionResponse)
292
+ async def lock_decision(
293
+ session_id: str,
294
+ request: LockDecisionRequest,
295
+ session: OptimizationSession = Depends(get_session)
296
+ ):
297
+ """
298
+ Lock in a decision to the audit log.
299
+
300
+ Creates an immutable audit entry for data, model, or acquisition decisions.
301
+ This should be called when the user is satisfied with their configuration
302
+ and ready to commit the decision to the audit trail.
303
+
304
+ Args:
305
+ session_id: Session identifier
306
+ request: Lock decision request
307
+ """
308
+ try:
309
+ if request.lock_type == "data":
310
+ entry = session.lock_data(notes=request.notes or "")
311
+ message = "Data decision locked successfully"
312
+
313
+ elif request.lock_type == "model":
314
+ entry = session.lock_model(notes=request.notes or "")
315
+ message = "Model decision locked successfully"
316
+
317
+ elif request.lock_type == "acquisition":
318
+ if not request.strategy or not request.parameters or not request.suggestions:
319
+ raise HTTPException(
320
+ status_code=status.HTTP_400_BAD_REQUEST,
321
+ detail="Acquisition lock requires strategy, parameters, and suggestions"
322
+ )
323
+ entry = session.lock_acquisition(
324
+ strategy=request.strategy,
325
+ parameters=request.parameters,
326
+ suggestions=request.suggestions,
327
+ notes=request.notes or ""
328
+ )
329
+ message = "Acquisition decision locked successfully"
330
+
331
+ else:
332
+ raise HTTPException(
333
+ status_code=status.HTTP_400_BAD_REQUEST,
334
+ detail=f"Invalid lock_type: {request.lock_type}"
335
+ )
336
+
337
+ return LockDecisionResponse(
338
+ success=True,
339
+ entry=AuditEntryResponse(**entry.to_dict()),
340
+ message=message
341
+ )
342
+
343
+ except ValueError as e:
344
+ raise HTTPException(
345
+ status_code=status.HTTP_400_BAD_REQUEST,
346
+ detail=str(e)
347
+ )
348
+
349
+
350
+ @router.get("/sessions/{session_id}/audit/export")
351
+ async def export_audit_markdown(
352
+ session_id: str,
353
+ session: OptimizationSession = Depends(get_session)
354
+ ):
355
+ """
356
+ Export audit log as markdown.
357
+
358
+ Returns the audit trail formatted as markdown for publication methods sections.
359
+ """
360
+ markdown = session.export_audit_markdown()
361
+
362
+ return Response(
363
+ content=markdown,
364
+ media_type="text/markdown",
365
+ headers={
366
+ "Content-Disposition": f"attachment; filename=audit_log_{session_id}.md"
367
+ }
368
+ )
369
+
370
+
371
+ # ============================================================
372
+ # Session File Management (JSON Format)
373
+ # ============================================================
374
+
375
+ @router.get("/sessions/{session_id}/download")
376
+ async def download_session(
377
+ session_id: str,
378
+ session: OptimizationSession = Depends(get_session)
379
+ ):
380
+ """
381
+ Download session as JSON file.
382
+
383
+ Downloads the complete session state as a .json file with user-friendly
384
+ naming support. The file includes metadata, audit log, search space,
385
+ experiments, and configuration.
386
+ """
387
+ # Create temporary file
388
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
389
+ temp_path = f.name
390
+
391
+ try:
392
+ # Save session to temp file
393
+ session.save_session(temp_path)
394
+
395
+ # Use session name for filename (sanitized)
396
+ filename = session.metadata.name.replace(" ", "_").replace("/", "_")
397
+ filename = f"{filename}.json"
398
+
399
+ return FileResponse(
400
+ path=temp_path,
401
+ media_type="application/json",
402
+ filename=filename,
403
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
404
+ )
405
+ except Exception as e:
406
+ # Clean up temp file on error
407
+ Path(temp_path).unlink(missing_ok=True)
408
+ raise HTTPException(
409
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
410
+ detail=f"Failed to export session: {str(e)}"
411
+ )
412
+
413
+
414
+ @router.post("/sessions/upload", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED)
415
+ async def upload_session(file: UploadFile = File(...)):
416
+ """
417
+ Upload and restore a session from JSON file.
418
+
419
+ Uploads a .json session file and creates a new session with the restored data.
420
+ A new session ID will be generated for API use, but the original session ID
421
+ is preserved in the metadata.
422
+ """
423
+ try:
424
+ # Save uploaded file to temp location
425
+ with tempfile.NamedTemporaryFile(mode='wb', suffix='.json', delete=False) as f:
426
+ content = await file.read()
427
+ f.write(content)
428
+ temp_path = f.name
429
+
430
+ try:
431
+ # Load session from file without retraining
432
+ loaded_session = OptimizationSession.load_session(temp_path, retrain_on_load=False)
433
+
434
+ # Create new session in store
435
+ new_session_id = session_store.create()
436
+
437
+ # Replace the session object with loaded one and align metadata
438
+ try:
439
+ loaded_session.metadata.session_id = new_session_id
440
+ except Exception:
441
+ pass
442
+
443
+ session_store._sessions[new_session_id]["session"] = loaded_session
444
+
445
+ # Update last accessed
446
+ session_store._sessions[new_session_id]["last_accessed"] = datetime.now()
447
+
448
+ # Persist to disk
449
+ session_store._save_to_disk(new_session_id)
450
+
451
+ session_info = session_store.get_info(new_session_id)
452
+
453
+ return SessionCreateResponse(
454
+ session_id=new_session_id,
455
+ created_at=session_info["created_at"],
456
+ expires_at=session_info["expires_at"]
457
+ )
458
+
459
+ finally:
460
+ # Clean up temp file
461
+ Path(temp_path).unlink(missing_ok=True)
462
+
463
+ except Exception as e:
464
+ logger.error(f"Failed to upload session: {e}")
465
+ raise HTTPException(
466
+ status_code=status.HTTP_400_BAD_REQUEST,
467
+ detail=f"Failed to upload session: {str(e)}"
468
+ )
469
+
470
+
471
+ # ============================================================
472
+ # Session Locking Endpoints
473
+ # ============================================================
474
+
475
+ @router.post("/sessions/{session_id}/lock", response_model=SessionLockResponse)
476
+ async def lock_session(
477
+ session_id: str,
478
+ request: SessionLockRequest
479
+ ):
480
+ """
481
+ Lock a session for external programmatic control.
482
+
483
+ When locked, the web UI should enter monitor-only mode.
484
+ Returns a lock_token that must be used to unlock.
485
+ """
486
+ try:
487
+ result = session_store.lock_session(
488
+ session_id=session_id,
489
+ locked_by=request.locked_by,
490
+ client_id=request.client_id
491
+ )
492
+
493
+ # Broadcast lock event to WebSocket clients
494
+ await broadcast_to_session(session_id, {
495
+ "event": "lock_status_changed",
496
+ "locked": True,
497
+ "locked_by": request.locked_by,
498
+ "locked_at": result["locked_at"]
499
+ })
500
+
501
+ return SessionLockResponse(**result)
502
+ except KeyError:
503
+ raise HTTPException(
504
+ status_code=status.HTTP_404_NOT_FOUND,
505
+ detail=f"Session {session_id} not found or expired"
506
+ )
507
+ except Exception as e:
508
+ raise HTTPException(
509
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
510
+ detail=f"Failed to lock session: {str(e)}"
511
+ )
512
+
513
+
514
+ @router.delete("/sessions/{session_id}/lock", response_model=SessionLockResponse)
515
+ async def unlock_session(
516
+ session_id: str,
517
+ lock_token: Optional[str] = None
518
+ ):
519
+ """
520
+ Unlock a session.
521
+
522
+ Optionally provide lock_token for verification.
523
+ If no token provided, forcibly unlocks (use with caution).
524
+ """
525
+ try:
526
+ result = session_store.unlock_session(session_id=session_id, lock_token=lock_token)
527
+
528
+ # Broadcast unlock event to WebSocket clients
529
+ await broadcast_to_session(session_id, {
530
+ "event": "lock_status_changed",
531
+ "locked": False,
532
+ "locked_by": None,
533
+ "locked_at": None
534
+ })
535
+
536
+ return SessionLockResponse(**result)
537
+ except KeyError:
538
+ raise HTTPException(
539
+ status_code=status.HTTP_404_NOT_FOUND,
540
+ detail=f"Session {session_id} not found or expired"
541
+ )
542
+ except ValueError as e:
543
+ raise HTTPException(
544
+ status_code=status.HTTP_403_FORBIDDEN,
545
+ detail=str(e)
546
+ )
547
+ except Exception as e:
548
+ raise HTTPException(
549
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
550
+ detail=f"Failed to unlock session: {str(e)}"
551
+ )
552
+
553
+
554
+ @router.get("/sessions/{session_id}/lock", response_model=SessionLockResponse)
555
+ async def get_lock_status(session_id: str):
556
+ """
557
+ Get current lock status of a session.
558
+
559
+ Used by web UI to detect when external controller has taken control
560
+ and automatically enter monitor mode.
561
+ """
562
+ try:
563
+ result = session_store.get_lock_status(session_id=session_id)
564
+ return SessionLockResponse(**result)
565
+ except KeyError:
566
+ raise HTTPException(
567
+ status_code=status.HTTP_404_NOT_FOUND,
568
+ detail=f"Session {session_id} not found or expired"
569
+ )
570
+ except Exception as e:
571
+ raise HTTPException(
572
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
573
+ detail=f"Failed to get lock status: {str(e)}"
574
+ )
@@ -192,12 +192,17 @@ async def get_contour_data(
192
192
  grid_df = pd.DataFrame(grid_points)
193
193
 
194
194
  # CRITICAL FIX: Reorder columns to match training data
195
- # The model was trained with a specific column order, we must match it
195
+ # The model was trained with a specific column order, we must match it.
196
+ # Exclude metadata columns that are part of the experiments table but
197
+ # are not model input features (e.g., Iteration, Reason, Output, Noise).
196
198
  train_data = session.experiment_manager.get_data()
197
- train_columns = [col for col in train_data.columns if col != 'Output']
198
-
199
- # Reorder grid_df to match training column order
200
- grid_df = grid_df[train_columns]
199
+ metadata_cols = {'Iteration', 'Reason', 'Output', 'Noise'}
200
+ feature_cols = [col for col in train_data.columns if col not in metadata_cols]
201
+
202
+ # Safely align the prediction grid to the model feature order.
203
+ # Use reindex so missing columns (shouldn't happen) are filled with the
204
+ # midpoint/defaults the grid already supplies; this avoids KeyError.
205
+ grid_df = grid_df.reindex(columns=feature_cols)
201
206
 
202
207
  # IMPORTANT: The model's predict() method handles preprocessing internally
203
208
  # (including categorical encoding), so we can pass the raw DataFrame directly
@@ -0,0 +1,132 @@
1
+ """
2
+ WebSocket router for real-time session updates.
3
+
4
+ Provides real-time push notifications for session events like lock status changes,
5
+ eliminating the need for client-side polling.
6
+ """
7
+
8
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
9
+ from typing import Dict, Set
10
+ import json
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ router = APIRouter()
16
+
17
+ # Track active WebSocket connections per session
18
+ # Structure: {session_id: {websocket1, websocket2, ...}}
19
+ active_connections: Dict[str, Set[WebSocket]] = {}
20
+
21
+
22
+ @router.websocket("/ws/sessions/{session_id}")
23
+ async def websocket_endpoint(websocket: WebSocket, session_id: str):
24
+ """
25
+ WebSocket endpoint for real-time session updates.
26
+
27
+ Clients connect to this endpoint to receive push notifications about
28
+ session events (lock status changes, experiment additions, etc.).
29
+
30
+ Args:
31
+ websocket: WebSocket connection
32
+ session_id: Session ID to subscribe to
33
+ """
34
+ await websocket.accept()
35
+ logger.info(f"WebSocket connected: session_id={session_id}")
36
+
37
+ # Register this connection for this session
38
+ if session_id not in active_connections:
39
+ active_connections[session_id] = set()
40
+ active_connections[session_id].add(websocket)
41
+
42
+ try:
43
+ # Send initial connection confirmation
44
+ await websocket.send_json({
45
+ "event": "connected",
46
+ "session_id": session_id,
47
+ "message": "WebSocket connection established"
48
+ })
49
+
50
+ # Keep connection alive and listen for client messages
51
+ while True:
52
+ # Receive messages from client (for future bi-directional features)
53
+ data = await websocket.receive_text()
54
+
55
+ # Echo back for debugging (can be removed in production)
56
+ try:
57
+ message = json.loads(data)
58
+ logger.debug(f"Received from client: {message}")
59
+ except json.JSONDecodeError:
60
+ logger.warning(f"Invalid JSON from client: {data}")
61
+
62
+ except WebSocketDisconnect:
63
+ logger.info(f"WebSocket disconnected: session_id={session_id}")
64
+ finally:
65
+ # Clean up on disconnect
66
+ if session_id in active_connections:
67
+ active_connections[session_id].discard(websocket)
68
+
69
+ # Remove session entry if no more connections
70
+ if not active_connections[session_id]:
71
+ del active_connections[session_id]
72
+ logger.debug(f"No more connections for session {session_id}")
73
+
74
+
75
+ async def broadcast_to_session(session_id: str, event: dict):
76
+ """
77
+ Broadcast an event to all WebSocket clients connected to a session.
78
+
79
+ Args:
80
+ session_id: Session ID to broadcast to
81
+ event: Event data to send (will be JSON serialized)
82
+ """
83
+ if session_id not in active_connections:
84
+ logger.debug(f"No active connections for session {session_id}")
85
+ return
86
+
87
+ # Track dead connections
88
+ dead_connections = set()
89
+
90
+ # Send to all connected clients
91
+ for connection in active_connections[session_id]:
92
+ try:
93
+ await connection.send_json(event)
94
+ logger.debug(f"Broadcast to session {session_id}: {event.get('event')}")
95
+ except Exception as e:
96
+ logger.warning(f"Failed to send to connection: {e}")
97
+ dead_connections.add(connection)
98
+
99
+ # Clean up dead connections
100
+ if dead_connections:
101
+ active_connections[session_id] -= dead_connections
102
+ logger.info(f"Cleaned up {len(dead_connections)} dead connections")
103
+
104
+ # Remove session if no connections left
105
+ if not active_connections[session_id]:
106
+ del active_connections[session_id]
107
+
108
+
109
+ def get_connection_count(session_id: str) -> int:
110
+ """
111
+ Get the number of active WebSocket connections for a session.
112
+
113
+ Args:
114
+ session_id: Session ID to check
115
+
116
+ Returns:
117
+ Number of active connections
118
+ """
119
+ return len(active_connections.get(session_id, set()))
120
+
121
+
122
+ def get_all_connection_counts() -> Dict[str, int]:
123
+ """
124
+ Get connection counts for all sessions.
125
+
126
+ Returns:
127
+ Dictionary mapping session_id to connection count
128
+ """
129
+ return {
130
+ session_id: len(connections)
131
+ for session_id, connections in active_connections.items()
132
+ }