workbench 0.8.168__py3-none-any.whl → 0.8.193__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +3 -2
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/model.py +16 -12
  7. workbench/api/monitor.py +1 -16
  8. workbench/core/artifacts/artifact.py +11 -3
  9. workbench/core/artifacts/data_capture_core.py +355 -0
  10. workbench/core/artifacts/endpoint_core.py +113 -27
  11. workbench/core/artifacts/feature_set_core.py +72 -13
  12. workbench/core/artifacts/model_core.py +71 -49
  13. workbench/core/artifacts/monitor_core.py +33 -249
  14. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  15. workbench/core/cloud_platform/aws/aws_meta.py +11 -4
  16. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  17. workbench/core/transforms/features_to_model/features_to_model.py +11 -6
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  19. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  20. workbench/core/views/training_view.py +49 -53
  21. workbench/core/views/view.py +51 -1
  22. workbench/core/views/view_utils.py +4 -4
  23. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  24. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  25. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  27. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  28. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  29. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  30. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  31. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  32. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  33. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  34. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  35. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  36. workbench/model_scripts/pytorch_model/pytorch.template +9 -18
  37. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  38. workbench/model_scripts/script_generation.py +7 -2
  39. workbench/model_scripts/uq_models/mapie.template +492 -0
  40. workbench/model_scripts/uq_models/requirements.txt +1 -0
  41. workbench/model_scripts/xgb_model/generated_model_script.py +34 -43
  42. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  43. workbench/repl/workbench_shell.py +4 -4
  44. workbench/scripts/lambda_launcher.py +63 -0
  45. workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
  46. workbench/scripts/ml_pipeline_sqs.py +186 -0
  47. workbench/utils/chem_utils/__init__.py +0 -0
  48. workbench/utils/chem_utils/fingerprints.py +134 -0
  49. workbench/utils/chem_utils/misc.py +194 -0
  50. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  51. workbench/utils/chem_utils/mol_standardize.py +450 -0
  52. workbench/utils/chem_utils/mol_tagging.py +348 -0
  53. workbench/utils/chem_utils/projections.py +209 -0
  54. workbench/utils/chem_utils/salts.py +256 -0
  55. workbench/utils/chem_utils/sdf.py +292 -0
  56. workbench/utils/chem_utils/toxicity.py +250 -0
  57. workbench/utils/chem_utils/vis.py +253 -0
  58. workbench/utils/config_manager.py +2 -6
  59. workbench/utils/endpoint_utils.py +5 -7
  60. workbench/utils/license_manager.py +2 -6
  61. workbench/utils/model_utils.py +89 -31
  62. workbench/utils/monitor_utils.py +44 -62
  63. workbench/utils/pandas_utils.py +3 -3
  64. workbench/utils/shap_utils.py +10 -2
  65. workbench/utils/workbench_sqs.py +1 -1
  66. workbench/utils/xgboost_model_utils.py +300 -151
  67. workbench/web_interface/components/model_plot.py +7 -1
  68. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  69. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  70. workbench/web_interface/components/plugins/model_details.py +7 -2
  71. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  72. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/METADATA +24 -2
  73. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/RECORD +77 -72
  74. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/entry_points.txt +3 -1
  75. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/licenses/LICENSE +1 -1
  76. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  77. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  78. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  79. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  80. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  81. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  82. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
  83. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  84. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  85. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  86. workbench/utils/chem_utils.py +0 -1556
  87. workbench/utils/fast_inference.py +0 -167
  88. workbench/utils/resource_utils.py +0 -39
  89. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/WHEEL +0 -0
  90. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  import logging
4
4
  import pandas as pd
5
5
  import numpy as np
6
+ from scipy.stats import spearmanr
6
7
  import importlib.resources
7
8
  from pathlib import Path
8
9
  import os
@@ -92,6 +93,31 @@ def get_custom_script_path(package: str, script_name: str) -> Path:
92
93
  return script_path
93
94
 
94
95
 
