workbench 0.8.168__py3-none-any.whl → 0.8.192__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +143 -102
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +3 -2
- workbench/api/feature_set.py +4 -4
- workbench/api/model.py +16 -12
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +113 -27
- workbench/core/artifacts/feature_set_core.py +72 -13
- workbench/core/artifacts/model_core.py +50 -15
- workbench/core/artifacts/monitor_core.py +33 -249
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +11 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +9 -4
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +49 -53
- workbench/core/views/view.py +51 -1
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
- workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
- workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
- workbench/model_scripts/pytorch_model/pytorch.template +9 -18
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +7 -2
- workbench/model_scripts/uq_models/mapie.template +492 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/xgb_model.template +31 -40
- workbench/repl/workbench_shell.py +4 -4
- workbench/scripts/lambda_launcher.py +63 -0
- workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +76 -30
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +283 -145
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/METADATA +2 -1
- {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/RECORD +74 -70
- {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
- workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
- {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
"""DataCaptureCore class for managing SageMaker endpoint data capture"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
import time
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Tuple
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from sagemaker import Predictor
|
|
10
|
+
from sagemaker.model_monitor import DataCaptureConfig
|
|
11
|
+
import awswrangler as wr
|
|
12
|
+
|
|
13
|
+
# Workbench Imports
|
|
14
|
+
from workbench.core.artifacts.endpoint_core import EndpointCore
|
|
15
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
16
|
+
from workbench.utils.monitor_utils import process_data_capture
|
|
17
|
+
|
|
18
|
+
# Setup logging
|
|
19
|
+
log = logging.getLogger("workbench")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataCaptureCore:
|
|
23
|
+
"""Manages data capture configuration and retrieval for SageMaker endpoints"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, endpoint_name: str):
|
|
26
|
+
"""DataCaptureCore Class
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
endpoint_name (str): Name of the endpoint to manage data capture for
|
|
30
|
+
"""
|
|
31
|
+
self.log = logging.getLogger("workbench")
|
|
32
|
+
self.endpoint_name = endpoint_name
|
|
33
|
+
self.endpoint = EndpointCore(self.endpoint_name)
|
|
34
|
+
|
|
35
|
+
# Initialize Class Attributes
|
|
36
|
+
self.sagemaker_session = self.endpoint.sm_session
|
|
37
|
+
self.sagemaker_client = self.endpoint.sm_client
|
|
38
|
+
self.data_capture_path = self.endpoint.endpoint_data_capture_path
|
|
39
|
+
self.workbench_role_arn = AWSAccountClamp().aws_session.get_workbench_execution_role_arn()
|
|
40
|
+
|
|
41
|
+
def summary(self) -> dict:
|
|
42
|
+
"""Return the summary of data capture configuration
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
dict: Summary of data capture status
|
|
46
|
+
"""
|
|
47
|
+
if self.endpoint.is_serverless():
|
|
48
|
+
return {"endpoint_type": "serverless", "data_capture": "not supported"}
|
|
49
|
+
else:
|
|
50
|
+
return {
|
|
51
|
+
"endpoint_type": "realtime",
|
|
52
|
+
"data_capture_enabled": self.is_enabled(),
|
|
53
|
+
"capture_percentage": self.capture_percentage(),
|
|
54
|
+
"capture_modes": self.capture_modes() if self.is_enabled() else [],
|
|
55
|
+
"data_capture_path": self.data_capture_path if self.is_enabled() else None,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
def enable(self, capture_percentage=100, capture_options=None, force_redeploy=False):
|
|
59
|
+
"""
|
|
60
|
+
Enable data capture for the SageMaker endpoint.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
capture_percentage (int): Percentage of data to capture. Defaults to 100.
|
|
64
|
+
capture_options (list): List of what to capture - ["REQUEST"], ["RESPONSE"], or ["REQUEST", "RESPONSE"].
|
|
65
|
+
Defaults to ["REQUEST", "RESPONSE"] to capture both.
|
|
66
|
+
force_redeploy (bool): If True, force redeployment even if data capture is already enabled.
|
|
67
|
+
"""
|
|
68
|
+
# Early returns for cases where we can't/don't need to add data capture
|
|
69
|
+
if self.endpoint.is_serverless():
|
|
70
|
+
self.log.warning("Data capture is not supported for serverless endpoints.")
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Default to capturing both if not specified
|
|
74
|
+
if capture_options is None:
|
|
75
|
+
capture_options = ["REQUEST", "RESPONSE"]
|
|
76
|
+
|
|
77
|
+
# Validate capture_options
|
|
78
|
+
valid_options = {"REQUEST", "RESPONSE"}
|
|
79
|
+
if not all(opt in valid_options for opt in capture_options):
|
|
80
|
+
self.log.error("Invalid capture_options. Must be a list containing 'REQUEST' and/or 'RESPONSE'")
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
if self.is_enabled() and not force_redeploy:
|
|
84
|
+
self.log.important(f"Data capture already configured for {self.endpoint_name}.")
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
# Get the current endpoint configuration name for later deletion
|
|
88
|
+
current_endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
89
|
+
|
|
90
|
+
# Log the data capture operation
|
|
91
|
+
self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
|
|
92
|
+
self.log.important(f"Capturing: {', '.join(capture_options)} at {capture_percentage}% sampling")
|
|
93
|
+
self.log.important("This will redeploy the endpoint...")
|
|
94
|
+
|
|
95
|
+
# Create and apply the data capture configuration
|
|
96
|
+
data_capture_config = DataCaptureConfig(
|
|
97
|
+
enable_capture=True,
|
|
98
|
+
sampling_percentage=capture_percentage,
|
|
99
|
+
destination_s3_uri=self.data_capture_path,
|
|
100
|
+
capture_options=capture_options,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Update endpoint with the new capture configuration
|
|
104
|
+
Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
|
|
105
|
+
data_capture_config=data_capture_config
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Clean up old endpoint configuration
|
|
109
|
+
try:
|
|
110
|
+
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
|
|
111
|
+
self.log.info(f"Deleted old endpoint configuration: {current_endpoint_config_name}")
|
|
112
|
+
except Exception as e:
|
|
113
|
+
self.log.warning(f"Could not delete old endpoint configuration {current_endpoint_config_name}: {e}")
|
|
114
|
+
|
|
115
|
+
def disable(self):
|
|
116
|
+
"""
|
|
117
|
+
Disable data capture for the SageMaker endpoint.
|
|
118
|
+
"""
|
|
119
|
+
# Early return if data capture isn't configured
|
|
120
|
+
if not self.is_enabled():
|
|
121
|
+
self.log.important(f"Data capture is not currently enabled for {self.endpoint_name}.")
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
# Get the current endpoint configuration name for later deletion
|
|
125
|
+
current_endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
126
|
+
|
|
127
|
+
# Log the operation
|
|
128
|
+
self.log.important(f"Disabling Data Capture for {self.endpoint_name}")
|
|
129
|
+
self.log.important("This normally redeploys the endpoint...")
|
|
130
|
+
|
|
131
|
+
# Create a configuration with capture disabled
|
|
132
|
+
data_capture_config = DataCaptureConfig(enable_capture=False, destination_s3_uri=self.data_capture_path)
|
|
133
|
+
|
|
134
|
+
# Update endpoint with the new configuration
|
|
135
|
+
Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
|
|
136
|
+
data_capture_config=data_capture_config
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Clean up old endpoint configuration
|
|
140
|
+
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
|
|
141
|
+
|
|
142
|
+
def is_enabled(self) -> bool:
|
|
143
|
+
"""
|
|
144
|
+
Check if data capture is enabled on the endpoint.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
bool: True if data capture is enabled, False otherwise.
|
|
148
|
+
"""
|
|
149
|
+
try:
|
|
150
|
+
endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
151
|
+
endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
152
|
+
data_capture_config = endpoint_config.get("DataCaptureConfig", {})
|
|
153
|
+
|
|
154
|
+
# Check if data capture is enabled
|
|
155
|
+
is_enabled = data_capture_config.get("EnableCapture", False)
|
|
156
|
+
return is_enabled
|
|
157
|
+
except Exception as e:
|
|
158
|
+
self.log.error(f"Error checking data capture configuration: {e}")
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
def capture_percentage(self) -> int:
|
|
162
|
+
"""
|
|
163
|
+
Get the data capture percentage from the endpoint configuration.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
int: Data capture percentage if enabled, None otherwise.
|
|
167
|
+
"""
|
|
168
|
+
try:
|
|
169
|
+
endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
170
|
+
endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
171
|
+
data_capture_config = endpoint_config.get("DataCaptureConfig", {})
|
|
172
|
+
|
|
173
|
+
# Check if data capture is enabled and return the percentage
|
|
174
|
+
if data_capture_config.get("EnableCapture", False):
|
|
175
|
+
return data_capture_config.get("InitialSamplingPercentage", 0)
|
|
176
|
+
else:
|
|
177
|
+
return None
|
|
178
|
+
except Exception as e:
|
|
179
|
+
self.log.error(f"Error checking data capture percentage: {e}")
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
def get_config(self) -> dict:
|
|
183
|
+
"""
|
|
184
|
+
Returns the complete data capture configuration from the endpoint config.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
dict: Complete DataCaptureConfig from AWS, or None if not configured
|
|
188
|
+
"""
|
|
189
|
+
config_name = self.endpoint.endpoint_config_name()
|
|
190
|
+
response = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=config_name)
|
|
191
|
+
data_capture_config = response.get("DataCaptureConfig")
|
|
192
|
+
if not data_capture_config:
|
|
193
|
+
self.log.error(f"No data capture configuration found for endpoint config {config_name}")
|
|
194
|
+
return None
|
|
195
|
+
return data_capture_config
|
|
196
|
+
|
|
197
|
+
def capture_modes(self) -> list:
|
|
198
|
+
"""Get the current capture modes (REQUEST/RESPONSE)"""
|
|
199
|
+
if not self.is_enabled():
|
|
200
|
+
return []
|
|
201
|
+
|
|
202
|
+
config = self.get_config()
|
|
203
|
+
if not config:
|
|
204
|
+
return []
|
|
205
|
+
|
|
206
|
+
capture_options = config.get("CaptureOptions", [])
|
|
207
|
+
modes = [opt.get("CaptureMode") for opt in capture_options]
|
|
208
|
+
return ["REQUEST" if m == "Input" else "RESPONSE" for m in modes if m]
|
|
209
|
+
|
|
210
|
+
def get_captured_data(self, from_date: str = None, add_timestamp: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
211
|
+
"""
|
|
212
|
+
Read and process captured data from S3.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
from_date (str, optional): Only process files from this date onwards (YYYY-MM-DD format).
|
|
216
|
+
Defaults to None to process all files.
|
|
217
|
+
add_timestamp (bool, optional): Whether to add a timestamp column to the DataFrame.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Tuple[pd.DataFrame, pd.DataFrame]: Processed input and output DataFrames.
|
|
221
|
+
"""
|
|
222
|
+
files = wr.s3.list_objects(self.data_capture_path)
|
|
223
|
+
if not files:
|
|
224
|
+
self.log.warning(f"No data capture files found in {self.data_capture_path}.")
|
|
225
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
226
|
+
|
|
227
|
+
# Filter by date if specified
|
|
228
|
+
if from_date:
|
|
229
|
+
from_date_obj = datetime.strptime(from_date, "%Y-%m-%d").date()
|
|
230
|
+
files = [f for f in files if self._file_date_filter(f, from_date_obj)]
|
|
231
|
+
self.log.info(f"Processing {len(files)} files from {from_date} onwards.")
|
|
232
|
+
else:
|
|
233
|
+
self.log.info(f"Processing all {len(files)} files...")
|
|
234
|
+
|
|
235
|
+
# Check if any files remain after filtering
|
|
236
|
+
if not files:
|
|
237
|
+
self.log.info("No files to process after date filtering.")
|
|
238
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
239
|
+
|
|
240
|
+
# Sort files by name (assumed to include timestamp)
|
|
241
|
+
files.sort()
|
|
242
|
+
|
|
243
|
+
# Get all timestamps in one batch if needed
|
|
244
|
+
timestamps = {}
|
|
245
|
+
if add_timestamp:
|
|
246
|
+
# Batch describe operation - much more efficient than per-file calls
|
|
247
|
+
timestamps = wr.s3.describe_objects(path=files)
|
|
248
|
+
|
|
249
|
+
# Process files using concurrent.futures
|
|
250
|
+
start_time = time.time()
|
|
251
|
+
|
|
252
|
+
def process_single_file(file_path):
|
|
253
|
+
"""Process a single file and return input/output DataFrames."""
|
|
254
|
+
try:
|
|
255
|
+
log.debug(f"Processing file: {file_path}...")
|
|
256
|
+
df = wr.s3.read_json(path=file_path, lines=True)
|
|
257
|
+
if not df.empty:
|
|
258
|
+
input_df, output_df = process_data_capture(df)
|
|
259
|
+
if add_timestamp and file_path in timestamps:
|
|
260
|
+
output_df["timestamp"] = timestamps[file_path]["LastModified"]
|
|
261
|
+
return input_df, output_df
|
|
262
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
263
|
+
except Exception as e:
|
|
264
|
+
self.log.warning(f"Error processing {file_path}: {e}")
|
|
265
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
266
|
+
|
|
267
|
+
# Use ThreadPoolExecutor for I/O-bound operations
|
|
268
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
269
|
+
|
|
270
|
+
max_workers = min(32, len(files)) # Cap at 32 threads or number of files
|
|
271
|
+
|
|
272
|
+
all_input_dfs, all_output_dfs = [], []
|
|
273
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
274
|
+
futures = [executor.submit(process_single_file, file_path) for file_path in files]
|
|
275
|
+
for future in futures:
|
|
276
|
+
input_df, output_df = future.result()
|
|
277
|
+
if not input_df.empty:
|
|
278
|
+
all_input_dfs.append(input_df)
|
|
279
|
+
if not output_df.empty:
|
|
280
|
+
all_output_dfs.append(output_df)
|
|
281
|
+
|
|
282
|
+
if not all_input_dfs:
|
|
283
|
+
self.log.warning("No valid data was processed.")
|
|
284
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
285
|
+
|
|
286
|
+
input_df = pd.concat(all_input_dfs, ignore_index=True)
|
|
287
|
+
output_df = pd.concat(all_output_dfs, ignore_index=True)
|
|
288
|
+
|
|
289
|
+
elapsed_time = time.time() - start_time
|
|
290
|
+
self.log.info(f"Processed {len(files)} files in {elapsed_time:.2f} seconds.")
|
|
291
|
+
return input_df, output_df
|
|
292
|
+
|
|
293
|
+
def _file_date_filter(self, file_path, from_date_obj):
|
|
294
|
+
"""Extract date from S3 path and compare with from_date."""
|
|
295
|
+
try:
|
|
296
|
+
# Match YYYY/MM/DD pattern in the path
|
|
297
|
+
date_match = re.search(r"/(\d{4})/(\d{2})/(\d{2})/", file_path)
|
|
298
|
+
if date_match:
|
|
299
|
+
year, month, day = date_match.groups()
|
|
300
|
+
file_date = datetime(int(year), int(month), int(day)).date()
|
|
301
|
+
return file_date >= from_date_obj
|
|
302
|
+
return False # No date pattern found
|
|
303
|
+
except ValueError:
|
|
304
|
+
return False
|
|
305
|
+
|
|
306
|
+
def __repr__(self) -> str:
|
|
307
|
+
"""String representation of this DataCaptureCore object
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
str: String representation of this DataCaptureCore object
|
|
311
|
+
"""
|
|
312
|
+
summary_dict = self.summary()
|
|
313
|
+
summary_items = [f" {repr(key)}: {repr(value)}" for key, value in summary_dict.items()]
|
|
314
|
+
summary_str = f"{self.__class__.__name__}: {self.endpoint_name}\n" + ",\n".join(summary_items)
|
|
315
|
+
return summary_str
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# Test function for the class
|
|
319
|
+
if __name__ == "__main__":
|
|
320
|
+
"""Exercise the MonitorCore class"""
|
|
321
|
+
from pprint import pprint
|
|
322
|
+
|
|
323
|
+
# Set options for actually seeing the dataframe
|
|
324
|
+
pd.set_option("display.max_columns", None)
|
|
325
|
+
pd.set_option("display.width", None)
|
|
326
|
+
|
|
327
|
+
# Create the Class and test it out
|
|
328
|
+
endpoint_name = "abalone-regression-rt"
|
|
329
|
+
my_endpoint = EndpointCore(endpoint_name)
|
|
330
|
+
if not my_endpoint.exists():
|
|
331
|
+
print(f"Endpoint {endpoint_name} does not exist.")
|
|
332
|
+
exit(1)
|
|
333
|
+
dc = my_endpoint.data_capture()
|
|
334
|
+
|
|
335
|
+
# Check the summary of the data capture class
|
|
336
|
+
pprint(dc.summary())
|
|
337
|
+
|
|
338
|
+
# Enable data capture on the endpoint
|
|
339
|
+
# dc.enable(force_redeploy=True)
|
|
340
|
+
my_endpoint.enable_data_capture()
|
|
341
|
+
|
|
342
|
+
# Test the data capture by running some predictions
|
|
343
|
+
# pred_df = my_endpoint.auto_inference()
|
|
344
|
+
# print(pred_df.head())
|
|
345
|
+
|
|
346
|
+
# Check that data capture is working
|
|
347
|
+
input_df, output_df = dc.get_captured_data(from_date="2025-09-01")
|
|
348
|
+
if input_df.empty and output_df.empty:
|
|
349
|
+
print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
|
|
350
|
+
else:
|
|
351
|
+
print("Found data capture files")
|
|
352
|
+
print("Input")
|
|
353
|
+
print(input_df.head())
|
|
354
|
+
print("Output")
|
|
355
|
+
print(output_df.head())
|
|
@@ -8,7 +8,7 @@ import pandas as pd
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from io import StringIO
|
|
10
10
|
import awswrangler as wr
|
|
11
|
-
from typing import Union, Optional
|
|
11
|
+
from typing import Union, Optional, Tuple
|
|
12
12
|
import hashlib
|
|
13
13
|
|
|
14
14
|
# Model Performance Scores
|
|
@@ -32,11 +32,11 @@ from sagemaker import Predictor
|
|
|
32
32
|
from workbench.core.artifacts.artifact import Artifact
|
|
33
33
|
from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
|
|
34
34
|
from workbench.utils.endpoint_metrics import EndpointMetrics
|
|
35
|
-
from workbench.utils.fast_inference import fast_inference
|
|
36
35
|
from workbench.utils.cache import Cache
|
|
37
36
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
38
37
|
from workbench.utils.model_utils import uq_metrics
|
|
39
38
|
from workbench.utils.xgboost_model_utils import cross_fold_inference
|
|
39
|
+
from workbench_bridges.endpoints.fast_inference import fast_inference
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class EndpointCore(Artifact):
|
|
@@ -164,11 +164,17 @@ class EndpointCore(Artifact):
|
|
|
164
164
|
"""
|
|
165
165
|
return "Serverless" in self.endpoint_meta["InstanceType"]
|
|
166
166
|
|
|
167
|
-
def
|
|
167
|
+
def data_capture(self):
|
|
168
|
+
"""Get the MonitorCore class for this endpoint"""
|
|
169
|
+
from workbench.core.artifacts.data_capture_core import DataCaptureCore
|
|
170
|
+
|
|
171
|
+
return DataCaptureCore(self.endpoint_name)
|
|
172
|
+
|
|
173
|
+
def enable_data_capture(self):
|
|
168
174
|
"""Add data capture to the endpoint"""
|
|
169
|
-
self.
|
|
175
|
+
self.data_capture().enable()
|
|
170
176
|
|
|
171
|
-
def
|
|
177
|
+
def monitor(self):
|
|
172
178
|
"""Get the MonitorCore class for this endpoint"""
|
|
173
179
|
from workbench.core.artifacts.monitor_core import MonitorCore
|
|
174
180
|
|
|
@@ -350,7 +356,7 @@ class EndpointCore(Artifact):
|
|
|
350
356
|
return pd.DataFrame()
|
|
351
357
|
|
|
352
358
|
# Grab the evaluation data from the FeatureSet
|
|
353
|
-
table =
|
|
359
|
+
table = model.training_view().table
|
|
354
360
|
eval_df = fs.query(f'SELECT * FROM "{table}" where training = FALSE')
|
|
355
361
|
capture_name = "auto_inference" if capture else None
|
|
356
362
|
return self.inference(eval_df, capture_name, id_column=fs.id_column)
|
|
@@ -414,8 +420,12 @@ class EndpointCore(Artifact):
|
|
|
414
420
|
|
|
415
421
|
# Capture the inference results and metrics
|
|
416
422
|
if capture_name is not None:
|
|
423
|
+
|
|
424
|
+
# If we don't have an id_column, we'll pull it from the model's FeatureSet
|
|
425
|
+
if id_column is None:
|
|
426
|
+
fs = FeatureSetCore(model.get_input())
|
|
427
|
+
id_column = fs.id_column
|
|
417
428
|
description = capture_name.replace("_", " ").title()
|
|
418
|
-
features = model.features()
|
|
419
429
|
self._capture_inference_results(
|
|
420
430
|
capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
|
|
421
431
|
)
|
|
@@ -423,31 +433,81 @@ class EndpointCore(Artifact):
|
|
|
423
433
|
# For UQ Models we also capture the uncertainty metrics
|
|
424
434
|
if model_type in [ModelType.UQ_REGRESSOR]:
|
|
425
435
|
metrics = uq_metrics(prediction_df, target_column)
|
|
426
|
-
|
|
427
|
-
# Now put into the Parameter Store Model Inference Namespace
|
|
428
436
|
self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
|
|
429
437
|
|
|
430
438
|
# Return the prediction DataFrame
|
|
431
439
|
return prediction_df
|
|
432
440
|
|
|
433
|
-
def cross_fold_inference(self, nfolds: int = 5) -> dict:
|
|
441
|
+
def cross_fold_inference(self, nfolds: int = 5) -> Tuple[dict, pd.DataFrame]:
|
|
434
442
|
"""Run cross-fold inference (only works for XGBoost models)
|
|
435
443
|
|
|
436
444
|
Args:
|
|
437
445
|
nfolds (int): Number of folds to use for cross-fold (default: 5)
|
|
438
446
|
|
|
439
447
|
Returns:
|
|
440
|
-
dict:
|
|
448
|
+
Tuple[dict, pd.DataFrame]: Tuple of (cross_fold_metrics, out_of_fold_df)
|
|
441
449
|
"""
|
|
442
450
|
|
|
443
451
|
# Grab our model
|
|
444
452
|
model = ModelCore(self.model_name)
|
|
445
453
|
|
|
446
454
|
# Compute CrossFold Metrics
|
|
447
|
-
cross_fold_metrics = cross_fold_inference(model, nfolds=nfolds)
|
|
455
|
+
cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
|
|
448
456
|
if cross_fold_metrics:
|
|
449
457
|
self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
|
|
450
|
-
|
|
458
|
+
|
|
459
|
+
# Capture the results
|
|
460
|
+
capture_name = "full_cross_fold"
|
|
461
|
+
description = capture_name.replace("_", " ").title()
|
|
462
|
+
target_column = model.target()
|
|
463
|
+
model_type = model.model_type
|
|
464
|
+
|
|
465
|
+
# Get the id_column from the model's FeatureSet
|
|
466
|
+
fs = FeatureSetCore(model.get_input())
|
|
467
|
+
id_column = fs.id_column
|
|
468
|
+
|
|
469
|
+
# Is this a UQ Model? If so, run full inference and merge the results
|
|
470
|
+
additional_columns = []
|
|
471
|
+
if model_type == ModelType.UQ_REGRESSOR:
|
|
472
|
+
self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
|
|
473
|
+
|
|
474
|
+
# Get the training view dataframe for inference
|
|
475
|
+
training_df = model.training_view().pull_dataframe()
|
|
476
|
+
|
|
477
|
+
# Run inference on the endpoint to get UQ outputs
|
|
478
|
+
uq_df = self.inference(training_df)
|
|
479
|
+
|
|
480
|
+
# Identify UQ-specific columns (quantiles and prediction_std)
|
|
481
|
+
uq_columns = [col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std"]
|
|
482
|
+
|
|
483
|
+
# Merge UQ columns with out-of-fold predictions
|
|
484
|
+
if uq_columns:
|
|
485
|
+
# Keep id_column and UQ columns, drop 'prediction' to avoid conflict when merging
|
|
486
|
+
uq_df = uq_df[[id_column] + uq_columns]
|
|
487
|
+
|
|
488
|
+
# Drop duplicates in uq_df based on id_column
|
|
489
|
+
uq_df = uq_df.drop_duplicates(subset=[id_column])
|
|
490
|
+
|
|
491
|
+
# Merge UQ columns into out_of_fold_df
|
|
492
|
+
out_of_fold_df = pd.merge(out_of_fold_df, uq_df, on=id_column, how="left")
|
|
493
|
+
additional_columns = uq_columns
|
|
494
|
+
self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
|
|
495
|
+
|
|
496
|
+
# Also compute UQ metrics
|
|
497
|
+
metrics = uq_metrics(out_of_fold_df, target_column)
|
|
498
|
+
self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
|
|
499
|
+
|
|
500
|
+
self._capture_inference_results(
|
|
501
|
+
capture_name,
|
|
502
|
+
out_of_fold_df,
|
|
503
|
+
target_column,
|
|
504
|
+
model_type,
|
|
505
|
+
pd.DataFrame([cross_fold_metrics["summary_metrics"]]),
|
|
506
|
+
description,
|
|
507
|
+
features=additional_columns,
|
|
508
|
+
id_column=id_column,
|
|
509
|
+
)
|
|
510
|
+
return cross_fold_metrics, out_of_fold_df
|
|
451
511
|
|
|
452
512
|
def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
|
|
453
513
|
"""Run inference on the Endpoint using the provided DataFrame
|
|
@@ -642,6 +702,10 @@ class EndpointCore(Artifact):
|
|
|
642
702
|
@staticmethod
|
|
643
703
|
def _hash_dataframe(df: pd.DataFrame, hash_length: int = 8):
|
|
644
704
|
# Internal: Compute a data hash for the dataframe
|
|
705
|
+
if df.empty:
|
|
706
|
+
return "--hash--"
|
|
707
|
+
|
|
708
|
+
# Sort the dataframe by columns to ensure consistent ordering
|
|
645
709
|
df = df.copy()
|
|
646
710
|
df = df.sort_values(by=sorted(df.columns.tolist()))
|
|
647
711
|
row_hashes = pd.util.hash_pandas_object(df, index=False)
|
|
@@ -696,8 +760,8 @@ class EndpointCore(Artifact):
|
|
|
696
760
|
wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
|
|
697
761
|
|
|
698
762
|
# Grab the target column, prediction column, any _proba columns, and the ID column (if present)
|
|
699
|
-
|
|
700
|
-
output_columns
|
|
763
|
+
output_columns = [target_column]
|
|
764
|
+
output_columns += [col for col in pred_results_df.columns if "prediction" in col]
|
|
701
765
|
|
|
702
766
|
# Add any _proba columns to the output columns
|
|
703
767
|
output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
|
|
@@ -707,7 +771,7 @@ class EndpointCore(Artifact):
|
|
|
707
771
|
|
|
708
772
|
# Add the ID column
|
|
709
773
|
if id_column and id_column in pred_results_df.columns:
|
|
710
|
-
output_columns.
|
|
774
|
+
output_columns.insert(0, id_column)
|
|
711
775
|
|
|
712
776
|
# Write the predictions to our S3 Model Inference Folder
|
|
713
777
|
self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
|
|
@@ -929,9 +993,9 @@ class EndpointCore(Artifact):
|
|
|
929
993
|
self.upsert_workbench_meta({"workbench_input": input})
|
|
930
994
|
|
|
931
995
|
def delete(self):
|
|
932
|
-
"""
|
|
996
|
+
"""Delete an existing Endpoint: Underlying Models, Configuration, and Endpoint"""
|
|
933
997
|
if not self.exists():
|
|
934
|
-
self.log.warning(f"Trying to delete an
|
|
998
|
+
self.log.warning(f"Trying to delete an Endpoint that doesn't exist: {self.name}")
|
|
935
999
|
|
|
936
1000
|
# Remove this endpoint from the list of registered endpoints
|
|
937
1001
|
self.log.info(f"Removing {self.name} from the list of registered endpoints...")
|
|
@@ -972,12 +1036,23 @@ class EndpointCore(Artifact):
|
|
|
972
1036
|
cls.log.info(f"Deleting Monitoring Schedule {schedule['MonitoringScheduleName']}...")
|
|
973
1037
|
cls.sm_client.delete_monitoring_schedule(MonitoringScheduleName=schedule["MonitoringScheduleName"])
|
|
974
1038
|
|
|
975
|
-
# Recursively delete all endpoint S3 artifacts (inference,
|
|
1039
|
+
# Recursively delete all endpoint S3 artifacts (inference, etc)
|
|
1040
|
+
# Note: We do not want to delete the data_capture/ files since these
|
|
1041
|
+
# might be used for collection and data drift analysis
|
|
976
1042
|
base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}"
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
1043
|
+
all_s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
|
|
1044
|
+
|
|
1045
|
+
# Filter out objects that contain 'data_capture/' in their path
|
|
1046
|
+
s3_objects_to_delete = [obj for obj in all_s3_objects if "/data_capture/" not in obj]
|
|
1047
|
+
cls.log.info(f"Found {len(all_s3_objects)} total objects at {base_endpoint_path}")
|
|
1048
|
+
cls.log.info(f"Filtering out data_capture files, will delete {len(s3_objects_to_delete)} objects...")
|
|
1049
|
+
cls.log.info(f"Objects to delete: {s3_objects_to_delete}")
|
|
1050
|
+
|
|
1051
|
+
if s3_objects_to_delete:
|
|
1052
|
+
wr.s3.delete_objects(s3_objects_to_delete, boto3_session=cls.boto3_session)
|
|
1053
|
+
cls.log.info(f"Successfully deleted {len(s3_objects_to_delete)} objects")
|
|
1054
|
+
else:
|
|
1055
|
+
cls.log.info("No objects to delete (only data_capture files found)")
|
|
981
1056
|
|
|
982
1057
|
# Delete any dataframes that were stored in the Dataframe Cache
|
|
983
1058
|
cls.log.info("Deleting Dataframe Cache...")
|
|
@@ -1028,7 +1103,7 @@ class EndpointCore(Artifact):
|
|
|
1028
1103
|
if __name__ == "__main__":
|
|
1029
1104
|
"""Exercise the Endpoint Class"""
|
|
1030
1105
|
from workbench.api import FeatureSet
|
|
1031
|
-
from workbench.utils.endpoint_utils import
|
|
1106
|
+
from workbench.utils.endpoint_utils import get_evaluation_data
|
|
1032
1107
|
import random
|
|
1033
1108
|
|
|
1034
1109
|
# Grab an EndpointCore object and pull some information from it
|
|
@@ -1036,7 +1111,7 @@ if __name__ == "__main__":
|
|
|
1036
1111
|
|
|
1037
1112
|
# Test various error conditions (set row 42 length to pd.NA)
|
|
1038
1113
|
# Note: This test should return ALL rows
|
|
1039
|
-
my_eval_df =
|
|
1114
|
+
my_eval_df = get_evaluation_data(my_endpoint)
|
|
1040
1115
|
my_eval_df.at[42, "length"] = pd.NA
|
|
1041
1116
|
pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
|
|
1042
1117
|
print(f"Sent rows: {len(my_eval_df)}")
|
|
@@ -1044,6 +1119,9 @@ if __name__ == "__main__":
|
|
|
1044
1119
|
assert len(pred_results) == len(my_eval_df), "Predictions should match the number of sent rows"
|
|
1045
1120
|
|
|
1046
1121
|
# Now we put in an invalid value
|
|
1122
|
+
print("*" * 80)
|
|
1123
|
+
print("NOW TESTING ERROR CONDITIONS...")
|
|
1124
|
+
print("*" * 80)
|
|
1047
1125
|
my_eval_df.at[42, "length"] = "invalid_value"
|
|
1048
1126
|
pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
|
|
1049
1127
|
print(f"Sent rows: {len(my_eval_df)}")
|
|
@@ -1104,16 +1182,20 @@ if __name__ == "__main__":
|
|
|
1104
1182
|
# Run Inference where we provide the data
|
|
1105
1183
|
# Note: This dataframe could be from a FeatureSet or any other source
|
|
1106
1184
|
print("Running Inference...")
|
|
1107
|
-
my_eval_df =
|
|
1185
|
+
my_eval_df = get_evaluation_data(my_endpoint)
|
|
1108
1186
|
pred_results = my_endpoint.inference(my_eval_df)
|
|
1109
1187
|
|
|
1110
1188
|
# Now set capture=True to save inference results and metrics
|
|
1111
|
-
my_eval_df =
|
|
1189
|
+
my_eval_df = get_evaluation_data(my_endpoint)
|
|
1112
1190
|
pred_results = my_endpoint.inference(my_eval_df, capture_name="holdout_xyz")
|
|
1113
1191
|
|
|
1114
1192
|
# Run predictions using the fast_inference method
|
|
1115
1193
|
fast_results = my_endpoint.fast_inference(my_eval_df)
|
|
1116
1194
|
|
|
1195
|
+
# Test the cross_fold_inference method
|
|
1196
|
+
print("Running Cross-Fold Inference...")
|
|
1197
|
+
metrics, all_results = my_endpoint.cross_fold_inference()
|
|
1198
|
+
|
|
1117
1199
|
# Run Inference and metrics for a Classification Endpoint
|
|
1118
1200
|
class_endpoint = EndpointCore("wine-classification")
|
|
1119
1201
|
auto_predictions = class_endpoint.auto_inference()
|
|
@@ -1122,6 +1204,10 @@ if __name__ == "__main__":
|
|
|
1122
1204
|
target = "wine_class"
|
|
1123
1205
|
print(class_endpoint.generate_confusion_matrix(target, auto_predictions))
|
|
1124
1206
|
|
|
1207
|
+
# Test the cross_fold_inference method
|
|
1208
|
+
print("Running Cross-Fold Inference...")
|
|
1209
|
+
metrics, all_results = class_endpoint.cross_fold_inference()
|
|
1210
|
+
|
|
1125
1211
|
# Test the class method delete (commented out for now)
|
|
1126
1212
|
# from workbench.api import Model
|
|
1127
1213
|
# model = Model("abalone-regression")
|