workbench 0.8.168__py3-none-any.whl → 0.8.193__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +143 -102
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +3 -2
- workbench/api/feature_set.py +4 -4
- workbench/api/model.py +16 -12
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +113 -27
- workbench/core/artifacts/feature_set_core.py +72 -13
- workbench/core/artifacts/model_core.py +71 -49
- workbench/core/artifacts/monitor_core.py +33 -249
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +11 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +11 -6
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +49 -53
- workbench/core/views/view.py +51 -1
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
- workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
- workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
- workbench/model_scripts/pytorch_model/pytorch.template +9 -18
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +7 -2
- workbench/model_scripts/uq_models/mapie.template +492 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +34 -43
- workbench/model_scripts/xgb_model/xgb_model.template +31 -40
- workbench/repl/workbench_shell.py +4 -4
- workbench/scripts/lambda_launcher.py +63 -0
- workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +89 -31
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +300 -151
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +7 -2
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/METADATA +24 -2
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/RECORD +77 -72
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/entry_points.txt +3 -1
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/WHEEL +0 -0
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/top_level.txt +0 -0
workbench/utils/model_utils.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import numpy as np
|
|
6
|
+
from scipy.stats import spearmanr
|
|
6
7
|
import importlib.resources
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
import os
|
|
@@ -92,6 +93,31 @@ def get_custom_script_path(package: str, script_name: str) -> Path:
|
|
|
92
93
|
return script_path
|
|
93
94
|
|
|
94
95
|
|
|
96
|
+
def proximity_model_local(model: "Model"):
|
|
97
|
+
"""Create a Proximity Model for this Model
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
model (Model): The Model/FeatureSet used to create the proximity model
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Proximity: The proximity model
|
|
104
|
+
"""
|
|
105
|
+
from workbench.algorithms.dataframe.proximity import Proximity # noqa: F401 (avoid circular import)
|
|
106
|
+
from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
|
|
107
|
+
|
|
108
|
+
# Get Feature and Target Columns from the existing given Model
|
|
109
|
+
features = model.features()
|
|
110
|
+
target = model.target()
|
|
111
|
+
|
|
112
|
+
# Backtrack our FeatureSet to get the ID column
|
|
113
|
+
fs = FeatureSet(model.get_input())
|
|
114
|
+
id_column = fs.id_column
|
|
115
|
+
|
|
116
|
+
# Create the Proximity Model from our Training Data
|
|
117
|
+
df = model.training_view().pull_dataframe()
|
|
118
|
+
return Proximity(df, id_column, features, target, track_columns=features)
|
|
119
|
+
|
|
120
|
+
|
|
95
121
|
def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
|
|
96
122
|
"""Create a proximity model based on the given model
|
|
97
123
|
|
|
@@ -140,7 +166,7 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
|
|
|
140
166
|
from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
|
|
141
167
|
|
|
142
168
|
# Get the custom script path for the UQ model
|
|
143
|
-
script_path = get_custom_script_path("uq_models", "
|
|
169
|
+
script_path = get_custom_script_path("uq_models", "mapie.template")
|
|
144
170
|
|
|
145
171
|
# Get Feature and Target Columns from the existing given Model
|
|
146
172
|
features = model.features()
|
|
@@ -162,6 +188,20 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
|
|
|
162
188
|
return uq_model
|
|
163
189
|
|
|
164
190
|
|
|
191
|
+
def safe_extract_tarfile(tar_path: str, extract_path: str) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Extract a tarball safely, using data filter if available.
|
|
194
|
+
|
|
195
|
+
The filter parameter was backported to Python 3.8+, 3.9+, 3.10.13+, 3.11+
|
|
196
|
+
as a security patch, but may not be present in older patch versions.
|
|
197
|
+
"""
|
|
198
|
+
with tarfile.open(tar_path, "r:gz") as tar:
|
|
199
|
+
if hasattr(tarfile, "data_filter"):
|
|
200
|
+
tar.extractall(path=extract_path, filter="data")
|
|
201
|
+
else:
|
|
202
|
+
tar.extractall(path=extract_path)
|
|
203
|
+
|
|
204
|
+
|
|
165
205
|
def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
|
|
166
206
|
"""
|
|
167
207
|
Download and extract category mappings from a model artifact in S3.
|
|
@@ -180,8 +220,7 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
|
|
|
180
220
|
wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
|
|
181
221
|
|
|
182
222
|
# Extract tarball
|
|
183
|
-
|
|
184
|
-
tar.extractall(path=tmpdir, filter="data")
|
|
223
|
+
safe_extract_tarfile(local_tar_path, tmpdir)
|
|
185
224
|
|
|
186
225
|
# Look for category mappings in base directory only
|
|
187
226
|
mappings_path = os.path.join(tmpdir, "category_mappings.json")
|
|
@@ -220,28 +259,41 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
220
259
|
# --- Coverage and Interval Width ---
|
|
221
260
|
if "q_025" in df.columns and "q_975" in df.columns:
|
|
222
261
|
lower_95, upper_95 = df["q_025"], df["q_975"]
|
|
262
|
+
lower_90, upper_90 = df["q_05"], df["q_95"]
|
|
263
|
+
lower_80, upper_80 = df["q_10"], df["q_90"]
|
|
264
|
+
lower_68 = df.get("q_16", df["q_10"]) # fallback to 80% interval
|
|
265
|
+
upper_68 = df.get("q_84", df["q_90"]) # fallback to 80% interval
|
|
223
266
|
lower_50, upper_50 = df["q_25"], df["q_75"]
|
|
224
267
|
elif "prediction_std" in df.columns:
|
|
225
268
|
lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
|
|
226
269
|
upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
|
|
270
|
+
lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
|
|
271
|
+
upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
|
|
272
|
+
lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
|
|
273
|
+
upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
|
|
274
|
+
lower_68 = df["prediction"] - 1.0 * df["prediction_std"]
|
|
275
|
+
upper_68 = df["prediction"] + 1.0 * df["prediction_std"]
|
|
227
276
|
lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
|
|
228
277
|
upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
|
|
229
278
|
else:
|
|
230
279
|
raise ValueError(
|
|
231
280
|
"Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
|
|
232
281
|
)
|
|
282
|
+
median_std = df["prediction_std"].median()
|
|
233
283
|
coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
284
|
+
coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
|
|
285
|
+
coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
|
|
286
|
+
coverage_68 = np.mean((df[target_col] >= lower_68) & (df[target_col] <= upper_68))
|
|
287
|
+
median_width_95 = np.median(upper_95 - lower_95)
|
|
288
|
+
median_width_90 = np.median(upper_90 - lower_90)
|
|
289
|
+
median_width_80 = np.median(upper_80 - lower_80)
|
|
290
|
+
median_width_50 = np.median(upper_50 - lower_50)
|
|
291
|
+
median_width_68 = np.median(upper_68 - lower_68)
|
|
237
292
|
|
|
238
293
|
# --- CRPS (measures calibration + sharpness) ---
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
mean_crps = np.mean(crps)
|
|
243
|
-
else:
|
|
244
|
-
mean_crps = np.nan
|
|
294
|
+
z = (df[target_col] - df["prediction"]) / df["prediction_std"]
|
|
295
|
+
crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
|
|
296
|
+
mean_crps = np.mean(crps)
|
|
245
297
|
|
|
246
298
|
# --- Interval Score @ 95% (penalizes miscoverage) ---
|
|
247
299
|
alpha_95 = 0.05
|
|
@@ -252,31 +304,43 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
252
304
|
)
|
|
253
305
|
mean_is_95 = np.mean(is_95)
|
|
254
306
|
|
|
255
|
-
# ---
|
|
307
|
+
# --- Interval to Error Correlation ---
|
|
256
308
|
abs_residuals = np.abs(df[target_col] - df["prediction"])
|
|
257
|
-
|
|
258
|
-
|
|
309
|
+
width_68 = upper_68 - lower_68
|
|
310
|
+
|
|
311
|
+
# Spearman correlation for robustness
|
|
312
|
+
interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
|
|
259
313
|
|
|
260
314
|
# Collect results
|
|
261
315
|
results = {
|
|
316
|
+
"coverage_68": coverage_68,
|
|
317
|
+
"coverage_80": coverage_80,
|
|
318
|
+
"coverage_90": coverage_90,
|
|
262
319
|
"coverage_95": coverage_95,
|
|
263
|
-
"
|
|
264
|
-
"
|
|
265
|
-
"
|
|
266
|
-
"
|
|
267
|
-
"
|
|
268
|
-
"
|
|
320
|
+
"median_std": median_std,
|
|
321
|
+
"median_width_50": median_width_50,
|
|
322
|
+
"median_width_68": median_width_68,
|
|
323
|
+
"median_width_80": median_width_80,
|
|
324
|
+
"median_width_90": median_width_90,
|
|
325
|
+
"median_width_95": median_width_95,
|
|
326
|
+
"interval_to_error_corr": interval_to_error_corr,
|
|
269
327
|
"n_samples": len(df),
|
|
270
328
|
}
|
|
271
329
|
|
|
272
330
|
print("\n=== UQ Metrics ===")
|
|
331
|
+
print(f"Coverage @ 68%: {coverage_68:.3f} (target: 0.68)")
|
|
332
|
+
print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
|
|
333
|
+
print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
|
|
273
334
|
print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
|
|
274
|
-
print(f"
|
|
275
|
-
print(f"
|
|
276
|
-
print(f"
|
|
335
|
+
print(f"Median Prediction StdDev: {median_std:.3f}")
|
|
336
|
+
print(f"Median 50% Width: {median_width_50:.3f}")
|
|
337
|
+
print(f"Median 68% Width: {median_width_68:.3f}")
|
|
338
|
+
print(f"Median 80% Width: {median_width_80:.3f}")
|
|
339
|
+
print(f"Median 90% Width: {median_width_90:.3f}")
|
|
340
|
+
print(f"Median 95% Width: {median_width_95:.3f}")
|
|
277
341
|
print(f"CRPS: {mean_crps:.3f} (lower is better)")
|
|
278
342
|
print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
|
|
279
|
-
print(f"
|
|
343
|
+
print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
|
|
280
344
|
print(f"Samples: {len(df)}")
|
|
281
345
|
return results
|
|
282
346
|
|
|
@@ -313,9 +377,3 @@ if __name__ == "__main__":
|
|
|
313
377
|
df = end.auto_inference(capture=True)
|
|
314
378
|
results = uq_metrics(df, target_col="solubility")
|
|
315
379
|
print(results)
|
|
316
|
-
|
|
317
|
-
# Test the uq_metrics function
|
|
318
|
-
end = Endpoint("aqsol-uq-100")
|
|
319
|
-
df = end.auto_inference(capture=True)
|
|
320
|
-
results = uq_metrics(df, target_col="solubility")
|
|
321
|
-
print(results)
|
workbench/utils/monitor_utils.py
CHANGED
|
@@ -14,7 +14,7 @@ from workbench.utils.s3_utils import read_content_from_s3
|
|
|
14
14
|
log = logging.getLogger("workbench")
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def
|
|
17
|
+
def pull_data_capture_for_testing(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
|
|
18
18
|
"""
|
|
19
19
|
Read and process captured data from S3.
|
|
20
20
|
|
|
@@ -26,7 +26,12 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
|
|
|
26
26
|
|
|
27
27
|
Returns:
|
|
28
28
|
Union[pd.DataFrame, None]: A dataframe of the captured data (or None if no data is found).
|
|
29
|
+
|
|
30
|
+
Notes:
|
|
31
|
+
This method is really only for testing and debugging.
|
|
29
32
|
"""
|
|
33
|
+
log.important("This method is for testing and debugging only.")
|
|
34
|
+
|
|
30
35
|
# List files in the specified S3 path
|
|
31
36
|
files = wr.s3.list_objects(data_capture_path)
|
|
32
37
|
if not files:
|
|
@@ -64,59 +69,53 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
|
|
|
64
69
|
def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
65
70
|
"""
|
|
66
71
|
Process the captured data DataFrame to extract input and output data.
|
|
67
|
-
|
|
72
|
+
Handles cases where input or output might not be captured.
|
|
73
|
+
|
|
68
74
|
Args:
|
|
69
75
|
df (DataFrame): DataFrame with captured data.
|
|
70
76
|
Returns:
|
|
71
77
|
tuple[DataFrame, DataFrame]: Input and output DataFrames.
|
|
72
78
|
"""
|
|
79
|
+
|
|
80
|
+
def parse_endpoint_data(data: dict) -> pd.DataFrame:
|
|
81
|
+
"""Parse endpoint data based on encoding type."""
|
|
82
|
+
encoding = data["encoding"].upper()
|
|
83
|
+
|
|
84
|
+
if encoding == "CSV":
|
|
85
|
+
return pd.read_csv(StringIO(data["data"]))
|
|
86
|
+
elif encoding == "JSON":
|
|
87
|
+
json_data = json.loads(data["data"])
|
|
88
|
+
if isinstance(json_data, dict):
|
|
89
|
+
return pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
90
|
+
else:
|
|
91
|
+
return pd.DataFrame(json_data)
|
|
92
|
+
else:
|
|
93
|
+
return None # Unknown encoding
|
|
94
|
+
|
|
73
95
|
input_dfs = []
|
|
74
96
|
output_dfs = []
|
|
75
97
|
|
|
76
|
-
|
|
98
|
+
# Use itertuples() instead of iterrows() for better performance
|
|
99
|
+
for row in df.itertuples(index=True):
|
|
77
100
|
try:
|
|
78
|
-
capture_data = row
|
|
79
|
-
|
|
80
|
-
#
|
|
81
|
-
if "endpointInput"
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# Process input data
|
|
92
|
-
input_data = capture_data["endpointInput"]
|
|
93
|
-
if input_data["encoding"].upper() == "CSV":
|
|
94
|
-
input_df = pd.read_csv(StringIO(input_data["data"]))
|
|
95
|
-
elif input_data["encoding"].upper() == "JSON":
|
|
96
|
-
json_data = json.loads(input_data["data"])
|
|
97
|
-
if isinstance(json_data, dict):
|
|
98
|
-
input_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
99
|
-
else:
|
|
100
|
-
input_df = pd.DataFrame(json_data)
|
|
101
|
-
|
|
102
|
-
# Process output data
|
|
103
|
-
output_data = capture_data["endpointOutput"]
|
|
104
|
-
if output_data["encoding"].upper() == "CSV":
|
|
105
|
-
output_df = pd.read_csv(StringIO(output_data["data"]))
|
|
106
|
-
elif output_data["encoding"].upper() == "JSON":
|
|
107
|
-
json_data = json.loads(output_data["data"])
|
|
108
|
-
if isinstance(json_data, dict):
|
|
109
|
-
output_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
110
|
-
else:
|
|
111
|
-
output_df = pd.DataFrame(json_data)
|
|
112
|
-
|
|
113
|
-
# If we get here, both processed successfully
|
|
114
|
-
input_dfs.append(input_df)
|
|
115
|
-
output_dfs.append(output_df)
|
|
101
|
+
capture_data = row.captureData
|
|
102
|
+
|
|
103
|
+
# Process input data if present
|
|
104
|
+
if "endpointInput" in capture_data:
|
|
105
|
+
input_df = parse_endpoint_data(capture_data["endpointInput"])
|
|
106
|
+
if input_df is not None:
|
|
107
|
+
input_dfs.append(input_df)
|
|
108
|
+
|
|
109
|
+
# Process output data if present
|
|
110
|
+
if "endpointOutput" in capture_data:
|
|
111
|
+
output_df = parse_endpoint_data(capture_data["endpointOutput"])
|
|
112
|
+
if output_df is not None:
|
|
113
|
+
output_dfs.append(output_df)
|
|
116
114
|
|
|
117
115
|
except Exception as e:
|
|
118
|
-
log.
|
|
116
|
+
log.debug(f"Row {row.Index}: Failed to process row: {e}")
|
|
119
117
|
continue
|
|
118
|
+
|
|
120
119
|
# Combine and return results
|
|
121
120
|
return (
|
|
122
121
|
pd.concat(input_dfs, ignore_index=True) if input_dfs else pd.DataFrame(),
|
|
@@ -178,23 +177,6 @@ def parse_monitoring_results(results_json: str) -> Dict[str, Any]:
|
|
|
178
177
|
return {"error": str(e)}
|
|
179
178
|
|
|
180
179
|
|
|
181
|
-
"""TEMP
|
|
182
|
-
# If the status is "CompletedWithViolations", we grab the lastest
|
|
183
|
-
# violation file and add it to the result
|
|
184
|
-
if status == "CompletedWithViolations":
|
|
185
|
-
violation_file = f"{self.monitoring_path}/
|
|
186
|
-
{last_run['CreationTime'].strftime('%Y/%m/%d')}/constraint_violations.json"
|
|
187
|
-
if wr.s3.does_object_exist(violation_file):
|
|
188
|
-
violations_json = read_content_from_s3(violation_file)
|
|
189
|
-
violations = parse_monitoring_results(violations_json)
|
|
190
|
-
result["violations"] = violations.get("constraint_violations", [])
|
|
191
|
-
result["violation_count"] = len(result["violations"])
|
|
192
|
-
else:
|
|
193
|
-
result["violations"] = []
|
|
194
|
-
result["violation_count"] = 0
|
|
195
|
-
"""
|
|
196
|
-
|
|
197
|
-
|
|
198
180
|
def preprocessing_script(feature_list: list[str]) -> str:
|
|
199
181
|
"""
|
|
200
182
|
A preprocessing script for monitoring jobs.
|
|
@@ -245,8 +227,8 @@ if __name__ == "__main__":
|
|
|
245
227
|
from workbench.api.monitor import Monitor
|
|
246
228
|
|
|
247
229
|
# Test pulling data capture
|
|
248
|
-
mon = Monitor("
|
|
249
|
-
df =
|
|
230
|
+
mon = Monitor("abalone-regression-rt")
|
|
231
|
+
df = pull_data_capture_for_testing(mon.data_capture_path)
|
|
250
232
|
print("Data Capture:")
|
|
251
233
|
print(df.head())
|
|
252
234
|
|
|
@@ -262,4 +244,4 @@ if __name__ == "__main__":
|
|
|
262
244
|
# Test preprocessing script
|
|
263
245
|
script = preprocessing_script(["feature1", "feature2", "feature3"])
|
|
264
246
|
print("\nPreprocessing Script:")
|
|
265
|
-
print(script)
|
|
247
|
+
# print(script)
|
workbench/utils/pandas_utils.py
CHANGED
|
@@ -152,7 +152,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
152
152
|
|
|
153
153
|
# Check for differences in common columns
|
|
154
154
|
for column in common_columns:
|
|
155
|
-
if pd.api.types.is_string_dtype(df1[column])
|
|
155
|
+
if pd.api.types.is_string_dtype(df1[column]) and pd.api.types.is_string_dtype(df2[column]):
|
|
156
156
|
# String comparison with NaNs treated as equal
|
|
157
157
|
differences = ~(df1[column].fillna("") == df2[column].fillna(""))
|
|
158
158
|
elif pd.api.types.is_float_dtype(df1[column]) or pd.api.types.is_float_dtype(df2[column]):
|
|
@@ -161,8 +161,8 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
161
161
|
pd.isna(df1[column]) & pd.isna(df2[column])
|
|
162
162
|
)
|
|
163
163
|
else:
|
|
164
|
-
# Other types (
|
|
165
|
-
differences =
|
|
164
|
+
# Other types (int, Int64, etc.) - compare with NaNs treated as equal
|
|
165
|
+
differences = (df1[column] != df2[column]) & ~(pd.isna(df1[column]) & pd.isna(df2[column]))
|
|
166
166
|
|
|
167
167
|
# If differences exist, display them
|
|
168
168
|
if differences.any():
|
workbench/utils/shap_utils.py
CHANGED
|
@@ -212,6 +212,14 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
|
|
|
212
212
|
log.error("No XGBoost model found in the artifact.")
|
|
213
213
|
return None, None, None, None
|
|
214
214
|
|
|
215
|
+
# Get the booster (SHAP requires the booster, not the sklearn wrapper)
|
|
216
|
+
if hasattr(xgb_model, "get_booster"):
|
|
217
|
+
# Full sklearn model - extract the booster
|
|
218
|
+
booster = xgb_model.get_booster()
|
|
219
|
+
else:
|
|
220
|
+
# Already a booster
|
|
221
|
+
booster = xgb_model
|
|
222
|
+
|
|
215
223
|
# Load category mappings if available
|
|
216
224
|
category_mappings = load_category_mappings_from_s3(model_artifact_uri)
|
|
217
225
|
|
|
@@ -229,8 +237,8 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
|
|
|
229
237
|
# Create a DMatrix with categorical support
|
|
230
238
|
dmatrix = xgb.DMatrix(X, enable_categorical=True)
|
|
231
239
|
|
|
232
|
-
# Use XGBoost's built-in SHAP calculation
|
|
233
|
-
shap_values =
|
|
240
|
+
# Use XGBoost's built-in SHAP calculation (booster method, not sklearn)
|
|
241
|
+
shap_values = booster.predict(dmatrix, pred_contribs=True, strict_shape=True)
|
|
234
242
|
features_with_bias = features + ["bias"]
|
|
235
243
|
|
|
236
244
|
# Now we need to subset the columns based on top 10 SHAP values
|
workbench/utils/workbench_sqs.py
CHANGED
|
@@ -12,7 +12,7 @@ class WorkbenchSQS:
|
|
|
12
12
|
self.log = logging.getLogger("workbench")
|
|
13
13
|
self.queue_url = queue_url
|
|
14
14
|
|
|
15
|
-
# Grab a Workbench Session
|
|
15
|
+
# Grab a Workbench Session
|
|
16
16
|
self.boto3_session = AWSAccountClamp().boto3_session
|
|
17
17
|
print(self.boto3_session)
|
|
18
18
|
|