workbench 0.8.168__py3-none-any.whl → 0.8.193__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +3 -2
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/model.py +16 -12
  7. workbench/api/monitor.py +1 -16
  8. workbench/core/artifacts/artifact.py +11 -3
  9. workbench/core/artifacts/data_capture_core.py +355 -0
  10. workbench/core/artifacts/endpoint_core.py +113 -27
  11. workbench/core/artifacts/feature_set_core.py +72 -13
  12. workbench/core/artifacts/model_core.py +71 -49
  13. workbench/core/artifacts/monitor_core.py +33 -249
  14. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  15. workbench/core/cloud_platform/aws/aws_meta.py +11 -4
  16. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  17. workbench/core/transforms/features_to_model/features_to_model.py +11 -6
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  19. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  20. workbench/core/views/training_view.py +49 -53
  21. workbench/core/views/view.py +51 -1
  22. workbench/core/views/view_utils.py +4 -4
  23. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  24. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  25. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  27. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  28. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  29. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  30. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  31. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  32. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  33. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  34. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  35. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  36. workbench/model_scripts/pytorch_model/pytorch.template +9 -18
  37. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  38. workbench/model_scripts/script_generation.py +7 -2
  39. workbench/model_scripts/uq_models/mapie.template +492 -0
  40. workbench/model_scripts/uq_models/requirements.txt +1 -0
  41. workbench/model_scripts/xgb_model/generated_model_script.py +34 -43
  42. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  43. workbench/repl/workbench_shell.py +4 -4
  44. workbench/scripts/lambda_launcher.py +63 -0
  45. workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
  46. workbench/scripts/ml_pipeline_sqs.py +186 -0
  47. workbench/utils/chem_utils/__init__.py +0 -0
  48. workbench/utils/chem_utils/fingerprints.py +134 -0
  49. workbench/utils/chem_utils/misc.py +194 -0
  50. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  51. workbench/utils/chem_utils/mol_standardize.py +450 -0
  52. workbench/utils/chem_utils/mol_tagging.py +348 -0
  53. workbench/utils/chem_utils/projections.py +209 -0
  54. workbench/utils/chem_utils/salts.py +256 -0
  55. workbench/utils/chem_utils/sdf.py +292 -0
  56. workbench/utils/chem_utils/toxicity.py +250 -0
  57. workbench/utils/chem_utils/vis.py +253 -0
  58. workbench/utils/config_manager.py +2 -6
  59. workbench/utils/endpoint_utils.py +5 -7
  60. workbench/utils/license_manager.py +2 -6
  61. workbench/utils/model_utils.py +89 -31
  62. workbench/utils/monitor_utils.py +44 -62
  63. workbench/utils/pandas_utils.py +3 -3
  64. workbench/utils/shap_utils.py +10 -2
  65. workbench/utils/workbench_sqs.py +1 -1
  66. workbench/utils/xgboost_model_utils.py +300 -151
  67. workbench/web_interface/components/model_plot.py +7 -1
  68. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  69. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  70. workbench/web_interface/components/plugins/model_details.py +7 -2
  71. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  72. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/METADATA +24 -2
  73. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/RECORD +77 -72
  74. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/entry_points.txt +3 -1
  75. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/licenses/LICENSE +1 -1
  76. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  77. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  78. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  79. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  80. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  81. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  82. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
  83. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  84. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  85. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  86. workbench/utils/chem_utils.py +0 -1556
  87. workbench/utils/fast_inference.py +0 -167
  88. workbench/utils/resource_utils.py +0 -39
  89. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/WHEEL +0 -0
  90. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,355 @@
