workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,87 @@
1
+ """PyTorch Tabular utilities for Workbench models."""
2
+
3
+ import logging
4
+ import os
5
+ import tarfile
6
+ import tempfile
7
+ from typing import Any, Tuple
8
+
9
+ import awswrangler as wr
10
+ import pandas as pd
11
+
12
+ from workbench.utils.aws_utils import pull_s3_data
13
+ from workbench.utils.metrics_utils import compute_metrics_from_predictions
14
+
15
+ log = logging.getLogger("workbench")
16
+
17
+
18
+ def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
19
+ """Download and extract a PyTorch model artifact from S3.
20
+
21
+ Args:
22
+ s3_uri: S3 URI of the model.tar.gz artifact
23
+ model_dir: Local directory to extract the model to
24
+ """
25
+ with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
26
+ tmp_path = tmp.name
27
+
28
+ try:
29
+ wr.s3.download(path=s3_uri, local_file=tmp_path)
30
+ with tarfile.open(tmp_path, "r:gz") as tar:
31
+ tar.extractall(model_dir)
32
+ log.info(f"Extracted model to {model_dir}")
33
+ finally:
34
+ if os.path.exists(tmp_path):
35
+ os.remove(tmp_path)
36
+
37
+
38
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
39
+ """Pull cross-validation results from AWS training artifacts.
40
+
41
+ This retrieves the validation predictions saved during model training and
42
+ computes metrics directly from them. For PyTorch models trained with
43
+ n_folds > 1, these are out-of-fold predictions from k-fold cross-validation.
44
+
45
+ Args:
46
+ workbench_model: Workbench model object
47
+
48
+ Returns:
49
+ Tuple of:
50
+ - DataFrame with computed metrics
51
+ - DataFrame with validation predictions
52
+ """
53
+ # Get the validation predictions from S3
54
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
55
+ predictions_df = pull_s3_data(s3_path)
56
+
57
+ if predictions_df is None:
58
+ raise ValueError(f"No validation predictions found at {s3_path}")
59
+
60
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
61
+
62
+ # Compute metrics from predictions
63
+ target = workbench_model.target()
64
+ class_labels = workbench_model.class_labels()
65
+
66
+ if target in predictions_df.columns and "prediction" in predictions_df.columns:
67
+ metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
68
+ else:
69
+ metrics_df = pd.DataFrame()
70
+
71
+ return metrics_df, predictions_df
72
+
73
+
74
+ if __name__ == "__main__":
75
+ from workbench.api import Model
76
+
77
+ # Test pulling CV results
78
+ model_name = "aqsol-reg-pytorch"
79
+ print(f"Loading Workbench model: {model_name}")
80
+ model = Model(model_name)
81
+ print(f"Model Framework: {model.model_framework}")
82
+
83
+ # Pull CV results from training artifacts
84
+ metrics_df, predictions_df = pull_cv_results(model)
85
+ print(f"\nMetrics:\n{metrics_df}")
86
+ print(f"\nPredictions shape: {predictions_df.shape}")
87
+ print(f"Predictions columns: {predictions_df.columns.tolist()}")
@@ -9,6 +9,7 @@ from typing import Optional, List, Tuple, Dict, Union
9
9
  from workbench.utils.xgboost_model_utils import xgboost_model_from_s3
10
10
  from workbench.utils.model_utils import load_category_mappings_from_s3
11
11
  from workbench.utils.pandas_utils import convert_categorical_types
12
+ from workbench.model_script_utils.model_script_utils import decompress_features
12
13
 
13
14
  # Set up the log
14
15
  log = logging.getLogger("workbench")
