alchemist-nrel 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. alchemist_core/__init__.py +2 -2
  2. alchemist_core/acquisition/botorch_acquisition.py +84 -126
  3. alchemist_core/data/experiment_manager.py +196 -20
  4. alchemist_core/models/botorch_model.py +292 -63
  5. alchemist_core/models/sklearn_model.py +175 -15
  6. alchemist_core/session.py +3532 -76
  7. alchemist_core/utils/__init__.py +3 -1
  8. alchemist_core/utils/acquisition_utils.py +60 -0
  9. alchemist_core/visualization/__init__.py +45 -0
  10. alchemist_core/visualization/helpers.py +130 -0
  11. alchemist_core/visualization/plots.py +1449 -0
  12. alchemist_nrel-0.3.2.dist-info/METADATA +185 -0
  13. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/RECORD +34 -29
  14. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/WHEEL +1 -1
  15. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/entry_points.txt +1 -1
  16. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/top_level.txt +0 -1
  17. api/example_client.py +7 -2
  18. api/main.py +3 -2
  19. api/models/requests.py +76 -1
  20. api/models/responses.py +102 -2
  21. api/routers/acquisition.py +25 -0
  22. api/routers/experiments.py +352 -11
  23. api/routers/sessions.py +195 -11
  24. api/routers/visualizations.py +6 -4
  25. api/routers/websocket.py +132 -0
  26. run_api.py → api/run_api.py +8 -7
  27. api/services/session_store.py +370 -71
  28. api/static/assets/index-B6Cf6s_b.css +1 -0
  29. api/static/assets/{index-C0_glioA.js → index-B7njvc9r.js} +223 -208
  30. api/static/index.html +2 -2
  31. ui/gpr_panel.py +11 -5
  32. ui/target_column_dialog.py +299 -0
  33. ui/ui.py +52 -5
  34. alchemist_core/models/ax_model.py +0 -159
  35. alchemist_nrel-0.3.0.dist-info/METADATA +0 -223
  36. api/static/assets/index-CB4V1LI5.css +0 -1
  37. {alchemist_nrel-0.3.0.dist-info → alchemist_nrel-0.3.2.dist-info}/licenses/LICENSE +0 -0
api/static/index.html CHANGED
@@ -5,8 +5,8 @@
5
5
  <link rel="icon" type="image/svg+xml" href="/NEW_ICON.png" />
6
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
7
  <title>ALchemist - Active Learning Toolkit</title>
8
- <script type="module" crossorigin src="/assets/index-C0_glioA.js"></script>
9
- <link rel="stylesheet" crossorigin href="/assets/index-CB4V1LI5.css">
8
+ <script type="module" crossorigin src="/assets/index-B7njvc9r.js"></script>
9
+ <link rel="stylesheet" crossorigin href="/assets/index-B6Cf6s_b.css">
10
10
  </head>
11
11
  <body>
12
12
  <div id="root"></div>
ui/gpr_panel.py CHANGED
@@ -505,14 +505,17 @@ class GaussianProcessPanel(ctk.CTkFrame):
505
505
  print(f" RMSE = {session_metrics.get('rmse', 'N/A'):.3f}")
506
506
  print("Learned hyperparameters:", self.main_app.learned_hyperparameters)
507
507
 
508
- # Initialize visualizations
508
+ # Initialize visualization with model results
509
+ # Get target column name from experiment manager
510
+ target_col = self.main_app.experiment_manager.target_columns[0]
511
+
509
512
  self.visualizations = Visualizations(
510
513
  parent=self,
511
514
  search_space=self.main_app.search_space,
512
515
  gpr_model=self.main_app.gpr_model,
513
516
  exp_df=self.main_app.exp_df,
514
- encoded_X=self.main_app.exp_df.drop(columns='Output'),
515
- encoded_y=self.main_app.exp_df['Output']
517
+ encoded_X=self.main_app.exp_df.drop(columns=target_col),
518
+ encoded_y=self.main_app.exp_df[target_col]
516
519
  )
