workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -3,29 +3,30 @@
3
3
  import logging
4
4
  import os
5
5
  import tempfile
6
- import tarfile
6
+ import joblib
7
7
  import pickle
8
8
  import glob
9
9
  import awswrangler as wr
10
- from typing import Optional, List, Tuple
10
+ from typing import Optional, List, Tuple, Any
11
11
  import hashlib
12
12
  import pandas as pd
13
13
  import numpy as np
14
14
  import xgboost as xgb
15
- from typing import Dict, Any
16
15
  from sklearn.model_selection import KFold, StratifiedKFold
17
16
  from sklearn.metrics import (
18
17
  precision_recall_fscore_support,
19
- confusion_matrix,
20
18
  mean_squared_error,
21
19
  mean_absolute_error,
22
20
  r2_score,
21
+ median_absolute_error,
22
+ roc_auc_score,
23
23
  )
24
+ from scipy.stats import spearmanr
24
25
  from sklearn.preprocessing import LabelEncoder
25
26
 
26
27
  # Workbench Imports
27
- from workbench.utils.model_utils import load_category_mappings_from_s3
28
- from workbench.utils.pandas_utils import convert_categorical_types
28
+ from workbench.utils.model_utils import load_category_mappings_from_s3, safe_extract_tarfile
29
+ from workbench.utils.pandas_utils import convert_categorical_types, expand_proba_column
29
30
 
30
31
  # Set up the log
31
32
  log = logging.getLogger("workbench")
@@ -34,14 +35,12 @@ log = logging.getLogger("workbench")
34
35
  def xgboost_model_from_s3(model_artifact_uri: str):
35
36
  """
36
37
  Download and extract XGBoost model artifact from S3, then load the model into memory.
37
- Handles both direct XGBoost model files and pickled models.
38
- Ensures categorical feature support is enabled.
39
38
 
40
39
  Args:
41
40
  model_artifact_uri (str): S3 URI of the model artifact.
42
41
 
43
42
  Returns:
44
- Loaded XGBoost model or None if unavailable.
43
+ Loaded XGBoost model (XGBClassifier, XGBRegressor, or Booster) or None if unavailable.
45
44
  """
46
45
 
47
46
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -50,68 +49,90 @@ def xgboost_model_from_s3(model_artifact_uri: str):
50
49
  wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
51
50
 
52
51
  # Extract tarball
53
- with tarfile.open(local_tar_path, "r:gz") as tar:
54
- tar.extractall(path=tmpdir, filter="data")
52
+ safe_extract_tarfile(local_tar_path, tmpdir)
55
53
 
56
54
  # Define model file patterns to search for (in order of preference)
57
55
  patterns = [
58
- # Direct XGBoost model files
59
- os.path.join(tmpdir, "xgboost-model"),
60
- os.path.join(tmpdir, "model"),
61
- os.path.join(tmpdir, "*.bin"),
56
+ # Joblib models (preferred - preserves everything)
57
+ os.path.join(tmpdir, "*model*.joblib"),
58
+ os.path.join(tmpdir, "xgb*.joblib"),
59
+ os.path.join(tmpdir, "**", "*model*.joblib"),
60
+ os.path.join(tmpdir, "**", "xgb*.joblib"),
61
+ # Pickle models (also preserves everything)
62
+ os.path.join(tmpdir, "*model*.pkl"),
63
+ os.path.join(tmpdir, "xgb*.pkl"),
64
+ os.path.join(tmpdir, "**", "*model*.pkl"),
65
+ os.path.join(tmpdir, "**", "xgb*.pkl"),
66
+ # JSON models (fallback - requires reconstruction)
67
+ os.path.join(tmpdir, "*model*.json"),
68
+ os.path.join(tmpdir, "xgb*.json"),
62
69
  os.path.join(tmpdir, "**", "*model*.json"),
63
- os.path.join(tmpdir, "**", "rmse.json"),
64
- # Pickled models
65
- os.path.join(tmpdir, "*.pkl"),
66
- os.path.join(tmpdir, "**", "*.pkl"),
67
- os.path.join(tmpdir, "*.pickle"),
68
- os.path.join(tmpdir, "**", "*.pickle"),
70
+ os.path.join(tmpdir, "**", "xgb*.json"),
69
71
  ]
70
72
 
71
73
  # Try each pattern
