workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (84) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  3. workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
  4. workbench/algorithms/dataframe/projection_2d.py +44 -21
  5. workbench/algorithms/dataframe/proximity.py +78 -150
  6. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  7. workbench/algorithms/models/cleanlab_model.py +382 -0
  8. workbench/algorithms/models/noise_model.py +388 -0
  9. workbench/algorithms/sql/outliers.py +3 -3
  10. workbench/api/__init__.py +3 -0
  11. workbench/api/df_store.py +17 -108
  12. workbench/api/endpoint.py +13 -11
  13. workbench/api/feature_set.py +111 -8
  14. workbench/api/meta_model.py +289 -0
  15. workbench/api/model.py +45 -12
  16. workbench/api/parameter_store.py +3 -52
  17. workbench/cached/cached_model.py +4 -4
  18. workbench/core/artifacts/artifact.py +5 -5
  19. workbench/core/artifacts/df_store_core.py +114 -0
  20. workbench/core/artifacts/endpoint_core.py +228 -237
  21. workbench/core/artifacts/feature_set_core.py +185 -230
  22. workbench/core/artifacts/model_core.py +34 -26
  23. workbench/core/artifacts/parameter_store_core.py +98 -0
  24. workbench/core/pipelines/pipeline_executor.py +1 -1
  25. workbench/core/transforms/features_to_model/features_to_model.py +22 -10
  26. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
  27. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  28. workbench/model_script_utils/model_script_utils.py +339 -0
  29. workbench/model_script_utils/pytorch_utils.py +405 -0
  30. workbench/model_script_utils/uq_harness.py +278 -0
  31. workbench/model_scripts/chemprop/chemprop.template +428 -631
  32. workbench/model_scripts/chemprop/generated_model_script.py +432 -635
  33. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  34. workbench/model_scripts/chemprop/requirements.txt +2 -10
  35. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  36. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  37. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  38. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  39. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  40. workbench/model_scripts/meta_model/meta_model.template +209 -0
  41. workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
  42. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  43. workbench/model_scripts/pytorch_model/pytorch.template +370 -609
  44. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  45. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  46. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  47. workbench/model_scripts/script_generation.py +6 -5
  48. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  49. workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
  50. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  51. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  52. workbench/model_scripts/xgb_model/xgb_model.template +366 -396
  53. workbench/repl/workbench_shell.py +0 -5
  54. workbench/resources/open_source_api.key +1 -1
  55. workbench/scripts/endpoint_test.py +2 -2
  56. workbench/scripts/meta_model_sim.py +35 -0
  57. workbench/scripts/training_test.py +85 -0
  58. workbench/utils/chem_utils/fingerprints.py +87 -46
  59. workbench/utils/chem_utils/projections.py +16 -6
  60. workbench/utils/chemprop_utils.py +36 -655
  61. workbench/utils/meta_model_simulator.py +499 -0
  62. workbench/utils/metrics_utils.py +256 -0
  63. workbench/utils/model_utils.py +192 -54
  64. workbench/utils/pytorch_utils.py +33 -472
  65. workbench/utils/shap_utils.py +1 -55
  66. workbench/utils/xgboost_local_crossfold.py +267 -0
  67. workbench/utils/xgboost_model_utils.py +49 -356
  68. workbench/web_interface/components/model_plot.py +7 -1
  69. workbench/web_interface/components/plugins/model_details.py +30 -68
  70. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  71. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
  72. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
  73. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
  74. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  75. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
  76. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  77. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  78. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  79. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  80. workbench/model_scripts/uq_models/mapie.template +0 -605
  81. workbench/model_scripts/uq_models/requirements.txt +0 -1
  82. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  83. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
  84. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,256 @@
