workbench 0.8.205__py3-none-any.whl → 0.8.213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/models/noise_model.py +388 -0
  2. workbench/api/endpoint.py +3 -6
  3. workbench/api/feature_set.py +1 -1
  4. workbench/api/model.py +5 -11
  5. workbench/cached/cached_model.py +4 -4
  6. workbench/core/artifacts/endpoint_core.py +63 -153
  7. workbench/core/artifacts/model_core.py +21 -19
  8. workbench/core/transforms/features_to_model/features_to_model.py +2 -2
  9. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
  10. workbench/model_script_utils/model_script_utils.py +335 -0
  11. workbench/model_script_utils/pytorch_utils.py +395 -0
  12. workbench/model_script_utils/uq_harness.py +278 -0
  13. workbench/model_scripts/chemprop/chemprop.template +289 -666
  14. workbench/model_scripts/chemprop/generated_model_script.py +292 -669
  15. workbench/model_scripts/chemprop/model_script_utils.py +335 -0
  16. workbench/model_scripts/chemprop/requirements.txt +2 -10
  17. workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
  18. workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
  19. workbench/model_scripts/pytorch_model/pytorch.template +350 -607
  20. workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
  21. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  22. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  23. workbench/model_scripts/script_generation.py +2 -5
  24. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  25. workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
  26. workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  28. workbench/model_scripts/xgb_model/xgb_model.template +344 -407
  29. workbench/scripts/training_test.py +85 -0
  30. workbench/utils/chemprop_utils.py +18 -656
  31. workbench/utils/metrics_utils.py +172 -0
  32. workbench/utils/model_utils.py +104 -47
  33. workbench/utils/pytorch_utils.py +32 -472
  34. workbench/utils/xgboost_local_crossfold.py +267 -0
  35. workbench/utils/xgboost_model_utils.py +49 -356
  36. workbench/web_interface/components/plugins/model_details.py +30 -68
  37. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/METADATA +5 -5
  38. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/RECORD +42 -31
  39. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/entry_points.txt +1 -0
  40. workbench/model_scripts/uq_models/mapie.template +0 -605
  41. workbench/model_scripts/uq_models/requirements.txt +0 -1
  42. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/WHEEL +0 -0
  43. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,172 @@