72
74
  for pattern in patterns:
73
- # Use glob to find all matching files
74
75
  for model_path in glob.glob(pattern, recursive=True):
75
- # Determine file type by extension
76
+ # Skip files that are clearly not XGBoost models
77
+ filename = os.path.basename(model_path).lower()
78
+ if any(skip in filename for skip in ["label_encoder", "scaler", "preprocessor", "transformer"]):
79
+ log.debug(f"Skipping non-model file: {model_path}")
80
+ continue
81
+
76
82
  _, ext = os.path.splitext(model_path)
77
83
 
78
84
  try:
79
- if ext.lower() in [".pkl", ".pickle"]:
80
- # Handle pickled models
85
+ if ext == ".joblib":
86
+ model = joblib.load(model_path)
87
+ # Verify it's actually an XGBoost model
88
+ if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
89
+ log.important(f"Loaded XGBoost model from joblib: {model_path}")
90
+ return model
91
+ else:
92
+ log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
93
+
94
+ elif ext in [".pkl", ".pickle"]:
81
95
  with open(model_path, "rb") as f:
82
96
  model = pickle.load(f)
83
-
84
- # Handle different model types
85
- if isinstance(model, xgb.Booster):
86
- log.important(f"Loaded XGBoost Booster from pickle: {model_path}")
97
+ # Verify it's actually an XGBoost model
98
+ if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
99
+ log.important(f"Loaded XGBoost model from pickle: {model_path}")
87
100
  return model
88
- elif hasattr(model, "get_booster"):
89
- log.important(f"Loaded XGBoost model from pipeline: {model_path}")
90
- booster = model.get_booster()
91
- return booster
92
- else:
93
- # Handle direct XGBoost model files
101
+ else:
102
+ log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
103
+
104
+ elif ext == ".json":
105
+ # JSON files should be XGBoost models by definition
94
106
  booster = xgb.Booster()
95
107
  booster.load_model(model_path)
96
- log.important(f"Loaded XGBoost model directly: {model_path}")
108
+ log.important(f"Loaded XGBoost booster from JSON: {model_path}")
97
109
  return booster
110
+
98
111
  except Exception as e:
99
- log.info(f"Failed to load model from {model_path}: {e}")
100
- continue # Try the next file
112
+ log.debug(f"Failed to load {model_path}: {e}")
113
+ continue
101
114
 
102
- # If no model found
103
115
  log.error("No XGBoost model found in the artifact.")
104
116
  return None
105
117
 
106
118
 
107
- def feature_importance(workbench_model, importance_type: str = "weight") -> Optional[List[Tuple[str, float]]]:
119
+ def feature_importance(workbench_model, importance_type: str = "gain") -> Optional[List[Tuple[str, float]]]:
108
120
  """
109
121
  Get sorted feature importances from a Workbench Model object.
110
122
 
111
123
  Args:
112
124
  workbench_model: Workbench model object
113
- importance_type: Type of feature importance.
114
- Options: 'weight', 'gain', 'cover', 'total_gain', 'total_cover'
125
+ importance_type: Type of feature importance. Options:
126
+ - 'gain' (default): Average improvement in loss/objective when feature is used.
127
+ Best for understanding predictive power of features.
128
+ - 'weight': Number of times a feature appears in trees (split count).
129
+ Useful for understanding model complexity and feature usage frequency.
130
+ - 'cover': Average number of samples affected when feature is used.
131
+ Shows the relative quantity of observations related to this feature.
132
+ - 'total_gain': Total improvement in loss/objective across all splits.
133
+ Similar to 'gain' but not averaged (can be biased toward frequent features).
134
+ - 'total_cover': Total number of samples affected across all splits.
135
+ Similar to 'cover' but not averaged.
115
136
 
116
137
  Returns:
117
138
  List of tuples (feature, importance) sorted by importance value (descending).
@@ -120,7 +141,8 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
120
141
 
121
142
  Note:
122
143
  XGBoost's get_score() only returns features with non-zero importance.
123
- This function ensures all model features are included in the output.
144
+ This function ensures all model features are included in the output,
145
+ adding zero values for features that weren't used in any tree splits.
124
146
  """
125
147
  model_artifact_uri = workbench_model.model_data_url()
126
148
  xgb_model = xgboost_model_from_s3(model_artifact_uri)