1
+ """Metrics utilities for computing model performance from predictions."""
2
+
3
+ import logging
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy.stats import spearmanr
9
+ from sklearn.metrics import (
10
+ mean_absolute_error,
11
+ median_absolute_error,
12
+ precision_recall_fscore_support,
13
+ r2_score,
14
+ roc_auc_score,
15
+ root_mean_squared_error,
16
+ )
17
+
18
+ log = logging.getLogger("workbench")
19
+
20
+
21
+ def validate_proba_columns(predictions_df: pd.DataFrame, class_labels: List[str], guessing: bool = False) -> bool:
22
+ """Validate that probability columns match class labels.
23
+
24
+ Args:
25
+ predictions_df: DataFrame with prediction results
26
+ class_labels: List of class labels
27
+ guessing: Whether class labels were guessed from data
28
+
29
+ Returns:
30
+ True if validation passes
31
+
32
+ Raises:
33
+ ValueError: If probability columns don't match class labels
34
+ """
35
+ proba_columns = [col.replace("_proba", "") for col in predictions_df.columns if col.endswith("_proba")]
36
+
37
+ if sorted(class_labels) != sorted(proba_columns):
38
+ label_type = "GUESSED class_labels" if guessing else "class_labels"
39
+ raise ValueError(f"_proba columns {proba_columns} != {label_type} {class_labels}!")
40
+ return True
41
+
42
+
43
+ def compute_classification_metrics(
44
+ predictions_df: pd.DataFrame,
45
+ target_col: str,
46
+ class_labels: Optional[List[str]] = None,
47
+ prediction_col: str = "prediction",
48
+ ) -> pd.DataFrame:
49
+ """Compute classification metrics from a predictions DataFrame.
50
+
51
+ Args:
52
+ predictions_df: DataFrame with target and prediction columns
53
+ target_col: Name of the target column
54
+ class_labels: List of class labels in order (if None, inferred from target column)
55
+ prediction_col: Name of the prediction column (default: "prediction")
56
+
57
+ Returns:
58
+ DataFrame with per-class metrics (precision, recall, f1, roc_auc, support)
59
+ plus a weighted 'all' row. Returns empty DataFrame if validation fails.
60
+ """
61
+ # Validate inputs
62
+ if predictions_df.empty:
63
+ log.warning("Empty DataFrame provided. Returning empty metrics.")
64
+ return pd.DataFrame()
65
+
66
+ if prediction_col not in predictions_df.columns:
67
+ log.warning(f"Prediction column '{prediction_col}' not found in DataFrame. Returning empty metrics.")
68
+ return pd.DataFrame()
69
+
70
+ if target_col not in predictions_df.columns:
71
+ log.warning(f"Target column '{target_col}' not found in DataFrame. Returning empty metrics.")
72
+ return pd.DataFrame()
73
+
74
+ # Handle NaN predictions
75
+ df = predictions_df.copy()
76
+ nan_pred = df[prediction_col].isnull().sum()
77
+ if nan_pred > 0:
78
+ log.warning(f"Dropping {nan_pred} rows with NaN predictions.")
79
+ df = df[~df[prediction_col].isnull()]
80
+
81
+ if df.empty:
82
+ log.warning("No valid rows after dropping NaNs. Returning empty metrics.")
83
+ return pd.DataFrame()
84
+
85
+ # Handle class labels
86
+ guessing = False
87
+ if class_labels is None:
88
+ log.warning("Class labels not provided. Inferring from target column.")
89
+ class_labels = df[target_col].unique().tolist()
90
+ guessing = True
91
+
92
+ # Validate probability columns if present
93
+ proba_cols = [col for col in df.columns if col.endswith("_proba")]
94
+ if proba_cols:
95
+ validate_proba_columns(df, class_labels, guessing=guessing)
96
+
97
+ y_true = df[target_col]
98
+ y_pred = df[prediction_col]
99
+
100
+ # Precision, recall, f1, support per class
101
+ prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=class_labels, zero_division=0)
102
+
103
+ # ROC AUC per class (requires probability columns and sorted labels)
104
+ proba_col_names = [f"{label}_proba" for label in class_labels]
105
+ if all(col in df.columns for col in proba_col_names):
106
+ # roc_auc_score requires labels to be sorted, so we sort and reorder results back
107
+ sorted_labels = sorted(class_labels)
108
+ sorted_proba_cols = [f"{label}_proba" for label in sorted_labels]
109
+ y_score_sorted = df[sorted_proba_cols].values
110
+ roc_auc_sorted = roc_auc_score(y_true, y_score_sorted, labels=sorted_labels, multi_class="ovr", average=None)
111
+ # Map back to original class_labels order
112
+ label_to_auc = dict(zip(sorted_labels, roc_auc_sorted))
113
+ roc_auc = np.array([label_to_auc[label] for label in class_labels])
114
+ else:
115
+ roc_auc = np.array([None] * len(class_labels))
116
+
117
+ # Build per-class metrics
118
+ metrics_df = pd.DataFrame(
119
+ {
120
+ target_col: class_labels,
121
+ "precision": prec,
122
+ "recall": rec,
123
+ "f1": f1,
124
+ "roc_auc": roc_auc,
125
+ "support": support.astype(int),
126
+ }
127
+ )
128
+
129
+ # Add weighted 'all' row
130
+ total = support.sum()
131
+ all_row = {
132
+ target_col: "all",
133
+ "precision": (prec * support).sum() / total,
134
+ "recall": (rec * support).sum() / total,
135
+ "f1": (f1 * support).sum() / total,
136
+ "roc_auc": (roc_auc * support).sum() / total if roc_auc[0] is not None else None,
137
+ "support": int(total),
138
+ }
139
+ metrics_df = pd.concat([metrics_df, pd.DataFrame([all_row])], ignore_index=True)
140
+
141
+ return metrics_df
142
+
143
+
144
+ def compute_regression_metrics(
145
+ predictions_df: pd.DataFrame,
146
+ target_col: str,
147
+ prediction_col: str = "prediction",
148
+ ) -> pd.DataFrame:
149
+ """Compute regression metrics from a predictions DataFrame.
150
+
151
+ Args:
152
+ predictions_df: DataFrame with target and prediction columns
153
+ target_col: Name of the target column
154
+ prediction_col: Name of the prediction column (default: "prediction")
155
+
156
+ Returns:
157
+ DataFrame with regression metrics (rmse, mae, medae, r2, spearmanr, support)
158
+ Returns empty DataFrame if validation fails or no valid data.
159
+ """
160
+ # Validate inputs
161
+ if predictions_df.empty:
162
+ log.warning("Empty DataFrame provided. Returning empty metrics.")
163
+ return pd.DataFrame()
164
+
165
+ if prediction_col not in predictions_df.columns:
166
+ log.warning(f"Prediction column '{prediction_col}' not found in DataFrame. Returning empty metrics.")
167
+ return pd.DataFrame()
168
+
169
+ if target_col not in predictions_df.columns:
170
+ log.warning(f"Target column '{target_col}' not found in DataFrame. Returning empty metrics.")
171
+ return pd.DataFrame()
172
+
173
+ # Handle NaN values
174
+ df = predictions_df[[target_col, prediction_col]].copy()
175
+ nan_target = df[target_col].isnull().sum()
176
+ nan_pred = df[prediction_col].isnull().sum()
177
+ if nan_target > 0 or nan_pred > 0:
178
+ log.warning(f"NaNs found: {target_col}={nan_target}, {prediction_col}={nan_pred}. Dropping NaN rows.")
179
+ df = df.dropna()
180
+
181
+ if df.empty:
182
+ log.warning("No valid rows after dropping NaNs. Returning empty metrics.")
183
+ return pd.DataFrame()
184
+
185
+ y_true = df[target_col].values
186
+ y_pred = df[prediction_col].values
187
+
188
+ return pd.DataFrame(
189
+ [
190
+ {
191
+ "rmse": root_mean_squared_error(y_true, y_pred),
192
+ "mae": mean_absolute_error(y_true, y_pred),
193
+ "medae": median_absolute_error(y_true, y_pred),
194
+ "r2": r2_score(y_true, y_pred),
195
+ "spearmanr": spearmanr(y_true, y_pred).correlation,
196
+ "support": len(y_true),
197
+ }
198
+ ]
199
+ )
200
+
201
+
202
+ def compute_metrics_from_predictions(
203
+ predictions_df: pd.DataFrame,
204
+ target_col: str,
205
+ class_labels: Optional[List[str]] = None,
206
+ prediction_col: str = "prediction",
207
+ ) -> pd.DataFrame:
208
+ """Compute metrics from a predictions DataFrame.
209
+
210
+ Automatically determines if this is classification or regression based on
211
+ whether class_labels is provided.
212
+
213
+ Args:
214
+ predictions_df: DataFrame with target and prediction columns
215
+ target_col: Name of the target column
216
+ class_labels: List of class labels for classification (None for regression)
217
+ prediction_col: Name of the prediction column (default: "prediction")
218
+
219
+ Returns:
220
+ DataFrame with computed metrics
221
+ """
222
+ if target_col not in predictions_df.columns:
223
+ raise ValueError(f"Target column '{target_col}' not found in predictions DataFrame")
224
+ if prediction_col not in predictions_df.columns:
225
+ raise ValueError(f"Prediction column '{prediction_col}' not found in predictions DataFrame")
226
+
227
+ if class_labels:
228
+ return compute_classification_metrics(predictions_df, target_col, class_labels, prediction_col)
229
+ else:
230
+ return compute_regression_metrics(predictions_df, target_col, prediction_col)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ # Test with sample data
235
+ print("Testing classification metrics...")
236
+ class_df = pd.DataFrame(
237
+ {
238
+ "target": ["a", "b", "c", "a", "b", "c", "a", "b", "c", "a"],
239
+ "prediction": ["a", "b", "c", "a", "b", "a", "a", "b", "c", "b"],
240
+ "a_proba": [0.8, 0.1, 0.1, 0.7, 0.2, 0.4, 0.9, 0.1, 0.1, 0.3],
241
+ "b_proba": [0.1, 0.8, 0.1, 0.2, 0.7, 0.3, 0.05, 0.8, 0.2, 0.6],
242
+ "c_proba": [0.1, 0.1, 0.8, 0.1, 0.1, 0.3, 0.05, 0.1, 0.7, 0.1],
243
+ }
244
+ )
245
+ metrics = compute_metrics_from_predictions(class_df, "target", ["a", "b", "c"])
246
+ print(metrics.to_string(index=False))
247
+
248
+ print("\nTesting regression metrics...")
249
+ reg_df = pd.DataFrame(
250
+ {
251
+ "target": [1.0, 2.0, 3.0, 4.0, 5.0],
252
+ "prediction": [1.1, 2.2, 2.9, 4.1, 4.8],
253
+ }
254
+ )
255
+ metrics = compute_metrics_from_predictions(reg_df, "target")
256
+ print(metrics.to_string(index=False))
@@ -93,16 +93,17 @@ def get_custom_script_path(package: str, script_name: str) -> Path:
93
93
  return script_path