96
+ def proximity_model_local(model: "Model"):
97
+ """Create a Proximity Model for this Model
98
+
99
+ Args:
100
+ model (Model): The Model/FeatureSet used to create the proximity model
101
+
102
+ Returns:
103
+ Proximity: The proximity model
104
+ """
105
+ from workbench.algorithms.dataframe.proximity import Proximity # noqa: F401 (avoid circular import)
106
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
107
+
108
+ # Get Feature and Target Columns from the existing given Model
109
+ features = model.features()
110
+ target = model.target()
111
+
112
+ # Backtrack our FeatureSet to get the ID column
113
+ fs = FeatureSet(model.get_input())
114
+ id_column = fs.id_column
115
+
116
+ # Create the Proximity Model from our Training Data
117
+ df = model.training_view().pull_dataframe()
118
+ return Proximity(df, id_column, features, target, track_columns=features)
119
+
120
+
95
121
  def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
96
122
  """Create a proximity model based on the given model
97
123
 
@@ -140,7 +166,7 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
140
166
  from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
141
167
 
142
168
  # Get the custom script path for the UQ model
143
- script_path = get_custom_script_path("uq_models", "meta_uq.template")
169
+ script_path = get_custom_script_path("uq_models", "mapie.template")
144
170
 
145
171
  # Get Feature and Target Columns from the existing given Model
146
172
  features = model.features()
@@ -162,6 +188,20 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
162
188
  return uq_model
163
189
 
164
190
 
191
+ def safe_extract_tarfile(tar_path: str, extract_path: str) -> None:
192
+ """
193
+ Extract a tarball safely, using data filter if available.
194
+
195
+ The filter parameter was backported to Python 3.8+, 3.9+, 3.10.13+, 3.11+
196
+ as a security patch, but may not be present in older patch versions.
197
+ """
198
+ with tarfile.open(tar_path, "r:gz") as tar:
199
+ if hasattr(tarfile, "data_filter"):
200
+ tar.extractall(path=extract_path, filter="data")
201
+ else:
202
+ tar.extractall(path=extract_path)
203
+
204
+
165
205
  def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
166
206
  """
167
207
  Download and extract category mappings from a model artifact in S3.
@@ -180,8 +220,7 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
180
220
  wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
181
221
 
182
222
  # Extract tarball
183
- with tarfile.open(local_tar_path, "r:gz") as tar:
184
- tar.extractall(path=tmpdir, filter="data")
223
+ safe_extract_tarfile(local_tar_path, tmpdir)
185
224
 
186
225
  # Look for category mappings in base directory only
187
226
  mappings_path = os.path.join(tmpdir, "category_mappings.json")
@@ -220,28 +259,41 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
220
259
  # --- Coverage and Interval Width ---
221
260
  if "q_025" in df.columns and "q_975" in df.columns:
222
261
  lower_95, upper_95 = df["q_025"], df["q_975"]
262
+ lower_90, upper_90 = df["q_05"], df["q_95"]
263
+ lower_80, upper_80 = df["q_10"], df["q_90"]
264
+ lower_68 = df.get("q_16", df["q_10"]) # fallback to 80% interval
265
+ upper_68 = df.get("q_84", df["q_90"]) # fallback to 80% interval
223
266
  lower_50, upper_50 = df["q_25"], df["q_75"]
224
267
  elif "prediction_std" in df.columns:
225
268
  lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
226
269
  upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
270
+ lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
271
+ upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
272
+ lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
273
+ upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
274
+ lower_68 = df["prediction"] - 1.0 * df["prediction_std"]
275
+ upper_68 = df["prediction"] + 1.0 * df["prediction_std"]
227
276
  lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
228
277
  upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
229
278
  else:
230
279
  raise ValueError(
231
280
  "Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
232
281
  )
282
+ median_std = df["prediction_std"].median()
233
283
  coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
234
- coverage_50 = np.mean((df[target_col] >= lower_50) & (df[target_col] <= upper_50))
235
- avg_width_95 = np.mean(upper_95 - lower_95)
236
- avg_width_50 = np.mean(upper_50 - lower_50)
284
+ coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
285
+ coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
286
+ coverage_68 = np.mean((df[target_col] >= lower_68) & (df[target_col] <= upper_68))
287
+ median_width_95 = np.median(upper_95 - lower_95)
288
+ median_width_90 = np.median(upper_90 - lower_90)
289
+ median_width_80 = np.median(upper_80 - lower_80)
290
+ median_width_50 = np.median(upper_50 - lower_50)
291
+ median_width_68 = np.median(upper_68 - lower_68)
237
292
 
238
293
  # --- CRPS (measures calibration + sharpness) ---