1
+ """DataCaptureCore class for managing SageMaker endpoint data capture"""
2
+
3
+ import logging
4
+ import re
5
+ import time
6
+ from datetime import datetime
7
+ from typing import Tuple
8
+ import pandas as pd
9
+ from sagemaker import Predictor
10
+ from sagemaker.model_monitor import DataCaptureConfig
11
+ import awswrangler as wr
12
+
13
+ # Workbench Imports
14
+ from workbench.core.artifacts.endpoint_core import EndpointCore
15
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
16
+ from workbench.utils.monitor_utils import process_data_capture
17
+
18
+ # Setup logging
19
+ log = logging.getLogger("workbench")
20
+
21
+
22
+ class DataCaptureCore:
23
+ """Manages data capture configuration and retrieval for SageMaker endpoints"""
24
+
25
+ def __init__(self, endpoint_name: str):
26
+ """DataCaptureCore Class
27
+
28
+ Args:
29
+ endpoint_name (str): Name of the endpoint to manage data capture for
30
+ """
31
+ self.log = logging.getLogger("workbench")
32
+ self.endpoint_name = endpoint_name
33
+ self.endpoint = EndpointCore(self.endpoint_name)
34
+
35
+ # Initialize Class Attributes
36
+ self.sagemaker_session = self.endpoint.sm_session
37
+ self.sagemaker_client = self.endpoint.sm_client
38
+ self.data_capture_path = self.endpoint.endpoint_data_capture_path
39
+ self.workbench_role_arn = AWSAccountClamp().aws_session.get_workbench_execution_role_arn()
40
+
41
+ def summary(self) -> dict:
42
+ """Return the summary of data capture configuration
43
+
44
+ Returns:
45
+ dict: Summary of data capture status
46
+ """
47
+ if self.endpoint.is_serverless():
48
+ return {"endpoint_type": "serverless", "data_capture": "not supported"}
49
+ else:
50
+ return {
51
+ "endpoint_type": "realtime",
52
+ "data_capture_enabled": self.is_enabled(),
53
+ "capture_percentage": self.capture_percentage(),
54
+ "capture_modes": self.capture_modes() if self.is_enabled() else [],
55
+ "data_capture_path": self.data_capture_path if self.is_enabled() else None,
56
+ }
57
+
58
+ def enable(self, capture_percentage=100, capture_options=None, force_redeploy=False):
59
+ """
60
+ Enable data capture for the SageMaker endpoint.
61
+
62
+ Args:
63
+ capture_percentage (int): Percentage of data to capture. Defaults to 100.
64
+ capture_options (list): List of what to capture - ["REQUEST"], ["RESPONSE"], or ["REQUEST", "RESPONSE"].
65
+ Defaults to ["REQUEST", "RESPONSE"] to capture both.
66
+ force_redeploy (bool): If True, force redeployment even if data capture is already enabled.
67
+ """
68
+ # Early returns for cases where we can't/don't need to add data capture
69
+ if self.endpoint.is_serverless():
70
+ self.log.warning("Data capture is not supported for serverless endpoints.")
71
+ return
72
+
73
+ # Default to capturing both if not specified
74
+ if capture_options is None:
75
+ capture_options = ["REQUEST", "RESPONSE"]
76
+
77
+ # Validate capture_options
78
+ valid_options = {"REQUEST", "RESPONSE"}
79
+ if not all(opt in valid_options for opt in capture_options):
80
+ self.log.error("Invalid capture_options. Must be a list containing 'REQUEST' and/or 'RESPONSE'")
81
+ return
82
+
83
+ if self.is_enabled() and not force_redeploy:
84
+ self.log.important(f"Data capture already configured for {self.endpoint_name}.")
85
+ return
86
+
87
+ # Get the current endpoint configuration name for later deletion
88
+ current_endpoint_config_name = self.endpoint.endpoint_config_name()
89
+
90
+ # Log the data capture operation
91
+ self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
92
+ self.log.important(f"Capturing: {', '.join(capture_options)} at {capture_percentage}% sampling")
93
+ self.log.important("This will redeploy the endpoint...")
94
+
95
+ # Create and apply the data capture configuration
96
+ data_capture_config = DataCaptureConfig(
97
+ enable_capture=True,
98
+ sampling_percentage=capture_percentage,
99
+ destination_s3_uri=self.data_capture_path,
100
+ capture_options=capture_options,
101
+ )
102
+
103
+ # Update endpoint with the new capture configuration
104
+ Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
105
+ data_capture_config=data_capture_config
106
+ )
107
+
108
+ # Clean up old endpoint configuration
109
+ try:
110
+ self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
111
+ self.log.info(f"Deleted old endpoint configuration: {current_endpoint_config_name}")
112
+ except Exception as e:
113
+ self.log.warning(f"Could not delete old endpoint configuration {current_endpoint_config_name}: {e}")
114
+
115
+ def disable(self):
116
+ """
117
+ Disable data capture for the SageMaker endpoint.
118
+ """
119
+ # Early return if data capture isn't configured
120
+ if not self.is_enabled():
121
+ self.log.important(f"Data capture is not currently enabled for {self.endpoint_name}.")
122
+ return
123
+
124
+ # Get the current endpoint configuration name for later deletion
125
+ current_endpoint_config_name = self.endpoint.endpoint_config_name()
126
+
127
+ # Log the operation
128
+ self.log.important(f"Disabling Data Capture for {self.endpoint_name}")
129
+ self.log.important("This normally redeploys the endpoint...")
130
+
131
+ # Create a configuration with capture disabled
132
+ data_capture_config = DataCaptureConfig(enable_capture=False, destination_s3_uri=self.data_capture_path)
133
+
134
+ # Update endpoint with the new configuration
135
+ Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
136
+ data_capture_config=data_capture_config
137
+ )
138
+
139
+ # Clean up old endpoint configuration
140
+ self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
141
+
142
+ def is_enabled(self) -> bool:
143
+ """
144
+ Check if data capture is enabled on the endpoint.
145
+
146
+ Returns:
147
+ bool: True if data capture is enabled, False otherwise.
148
+ """
149
+ try:
150
+ endpoint_config_name = self.endpoint.endpoint_config_name()
151
+ endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
152
+ data_capture_config = endpoint_config.get("DataCaptureConfig", {})
153
+
154
+ # Check if data capture is enabled
155
+ is_enabled = data_capture_config.get("EnableCapture", False)
156
+ return is_enabled
157
+ except Exception as e:
158
+ self.log.error(f"Error checking data capture configuration: {e}")
159
+ return False
160
+
161
+ def capture_percentage(self) -> int:
162
+ """
163
+ Get the data capture percentage from the endpoint configuration.
164
+
165
+ Returns:
166
+ int: Data capture percentage if enabled, None otherwise.
167
+ """
168
+ try:
169
+ endpoint_config_name = self.endpoint.endpoint_config_name()
170
+ endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
171
+ data_capture_config = endpoint_config.get("DataCaptureConfig", {})
172
+
173
+ # Check if data capture is enabled and return the percentage
174
+ if data_capture_config.get("EnableCapture", False):
175
+ return data_capture_config.get("InitialSamplingPercentage", 0)
176
+ else:
177
+ return None
178
+ except Exception as e:
179
+ self.log.error(f"Error checking data capture percentage: {e}")
180
+ return None
181
+
182
+ def get_config(self) -> dict:
183
+ """
184
+ Returns the complete data capture configuration from the endpoint config.
185
+
186
+ Returns:
187
+ dict: Complete DataCaptureConfig from AWS, or None if not configured
188
+ """
189
+ config_name = self.endpoint.endpoint_config_name()
190
+ response = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=config_name)
191
+ data_capture_config = response.get("DataCaptureConfig")
192
+ if not data_capture_config:
193
+ self.log.error(f"No data capture configuration found for endpoint config {config_name}")
194
+ return None
195
+ return data_capture_config
196
+
197
+ def capture_modes(self) -> list:
198
+ """Get the current capture modes (REQUEST/RESPONSE)"""
199
+ if not self.is_enabled():
200
+ return []
201
+
202
+ config = self.get_config()
203
+ if not config:
204
+ return []
205
+
206
+ capture_options = config.get("CaptureOptions", [])
207
+ modes = [opt.get("CaptureMode") for opt in capture_options]
208
+ return ["REQUEST" if m == "Input" else "RESPONSE" for m in modes if m]
209
+
210
+ def get_captured_data(self, from_date: str = None, add_timestamp: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
211
+ """
212
+ Read and process captured data from S3.
213
+
214
+ Args:
215
+ from_date (str, optional): Only process files from this date onwards (YYYY-MM-DD format).
216
+ Defaults to None to process all files.
217
+ add_timestamp (bool, optional): Whether to add a timestamp column to the DataFrame.
218
+
219
+ Returns:
220
+ Tuple[pd.DataFrame, pd.DataFrame]: Processed input and output DataFrames.
221
+ """
222
+ files = wr.s3.list_objects(self.data_capture_path)
223
+ if not files:
224
+ self.log.warning(f"No data capture files found in {self.data_capture_path}.")
225
+ return pd.DataFrame(), pd.DataFrame()
226
+
227
+ # Filter by date if specified
228
+ if from_date:
229
+ from_date_obj = datetime.strptime(from_date, "%Y-%m-%d").date()
230
+ files = [f for f in files if self._file_date_filter(f, from_date_obj)]
231
+ self.log.info(f"Processing {len(files)} files from {from_date} onwards.")
232
+ else:
233
+ self.log.info(f"Processing all {len(files)} files...")
234
+
235
+ # Check if any files remain after filtering
236
+ if not files:
237
+ self.log.info("No files to process after date filtering.")
238
+ return pd.DataFrame(), pd.DataFrame()
239
+
240
+ # Sort files by name (assumed to include timestamp)
241
+ files.sort()
242
+
243
+ # Get all timestamps in one batch if needed
244
+ timestamps = {}
245
+ if add_timestamp:
246
+ # Batch describe operation - much more efficient than per-file calls
247
+ timestamps = wr.s3.describe_objects(path=files)
248
+
249
+ # Process files using concurrent.futures
250
+ start_time = time.time()
251
+
252
+ def process_single_file(file_path):
253
+ """Process a single file and return input/output DataFrames."""
254
+ try:
255
+ log.debug(f"Processing file: {file_path}...")
256
+ df = wr.s3.read_json(path=file_path, lines=True)
257
+ if not df.empty:
258
+ input_df, output_df = process_data_capture(df)
259
+ if add_timestamp and file_path in timestamps:
260
+ output_df["timestamp"] = timestamps[file_path]["LastModified"]
261
+ return input_df, output_df
262
+ return pd.DataFrame(), pd.DataFrame()
263
+ except Exception as e:
264
+ self.log.warning(f"Error processing {file_path}: {e}")
265
+ return pd.DataFrame(), pd.DataFrame()
266
+
267
+ # Use ThreadPoolExecutor for I/O-bound operations
268
+ from concurrent.futures import ThreadPoolExecutor
269
+
270
+ max_workers = min(32, len(files)) # Cap at 32 threads or number of files
271
+
272
+ all_input_dfs, all_output_dfs = [], []
273
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
274
+ futures = [executor.submit(process_single_file, file_path) for file_path in files]
275
+ for future in futures:
276
+ input_df, output_df = future.result()
277
+ if not input_df.empty:
278
+ all_input_dfs.append(input_df)
279
+ if not output_df.empty:
280
+ all_output_dfs.append(output_df)
281
+
282
+ if not all_input_dfs:
283
+ self.log.warning("No valid data was processed.")
284
+ return pd.DataFrame(), pd.DataFrame()
285
+
286
+ input_df = pd.concat(all_input_dfs, ignore_index=True)
287
+ output_df = pd.concat(all_output_dfs, ignore_index=True)
288
+
289
+ elapsed_time = time.time() - start_time
290
+ self.log.info(f"Processed {len(files)} files in {elapsed_time:.2f} seconds.")
291
+ return input_df, output_df
292
+
293
+ def _file_date_filter(self, file_path, from_date_obj):
294
+ """Extract date from S3 path and compare with from_date."""
295
+ try:
296
+ # Match YYYY/MM/DD pattern in the path
297
+ date_match = re.search(r"/(\d{4})/(\d{2})/(\d{2})/", file_path)
298
+ if date_match:
299
+ year, month, day = date_match.groups()
300
+ file_date = datetime(int(year), int(month), int(day)).date()
301
+ return file_date >= from_date_obj
302
+ return False # No date pattern found
303
+ except ValueError:
304
+ return False
305
+
306
+ def __repr__(self) -> str:
307
+ """String representation of this DataCaptureCore object
308
+
309
+ Returns:
310
+ str: String representation of this DataCaptureCore object
311
+ """
312
+ summary_dict = self.summary()
313
+ summary_items = [f" {repr(key)}: {repr(value)}" for key, value in summary_dict.items()]
314
+ summary_str = f"{self.__class__.__name__}: {self.endpoint_name}\n" + ",\n".join(summary_items)
315
+ return summary_str
316
+
317
+
318
+ # Test function for the class
319
+ if __name__ == "__main__":
320
+ """Exercise the MonitorCore class"""
321
+ from pprint import pprint
322
+
323
+ # Set options for actually seeing the dataframe
324
+ pd.set_option("display.max_columns", None)
325
+ pd.set_option("display.width", None)
326
+
327
+ # Create the Class and test it out
328
+ endpoint_name = "abalone-regression-rt"
329
+ my_endpoint = EndpointCore(endpoint_name)
330
+ if not my_endpoint.exists():
331
+ print(f"Endpoint {endpoint_name} does not exist.")
332
+ exit(1)
333
+ dc = my_endpoint.data_capture()
334
+
335
+ # Check the summary of the data capture class
336
+ pprint(dc.summary())
337
+
338
+ # Enable data capture on the endpoint
339
+ # dc.enable(force_redeploy=True)
340
+ my_endpoint.enable_data_capture()
341
+
342
+ # Test the data capture by running some predictions
343
+ # pred_df = my_endpoint.auto_inference()
344
+ # print(pred_df.head())
345
+
346
+ # Check that data capture is working
347
+ input_df, output_df = dc.get_captured_data(from_date="2025-09-01")
348
+ if input_df.empty and output_df.empty:
349
+ print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
350
+ else:
351
+ print("Found data capture files")
352
+ print("Input")
353
+ print(input_df.head())
354
+ print("Output")
355
+ print(output_df.head())
@@ -8,7 +8,7 @@ import pandas as pd
8
8
  import numpy as np
9
9
  from io import StringIO
10
10
  import awswrangler as wr
11
- from typing import Union, Optional
11
+ from typing import Union, Optional, Tuple
12
12
  import hashlib
13
13
 
14
14
  # Model Performance Scores
@@ -32,11 +32,11 @@ from sagemaker import Predictor
32
32
  from workbench.core.artifacts.artifact import Artifact
33
33
  from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
34
34
  from workbench.utils.endpoint_metrics import EndpointMetrics
35
- from workbench.utils.fast_inference import fast_inference
36
35
  from workbench.utils.cache import Cache
37
36
  from workbench.utils.s3_utils import compute_s3_object_hash
38
37
  from workbench.utils.model_utils import uq_metrics
39
38
  from workbench.utils.xgboost_model_utils import cross_fold_inference
39
+ from workbench_bridges.endpoints.fast_inference import fast_inference
40
40
 
41
41
 
42
42
  class EndpointCore(Artifact):
@@ -164,11 +164,17 @@ class EndpointCore(Artifact):
164
164
  """
165
165
  return "Serverless" in self.endpoint_meta["InstanceType"]
166
166
 
167
- def add_data_capture(self):
167
+ def data_capture(self):
168
+ """Get the MonitorCore class for this endpoint"""
169
+ from workbench.core.artifacts.data_capture_core import DataCaptureCore
170
+
171
+ return DataCaptureCore(self.endpoint_name)
172
+
173
+ def enable_data_capture(self):
168
174
  """Add data capture to the endpoint"""
169
- self.get_monitor().add_data_capture()
175
+ self.data_capture().enable()
170
176
 
171
- def get_monitor(self):
177
+ def monitor(self):
172
178
  """Get the MonitorCore class for this endpoint"""
173
179
  from workbench.core.artifacts.monitor_core import MonitorCore
174
180
 
@@ -350,7 +356,7 @@ class EndpointCore(Artifact):
350
356
  return pd.DataFrame()
351
357
 
352
358
  # Grab the evaluation data from the FeatureSet
353
- table = fs.view("training").table
359
+ table = model.training_view().table
354
360
  eval_df = fs.query(f'SELECT * FROM "{table}" where training = FALSE')
355
361
  capture_name = "auto_inference" if capture else None
356
362
  return self.inference(eval_df, capture_name, id_column=fs.id_column)
@@ -414,8 +420,12 @@ class EndpointCore(Artifact):
414
420
 
415
421
  # Capture the inference results and metrics
416
422
  if capture_name is not None:
423
+
424
+ # If we don't have an id_column, we'll pull it from the model's FeatureSet
425
+ if id_column is None:
426
+ fs = FeatureSetCore(model.get_input())
427
+ id_column = fs.id_column
417
428
  description = capture_name.replace("_", " ").title()
418
- features = model.features()
419
429
  self._capture_inference_results(
420
430
  capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
421
431
  )
@@ -423,31 +433,81 @@ class EndpointCore(Artifact):
423
433
  # For UQ Models we also capture the uncertainty metrics
424
434
  if model_type in [ModelType.UQ_REGRESSOR]:
425
435
  metrics = uq_metrics(prediction_df, target_column)
426
-
427
- # Now put into the Parameter Store Model Inference Namespace
428
436
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
429
437
 
430
438
  # Return the prediction DataFrame
431
439
  return prediction_df
432
440
 
433
- def cross_fold_inference(self, nfolds: int = 5) -> dict:
441
+ def cross_fold_inference(self, nfolds: int = 5) -> Tuple[dict, pd.DataFrame]:
434
442
  """Run cross-fold inference (only works for XGBoost models)
435
443
 
436
444
  Args:
437
445
  nfolds (int): Number of folds to use for cross-fold (default: 5)
438
446
 
439
447
  Returns:
440
- dict: Dictionary with the cross-fold inference results
448
+ Tuple[dict, pd.DataFrame]: Tuple of (cross_fold_metrics, out_of_fold_df)
441
449
  """
442
450
 
443
451
  # Grab our model
444
452
  model = ModelCore(self.model_name)
445
453
 
446
454
  # Compute CrossFold Metrics
447
- cross_fold_metrics = cross_fold_inference(model, nfolds=nfolds)
455
+ cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
448
456
  if cross_fold_metrics:
449
457
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
450
- return cross_fold_metrics
458
+
459
+ # Capture the results
460
+ capture_name = "full_cross_fold"
461
+ description = capture_name.replace("_", " ").title()
462
+ target_column = model.target()
463
+ model_type = model.model_type
464
+
465
+ # Get the id_column from the model's FeatureSet
466
+ fs = FeatureSetCore(model.get_input())
467
+ id_column = fs.id_column
468
+
469
+ # Is this a UQ Model? If so, run full inference and merge the results
470
+ additional_columns = []
471
+ if model_type == ModelType.UQ_REGRESSOR:
472
+ self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
473
+
474
+ # Get the training view dataframe for inference
475
+ training_df = model.training_view().pull_dataframe()
476
+
477
+ # Run inference on the endpoint to get UQ outputs
478
+ uq_df = self.inference(training_df)
479
+
480
+ # Identify UQ-specific columns (quantiles and prediction_std)
481
+ uq_columns = [col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std"]
482
+
483
+ # Merge UQ columns with out-of-fold predictions
484
+ if uq_columns:
485
+ # Keep id_column and UQ columns, drop 'prediction' to avoid conflict when merging
486
+ uq_df = uq_df[[id_column] + uq_columns]
487
+
488
+ # Drop duplicates in uq_df based on id_column
489
+ uq_df = uq_df.drop_duplicates(subset=[id_column])
490
+
491
+ # Merge UQ columns into out_of_fold_df
492
+ out_of_fold_df = pd.merge(out_of_fold_df, uq_df, on=id_column, how="left")
493
+ additional_columns = uq_columns
494
+ self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
495
+
496
+ # Also compute UQ metrics
497
+ metrics = uq_metrics(out_of_fold_df, target_column)
498
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
499
+
500
+ self._capture_inference_results(
501
+ capture_name,
502
+ out_of_fold_df,
503
+ target_column,
504
+ model_type,
505
+ pd.DataFrame([cross_fold_metrics["summary_metrics"]]),
506
+ description,
507
+ features=additional_columns,
508
+ id_column=id_column,
509
+ )
510
+ return cross_fold_metrics, out_of_fold_df
451
511
 
452
512
  def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
453
513
  """Run inference on the Endpoint using the provided DataFrame
@@ -642,6 +702,10 @@ class EndpointCore(Artifact):
642
702
  @staticmethod
643
703
  def _hash_dataframe(df: pd.DataFrame, hash_length: int = 8):
644
704
  # Internal: Compute a data hash for the dataframe
705
+ if df.empty:
706
+ return "--hash--"
707
+
708
+ # Sort the dataframe by columns to ensure consistent ordering
645
709
  df = df.copy()
646
710
  df = df.sort_values(by=sorted(df.columns.tolist()))
647
711
  row_hashes = pd.util.hash_pandas_object(df, index=False)
@@ -696,8 +760,8 @@ class EndpointCore(Artifact):
696
760
  wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
697
761
 
698
762
  # Grab the target column, prediction column, any _proba columns, and the ID column (if present)
699
- prediction_col = "prediction" if "prediction" in pred_results_df.columns else "predictions"
700
- output_columns = [target_column, prediction_col]
763
+ output_columns = [target_column]
764
+ output_columns += [col for col in pred_results_df.columns if "prediction" in col]
701
765
 
702
766
  # Add any _proba columns to the output columns
703
767
  output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
@@ -707,7 +771,7 @@ class EndpointCore(Artifact):
707
771
 
708
772
  # Add the ID column
709
773
  if id_column and id_column in pred_results_df.columns:
710
- output_columns.append(id_column)
774
+ output_columns.insert(0, id_column)
711
775
 
712
776
  # Write the predictions to our S3 Model Inference Folder
713
777
  self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
@@ -929,9 +993,9 @@ class EndpointCore(Artifact):
929
993
  self.upsert_workbench_meta({"workbench_input": input})
930
994
 
931
995
  def delete(self):
932
- """ "Delete an existing Endpoint: Underlying Models, Configuration, and Endpoint"""
996
+ """Delete an existing Endpoint: Underlying Models, Configuration, and Endpoint"""
933
997
  if not self.exists():
934
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
998
+ self.log.warning(f"Trying to delete an Endpoint that doesn't exist: {self.name}")
935
999
 
936
1000
  # Remove this endpoint from the list of registered endpoints
937
1001
  self.log.info(f"Removing {self.name} from the list of registered endpoints...")
@@ -972,12 +1036,23 @@ class EndpointCore(Artifact):
972
1036
  cls.log.info(f"Deleting Monitoring Schedule {schedule['MonitoringScheduleName']}...")
973
1037
  cls.sm_client.delete_monitoring_schedule(MonitoringScheduleName=schedule["MonitoringScheduleName"])
974
1038
 
975
- # Recursively delete all endpoint S3 artifacts (inference, data capture, monitoring, etc)
1039
+ # Recursively delete all endpoint S3 artifacts (inference, etc)
1040
+ # Note: We do not want to delete the data_capture/ files since these
1041
+ # might be used for collection and data drift analysis
976
1042
  base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}"