1
+ """Metrics utilities for computing model performance from predictions."""
2
+
3
+ import logging
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy.stats import spearmanr
9
+ from sklearn.metrics import (
10
+ mean_absolute_error,
11
+ median_absolute_error,
12
+ precision_recall_fscore_support,
13
+ r2_score,
14
+ roc_auc_score,
15
+ root_mean_squared_error,
16
+ )
17
+
18
+ log = logging.getLogger("workbench")
19
+
20
+
21
+ def compute_classification_metrics(
22
+ predictions_df: pd.DataFrame,
23
+ target_col: str,
24
+ class_labels: List[str],
25
+ prediction_col: str = "prediction",
26
+ ) -> pd.DataFrame:
27
+ """Compute classification metrics from a predictions DataFrame.
28
+
29
+ Args:
30
+ predictions_df: DataFrame with target and prediction columns
31
+ target_col: Name of the target column
32
+ class_labels: List of class labels in order
33
+ prediction_col: Name of the prediction column (default: "prediction")
34
+
35
+ Returns:
36
+ DataFrame with per-class metrics (precision, recall, f1, roc_auc, support)
37
+ plus a weighted 'all' row
38
+ """
39
+ y_true = predictions_df[target_col]
40
+ y_pred = predictions_df[prediction_col]
41
+
42
+ # Precision, recall, f1, support per class
43
+ prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=class_labels, zero_division=0)
44
+
45
+ # ROC AUC per class (requires probability columns and sorted labels)
46
+ proba_cols = [f"{label}_proba" for label in class_labels]
47
+ if all(col in predictions_df.columns for col in proba_cols):
48
+ # roc_auc_score requires labels to be sorted, so we sort and reorder results back
49
+ sorted_labels = sorted(class_labels)
50
+ sorted_proba_cols = [f"{label}_proba" for label in sorted_labels]
51
+ y_score_sorted = predictions_df[sorted_proba_cols].values
52
+ roc_auc_sorted = roc_auc_score(y_true, y_score_sorted, labels=sorted_labels, multi_class="ovr", average=None)
53
+ # Map back to original class_labels order
54
+ label_to_auc = dict(zip(sorted_labels, roc_auc_sorted))
55
+ roc_auc = np.array([label_to_auc[label] for label in class_labels])
56
+ else:
57
+ roc_auc = np.array([None] * len(class_labels))
58
+
59
+ # Build per-class metrics
60
+ metrics_df = pd.DataFrame(
61
+ {
62
+ target_col: class_labels,
63
+ "precision": prec,
64
+ "recall": rec,
65
+ "f1": f1,
66
+ "roc_auc": roc_auc,
67
+ "support": support.astype(int),
68
+ }
69
+ )
70
+
71
+ # Add weighted 'all' row
72
+ total = support.sum()
73
+ all_row = {
74
+ target_col: "all",
75
+ "precision": (prec * support).sum() / total,
76
+ "recall": (rec * support).sum() / total,
77
+ "f1": (f1 * support).sum() / total,
78
+ "roc_auc": (roc_auc * support).sum() / total if roc_auc[0] is not None else None,
79
+ "support": int(total),
80
+ }
81
+ metrics_df = pd.concat([metrics_df, pd.DataFrame([all_row])], ignore_index=True)
82
+
83
+ return metrics_df
84
+
85
+
86
+ def compute_regression_metrics(
87
+ predictions_df: pd.DataFrame,
88
+ target_col: str,
89
+ prediction_col: str = "prediction",
90
+ ) -> pd.DataFrame:
91
+ """Compute regression metrics from a predictions DataFrame.
92
+
93
+ Args:
94
+ predictions_df: DataFrame with target and prediction columns
95
+ target_col: Name of the target column
96
+ prediction_col: Name of the prediction column (default: "prediction")
97
+
98
+ Returns:
99
+ DataFrame with regression metrics (rmse, mae, medae, r2, spearmanr, support)
100
+ """
101
+ y_true = predictions_df[target_col].values
102
+ y_pred = predictions_df[prediction_col].values
103
+
104
+ return pd.DataFrame(
105
+ [
106
+ {
107
+ "rmse": root_mean_squared_error(y_true, y_pred),
108
+ "mae": mean_absolute_error(y_true, y_pred),
109
+ "medae": median_absolute_error(y_true, y_pred),
110
+ "r2": r2_score(y_true, y_pred),
111
+ "spearmanr": spearmanr(y_true, y_pred).correlation,
112
+ "support": len(y_true),
113
+ }
114
+ ]
115
+ )
116
+
117
+
118
+ def compute_metrics_from_predictions(
119
+ predictions_df: pd.DataFrame,
120
+ target_col: str,
121
+ class_labels: Optional[List[str]] = None,
122
+ prediction_col: str = "prediction",
123
+ ) -> pd.DataFrame:
124
+ """Compute metrics from a predictions DataFrame.
125
+
126
+ Automatically determines if this is classification or regression based on
127
+ whether class_labels is provided.
128
+
129
+ Args:
130
+ predictions_df: DataFrame with target and prediction columns
131
+ target_col: Name of the target column
132
+ class_labels: List of class labels for classification (None for regression)
133
+ prediction_col: Name of the prediction column (default: "prediction")
134
+
135
+ Returns:
136
+ DataFrame with computed metrics
137
+ """
138
+ if target_col not in predictions_df.columns:
139
+ raise ValueError(f"Target column '{target_col}' not found in predictions DataFrame")
140
+ if prediction_col not in predictions_df.columns:
141
+ raise ValueError(f"Prediction column '{prediction_col}' not found in predictions DataFrame")
142
+
143
+ if class_labels:
144
+ return compute_classification_metrics(predictions_df, target_col, class_labels, prediction_col)
145
+ else:
146
+ return compute_regression_metrics(predictions_df, target_col, prediction_col)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ # Test with sample data
151
+ print("Testing classification metrics...")
152
+ class_df = pd.DataFrame(
153
+ {
154
+ "target": ["a", "b", "c", "a", "b", "c", "a", "b", "c", "a"],
155
+ "prediction": ["a", "b", "c", "a", "b", "a", "a", "b", "c", "b"],
156
+ "a_proba": [0.8, 0.1, 0.1, 0.7, 0.2, 0.4, 0.9, 0.1, 0.1, 0.3],
157
+ "b_proba": [0.1, 0.8, 0.1, 0.2, 0.7, 0.3, 0.05, 0.8, 0.2, 0.6],
158
+ "c_proba": [0.1, 0.1, 0.8, 0.1, 0.1, 0.3, 0.05, 0.1, 0.7, 0.1],
159
+ }
160
+ )
161
+ metrics = compute_metrics_from_predictions(class_df, "target", ["a", "b", "c"])
162
+ print(metrics.to_string(index=False))
163
+
164
+ print("\nTesting regression metrics...")
165
+ reg_df = pd.DataFrame(
166
+ {
167
+ "target": [1.0, 2.0, 3.0, 4.0, 5.0],
168
+ "prediction": [1.1, 2.2, 2.9, 4.1, 4.8],
169
+ }
170
+ )
171
+ metrics = compute_metrics_from_predictions(reg_df, "target")
172
+ print(metrics.to_string(index=False))
@@ -125,8 +125,40 @@ def proximity_model_local(model: "Model"):
125
125
  return Proximity(full_df, id_column, features, target, track_columns=features)
