alchemist-nrel 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alchemist_core/__init__.py +2 -2
- alchemist_core/acquisition/botorch_acquisition.py +84 -126
- alchemist_core/data/experiment_manager.py +196 -20
- alchemist_core/models/botorch_model.py +292 -63
- alchemist_core/models/sklearn_model.py +175 -15
- alchemist_core/session.py +3532 -76
- alchemist_core/utils/__init__.py +3 -1
- alchemist_core/utils/acquisition_utils.py +60 -0
- alchemist_core/visualization/__init__.py +45 -0
- alchemist_core/visualization/helpers.py +130 -0
- alchemist_core/visualization/plots.py +1449 -0
- alchemist_nrel-0.3.2.dist-info/METADATA +185 -0
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/RECORD +34 -29
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/WHEEL +1 -1
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/entry_points.txt +1 -1
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/top_level.txt +0 -1
- api/example_client.py +7 -2
- api/main.py +3 -2
- api/models/requests.py +76 -1
- api/models/responses.py +102 -2
- api/routers/acquisition.py +25 -0
- api/routers/experiments.py +352 -11
- api/routers/sessions.py +195 -11
- api/routers/visualizations.py +6 -4
- api/routers/websocket.py +132 -0
- run_api.py → api/run_api.py +8 -7
- api/services/session_store.py +370 -71
- api/static/assets/index-B6Cf6s_b.css +1 -0
- api/static/assets/{index-C0_glioA.js → index-B7njvc9r.js} +223 -208
- api/static/index.html +2 -2
- ui/gpr_panel.py +11 -5
- ui/target_column_dialog.py +299 -0
- ui/ui.py +52 -5
- alchemist_core/models/ax_model.py +0 -159
- alchemist_nrel-0.3.0.dist-info/METADATA +0 -223
- api/static/assets/index-CB4V1LI5.css +0 -1
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/licenses/LICENSE +0 -0
api/routers/experiments.py
CHANGED
|
@@ -3,15 +3,27 @@ Experiments router - Experimental data management.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from fastapi import APIRouter, Depends, UploadFile, File, Query
|
|
6
|
-
from ..models.requests import
|
|
6
|
+
from ..models.requests import (
|
|
7
|
+
AddExperimentRequest,
|
|
8
|
+
AddExperimentsBatchRequest,
|
|
9
|
+
InitialDesignRequest,
|
|
10
|
+
StageExperimentRequest,
|
|
11
|
+
StageExperimentsBatchRequest,
|
|
12
|
+
CompleteStagedExperimentsRequest
|
|
13
|
+
)
|
|
7
14
|
from ..models.responses import (
|
|
8
15
|
ExperimentResponse,
|
|
9
16
|
ExperimentsListResponse,
|
|
10
17
|
ExperimentsSummaryResponse,
|
|
11
|
-
InitialDesignResponse
|
|
18
|
+
InitialDesignResponse,
|
|
19
|
+
StagedExperimentResponse,
|
|
20
|
+
StagedExperimentsListResponse,
|
|
21
|
+
StagedExperimentsClearResponse,
|
|
22
|
+
StagedExperimentsCompletedResponse
|
|
12
23
|
)
|
|
13
24
|
from ..dependencies import get_session
|
|
14
25
|
from ..middleware.error_handlers import NoVariablesError
|
|
26
|
+
from .websocket import broadcast_to_session
|
|
15
27
|
from alchemist_core.session import OptimizationSession
|
|
16
28
|
import logging
|
|
17
29
|
import pandas as pd
|
|
@@ -51,35 +63,62 @@ async def add_experiment(
|
|
|
51
63
|
session.add_experiment(
|
|
52
64
|
inputs=experiment.inputs,
|
|
53
65
|
output=experiment.output,
|
|
54
|
-
noise=experiment.noise
|
|
66
|
+
noise=experiment.noise,
|
|
67
|
+
iteration=experiment.iteration,
|
|
68
|
+
reason=experiment.reason
|
|
55
69
|
)
|
|
56
70
|
|
|
57
71
|
n_experiments = len(session.experiment_manager.df)
|
|
58
72
|
logger.info(f"Added experiment to session {session_id}. Total: {n_experiments}")
|
|
59
73
|
|
|
60
|
-
# Auto-train if requested
|
|
74
|
+
# Auto-train if requested (need at least 5 points to train)
|
|
61
75
|
model_trained = False
|
|
62
76
|
training_metrics = None
|
|
63
77
|
|
|
64
|
-
if auto_train and n_experiments >= 5:
|
|
78
|
+
if auto_train and n_experiments >= 5:
|
|
65
79
|
try:
|
|
66
80
|
# Use previous config or provided config
|
|
67
81
|
backend = training_backend or (session.model_backend if session.model else "sklearn")
|
|
68
82
|
kernel = training_kernel or "rbf"
|
|
69
83
|
|
|
84
|
+
# Note: Input/output transforms are now automatically applied by core Session.train_model()
|
|
85
|
+
# for BoTorch models. No need to specify them here unless overriding defaults.
|
|
70
86
|
result = session.train_model(backend=backend, kernel=kernel)
|
|
71
87
|
model_trained = True
|
|
72
88
|
metrics = result.get("metrics", {})
|
|
89
|
+
hyperparameters = result.get("hyperparameters", {})
|
|
73
90
|
training_metrics = {
|
|
74
91
|
"rmse": metrics.get("rmse"),
|
|
75
92
|
"r2": metrics.get("r2"),
|
|
76
93
|
"backend": backend
|
|
77
94
|
}
|
|
78
95
|
logger.info(f"Auto-trained model for session {session_id}: {training_metrics}")
|
|
96
|
+
|
|
97
|
+
# Record in audit log if this is an optimization iteration
|
|
98
|
+
if experiment.iteration is not None and experiment.iteration > 0:
|
|
99
|
+
session.audit_log.lock_model(
|
|
100
|
+
backend=backend,
|
|
101
|
+
kernel=kernel,
|
|
102
|
+
hyperparameters=hyperparameters,
|
|
103
|
+
cv_metrics=metrics,
|
|
104
|
+
iteration=experiment.iteration,
|
|
105
|
+
notes=f"Auto-trained after iteration {experiment.iteration}"
|
|
106
|
+
)
|
|
79
107
|
except Exception as e:
|
|
80
108
|
logger.error(f"Auto-train failed for session {session_id}: {e}")
|
|
81
109
|
# Don't fail the whole request, just log it
|
|
82
110
|
|
|
111
|
+
# Broadcast experiment update to WebSocket clients
|
|
112
|
+
await broadcast_to_session(session_id, {
|
|
113
|
+
"event": "experiments_updated",
|
|
114
|
+
"n_experiments": n_experiments
|
|
115
|
+
})
|
|
116
|
+
if model_trained:
|
|
117
|
+
await broadcast_to_session(session_id, {
|
|
118
|
+
"event": "model_trained",
|
|
119
|
+
"metrics": training_metrics
|
|
120
|
+
})
|
|
121
|
+
|
|
83
122
|
return ExperimentResponse(
|
|
84
123
|
message="Experiment added successfully",
|
|
85
124
|
n_experiments=n_experiments,
|
|
@@ -142,6 +181,17 @@ async def add_experiments_batch(
|
|
|
142
181
|
except Exception as e:
|
|
143
182
|
logger.error(f"Auto-train failed for session {session_id}: {e}")
|
|
144
183
|
|
|
184
|
+
# Broadcast experiment update to WebSocket clients
|
|
185
|
+
await broadcast_to_session(session_id, {
|
|
186
|
+
"event": "experiments_updated",
|
|
187
|
+
"n_experiments": n_experiments
|
|
188
|
+
})
|
|
189
|
+
if model_trained:
|
|
190
|
+
await broadcast_to_session(session_id, {
|
|
191
|
+
"event": "model_trained",
|
|
192
|
+
"metrics": training_metrics
|
|
193
|
+
})
|
|
194
|
+
|
|
145
195
|
return ExperimentResponse(
|
|
146
196
|
message=f"Added {len(batch.experiments)} experiments successfully",
|
|
147
197
|
n_experiments=n_experiments,
|
|
@@ -211,19 +261,84 @@ async def list_experiments(
|
|
|
211
261
|
)
|
|
212
262
|
|
|
213
263
|
|
|
264
|
+
@router.post("/{session_id}/experiments/preview")
|
|
265
|
+
async def preview_csv_columns(
|
|
266
|
+
session_id: str,
|
|
267
|
+
file: UploadFile = File(...),
|
|
268
|
+
session: OptimizationSession = Depends(get_session)
|
|
269
|
+
):
|
|
270
|
+
"""
|
|
271
|
+
Preview CSV file columns before uploading to check for target columns.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
- available_columns: List of all columns in CSV
|
|
275
|
+
- has_output: Whether 'Output' column exists
|
|
276
|
+
- recommended_target: Suggested target column if 'Output' missing
|
|
277
|
+
"""
|
|
278
|
+
# Save uploaded file temporarily
|
|
279
|
+
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.csv') as tmp:
|
|
280
|
+
content = await file.read()
|
|
281
|
+
tmp.write(content)
|
|
282
|
+
tmp_path = tmp.name
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
# Read CSV to get column names
|
|
286
|
+
df = pd.read_csv(tmp_path)
|
|
287
|
+
columns = df.columns.tolist()
|
|
288
|
+
|
|
289
|
+
# Check for 'Output' column
|
|
290
|
+
has_output = 'Output' in columns
|
|
291
|
+
|
|
292
|
+
# Filter out metadata columns
|
|
293
|
+
metadata_cols = {'Iteration', 'Reason', 'Noise'}
|
|
294
|
+
available_targets = [col for col in columns if col not in metadata_cols]
|
|
295
|
+
|
|
296
|
+
# Recommend target column
|
|
297
|
+
recommended = None
|
|
298
|
+
if not has_output:
|
|
299
|
+
# Look for common target column names
|
|
300
|
+
common_names = ['output', 'y', 'target', 'yield', 'response']
|
|
301
|
+
for name in common_names:
|
|
302
|
+
if name in [col.lower() for col in available_targets]:
|
|
303
|
+
recommended = [col for col in available_targets if col.lower() == name][0]
|
|
304
|
+
break
|
|
305
|
+
|
|
306
|
+
# If no common name found, use first numeric column
|
|
307
|
+
if not recommended and available_targets:
|
|
308
|
+
# Check if first available column is numeric
|
|
309
|
+
if pd.api.types.is_numeric_dtype(df[available_targets[0]]):
|
|
310
|
+
recommended = available_targets[0]
|
|
311
|
+
|
|
312
|
+
return {
|
|
313
|
+
"columns": columns,
|
|
314
|
+
"available_targets": available_targets,
|
|
315
|
+
"has_output": has_output,
|
|
316
|
+
"recommended_target": recommended,
|
|
317
|
+
"n_rows": len(df)
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
finally:
|
|
321
|
+
# Clean up temp file
|
|
322
|
+
if os.path.exists(tmp_path):
|
|
323
|
+
os.unlink(tmp_path)
|
|
324
|
+
|
|
325
|
+
|
|
214
326
|
@router.post("/{session_id}/experiments/upload")
|
|
215
327
|
async def upload_experiments(
|
|
216
328
|
session_id: str,
|
|
217
329
|
file: UploadFile = File(...),
|
|
218
|
-
|
|
330
|
+
target_columns: str = "Output", # Note: API accepts string, will be normalized by Session API
|
|
219
331
|
session: OptimizationSession = Depends(get_session)
|
|
220
332
|
):
|
|
221
333
|
"""
|
|
222
334
|
Upload experimental data from CSV file.
|
|
223
335
|
|
|
224
336
|
The CSV should have columns matching the variable names,
|
|
225
|
-
plus
|
|
226
|
-
|
|
337
|
+
plus target column(s) (default: "Output") and optional noise column ("Noise").
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
target_columns: Target column name (single-objective) or comma-separated names (multi-objective).
|
|
341
|
+
Examples: "Output", "yield", "yield,selectivity"
|
|
227
342
|
"""
|
|
228
343
|
# Check if variables are defined
|
|
229
344
|
if len(session.search_space.variables) == 0:
|
|
@@ -236,17 +351,26 @@ async def upload_experiments(
|
|
|
236
351
|
tmp_path = tmp.name
|
|
237
352
|
|
|
238
353
|
try:
|
|
354
|
+
# Parse target_columns (handle comma-separated for future multi-objective support)
|
|
355
|
+
target_cols_parsed = target_columns.split(',') if ',' in target_columns else target_columns
|
|
356
|
+
|
|
239
357
|
# Load data using session's load_data method
|
|
240
|
-
session.load_data(tmp_path,
|
|
358
|
+
session.load_data(tmp_path, target_columns=target_cols_parsed)
|
|
241
359
|
|
|
242
360
|
n_experiments = len(session.experiment_manager.df)
|
|
243
361
|
logger.info(f"Loaded {n_experiments} experiments from CSV for session {session_id}")
|
|
244
|
-
|
|
362
|
+
|
|
363
|
+
# Broadcast experiment update to WebSocket clients
|
|
364
|
+
await broadcast_to_session(session_id, {
|
|
365
|
+
"event": "experiments_updated",
|
|
366
|
+
"n_experiments": n_experiments
|
|
367
|
+
})
|
|
368
|
+
|
|
245
369
|
return {
|
|
246
370
|
"message": f"Loaded {n_experiments} experiments successfully",
|
|
247
371
|
"n_experiments": n_experiments
|
|
248
372
|
}
|
|
249
|
-
|
|
373
|
+
|
|
250
374
|
finally:
|
|
251
375
|
# Clean up temp file
|
|
252
376
|
if os.path.exists(tmp_path):
|
|
@@ -264,3 +388,220 @@ async def get_experiments_summary(
|
|
|
264
388
|
Returns sample size, target variable statistics, and feature information.
|
|
265
389
|
"""
|
|
266
390
|
return session.get_data_summary()
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
# ============================================================
|
|
394
|
+
# Staged Experiments Endpoints
|
|
395
|
+
# ============================================================
|
|
396
|
+
|
|
397
|
+
@router.post("/{session_id}/experiments/staged", response_model=StagedExperimentResponse)
|
|
398
|
+
async def stage_experiment(
|
|
399
|
+
session_id: str,
|
|
400
|
+
request: StageExperimentRequest,
|
|
401
|
+
session: OptimizationSession = Depends(get_session)
|
|
402
|
+
):
|
|
403
|
+
"""
|
|
404
|
+
Stage an experiment for later execution.
|
|
405
|
+
|
|
406
|
+
Staged experiments are stored in a queue awaiting evaluation.
|
|
407
|
+
This is useful for autonomous workflows where the controller
|
|
408
|
+
needs to track which experiments are pending execution.
|
|
409
|
+
|
|
410
|
+
Use GET /experiments/staged to retrieve staged experiments,
|
|
411
|
+
and POST /experiments/staged/complete to finalize them with outputs.
|
|
412
|
+
"""
|
|
413
|
+
# Check if variables are defined
|
|
414
|
+
if len(session.search_space.variables) == 0:
|
|
415
|
+
raise NoVariablesError("No variables defined. Add variables to search space first.")
|
|
416
|
+
|
|
417
|
+
# Add reason metadata if provided
|
|
418
|
+
inputs_with_meta = dict(request.inputs)
|
|
419
|
+
if request.reason:
|
|
420
|
+
inputs_with_meta['_reason'] = request.reason
|
|
421
|
+
|
|
422
|
+
session.add_staged_experiment(inputs_with_meta)
|
|
423
|
+
|
|
424
|
+
n_staged = len(session.get_staged_experiments())
|
|
425
|
+
logger.info(f"Staged experiment for session {session_id}. Total staged: {n_staged}")
|
|
426
|
+
|
|
427
|
+
return StagedExperimentResponse(
|
|
428
|
+
message="Experiment staged successfully",
|
|
429
|
+
n_staged=n_staged,
|
|
430
|
+
staged_inputs=request.inputs
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
@router.post("/{session_id}/experiments/staged/batch", response_model=StagedExperimentsListResponse)
|
|
435
|
+
async def stage_experiments_batch(
|
|
436
|
+
session_id: str,
|
|
437
|
+
request: StageExperimentsBatchRequest,
|
|
438
|
+
session: OptimizationSession = Depends(get_session)
|
|
439
|
+
):
|
|
440
|
+
"""
|
|
441
|
+
Stage multiple experiments at once.
|
|
442
|
+
|
|
443
|
+
Useful after acquisition functions suggest multiple points for parallel execution.
|
|
444
|
+
The `reason` parameter is stored as metadata and will be used when completing
|
|
445
|
+
the experiments (recorded in the 'Reason' column of the experiment data).
|
|
446
|
+
"""
|
|
447
|
+
# Check if variables are defined
|
|
448
|
+
if len(session.search_space.variables) == 0:
|
|
449
|
+
raise NoVariablesError("No variables defined. Add variables to search space first.")
|
|
450
|
+
|
|
451
|
+
for inputs in request.experiments:
|
|
452
|
+
inputs_with_meta = dict(inputs)
|
|
453
|
+
if request.reason:
|
|
454
|
+
inputs_with_meta['_reason'] = request.reason
|
|
455
|
+
session.add_staged_experiment(inputs_with_meta)
|
|
456
|
+
|
|
457
|
+
logger.info(f"Staged {len(request.experiments)} experiments for session {session_id}. Total staged: {len(session.get_staged_experiments())}")
|
|
458
|
+
|
|
459
|
+
# Return clean experiments (without metadata) for client use
|
|
460
|
+
return StagedExperimentsListResponse(
|
|
461
|
+
experiments=request.experiments, # Return the original clean inputs
|
|
462
|
+
n_staged=len(session.get_staged_experiments()),
|
|
463
|
+
reason=request.reason
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
@router.get("/{session_id}/experiments/staged", response_model=StagedExperimentsListResponse)
|
|
468
|
+
async def get_staged_experiments(
|
|
469
|
+
session_id: str,
|
|
470
|
+
session: OptimizationSession = Depends(get_session)
|
|
471
|
+
):
|
|
472
|
+
"""
|
|
473
|
+
Get all staged experiments awaiting execution.
|
|
474
|
+
|
|
475
|
+
Returns the list of experiments that have been queued but not yet
|
|
476
|
+
completed with output values. The response includes:
|
|
477
|
+
- experiments: Clean variable inputs only (no metadata)
|
|
478
|
+
- reason: The strategy/reason for these experiments (if provided when staging)
|
|
479
|
+
"""
|
|
480
|
+
staged = session.get_staged_experiments()
|
|
481
|
+
|
|
482
|
+
# Extract reason from first experiment (if present) and clean all experiments
|
|
483
|
+
reason = None
|
|
484
|
+
clean_experiments = []
|
|
485
|
+
for exp in staged:
|
|
486
|
+
if '_reason' in exp and reason is None:
|
|
487
|
+
reason = exp['_reason']
|
|
488
|
+
# Return only variable values, not metadata
|
|
489
|
+
clean_exp = {k: v for k, v in exp.items() if not k.startswith('_')}
|
|
490
|
+
clean_experiments.append(clean_exp)
|
|
491
|
+
|
|
492
|
+
return StagedExperimentsListResponse(
|
|
493
|
+
experiments=clean_experiments,
|
|
494
|
+
n_staged=len(staged),
|
|
495
|
+
reason=reason
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@router.delete("/{session_id}/experiments/staged", response_model=StagedExperimentsClearResponse)
|
|
500
|
+
async def clear_staged_experiments(
|
|
501
|
+
session_id: str,
|
|
502
|
+
session: OptimizationSession = Depends(get_session)
|
|
503
|
+
):
|
|
504
|
+
"""
|
|
505
|
+
Clear all staged experiments.
|
|
506
|
+
|
|
507
|
+
Use this to reset the staging queue if experiments were cancelled
|
|
508
|
+
or need to be regenerated.
|
|
509
|
+
"""
|
|
510
|
+
n_cleared = session.clear_staged_experiments()
|
|
511
|
+
logger.info(f"Cleared {n_cleared} staged experiments for session {session_id}")
|
|
512
|
+
|
|
513
|
+
return StagedExperimentsClearResponse(
|
|
514
|
+
message="Staged experiments cleared",
|
|
515
|
+
n_cleared=n_cleared
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
@router.post("/{session_id}/experiments/staged/complete", response_model=StagedExperimentsCompletedResponse)
|
|
520
|
+
async def complete_staged_experiments(
|
|
521
|
+
session_id: str,
|
|
522
|
+
request: CompleteStagedExperimentsRequest,
|
|
523
|
+
auto_train: bool = Query(False, description="Auto-train model after adding data"),
|
|
524
|
+
training_backend: Optional[str] = Query(None, description="Model backend (sklearn/botorch)"),
|
|
525
|
+
training_kernel: Optional[str] = Query(None, description="Kernel type (rbf/matern)"),
|
|
526
|
+
session: OptimizationSession = Depends(get_session)
|
|
527
|
+
):
|
|
528
|
+
"""
|
|
529
|
+
Complete staged experiments by providing output values.
|
|
530
|
+
|
|
531
|
+
This pairs the staged experiment inputs with the provided outputs,
|
|
532
|
+
adds them to the experiment dataset, and clears the staging queue.
|
|
533
|
+
|
|
534
|
+
The number of outputs must match the number of staged experiments.
|
|
535
|
+
Outputs should be in the same order as the staged experiments were added.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
auto_train: If True, retrain model after adding data
|
|
539
|
+
training_backend: Model backend (uses last if None)
|
|
540
|
+
training_kernel: Kernel type (uses last or 'rbf' if None)
|
|
541
|
+
"""
|
|
542
|
+
staged = session.get_staged_experiments()
|
|
543
|
+
|
|
544
|
+
if len(staged) == 0:
|
|
545
|
+
return StagedExperimentsCompletedResponse(
|
|
546
|
+
message="No staged experiments to complete",
|
|
547
|
+
n_added=0,
|
|
548
|
+
n_experiments=len(session.experiment_manager.df),
|
|
549
|
+
model_trained=False
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if len(request.outputs) != len(staged):
|
|
553
|
+
raise ValueError(
|
|
554
|
+
f"Number of outputs ({len(request.outputs)}) must match "
|
|
555
|
+
f"number of staged experiments ({len(staged)})"
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Use the core Session method to move staged experiments to dataset
|
|
559
|
+
n_added = session.move_staged_to_experiments(
|
|
560
|
+
outputs=request.outputs,
|
|
561
|
+
noises=request.noises,
|
|
562
|
+
iteration=request.iteration,
|
|
563
|
+
reason=request.reason
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
n_experiments = len(session.experiment_manager.df)
|
|
567
|
+
logger.info(f"Completed {n_added} staged experiments for session {session_id}. Total: {n_experiments}")
|
|
568
|
+
|
|
569
|
+
# Auto-train if requested
|
|
570
|
+
model_trained = False
|
|
571
|
+
training_metrics = None
|
|
572
|
+
|
|
573
|
+
if auto_train and n_experiments >= 5:
|
|
574
|
+
try:
|
|
575
|
+
backend = training_backend or (session.model_backend if session.model else "sklearn")
|
|
576
|
+
kernel = training_kernel or "rbf"
|
|
577
|
+
|
|
578
|
+
result = session.train_model(backend=backend, kernel=kernel)
|
|
579
|
+
model_trained = True
|
|
580
|
+
metrics = result.get("metrics", {})
|
|
581
|
+
training_metrics = {
|
|
582
|
+
"rmse": metrics.get("rmse"),
|
|
583
|
+
"r2": metrics.get("r2"),
|
|
584
|
+
"backend": backend
|
|
585
|
+
}
|
|
586
|
+
logger.info(f"Auto-trained model for session {session_id}: {training_metrics}")
|
|
587
|
+
except Exception as e:
|
|
588
|
+
logger.error(f"Auto-train failed for session {session_id}: {e}")
|
|
589
|
+
|
|
590
|
+
# Broadcast experiment update to WebSocket clients
|
|
591
|
+
await broadcast_to_session(session_id, {
|
|
592
|
+
"event": "experiments_updated",
|
|
593
|
+
"n_experiments": n_experiments
|
|
594
|
+
})
|
|
595
|
+
if model_trained:
|
|
596
|
+
await broadcast_to_session(session_id, {
|
|
597
|
+
"event": "model_trained",
|
|
598
|
+
"metrics": training_metrics
|
|
599
|
+
})
|
|
600
|
+
|
|
601
|
+
return StagedExperimentsCompletedResponse(
|
|
602
|
+
message="Staged experiments completed and added to dataset",
|
|
603
|
+
n_added=n_added,
|
|
604
|
+
n_experiments=n_experiments,
|
|
605
|
+
model_trained=model_trained,
|
|
606
|
+
training_metrics=training_metrics
|
|
607
|
+
)
|
api/routers/sessions.py
CHANGED
|
@@ -4,11 +4,14 @@ Sessions router - Session lifecycle management.
|
|
|
4
4
|
|
|
5
5
|
from fastapi import APIRouter, HTTPException, status, UploadFile, File, Depends
|
|
6
6
|
from fastapi.responses import Response, FileResponse, JSONResponse
|
|
7
|
-
from
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from ..models.requests import UpdateMetadataRequest, LockDecisionRequest, SessionLockRequest
|
|
8
9
|
from ..models.responses import (
|
|
9
10
|
SessionCreateResponse, SessionInfoResponse, SessionStateResponse,
|
|
10
|
-
SessionMetadataResponse, AuditLogResponse, AuditEntryResponse, LockDecisionResponse
|
|
11
|
+
SessionMetadataResponse, AuditLogResponse, AuditEntryResponse, LockDecisionResponse,
|
|
12
|
+
SessionLockResponse
|
|
11
13
|
)
|
|
14
|
+
from .websocket import broadcast_to_session
|
|
12
15
|
from ..services import session_store
|
|
13
16
|
from ..dependencies import get_session
|
|
14
17
|
from alchemist_core.session import OptimizationSession
|
|
@@ -36,8 +39,7 @@ async def create_session():
|
|
|
36
39
|
|
|
37
40
|
return SessionCreateResponse(
|
|
38
41
|
session_id=session_id,
|
|
39
|
-
created_at=session_info["created_at"]
|
|
40
|
-
expires_at=session_info["expires_at"]
|
|
42
|
+
created_at=session_info["created_at"]
|
|
41
43
|
)
|
|
42
44
|
|
|
43
45
|
|
|
@@ -120,10 +122,8 @@ async def extend_session(session_id: str, hours: int = 24):
|
|
|
120
122
|
detail=f"Session {session_id} not found"
|
|
121
123
|
)
|
|
122
124
|
|
|
123
|
-
info = session_store.get_info(session_id)
|
|
124
125
|
return {
|
|
125
|
-
"message": "Session TTL extended"
|
|
126
|
-
"expires_at": info["expires_at"]
|
|
126
|
+
"message": "Session TTL extended (legacy endpoint - no longer has effect)"
|
|
127
127
|
}
|
|
128
128
|
|
|
129
129
|
|
|
@@ -191,8 +191,7 @@ async def import_session(file: UploadFile = File(...)):
|
|
|
191
191
|
session_info = session_store.get_info(session_id)
|
|
192
192
|
return SessionCreateResponse(
|
|
193
193
|
session_id=session_id,
|
|
194
|
-
created_at=session_info["created_at"]
|
|
195
|
-
expires_at=session_info["expires_at"]
|
|
194
|
+
created_at=session_info["created_at"]
|
|
196
195
|
)
|
|
197
196
|
|
|
198
197
|
except Exception as e:
|
|
@@ -449,8 +448,7 @@ async def upload_session(file: UploadFile = File(...)):
|
|
|
449
448
|
|
|
450
449
|
return SessionCreateResponse(
|
|
451
450
|
session_id=new_session_id,
|
|
452
|
-
created_at=session_info["created_at"]
|
|
453
|
-
expires_at=session_info["expires_at"]
|
|
451
|
+
created_at=session_info["created_at"]
|
|
454
452
|
)
|
|
455
453
|
|
|
456
454
|
finally:
|
|
@@ -463,3 +461,189 @@ async def upload_session(file: UploadFile = File(...)):
|
|
|
463
461
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
464
462
|
detail=f"Failed to upload session: {str(e)}"
|
|
465
463
|
)
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
# ============================================================
|
|
467
|
+
# Recovery / Backup Endpoints
|
|
468
|
+
# ============================================================
|
|
469
|
+
|
|
470
|
+
@router.post("/sessions/{session_id}/backup", status_code=status.HTTP_200_OK)
|
|
471
|
+
async def create_recovery_backup(session_id: str):
|
|
472
|
+
"""
|
|
473
|
+
Create a silent recovery backup for crash protection.
|
|
474
|
+
|
|
475
|
+
Called automatically by frontend every 30 seconds while user has session open.
|
|
476
|
+
User never sees these backups unless browser crashes and recovery is needed.
|
|
477
|
+
"""
|
|
478
|
+
success = session_store.save_recovery_backup(session_id)
|
|
479
|
+
if not success:
|
|
480
|
+
raise HTTPException(
|
|
481
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
482
|
+
detail=f"Session {session_id} not found or backup failed"
|
|
483
|
+
)
|
|
484
|
+
return {"message": "Recovery backup created"}
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
@router.delete("/sessions/{session_id}/backup", status_code=status.HTTP_200_OK)
|
|
488
|
+
async def clear_recovery_backup(session_id: str):
|
|
489
|
+
"""
|
|
490
|
+
Clear recovery backups for a session.
|
|
491
|
+
|
|
492
|
+
Called after user successfully saves their session to their computer.
|
|
493
|
+
This prevents recovery prompt from appearing unnecessarily.
|
|
494
|
+
"""
|
|
495
|
+
deleted = session_store.clear_recovery_backup(session_id)
|
|
496
|
+
return {"message": "Recovery backup cleared", "deleted": deleted}
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@router.get("/recovery/list")
|
|
500
|
+
async def list_recovery_sessions():
|
|
501
|
+
"""
|
|
502
|
+
List available recovery sessions.
|
|
503
|
+
|
|
504
|
+
Called on app startup to check if there are any unsaved sessions
|
|
505
|
+
that can be recovered from a crash.
|
|
506
|
+
"""
|
|
507
|
+
recoveries = session_store.list_recovery_sessions()
|
|
508
|
+
return {"recoveries": recoveries, "count": len(recoveries)}
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
@router.post("/recovery/{session_id}/restore", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED)
|
|
512
|
+
async def restore_recovery_session(session_id: str):
|
|
513
|
+
"""
|
|
514
|
+
Restore a session from recovery backup.
|
|
515
|
+
|
|
516
|
+
Called when user clicks "Restore" on the recovery banner.
|
|
517
|
+
Creates a new active session from the recovery file.
|
|
518
|
+
"""
|
|
519
|
+
new_session_id = session_store.restore_from_recovery(session_id)
|
|
520
|
+
|
|
521
|
+
if new_session_id is None:
|
|
522
|
+
raise HTTPException(
|
|
523
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
524
|
+
detail=f"No recovery data found for session {session_id}"
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
session_info = session_store.get_info(new_session_id)
|
|
528
|
+
return SessionCreateResponse(
|
|
529
|
+
session_id=new_session_id,
|
|
530
|
+
created_at=session_info["created_at"]
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
@router.delete("/recovery/cleanup")
|
|
535
|
+
async def cleanup_old_recoveries(max_age_hours: int = 24):
|
|
536
|
+
"""
|
|
537
|
+
Clean up old recovery files.
|
|
538
|
+
|
|
539
|
+
Deletes recovery files older than specified hours.
|
|
540
|
+
Can be called manually or via scheduled task.
|
|
541
|
+
"""
|
|
542
|
+
session_store.cleanup_old_recoveries(max_age_hours)
|
|
543
|
+
return {"message": f"Cleaned up recovery files older than {max_age_hours} hours"}
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
# ============================================================
|
|
547
|
+
# Session Locking Endpoints
|
|
548
|
+
# ============================================================
|
|
549
|
+
|
|
550
|
+
@router.post("/sessions/{session_id}/lock", response_model=SessionLockResponse)
|
|
551
|
+
async def lock_session(
|
|
552
|
+
session_id: str,
|
|
553
|
+
request: SessionLockRequest
|
|
554
|
+
):
|
|
555
|
+
"""
|
|
556
|
+
Lock a session for external programmatic control.
|
|
557
|
+
|
|
558
|
+
When locked, the web UI should enter monitor-only mode.
|
|
559
|
+
Returns a lock_token that must be used to unlock.
|
|
560
|
+
"""
|
|
561
|
+
try:
|
|
562
|
+
result = session_store.lock_session(
|
|
563
|
+
session_id=session_id,
|
|
564
|
+
locked_by=request.locked_by,
|
|
565
|
+
client_id=request.client_id
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Broadcast lock event to WebSocket clients
|
|
569
|
+
await broadcast_to_session(session_id, {
|
|
570
|
+
"event": "lock_status_changed",
|
|
571
|
+
"locked": True,
|
|
572
|
+
"locked_by": request.locked_by,
|
|
573
|
+
"locked_at": result["locked_at"]
|
|
574
|
+
})
|
|
575
|
+
|
|
576
|
+
return SessionLockResponse(**result)
|
|
577
|
+
except KeyError:
|
|
578
|
+
raise HTTPException(
|
|
579
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
580
|
+
detail=f"Session {session_id} not found or expired"
|
|
581
|
+
)
|
|
582
|
+
except Exception as e:
|
|
583
|
+
raise HTTPException(
|
|
584
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
585
|
+
detail=f"Failed to lock session: {str(e)}"
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@router.delete("/sessions/{session_id}/lock", response_model=SessionLockResponse)
|
|
590
|
+
async def unlock_session(
|
|
591
|
+
session_id: str,
|
|
592
|
+
lock_token: Optional[str] = None
|
|
593
|
+
):
|
|
594
|
+
"""
|
|
595
|
+
Unlock a session.
|
|
596
|
+
|
|
597
|
+
Optionally provide lock_token for verification.
|
|
598
|
+
If no token provided, forcibly unlocks (use with caution).
|
|
599
|
+
"""
|
|
600
|
+
try:
|
|
601
|
+
result = session_store.unlock_session(session_id=session_id, lock_token=lock_token)
|
|
602
|
+
|
|
603
|
+
# Broadcast unlock event to WebSocket clients
|
|
604
|
+
await broadcast_to_session(session_id, {
|
|
605
|
+
"event": "lock_status_changed",
|
|
606
|
+
"locked": False,
|
|
607
|
+
"locked_by": None,
|
|
608
|
+
"locked_at": None
|
|
609
|
+
})
|
|
610
|
+
|
|
611
|
+
return SessionLockResponse(**result)
|
|
612
|
+
except KeyError:
|
|
613
|
+
raise HTTPException(
|
|
614
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
615
|
+
detail=f"Session {session_id} not found or expired"
|
|
616
|
+
)
|
|
617
|
+
except ValueError as e:
|
|
618
|
+
raise HTTPException(
|
|
619
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
620
|
+
detail=str(e)
|
|
621
|
+
)
|
|
622
|
+
except Exception as e:
|
|
623
|
+
raise HTTPException(
|
|
624
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
625
|
+
detail=f"Failed to unlock session: {str(e)}"
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
@router.get("/sessions/{session_id}/lock", response_model=SessionLockResponse)
|
|
630
|
+
async def get_lock_status(session_id: str):
|
|
631
|
+
"""
|
|
632
|
+
Get current lock status of a session.
|
|
633
|
+
|
|
634
|
+
Used by web UI to detect when external controller has taken control
|
|
635
|
+
and automatically enter monitor mode.
|
|
636
|
+
"""
|
|
637
|
+
try:
|
|
638
|
+
result = session_store.get_lock_status(session_id=session_id)
|
|
639
|
+
return SessionLockResponse(**result)
|
|
640
|
+
except KeyError:
|
|
641
|
+
raise HTTPException(
|
|
642
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
643
|
+
detail=f"Session {session_id} not found or expired"
|
|
644
|
+
)
|
|
645
|
+
except Exception as e:
|
|
646
|
+
raise HTTPException(
|
|
647
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
648
|
+
detail=f"Failed to get lock status: {str(e)}"
|
|
649
|
+
)
|