workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,526 @@
1
+ """PyTorch Tabular utilities for Workbench models."""
2
+
3
+ # flake8: noqa: E402
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ from pprint import pformat
8
+ from typing import Any, Tuple
9
+
10
+ # Disable OpenMP parallelism to avoid segfaults on macOS with conflicting OpenMP runtimes
11
+ # (libomp from LLVM vs libiomp from Intel). Must be set before importing numpy/sklearn/torch.
12
+ # See: https://github.com/scikit-learn/scikit-learn/issues/21302
13
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
14
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ from scipy.stats import spearmanr
19
+ from sklearn.metrics import (
20
+ mean_absolute_error,
21
+ mean_squared_error,
22
+ median_absolute_error,
23
+ precision_recall_fscore_support,
24
+ r2_score,
25
+ roc_auc_score,
26
+ )
27
+ from sklearn.model_selection import KFold, StratifiedKFold
28
+ from sklearn.preprocessing import LabelEncoder
29
+
30
+ from workbench.utils.model_utils import safe_extract_tarfile
31
+ from workbench.utils.pandas_utils import expand_proba_column
32
+ from workbench.utils.aws_utils import pull_s3_data
33
+
34
+ log = logging.getLogger("workbench")
35
+
36
+
37
+ def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
38
+ """Download model artifact from S3 and extract it.
39
+
40
+ Args:
41
+ s3_uri: S3 URI to the model artifact (model.tar.gz)
42
+ model_dir: Directory to extract model artifacts to
43
+ """
44
+ import awswrangler as wr
45
+
46
+ log.info(f"Downloading model from {s3_uri}...")
47
+
48
+ # Download to temp file
49
+ local_tar_path = os.path.join(model_dir, "model.tar.gz")
50
+ wr.s3.download(path=s3_uri, local_file=local_tar_path)
51
+
52
+ # Extract using safe extraction
53
+ log.info(f"Extracting to {model_dir}...")
54
+ safe_extract_tarfile(local_tar_path, model_dir)
55
+
56
+ # Cleanup tar file
57
+ os.unlink(local_tar_path)
58
+
59
+
60
+ def load_pytorch_model_artifacts(model_dir: str) -> Tuple[Any, dict]:
61
+ """Load PyTorch Tabular model and artifacts from an extracted model directory.
62
+
63
+ Args:
64
+ model_dir: Directory containing extracted model artifacts
65
+
66
+ Returns:
67
+ Tuple of (TabularModel, artifacts_dict).
68
+ artifacts_dict contains 'label_encoder' and 'category_mappings' if present.
69
+ """
70
+ import json
71
+
72
+ import joblib
73
+
74
+ # pytorch-tabular saves complex objects, use legacy loading behavior
75
+ os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
76
+ from pytorch_tabular import TabularModel
77
+
78
+ model_path = os.path.join(model_dir, "tabular_model")
79
+ if not os.path.exists(model_path):
80
+ raise FileNotFoundError(f"No tabular_model directory found in {model_dir}")
81
+
82
+ # PyTorch Tabular needs write access, so chdir to /tmp
83
+ original_cwd = os.getcwd()
84
+ try:
85
+ os.chdir("/tmp")
86
+ model = TabularModel.load_model(model_path)
87
+ finally:
88
+ os.chdir(original_cwd)
89
+
90
+ # Load additional artifacts
91
+ artifacts = {}
92
+
93
+ label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
94
+ if os.path.exists(label_encoder_path):
95
+ artifacts["label_encoder"] = joblib.load(label_encoder_path)
96
+
97
+ category_mappings_path = os.path.join(model_dir, "category_mappings.json")
98
+ if os.path.exists(category_mappings_path):
99
+ with open(category_mappings_path) as f:
100
+ artifacts["category_mappings"] = json.load(f)
101
+
102
+ return model, artifacts
103
+
104
+
105
+ def _extract_model_configs(loaded_model: Any, n_train: int) -> dict:
106
+ """Extract trainer and model configs from a loaded PyTorch Tabular model.
107
+
108
+ Args:
109
+ loaded_model: Loaded TabularModel instance
110
+ n_train: Number of training samples (used for batch_size calculation)
111
+
112
+ Returns:
113
+ Dictionary with 'trainer' and 'model' config dictionaries
114
+ """
115
+ config = loaded_model.config
116
+
117
+ # Trainer config - extract from loaded model, matching template defaults
118
+ trainer_defaults = {
119
+ "auto_lr_find": False,
120
+ "batch_size": min(128, max(32, n_train // 16)),
121
+ "max_epochs": 100,
122
+ "min_epochs": 10,
123
+ "early_stopping": "valid_loss",
124
+ "early_stopping_patience": 10,
125
+ "gradient_clip_val": 1.0,
126
+ }
127
+
128
+ trainer_config = {}
129
+ for key, default in trainer_defaults.items():
130
+ value = getattr(config, key, default)
131
+ if value == default and not hasattr(config, key):
132
+ log.warning(f"Trainer config '{key}' not found in loaded model, using default: {default}")
133
+ trainer_config[key] = value
134
+
135
+ # Model config - extract from loaded model, matching template defaults
136
+ model_defaults = {
137
+ "layers": "256-128-64",
138
+ "activation": "LeakyReLU",
139
+ "learning_rate": 1e-3,
140
+ "dropout": 0.3,
141
+ "use_batch_norm": True,
142
+ "initialization": "kaiming",
143
+ }
144
+
145
+ model_config = {}
146
+ for key, default in model_defaults.items():
147
+ value = getattr(config, key, default)
148
+ if value == default and not hasattr(config, key):
149
+ log.warning(f"Model config '{key}' not found in loaded model, using default: {default}")
150
+ model_config[key] = value
151
+
152
+ return {"trainer": trainer_config, "model": model_config}
153
+
154
+
155
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
156
+ """Pull cross-validation results from AWS training artifacts.
157
+
158
+ This retrieves the validation predictions and training metrics that were
159
+ saved during model training.
160
+
161
+ Args:
162
+ workbench_model: Workbench model object
163
+
164
+ Returns:
165
+ Tuple of:
166
+ - DataFrame with training metrics
167
+ - DataFrame with validation predictions
168
+ """
169
+ # Get the validation predictions from S3
170
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
171
+ predictions_df = pull_s3_data(s3_path)
172
+
173
+ if predictions_df is None:
174
+ raise ValueError(f"No validation predictions found at {s3_path}")
175
+
176
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
177
+
178
+ # Get training metrics from model metadata
179
+ training_metrics = workbench_model.workbench_meta().get("workbench_training_metrics")
180
+
181
+ if training_metrics is None:
182
+ raise ValueError(f"No training metrics found in model metadata for {workbench_model.model_name}")
183
+
184
+ metrics_df = pd.DataFrame.from_dict(training_metrics)
185
+ log.info(f"Metrics summary:\n{metrics_df.to_string(index=False)}")
186
+
187
+ return metrics_df, predictions_df
188
+
189
+
190
+ def cross_fold_inference(
191
+ workbench_model: Any,
192
+ nfolds: int = 5,
193
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
194
+ """Performs K-fold cross-validation for PyTorch Tabular models.
195
+
196
+ Replicates the training setup from the original model to ensure
197
+ cross-validation results are comparable to the deployed model.
198
+
199
+ Args:
200
+ workbench_model: Workbench model object
201
+ nfolds: Number of folds for cross-validation (default is 5)
202
+
203
+ Returns:
204
+ Tuple of:
205
+ - DataFrame with per-class metrics (and 'all' row for overall metrics)
206
+ - DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
207
+ """
208
+ import shutil
209
+
210
+ from pytorch_tabular import TabularModel
211
+ from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
212
+ from pytorch_tabular.models import CategoryEmbeddingModelConfig
213
+
214
+ from workbench.api import FeatureSet
215
+
216
+ # Create a temporary model directory
217
+ model_dir = tempfile.mkdtemp(prefix="pytorch_cv_")
218
+ log.info(f"Using model directory: {model_dir}")
219
+
220
+ try:
221
+ # Download and extract model artifacts to get config and artifacts
222
+ model_artifact_uri = workbench_model.model_data_url()
223
+ download_and_extract_model(model_artifact_uri, model_dir)
224
+
225
+ # Load model and artifacts
226
+ loaded_model, artifacts = load_pytorch_model_artifacts(model_dir)
227
+ category_mappings = artifacts.get("category_mappings", {})
228
+
229
+ # Determine if classifier from the loaded model's config
230
+ is_classifier = loaded_model.config.task == "classification"
231
+
232
+ # Use saved label encoder if available, otherwise create fresh one
233
+ if is_classifier:
234
+ label_encoder = artifacts.get("label_encoder")
235
+ if label_encoder is None:
236
+ log.warning("No saved label encoder found, creating fresh one")
237
+ label_encoder = LabelEncoder()
238
+ else:
239
+ label_encoder = None
240
+
241
+ # Prepare data
242
+ fs = FeatureSet(workbench_model.get_input())
243
+ df = workbench_model.training_view().pull_dataframe()
244
+
245
+ # Get columns
246
+ id_col = fs.id_column
247
+ target_col = workbench_model.target()
248
+ feature_cols = workbench_model.features()
249
+ print(f"Target column: {target_col}")
250
+ print(f"Feature columns: {len(feature_cols)} features")
251
+
252
+ # Convert string columns to category for PyTorch Tabular compatibility
253
+ for col in feature_cols:
254
+ if pd.api.types.is_string_dtype(df[col]):
255
+ if col in category_mappings:
256
+ df[col] = pd.Categorical(df[col], categories=category_mappings[col])
257
+ else:
258
+ df[col] = df[col].astype("category")
259
+
260
+ # Determine categorical and continuous columns
261
+ categorical_cols = [col for col in feature_cols if df[col].dtype.name == "category"]
262
+ continuous_cols = [col for col in feature_cols if col not in categorical_cols]
263
+
264
+ # Cast continuous columns to float
265
+ if continuous_cols:
266
+ df[continuous_cols] = df[continuous_cols].astype("float64")
267
+
268
+ # Drop rows with NaN features or target (PyTorch Tabular cannot handle NaN values)
269
+ nan_mask = df[feature_cols].isna().any(axis=1) | df[target_col].isna()
270
+ if nan_mask.any():
271
+ n_nan_rows = nan_mask.sum()
272
+ log.warning(
273
+ f"Dropping {n_nan_rows} rows ({100*n_nan_rows/len(df):.1f}%) with NaN values for cross-validation"
274
+ )
275
+ df = df[~nan_mask].reset_index(drop=True)
276
+
277
+ X = df[feature_cols]
278
+ y = df[target_col]
279
+ ids = df[id_col]
280
+
281
+ # Encode target if classifier
282
+ if label_encoder is not None:
283
+ if not hasattr(label_encoder, "classes_"):
284
+ label_encoder.fit(y)
285
+ y_encoded = label_encoder.transform(y)
286
+ y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
287
+ else:
288
+ y_for_cv = y
289
+
290
+ # Extract configs from loaded model (pass approx train size for batch_size calculation)
291
+ n_train_approx = int(len(df) * (1 - 1 / nfolds))
292
+ configs = _extract_model_configs(loaded_model, n_train_approx)
293
+ trainer_params = configs["trainer"]
294
+ model_params = configs["model"]
295
+
296
+ log.info(f"Trainer config:\n{pformat(trainer_params)}")
297
+ log.info(f"Model config:\n{pformat(model_params)}")
298
+
299
+ # Prepare KFold
300
+ kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
301
+
302
+ # Initialize results collection
303
+ fold_metrics = []
304
+ predictions_df = pd.DataFrame({id_col: ids, target_col: y})
305
+ if is_classifier:
306
+ predictions_df["pred_proba"] = [None] * len(predictions_df)
307
+
308
+ # Perform cross-validation
309
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
310
+ print(f"\n{'='*50}")
311
+ print(f"Fold {fold_idx}/{nfolds}")
312
+ print(f"{'='*50}")
313
+
314
+ # Split data
315
+ df_train = df.iloc[train_idx].copy()
316
+ df_val = df.iloc[val_idx].copy()
317
+
318
+ # Encode target for this fold
319
+ if is_classifier:
320
+ df_train[target_col] = label_encoder.transform(df_train[target_col])
321
+ df_val[target_col] = label_encoder.transform(df_val[target_col])
322
+
323
+ # Create configs for this fold - matching the training template exactly
324
+ data_config = DataConfig(
325
+ target=[target_col],
326
+ continuous_cols=continuous_cols,
327
+ categorical_cols=categorical_cols,
328
+ )
329
+
330
+ trainer_config = TrainerConfig(
331
+ auto_lr_find=trainer_params["auto_lr_find"],
332
+ batch_size=trainer_params["batch_size"],
333
+ max_epochs=trainer_params["max_epochs"],
334
+ min_epochs=trainer_params["min_epochs"],
335
+ early_stopping=trainer_params["early_stopping"],
336
+ early_stopping_patience=trainer_params["early_stopping_patience"],
337
+ gradient_clip_val=trainer_params["gradient_clip_val"],
338
+ checkpoints="valid_loss", # Save best model based on validation loss
339
+ accelerator="cpu",
340
+ )
341
+
342
+ optimizer_config = OptimizerConfig()
343
+
344
+ model_config = CategoryEmbeddingModelConfig(
345
+ task="classification" if is_classifier else "regression",
346
+ layers=model_params["layers"],
347
+ activation=model_params["activation"],
348
+ learning_rate=model_params["learning_rate"],
349
+ dropout=model_params["dropout"],
350
+ use_batch_norm=model_params["use_batch_norm"],
351
+ initialization=model_params["initialization"],
352
+ )
353
+
354
+ # Create and train fresh model
355
+ tabular_model = TabularModel(
356
+ data_config=data_config,
357
+ model_config=model_config,
358
+ optimizer_config=optimizer_config,
359
+ trainer_config=trainer_config,
360
+ )
361
+
362
+ # Change to /tmp for training (PyTorch Tabular needs write access)
363
+ original_cwd = os.getcwd()
364
+ try:
365
+ os.chdir("/tmp")
366
+ # Clean up checkpoint directory from previous fold
367
+ checkpoint_dir = "/tmp/saved_models"
368
+ if os.path.exists(checkpoint_dir):
369
+ shutil.rmtree(checkpoint_dir)
370
+ tabular_model.fit(train=df_train, validation=df_val)
371
+ finally:
372
+ os.chdir(original_cwd)
373
+
374
+ # Make predictions
375
+ result = tabular_model.predict(df_val[feature_cols])
376
+
377
+ # Extract predictions
378
+ prediction_col = f"{target_col}_prediction"
379
+ preds = result[prediction_col].values
380
+
381
+ # Store predictions at the correct indices
382
+ val_indices = df.iloc[val_idx].index
383
+ if is_classifier:
384
+ preds_decoded = label_encoder.inverse_transform(preds.astype(int))
385
+ predictions_df.loc[val_indices, "prediction"] = preds_decoded
386
+
387
+ # Get probabilities and store at validation indices only
388
+ prob_cols = sorted([col for col in result.columns if col.endswith("_probability")])
389
+ if prob_cols:
390
+ probs = result[prob_cols].values
391
+ for i, idx in enumerate(val_indices):
392
+ predictions_df.at[idx, "pred_proba"] = probs[i].tolist()
393
+ else:
394
+ predictions_df.loc[val_indices, "prediction"] = preds
395
+
396
+ # Calculate fold metrics
397
+ if is_classifier:
398
+ y_val_orig = label_encoder.inverse_transform(df_val[target_col])
399
+ preds_orig = preds_decoded
400
+
401
+ prec, rec, f1, _ = precision_recall_fscore_support(
402
+ y_val_orig, preds_orig, average="weighted", zero_division=0
403
+ )
404
+
405
+ prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
406
+ y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
407
+ )
408
+
409
+ y_val_encoded = df_val[target_col].values
410
+ roc_auc_overall = roc_auc_score(y_val_encoded, probs, multi_class="ovr", average="macro")
411
+ roc_auc_per_class = roc_auc_score(y_val_encoded, probs, multi_class="ovr", average=None)
412
+
413
+ fold_metrics.append(
414
+ {
415
+ "fold": fold_idx,
416
+ "precision": prec,
417
+ "recall": rec,
418
+ "f1": f1,
419
+ "roc_auc": roc_auc_overall,
420
+ "precision_per_class": prec_per_class,
421
+ "recall_per_class": rec_per_class,
422
+ "f1_per_class": f1_per_class,
423
+ "roc_auc_per_class": roc_auc_per_class,
424
+ }
425
+ )
426
+
427
+ print(f"Fold {fold_idx} - F1: {f1:.4f}, ROC-AUC: {roc_auc_overall:.4f}")
428
+ else:
429
+ y_val = df_val[target_col].values
430
+ spearman_corr, _ = spearmanr(y_val, preds)
431
+ rmse = np.sqrt(mean_squared_error(y_val, preds))
432
+
433
+ fold_metrics.append(
434
+ {
435
+ "fold": fold_idx,
436
+ "rmse": rmse,
437
+ "mae": mean_absolute_error(y_val, preds),
438
+ "medae": median_absolute_error(y_val, preds),
439
+ "r2": r2_score(y_val, preds),
440
+ "spearmanr": spearman_corr,
441
+ }
442
+ )
443
+
444
+ print(f"Fold {fold_idx} - RMSE: {rmse:.4f}, R2: {fold_metrics[-1]['r2']:.4f}")
445
+
446
+ # Calculate summary metrics
447
+ fold_df = pd.DataFrame(fold_metrics)
448
+
449
+ if is_classifier:
450
+ if "pred_proba" in predictions_df.columns:
451
+ predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
452
+
453
+ metric_rows = []
454
+ for idx, class_name in enumerate(label_encoder.classes_):
455
+ prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
456
+ rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
457
+ f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
458
+ roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
459
+
460
+ y_orig = label_encoder.inverse_transform(y_for_cv)
461
+ support = int((y_orig == class_name).sum())
462
+
463
+ metric_rows.append(
464
+ {
465
+ "class": class_name,
466
+ "precision": prec_scores.mean(),
467
+ "recall": rec_scores.mean(),
468
+ "f1": f1_scores.mean(),
469
+ "roc_auc": roc_auc_scores.mean(),
470
+ "support": support,
471
+ }
472
+ )
473
+
474
+ metric_rows.append(
475
+ {
476
+ "class": "all",
477
+ "precision": fold_df["precision"].mean(),
478
+ "recall": fold_df["recall"].mean(),
479
+ "f1": fold_df["f1"].mean(),
480
+ "roc_auc": fold_df["roc_auc"].mean(),
481
+ "support": len(y_for_cv),
482
+ }
483
+ )
484
+
485
+ metrics_df = pd.DataFrame(metric_rows)
486
+ else:
487
+ metrics_df = pd.DataFrame(
488
+ [
489
+ {
490
+ "rmse": fold_df["rmse"].mean(),
491
+ "mae": fold_df["mae"].mean(),
492
+ "medae": fold_df["medae"].mean(),
493
+ "r2": fold_df["r2"].mean(),
494
+ "spearmanr": fold_df["spearmanr"].mean(),
495
+ "support": len(y_for_cv),
496
+ }
497
+ ]
498
+ )
499
+
500
+ print(f"\n{'='*50}")
501
+ print("Cross-Validation Summary")
502
+ print(f"{'='*50}")
503
+ print(metrics_df.to_string(index=False))
504
+
505
+ return metrics_df, predictions_df
506
+
507
+ finally:
508
+ log.info(f"Cleaning up model directory: {model_dir}")
509
+ shutil.rmtree(model_dir, ignore_errors=True)
510
+
511
+
512
+ if __name__ == "__main__":
513
+
514
+ # Tests for the PyTorch utilities
515
+ from workbench.api import Model, Endpoint
516
+
517
+ # Initialize Workbench model
518
+ model_name = "caco2-er-reg-pytorch-test"
519
+ # model_name = "aqsol-pytorch-reg"
520
+ print(f"Loading Workbench model: {model_name}")
521
+ model = Model(model_name)
522
+ print(f"Model Framework: {model.model_framework}")
523
+
524
+ # Perform cross-fold inference
525
+ end = Endpoint(model.endpoints()[0])
526
+ end.cross_fold_inference()
@@ -212,6 +212,14 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
212
212
  log.error("No XGBoost model found in the artifact.")
213
213
  return None, None, None, None
214
214
 
215
+ # Get the booster (SHAP requires the booster, not the sklearn wrapper)
216
+ if hasattr(xgb_model, "get_booster"):
217
+ # Full sklearn model - extract the booster
218
+ booster = xgb_model.get_booster()
219
+ else:
220
+ # Already a booster
221
+ booster = xgb_model
222
+
215
223
  # Load category mappings if available
216
224
  category_mappings = load_category_mappings_from_s3(model_artifact_uri)
217
225
 
@@ -229,8 +237,8 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
229
237
  # Create a DMatrix with categorical support
230
238
  dmatrix = xgb.DMatrix(X, enable_categorical=True)
231
239
 
232
- # Use XGBoost's built-in SHAP calculation
233
- shap_values = xgb_model.predict(dmatrix, pred_contribs=True, strict_shape=True)
240
+ # Use XGBoost's built-in SHAP calculation (booster method, not sklearn)
241
+ shap_values = booster.predict(dmatrix, pred_contribs=True, strict_shape=True)
234
242
  features_with_bias = features + ["bias"]
235
243
 
236
244
  # Now we need to subset the columns based on top 10 SHAP values
@@ -181,9 +181,6 @@ def logging_setup(color_logs=True):
181
181
  log.debug("Debugging enabled via WORKBENCH_DEBUG environment variable.")
182
182
  else:
183
183
  log.setLevel(logging.INFO)
184
- # Note: Not using the ThrottlingFilter for now
185
- # throttle_filter = ThrottlingFilter(rate_seconds=5)
186
- # handler.addFilter(throttle_filter)
187
184
 
188
185
  # Suppress specific logger
189
186
  logging.getLogger("sagemaker.config").setLevel(logging.WARNING)
@@ -12,7 +12,7 @@ class WorkbenchSQS:
12
12
  self.log = logging.getLogger("workbench")
13
13
  self.queue_url = queue_url
14
14
 
15
- # Grab a Workbench Session (this allows us to assume the Workbench-ExecutionRole)
15
+ # Grab a Workbench Session
16
16
  self.boto3_session = AWSAccountClamp().boto3_session
17
17
  print(self.boto3_session)
18
18