126
126
 
127
127
 
128
- def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
129
- """Create a proximity model based on the given model
128
+ def noise_model_local(model: "Model"):
129
+ """Create a NoiseModel for detecting noisy/problematic samples in a Model's training data.
130
+
131
+ Args:
132
+ model (Model): The Model used to create the noise model
133
+
134
+ Returns:
135
+ NoiseModel: The noise model with precomputed noise scores for all samples
136
+ """
137
+ from workbench.algorithms.models.noise_model import NoiseModel # noqa: F401 (avoid circular import)
138
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
139
+
140
+ # Get Feature and Target Columns from the existing given Model
141
+ features = model.features()
142
+ target = model.target()
143
+
144
+ # Backtrack our FeatureSet to get the ID column
145
+ fs = FeatureSet(model.get_input())
146
+ id_column = fs.id_column
147
+
148
+ # Create the NoiseModel from both the full FeatureSet and the Model training data
149
+ full_df = fs.pull_dataframe()
150
+ model_df = model.training_view().pull_dataframe()
151
+
152
+ # Mark rows that are in the model
153
+ model_ids = set(model_df[id_column])
154
+ full_df["in_model"] = full_df[id_column].isin(model_ids)
155
+
156
+ # Create and return the NoiseModel
157
+ return NoiseModel(full_df, id_column, features, target)
158
+
159
+
160
+ def published_proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
161
+ """Create a published proximity model based on the given model
130
162
 
131
163
  Args:
132
164
  model (Model): The model to create the proximity model from