94
94
 
95
95
 
96
- def proximity_model_local(model: "Model"):
97
- """Create a Proximity Model for this Model
96
+ def proximity_model_local(model: "Model", include_all_columns: bool = False):
97
+ """Create a FeatureSpaceProximity Model for this Model
98
98
 
99
99
  Args:
100
100
  model (Model): The Model/FeatureSet used to create the proximity model
101
+ include_all_columns (bool): Include all DataFrame columns in neighbor results (default: False)
101
102
 
102
103
  Returns:
103
- Proximity: The proximity model
104
+ FeatureSpaceProximity: The proximity model
104
105
  """
105
- from workbench.algorithms.dataframe.proximity import Proximity # noqa: F401 (avoid circular import)
106
+ from workbench.algorithms.dataframe.feature_space_proximity import FeatureSpaceProximity # noqa: F401
106
107
  from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
107
108
 
108
109
  # Get Feature and Target Columns from the existing given Model
@@ -121,74 +122,154 @@ def proximity_model_local(model: "Model"):
121
122
  model_ids = set(model_df[id_column])
122
123
  full_df["in_model"] = full_df[id_column].isin(model_ids)
123
124
 
124
- # Create and return the Proximity Model
125
- return Proximity(full_df, id_column, features, target, track_columns=features)
125
+ # Create and return the FeatureSpaceProximity Model
126
+ return FeatureSpaceProximity(
127
+ full_df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
128
+ )
126
129
 
127
130
 
128
- def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
129
- """Create a proximity model based on the given model
131
+ def fingerprint_prox_model_local(
132
+ model: "Model",
133
+ include_all_columns: bool = False,
134
+ radius: int = 2,
135
+ n_bits: int = 1024,
136
+ counts: bool = False,
137
+ ):
138
+ """Create a FingerprintProximity Model for this Model
130
139
 
131
140
  Args:
132
- model (Model): The model to create the proximity model from
133
- prox_model_name (str): The name of the proximity model to create
134
- track_columns (list, optional): List of columns to track in the proximity model
141
+ model (Model): The Model used to create the fingerprint proximity model
142
+ include_all_columns (bool): Include all DataFrame columns in neighbor results (default: False)
143
+ radius (int): Morgan fingerprint radius (default: 2)
144
+ n_bits (int): Number of bits for the fingerprint (default: 1024)
145
+ counts (bool): Use count fingerprints instead of binary (default: False)
146
+
135
147
  Returns:
136
- Model: The proximity model
148
+ FingerprintProximity: The fingerprint proximity model
137
149
  """