@@ -128,11 +150,18 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
128
150
  log.error("No XGBoost model found in the artifact.")
129
151
  return None
130
152
 
131
- # Get feature importances (only non-zero features)
132
- importances = xgb_model.get_score(importance_type=importance_type)
153
+ # Check if we got a full sklearn model or just a booster (for backwards compatibility)
154
+ if hasattr(xgb_model, "get_booster"):
155
+ # Full sklearn model - get the booster for feature importance
156
+ booster = xgb_model.get_booster()
157
+ all_features = booster.feature_names
158
+ else:
159
+ # Already a booster (legacy JSON load)
160
+ booster = xgb_model
161
+ all_features = xgb_model.feature_names
133
162
 
134
- # Get all feature names from the model
135
- all_features = xgb_model.feature_names
163
+ # Get feature importances (only non-zero features)
164
+ importances = booster.get_score(importance_type=importance_type)
136
165
 
137
166
  # Create complete importance dict with zeros for missing features
138
167
  complete_importances = {feat: importances.get(feat, 0.0) for feat in all_features}
@@ -229,148 +258,326 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
229
258
  return result_df
230
259
 
231
260
 
232
- def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Dict[str, Any]:
261
+ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.DataFrame, pd.DataFrame]:
233
262
  """
234
263
  Performs K-fold cross-validation with detailed metrics.
235
264
  Args:
236
265
  workbench_model: Workbench model object
237
266
  nfolds: Number of folds for cross-validation (default is 5)
238
267
  Returns:
239
- Dictionary containing:
240
- - folds: Dictionary of formatted strings for each fold
241
- - summary_metrics: Summary metrics across folds
242
- - overall_metrics: Overall metrics for all folds
268
+ Tuple of:
269
+ - DataFrame with per-class metrics (and 'all' row for overall metrics)
270
+ - DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
243
271
  """
244
272
  from workbench.api import FeatureSet
245
273
 
246
274
  # Load model
247
- model_type = workbench_model.model_type.value
248
275
  model_artifact_uri = workbench_model.model_data_url()
249
- loaded_booster = xgboost_model_from_s3(model_artifact_uri)
250
- if loaded_booster is None:
276
+ loaded_model = xgboost_model_from_s3(model_artifact_uri)
277
+ if loaded_model is None:
251
278
  log.error("No XGBoost model found in the artifact.")
252
- return {}
253
- # Create the model wrapper
254
- is_classifier = model_type == "classifier"
255
- xgb_model = (
256
- xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
257
- )
258
- xgb_model._Booster = loaded_booster
279
+ return pd.DataFrame(), pd.DataFrame()
280
+
281
+ # Check if we got a full sklearn model or need to create one
282
+ if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
283
+ is_classifier = isinstance(loaded_model, xgb.XGBClassifier)
284
+
285
+ # Get the model's hyperparameters and ensure enable_categorical=True
286
+ params = loaded_model.get_params()
287
+ params["enable_categorical"] = True
288
+
289
+ # Create new model with same params but enable_categorical=True
290
+ if is_classifier:
291
+ xgb_model = xgb.XGBClassifier(**params)
292
+ else:
293
+ xgb_model = xgb.XGBRegressor(**params)
294
+
295
+ elif isinstance(loaded_model, xgb.Booster):
296
+ # Legacy: got a booster, need to wrap it
297
+ log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
298
+ is_classifier = workbench_model.model_type.value == "classifier"
299
+ xgb_model = (
300
+ xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
301
+ )
302
+ xgb_model._Booster = loaded_model
303
+ else:
304
+ log.error(f"Unexpected model type: {type(loaded_model)}")
305
+ return pd.DataFrame(), pd.DataFrame()
306
+
259
307
  # Prepare data
260
308
  fs = FeatureSet(workbench_model.get_input())
261
- df = fs.pull_dataframe()
309
+ df = workbench_model.training_view().pull_dataframe()
310
+
311
+ # Extract sample weights if present
312
+ sample_weights = df.get("sample_weight")
313
+ if sample_weights is not None:
314
+ log.info(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}")
315
+
316
+ # Get columns
317
+ id_col = fs.id_column
318
+ target_col = workbench_model.target()
262
319
  feature_cols = workbench_model.features()
