alchemist-nrel 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alchemist_core/__init__.py +2 -2
- alchemist_core/acquisition/botorch_acquisition.py +83 -126
- alchemist_core/data/experiment_manager.py +181 -12
- alchemist_core/models/botorch_model.py +292 -63
- alchemist_core/models/sklearn_model.py +145 -13
- alchemist_core/session.py +3330 -31
- alchemist_core/utils/__init__.py +3 -1
- alchemist_core/utils/acquisition_utils.py +60 -0
- alchemist_core/visualization/__init__.py +45 -0
- alchemist_core/visualization/helpers.py +130 -0
- alchemist_core/visualization/plots.py +1449 -0
- {alchemist_nrel-0.3.1.dist-info → alchemist_nrel-0.3.2.dist-info}/METADATA +13 -13
- {alchemist_nrel-0.3.1.dist-info → alchemist_nrel-0.3.2.dist-info}/RECORD +31 -26
- {alchemist_nrel-0.3.1.dist-info → alchemist_nrel-0.3.2.dist-info}/WHEEL +1 -1
- api/main.py +1 -1
- api/models/requests.py +52 -0
- api/models/responses.py +79 -2
- api/routers/experiments.py +333 -8
- api/routers/sessions.py +84 -9
- api/routers/visualizations.py +6 -4
- api/routers/websocket.py +2 -2
- api/services/session_store.py +295 -71
- api/static/assets/index-B6Cf6s_b.css +1 -0
- api/static/assets/{index-DWfIKU9j.js → index-B7njvc9r.js} +201 -196
- api/static/index.html +2 -2
- ui/gpr_panel.py +11 -5
- ui/target_column_dialog.py +299 -0
- ui/ui.py +52 -5
- api/static/assets/index-sMIa_1hV.css +0 -1
- {alchemist_nrel-0.3.1.dist-info → alchemist_nrel-0.3.2.dist-info}/entry_points.txt +0 -0
- {alchemist_nrel-0.3.1.dist-info → alchemist_nrel-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {alchemist_nrel-0.3.1.dist-info → alchemist_nrel-0.3.2.dist-info}/top_level.txt +0 -0
api/routers/experiments.py
CHANGED
|
@@ -3,15 +3,27 @@ Experiments router - Experimental data management.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from fastapi import APIRouter, Depends, UploadFile, File, Query
|
|
6
|
-
from ..models.requests import
|
|
6
|
+
from ..models.requests import (
|
|
7
|
+
AddExperimentRequest,
|
|
8
|
+
AddExperimentsBatchRequest,
|
|
9
|
+
InitialDesignRequest,
|
|
10
|
+
StageExperimentRequest,
|
|
11
|
+
StageExperimentsBatchRequest,
|
|
12
|
+
CompleteStagedExperimentsRequest
|
|
13
|
+
)
|
|
7
14
|
from ..models.responses import (
|
|
8
15
|
ExperimentResponse,
|
|
9
16
|
ExperimentsListResponse,
|
|
10
17
|
ExperimentsSummaryResponse,
|
|
11
|
-
InitialDesignResponse
|
|
18
|
+
InitialDesignResponse,
|
|
19
|
+
StagedExperimentResponse,
|
|
20
|
+
StagedExperimentsListResponse,
|
|
21
|
+
StagedExperimentsClearResponse,
|
|
22
|
+
StagedExperimentsCompletedResponse
|
|
12
23
|
)
|
|
13
24
|
from ..dependencies import get_session
|
|
14
25
|
from ..middleware.error_handlers import NoVariablesError
|
|
26
|
+
from .websocket import broadcast_to_session
|
|
15
27
|
from alchemist_core.session import OptimizationSession
|
|
16
28
|
import logging
|
|
17
29
|
import pandas as pd
|
|
@@ -96,6 +108,17 @@ async def add_experiment(
|
|
|
96
108
|
logger.error(f"Auto-train failed for session {session_id}: {e}")
|
|
97
109
|
# Don't fail the whole request, just log it
|
|
98
110
|
|
|
111
|
+
# Broadcast experiment update to WebSocket clients
|
|
112
|
+
await broadcast_to_session(session_id, {
|
|
113
|
+
"event": "experiments_updated",
|
|
114
|
+
"n_experiments": n_experiments
|
|
115
|
+
})
|
|
116
|
+
if model_trained:
|
|
117
|
+
await broadcast_to_session(session_id, {
|
|
118
|
+
"event": "model_trained",
|
|
119
|
+
"metrics": training_metrics
|
|
120
|
+
})
|
|
121
|
+
|
|
99
122
|
return ExperimentResponse(
|
|
100
123
|
message="Experiment added successfully",
|
|
101
124
|
n_experiments=n_experiments,
|
|
@@ -158,6 +181,17 @@ async def add_experiments_batch(
|
|
|
158
181
|
except Exception as e:
|
|
159
182
|
logger.error(f"Auto-train failed for session {session_id}: {e}")
|
|
160
183
|
|
|
184
|
+
# Broadcast experiment update to WebSocket clients
|
|
185
|
+
await broadcast_to_session(session_id, {
|
|
186
|
+
"event": "experiments_updated",
|
|
187
|
+
"n_experiments": n_experiments
|
|
188
|
+
})
|
|
189
|
+
if model_trained:
|
|
190
|
+
await broadcast_to_session(session_id, {
|
|
191
|
+
"event": "model_trained",
|
|
192
|
+
"metrics": training_metrics
|
|
193
|
+
})
|
|
194
|
+
|
|
161
195
|
return ExperimentResponse(
|
|
162
196
|
message=f"Added {len(batch.experiments)} experiments successfully",
|
|
163
197
|
n_experiments=n_experiments,
|
|
@@ -227,19 +261,84 @@ async def list_experiments(
|
|
|
227
261
|
)
|
|
228
262
|
|
|
229
263
|
|
|
264
|
+
@router.post("/{session_id}/experiments/preview")
|
|
265
|
+
async def preview_csv_columns(
|
|
266
|
+
session_id: str,
|
|
267
|
+
file: UploadFile = File(...),
|
|
268
|
+
session: OptimizationSession = Depends(get_session)
|
|
269
|
+
):
|
|
270
|
+
"""
|
|
271
|
+
Preview CSV file columns before uploading to check for target columns.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
- available_columns: List of all columns in CSV
|
|
275
|
+
- has_output: Whether 'Output' column exists
|
|
276
|
+
- recommended_target: Suggested target column if 'Output' missing
|
|
277
|
+
"""
|
|
278
|
+
# Save uploaded file temporarily
|
|
279
|
+
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.csv') as tmp:
|
|
280
|
+
content = await file.read()
|
|
281
|
+
tmp.write(content)
|
|
282
|
+
tmp_path = tmp.name
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
# Read CSV to get column names
|
|
286
|
+
df = pd.read_csv(tmp_path)
|
|
287
|
+
columns = df.columns.tolist()
|
|
288
|
+
|
|
289
|
+
# Check for 'Output' column
|
|
290
|
+
has_output = 'Output' in columns
|
|
291
|
+
|
|
292
|
+
# Filter out metadata columns
|
|
293
|
+
metadata_cols = {'Iteration', 'Reason', 'Noise'}
|
|
294
|
+
available_targets = [col for col in columns if col not in metadata_cols]
|
|
295
|
+
|
|
296
|
+
# Recommend target column
|
|
297
|
+
recommended = None
|
|
298
|
+
if not has_output:
|
|
299
|
+
# Look for common target column names
|
|
300
|
+
common_names = ['output', 'y', 'target', 'yield', 'response']
|
|
301
|
+
for name in common_names:
|
|
302
|
+
if name in [col.lower() for col in available_targets]:
|
|
303
|
+
recommended = [col for col in available_targets if col.lower() == name][0]
|
|
304
|
+
break
|
|
305
|
+
|
|
306
|
+
# If no common name found, use first numeric column
|
|
307
|
+
if not recommended and available_targets:
|
|
308
|
+
# Check if first available column is numeric
|
|
309
|
+
if pd.api.types.is_numeric_dtype(df[available_targets[0]]):
|
|
310
|
+
recommended = available_targets[0]
|
|
311
|
+
|
|
312
|
+
return {
|
|
313
|
+
"columns": columns,
|
|
314
|
+
"available_targets": available_targets,
|
|
315
|
+
"has_output": has_output,
|
|
316
|
+
"recommended_target": recommended,
|
|
317
|
+
"n_rows": len(df)
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
finally:
|
|
321
|
+
# Clean up temp file
|
|
322
|
+
if os.path.exists(tmp_path):
|
|
323
|
+
os.unlink(tmp_path)
|
|
324
|
+
|
|
325
|
+
|
|
230
326
|
@router.post("/{session_id}/experiments/upload")
|
|
231
327
|
async def upload_experiments(
|
|
232
328
|
session_id: str,
|
|
233
329
|
file: UploadFile = File(...),
|
|
234
|
-
|
|
330
|
+
target_columns: str = "Output", # Note: API accepts string, will be normalized by Session API
|
|
235
331
|
session: OptimizationSession = Depends(get_session)
|
|
236
332
|
):
|
|
237
333
|
"""
|
|
238
334
|
Upload experimental data from CSV file.
|
|
239
335
|
|
|
240
336
|
The CSV should have columns matching the variable names,
|
|
241
|
-
plus
|
|
242
|
-
|
|
337
|
+
plus target column(s) (default: "Output") and optional noise column ("Noise").
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
target_columns: Target column name (single-objective) or comma-separated names (multi-objective).
|
|
341
|
+
Examples: "Output", "yield", "yield,selectivity"
|
|
243
342
|
"""
|
|
244
343
|
# Check if variables are defined
|
|
245
344
|
if len(session.search_space.variables) == 0:
|
|
@@ -252,17 +351,26 @@ async def upload_experiments(
|
|
|
252
351
|
tmp_path = tmp.name
|
|
253
352
|
|
|
254
353
|
try:
|
|
354
|
+
# Parse target_columns (handle comma-separated for future multi-objective support)
|
|
355
|
+
target_cols_parsed = target_columns.split(',') if ',' in target_columns else target_columns
|
|
356
|
+
|
|
255
357
|
# Load data using session's load_data method
|
|
256
|
-
session.load_data(tmp_path,
|
|
358
|
+
session.load_data(tmp_path, target_columns=target_cols_parsed)
|
|
257
359
|
|
|
258
360
|
n_experiments = len(session.experiment_manager.df)
|
|
259
361
|
logger.info(f"Loaded {n_experiments} experiments from CSV for session {session_id}")
|
|
260
|
-
|
|
362
|
+
|
|
363
|
+
# Broadcast experiment update to WebSocket clients
|
|
364
|
+
await broadcast_to_session(session_id, {
|
|
365
|
+
"event": "experiments_updated",
|
|
366
|
+
"n_experiments": n_experiments
|
|
367
|
+
})
|
|
368
|
+
|
|
261
369
|
return {
|
|
262
370
|
"message": f"Loaded {n_experiments} experiments successfully",
|
|
263
371
|
"n_experiments": n_experiments
|
|
264
372
|
}
|
|
265
|
-
|
|
373
|
+
|
|
266
374
|
finally:
|
|
267
375
|
# Clean up temp file
|
|
268
376
|
if os.path.exists(tmp_path):
|
|
@@ -280,3 +388,220 @@ async def get_experiments_summary(
|
|
|
280
388
|
Returns sample size, target variable statistics, and feature information.
|
|
281
389
|
"""
|
|
282
390
|
return session.get_data_summary()
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
# ============================================================
|
|
394
|
+
# Staged Experiments Endpoints
|
|
395
|
+
# ============================================================
|
|
396
|
+
|
|
397
|
+
@router.post("/{session_id}/experiments/staged", response_model=StagedExperimentResponse)
|
|
398
|
+
async def stage_experiment(
|
|
399
|
+
session_id: str,
|
|
400
|
+
request: StageExperimentRequest,
|
|
401
|
+
session: OptimizationSession = Depends(get_session)
|
|
402
|
+
):
|
|
403
|
+
"""
|
|
404
|
+
Stage an experiment for later execution.
|
|
405
|
+
|
|
406
|
+
Staged experiments are stored in a queue awaiting evaluation.
|
|
407
|
+
This is useful for autonomous workflows where the controller
|
|
408
|
+
needs to track which experiments are pending execution.
|
|
409
|
+
|
|
410
|
+
Use GET /experiments/staged to retrieve staged experiments,
|
|
411
|
+
and POST /experiments/staged/complete to finalize them with outputs.
|
|
412
|
+
"""
|
|
413
|
+
# Check if variables are defined
|
|
414
|
+
if len(session.search_space.variables) == 0:
|
|
415
|
+
raise NoVariablesError("No variables defined. Add variables to search space first.")
|
|
416
|
+
|
|
417
|
+
# Add reason metadata if provided
|
|
418
|
+
inputs_with_meta = dict(request.inputs)
|
|
419
|
+
if request.reason:
|
|
420
|
+
inputs_with_meta['_reason'] = request.reason
|
|
421
|
+
|
|
422
|
+
session.add_staged_experiment(inputs_with_meta)
|
|
423
|
+
|
|
424
|
+
n_staged = len(session.get_staged_experiments())
|
|
425
|
+
logger.info(f"Staged experiment for session {session_id}. Total staged: {n_staged}")
|
|
426
|
+
|
|
427
|
+
return StagedExperimentResponse(
|
|
428
|
+
message="Experiment staged successfully",
|
|
429
|
+
n_staged=n_staged,
|
|
430
|
+
staged_inputs=request.inputs
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
@router.post("/{session_id}/experiments/staged/batch", response_model=StagedExperimentsListResponse)
|
|
435
|
+
async def stage_experiments_batch(
|
|
436
|
+
session_id: str,
|
|
437
|
+
request: StageExperimentsBatchRequest,
|
|
438
|
+
session: OptimizationSession = Depends(get_session)
|
|
439
|
+
):
|
|
440
|
+
"""
|
|
441
|
+
Stage multiple experiments at once.
|
|
442
|
+
|
|
443
|
+
Useful after acquisition functions suggest multiple points for parallel execution.
|
|
444
|
+
The `reason` parameter is stored as metadata and will be used when completing
|
|
445
|
+
the experiments (recorded in the 'Reason' column of the experiment data).
|
|
446
|
+
"""
|
|
447
|
+
# Check if variables are defined
|
|
448
|
+
if len(session.search_space.variables) == 0:
|
|
449
|
+
raise NoVariablesError("No variables defined. Add variables to search space first.")
|
|
450
|
+
|
|
451
|
+
for inputs in request.experiments:
|
|
452
|
+
inputs_with_meta = dict(inputs)
|
|
453
|
+
if request.reason:
|
|
454
|
+
inputs_with_meta['_reason'] = request.reason
|
|
455
|
+
session.add_staged_experiment(inputs_with_meta)
|
|
456
|
+
|
|
457
|
+
logger.info(f"Staged {len(request.experiments)} experiments for session {session_id}. Total staged: {len(session.get_staged_experiments())}")
|
|
458
|
+
|
|
459
|
+
# Return clean experiments (without metadata) for client use
|
|
460
|
+
return StagedExperimentsListResponse(
|
|
461
|
+
experiments=request.experiments, # Return the original clean inputs
|
|
462
|
+
n_staged=len(session.get_staged_experiments()),
|
|
463
|
+
reason=request.reason
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
@router.get("/{session_id}/experiments/staged", response_model=StagedExperimentsListResponse)
|
|
468
|
+
async def get_staged_experiments(
|
|
469
|
+
session_id: str,
|
|
470
|
+
session: OptimizationSession = Depends(get_session)
|
|
471
|
+
):
|
|
472
|
+
"""
|
|
473
|
+
Get all staged experiments awaiting execution.
|
|
474
|
+
|
|
475
|
+
Returns the list of experiments that have been queued but not yet
|
|
476
|
+
completed with output values. The response includes:
|
|
477
|
+
- experiments: Clean variable inputs only (no metadata)
|
|
478
|
+
- reason: The strategy/reason for these experiments (if provided when staging)
|
|
479
|
+
"""
|
|
480
|
+
staged = session.get_staged_experiments()
|
|
481
|
+
|
|
482
|
+
# Extract reason from first experiment (if present) and clean all experiments
|
|
483
|
+
reason = None
|
|
484
|
+
clean_experiments = []
|
|
485
|
+
for exp in staged:
|
|
486
|
+
if '_reason' in exp and reason is None:
|
|
487
|
+
reason = exp['_reason']
|
|
488
|
+
# Return only variable values, not metadata
|
|
489
|
+
clean_exp = {k: v for k, v in exp.items() if not k.startswith('_')}
|
|
490
|
+
clean_experiments.append(clean_exp)
|
|
491
|
+
|
|
492
|
+
return StagedExperimentsListResponse(
|
|
493
|
+
experiments=clean_experiments,
|
|
494
|
+
n_staged=len(staged),
|
|
495
|
+
reason=reason
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@router.delete("/{session_id}/experiments/staged", response_model=StagedExperimentsClearResponse)
|
|
500
|
+
async def clear_staged_experiments(
|
|
501
|
+
session_id: str,
|
|
502
|
+
session: OptimizationSession = Depends(get_session)
|
|
503
|
+
):
|
|
504
|
+
"""
|
|
505
|
+
Clear all staged experiments.
|
|
506
|
+
|
|
507
|
+
Use this to reset the staging queue if experiments were cancelled
|
|
508
|
+
or need to be regenerated.
|
|
509
|
+
"""
|
|
510
|
+
n_cleared = session.clear_staged_experiments()
|
|
511
|
+
logger.info(f"Cleared {n_cleared} staged experiments for session {session_id}")
|
|
512
|
+
|
|
513
|
+
return StagedExperimentsClearResponse(
|
|
514
|
+
message="Staged experiments cleared",
|
|
515
|
+
n_cleared=n_cleared
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
@router.post("/{session_id}/experiments/staged/complete", response_model=StagedExperimentsCompletedResponse)
|
|
520
|
+
async def complete_staged_experiments(
|
|
521
|
+
session_id: str,
|
|
522
|
+
request: CompleteStagedExperimentsRequest,
|
|
523
|
+
auto_train: bool = Query(False, description="Auto-train model after adding data"),
|
|
524
|
+
training_backend: Optional[str] = Query(None, description="Model backend (sklearn/botorch)"),
|
|
525
|
+
training_kernel: Optional[str] = Query(None, description="Kernel type (rbf/matern)"),
|
|
526
|
+
session: OptimizationSession = Depends(get_session)
|
|
527
|
+
):
|
|
528
|
+
"""
|
|
529
|
+
Complete staged experiments by providing output values.
|
|
530
|
+
|
|
531
|
+
This pairs the staged experiment inputs with the provided outputs,
|
|
532
|
+
adds them to the experiment dataset, and clears the staging queue.
|
|
533
|
+
|
|
534
|
+
The number of outputs must match the number of staged experiments.
|
|
535
|
+
Outputs should be in the same order as the staged experiments were added.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
auto_train: If True, retrain model after adding data
|
|
539
|
+
training_backend: Model backend (uses last if None)
|
|
540
|
+
training_kernel: Kernel type (uses last or 'rbf' if None)
|
|
541
|
+
"""
|
|
542
|
+
staged = session.get_staged_experiments()
|
|
543
|
+
|
|
544
|
+
if len(staged) == 0:
|
|
545
|
+
return StagedExperimentsCompletedResponse(
|
|
546
|
+
message="No staged experiments to complete",
|
|
547
|
+
n_added=0,
|
|
548
|
+
n_experiments=len(session.experiment_manager.df),
|
|
549
|
+
model_trained=False
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if len(request.outputs) != len(staged):
|
|
553
|
+
raise ValueError(
|
|
554
|
+
f"Number of outputs ({len(request.outputs)}) must match "
|
|
555
|
+
f"number of staged experiments ({len(staged)})"
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Use the core Session method to move staged experiments to dataset
|
|
559
|
+
n_added = session.move_staged_to_experiments(
|
|
560
|
+
outputs=request.outputs,
|
|
561
|
+
noises=request.noises,
|
|
562
|
+
iteration=request.iteration,
|
|
563
|
+
reason=request.reason
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
n_experiments = len(session.experiment_manager.df)
|
|
567
|
+
logger.info(f"Completed {n_added} staged experiments for session {session_id}. Total: {n_experiments}")
|
|
568
|
+
|
|
569
|
+
# Auto-train if requested
|
|
570
|
+
model_trained = False
|
|
571
|
+
training_metrics = None
|
|
572
|
+
|
|
573
|
+
if auto_train and n_experiments >= 5:
|
|
574
|
+
try:
|
|
575
|
+
backend = training_backend or (session.model_backend if session.model else "sklearn")
|
|
576
|
+
kernel = training_kernel or "rbf"
|
|
577
|
+
|
|
578
|
+
result = session.train_model(backend=backend, kernel=kernel)
|
|
579
|
+
model_trained = True
|
|
580
|
+
metrics = result.get("metrics", {})
|
|
581
|
+
training_metrics = {
|
|
582
|
+
"rmse": metrics.get("rmse"),
|
|
583
|
+
"r2": metrics.get("r2"),
|
|
584
|
+
"backend": backend
|
|
585
|
+
}
|
|
586
|
+
logger.info(f"Auto-trained model for session {session_id}: {training_metrics}")
|
|
587
|
+
except Exception as e:
|
|
588
|
+
logger.error(f"Auto-train failed for session {session_id}: {e}")
|
|
589
|
+
|
|
590
|
+
# Broadcast experiment update to WebSocket clients
|
|
591
|
+
await broadcast_to_session(session_id, {
|
|
592
|
+
"event": "experiments_updated",
|
|
593
|
+
"n_experiments": n_experiments
|
|
594
|
+
})
|
|
595
|
+
if model_trained:
|
|
596
|
+
await broadcast_to_session(session_id, {
|
|
597
|
+
"event": "model_trained",
|
|
598
|
+
"metrics": training_metrics
|
|
599
|
+
})
|
|
600
|
+
|
|
601
|
+
return StagedExperimentsCompletedResponse(
|
|
602
|
+
message="Staged experiments completed and added to dataset",
|
|
603
|
+
n_added=n_added,
|
|
604
|
+
n_experiments=n_experiments,
|
|
605
|
+
model_trained=model_trained,
|
|
606
|
+
training_metrics=training_metrics
|
|
607
|
+
)
|
api/routers/sessions.py
CHANGED
|
@@ -39,8 +39,7 @@ async def create_session():
|
|
|
39
39
|
|
|
40
40
|
return SessionCreateResponse(
|
|
41
41
|
session_id=session_id,
|
|
42
|
-
created_at=session_info["created_at"]
|
|
43
|
-
expires_at=session_info["expires_at"]
|
|
42
|
+
created_at=session_info["created_at"]
|
|
44
43
|
)
|
|
45
44
|
|
|
46
45
|
|
|
@@ -123,10 +122,8 @@ async def extend_session(session_id: str, hours: int = 24):
|
|
|
123
122
|
detail=f"Session {session_id} not found"
|
|
124
123
|
)
|
|
125
124
|
|
|
126
|
-
info = session_store.get_info(session_id)
|
|
127
125
|
return {
|
|
128
|
-
"message": "Session TTL extended"
|
|
129
|
-
"expires_at": info["expires_at"]
|
|
126
|
+
"message": "Session TTL extended (legacy endpoint - no longer has effect)"
|
|
130
127
|
}
|
|
131
128
|
|
|
132
129
|
|
|
@@ -194,8 +191,7 @@ async def import_session(file: UploadFile = File(...)):
|
|
|
194
191
|
session_info = session_store.get_info(session_id)
|
|
195
192
|
return SessionCreateResponse(
|
|
196
193
|
session_id=session_id,
|
|
197
|
-
created_at=session_info["created_at"]
|
|
198
|
-
expires_at=session_info["expires_at"]
|
|
194
|
+
created_at=session_info["created_at"]
|
|
199
195
|
)
|
|
200
196
|
|
|
201
197
|
except Exception as e:
|
|
@@ -452,8 +448,7 @@ async def upload_session(file: UploadFile = File(...)):
|
|
|
452
448
|
|
|
453
449
|
return SessionCreateResponse(
|
|
454
450
|
session_id=new_session_id,
|
|
455
|
-
created_at=session_info["created_at"]
|
|
456
|
-
expires_at=session_info["expires_at"]
|
|
451
|
+
created_at=session_info["created_at"]
|
|
457
452
|
)
|
|
458
453
|
|
|
459
454
|
finally:
|
|
@@ -468,6 +463,86 @@ async def upload_session(file: UploadFile = File(...)):
|
|
|
468
463
|
)
|
|
469
464
|
|
|
470
465
|
|
|
466
|
+
# ============================================================
|
|
467
|
+
# Recovery / Backup Endpoints
|
|
468
|
+
# ============================================================
|
|
469
|
+
|
|
470
|
+
@router.post("/sessions/{session_id}/backup", status_code=status.HTTP_200_OK)
|
|
471
|
+
async def create_recovery_backup(session_id: str):
|
|
472
|
+
"""
|
|
473
|
+
Create a silent recovery backup for crash protection.
|
|
474
|
+
|
|
475
|
+
Called automatically by frontend every 30 seconds while user has session open.
|
|
476
|
+
User never sees these backups unless browser crashes and recovery is needed.
|
|
477
|
+
"""
|
|
478
|
+
success = session_store.save_recovery_backup(session_id)
|
|
479
|
+
if not success:
|
|
480
|
+
raise HTTPException(
|
|
481
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
482
|
+
detail=f"Session {session_id} not found or backup failed"
|
|
483
|
+
)
|
|
484
|
+
return {"message": "Recovery backup created"}
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
@router.delete("/sessions/{session_id}/backup", status_code=status.HTTP_200_OK)
|
|
488
|
+
async def clear_recovery_backup(session_id: str):
|
|
489
|
+
"""
|
|
490
|
+
Clear recovery backups for a session.
|
|
491
|
+
|
|
492
|
+
Called after user successfully saves their session to their computer.
|
|
493
|
+
This prevents recovery prompt from appearing unnecessarily.
|
|
494
|
+
"""
|
|
495
|
+
deleted = session_store.clear_recovery_backup(session_id)
|
|
496
|
+
return {"message": "Recovery backup cleared", "deleted": deleted}
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@router.get("/recovery/list")
|
|
500
|
+
async def list_recovery_sessions():
|
|
501
|
+
"""
|
|
502
|
+
List available recovery sessions.
|
|
503
|
+
|
|
504
|
+
Called on app startup to check if there are any unsaved sessions
|
|
505
|
+
that can be recovered from a crash.
|
|
506
|
+
"""
|
|
507
|
+
recoveries = session_store.list_recovery_sessions()
|
|
508
|
+
return {"recoveries": recoveries, "count": len(recoveries)}
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
@router.post("/recovery/{session_id}/restore", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED)
|
|
512
|
+
async def restore_recovery_session(session_id: str):
|
|
513
|
+
"""
|
|
514
|
+
Restore a session from recovery backup.
|
|
515
|
+
|
|
516
|
+
Called when user clicks "Restore" on the recovery banner.
|
|
517
|
+
Creates a new active session from the recovery file.
|
|
518
|
+
"""
|
|
519
|
+
new_session_id = session_store.restore_from_recovery(session_id)
|
|
520
|
+
|
|
521
|
+
if new_session_id is None:
|
|
522
|
+
raise HTTPException(
|
|
523
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
524
|
+
detail=f"No recovery data found for session {session_id}"
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
session_info = session_store.get_info(new_session_id)
|
|
528
|
+
return SessionCreateResponse(
|
|
529
|
+
session_id=new_session_id,
|
|
530
|
+
created_at=session_info["created_at"]
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
@router.delete("/recovery/cleanup")
|
|
535
|
+
async def cleanup_old_recoveries(max_age_hours: int = 24):
|
|
536
|
+
"""
|
|
537
|
+
Clean up old recovery files.
|
|
538
|
+
|
|
539
|
+
Deletes recovery files older than specified hours.
|
|
540
|
+
Can be called manually or via scheduled task.
|
|
541
|
+
"""
|
|
542
|
+
session_store.cleanup_old_recoveries(max_age_hours)
|
|
543
|
+
return {"message": f"Cleaned up recovery files older than {max_age_hours} hours"}
|
|
544
|
+
|
|
545
|
+
|
|
471
546
|
# ============================================================
|
|
472
547
|
# Session Locking Endpoints
|
|
473
548
|
# ============================================================
|
api/routers/visualizations.py
CHANGED
|
@@ -194,9 +194,10 @@ async def get_contour_data(
|
|
|
194
194
|
# CRITICAL FIX: Reorder columns to match training data
|
|
195
195
|
# The model was trained with a specific column order, we must match it.
|
|
196
196
|
# Exclude metadata columns that are part of the experiments table but
|
|
197
|
-
# are not model input features (e.g., Iteration, Reason,
|
|
197
|
+
# are not model input features (e.g., Iteration, Reason, target columns, Noise).
|
|
198
198
|
train_data = session.experiment_manager.get_data()
|
|
199
|
-
|
|
199
|
+
target_cols = set(session.experiment_manager.target_columns)
|
|
200
|
+
metadata_cols = {'Iteration', 'Reason', 'Noise'} | target_cols
|
|
200
201
|
feature_cols = [col for col in train_data.columns if col not in metadata_cols]
|
|
201
202
|
|
|
202
203
|
# Safely align the prediction grid to the model feature order.
|
|
@@ -216,11 +217,12 @@ async def get_contour_data(
|
|
|
216
217
|
experiments_data = None
|
|
217
218
|
if request.include_experiments and len(session.experiment_manager) > 0:
|
|
218
219
|
exp_df = session.experiment_manager.get_data()
|
|
219
|
-
|
|
220
|
+
target_col = session.experiment_manager.target_columns[0] # Use first target for visualization
|
|
221
|
+
if request.x_var in exp_df.columns and request.y_var in exp_df.columns and target_col in exp_df.columns:
|
|
220
222
|
experiments_data = {
|
|
221
223
|
"x": exp_df[request.x_var].tolist(),
|
|
222
224
|
"y": exp_df[request.y_var].tolist(),
|
|
223
|
-
"output": exp_df[
|
|
225
|
+
"output": exp_df[target_col].tolist()
|
|
224
226
|
}
|
|
225
227
|
|
|
226
228
|
# Get suggestion data if requested (would need to be stored in session)
|
api/routers/websocket.py
CHANGED
|
@@ -32,7 +32,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
|
32
32
|
session_id: Session ID to subscribe to
|
|
33
33
|
"""
|
|
34
34
|
await websocket.accept()
|
|
35
|
-
logger.
|
|
35
|
+
logger.debug(f"WebSocket connected: session_id={session_id}")
|
|
36
36
|
|
|
37
37
|
# Register this connection for this session
|
|
38
38
|
if session_id not in active_connections:
|
|
@@ -60,7 +60,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
|
60
60
|
logger.warning(f"Invalid JSON from client: {data}")
|
|
61
61
|
|
|
62
62
|
except WebSocketDisconnect:
|
|
63
|
-
logger.
|
|
63
|
+
logger.debug(f"WebSocket disconnected: session_id={session_id}")
|
|
64
64
|
finally:
|
|
65
65
|
# Clean up on disconnect
|
|
66
66
|
if session_id in active_connections:
|