977
- s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
978
- cls.log.info(f"Deleting S3 Objects at {base_endpoint_path}...")
979
- cls.log.info(f"{s3_objects}")
980
- wr.s3.delete_objects(s3_objects, boto3_session=cls.boto3_session)
1043
+ all_s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
1044
+
1045
+ # Filter out objects that contain 'data_capture/' in their path
1046
+ s3_objects_to_delete = [obj for obj in all_s3_objects if "/data_capture/" not in obj]
1047
+ cls.log.info(f"Found {len(all_s3_objects)} total objects at {base_endpoint_path}")
1048
+ cls.log.info(f"Filtering out data_capture files, will delete {len(s3_objects_to_delete)} objects...")
1049
+ cls.log.info(f"Objects to delete: {s3_objects_to_delete}")
1050
+
1051
+ if s3_objects_to_delete:
1052
+ wr.s3.delete_objects(s3_objects_to_delete, boto3_session=cls.boto3_session)
1053
+ cls.log.info(f"Successfully deleted {len(s3_objects_to_delete)} objects")
1054
+ else:
1055
+ cls.log.info("No objects to delete (only data_capture files found)")
981
1056
 
982
1057
  # Delete any dataframes that were stored in the Dataframe Cache
983
1058
  cls.log.info("Deleting Dataframe Cache...")