263
- # Convert string features to categorical
320
+ print(f"Target column: {target_col}")
321
+ print(f"Feature columns: {len(feature_cols)} features")
322
+
323
+ # Convert string[python] to object, then to category for XGBoost compatibility
264
324
  for col in feature_cols:
265
- if df[col].dtype in ["object", "string"]:
266
- df[col] = df[col].astype("category")
267
- # Split X and y
268
- X = df[workbench_model.features()]
269
- y = df[workbench_model.target()]
325
+ if pd.api.types.is_string_dtype(df[col]):
326
+ df[col] = df[col].astype("object").astype("category")
327
+
328
+ X = df[feature_cols]
329
+ y = df[target_col]
330
+ ids = df[id_col]
270
331
 
271
- # Encode target if it's a classification problem
332
+ # Encode target if classifier
272
333
  label_encoder = LabelEncoder() if is_classifier else None
273
334
  if label_encoder:
274
- y = pd.Series(label_encoder.fit_transform(y), name=workbench_model.target())
335
+ y_encoded = label_encoder.fit_transform(y)
336
+ y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
337
+ else:
338
+ y_for_cv = y
339
+
275
340
  # Prepare KFold
276
- kfold = (
277
- StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=42)
278
- if is_classifier
279
- else KFold(n_splits=nfolds, shuffle=True, random_state=42)
280
- )
341
+ kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
342
+
343
+ # Initialize results collection
344
+ fold_metrics = []
345
+ predictions_df = pd.DataFrame({id_col: ids, target_col: y})
281
346
 
282
- fold_results = []
283
- all_predictions = []
284
- all_actuals = []
285
- for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y)):
347
+ # Perform cross-validation
348
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
286
349
  X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
287
- y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
350
+ y_train, y_val = y_for_cv.iloc[train_idx], y_for_cv.iloc[val_idx]
288
351
 
289
- # Train the model
290
- xgb_model.fit(X_train, y_train)
352
+ # Get sample weights for training fold
353
+ weights_train = sample_weights.iloc[train_idx] if sample_weights is not None else None
354
+
355
+ # Train and predict
356
+ xgb_model.fit(X_train, y_train, sample_weight=weights_train)
291
357
  preds = xgb_model.predict(X_val)
292
- all_predictions.extend(preds)
293
- all_actuals.extend(y_val)
294
358
 
295
- # Calculate metrics for this fold
296
- fold_metrics = {"fold": fold_idx + 1}
359
+ # Store predictions (decode if classifier)
360
+ val_indices = X_val.index
361
+ if is_classifier:
362
+ predictions_df.loc[val_indices, "prediction"] = label_encoder.inverse_transform(preds.astype(int))
363
+ y_proba = xgb_model.predict_proba(X_val)
364
+ predictions_df.loc[val_indices, "pred_proba"] = pd.Series(y_proba.tolist(), index=val_indices)
365
+ else:
366
+ predictions_df.loc[val_indices, "prediction"] = preds
297
367
 
368
+ # Calculate fold metrics
298
369
  if is_classifier:
299
- y_val_original = label_encoder.inverse_transform(y_val)
300
- preds_original = label_encoder.inverse_transform(preds.astype(int))
301
- scores = precision_recall_fscore_support(
302
- y_val_original, preds_original, average="weighted", zero_division=0
370
+ y_val_orig = label_encoder.inverse_transform(y_val)
371
+ preds_orig = label_encoder.inverse_transform(preds.astype(int))
372
+
373
+ # Overall weighted metrics
374
+ prec, rec, f1, _ = precision_recall_fscore_support(
375
+ y_val_orig, preds_orig, average="weighted", zero_division=0
376
+ )
377
+
378
+ # Per-class F1
379
+ prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
380
+ y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
381
+ )
382
+
383
+ # ROC-AUC (overall and per-class)
384
+ roc_auc_overall = roc_auc_score(y_val, y_proba, multi_class="ovr", average="macro")
385
+ roc_auc_per_class = roc_auc_score(y_val, y_proba, multi_class="ovr", average=None)
386
+
387
+ fold_metrics.append(
388
+ {
389
+ "fold": fold_idx,
390
+ "precision": prec,
391
+ "recall": rec,
392
+ "f1": f1,
393
+ "roc_auc": roc_auc_overall,
394
+ "precision_per_class": prec_per_class,
395
+ "recall_per_class": rec_per_class,
396
+ "f1_per_class": f1_per_class,
397
+ "roc_auc_per_class": roc_auc_per_class,
398
+ }
303
399
  )