138
- from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
150
+ from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity # noqa: F401
151
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
139
152
 
140
- # Get the custom script path for the proximity model
141
- script_path = get_custom_script_path("proximity", "feature_space_proximity.template")
153
+ # Get Target Column from the existing given Model
154
+ target = model.target()
155
+
156
+ # Backtrack our FeatureSet to get the ID column
157
+ fs = FeatureSet(model.get_input())
158
+ id_column = fs.id_column
159
+
160
+ # Create the Proximity Model from both the full FeatureSet and the Model training data
161
+ full_df = fs.pull_dataframe()
162
+ model_df = model.training_view().pull_dataframe()
163
+
164
+ # Mark rows that are in the model
165
+ model_ids = set(model_df[id_column])
166
+ full_df["in_model"] = full_df[id_column].isin(model_ids)
167
+
168
+ # Create and return the FingerprintProximity Model
169
+ return FingerprintProximity(
170
+ full_df,
171
+ id_column=id_column,
172
+ target=target,
173
+ include_all_columns=include_all_columns,
174
+ radius=radius,
175
+ n_bits=n_bits,
176
+ )
177
+
178
+
179
+ def noise_model_local(model: "Model"):
180
+ """Create a NoiseModel for detecting noisy/problematic samples in a Model's training data.
181
+
182
+ Args:
183
+ model (Model): The Model used to create the noise model
184
+
185
+ Returns:
186
+ NoiseModel: The noise model with precomputed noise scores for all samples
187
+ """
188
+ from workbench.algorithms.models.noise_model import NoiseModel # noqa: F401 (avoid circular import)
189
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
142
190
 
