alchemist-nrel 0.2.1__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alchemist_core/__init__.py +14 -7
- alchemist_core/acquisition/botorch_acquisition.py +15 -6
- alchemist_core/audit_log.py +594 -0
- alchemist_core/data/experiment_manager.py +76 -5
- alchemist_core/models/botorch_model.py +6 -4
- alchemist_core/models/sklearn_model.py +74 -8
- alchemist_core/session.py +788 -39
- alchemist_core/utils/doe.py +200 -0
- alchemist_nrel-0.3.1.dist-info/METADATA +185 -0
- alchemist_nrel-0.3.1.dist-info/RECORD +66 -0
- {alchemist_nrel-0.2.1.dist-info → alchemist_nrel-0.3.1.dist-info}/entry_points.txt +1 -0
- api/example_client.py +7 -2
- api/main.py +21 -4
- api/models/requests.py +95 -1
- api/models/responses.py +167 -0
- api/routers/acquisition.py +25 -0
- api/routers/experiments.py +134 -6
- api/routers/sessions.py +438 -10
- api/routers/visualizations.py +10 -5
- api/routers/websocket.py +132 -0
- api/run_api.py +56 -0
- api/services/session_store.py +285 -54
- api/static/NEW_ICON.ico +0 -0
- api/static/NEW_ICON.png +0 -0
- api/static/NEW_LOGO_DARK.png +0 -0
- api/static/NEW_LOGO_LIGHT.png +0 -0
- api/static/assets/api-vcoXEqyq.js +1 -0
- api/static/assets/index-DWfIKU9j.js +4094 -0
- api/static/assets/index-sMIa_1hV.css +1 -0
- api/static/index.html +14 -0
- api/static/vite.svg +1 -0
- ui/gpr_panel.py +7 -2
- ui/notifications.py +197 -10
- ui/ui.py +1117 -68
- ui/variables_setup.py +47 -2
- ui/visualizations.py +60 -3
- alchemist_core/models/ax_model.py +0 -159
- alchemist_nrel-0.2.1.dist-info/METADATA +0 -206
- alchemist_nrel-0.2.1.dist-info/RECORD +0 -54
- {alchemist_nrel-0.2.1.dist-info → alchemist_nrel-0.3.1.dist-info}/WHEEL +0 -0
- {alchemist_nrel-0.2.1.dist-info → alchemist_nrel-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {alchemist_nrel-0.2.1.dist-info → alchemist_nrel-0.3.1.dist-info}/top_level.txt +0 -0
api/routers/sessions.py
CHANGED
|
@@ -2,13 +2,24 @@
|
|
|
2
2
|
Sessions router - Session lifecycle management.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from fastapi import APIRouter, HTTPException, status, UploadFile, File
|
|
6
|
-
from fastapi.responses import Response
|
|
7
|
-
from
|
|
5
|
+
from fastapi import APIRouter, HTTPException, status, UploadFile, File, Depends
|
|
6
|
+
from fastapi.responses import Response, FileResponse, JSONResponse
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from ..models.requests import UpdateMetadataRequest, LockDecisionRequest, SessionLockRequest
|
|
9
|
+
from ..models.responses import (
|
|
10
|
+
SessionCreateResponse, SessionInfoResponse, SessionStateResponse,
|
|
11
|
+
SessionMetadataResponse, AuditLogResponse, AuditEntryResponse, LockDecisionResponse,
|
|
12
|
+
SessionLockResponse
|
|
13
|
+
)
|
|
14
|
+
from .websocket import broadcast_to_session
|
|
8
15
|
from ..services import session_store
|
|
9
16
|
from ..dependencies import get_session
|
|
17
|
+
from alchemist_core.session import OptimizationSession
|
|
10
18
|
from datetime import datetime
|
|
11
19
|
import logging
|
|
20
|
+
import json
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
import tempfile
|
|
12
23
|
|
|
13
24
|
logger = logging.getLogger(__name__)
|
|
14
25
|
|
|
@@ -50,6 +61,36 @@ async def get_session_info(session_id: str):
|
|
|
50
61
|
return SessionInfoResponse(**info)
|
|
51
62
|
|
|
52
63
|
|
|
64
|
+
@router.get("/sessions/{session_id}/state", response_model=SessionStateResponse)
|
|
65
|
+
async def get_session_state(
|
|
66
|
+
session_id: str,
|
|
67
|
+
session: OptimizationSession = Depends(get_session)
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Get current session state for monitoring autonomous optimization.
|
|
71
|
+
|
|
72
|
+
Returns key metrics for dashboard displays or autonomous controllers
|
|
73
|
+
to monitor optimization progress without retrieving full session data.
|
|
74
|
+
"""
|
|
75
|
+
# Get session metrics
|
|
76
|
+
n_variables = len(session.search_space.variables)
|
|
77
|
+
n_experiments = len(session.experiment_manager.df)
|
|
78
|
+
model_trained = session.model is not None
|
|
79
|
+
|
|
80
|
+
# Get last suggestion if available
|
|
81
|
+
last_suggestion = None
|
|
82
|
+
if hasattr(session, '_last_suggestion') and session._last_suggestion:
|
|
83
|
+
last_suggestion = session._last_suggestion
|
|
84
|
+
|
|
85
|
+
return SessionStateResponse(
|
|
86
|
+
session_id=session_id,
|
|
87
|
+
n_variables=n_variables,
|
|
88
|
+
n_experiments=n_experiments,
|
|
89
|
+
model_trained=model_trained,
|
|
90
|
+
last_suggestion=last_suggestion
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
53
94
|
@router.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
54
95
|
async def delete_session(session_id: str):
|
|
55
96
|
"""
|
|
@@ -89,13 +130,30 @@ async def extend_session(session_id: str, hours: int = 24):
|
|
|
89
130
|
}
|
|
90
131
|
|
|
91
132
|
|
|
133
|
+
@router.post("/sessions/{session_id}/save", status_code=status.HTTP_200_OK)
|
|
134
|
+
async def save_session_server_side(session_id: str):
|
|
135
|
+
"""
|
|
136
|
+
Persist the current in-memory session to the server-side session file.
|
|
137
|
+
|
|
138
|
+
This allows the web UI to save changes directly to the session store file
|
|
139
|
+
instead of triggering a browser download.
|
|
140
|
+
"""
|
|
141
|
+
success = session_store.persist_session_to_disk(session_id)
|
|
142
|
+
if not success:
|
|
143
|
+
raise HTTPException(
|
|
144
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
145
|
+
detail=f"Session {session_id} not found or failed to save"
|
|
146
|
+
)
|
|
147
|
+
return {"message": "Session persisted to server storage"}
|
|
148
|
+
|
|
149
|
+
|
|
92
150
|
@router.get("/sessions/{session_id}/export")
|
|
93
151
|
async def export_session(session_id: str):
|
|
94
152
|
"""
|
|
95
153
|
Export a session for download.
|
|
96
154
|
|
|
97
|
-
Downloads the complete session state as a .
|
|
98
|
-
reimported later.
|
|
155
|
+
Downloads the complete session state as a .json file that can be
|
|
156
|
+
reimported later or used in desktop application.
|
|
99
157
|
"""
|
|
100
158
|
session_data = session_store.export_session(session_id)
|
|
101
159
|
if session_data is None:
|
|
@@ -106,9 +164,9 @@ async def export_session(session_id: str):
|
|
|
106
164
|
|
|
107
165
|
return Response(
|
|
108
166
|
content=session_data,
|
|
109
|
-
media_type="application/
|
|
167
|
+
media_type="application/json",
|
|
110
168
|
headers={
|
|
111
|
-
"Content-Disposition": f"attachment; filename=session_{session_id}.
|
|
169
|
+
"Content-Disposition": f"attachment; filename=session_{session_id}.json"
|
|
112
170
|
}
|
|
113
171
|
)
|
|
114
172
|
|
|
@@ -118,12 +176,14 @@ async def import_session(file: UploadFile = File(...)):
|
|
|
118
176
|
"""
|
|
119
177
|
Import a previously exported session.
|
|
120
178
|
|
|
121
|
-
Uploads a .
|
|
122
|
-
A new session ID will be generated.
|
|
179
|
+
Uploads a .json session file and creates a new session with the imported data.
|
|
180
|
+
A new session ID will be generated. Compatible with desktop application sessions.
|
|
123
181
|
"""
|
|
124
182
|
try:
|
|
125
183
|
session_data = await file.read()
|
|
126
|
-
|
|
184
|
+
# Decode bytes to string for JSON
|
|
185
|
+
session_json = session_data.decode('utf-8')
|
|
186
|
+
session_id = session_store.import_session(session_json)
|
|
127
187
|
|
|
128
188
|
if session_id is None:
|
|
129
189
|
raise HTTPException(
|
|
@@ -144,3 +204,371 @@ async def import_session(file: UploadFile = File(...)):
|
|
|
144
204
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
145
205
|
detail=f"Failed to import session: {str(e)}"
|
|
146
206
|
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# ============================================================
|
|
210
|
+
# Metadata Management Endpoints
|
|
211
|
+
# ============================================================
|
|
212
|
+
|
|
213
|
+
@router.get("/sessions/{session_id}/metadata", response_model=SessionMetadataResponse)
|
|
214
|
+
async def get_metadata(
|
|
215
|
+
session_id: str,
|
|
216
|
+
session: OptimizationSession = Depends(get_session)
|
|
217
|
+
):
|
|
218
|
+
"""
|
|
219
|
+
Get session metadata.
|
|
220
|
+
|
|
221
|
+
Returns the session's user-friendly name, description, tags, and timestamps.
|
|
222
|
+
"""
|
|
223
|
+
return SessionMetadataResponse(
|
|
224
|
+
session_id=session.metadata.session_id,
|
|
225
|
+
name=session.metadata.name,
|
|
226
|
+
created_at=session.metadata.created_at,
|
|
227
|
+
last_modified=session.metadata.last_modified,
|
|
228
|
+
description=session.metadata.description,
|
|
229
|
+
tags=session.metadata.tags
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@router.patch("/sessions/{session_id}/metadata", response_model=SessionMetadataResponse)
|
|
234
|
+
async def update_metadata(
|
|
235
|
+
session_id: str,
|
|
236
|
+
request: UpdateMetadataRequest,
|
|
237
|
+
session: OptimizationSession = Depends(get_session)
|
|
238
|
+
):
|
|
239
|
+
"""
|
|
240
|
+
Update session metadata.
|
|
241
|
+
|
|
242
|
+
Update the session's name, description, and/or tags. Only provided fields
|
|
243
|
+
will be updated; omitted fields remain unchanged.
|
|
244
|
+
"""
|
|
245
|
+
session.update_metadata(
|
|
246
|
+
name=request.name,
|
|
247
|
+
description=request.description,
|
|
248
|
+
tags=request.tags
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return SessionMetadataResponse(
|
|
252
|
+
session_id=session.metadata.session_id,
|
|
253
|
+
name=session.metadata.name,
|
|
254
|
+
created_at=session.metadata.created_at,
|
|
255
|
+
last_modified=session.metadata.last_modified,
|
|
256
|
+
description=session.metadata.description,
|
|
257
|
+
tags=session.metadata.tags
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
# ============================================================
|
|
262
|
+
# Audit Log Endpoints
|
|
263
|
+
# ============================================================
|
|
264
|
+
|
|
265
|
+
@router.get("/sessions/{session_id}/audit", response_model=AuditLogResponse)
|
|
266
|
+
async def get_audit_log(
|
|
267
|
+
session_id: str,
|
|
268
|
+
entry_type: str = None,
|
|
269
|
+
session: OptimizationSession = Depends(get_session)
|
|
270
|
+
):
|
|
271
|
+
"""
|
|
272
|
+
Get audit log entries.
|
|
273
|
+
|
|
274
|
+
Retrieves the complete audit trail or filters by entry type.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
session_id: Session identifier
|
|
278
|
+
entry_type: Optional filter ('data_locked', 'model_locked', 'acquisition_locked')
|
|
279
|
+
"""
|
|
280
|
+
if entry_type:
|
|
281
|
+
entries = session.audit_log.get_entries(entry_type)
|
|
282
|
+
else:
|
|
283
|
+
entries = session.audit_log.get_entries()
|
|
284
|
+
|
|
285
|
+
return AuditLogResponse(
|
|
286
|
+
entries=[AuditEntryResponse(**e.to_dict()) for e in entries],
|
|
287
|
+
n_entries=len(entries)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@router.post("/sessions/{session_id}/audit/lock", response_model=LockDecisionResponse)
|
|
292
|
+
async def lock_decision(
|
|
293
|
+
session_id: str,
|
|
294
|
+
request: LockDecisionRequest,
|
|
295
|
+
session: OptimizationSession = Depends(get_session)
|
|
296
|
+
):
|
|
297
|
+
"""
|
|
298
|
+
Lock in a decision to the audit log.
|
|
299
|
+
|
|
300
|
+
Creates an immutable audit entry for data, model, or acquisition decisions.
|
|
301
|
+
This should be called when the user is satisfied with their configuration
|
|
302
|
+
and ready to commit the decision to the audit trail.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
session_id: Session identifier
|
|
306
|
+
request: Lock decision request
|
|
307
|
+
"""
|
|
308
|
+
try:
|
|
309
|
+
if request.lock_type == "data":
|
|
310
|
+
entry = session.lock_data(notes=request.notes or "")
|
|
311
|
+
message = "Data decision locked successfully"
|
|
312
|
+
|
|
313
|
+
elif request.lock_type == "model":
|
|
314
|
+
entry = session.lock_model(notes=request.notes or "")
|
|
315
|
+
message = "Model decision locked successfully"
|
|
316
|
+
|
|
317
|
+
elif request.lock_type == "acquisition":
|
|
318
|
+
if not request.strategy or not request.parameters or not request.suggestions:
|
|
319
|
+
raise HTTPException(
|
|
320
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
321
|
+
detail="Acquisition lock requires strategy, parameters, and suggestions"
|
|
322
|
+
)
|
|
323
|
+
entry = session.lock_acquisition(
|
|
324
|
+
strategy=request.strategy,
|
|
325
|
+
parameters=request.parameters,
|
|
326
|
+
suggestions=request.suggestions,
|
|
327
|
+
notes=request.notes or ""
|
|
328
|
+
)
|
|
329
|
+
message = "Acquisition decision locked successfully"
|
|
330
|
+
|
|
331
|
+
else:
|
|
332
|
+
raise HTTPException(
|
|
333
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
334
|
+
detail=f"Invalid lock_type: {request.lock_type}"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
return LockDecisionResponse(
|
|
338
|
+
success=True,
|
|
339
|
+
entry=AuditEntryResponse(**entry.to_dict()),
|
|
340
|
+
message=message
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
except ValueError as e:
|
|
344
|
+
raise HTTPException(
|
|
345
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
346
|
+
detail=str(e)
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@router.get("/sessions/{session_id}/audit/export")
|
|
351
|
+
async def export_audit_markdown(
|
|
352
|
+
session_id: str,
|
|
353
|
+
session: OptimizationSession = Depends(get_session)
|
|
354
|
+
):
|
|
355
|
+
"""
|
|
356
|
+
Export audit log as markdown.
|
|
357
|
+
|
|
358
|
+
Returns the audit trail formatted as markdown for publication methods sections.
|
|
359
|
+
"""
|
|
360
|
+
markdown = session.export_audit_markdown()
|
|
361
|
+
|
|
362
|
+
return Response(
|
|
363
|
+
content=markdown,
|
|
364
|
+
media_type="text/markdown",
|
|
365
|
+
headers={
|
|
366
|
+
"Content-Disposition": f"attachment; filename=audit_log_{session_id}.md"
|
|
367
|
+
}
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
# ============================================================
|
|
372
|
+
# Session File Management (JSON Format)
|
|
373
|
+
# ============================================================
|
|
374
|
+
|
|
375
|
+
@router.get("/sessions/{session_id}/download")
|
|
376
|
+
async def download_session(
|
|
377
|
+
session_id: str,
|
|
378
|
+
session: OptimizationSession = Depends(get_session)
|
|
379
|
+
):
|
|
380
|
+
"""
|
|
381
|
+
Download session as JSON file.
|
|
382
|
+
|
|
383
|
+
Downloads the complete session state as a .json file with user-friendly
|
|
384
|
+
naming support. The file includes metadata, audit log, search space,
|
|
385
|
+
experiments, and configuration.
|
|
386
|
+
"""
|
|
387
|
+
# Create temporary file
|
|
388
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
389
|
+
temp_path = f.name
|
|
390
|
+
|
|
391
|
+
try:
|
|
392
|
+
# Save session to temp file
|
|
393
|
+
session.save_session(temp_path)
|
|
394
|
+
|
|
395
|
+
# Use session name for filename (sanitized)
|
|
396
|
+
filename = session.metadata.name.replace(" ", "_").replace("/", "_")
|
|
397
|
+
filename = f"{filename}.json"
|
|
398
|
+
|
|
399
|
+
return FileResponse(
|
|
400
|
+
path=temp_path,
|
|
401
|
+
media_type="application/json",
|
|
402
|
+
filename=filename,
|
|
403
|
+
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
|
404
|
+
)
|
|
405
|
+
except Exception as e:
|
|
406
|
+
# Clean up temp file on error
|
|
407
|
+
Path(temp_path).unlink(missing_ok=True)
|
|
408
|
+
raise HTTPException(
|
|
409
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
410
|
+
detail=f"Failed to export session: {str(e)}"
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@router.post("/sessions/upload", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED)
|
|
415
|
+
async def upload_session(file: UploadFile = File(...)):
|
|
416
|
+
"""
|
|
417
|
+
Upload and restore a session from JSON file.
|
|
418
|
+
|
|
419
|
+
Uploads a .json session file and creates a new session with the restored data.
|
|
420
|
+
A new session ID will be generated for API use, but the original session ID
|
|
421
|
+
is preserved in the metadata.
|
|
422
|
+
"""
|
|
423
|
+
try:
|
|
424
|
+
# Save uploaded file to temp location
|
|
425
|
+
with tempfile.NamedTemporaryFile(mode='wb', suffix='.json', delete=False) as f:
|
|
426
|
+
content = await file.read()
|
|
427
|
+
f.write(content)
|
|
428
|
+
temp_path = f.name
|
|
429
|
+
|
|
430
|
+
try:
|
|
431
|
+
# Load session from file without retraining
|
|
432
|
+
loaded_session = OptimizationSession.load_session(temp_path, retrain_on_load=False)
|
|
433
|
+
|
|
434
|
+
# Create new session in store
|
|
435
|
+
new_session_id = session_store.create()
|
|
436
|
+
|
|
437
|
+
# Replace the session object with loaded one and align metadata
|
|
438
|
+
try:
|
|
439
|
+
loaded_session.metadata.session_id = new_session_id
|
|
440
|
+
except Exception:
|
|
441
|
+
pass
|
|
442
|
+
|
|
443
|
+
session_store._sessions[new_session_id]["session"] = loaded_session
|
|
444
|
+
|
|
445
|
+
# Update last accessed
|
|
446
|
+
session_store._sessions[new_session_id]["last_accessed"] = datetime.now()
|
|
447
|
+
|
|
448
|
+
# Persist to disk
|
|
449
|
+
session_store._save_to_disk(new_session_id)
|
|
450
|
+
|
|
451
|
+
session_info = session_store.get_info(new_session_id)
|
|
452
|
+
|
|
453
|
+
return SessionCreateResponse(
|
|
454
|
+
session_id=new_session_id,
|
|
455
|
+
created_at=session_info["created_at"],
|
|
456
|
+
expires_at=session_info["expires_at"]
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
finally:
|
|
460
|
+
# Clean up temp file
|
|
461
|
+
Path(temp_path).unlink(missing_ok=True)
|
|
462
|
+
|
|
463
|
+
except Exception as e:
|
|
464
|
+
logger.error(f"Failed to upload session: {e}")
|
|
465
|
+
raise HTTPException(
|
|
466
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
467
|
+
detail=f"Failed to upload session: {str(e)}"
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
# ============================================================
|
|
472
|
+
# Session Locking Endpoints
|
|
473
|
+
# ============================================================
|
|
474
|
+
|
|
475
|
+
@router.post("/sessions/{session_id}/lock", response_model=SessionLockResponse)
|
|
476
|
+
async def lock_session(
|
|
477
|
+
session_id: str,
|
|
478
|
+
request: SessionLockRequest
|
|
479
|
+
):
|
|
480
|
+
"""
|
|
481
|
+
Lock a session for external programmatic control.
|
|
482
|
+
|
|
483
|
+
When locked, the web UI should enter monitor-only mode.
|
|
484
|
+
Returns a lock_token that must be used to unlock.
|
|
485
|
+
"""
|
|
486
|
+
try:
|
|
487
|
+
result = session_store.lock_session(
|
|
488
|
+
session_id=session_id,
|
|
489
|
+
locked_by=request.locked_by,
|
|
490
|
+
client_id=request.client_id
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Broadcast lock event to WebSocket clients
|
|
494
|
+
await broadcast_to_session(session_id, {
|
|
495
|
+
"event": "lock_status_changed",
|
|
496
|
+
"locked": True,
|
|
497
|
+
"locked_by": request.locked_by,
|
|
498
|
+
"locked_at": result["locked_at"]
|
|
499
|
+
})
|
|
500
|
+
|
|
501
|
+
return SessionLockResponse(**result)
|
|
502
|
+
except KeyError:
|
|
503
|
+
raise HTTPException(
|
|
504
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
505
|
+
detail=f"Session {session_id} not found or expired"
|
|
506
|
+
)
|
|
507
|
+
except Exception as e:
|
|
508
|
+
raise HTTPException(
|
|
509
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
510
|
+
detail=f"Failed to lock session: {str(e)}"
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
@router.delete("/sessions/{session_id}/lock", response_model=SessionLockResponse)
|
|
515
|
+
async def unlock_session(
|
|
516
|
+
session_id: str,
|
|
517
|
+
lock_token: Optional[str] = None
|
|
518
|
+
):
|
|
519
|
+
"""
|
|
520
|
+
Unlock a session.
|
|
521
|
+
|
|
522
|
+
Optionally provide lock_token for verification.
|
|
523
|
+
If no token provided, forcibly unlocks (use with caution).
|
|
524
|
+
"""
|
|
525
|
+
try:
|
|
526
|
+
result = session_store.unlock_session(session_id=session_id, lock_token=lock_token)
|
|
527
|
+
|
|
528
|
+
# Broadcast unlock event to WebSocket clients
|
|
529
|
+
await broadcast_to_session(session_id, {
|
|
530
|
+
"event": "lock_status_changed",
|
|
531
|
+
"locked": False,
|
|
532
|
+
"locked_by": None,
|
|
533
|
+
"locked_at": None
|
|
534
|
+
})
|
|
535
|
+
|
|
536
|
+
return SessionLockResponse(**result)
|
|
537
|
+
except KeyError:
|
|
538
|
+
raise HTTPException(
|
|
539
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
540
|
+
detail=f"Session {session_id} not found or expired"
|
|
541
|
+
)
|
|
542
|
+
except ValueError as e:
|
|
543
|
+
raise HTTPException(
|
|
544
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
545
|
+
detail=str(e)
|
|
546
|
+
)
|
|
547
|
+
except Exception as e:
|
|
548
|
+
raise HTTPException(
|
|
549
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
550
|
+
detail=f"Failed to unlock session: {str(e)}"
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@router.get("/sessions/{session_id}/lock", response_model=SessionLockResponse)
|
|
555
|
+
async def get_lock_status(session_id: str):
|
|
556
|
+
"""
|
|
557
|
+
Get current lock status of a session.
|
|
558
|
+
|
|
559
|
+
Used by web UI to detect when external controller has taken control
|
|
560
|
+
and automatically enter monitor mode.
|
|
561
|
+
"""
|
|
562
|
+
try:
|
|
563
|
+
result = session_store.get_lock_status(session_id=session_id)
|
|
564
|
+
return SessionLockResponse(**result)
|
|
565
|
+
except KeyError:
|
|
566
|
+
raise HTTPException(
|
|
567
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
568
|
+
detail=f"Session {session_id} not found or expired"
|
|
569
|
+
)
|
|
570
|
+
except Exception as e:
|
|
571
|
+
raise HTTPException(
|
|
572
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
573
|
+
detail=f"Failed to get lock status: {str(e)}"
|
|
574
|
+
)
|
api/routers/visualizations.py
CHANGED
|
@@ -192,12 +192,17 @@ async def get_contour_data(
|
|
|
192
192
|
grid_df = pd.DataFrame(grid_points)
|
|
193
193
|
|
|
194
194
|
# CRITICAL FIX: Reorder columns to match training data
|
|
195
|
-
# The model was trained with a specific column order, we must match it
|
|
195
|
+
# The model was trained with a specific column order, we must match it.
|
|
196
|
+
# Exclude metadata columns that are part of the experiments table but
|
|
197
|
+
# are not model input features (e.g., Iteration, Reason, Output, Noise).
|
|
196
198
|
train_data = session.experiment_manager.get_data()
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
199
|
+
metadata_cols = {'Iteration', 'Reason', 'Output', 'Noise'}
|
|
200
|
+
feature_cols = [col for col in train_data.columns if col not in metadata_cols]
|
|
201
|
+
|
|
202
|
+
# Safely align the prediction grid to the model feature order.
|
|
203
|
+
# Use reindex so missing columns (shouldn't happen) are filled with the
|
|
204
|
+
# midpoint/defaults the grid already supplies; this avoids KeyError.
|
|
205
|
+
grid_df = grid_df.reindex(columns=feature_cols)
|
|
201
206
|
|
|
202
207
|
# IMPORTANT: The model's predict() method handles preprocessing internally
|
|
203
208
|
# (including categorical encoding), so we can pass the raw DataFrame directly
|
api/routers/websocket.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebSocket router for real-time session updates.
|
|
3
|
+
|
|
4
|
+
Provides real-time push notifications for session events like lock status changes,
|
|
5
|
+
eliminating the need for client-side polling.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
9
|
+
from typing import Dict, Set
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
router = APIRouter()
|
|
16
|
+
|
|
17
|
+
# Track active WebSocket connections per session
|
|
18
|
+
# Structure: {session_id: {websocket1, websocket2, ...}}
|
|
19
|
+
active_connections: Dict[str, Set[WebSocket]] = {}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@router.websocket("/ws/sessions/{session_id}")
|
|
23
|
+
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
24
|
+
"""
|
|
25
|
+
WebSocket endpoint for real-time session updates.
|
|
26
|
+
|
|
27
|
+
Clients connect to this endpoint to receive push notifications about
|
|
28
|
+
session events (lock status changes, experiment additions, etc.).
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
websocket: WebSocket connection
|
|
32
|
+
session_id: Session ID to subscribe to
|
|
33
|
+
"""
|
|
34
|
+
await websocket.accept()
|
|
35
|
+
logger.info(f"WebSocket connected: session_id={session_id}")
|
|
36
|
+
|
|
37
|
+
# Register this connection for this session
|
|
38
|
+
if session_id not in active_connections:
|
|
39
|
+
active_connections[session_id] = set()
|
|
40
|
+
active_connections[session_id].add(websocket)
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
# Send initial connection confirmation
|
|
44
|
+
await websocket.send_json({
|
|
45
|
+
"event": "connected",
|
|
46
|
+
"session_id": session_id,
|
|
47
|
+
"message": "WebSocket connection established"
|
|
48
|
+
})
|
|
49
|
+
|
|
50
|
+
# Keep connection alive and listen for client messages
|
|
51
|
+
while True:
|
|
52
|
+
# Receive messages from client (for future bi-directional features)
|
|
53
|
+
data = await websocket.receive_text()
|
|
54
|
+
|
|
55
|
+
# Echo back for debugging (can be removed in production)
|
|
56
|
+
try:
|
|
57
|
+
message = json.loads(data)
|
|
58
|
+
logger.debug(f"Received from client: {message}")
|
|
59
|
+
except json.JSONDecodeError:
|
|
60
|
+
logger.warning(f"Invalid JSON from client: {data}")
|
|
61
|
+
|
|
62
|
+
except WebSocketDisconnect:
|
|
63
|
+
logger.info(f"WebSocket disconnected: session_id={session_id}")
|
|
64
|
+
finally:
|
|
65
|
+
# Clean up on disconnect
|
|
66
|
+
if session_id in active_connections:
|
|
67
|
+
active_connections[session_id].discard(websocket)
|
|
68
|
+
|
|
69
|
+
# Remove session entry if no more connections
|
|
70
|
+
if not active_connections[session_id]:
|
|
71
|
+
del active_connections[session_id]
|
|
72
|
+
logger.debug(f"No more connections for session {session_id}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
async def broadcast_to_session(session_id: str, event: dict):
|
|
76
|
+
"""
|
|
77
|
+
Broadcast an event to all WebSocket clients connected to a session.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
session_id: Session ID to broadcast to
|
|
81
|
+
event: Event data to send (will be JSON serialized)
|
|
82
|
+
"""
|
|
83
|
+
if session_id not in active_connections:
|
|
84
|
+
logger.debug(f"No active connections for session {session_id}")
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
# Track dead connections
|
|
88
|
+
dead_connections = set()
|
|
89
|
+
|
|
90
|
+
# Send to all connected clients
|
|
91
|
+
for connection in active_connections[session_id]:
|
|
92
|
+
try:
|
|
93
|
+
await connection.send_json(event)
|
|
94
|
+
logger.debug(f"Broadcast to session {session_id}: {event.get('event')}")
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.warning(f"Failed to send to connection: {e}")
|
|
97
|
+
dead_connections.add(connection)
|
|
98
|
+
|
|
99
|
+
# Clean up dead connections
|
|
100
|
+
if dead_connections:
|
|
101
|
+
active_connections[session_id] -= dead_connections
|
|
102
|
+
logger.info(f"Cleaned up {len(dead_connections)} dead connections")
|
|
103
|
+
|
|
104
|
+
# Remove session if no connections left
|
|
105
|
+
if not active_connections[session_id]:
|
|
106
|
+
del active_connections[session_id]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_connection_count(session_id: str) -> int:
|
|
110
|
+
"""
|
|
111
|
+
Get the number of active WebSocket connections for a session.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
session_id: Session ID to check
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Number of active connections
|
|
118
|
+
"""
|
|
119
|
+
return len(active_connections.get(session_id, set()))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_all_connection_counts() -> Dict[str, int]:
|
|
123
|
+
"""
|
|
124
|
+
Get connection counts for all sessions.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Dictionary mapping session_id to connection count
|
|
128
|
+
"""
|
|
129
|
+
return {
|
|
130
|
+
session_id: len(connections)
|
|
131
|
+
for session_id, connections in active_connections.items()
|
|
132
|
+
}
|