304
- fold_metrics.update({"precision": float(scores[0]), "recall": float(scores[1]), "fscore": float(scores[2])})
305
400
  else:
306
- fold_metrics.update(
401
+ spearman_corr, _ = spearmanr(y_val, preds)
402
+ fold_metrics.append(
307
403
  {
308
- "rmse": float(np.sqrt(mean_squared_error(y_val, preds))),
309
- "mae": float(mean_absolute_error(y_val, preds)),
310
- "r2": float(r2_score(y_val, preds)),
404
+ "fold": fold_idx,
405
+ "rmse": np.sqrt(mean_squared_error(y_val, preds)),
406
+ "mae": mean_absolute_error(y_val, preds),
407
+ "medae": median_absolute_error(y_val, preds),
408
+ "r2": r2_score(y_val, preds),
409
+ "spearmanr": spearman_corr,
311
410
  }
312
411
  )
313
412
 
314
- fold_results.append(fold_metrics)
315
- # Calculate overall metrics
316
- overall_metrics = {}
413
+ # Calculate summary metrics
414
+ fold_df = pd.DataFrame(fold_metrics)
415
+
317
416
  if is_classifier:
318
- all_actuals_original = label_encoder.inverse_transform(all_actuals)
319
- all_predictions_original = label_encoder.inverse_transform(all_predictions)
320
- scores = precision_recall_fscore_support(
321
- all_actuals_original, all_predictions_original, average="weighted", zero_division=0
322
- )
323
- overall_metrics.update(
417
+ # Expand the *_proba columns into separate columns for easier handling
418
+ predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
419
+
420
+ # Build per-class metrics DataFrame
421
+ metric_rows = []
422
+
423
+ # Per-class rows
424
+ for idx, class_name in enumerate(label_encoder.classes_):
425
+ prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
426
+ rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
427
+ f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
428
+ roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
429
+
430
+ y_orig = label_encoder.inverse_transform(y_for_cv)
431
+ support = int((y_orig == class_name).sum())
432
+
433
+ metric_rows.append(
434
+ {
435
+ "class": class_name,
436
+ "precision": prec_scores.mean(),
437
+ "recall": rec_scores.mean(),
438
+ "f1": f1_scores.mean(),
439
+ "roc_auc": roc_auc_scores.mean(),
440
+ "support": support,
441
+ }
442
+ )
443
+
444
+ # Overall 'all' row
445
+ metric_rows.append(
324
446
  {
325
- "precision": float(scores[0]),
326
- "recall": float(scores[1]),
327
- "fscore": float(scores[2]),
328
- "confusion_matrix": confusion_matrix(
329
- all_actuals_original, all_predictions_original, labels=label_encoder.classes_
330
- ).tolist(),
331
- "label_names": list(label_encoder.classes_),
447
+ "class": "all",
448
+ "precision": fold_df["precision"].mean(),
449
+ "recall": fold_df["recall"].mean(),
450
+ "f1": fold_df["f1"].mean(),
451
+ "roc_auc": fold_df["roc_auc"].mean(),
452
+ "support": len(y_for_cv),
332
453
  }
333
454
  )
455
+
456
+ metrics_df = pd.DataFrame(metric_rows)
457
+
334
458
  else:
335
- overall_metrics.update(
336
- {
337
- "rmse": float(np.sqrt(mean_squared_error(all_actuals, all_predictions))),
338
- "mae": float(mean_absolute_error(all_actuals, all_predictions)),
339
- "r2": float(r2_score(all_actuals, all_predictions)),
340
- }
459
+ # Regression metrics
460
+ metrics_df = pd.DataFrame(
461
+ [
462
+ {
463
+ "rmse": fold_df["rmse"].mean(),
464
+ "mae": fold_df["mae"].mean(),
465
+ "medae": fold_df["medae"].mean(),
466
+ "r2": fold_df["r2"].mean(),
467
+ "spearmanr": fold_df["spearmanr"].mean(),
468
+ "support": len(y_for_cv),
469
+ }
470
+ ]
341
471
  )
342
- # Calculate summary metrics across folds
343
- summary_metrics = {}
344
- metrics_to_aggregate = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "r2"]
345
-
346
- for metric in metrics_to_aggregate:
347
- values = [fold[metric] for fold in fold_results]
348
- summary_metrics[metric] = f"{float(np.mean(values)):.3f} ±{float(np.std(values)):.3f}"
349
- # Format fold results as strings (TBD section)
350
- formatted_folds = {}
351
- for fold_data in fold_results:
352
- fold_key = f"Fold {fold_data['fold']}"
353
- if is_classifier:
354
- formatted_folds[fold_key] = (
355
- f"precision: {fold_data['precision']:.3f} "
356
- f"recall: {fold_data['recall']:.3f} "
357
- f"fscore: {fold_data['fscore']:.3f}"
358
- )
359
- else:
360
- formatted_folds[fold_key] = (
361
- f"rmse: {fold_data['rmse']:.3f} mae: {fold_data['mae']:.3f} r2: {fold_data['r2']:.3f}"
362
- )
363
- # Return the results
364
- return {
365
- "summary_metrics": summary_metrics,
366
- # "overall_metrics": overall_metrics,
367
- "folds": formatted_folds,
368
- }
472
+
473
+ return metrics_df, predictions_df
474
+
475
+
476
+ def leave_one_out_inference(workbench_model: Any) -> pd.DataFrame:
477
+ """
478
+ Performs leave-one-out cross-validation (parallelized).
479
+ For datasets > 1000 rows, first identifies top 100 worst predictions via 10-fold CV,
480
+ then performs true leave-one-out on those 100 samples.
481
+ Each model trains on ALL data except one sample.
482
+ """
483
+ from workbench.api import FeatureSet
484
+ from joblib import Parallel, delayed
485
+ from tqdm import tqdm
486
+
487
+ def train_and_predict_one(model_params, is_classifier, X, y, train_idx, val_idx):
488
+ """Train on train_idx, predict on val_idx."""
489
+ model = xgb.XGBClassifier(**model_params) if is_classifier else xgb.XGBRegressor(**model_params)
490
+ model.fit(X[train_idx], y[train_idx])
491
+ return model.predict(X[val_idx])[0]
492
+
493
+ # Load model and get params
494
+ model_artifact_uri = workbench_model.model_data_url()
495
+ loaded_model = xgboost_model_from_s3(model_artifact_uri)
496
+ if loaded_model is None:
497
+ log.error("No XGBoost model found in the artifact.")
498
+ return pd.DataFrame()
499
+
500
+ if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
501
+ is_classifier = isinstance(loaded_model, xgb.XGBClassifier)
502
+ model_params = loaded_model.get_params()
503
+ elif isinstance(loaded_model, xgb.Booster):
504
+ log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
505
+ is_classifier = workbench_model.model_type.value == "classifier"
506
+ model_params = {"enable_categorical": True}
507
+ else:
508
+ log.error(f"Unexpected model type: {type(loaded_model)}")
509
+ return pd.DataFrame()
510
+
511
+ # Load and prepare data
512
+ fs = FeatureSet(workbench_model.get_input())
513
+ df = workbench_model.training_view().pull_dataframe()
514
+ id_col = fs.id_column
515
+ target_col = workbench_model.target()
516
+ feature_cols = workbench_model.features()
517
+
518
+ # Convert string[python] to object, then to category for XGBoost compatibility
519
+ # This avoids XGBoost's issue with pandas 2.x string[python] dtype in categorical categories
520
+ for col in feature_cols:
521
+ if pd.api.types.is_string_dtype(df[col]):
522
+ # Double conversion: string[python] -> object -> category
523
+ df[col] = df[col].astype("object").astype("category")
524
+
525
+ # Determine which samples to run LOO on
526
+ if len(df) > 1000:
527
+ log.important(f"Dataset has {len(df)} rows. Running 10-fold CV to identify top 1000 worst predictions...")
528
+ _, predictions_df = cross_fold_inference(workbench_model, nfolds=10)
529
+ predictions_df["residual_abs"] = np.abs(predictions_df[target_col] - predictions_df["prediction"])
530
+ worst_samples = predictions_df.nlargest(1000, "residual_abs")
531
+ worst_ids = worst_samples[id_col].values
532
+ loo_indices = df[df[id_col].isin(worst_ids)].index.values
533
+ log.important(f"Running leave-one-out CV on 1000 worst samples. Each model trains on {len(df)-1} rows...")
534
+ else:
535
+ log.important(f"Running leave-one-out CV on all {len(df)} samples...")
536
+ loo_indices = df.index.values
537
+
538
+ # Prepare full dataset for training
539
+ X_full = df[feature_cols].values
540
+ y_full = df[target_col].values
541
+
542
+ # Encode target if classifier
543
+ label_encoder = LabelEncoder() if is_classifier else None
544
+ if label_encoder:
545
+ y_full = label_encoder.fit_transform(y_full)
546
+
547
+ # Generate LOO splits
548
+ splits = []
549
+ for loo_idx in loo_indices:
550
+ train_idx = np.delete(np.arange(len(X_full)), loo_idx)
551
+ val_idx = np.array([loo_idx])
552
+ splits.append((train_idx, val_idx))
553
+
554
+ # Parallel execution
555
+ predictions = Parallel(n_jobs=4)(
556
+ delayed(train_and_predict_one)(model_params, is_classifier, X_full, y_full, train_idx, val_idx)
557
+ for train_idx, val_idx in tqdm(splits, desc="LOO CV")
558
+ )
559
+
560
+ # Build results dataframe
561
+ predictions_array = np.array(predictions)
562
+ if label_encoder:
563
+ predictions_array = label_encoder.inverse_transform(predictions_array.astype(int))
564
+
565
+ predictions_df = pd.DataFrame(
566
+ {
567
+ id_col: df.loc[loo_indices, id_col].values,
568
+ target_col: df.loc[loo_indices, target_col].values,
569
+ "prediction": predictions_array,
570
+ }
571
+ )
572
+
573
+ predictions_df["residual_abs"] = np.abs(predictions_df[target_col] - predictions_df["prediction"])
574
+
575
+ return predictions_df
369
576
 
370
577
 
371
578
  if __name__ == "__main__":
372
579
  """Exercise the Model Utilities"""
373
- from workbench.api import Model, FeatureSet
580
+ from workbench.api import Model
374
581
  from pprint import pprint
375
582
 
376
583
  # Test the XGBoost model loading and feature importance
@@ -383,11 +590,28 @@ if __name__ == "__main__":
383
590
  model_artifact_uri = model.model_data_url()
384
591
  xgb_model = xgboost_model_from_s3(model_artifact_uri)
385
592
 
593
+ # Verify enable_categorical is preserved (for debugging/confidence)
594
+ print(f"Model parameters: {xgb_model.get_params()}")
595
+ print(f"enable_categorical: {xgb_model.enable_categorical}")
596
+
386
597
  # Test with UQ Model
387
598
  uq_model = Model("aqsol-uq")
388
599
  _xgb_model = xgboost_model_from_s3(uq_model.model_data_url())
389
600
 
601
+ print("\n=== CROSS FOLD REGRESSION EXAMPLE ===")
602
+ model = Model("abalone-regression")
603
+ results, df = cross_fold_inference(model)
604
+ pprint(results)
605
+ print(df.head())
606
+
607
+ print("\n=== CROSS FOLD CLASSIFICATION EXAMPLE ===")
608
+ model = Model("wine-classification")
609
+ results, df = cross_fold_inference(model)
610
+ pprint(results)
611
+ print(df.head())
612
+
390
613
  # Test XGBoost add_leaf_hash
614
+ """
391
615
  input_df = FeatureSet(model.get_input()).pull_dataframe()
392
616
  leaf_df = add_leaf_hash(model, input_df)
393
617
  print("DataFrame with Leaf Hash:")
@@ -404,13 +628,4 @@ if __name__ == "__main__":
404
628
  stats_df = leaf_stats(leaf_df, target_col)
405
629
  print("DataFrame with Leaf Statistics:")
406
630
  print(stats_df)
407
-
408
- print("\n=== CROSS FOLD REGRESSION EXAMPLE ===")
409
- model = Model("abalone-regression")
410
- results = cross_fold_inference(model)
411
- pprint(results)
412
-
413
- print("\n=== CROSS FOLD CLASSIFICATION EXAMPLE ===")
414
- model = Model("wine-classification")
415
- results = cross_fold_inference(model)
416
- pprint(results)
631
+ """