workbench 0.8.198__py3-none-any.whl → 0.8.203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. workbench/algorithms/dataframe/proximity.py +11 -4
  2. workbench/api/__init__.py +2 -1
  3. workbench/api/df_store.py +17 -108
  4. workbench/api/feature_set.py +48 -11
  5. workbench/api/model.py +1 -1
  6. workbench/api/parameter_store.py +3 -52
  7. workbench/core/artifacts/__init__.py +11 -2
  8. workbench/core/artifacts/artifact.py +5 -5
  9. workbench/core/artifacts/df_store_core.py +114 -0
  10. workbench/core/artifacts/endpoint_core.py +261 -78
  11. workbench/core/artifacts/feature_set_core.py +69 -1
  12. workbench/core/artifacts/model_core.py +48 -14
  13. workbench/core/artifacts/parameter_store_core.py +98 -0
  14. workbench/core/transforms/features_to_model/features_to_model.py +50 -33
  15. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  16. workbench/core/views/view.py +2 -2
  17. workbench/model_scripts/chemprop/chemprop.template +933 -0
  18. workbench/model_scripts/chemprop/generated_model_script.py +933 -0
  19. workbench/model_scripts/chemprop/requirements.txt +11 -0
  20. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  21. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  22. workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
  23. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
  24. workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
  25. workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
  26. workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
  27. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +365 -173
  29. workbench/model_scripts/pytorch_model/pytorch.template +362 -170
  30. workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
  31. workbench/model_scripts/script_generation.py +10 -7
  32. workbench/model_scripts/uq_models/generated_model_script.py +43 -27
  33. workbench/model_scripts/uq_models/mapie.template +40 -24
  34. workbench/model_scripts/xgb_model/generated_model_script.py +36 -7
  35. workbench/model_scripts/xgb_model/xgb_model.template +36 -7
  36. workbench/repl/workbench_shell.py +14 -5
  37. workbench/resources/open_source_api.key +1 -1
  38. workbench/scripts/endpoint_test.py +162 -0
  39. workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
  40. workbench/utils/chemprop_utils.py +761 -0
  41. workbench/utils/pytorch_utils.py +527 -0
  42. workbench/utils/xgboost_model_utils.py +10 -5
  43. workbench/web_interface/components/model_plot.py +7 -1
  44. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/METADATA +3 -3
  45. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/RECORD +49 -43
  46. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/entry_points.txt +2 -1
  47. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  48. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  49. workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
  50. workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
  51. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/WHEEL +0 -0
  52. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/licenses/LICENSE +0 -0
  53. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,527 @@