143
191
  # Get Feature and Target Columns from the existing given Model
144
192
  features = model.features()
145
193
  target = model.target()
146
194
 
147
- # Create the Proximity Model from our FeatureSet
195
+ # Backtrack our FeatureSet to get the ID column
148
196
  fs = FeatureSet(model.get_input())
149
- prox_model = fs.to_model(
150
- name=prox_model_name,
151
- model_type=ModelType.PROXIMITY,
152
- feature_list=features,
153
- target_column=target,
154
- description=f"Proximity Model for {model.name}",
155
- tags=["proximity", model.name],
156
- custom_script=script_path,
157
- custom_args={"track_columns": track_columns},
158
- )
159
- return prox_model
197
+ id_column = fs.id_column
198
+
199
+ # Create the NoiseModel from both the full FeatureSet and the Model training data
200
+ full_df = fs.pull_dataframe()
201
+ model_df = model.training_view().pull_dataframe()
202
+
203
+ # Mark rows that are in the model
204
+ model_ids = set(model_df[id_column])
205
+ full_df["in_model"] = full_df[id_column].isin(model_ids)
160
206
 
207
+ # Create and return the NoiseModel
208
+ return NoiseModel(full_df, id_column, features, target)
161
209
 
162
- def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -> "Model":
163
- """Create a Uncertainty Quantification (UQ) model based on the given model
210
+
211
+ def cleanlab_model_local(model: "Model"):
212
+ """Create a CleanlabModels instance for detecting data quality issues in a Model's training data.
164
213
 