239
- if "prediction_std" in df.columns:
240
- z = (df[target_col] - df["prediction"]) / df["prediction_std"]
241
- crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
242
- mean_crps = np.mean(crps)
243
- else:
244
- mean_crps = np.nan
294
+ z = (df[target_col] - df["prediction"]) / df["prediction_std"]
295
+ crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
296
+ mean_crps = np.mean(crps)
245
297
 
246
298
  # --- Interval Score @ 95% (penalizes miscoverage) ---
247
299
  alpha_95 = 0.05
@@ -252,31 +304,43 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
252
304
  )
253
305
  mean_is_95 = np.mean(is_95)
254
306
 
255
- # --- Adaptive Calibration (correlation between errors and uncertainty) ---
307
+ # --- Interval to Error Correlation ---
256
308
  abs_residuals = np.abs(df[target_col] - df["prediction"])
257
- width_95 = upper_95 - lower_95
258
- adaptive_calibration = np.corrcoef(abs_residuals, width_95)[0, 1]
309
+ width_68 = upper_68 - lower_68
310
+
311
+ # Spearman correlation for robustness
312
+ interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
259
313
 
260
314
  # Collect results
261
315
  results = {
316
+ "coverage_68": coverage_68,
317
+ "coverage_80": coverage_80,
318
+ "coverage_90": coverage_90,
262
319
  "coverage_95": coverage_95,
263
- "coverage_50": coverage_50,
264
- "avg_width_95": avg_width_95,
265
- "avg_width_50": avg_width_50,
266
- "crps": mean_crps,
267
- "interval_score_95": mean_is_95,
268
- "adaptive_calibration": adaptive_calibration,
320
+ "median_std": median_std,
321
+ "median_width_50": median_width_50,
322
+ "median_width_68": median_width_68,
323
+ "median_width_80": median_width_80,
324
+ "median_width_90": median_width_90,
325
+ "median_width_95": median_width_95,
326
+ "interval_to_error_corr": interval_to_error_corr,
269
327
  "n_samples": len(df),
270
328
  }
271
329
 
272
330
  print("\n=== UQ Metrics ===")
331
+ print(f"Coverage @ 68%: {coverage_68:.3f} (target: 0.68)")
332
+ print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
333
+ print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
273
334
  print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
274
- print(f"Coverage @ 50%: {coverage_50:.3f} (target: 0.50)")
275
- print(f"Average 95% Width: {avg_width_95:.3f}")
276
- print(f"Average 50% Width: {avg_width_50:.3f}")
335
+ print(f"Median Prediction StdDev: {median_std:.3f}")
336
+ print(f"Median 50% Width: {median_width_50:.3f}")
337
+ print(f"Median 68% Width: {median_width_68:.3f}")
338
+ print(f"Median 80% Width: {median_width_80:.3f}")
339
+ print(f"Median 90% Width: {median_width_90:.3f}")
340
+ print(f"Median 95% Width: {median_width_95:.3f}")
277
341
  print(f"CRPS: {mean_crps:.3f} (lower is better)")
278
342
  print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
279
- print(f"Adaptive Calibration: {adaptive_calibration:.3f} (higher is better, target: >0.5)")
343
+ print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
280
344
  print(f"Samples: {len(df)}")
281
345
  return results
282
346
 
@@ -313,9 +377,3 @@ if __name__ == "__main__":
313
377
  df = end.auto_inference(capture=True)
314
378
  results = uq_metrics(df, target_col="solubility")
315
379
  print(results)
316
-
317
- # Test the uq_metrics function
318
- end = Endpoint("aqsol-uq-100")
319
- df = end.auto_inference(capture=True)
320
- results = uq_metrics(df, target_col="solubility")
321
- print(results)
@@ -14,7 +14,7 @@ from workbench.utils.s3_utils import read_content_from_s3
14
14
  log = logging.getLogger("workbench")
15
15
 
16
16
 