517
520
  self.visualizations.rmse_values = self.main_app.rmse_values
518
521
  self.visualizations.mae_values = self.main_app.mae_values
@@ -532,13 +535,16 @@ class GaussianProcessPanel(ctk.CTkFrame):
532
535
  # VISUALIZATIONS
533
536
  # ==========================
534
537
  def initialize_visualizations(self):
538
+ # Get target column name from experiment manager
539
+ target_col = self.main_app.experiment_manager.target_columns[0]
540
+
535
541
  self.visualizations = Visualizations(
536
542
  parent=self,
537
543
  search_space=self.main_app.search_space,
538
544
  gpr_model=self.main_app.gpr_model,
539
545
  exp_df=self.main_app.exp_df,
540
- encoded_X=self.main_app.exp_df.drop(columns='Output'),
541
- encoded_y=self.main_app.exp_df['Output']
546
+ encoded_X=self.main_app.exp_df.drop(columns=target_col),
547
+ encoded_y=self.main_app.exp_df[target_col]
542
548
  )
543
549
  self.visualizations.rmse_values = self.main_app.rmse_values
544
550
  self.visualizations.mae_values = self.main_app.mae_values
@@ -0,0 +1,299 @@
1
+ """
2
+ Target Column Selection Dialog
3
+
4
+ Allows users to select which column(s) in their CSV should be treated as optimization targets.
5
+ Supports both single-objective and multi-objective optimization.
6
+ """
7
+
8
+ import customtkinter as ctk
9
+ from typing import List, Optional, Tuple
10
+ import tkinter as tk
11
+
12
+
13
+ class TargetColumnDialog(ctk.CTkToplevel):
14
+ """
15
+ Dialog for selecting target columns when loading experimental data.
16
+
17
+ Features:
18
+ - Single/Multi-objective mode toggle
19
+ - Column selection (dropdown for single, checkboxes for multi)
20
+ - Validation before confirming
21
+ """
22
+
23
+ def __init__(self, parent, available_columns: List[str], default_column: str = None):
24
+ """
25
+ Initialize the target column selection dialog.
26
+
27
+ Args:
28
+ parent: Parent window
29
+ available_columns: List of column names available in the CSV
30
+ default_column: Default column to select (if it exists in available_columns)
31
+ """
32
+ super().__init__(parent)
33
+
34
+ self.title("Select Target Column(s)")
35
+ self.geometry("500x400")
36
+ self.resizable(False, False)
37
+
38
+ # Make dialog modal
39
+ self.transient(parent)
40
+ self.grab_set()
41
+
42
+ # Store data
43
+ self.available_columns = available_columns
44
+ self.default_column = default_column if default_column in available_columns else None
45
+ self.result = None # Will store selected column(s) when confirmed
46
+
47
+ # UI state
48
+ self.mode = "single" # "single" or "multi"
49
+ self.checkbox_vars = {} # For multi-objective mode
50
+
51
+ self._create_ui()
52
+
53
+ # Center the dialog
54
+ self.update_idletasks()
55
+ x = parent.winfo_x() + (parent.winfo_width() // 2) - (self.winfo_width() // 2)
56
+ y = parent.winfo_y() + (parent.winfo_height() // 2) - (self.winfo_height() // 2)
57
+ self.geometry(f"+{x}+{y}")
58
+
59
+ def _create_ui(self):
60
+ """Create the dialog UI elements."""
61
+ # Header
62
+ header_frame = ctk.CTkFrame(self, fg_color="transparent")
63
+ header_frame.pack(fill="x", padx=20, pady=(20, 10))
64
+
65
+ ctk.CTkLabel(
66
+ header_frame,
67
+ text="Select Target Column(s)",
68
+ font=ctk.CTkFont(size=16, weight="bold")
69
+ ).pack(anchor="w")
70
+
71
+ ctk.CTkLabel(
72
+ header_frame,
73
+ text="Choose which column(s) to optimize:",
74
+ font=ctk.CTkFont(size=12),
75
+ text_color="gray"
76
+ ).pack(anchor="w", pady=(5, 0))
77
+
78
+ # Mode selector (Single vs Multi-objective)
79
+ mode_frame = ctk.CTkFrame(self)
80
+ mode_frame.pack(fill="x", padx=20, pady=10)
81
+
82
+ ctk.CTkLabel(
83
+ mode_frame,
84
+ text="Optimization Mode:",
85
+ font=ctk.CTkFont(size=12, weight="bold")
86
+ ).pack(side="left", padx=(10, 20))
87
+
88
+ self.mode_var = ctk.StringVar(value="single")
89
+
90
+ self.single_radio = ctk.CTkRadioButton(
91
+ mode_frame,
92
+ text="Single-Objective",
93
+ variable=self.mode_var,
94
+ value="single",
95
+ command=self._on_mode_change
96
+ )
97
+ self.single_radio.pack(side="left", padx=10)
98
+
99
+ self.multi_radio = ctk.CTkRadioButton(
100
+ mode_frame,
101
+ text="Multi-Objective",
102
+ variable=self.mode_var,
103
+ value="multi",
104
+ command=self._on_mode_change
105
+ )
106
+ self.multi_radio.pack(side="left", padx=10)
107
+
108
+ # Column selection area (content changes based on mode)
109
+ self.selection_frame = ctk.CTkFrame(self)
110
+ self.selection_frame.pack(fill="both", expand=True, padx=20, pady=10)
111
+
112
+ self._update_selection_ui()
113
+
114
+ # Buttons
115
+ button_frame = ctk.CTkFrame(self, fg_color="transparent")
116
+ button_frame.pack(fill="x", padx=20, pady=(10, 20))
117
+
118
+ ctk.CTkButton(
119
+ button_frame,
120
+ text="Cancel",
121
+ command=self._on_cancel,
122
+ width=100
123
+ ).pack(side="right", padx=(10, 0))
124
+
125
+ ctk.CTkButton(
126
+ button_frame,
127
+ text="Confirm",
128
+ command=self._on_confirm,
129
+ width=100
130
+ ).pack(side="right")
131
+
132
+ def _on_mode_change(self):
133
+ """Handle mode change between single and multi-objective."""
134
+ self.mode = self.mode_var.get()
135
+ self._update_selection_ui()
136
+
137
+ def _update_selection_ui(self):
138
+ """Update the column selection UI based on current mode."""
139
+ # Clear existing widgets
140
+ for widget in self.selection_frame.winfo_children():
141
+ widget.destroy()
142
+
143
+ if self.mode == "single":
144
+ self._create_single_objective_ui()
145
+ else:
146
+ self._create_multi_objective_ui()
147
+
148
+ def _create_single_objective_ui(self):
149
+ """Create UI for single-objective mode (dropdown)."""
150
+ ctk.CTkLabel(
151
+ self.selection_frame,
152
+ text="Select target column:",
153
+ font=ctk.CTkFont(size=12)
154
+ ).pack(anchor="w", padx=20, pady=(20, 10))
155
+
156
+ # Dropdown menu
157
+ self.column_var = ctk.StringVar(value=self.default_column or self.available_columns[0])
158
+
159
+ self.column_dropdown = ctk.CTkOptionMenu(
160
+ self.selection_frame,
161
+ variable=self.column_var,
162
+ values=self.available_columns,
163
+ width=400
164
+ )
165
+ self.column_dropdown.pack(padx=20, pady=10)
166
+
167
+ # Info text
168
+ info_frame = ctk.CTkFrame(self.selection_frame, fg_color="transparent")
169
+ info_frame.pack(fill="x", padx=20, pady=(20, 10))
170
+
171
+ ctk.CTkLabel(
172
+ info_frame,
173
+ text="💡 Tip: This column will be maximized or minimized during optimization.",
174
+ font=ctk.CTkFont(size=11),
175
+ text_color="gray",
176
+ wraplength=400,
177
+ justify="left"
178
+ ).pack(anchor="w")
179
+
180
+ def _create_multi_objective_ui(self):
181
+ """Create UI for multi-objective mode (checkboxes)."""
182
+ ctk.CTkLabel(
183
+ self.selection_frame,
184
+ text="Select target columns (2 or more):",
185
+ font=ctk.CTkFont(size=12)
186
+ ).pack(anchor="w", padx=20, pady=(20, 10))
187
+
188
+ # Scrollable frame for checkboxes
189
+ checkbox_frame = ctk.CTkScrollableFrame(
190
+ self.selection_frame,
191
+ height=150
192
+ )
193
+ checkbox_frame.pack(fill="both", expand=True, padx=20, pady=10)
194
+
195
+ # Create checkboxes for each column
196
+ self.checkbox_vars = {}
197
+ for col in self.available_columns:
198
+ var = ctk.BooleanVar(value=False)
199
+ self.checkbox_vars[col] = var
200
+
201
+ checkbox = ctk.CTkCheckBox(
202
+ checkbox_frame,
203
+ text=col,
204
+ variable=var
205
+ )
206
+ checkbox.pack(anchor="w", pady=5, padx=10)
207
+
208
+ # Info text
209
+ info_frame = ctk.CTkFrame(self.selection_frame, fg_color="transparent")
210
+ info_frame.pack(fill="x", padx=20, pady=(10, 10))
211
+
212
+ ctk.CTkLabel(
213
+ info_frame,
214
+ text="💡 Tip: Multi-objective optimization finds trade-offs between objectives.",
215
+ font=ctk.CTkFont(size=11),
216
+ text_color="gray",
217
+ wraplength=400,
218
+ justify="left"
219
+ ).pack(anchor="w")
220
+
221
+ def _on_confirm(self):
222
+ """Handle confirm button click."""
223
+ if self.mode == "single":
224
+ # Single-objective: return selected column as string
225
+ selected = self.column_var.get()
226
+ if selected:
227
+ self.result = selected
228
+ self.destroy()
229
+ else:
230
+ # Multi-objective: return list of selected columns
231
+ selected = [col for col, var in self.checkbox_vars.items() if var.get()]
232
+ if len(selected) < 2:
233
+ # Show error - need at least 2 objectives
234
+ error_dialog = ctk.CTkToplevel(self)
235
+ error_dialog.title("Invalid Selection")
236
+ error_dialog.geometry("350x150")
237
+ error_dialog.transient(self)
238
+ error_dialog.grab_set()
239
+
240
+ ctk.CTkLabel(
241
+ error_dialog,
242
+ text="⚠️ Multi-Objective Mode",
243
+ font=ctk.CTkFont(size=14, weight="bold")
244
+ ).pack(pady=(20, 10))
245
+
246
+ ctk.CTkLabel(
247
+ error_dialog,
248
+ text="Please select at least 2 target columns\nfor multi-objective optimization.",
249
+ font=ctk.CTkFont(size=12)
250
+ ).pack(pady=10)
251
+
252
+ ctk.CTkButton(
253
+ error_dialog,
254
+ text="OK",
255
+ command=error_dialog.destroy,
256
+ width=100
257
+ ).pack(pady=10)
258
+
259
+ # Center error dialog
260
+ error_dialog.update_idletasks()
261
+ x = self.winfo_x() + (self.winfo_width() // 2) - (error_dialog.winfo_width() // 2)
262
+ y = self.winfo_y() + (self.winfo_height() // 2) - (error_dialog.winfo_height() // 2)
263
+ error_dialog.geometry(f"+{x}+{y}")
264
+ return
265
+
266
+ self.result = selected
267
+ self.destroy()
268
+
269
+ def _on_cancel(self):
270
+ """Handle cancel button click."""
271
+ self.result = None
272
+ self.destroy()
273
+
274
+ def get_result(self) -> Optional[str | List[str]]:
275
+ """
276
+ Get the user's selection.
277
+
278
+ Returns:
279
+ String for single-objective, list for multi-objective, or None if cancelled
280
+ """
281
+ return self.result
282
+
283
+
284
+ def show_target_column_dialog(parent, available_columns: List[str],
285
+ default_column: str = None) -> Optional[str | List[str]]:
286
+ """
287
+ Show target column selection dialog and return user's choice.
288
+
289
+ Args:
290
+ parent: Parent window
291
+ available_columns: List of column names available in the CSV
292
+ default_column: Default column to select (if it exists)
293
+
294
+ Returns:
295
+ Selected column(s) or None if cancelled
296
+ """
297
+ dialog = TargetColumnDialog(parent, available_columns, default_column)
298
+ parent.wait_window(dialog)
299
+ return dialog.get_result()
ui/ui.py CHANGED
@@ -519,6 +519,46 @@ class ALchemistApp(ctk.CTk):
519
519
 
520
520
  if file_path:
521
521
  try:
522
+ # First, read the CSV to check for target column
523
+ import pandas as pd
524
+ preview_df = pd.read_csv(file_path)
525
+
526
+ # Check if any configured target column exists
527
+ # Default to looking for 'Output' if no target_columns configured
528
+ expected_targets = getattr(self.experiment_manager, 'target_columns', ['Output'])
529
+ missing_targets = [col for col in expected_targets if col not in preview_df.columns]
530
+
531
+ # If target column(s) missing, show selection dialog
532
+ target_columns_to_use = None
533
+ if missing_targets:
534
+ # Get non-metadata columns that could be targets
535
+ metadata_cols = {'Iteration', 'Reason', 'Noise'}
536
+ available_cols = [col for col in preview_df.columns if col not in metadata_cols]
537
+
538
+ if not available_cols:
539
+ raise ValueError("CSV file contains no columns that could be target columns.")
540
+
541
+ # Show target selection dialog
542
+ from ui.target_column_dialog import show_target_column_dialog
543
+ selected = show_target_column_dialog(
544
+ parent=self,
545
+ available_columns=available_cols,
546
+ default_column='output' if 'output' in available_cols else None
547
+ )
548
+
549
+ if selected is None:
550
+ # User cancelled
551
+ print("Data loading cancelled by user.")
552
+ return
553
+
554
+ target_columns_to_use = selected if isinstance(selected, list) else [selected]
555
+ print(f"User selected target column(s): {target_columns_to_use}")
556
+ else:
557
+ target_columns_to_use = expected_targets
558
+
559
+ # Configure experiment manager with selected target columns
560
+ self.experiment_manager.target_columns = target_columns_to_use
561
+
522
562
  # Load experiments using the ExperimentManager
523
563
  self.experiment_manager.load_from_csv(file_path)
524
564
 
@@ -536,6 +576,7 @@ class ALchemistApp(ctk.CTk):
536
576
 
537
577
  # Log the data loading
538
578
  print(f"Loaded {len(self.exp_df)} experiment points from {file_path}")
579
+ print(f"Target column(s): {target_columns_to_use}")
539
580
  if 'Noise' in self.exp_df.columns:
540
581
  print("Notice: Noise column detected. This will be used for model regularization if available.")
541
582
 
@@ -1695,8 +1736,9 @@ class ALchemistApp(ctk.CTk):
1695
1736
  # Ensure metadata columns have correct types
1696
1737
  exp_df_clean = self.exp_df.copy()
1697
1738
 
1698
- # Define metadata columns
1699
- metadata_cols = {'Output', 'Noise', 'Iteration', 'Reason'}
1739
+ # Define metadata columns (including configured target columns)
1740
+ target_cols = set(self.experiment_manager.target_columns) if hasattr(self.experiment_manager, 'target_columns') else {'Output'}
1741
+ metadata_cols = target_cols | {'Noise', 'Iteration', 'Reason'}
1700
1742
 
1701
1743
  # Ensure Iteration is numeric
1702
1744
  if 'Iteration' in exp_df_clean.columns:
@@ -1706,9 +1748,10 @@ class ALchemistApp(ctk.CTk):
1706
1748
  if 'Reason' in exp_df_clean.columns:
1707
1749
  exp_df_clean['Reason'] = exp_df_clean['Reason'].astype(str).replace('nan', 'Manual')
1708
1750
 
1709
- # Ensure Output is numeric
1710
- if 'Output' in exp_df_clean.columns:
1711
- exp_df_clean['Output'] = pd.to_numeric(exp_df_clean['Output'], errors='coerce')
1751
+ # Ensure target columns are numeric
1752
+ for target_col in target_cols:
1753
+ if target_col in exp_df_clean.columns:
1754
+ exp_df_clean[target_col] = pd.to_numeric(exp_df_clean[target_col], errors='coerce')
1712
1755
 
1713
1756
  # Ensure Noise is numeric if present
1714
1757
  if 'Noise' in exp_df_clean.columns:
@@ -1745,6 +1788,10 @@ class ALchemistApp(ctk.CTk):
1745
1788
  # Copy cleaned data to session's experiment manager
1746
1789
  self.session.experiment_manager.df = exp_df_clean
1747
1790
 
1791
+ # Copy target_columns configuration to session's experiment manager
1792
+ if hasattr(self.experiment_manager, 'target_columns'):
1793
+ self.session.experiment_manager.target_columns = self.experiment_manager.target_columns
1794
+
1748
1795
  # Update local exp_df with cleaned version
1749
1796
  self.exp_df = exp_df_clean
1750
1797
 
@@ -1,159 +0,0 @@
1
- from ax.service.ax_client import AxClient
2
- from .base_model import BaseModel
3
- import pandas as pd
4
- import numpy as np
5
- from skopt.space import Real, Integer, Categorical
6
-
7
- class AxModel(BaseModel):
8
- def __init__(self, search_space, experiment_name="experiment", random_state=42):
9
- """
10
- Initialize the AxModel.
11
-
12
- Args:
13
- search_space: A list of skopt.space objects (Real, Integer, or Categorical).
14
- experiment_name: A name for the Ax experiment.
15
- random_state: Random seed for reproducibility.
16
- """
17
- self.experiment_name = experiment_name
18
- self.search_space = search_space
19
- self.random_state = random_state
20
- self.ax_client = AxClient(random_seed=random_state)
21
- self.trained = False
22
-
23
- def _build_parameters(self):
24
- """
25
- Build the Ax parameters list from the search_space.
26
- """
27
- parameters = []
28
- for dim in self.search_space:
29
- if isinstance(dim, Real):
30
- # For Real dimensions, use a continuous range.
31
- parameters.append({
32
- "name": dim.name,
33
- "type": "range",
34
- "bounds": list(dim.bounds),
35
- "value_type": "float",
36
- })
37
- elif isinstance(dim, Integer):
38
- # For Integer dimensions, use a range and specify value type as int.
39
- parameters.append({
40
- "name": dim.name,
41
- "type": "range",
42
- "bounds": list(dim.bounds),
43
- "value_type": "int",
44
- })
45
- elif isinstance(dim, Categorical):
46
- # For categorical dimensions, use "choice" and list the categories.
47
- # Here we assume that the categories are strings; if numeric, adjust "value_type" accordingly.
48
- parameters.append({
49
- "name": dim.name,
50
- "type": "choice",
51
- "values": list(dim.categories),
52
- "value_type": "str",
53
- })
54
- else:
55
- raise ValueError(f"Unsupported search space dimension type: {type(dim)}")
56
- return parameters
57
-
58
- def train(self, exp_df, **kwargs):
59
- """Train the Ax model using the raw experiment DataFrame."""
60
- X = exp_df.drop(columns="Output")
61
- y = exp_df["Output"]
62
- parameters = self._build_parameters()
63
- self.ax_client.create_experiment(
64
- name=self.experiment_name,
65
- parameters=parameters,
66
- )
67
- for i, row in X.iterrows():
68
- params = row.to_dict()
69
- outcome = float(y.iloc[i])
70
- self.ax_client.complete_trial(trial_index=i, raw_data={"objective": outcome})
71
- self.trained = True
72
-
73
- def predict(self, X, return_std=False, **kwargs):
74
- """
75
- For Ax, prediction means asking for the next candidate.
76
-
77
- Args:
78
- X: Not used (the next candidate is computed based on the experiment history).
79
- return_std: Not applicable; always returns just the candidate.
80
-
81
- Returns:
82
- A dictionary with parameter names and suggested values.
83
- """
84
- if not self.trained:
85
- raise ValueError("The Ax experiment has not been trained with past data yet.")
86
- parameters, trial_index = self.ax_client.get_next_trial()
87
- return parameters
88
-
89
- def predict_with_std(self, X):
90
- """
91
- Make predictions with standard deviation.
92
-
93
- Args:
94
- X: Input features (DataFrame or array)
95
-
96
- Returns:
97
- Tuple of (predictions, standard deviations)
98
- """
99
- if not self.is_trained:
100
- raise ValueError("Model is not trained yet")
101
-
102
- # Convert to DataFrame if needed
103
- if not isinstance(X, pd.DataFrame):
104
- if hasattr(self, 'feature_names') and self.feature_names:
105
- X = pd.DataFrame(X, columns=self.feature_names)
106
- else:
107
- raise ValueError("Cannot convert input to DataFrame - feature names unknown")
108
-
109
- # Prepare the observations in Ax format
110
- obs = []
111
- for _, row in X.iterrows():
112
- arm_parameters = row.to_dict()
113
- obs.append(arm_parameters)
114
-
115
- # Get the predictions
116
- means, covariances = self.surrogate.predict(obs)
117
-
118
- # Extract standard deviations from covariances
119
- stds = np.sqrt(np.diag(covariances))
120
-
121
- return means, stds
122
-
123
- def evaluate(self, X, y, **kwargs):
124
- """
125
- Evaluate the Ax model's performance using stored outcomes.
126
- In a more complete implementation, you could compute metrics such as RMSE across trials.
127
-
128
- Returns:
129
- A dictionary with evaluation metrics (here empty as a placeholder).
130
- """
131
- # Example: Extract and compute statistics from the experiment.
132
- return {}
133
-
134
- def get_hyperparameters(self):
135
- """
136
- Get model hyperparameters.
137
-
138
- Returns:
139
- A dictionary with hyperparameter names and values.
140
- """
141
- if not self.is_trained:
142
- return {"status": "Model not trained"}
143
-
144
- try:
145
- params = {}
146
- # For Ax models, we can extract some basic info
147
- if hasattr(self, 'surrogate') and hasattr(self.surrogate, 'model'):
148
- model_type = type(self.surrogate.model).__name__
149
- params['model_type'] = model_type
150
-
151
- # Try to get some GPEI-specific attributes if available
152
- if hasattr(self.surrogate.model, 'model'):
153
- inner_model = self.surrogate.model.model
154
- if hasattr(inner_model, 'covar_module'):
155
- params['covar_module'] = str(inner_model.covar_module)
156
-
157
- return params
158
- except Exception as e:
159
- return {"error": str(e)}