165
214
  Args:
166
- model (Model): The model to create the UQ model from
167
- uq_model_name (str): The name of the UQ model to create
168
- train_all_data (bool, optional): Whether to train the UQ model on all data (default: False)
215
+ model (Model): The Model used to create the cleanlab models
169
216
 
170
217
  Returns:
171
- Model: The UQ model
218
+ CleanlabModels: Factory providing access to CleanLearning and Datalab models.
219
+ - clean_learning(): CleanLearning model with enhanced get_label_issues()
220
+ - datalab(): Datalab instance with report(), get_issues()
221
+ """
222
+ from workbench.algorithms.models.cleanlab_model import create_cleanlab_model # noqa: F401 (avoid circular import)
223
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
224
+
225
+ # Get Feature and Target Columns from the existing given Model
226
+ features = model.features()
227
+ target = model.target()
228
+ model_type = model.model_type
229
+
230
+ # Backtrack our FeatureSet to get the ID column
231
+ fs = FeatureSet(model.get_input())
232
+ id_column = fs.id_column
233
+
234
+ # Get the full FeatureSet data
235
+ full_df = fs.pull_dataframe()
236
+
237
+ # Create and return the CleanLearning model
238
+ return create_cleanlab_model(full_df, id_column, features, target, model_type=model_type)
239
+
240
+
241
+ def published_proximity_model(model: "Model", prox_model_name: str, include_all_columns: bool = False) -> "Model":
242
+ """Create a published proximity model based on the given model
243
+
244
+ Args:
245
+ model (Model): The model to create the proximity model from
246
+ prox_model_name (str): The name of the proximity model to create
247
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
248
+ Returns:
249
+ Model: The proximity model
172
250
  """
173
251
  from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
174
252
 
253
+ # Get the custom script path for the proximity model
254
+ script_path = get_custom_script_path("proximity", "feature_space_proximity.template")
255
+
175
256
  # Get Feature and Target Columns from the existing given Model
176
257
  features = model.features()
177
258
  target = model.target()
178
259
 
179
260
  # Create the Proximity Model from our FeatureSet
180
261
  fs = FeatureSet(model.get_input())
181
- uq_model = fs.to_model(
182
- name=uq_model_name,
183
- model_type=ModelType.UQ_REGRESSOR,
262
+ prox_model = fs.to_model(
263
+ name=prox_model_name,
264
+ model_type=ModelType.PROXIMITY,
184
265
  feature_list=features,
185
266
  target_column=target,
186
- description=f"UQ Model for {model.name}",
187
- tags=["uq", model.name],
188
- train_all_data=train_all_data,
189
- custom_args={"id_column": fs.id_column, "track_columns": [target]},
267
+ description=f"Proximity Model for {model.name}",
268
+ tags=["proximity", model.name],
269
+ custom_script=script_path,
270
+ custom_args={"include_all_columns": include_all_columns},
190
271
  )
191
- return uq_model
272
+ return prox_model
192
273
 
193
274
 
194
275
  def safe_extract_tarfile(tar_path: str, extract_path: str) -> None:
@@ -239,6 +320,63 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
239
320
  return category_mappings
240
321
 
241
322
 
323
+ def load_hyperparameters_from_s3(model_artifact_uri: str) -> Optional[dict]:
324
+ """
325
+ Download and extract hyperparameters from a model artifact in S3.
326
+
327
+ Args:
328
+ model_artifact_uri (str): S3 URI of the model artifact (model.tar.gz).
329
+
330
+ Returns:
331
+ dict: The loaded hyperparameters or None if not found.
332
+ """
333
+ hyperparameters = None
334
+
335
+ with tempfile.TemporaryDirectory() as tmpdir:
336
+ # Download model artifact
337
+ local_tar_path = os.path.join(tmpdir, "model.tar.gz")
338
+ wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
339
+
340
+ # Extract tarball
341
+ safe_extract_tarfile(local_tar_path, tmpdir)
342
+
343
+ # Look for hyperparameters in base directory only
344
+ hyperparameters_path = os.path.join(tmpdir, "hyperparameters.json")
345
+
346
+ if os.path.exists(hyperparameters_path):
347
+ try:
348
+ with open(hyperparameters_path, "r") as f:
349
+ hyperparameters = json.load(f)
350
+ log.info(f"Loaded hyperparameters from {hyperparameters_path}")
351
+ except Exception as e:
352
+ log.warning(f"Failed to load hyperparameters from {hyperparameters_path}: {e}")
353
+
354
+ return hyperparameters
355
+
356
+
357
+ def get_model_hyperparameters(workbench_model: Any) -> Optional[dict]:
358
+ """Get the hyperparameters used to train a Workbench model.
359
+
360
+ This retrieves the hyperparameters.json file from the model artifacts
361
+ that was saved during model training.
362
+
363
+ Args:
364
+ workbench_model: Workbench model object
365
+
366
+ Returns:
367
+ dict: The hyperparameters used during training, or None if not found
368
+ """
369
+ # Get the model artifact URI
370
+ model_artifact_uri = workbench_model.model_data_url()
371
+
372
+ if model_artifact_uri is None:
373
+ log.warning(f"No model artifact found for {workbench_model.uuid}")
374
+ return None
375
+
376
+ log.info(f"Loading hyperparameters from {model_artifact_uri}")
377
+ return load_hyperparameters_from_s3(model_artifact_uri)
378
+
379
+
242
380
  def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
243
381
  """