@@ -159,38 +191,6 @@ def proximity_model(model: "Model", prox_model_name: str, track_columns: list =
159
191
  return prox_model
160
192
 
161
193
 
162
- def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -> "Model":
163
- """Create a Uncertainty Quantification (UQ) model based on the given model
164
-
165
- Args:
166
- model (Model): The model to create the UQ model from
167
- uq_model_name (str): The name of the UQ model to create
168
- train_all_data (bool, optional): Whether to train the UQ model on all data (default: False)
169
-
170
- Returns:
171
- Model: The UQ model
172
- """
173
- from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
174
-
175
- # Get Feature and Target Columns from the existing given Model
176
- features = model.features()
177
- target = model.target()
178
-
179
- # Create the Proximity Model from our FeatureSet
180
- fs = FeatureSet(model.get_input())
181
- uq_model = fs.to_model(
182
- name=uq_model_name,
183
- model_type=ModelType.UQ_REGRESSOR,
184
- feature_list=features,
185
- target_column=target,
186
- description=f"UQ Model for {model.name}",
187
- tags=["uq", model.name],
188
- train_all_data=train_all_data,
189
- custom_args={"id_column": fs.id_column, "track_columns": [target]},
190
- )
191
- return uq_model
192
-
193
-
194
194
  def safe_extract_tarfile(tar_path: str, extract_path: str) -> None:
195
195
  """
196
196
  Extract a tarball safely, using data filter if available.
@@ -239,6 +239,63 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
239
239
  return category_mappings
240
240
 
241
241
 
242
+ def load_hyperparameters_from_s3(model_artifact_uri: str) -> Optional[dict]:
243
+ """
244
+ Download and extract hyperparameters from a model artifact in S3.
245
+
246
+ Args:
247
+ model_artifact_uri (str): S3 URI of the model artifact (model.tar.gz).
248
+
249
+ Returns:
250
+ dict: The loaded hyperparameters or None if not found.
251
+ """
252
+ hyperparameters = None
253
+
254
+ with tempfile.TemporaryDirectory() as tmpdir:
255
+ # Download model artifact
256
+ local_tar_path = os.path.join(tmpdir, "model.tar.gz")
257
+ wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
258
+
259
+ # Extract tarball
260
+ safe_extract_tarfile(local_tar_path, tmpdir)
261
+
262
+ # Look for hyperparameters in base directory only
263
+ hyperparameters_path = os.path.join(tmpdir, "hyperparameters.json")
264
+
265
+ if os.path.exists(hyperparameters_path):
266
+ try:
267
+ with open(hyperparameters_path, "r") as f:
268
+ hyperparameters = json.load(f)
269
+ log.info(f"Loaded hyperparameters from {hyperparameters_path}")
270
+ except Exception as e:
271
+ log.warning(f"Failed to load hyperparameters from {hyperparameters_path}: {e}")
272
+
273
+ return hyperparameters
274
+
275
+
276
+ def get_model_hyperparameters(workbench_model: Any) -> Optional[dict]:
277
+ """Get the hyperparameters used to train a Workbench model.
278
+
279
+ This retrieves the hyperparameters.json file from the model artifacts
280
+ that was saved during model training.
281
+
282
+ Args:
283
+ workbench_model: Workbench model object
284
+
285
+ Returns:
286
+ dict: The hyperparameters used during training, or None if not found
287
+ """
288
+ # Get the model artifact URI
289
+ model_artifact_uri = workbench_model.model_data_url()
290
+
291
+ if model_artifact_uri is None:
292
+ log.warning(f"No model artifact found for {workbench_model.uuid}")
293
+ return None
294
+
295
+ log.info(f"Loading hyperparameters from {model_artifact_uri}")
296
+ return load_hyperparameters_from_s3(model_artifact_uri)
297
+
298
+
242
299
  def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
243
300
  """
244
301
  Evaluate uncertainty quantification model with essential metrics.
@@ -259,6 +316,13 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
259
316
  if "prediction" not in df.columns:
260
317
  raise ValueError("Prediction column 'prediction' not found in DataFrame.")
261
318
 
319
+ # Drop rows with NaN predictions (e.g., from models that can't handle missing features)
320
+ n_total = len(df)
321
+ df = df.dropna(subset=["prediction", target_col])
322
+ n_valid = len(df)
323
+ if n_valid < n_total:
324
+ log.info(f"UQ metrics: dropped {n_total - n_valid} rows with NaN predictions")
325
+
262
326
  # --- Coverage and Interval Width ---
263
327
  if "q_025" in df.columns and "q_975" in df.columns:
264
328
  lower_95, upper_95 = df["q_025"], df["q_975"]
@@ -350,7 +414,7 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
350
414
 
351
415
  if __name__ == "__main__":
352
416
  """Exercise the Model Utilities"""
353
- from workbench.api import Model, Endpoint
417
+ from workbench.api import Model
354
418
 
355
419
  # Get the instance information
356
420
  print(model_instance_info())
@@ -365,18 +429,11 @@ if __name__ == "__main__":
365
429
  # Get the custom script path
366
430
  print(get_custom_script_path("chem_info", "molecular_descriptors.py"))
367
431
 
368
- # Test the proximity model
432
+ # Test loading hyperparameters
369
433
  m = Model("aqsol-regression")
434
+ hyperparams = get_model_hyperparameters(m)
435
+ print(hyperparams)
436
+
437
+ # Test the proximity model
370
438
  # prox_model = proximity_model(m, "aqsol-prox")
371
439
  # print(prox_model)#
372
-
373
- # Test the UQ model
374
- # uq_model_instance = uq_model(m, "aqsol-uq")
375
- # print(uq_model_instance)
376
- # uq_model_instance.to_endpoint()
377
-
378
- # Test the uq_metrics function
379
- end = Endpoint("aqsol-uq")
380
- df = end.auto_inference(capture=True)
381
- results = uq_metrics(df, target_col="solubility")
382
- print(results)