1
+ """PyTorch Tabular utilities for Workbench models."""
2
+
3
+ # flake8: noqa: E402
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ from pprint import pformat
8
+ from typing import Any, Tuple
9
+
10
+ # Disable OpenMP parallelism to avoid segfaults on macOS with conflicting OpenMP runtimes
11
+ # (libomp from LLVM vs libiomp from Intel). Must be set before importing numpy/sklearn/torch.
12
+ # See: https://github.com/scikit-learn/scikit-learn/issues/21302
13
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
14
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ from scipy.stats import spearmanr
19
+ from sklearn.metrics import (
20
+ mean_absolute_error,
21
+ mean_squared_error,
22
+ median_absolute_error,
23
+ precision_recall_fscore_support,
24
+ r2_score,
25
+ roc_auc_score,
26
+ )
27
+ from sklearn.model_selection import KFold, StratifiedKFold
28
+ from sklearn.preprocessing import LabelEncoder
29
+
30
+ from workbench.utils.model_utils import safe_extract_tarfile
31
+ from workbench.utils.pandas_utils import expand_proba_column
32
+ from workbench.utils.aws_utils import pull_s3_data
33
+
34
+ log = logging.getLogger("workbench")
35
+
36
+
37
+ def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
38
+ """Download model artifact from S3 and extract it.
39
+
40
+ Args:
41
+ s3_uri: S3 URI to the model artifact (model.tar.gz)
42
+ model_dir: Directory to extract model artifacts to
43
+ """
44
+ import awswrangler as wr
45
+
46
+ log.info(f"Downloading model from {s3_uri}...")
47
+
48
+ # Download to temp file
49
+ local_tar_path = os.path.join(model_dir, "model.tar.gz")
50
+ wr.s3.download(path=s3_uri, local_file=local_tar_path)
51
+
52
+ # Extract using safe extraction
53
+ log.info(f"Extracting to {model_dir}...")
54
+ safe_extract_tarfile(local_tar_path, model_dir)
55
+
56
+ # Cleanup tar file
57
+ os.unlink(local_tar_path)
58
+
59
+
60
+ def load_pytorch_model_artifacts(model_dir: str) -> Tuple[Any, dict]:
61
+ """Load PyTorch Tabular model and artifacts from an extracted model directory.
62
+
63
+ Args:
64
+ model_dir: Directory containing extracted model artifacts
65
+
66
+ Returns:
67
+ Tuple of (TabularModel, artifacts_dict).
68
+ artifacts_dict contains 'label_encoder' and 'category_mappings' if present.
69
+ """
70
+ import json
71
+
72
+ import joblib
73
+
74
+ # pytorch-tabular saves complex objects, use legacy loading behavior
75
+ os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
76
+ from pytorch_tabular import TabularModel
77
+
78
+ model_path = os.path.join(model_dir, "tabular_model")
79
+ if not os.path.exists(model_path):
80
+ raise FileNotFoundError(f"No tabular_model directory found in {model_dir}")
81
+
82
+ # PyTorch Tabular needs write access, so chdir to /tmp
83
+ original_cwd = os.getcwd()
84
+ try:
85
+ os.chdir("/tmp")
86
+ model = TabularModel.load_model(model_path)
87
+ finally:
88
+ os.chdir(original_cwd)
89
+
90
+ # Load additional artifacts
91
+ artifacts = {}
92
+
93
+ label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
94
+ if os.path.exists(label_encoder_path):
95
+ artifacts["label_encoder"] = joblib.load(label_encoder_path)
96
+
97
+ category_mappings_path = os.path.join(model_dir, "category_mappings.json")
98
+ if os.path.exists(category_mappings_path):
99
+ with open(category_mappings_path) as f:
100
+ artifacts["category_mappings"] = json.load(f)
101
+
102
+ return model, artifacts
103
+
104
+
105
+ def _extract_model_configs(loaded_model: Any, n_train: int) -> dict:
106
+ """Extract trainer and model configs from a loaded PyTorch Tabular model.
107
+
108
+ Args:
109
+ loaded_model: Loaded TabularModel instance
110
+ n_train: Number of training samples (used for batch_size calculation)
111
+
112
+ Returns:
113
+ Dictionary with 'trainer' and 'model' config dictionaries
114
+ """
115
+ config = loaded_model.config
116
+
117
+ # Trainer config - extract from loaded model, matching template defaults
118
+ trainer_defaults = {
119
+ "auto_lr_find": False,
120
+ "batch_size": min(128, max(32, n_train // 16)),
121
+ "max_epochs": 100,
122
+ "min_epochs": 10,
123
+ "early_stopping": "valid_loss",
124
+ "early_stopping_patience": 10,
125
+ "gradient_clip_val": 1.0,
126
+ }
127
+
128
+ trainer_config = {}
129
+ for key, default in trainer_defaults.items():
130
+ value = getattr(config, key, default)
131
+ if value == default and not hasattr(config, key):
132
+ log.warning(f"Trainer config '{key}' not found in loaded model, using default: {default}")
133
+ trainer_config[key] = value
134
+
135
+ # Model config - extract from loaded model, matching template defaults
136
+ model_defaults = {
137
+ "layers": "256-128-64",
138
+ "activation": "LeakyReLU",
139
+ "learning_rate": 1e-3,
140
+ "dropout": 0.3,
141
+ "use_batch_norm": True,
142
+ "initialization": "kaiming",
143
+ }
144
+
145
+ model_config = {}
146
+ for key, default in model_defaults.items():
147
+ value = getattr(config, key, default)
148
+ if value == default and not hasattr(config, key):
149
+ log.warning(f"Model config '{key}' not found in loaded model, using default: {default}")
150
+ model_config[key] = value
151
+
152
+ return {"trainer": trainer_config, "model": model_config}
153
+
154
+
155
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
156
+ """Pull cross-validation results from AWS training artifacts.
157
+
158
+ This retrieves the validation predictions and training metrics that were
159
+ saved during model training.
160
+
161
+ Args:
162
+ workbench_model: Workbench model object
163
+
164
+ Returns:
165
+ Tuple of:
166
+ - DataFrame with training metrics
167
+ - DataFrame with validation predictions
168
+ """
169
+ # Get the validation predictions from S3
170
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
171
+ predictions_df = pull_s3_data(s3_path)
172
+
173
+ if predictions_df is None:
174
+ raise ValueError(f"No validation predictions found at {s3_path}")
175
+
176
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
177
+
178
+ # Get training metrics from model metadata
179
+ training_metrics = workbench_model.workbench_meta().get("workbench_training_metrics")
180
+
181
+ if training_metrics is None:
182
+ log.warning(f"No training metrics found in model metadata for {workbench_model.model_name}")
183
+ metrics_df = pd.DataFrame({"error": [f"No training metrics found for {workbench_model.model_name}"]})
184
+ else:
185
+ metrics_df = pd.DataFrame.from_dict(training_metrics)
186
+ log.info(f"Metrics summary:\n{metrics_df.to_string(index=False)}")
187
+
188
+ return metrics_df, predictions_df
189
+
190
+
191
+ def cross_fold_inference(
192
+ workbench_model: Any,
193
+ nfolds: int = 5,
194
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
195
+ """Performs K-fold cross-validation for PyTorch Tabular models.
196
+
197
+ Replicates the training setup from the original model to ensure
198
+ cross-validation results are comparable to the deployed model.
199
+
200
+ Args:
201
+ workbench_model: Workbench model object
202
+ nfolds: Number of folds for cross-validation (default is 5)
203
+
204
+ Returns:
205
+ Tuple of:
206
+ - DataFrame with per-class metrics (and 'all' row for overall metrics)
207
+ - DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
208
+ """
209
+ import shutil
210
+
211
+ from pytorch_tabular import TabularModel
212
+ from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
213
+ from pytorch_tabular.models import CategoryEmbeddingModelConfig
214
+
215
+ from workbench.api import FeatureSet
216
+
217
+ # Create a temporary model directory
218
+ model_dir = tempfile.mkdtemp(prefix="pytorch_cv_")
219
+ log.info(f"Using model directory: {model_dir}")
220
+
221
+ try:
222
+ # Download and extract model artifacts to get config and artifacts
223
+ model_artifact_uri = workbench_model.model_data_url()
224
+ download_and_extract_model(model_artifact_uri, model_dir)
225
+
226
+ # Load model and artifacts
227
+ loaded_model, artifacts = load_pytorch_model_artifacts(model_dir)
228
+ category_mappings = artifacts.get("category_mappings", {})
229
+
230
+ # Determine if classifier from the loaded model's config
231
+ is_classifier = loaded_model.config.task == "classification"
232
+
233
+ # Use saved label encoder if available, otherwise create fresh one
234
+ if is_classifier:
235
+ label_encoder = artifacts.get("label_encoder")
236
+ if label_encoder is None:
237
+ log.warning("No saved label encoder found, creating fresh one")
238
+ label_encoder = LabelEncoder()
239
+ else:
240
+ label_encoder = None
241
+
242
+ # Prepare data
243
+ fs = FeatureSet(workbench_model.get_input())
244
+ df = workbench_model.training_view().pull_dataframe()
245
+
246
+ # Get columns
247
+ id_col = fs.id_column
248
+ target_col = workbench_model.target()
249
+ feature_cols = workbench_model.features()
250
+ print(f"Target column: {target_col}")
251
+ print(f"Feature columns: {len(feature_cols)} features")
252
+
253
+ # Convert string columns to category for PyTorch Tabular compatibility
254
+ for col in feature_cols:
255
+ if pd.api.types.is_string_dtype(df[col]):
256
+ if col in category_mappings:
257
+ df[col] = pd.Categorical(df[col], categories=category_mappings[col])
258
+ else:
259
+ df[col] = df[col].astype("category")
260
+
261
+ # Determine categorical and continuous columns
262
+ categorical_cols = [col for col in feature_cols if df[col].dtype.name == "category"]
263
+ continuous_cols = [col for col in feature_cols if col not in categorical_cols]
264
+
265
+ # Cast continuous columns to float
266
+ if continuous_cols:
267
+ df[continuous_cols] = df[continuous_cols].astype("float64")
268
+
269
+ # Drop rows with NaN features or target (PyTorch Tabular cannot handle NaN values)
270
+ nan_mask = df[feature_cols].isna().any(axis=1) | df[target_col].isna()
271
+ if nan_mask.any():
272
+ n_nan_rows = nan_mask.sum()
273
+ log.warning(
274
+ f"Dropping {n_nan_rows} rows ({100*n_nan_rows/len(df):.1f}%) with NaN values for cross-validation"
275
+ )
276
+ df = df[~nan_mask].reset_index(drop=True)
277
+
278
+ X = df[feature_cols]
279
+ y = df[target_col]
280
+ ids = df[id_col]
281
+
282
+ # Encode target if classifier
283
+ if label_encoder is not None:
284
+ if not hasattr(label_encoder, "classes_"):
285
+ label_encoder.fit(y)
286
+ y_encoded = label_encoder.transform(y)
287
+ y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
288
+ else:
289
+ y_for_cv = y
290
+
291
+ # Extract configs from loaded model (pass approx train size for batch_size calculation)
292
+ n_train_approx = int(len(df) * (1 - 1 / nfolds))
293
+ configs = _extract_model_configs(loaded_model, n_train_approx)
294
+ trainer_params = configs["trainer"]
295
+ model_params = configs["model"]
296
+
297
+ log.info(f"Trainer config:\n{pformat(trainer_params)}")
298
+ log.info(f"Model config:\n{pformat(model_params)}")
299
+
300
+ # Prepare KFold
301
+ kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
302
+
303
+ # Initialize results collection
304
+ fold_metrics = []
305
+ predictions_df = pd.DataFrame({id_col: ids, target_col: y})
306
+ if is_classifier:
307
+ predictions_df["pred_proba"] = [None] * len(predictions_df)
308
+
309
+ # Perform cross-validation
310
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
311
+ print(f"\n{'='*50}")
312
+ print(f"Fold {fold_idx}/{nfolds}")
313
+ print(f"{'='*50}")
314
+
315
+ # Split data
316
+ df_train = df.iloc[train_idx].copy()
317
+ df_val = df.iloc[val_idx].copy()
318
+
319
+ # Encode target for this fold
320
+ if is_classifier:
321
+ df_train[target_col] = label_encoder.transform(df_train[target_col])
322
+ df_val[target_col] = label_encoder.transform(df_val[target_col])
323
+
324
+ # Create configs for this fold - matching the training template exactly
325
+ data_config = DataConfig(
326
+ target=[target_col],
327
+ continuous_cols=continuous_cols,
328
+ categorical_cols=categorical_cols,
329
+ )
330
+
331
+ trainer_config = TrainerConfig(
332
+ auto_lr_find=trainer_params["auto_lr_find"],
333
+ batch_size=trainer_params["batch_size"],
334
+ max_epochs=trainer_params["max_epochs"],
335
+ min_epochs=trainer_params["min_epochs"],
336
+ early_stopping=trainer_params["early_stopping"],
337
+ early_stopping_patience=trainer_params["early_stopping_patience"],
338
+ gradient_clip_val=trainer_params["gradient_clip_val"],
339
+ checkpoints="valid_loss", # Save best model based on validation loss
340
+ accelerator="cpu",
341
+ )
342
+
343
+ optimizer_config = OptimizerConfig()
344
+
345
+ model_config = CategoryEmbeddingModelConfig(
346
+ task="classification" if is_classifier else "regression",
347
+ layers=model_params["layers"],
348
+ activation=model_params["activation"],
349
+ learning_rate=model_params["learning_rate"],
350
+ dropout=model_params["dropout"],
351
+ use_batch_norm=model_params["use_batch_norm"],
352
+ initialization=model_params["initialization"],
353
+ )
354
+
355
+ # Create and train fresh model
356
+ tabular_model = TabularModel(
357
+ data_config=data_config,
358
+ model_config=model_config,
359
+ optimizer_config=optimizer_config,
360
+ trainer_config=trainer_config,
361
+ )
362
+
363
+ # Change to /tmp for training (PyTorch Tabular needs write access)
364
+ original_cwd = os.getcwd()
365
+ try:
366
+ os.chdir("/tmp")
367
+ # Clean up checkpoint directory from previous fold
368
+ checkpoint_dir = "/tmp/saved_models"
369
+ if os.path.exists(checkpoint_dir):
370
+ shutil.rmtree(checkpoint_dir)
371
+ tabular_model.fit(train=df_train, validation=df_val)
372
+ finally:
373
+ os.chdir(original_cwd)
374
+
375
+ # Make predictions
376
+ result = tabular_model.predict(df_val[feature_cols])
377
+
378
+ # Extract predictions
379
+ prediction_col = f"{target_col}_prediction"
380
+ preds = result[prediction_col].values
381
+
382
+ # Store predictions at the correct indices
383
+ val_indices = df.iloc[val_idx].index
384
+ if is_classifier:
385
+ preds_decoded = label_encoder.inverse_transform(preds.astype(int))
386
+ predictions_df.loc[val_indices, "prediction"] = preds_decoded
387
+
388
+ # Get probabilities and store at validation indices only
389
+ prob_cols = sorted([col for col in result.columns if col.endswith("_probability")])
390
+ if prob_cols:
391
+ probs = result[prob_cols].values
392
+ for i, idx in enumerate(val_indices):
393
+ predictions_df.at[idx, "pred_proba"] = probs[i].tolist()
394
+ else:
395
+ predictions_df.loc[val_indices, "prediction"] = preds
396
+
397
+ # Calculate fold metrics
398
+ if is_classifier:
399
+ y_val_orig = label_encoder.inverse_transform(df_val[target_col])
400
+ preds_orig = preds_decoded
401
+
402
+ prec, rec, f1, _ = precision_recall_fscore_support(
403
+ y_val_orig, preds_orig, average="weighted", zero_division=0
404
+ )
405
+
406
+ prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
407
+ y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
408
+ )
409
+
410
+ y_val_encoded = df_val[target_col].values
411
+ roc_auc_overall = roc_auc_score(y_val_encoded, probs, multi_class="ovr", average="macro")
412
+ roc_auc_per_class = roc_auc_score(y_val_encoded, probs, multi_class="ovr", average=None)
413
+
414
+ fold_metrics.append(
415
+ {
416
+ "fold": fold_idx,
417
+ "precision": prec,
418
+ "recall": rec,
419
+ "f1": f1,
420
+ "roc_auc": roc_auc_overall,
421
+ "precision_per_class": prec_per_class,
422
+ "recall_per_class": rec_per_class,
423
+ "f1_per_class": f1_per_class,
424
+ "roc_auc_per_class": roc_auc_per_class,
425
+ }
426
+ )
427
+
428
+ print(f"Fold {fold_idx} - F1: {f1:.4f}, ROC-AUC: {roc_auc_overall:.4f}")
429
+ else:
430
+ y_val = df_val[target_col].values
431
+ spearman_corr, _ = spearmanr(y_val, preds)
432
+ rmse = np.sqrt(mean_squared_error(y_val, preds))
433
+
434
+ fold_metrics.append(
435
+ {
436
+ "fold": fold_idx,
437
+ "rmse": rmse,
438
+ "mae": mean_absolute_error(y_val, preds),
439
+ "medae": median_absolute_error(y_val, preds),
440
+ "r2": r2_score(y_val, preds),
441
+ "spearmanr": spearman_corr,
442
+ }
443
+ )
444
+
445
+ print(f"Fold {fold_idx} - RMSE: {rmse:.4f}, R2: {fold_metrics[-1]['r2']:.4f}")
446
+
447
+ # Calculate summary metrics
448
+ fold_df = pd.DataFrame(fold_metrics)
449
+
450
+ if is_classifier:
451
+ if "pred_proba" in predictions_df.columns:
452
+ predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
453
+
454
+ metric_rows = []
455
+ for idx, class_name in enumerate(label_encoder.classes_):
456
+ prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
457
+ rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
458
+ f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
459
+ roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
460
+
461
+ y_orig = label_encoder.inverse_transform(y_for_cv)
462
+ support = int((y_orig == class_name).sum())
463
+
464
+ metric_rows.append(
465
+ {
466
+ "class": class_name,
467
+ "precision": prec_scores.mean(),
468
+ "recall": rec_scores.mean(),
469
+ "f1": f1_scores.mean(),
470
+ "roc_auc": roc_auc_scores.mean(),
471
+ "support": support,
472
+ }
473
+ )
474
+
475
+ metric_rows.append(
476
+ {
477
+ "class": "all",
478
+ "precision": fold_df["precision"].mean(),
479
+ "recall": fold_df["recall"].mean(),
480
+ "f1": fold_df["f1"].mean(),
481
+ "roc_auc": fold_df["roc_auc"].mean(),
482
+ "support": len(y_for_cv),
483
+ }
484
+ )
485
+
486
+ metrics_df = pd.DataFrame(metric_rows)
487
+ else:
488
+ metrics_df = pd.DataFrame(
489
+ [
490
+ {
491
+ "rmse": fold_df["rmse"].mean(),
492
+ "mae": fold_df["mae"].mean(),
493
+ "medae": fold_df["medae"].mean(),
494
+ "r2": fold_df["r2"].mean(),
495
+ "spearmanr": fold_df["spearmanr"].mean(),
496
+ "support": len(y_for_cv),
497
+ }
498
+ ]
499
+ )
500
+
501
+ print(f"\n{'='*50}")
502
+ print("Cross-Validation Summary")
503
+ print(f"{'='*50}")
504
+ print(metrics_df.to_string(index=False))
505
+
506
+ return metrics_df, predictions_df
507
+
508
+ finally:
509
+ log.info(f"Cleaning up model directory: {model_dir}")
510
+ shutil.rmtree(model_dir, ignore_errors=True)
511
+
512
+
513
+ if __name__ == "__main__":
514
+
515
+ # Tests for the PyTorch utilities
516
+ from workbench.api import Model, Endpoint
517
+
518
+ # Initialize Workbench model
519
+ model_name = "caco2-er-reg-pytorch-test"
520
+ # model_name = "aqsol-pytorch-reg"
521
+ print(f"Loading Workbench model: {model_name}")
522
+ model = Model(model_name)
523
+ print(f"Model Framework: {model.model_framework}")
524
+
525
+ # Perform cross-fold inference
526
+ end = Endpoint(model.endpoints()[0])
527
+ end.cross_fold_inference()
@@ -308,7 +308,12 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.Data
308
308
  fs = FeatureSet(workbench_model.get_input())
309
309
  df = workbench_model.training_view().pull_dataframe()
310
310
 
311
- # Get id column - assuming FeatureSet has an id_column attribute or similar
311
+ # Extract sample weights if present
312
+ sample_weights = df.get("sample_weight")
313
+ if sample_weights is not None:
314
+ log.info(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}")
315
+
316
+ # Get columns
312
317
  id_col = fs.id_column
313
318
  target_col = workbench_model.target()
314
319
  feature_cols = workbench_model.features()
@@ -316,10 +321,8 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.Data
316
321
  print(f"Feature columns: {len(feature_cols)} features")
317
322
 
318
323
  # Convert string[python] to object, then to category for XGBoost compatibility
319
- # This avoids XGBoost's issue with pandas 2.x string[python] dtype in categorical categories
320
324
  for col in feature_cols:
321
325
  if pd.api.types.is_string_dtype(df[col]):
322
- # Double conversion: string[python] -> object -> category
323
326
  df[col] = df[col].astype("object").astype("category")
324
327
 
325
328
  X = df[feature_cols]
@@ -335,7 +338,6 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.Data
335
338
  y_for_cv = y
336
339
 
337
340
  # Prepare KFold
338
- # Note: random_state=42 seems to not actually give us reproducible results
339
341
  kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
340
342
 
341
343
  # Initialize results collection
@@ -347,8 +349,11 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.Data
347
349
  X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
348
350
  y_train, y_val = y_for_cv.iloc[train_idx], y_for_cv.iloc[val_idx]
349
351
 
352
+ # Get sample weights for training fold
353
+ weights_train = sample_weights.iloc[train_idx] if sample_weights is not None else None
354
+
350
355
  # Train and predict
351
- xgb_model.fit(X_train, y_train)
356
+ xgb_model.fit(X_train, y_train, sample_weight=weights_train)
352
357
  preds = xgb_model.predict(X_val)
353
358
 
354
359
  # Store predictions (decode if classifier)
@@ -36,8 +36,14 @@ class ModelPlot(ComponentInterface):
36
36
  if df is None:
37
37
  return self.display_text("No Data")
38
38
 
39
- # Calculate the distance from the diagonal for each point
39
+ # Grab the target(s) for this model
40
40
  target = model.target()
41
+
42
+ # For multi-task models, match target to inference_run name or default to first
43
+ if isinstance(target, list):
44
+ target = next((t for t in target if t in inference_run), target[0])
45
+
46
+ # Compute error for coloring
41
47
  df["error"] = abs(df["prediction"] - df[target])
42
48
  return ScatterPlot().update_properties(
43
49
  df,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.198
3
+ Version: 0.8.203
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License: MIT License
@@ -42,7 +42,7 @@ Requires-Dist: redis>=5.0.1
42
42
  Requires-Dist: numpy>=1.26.4
43
43
  Requires-Dist: pandas>=2.2.1
44
44
  Requires-Dist: awswrangler>=3.4.0
45
- Requires-Dist: sagemaker>=2.143
45
+ Requires-Dist: sagemaker<3.0,>=2.143
46
46
  Requires-Dist: cryptography>=44.0.2
47
47
  Requires-Dist: ipython>=8.37.0
48
48
  Requires-Dist: pyreadline3; sys_platform == "win32"
@@ -52,7 +52,7 @@ Requires-Dist: joblib>=1.3.2
52
52
  Requires-Dist: requests>=2.26.0
53
53
  Requires-Dist: rdkit>=2024.9.5
54
54
  Requires-Dist: mordredcommunity>=2.0.6
55
- Requires-Dist: workbench-bridges>=0.1.10
55
+ Requires-Dist: workbench-bridges>=0.1.15
56
56
  Provides-Extra: ui
57
57
  Requires-Dist: plotly>=6.0.0; extra == "ui"
58
58
  Requires-Dist: dash>3.0.0; extra == "ui"