244
382
  Evaluate uncertainty quantification model with essential metrics.
@@ -259,6 +397,13 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
259
397
  if "prediction" not in df.columns:
260
398
  raise ValueError("Prediction column 'prediction' not found in DataFrame.")
261
399
 
400
+ # Drop rows with NaN predictions (e.g., from models that can't handle missing features)
401
+ n_total = len(df)
402
+ df = df.dropna(subset=["prediction", target_col])
403
+ n_valid = len(df)
404
+ if n_valid < n_total:
405
+ log.info(f"UQ metrics: dropped {n_total - n_valid} rows with NaN predictions")
406
+
262
407
  # --- Coverage and Interval Width ---
263
408
  if "q_025" in df.columns and "q_975" in df.columns:
264
409
  lower_95, upper_95 = df["q_025"], df["q_975"]
@@ -350,7 +495,7 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
350
495
 
351
496
  if __name__ == "__main__":
352
497
  """Exercise the Model Utilities"""
353
- from workbench.api import Model, Endpoint
498
+ from workbench.api import Model
354
499
 
355
500
  # Get the instance information
356
501
  print(model_instance_info())
@@ -365,18 +510,11 @@ if __name__ == "__main__":
365
510
  # Get the custom script path
366
511
  print(get_custom_script_path("chem_info", "molecular_descriptors.py"))
367
512
 
368
- # Test the proximity model
513
+ # Test loading hyperparameters
369
514
  m = Model("aqsol-regression")
515
+ hyperparams = get_model_hyperparameters(m)
516
+ print(hyperparams)
517
+
518
+ # Test the proximity model
370
519
  # prox_model = proximity_model(m, "aqsol-prox")
371
520
  # print(prox_model)#
372
-
373
- # Test the UQ model
374
- # uq_model_instance = uq_model(m, "aqsol-uq")
375
- # print(uq_model_instance)
376
- # uq_model_instance.to_endpoint()
377
-
378
- # Test the uq_metrics function
379
- end = Endpoint("aqsol-uq")
380
- df = end.auto_inference(capture=True)
381
- results = uq_metrics(df, target_col="solubility")
382
- print(results)