@@ -111,61 +112,6 @@ def shap_values_data(
111
112
  return result_df, feature_df
112
113
 
113
114
 
114
- def decompress_features(
115
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
116
- ) -> Tuple[pd.DataFrame, List[str]]:
117
- """Prepare features for the XGBoost model
118
-
119
- Args:
120
- df (pd.DataFrame): The features DataFrame
121
- features (List[str]): Full list of feature names
122
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
123
-
124
- Returns:
125
- pd.DataFrame: DataFrame with the decompressed features
126
- List[str]: Updated list of feature names after decompression
127
-
128
- Raises:
129
- ValueError: If any missing values are found in the specified features
130
- """
131
-
132
- # Check for any missing values in the required features
133
- missing_counts = df[features].isna().sum()
134
- if missing_counts.any():
135
- missing_features = missing_counts[missing_counts > 0]
136
- print(
137
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
138
- "WARNING: You might want to remove/replace all NaN values before processing."
139
- )
140
-
141
- # Decompress the specified compressed features
142
- decompressed_features = features
143
- for feature in compressed_features:
144
- if (feature not in df.columns) or (feature not in features):
145
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
146
- continue
147
-
148
- # Remove the feature from the list of features to avoid duplication
149
- decompressed_features.remove(feature)
150
-
151
- # Handle all compressed features as bitstrings
152
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
153
- prefix = feature[:3]
154
-
155
- # Create all new columns at once - avoids fragmentation
156
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
157
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
158
-
159
- # Add to features list
160
- decompressed_features.extend(new_col_names)
161
-
162
- # Drop original column and concatenate new ones
163
- df = df.drop(columns=[feature])
164
- df = pd.concat([df, new_df], axis=1)
165
-
166
- return df, decompressed_features
167
-
168
-
169
115
  def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
170
116
  """
171
117
  Internal function to calculate SHAP values for Workbench Models.
@@ -212,6 +158,14 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
212
158
  log.error("No XGBoost model found in the artifact.")
213
159
  return None, None, None, None
214
160
 
161
+ # Get the booster (SHAP requires the booster, not the sklearn wrapper)
162
+ if hasattr(xgb_model, "get_booster"):
163
+ # Full sklearn model - extract the booster
164
+ booster = xgb_model.get_booster()
165
+ else:
166
+ # Already a booster
167
+ booster = xgb_model
168
+
215
169
  # Load category mappings if available
216
170
  category_mappings = load_category_mappings_from_s3(model_artifact_uri)
217
171
 
@@ -229,8 +183,8 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
229
183
  # Create a DMatrix with categorical support
230
184
  dmatrix = xgb.DMatrix(X, enable_categorical=True)
231
185
 
232
- # Use XGBoost's built-in SHAP calculation
233
- shap_values = xgb_model.predict(dmatrix, pred_contribs=True, strict_shape=True)
186
+ # Use XGBoost's built-in SHAP calculation (booster method, not sklearn)
187
+ shap_values = booster.predict(dmatrix, pred_contribs=True, strict_shape=True)
234
188
  features_with_bias = features + ["bias"]
235
189
 
236
190
  # Now we need to subset the columns based on top 10 SHAP values
@@ -181,9 +181,6 @@ def logging_setup(color_logs=True):
181
181
  log.debug("Debugging enabled via WORKBENCH_DEBUG environment variable.")
182
182
  else:
183
183
  log.setLevel(logging.INFO)
184
- # Note: Not using the ThrottlingFilter for now
185
- # throttle_filter = ThrottlingFilter(rate_seconds=5)
186
- # handler.addFilter(throttle_filter)
187
184
 
188
185
  # Suppress specific logger
189
186
  logging.getLogger("sagemaker.config").setLevel(logging.WARNING)
@@ -12,7 +12,7 @@ class WorkbenchSQS:
12
12
  self.log = logging.getLogger("workbench")
13
13
  self.queue_url = queue_url
14
14
 
15
- # Grab a Workbench Session (this allows us to assume the Workbench-ExecutionRole)
15
+ # Grab a Workbench Session
16
16
  self.boto3_session = AWSAccountClamp().boto3_session
17
17
  print(self.boto3_session)
18
18
 
@@ -0,0 +1,267 @@
1
+ """XGBoost Local Cross-Fold Validation Utilities
2
+
3
+ This module contains functions for running cross-validation locally on XGBoost models.
4
+ For most use cases, prefer using pull_cv_results() from xgboost_model_utils.py which
5
+ retrieves the CV results that were saved during training on SageMaker.
6
+
7
+ These local cross-fold functions are useful for:
8
+ - Re-running CV with different fold counts
9
+ - Leave-one-out cross-validation
10
+ - Custom CV experiments
11
+ """
12
+
13
+ import logging
14
+ from typing import Any, Tuple
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import xgboost as xgb
19
+ from sklearn.model_selection import KFold, StratifiedKFold
20
+ from sklearn.preprocessing import LabelEncoder
21
+
22
+ from workbench.utils.metrics_utils import compute_metrics_from_predictions
23
+ from workbench.utils.pandas_utils import expand_proba_column
24
+ from workbench.utils.xgboost_model_utils import xgboost_model_from_s3
25
+
26
+ log = logging.getLogger("workbench")
27
+
28
+
29
+ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.DataFrame, pd.DataFrame]:
30
+ """
31
+ Performs K-fold cross-validation locally with detailed metrics.
32
+
33
+ Note: For most use cases, prefer using pull_cv_results() from xgboost_model_utils.py
34
+ which retrieves the CV results that were saved during training.
35
+
36
+ Args:
37
+ workbench_model: Workbench model object
38
+ nfolds: Number of folds for cross-validation (default is 5)
39
+ Returns:
40
+ Tuple of:
41
+ - DataFrame with per-class metrics (and 'all' row for overall metrics)
42
+ - DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
43
+ """
44
+ from workbench.api import FeatureSet
45
+
46
+ # Load model
47
+ model_artifact_uri = workbench_model.model_data_url()
48
+ loaded_model = xgboost_model_from_s3(model_artifact_uri)
49
+ if loaded_model is None:
50
+ log.error("No XGBoost model found in the artifact.")
51
+ return pd.DataFrame(), pd.DataFrame()
52
+
53
+ # Check if we got a full sklearn model or need to create one
54
+ if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
55
+ is_classifier = isinstance(loaded_model, xgb.XGBClassifier)
56
+
57
+ # Get the model's hyperparameters and ensure enable_categorical=True
58
+ params = loaded_model.get_params()
59
+ params["enable_categorical"] = True
60
+
61
+ # Create new model with same params but enable_categorical=True
62
+ if is_classifier:
63
+ xgb_model = xgb.XGBClassifier(**params)
64
+ else:
65
+ xgb_model = xgb.XGBRegressor(**params)
66
+
67
+ elif isinstance(loaded_model, xgb.Booster):
68
+ # Legacy: got a booster, need to wrap it
69
+ log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
70
+ is_classifier = workbench_model.model_type.value == "classifier"
71
+ xgb_model = (
72
+ xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
73
+ )
74
+ xgb_model._Booster = loaded_model
75
+ else:
76
+ log.error(f"Unexpected model type: {type(loaded_model)}")
77
+ return pd.DataFrame(), pd.DataFrame()
78
+
79
+ # Prepare data
80
+ fs = FeatureSet(workbench_model.get_input())
81
+ df = workbench_model.training_view().pull_dataframe()
82
+
83
+ # Extract sample weights if present
84
+ sample_weights = df.get("sample_weight")
85
+ if sample_weights is not None:
86
+ log.info(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}")
87
+
88
+ # Get columns
89
+ id_col = fs.id_column
90
+ target_col = workbench_model.target()
91
+ feature_cols = workbench_model.features()
92
+ print(f"Target column: {target_col}")
93
+ print(f"Feature columns: {len(feature_cols)} features")
94
+
95
+ # Convert string[python] to object, then to category for XGBoost compatibility
96
+ for col in feature_cols:
97
+ if pd.api.types.is_string_dtype(df[col]):
98
+ df[col] = df[col].astype("object").astype("category")
99
+
100
+ X = df[feature_cols]
101
+ y = df[target_col]
102
+ ids = df[id_col]
103
+
104
+ # Encode target if classifier
105
+ label_encoder = LabelEncoder() if is_classifier else None
106
+ if label_encoder:
107
+ y_encoded = label_encoder.fit_transform(y)
108
+ y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
109
+ else:
110
+ y_for_cv = y
111
+
112
+ # Prepare KFold
113
+ kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
114
+
115
+ # Initialize predictions DataFrame
116
+ predictions_df = pd.DataFrame({id_col: ids, target_col: y})
117
+
118
+ # Perform cross-validation
119
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
120
+ X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
121
+ y_train = y_for_cv.iloc[train_idx]
122
+
123
+ # Get sample weights for training fold
124
+ weights_train = sample_weights.iloc[train_idx] if sample_weights is not None else None
125
+
126
+ # Train and predict
127
+ xgb_model.fit(X_train, y_train, sample_weight=weights_train)
128
+ preds = xgb_model.predict(X_val)
129
+
130
+ # Store predictions (decode if classifier)
131
+ val_indices = X_val.index
132
+ if is_classifier:
133
+ predictions_df.loc[val_indices, "prediction"] = label_encoder.inverse_transform(preds.astype(int))
134
+ y_proba = xgb_model.predict_proba(X_val)
135
+ predictions_df.loc[val_indices, "pred_proba"] = pd.Series(y_proba.tolist(), index=val_indices)
136
+ else:
137
+ predictions_df.loc[val_indices, "prediction"] = preds
138
+
139
+ # Expand proba columns for classifiers
140
+ if is_classifier:
141
+ predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
142
+
143
+ # Compute metrics from the complete out-of-fold predictions
144
+ class_labels = list(label_encoder.classes_) if is_classifier else None
145
+ metrics_df = compute_metrics_from_predictions(predictions_df, target_col, class_labels)
146
+
147
+ return metrics_df, predictions_df
148
+
149
+
150
+ def leave_one_out_inference(workbench_model: Any) -> pd.DataFrame:
151
+ """
152
+ Performs leave-one-out cross-validation (parallelized).
153
+ For datasets > 1000 rows, first identifies top 100 worst predictions via 10-fold CV,
154
+ then performs true leave-one-out on those 100 samples.
155
+ Each model trains on ALL data except one sample.
156
+ """
157
+ from workbench.api import FeatureSet
158
+ from joblib import Parallel, delayed
159
+ from tqdm import tqdm
160
+
161
+ def train_and_predict_one(model_params, is_classifier, X, y, train_idx, val_idx):
162
+ """Train on train_idx, predict on val_idx."""
163
+ model = xgb.XGBClassifier(**model_params) if is_classifier else xgb.XGBRegressor(**model_params)
164
+ model.fit(X[train_idx], y[train_idx])
165
+ return model.predict(X[val_idx])[0]
166
+
167
+ # Load model and get params
168
+ model_artifact_uri = workbench_model.model_data_url()
169
+ loaded_model = xgboost_model_from_s3(model_artifact_uri)
170
+ if loaded_model is None:
171
+ log.error("No XGBoost model found in the artifact.")
172
+ return pd.DataFrame()
173
+
174
+ if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
175
+ is_classifier = isinstance(loaded_model, xgb.XGBClassifier)
176
+ model_params = loaded_model.get_params()
177
+ elif isinstance(loaded_model, xgb.Booster):
178
+ log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
179
+ is_classifier = workbench_model.model_type.value == "classifier"
180
+ model_params = {"enable_categorical": True}
181
+ else:
182
+ log.error(f"Unexpected model type: {type(loaded_model)}")
183
+ return pd.DataFrame()
184
+
185
+ # Load and prepare data
186
+ fs = FeatureSet(workbench_model.get_input())
187
+ df = workbench_model.training_view().pull_dataframe()
188
+ id_col = fs.id_column
189
+ target_col = workbench_model.target()
190
+ feature_cols = workbench_model.features()
191
+
192
+ # Convert string[python] to object, then to category for XGBoost compatibility
193
+ # This avoids XGBoost's issue with pandas 2.x string[python] dtype in categorical categories
194
+ for col in feature_cols:
195
+ if pd.api.types.is_string_dtype(df[col]):
196
+ # Double conversion: string[python] -> object -> category
197
+ df[col] = df[col].astype("object").astype("category")
198
+
199
+ # Determine which samples to run LOO on
200
+ if len(df) > 1000:
201
+ log.important(f"Dataset has {len(df)} rows. Running 10-fold CV to identify top 1000 worst predictions...")
202
+ _, predictions_df = cross_fold_inference(workbench_model, nfolds=10)
203
+ predictions_df["residual_abs"] = np.abs(predictions_df[target_col] - predictions_df["prediction"])
204
+ worst_samples = predictions_df.nlargest(1000, "residual_abs")
205
+ worst_ids = worst_samples[id_col].values
206
+ loo_indices = df[df[id_col].isin(worst_ids)].index.values
207
+ log.important(f"Running leave-one-out CV on 1000 worst samples. Each model trains on {len(df)-1} rows...")
208
+ else:
209
+ log.important(f"Running leave-one-out CV on all {len(df)} samples...")
210
+ loo_indices = df.index.values
211
+
212
+ # Prepare full dataset for training
213
+ X_full = df[feature_cols].values
214
+ y_full = df[target_col].values
215
+
216
+ # Encode target if classifier
217
+ label_encoder = LabelEncoder() if is_classifier else None
218
+ if label_encoder:
219
+ y_full = label_encoder.fit_transform(y_full)
220
+
221
+ # Generate LOO splits
222
+ splits = []
223
+ for loo_idx in loo_indices:
224
+ train_idx = np.delete(np.arange(len(X_full)), loo_idx)
225
+ val_idx = np.array([loo_idx])
226
+ splits.append((train_idx, val_idx))
227
+
228
+ # Parallel execution
229
+ predictions = Parallel(n_jobs=4)(
230
+ delayed(train_and_predict_one)(model_params, is_classifier, X_full, y_full, train_idx, val_idx)
231
+ for train_idx, val_idx in tqdm(splits, desc="LOO CV")
232
+ )
233
+
234
+ # Build results dataframe
235
+ predictions_array = np.array(predictions)
236
+ if label_encoder:
237
+ predictions_array = label_encoder.inverse_transform(predictions_array.astype(int))
238
+
239
+ predictions_df = pd.DataFrame(
240
+ {
241
+ id_col: df.loc[loo_indices, id_col].values,
242
+ target_col: df.loc[loo_indices, target_col].values,
243
+ "prediction": predictions_array,
244
+ }
245
+ )
246
+
247
+ predictions_df["residual_abs"] = np.abs(predictions_df[target_col] - predictions_df["prediction"])
248
+
249
+ return predictions_df
250
+
251
+
252
+ if __name__ == "__main__":
253
+ """Exercise the Local Cross-Fold Utilities"""
254
+ from workbench.api import Model
255
+ from pprint import pprint
256
+
257
+ print("\n=== LOCAL CROSS FOLD REGRESSION EXAMPLE ===")
258
+ model = Model("abalone-regression")
259
+ results, df = cross_fold_inference(model)
260
+ pprint(results)
261
+ print(df.head())
262
+
263
+ print("\n=== LOCAL CROSS FOLD CLASSIFICATION EXAMPLE ===")
264
+ model = Model("wine-classification")
265
+ results, df = cross_fold_inference(model)
266
+ pprint(results)
267
+ print(df.head())