alchemist-nrel 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alchemist_core/__init__.py +2 -2
- alchemist_core/acquisition/botorch_acquisition.py +84 -126
- alchemist_core/data/experiment_manager.py +196 -20
- alchemist_core/models/botorch_model.py +292 -63
- alchemist_core/models/sklearn_model.py +175 -15
- alchemist_core/session.py +3532 -76
- alchemist_core/utils/__init__.py +3 -1
- alchemist_core/utils/acquisition_utils.py +60 -0
- alchemist_core/visualization/__init__.py +45 -0
- alchemist_core/visualization/helpers.py +130 -0
- alchemist_core/visualization/plots.py +1449 -0
- alchemist_nrel-0.3.2.dist-info/METADATA +185 -0
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/RECORD +34 -29
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/WHEEL +1 -1
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/entry_points.txt +1 -1
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/top_level.txt +0 -1
- api/example_client.py +7 -2
- api/main.py +3 -2
- api/models/requests.py +76 -1
- api/models/responses.py +102 -2
- api/routers/acquisition.py +25 -0
- api/routers/experiments.py +352 -11
- api/routers/sessions.py +195 -11
- api/routers/visualizations.py +6 -4
- api/routers/websocket.py +132 -0
- run_api.py → api/run_api.py +8 -7
- api/services/session_store.py +370 -71
- api/static/assets/index-B6Cf6s_b.css +1 -0
- api/static/assets/{index-C0_glioA.js → index-B7njvc9r.js} +223 -208
- api/static/index.html +2 -2
- ui/gpr_panel.py +11 -5
- ui/target_column_dialog.py +299 -0
- ui/ui.py +52 -5
- alchemist_core/models/ax_model.py +0 -159
- alchemist_nrel-0.3.0.dist-info/METADATA +0 -223
- api/static/assets/index-CB4V1LI5.css +0 -1
- {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/licenses/LICENSE +0 -0
api/routers/visualizations.py
CHANGED
|
@@ -194,9 +194,10 @@ async def get_contour_data(
|
|
|
194
194
|
# CRITICAL FIX: Reorder columns to match training data
|
|
195
195
|
# The model was trained with a specific column order, we must match it.
|
|
196
196
|
# Exclude metadata columns that are part of the experiments table but
|
|
197
|
-
# are not model input features (e.g., Iteration, Reason,
|
|
197
|
+
# are not model input features (e.g., Iteration, Reason, target columns, Noise).
|
|
198
198
|
train_data = session.experiment_manager.get_data()
|
|
199
|
-
|
|
199
|
+
target_cols = set(session.experiment_manager.target_columns)
|
|
200
|
+
metadata_cols = {'Iteration', 'Reason', 'Noise'} | target_cols
|
|
200
201
|
feature_cols = [col for col in train_data.columns if col not in metadata_cols]
|
|
201
202
|
|
|
202
203
|
# Safely align the prediction grid to the model feature order.
|
|
@@ -216,11 +217,12 @@ async def get_contour_data(
|
|
|
216
217
|
experiments_data = None
|
|
217
218
|
if request.include_experiments and len(session.experiment_manager) > 0:
|
|
218
219
|
exp_df = session.experiment_manager.get_data()
|
|
219
|
-
|
|
220
|
+
target_col = session.experiment_manager.target_columns[0] # Use first target for visualization
|
|
221
|
+
if request.x_var in exp_df.columns and request.y_var in exp_df.columns and target_col in exp_df.columns:
|
|
220
222
|
experiments_data = {
|
|
221
223
|
"x": exp_df[request.x_var].tolist(),
|
|
222
224
|
"y": exp_df[request.y_var].tolist(),
|
|
223
|
-
"output": exp_df[
|
|
225
|
+
"output": exp_df[target_col].tolist()
|
|
224
226
|
}
|
|
225
227
|
|
|
226
228
|
# Get suggestion data if requested (would need to be stored in session)
|
api/routers/websocket.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebSocket router for real-time session updates.
|
|
3
|
+
|
|
4
|
+
Provides real-time push notifications for session events like lock status changes,
|
|
5
|
+
eliminating the need for client-side polling.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
9
|
+
from typing import Dict, Set
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
router = APIRouter()
|
|
16
|
+
|
|
17
|
+
# Track active WebSocket connections per session
|
|
18
|
+
# Structure: {session_id: {websocket1, websocket2, ...}}
|
|
19
|
+
active_connections: Dict[str, Set[WebSocket]] = {}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@router.websocket("/ws/sessions/{session_id}")
|
|
23
|
+
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
24
|
+
"""
|
|
25
|
+
WebSocket endpoint for real-time session updates.
|
|
26
|
+
|
|
27
|
+
Clients connect to this endpoint to receive push notifications about
|
|
28
|
+
session events (lock status changes, experiment additions, etc.).
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
websocket: WebSocket connection
|
|
32
|
+
session_id: Session ID to subscribe to
|
|
33
|
+
"""
|
|
34
|
+
await websocket.accept()
|
|
35
|
+
logger.debug(f"WebSocket connected: session_id={session_id}")
|
|
36
|
+
|
|
37
|
+
# Register this connection for this session
|
|
38
|
+
if session_id not in active_connections:
|
|
39
|
+
active_connections[session_id] = set()
|
|
40
|
+
active_connections[session_id].add(websocket)
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
# Send initial connection confirmation
|
|
44
|
+
await websocket.send_json({
|
|
45
|
+
"event": "connected",
|
|
46
|
+
"session_id": session_id,
|
|
47
|
+
"message": "WebSocket connection established"
|
|
48
|
+
})
|
|
49
|
+
|
|
50
|
+
# Keep connection alive and listen for client messages
|
|
51
|
+
while True:
|
|
52
|
+
# Receive messages from client (for future bi-directional features)
|
|
53
|
+
data = await websocket.receive_text()
|
|
54
|
+
|
|
55
|
+
# Echo back for debugging (can be removed in production)
|
|
56
|
+
try:
|
|
57
|
+
message = json.loads(data)
|
|
58
|
+
logger.debug(f"Received from client: {message}")
|
|
59
|
+
except json.JSONDecodeError:
|
|
60
|
+
logger.warning(f"Invalid JSON from client: {data}")
|
|
61
|
+
|
|
62
|
+
except WebSocketDisconnect:
|
|
63
|
+
logger.debug(f"WebSocket disconnected: session_id={session_id}")
|
|
64
|
+
finally:
|
|
65
|
+
# Clean up on disconnect
|
|
66
|
+
if session_id in active_connections:
|
|
67
|
+
active_connections[session_id].discard(websocket)
|
|
68
|
+
|
|
69
|
+
# Remove session entry if no more connections
|
|
70
|
+
if not active_connections[session_id]:
|
|
71
|
+
del active_connections[session_id]
|
|
72
|
+
logger.debug(f"No more connections for session {session_id}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
async def broadcast_to_session(session_id: str, event: dict):
|
|
76
|
+
"""
|
|
77
|
+
Broadcast an event to all WebSocket clients connected to a session.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
session_id: Session ID to broadcast to
|
|
81
|
+
event: Event data to send (will be JSON serialized)
|
|
82
|
+
"""
|
|
83
|
+
if session_id not in active_connections:
|
|
84
|
+
logger.debug(f"No active connections for session {session_id}")
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
# Track dead connections
|
|
88
|
+
dead_connections = set()
|
|
89
|
+
|
|
90
|
+
# Send to all connected clients
|
|
91
|
+
for connection in active_connections[session_id]:
|
|
92
|
+
try:
|
|
93
|
+
await connection.send_json(event)
|
|
94
|
+
logger.debug(f"Broadcast to session {session_id}: {event.get('event')}")
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.warning(f"Failed to send to connection: {e}")
|
|
97
|
+
dead_connections.add(connection)
|
|
98
|
+
|
|
99
|
+
# Clean up dead connections
|
|
100
|
+
if dead_connections:
|
|
101
|
+
active_connections[session_id] -= dead_connections
|
|
102
|
+
logger.info(f"Cleaned up {len(dead_connections)} dead connections")
|
|
103
|
+
|
|
104
|
+
# Remove session if no connections left
|
|
105
|
+
if not active_connections[session_id]:
|
|
106
|
+
del active_connections[session_id]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_connection_count(session_id: str) -> int:
|
|
110
|
+
"""
|
|
111
|
+
Get the number of active WebSocket connections for a session.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
session_id: Session ID to check
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Number of active connections
|
|
118
|
+
"""
|
|
119
|
+
return len(active_connections.get(session_id, set()))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_all_connection_counts() -> Dict[str, int]:
|
|
123
|
+
"""
|
|
124
|
+
Get connection counts for all sessions.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Dictionary mapping session_id to connection count
|
|
128
|
+
"""
|
|
129
|
+
return {
|
|
130
|
+
session_id: len(connections)
|
|
131
|
+
for session_id, connections in active_connections.items()
|
|
132
|
+
}
|
run_api.py → api/run_api.py
RENAMED
|
@@ -2,10 +2,10 @@
|
|
|
2
2
|
Startup script for ALchemist FastAPI server.
|
|
3
3
|
|
|
4
4
|
Usage:
|
|
5
|
-
python run_api
|
|
6
|
-
python run_api
|
|
7
|
-
python run_api
|
|
8
|
-
alchemist-web
|
|
5
|
+
python -m api.run_api # Development mode with auto-reload
|
|
6
|
+
python -m api.run_api --production # Production mode (no reload)
|
|
7
|
+
python -m api.run_api --dev # Explicitly start in development mode
|
|
8
|
+
alchemist-web # Entry point (production mode by default)
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
def main():
|
|
@@ -15,10 +15,10 @@ def main():
|
|
|
15
15
|
|
|
16
16
|
# For the alchemist-web entry point, default to production mode
|
|
17
17
|
# Only use dev mode if explicitly requested
|
|
18
|
-
is_script_call = any(arg.endswith('run_api.py') for arg in sys.argv)
|
|
18
|
+
is_script_call = any(arg.endswith('run_api.py') or 'api.run_api' in arg for arg in sys.argv)
|
|
19
19
|
|
|
20
20
|
if is_script_call:
|
|
21
|
-
# Called as: python run_api
|
|
21
|
+
# Called as: python -m api.run_api
|
|
22
22
|
# Default to dev mode unless --production flag is present
|
|
23
23
|
production = "--production" in sys.argv or "--prod" in sys.argv
|
|
24
24
|
else:
|
|
@@ -47,7 +47,8 @@ def main():
|
|
|
47
47
|
host="0.0.0.0",
|
|
48
48
|
port=8000,
|
|
49
49
|
reload=True,
|
|
50
|
-
log_level="
|
|
50
|
+
log_level="warning", # Suppress INFO logs from polling
|
|
51
|
+
access_log=False # Disable access logs entirely
|
|
51
52
|
)
|
|
52
53
|
|
|
53
54
|
|