@@ -1028,7 +1103,7 @@ class EndpointCore(Artifact):
1028
1103
  if __name__ == "__main__":
1029
1104
  """Exercise the Endpoint Class"""
1030
1105
  from workbench.api import FeatureSet
1031
- from workbench.utils.endpoint_utils import fs_evaluation_data
1106
+ from workbench.utils.endpoint_utils import get_evaluation_data
1032
1107
  import random
1033
1108
 
1034
1109
  # Grab an EndpointCore object and pull some information from it
@@ -1036,7 +1111,7 @@ if __name__ == "__main__":
1036
1111
 
1037
1112
  # Test various error conditions (set row 42 length to pd.NA)
1038
1113
  # Note: This test should return ALL rows
1039
- my_eval_df = fs_evaluation_data(my_endpoint)
1114
+ my_eval_df = get_evaluation_data(my_endpoint)
1040
1115
  my_eval_df.at[42, "length"] = pd.NA
1041
1116
  pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
1042
1117
  print(f"Sent rows: {len(my_eval_df)}")
@@ -1044,6 +1119,9 @@ if __name__ == "__main__":
1044
1119
  assert len(pred_results) == len(my_eval_df), "Predictions should match the number of sent rows"
1045
1120
 
1046
1121
  # Now we put in an invalid value