17
- def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
17
+ def pull_data_capture_for_testing(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
18
18
  """
19
19
  Read and process captured data from S3.
20
20
 
@@ -26,7 +26,12 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
26
26
 
27
27
  Returns:
28
28
  Union[pd.DataFrame, None]: A dataframe of the captured data (or None if no data is found).
29
+
30
+ Notes:
31
+ This method is really only for testing and debugging.
29
32
  """
33
+ log.important("This method is for testing and debugging only.")
34
+
30
35
  # List files in the specified S3 path
31
36
  files = wr.s3.list_objects(data_capture_path)
32
37
  if not files:
@@ -64,59 +69,53 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
64
69
  def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
65
70
  """
66
71
  Process the captured data DataFrame to extract input and output data.
67
- Continues processing even if individual files are malformed.
72
+ Handles cases where input or output might not be captured.
73
+
68
74
  Args:
69
75
  df (DataFrame): DataFrame with captured data.
70
76
  Returns:
71
77
  tuple[DataFrame, DataFrame]: Input and output DataFrames.
72
78
  """
79
+
80
+ def parse_endpoint_data(data: dict) -> pd.DataFrame:
81
+ """Parse endpoint data based on encoding type."""
82
+ encoding = data["encoding"].upper()
83
+
84
+ if encoding == "CSV":
85
+ return pd.read_csv(StringIO(data["data"]))
86
+ elif encoding == "JSON":
87
+ json_data = json.loads(data["data"])
88
+ if isinstance(json_data, dict):
89
+ return pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
90
+ else:
91
+ return pd.DataFrame(json_data)
92
+ else:
93
+ return None # Unknown encoding
94
+
73
95
  input_dfs = []
74
96
  output_dfs = []
75
97
 
76
- for idx, row in df.iterrows():
98
+ # Use itertuples() instead of iterrows() for better performance
99
+ for row in df.itertuples(index=True):
77
100
  try:
78
- capture_data = row["captureData"]
79
-
80
- # Check if this capture has the required fields (all or nothing)
81
- if "endpointInput" not in capture_data:
82
- log.warning(f"Row {idx}: No endpointInput found in capture data.")
83
- continue
84
-
85
- if "endpointOutput" not in capture_data:
86
- log.critical(
87
- f"Row {idx}: No endpointOutput found in capture data. DataCapture needs to include Output capture!"
88
- )
89
- continue
90
-
91
- # Process input data
92
- input_data = capture_data["endpointInput"]
93
- if input_data["encoding"].upper() == "CSV":
94
- input_df = pd.read_csv(StringIO(input_data["data"]))
95
- elif input_data["encoding"].upper() == "JSON":
96
- json_data = json.loads(input_data["data"])
97
- if isinstance(json_data, dict):
98
- input_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
99
- else:
100
- input_df = pd.DataFrame(json_data)
101
-
102
- # Process output data
103
- output_data = capture_data["endpointOutput"]
104
- if output_data["encoding"].upper() == "CSV":
105
- output_df = pd.read_csv(StringIO(output_data["data"]))
106
- elif output_data["encoding"].upper() == "JSON":
107
- json_data = json.loads(output_data["data"])
108
- if isinstance(json_data, dict):
109
- output_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
110
- else:
111
- output_df = pd.DataFrame(json_data)
112
-
113
- # If we get here, both processed successfully
114
- input_dfs.append(input_df)
115
- output_dfs.append(output_df)
101
+ capture_data = row.captureData
102
+
103
+ # Process input data if present
104
+ if "endpointInput" in capture_data:
105
+ input_df = parse_endpoint_data(capture_data["endpointInput"])
106
+ if input_df is not None:
107
+ input_dfs.append(input_df)
108
+
109
+ # Process output data if present
110
+ if "endpointOutput" in capture_data:
111
+ output_df = parse_endpoint_data(capture_data["endpointOutput"])
112
+ if output_df is not None:
113
+ output_dfs.append(output_df)
116
114
 
117
115
  except Exception as e:
118
- log.error(f"Row {idx}: Failed to process row: {e}")
116
+ log.debug(f"Row {row.Index}: Failed to process row: {e}")
119
117
  continue
118
+
120
119
  # Combine and return results
121
120
  return (
122
121
  pd.concat(input_dfs, ignore_index=True) if input_dfs else pd.DataFrame(),
@@ -178,23 +177,6 @@ def parse_monitoring_results(results_json: str) -> Dict[str, Any]:
178
177
  return {"error": str(e)}
179
178
 
180
179
 
181
- """TEMP
182
- # If the status is "CompletedWithViolations", we grab the lastest
183
- # violation file and add it to the result
184
- if status == "CompletedWithViolations":
185
- violation_file = f"{self.monitoring_path}/
186
- {last_run['CreationTime'].strftime('%Y/%m/%d')}/constraint_violations.json"
187
- if wr.s3.does_object_exist(violation_file):
188
- violations_json = read_content_from_s3(violation_file)
189
- violations = parse_monitoring_results(violations_json)
190
- result["violations"] = violations.get("constraint_violations", [])
191
- result["violation_count"] = len(result["violations"])
192
- else:
193
- result["violations"] = []
194
- result["violation_count"] = 0
195
- """
196
-
197
-
198
180
  def preprocessing_script(feature_list: list[str]) -> str:
199
181
  """
200
182
  A preprocessing script for monitoring jobs.
@@ -245,8 +227,8 @@ if __name__ == "__main__":
245
227
  from workbench.api.monitor import Monitor
246
228
 
247
229
  # Test pulling data capture
248
- mon = Monitor("caco2-pappab-class-0")
249
- df = pull_data_capture(mon.data_capture_path)
230
+ mon = Monitor("abalone-regression-rt")
231
+ df = pull_data_capture_for_testing(mon.data_capture_path)
250
232
  print("Data Capture:")
251
233
  print(df.head())
252
234
 
@@ -262,4 +244,4 @@ if __name__ == "__main__":
262
244
  # Test preprocessing script
263
245
  script = preprocessing_script(["feature1", "feature2", "feature3"])
264
246
  print("\nPreprocessing Script:")
265
- print(script)
247
+ # print(script)
@@ -152,7 +152,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
152
152
 
153
153
  # Check for differences in common columns
154
154
  for column in common_columns:
155
- if pd.api.types.is_string_dtype(df1[column]) or pd.api.types.is_string_dtype(df2[column]):
155
+ if pd.api.types.is_string_dtype(df1[column]) and pd.api.types.is_string_dtype(df2[column]):
156
156
  # String comparison with NaNs treated as equal
157
157
  differences = ~(df1[column].fillna("") == df2[column].fillna(""))
158
158
  elif pd.api.types.is_float_dtype(df1[column]) or pd.api.types.is_float_dtype(df2[column]):
@@ -161,8 +161,8 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
161
161
  pd.isna(df1[column]) & pd.isna(df2[column])
162
162
  )
163
163
  else:
164
- # Other types (e.g., int) with NaNs treated as equal
165
- differences = ~(df1[column].fillna(0) == df2[column].fillna(0))
164
+ # Other types (int, Int64, etc.) - compare with NaNs treated as equal
165
+ differences = (df1[column] != df2[column]) & ~(pd.isna(df1[column]) & pd.isna(df2[column]))
166
166
 
167
167
  # If differences exist, display them
168
168
  if differences.any():
@@ -212,6 +212,14 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
212
212
  log.error("No XGBoost model found in the artifact.")
213
213
  return None, None, None, None
214
214
 
215
+ # Get the booster (SHAP requires the booster, not the sklearn wrapper)
216
+ if hasattr(xgb_model, "get_booster"):
217
+ # Full sklearn model - extract the booster
218
+ booster = xgb_model.get_booster()
219
+ else:
220
+ # Already a booster
221
+ booster = xgb_model
222
+
215
223
  # Load category mappings if available
216
224
  category_mappings = load_category_mappings_from_s3(model_artifact_uri)
217
225
 
@@ -229,8 +237,8 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
229
237
  # Create a DMatrix with categorical support
230
238
  dmatrix = xgb.DMatrix(X, enable_categorical=True)
231
239
 
232
- # Use XGBoost's built-in SHAP calculation
233
- shap_values = xgb_model.predict(dmatrix, pred_contribs=True, strict_shape=True)
240
+ # Use XGBoost's built-in SHAP calculation (booster method, not sklearn)
241
+ shap_values = booster.predict(dmatrix, pred_contribs=True, strict_shape=True)
234
242
  features_with_bias = features + ["bias"]
235
243
 
236
244
  # Now we need to subset the columns based on top 10 SHAP values
@@ -12,7 +12,7 @@ class WorkbenchSQS:
12
12
  self.log = logging.getLogger("workbench")
13
13
  self.queue_url = queue_url
14
14
 
15
- # Grab a Workbench Session (this allows us to assume the Workbench-ExecutionRole)
15
+ # Grab a Workbench Session
16
16
  self.boto3_session = AWSAccountClamp().boto3_session
17
17
  print(self.boto3_session)
18
18