alchemist-nrel 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. alchemist_core/__init__.py +2 -2
  2. alchemist_core/acquisition/botorch_acquisition.py +84 -126
  3. alchemist_core/data/experiment_manager.py +196 -20
  4. alchemist_core/models/botorch_model.py +292 -63
  5. alchemist_core/models/sklearn_model.py +175 -15
  6. alchemist_core/session.py +3532 -76
  7. alchemist_core/utils/__init__.py +3 -1
  8. alchemist_core/utils/acquisition_utils.py +60 -0
  9. alchemist_core/visualization/__init__.py +45 -0
  10. alchemist_core/visualization/helpers.py +130 -0
  11. alchemist_core/visualization/plots.py +1449 -0
  12. alchemist_nrel-0.3.2.dist-info/METADATA +185 -0
  13. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/RECORD +34 -29
  14. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/WHEEL +1 -1
  15. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/entry_points.txt +1 -1
  16. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/top_level.txt +0 -1
  17. api/example_client.py +7 -2
  18. api/main.py +3 -2
  19. api/models/requests.py +76 -1
  20. api/models/responses.py +102 -2
  21. api/routers/acquisition.py +25 -0
  22. api/routers/experiments.py +352 -11
  23. api/routers/sessions.py +195 -11
  24. api/routers/visualizations.py +6 -4
  25. api/routers/websocket.py +132 -0
  26. run_api.py → api/run_api.py +8 -7
  27. api/services/session_store.py +370 -71
  28. api/static/assets/index-B6Cf6s_b.css +1 -0
  29. api/static/assets/{index-C0_glioA.js → index-B7njvc9r.js} +223 -208
  30. api/static/index.html +2 -2
  31. ui/gpr_panel.py +11 -5
  32. ui/target_column_dialog.py +299 -0
  33. ui/ui.py +52 -5
  34. alchemist_core/models/ax_model.py +0 -159
  35. alchemist_nrel-0.3.0.dist-info/METADATA +0 -223
  36. api/static/assets/index-CB4V1LI5.css +0 -1
  37. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -194,9 +194,10 @@ async def get_contour_data(
194
194
  # CRITICAL FIX: Reorder columns to match training data
195
195
  # The model was trained with a specific column order, we must match it.
196
196
  # Exclude metadata columns that are part of the experiments table but
197
- # are not model input features (e.g., Iteration, Reason, Output, Noise).
197
+ # are not model input features (e.g., Iteration, Reason, target columns, Noise).
198
198
  train_data = session.experiment_manager.get_data()
199
- metadata_cols = {'Iteration', 'Reason', 'Output', 'Noise'}
199
+ target_cols = set(session.experiment_manager.target_columns)
200
+ metadata_cols = {'Iteration', 'Reason', 'Noise'} | target_cols
200
201
  feature_cols = [col for col in train_data.columns if col not in metadata_cols]
201
202
 
202
203
  # Safely align the prediction grid to the model feature order.
@@ -216,11 +217,12 @@ async def get_contour_data(
216
217
  experiments_data = None
217
218
  if request.include_experiments and len(session.experiment_manager) > 0:
218
219
  exp_df = session.experiment_manager.get_data()
219
- if request.x_var in exp_df.columns and request.y_var in exp_df.columns and "Output" in exp_df.columns:
220
+ target_col = session.experiment_manager.target_columns[0] # Use first target for visualization
221
+ if request.x_var in exp_df.columns and request.y_var in exp_df.columns and target_col in exp_df.columns:
220
222
  experiments_data = {
221
223
  "x": exp_df[request.x_var].tolist(),
222
224
  "y": exp_df[request.y_var].tolist(),
223
- "output": exp_df["Output"].tolist()
225
+ "output": exp_df[target_col].tolist()
224
226
  }
225
227
 
226
228
  # Get suggestion data if requested (would need to be stored in session)
@@ -0,0 +1,132 @@
1
+ """
2
+ WebSocket router for real-time session updates.
3
+
4
+ Provides real-time push notifications for session events like lock status changes,
5
+ eliminating the need for client-side polling.
6
+ """
7
+
8
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
9
+ from typing import Dict, Set
10
+ import json
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ router = APIRouter()
16
+
17
+ # Track active WebSocket connections per session
18
+ # Structure: {session_id: {websocket1, websocket2, ...}}
19
+ active_connections: Dict[str, Set[WebSocket]] = {}
20
+
21
+
22
+ @router.websocket("/ws/sessions/{session_id}")
23
+ async def websocket_endpoint(websocket: WebSocket, session_id: str):
24
+ """
25
+ WebSocket endpoint for real-time session updates.
26
+
27
+ Clients connect to this endpoint to receive push notifications about
28
+ session events (lock status changes, experiment additions, etc.).
29
+
30
+ Args:
31
+ websocket: WebSocket connection
32
+ session_id: Session ID to subscribe to
33
+ """
34
+ await websocket.accept()
35
+ logger.debug(f"WebSocket connected: session_id={session_id}")
36
+
37
+ # Register this connection for this session
38
+ if session_id not in active_connections:
39
+ active_connections[session_id] = set()
40
+ active_connections[session_id].add(websocket)
41
+
42
+ try:
43
+ # Send initial connection confirmation
44
+ await websocket.send_json({
45
+ "event": "connected",
46
+ "session_id": session_id,
47
+ "message": "WebSocket connection established"
48
+ })
49
+
50
+ # Keep connection alive and listen for client messages
51
+ while True:
52
+ # Receive messages from client (for future bi-directional features)
53
+ data = await websocket.receive_text()
54
+
55
+ # Echo back for debugging (can be removed in production)
56
+ try:
57
+ message = json.loads(data)
58
+ logger.debug(f"Received from client: {message}")
59
+ except json.JSONDecodeError:
60
+ logger.warning(f"Invalid JSON from client: {data}")
61
+
62
+ except WebSocketDisconnect:
63
+ logger.debug(f"WebSocket disconnected: session_id={session_id}")
64
+ finally:
65
+ # Clean up on disconnect
66
+ if session_id in active_connections:
67
+ active_connections[session_id].discard(websocket)
68
+
69
+ # Remove session entry if no more connections
70
+ if not active_connections[session_id]:
71
+ del active_connections[session_id]
72
+ logger.debug(f"No more connections for session {session_id}")
73
+
74
+
75
+ async def broadcast_to_session(session_id: str, event: dict):
76
+ """
77
+ Broadcast an event to all WebSocket clients connected to a session.
78
+
79
+ Args:
80
+ session_id: Session ID to broadcast to
81
+ event: Event data to send (will be JSON serialized)
82
+ """
83
+ if session_id not in active_connections:
84
+ logger.debug(f"No active connections for session {session_id}")
85
+ return
86
+
87
+ # Track dead connections
88
+ dead_connections = set()
89
+
90
+ # Send to all connected clients
91
+ for connection in active_connections[session_id]:
92
+ try:
93
+ await connection.send_json(event)
94
+ logger.debug(f"Broadcast to session {session_id}: {event.get('event')}")
95
+ except Exception as e:
96
+ logger.warning(f"Failed to send to connection: {e}")
97
+ dead_connections.add(connection)
98
+
99
+ # Clean up dead connections
100
+ if dead_connections:
101
+ active_connections[session_id] -= dead_connections
102
+ logger.info(f"Cleaned up {len(dead_connections)} dead connections")
103
+
104
+ # Remove session if no connections left
105
+ if not active_connections[session_id]:
106
+ del active_connections[session_id]
107
+
108
+
109
+ def get_connection_count(session_id: str) -> int:
110
+ """
111
+ Get the number of active WebSocket connections for a session.
112
+
113
+ Args:
114
+ session_id: Session ID to check
115
+
116
+ Returns:
117
+ Number of active connections
118
+ """
119
+ return len(active_connections.get(session_id, set()))
120
+
121
+
122
+ def get_all_connection_counts() -> Dict[str, int]:
123
+ """
124
+ Get connection counts for all sessions.
125
+
126
+ Returns:
127
+ Dictionary mapping session_id to connection count
128
+ """
129
+ return {
130
+ session_id: len(connections)
131
+ for session_id, connections in active_connections.items()
132
+ }
@@ -2,10 +2,10 @@
2
2
  Startup script for ALchemist FastAPI server.
3
3
 
4
4
  Usage:
5
- python run_api.py # Development mode with auto-reload
6
- python run_api.py --production # Production mode (no reload)
7
- python run_api.py --dev # Explicitly start in development mode
8
- alchemist-web # Entry point (production mode by default)
5
+ python -m api.run_api # Development mode with auto-reload
6
+ python -m api.run_api --production # Production mode (no reload)
7
+ python -m api.run_api --dev # Explicitly start in development mode
8
+ alchemist-web # Entry point (production mode by default)
9
9
  """
10
10
 
11
11
  def main():
@@ -15,10 +15,10 @@ def main():
15
15
 
16
16
  # For the alchemist-web entry point, default to production mode
17
17
  # Only use dev mode if explicitly requested
18
- is_script_call = any(arg.endswith('run_api.py') for arg in sys.argv)
18
+ is_script_call = any(arg.endswith('run_api.py') or 'api.run_api' in arg for arg in sys.argv)
19
19
 
20
20
  if is_script_call:
21
- # Called as: python run_api.py
21
+ # Called as: python -m api.run_api
22
22
  # Default to dev mode unless --production flag is present
23
23
  production = "--production" in sys.argv or "--prod" in sys.argv
24
24
  else:
@@ -47,7 +47,8 @@ def main():
47
47
  host="0.0.0.0",
48
48
  port=8000,
49
49
  reload=True,
50
- log_level="info"
50
+ log_level="warning", # Suppress INFO logs from polling
51
+ access_log=False # Disable access logs entirely
51
52
  )
52
53
 
53
54