1122
+ print("*" * 80)
1123
+ print("NOW TESTING ERROR CONDITIONS...")
1124
+ print("*" * 80)
1047
1125
  my_eval_df.at[42, "length"] = "invalid_value"
1048
1126
  pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
1049
1127
  print(f"Sent rows: {len(my_eval_df)}")
@@ -1104,16 +1182,20 @@ if __name__ == "__main__":
1104
1182
  # Run Inference where we provide the data
1105
1183
  # Note: This dataframe could be from a FeatureSet or any other source
1106
1184
  print("Running Inference...")
1107
- my_eval_df = fs_evaluation_data(my_endpoint)
1185
+ my_eval_df = get_evaluation_data(my_endpoint)
1108
1186
  pred_results = my_endpoint.inference(my_eval_df)
1109
1187
 
1110
1188
  # Now set capture=True to save inference results and metrics
1111
- my_eval_df = fs_evaluation_data(my_endpoint)
1189
+ my_eval_df = get_evaluation_data(my_endpoint)
1112
1190
  pred_results = my_endpoint.inference(my_eval_df, capture_name="holdout_xyz")
1113
1191
 
1114
1192
  # Run predictions using the fast_inference method
1115
1193
  fast_results = my_endpoint.fast_inference(my_eval_df)
1116
1194
 
1195
+ # Test the cross_fold_inference method
1196
+ print("Running Cross-Fold Inference...")
1197
+ metrics, all_results = my_endpoint.cross_fold_inference()
1198
+
1117
1199
  # Run Inference and metrics for a Classification Endpoint
1118
1200
  class_endpoint = EndpointCore("wine-classification")
1119
1201
  auto_predictions = class_endpoint.auto_inference()
@@ -1122,6 +1204,10 @@ if __name__ == "__main__":
1122
1204
  target = "wine_class"
1123
1205
  print(class_endpoint.generate_confusion_matrix(target, auto_predictions))
1124
1206
 
1207
+ # Test the cross_fold_inference method
1208
+ print("Running Cross-Fold Inference...")
1209
+ metrics, all_results = class_endpoint.cross_fold_inference()
1210
+
1125
1211
  # Test the class method delete (commented out for now)
1126
1212
  # from workbench.api import Model
1127
1213
  # model = Model("abalone-regression")