workbench 0.8.178__py3-none-any.whl → 0.8.180__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (26) hide show
  1. workbench/api/endpoint.py +3 -2
  2. workbench/core/artifacts/endpoint_core.py +5 -5
  3. workbench/core/artifacts/feature_set_core.py +32 -2
  4. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  5. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  6. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  7. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  8. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +37 -34
  9. workbench/model_scripts/custom_models/uq_models/mapie.template +35 -32
  10. workbench/model_scripts/custom_models/uq_models/meta_uq.template +7 -22
  11. workbench/model_scripts/custom_models/uq_models/ngboost.template +5 -12
  12. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  13. workbench/model_scripts/pytorch_model/pytorch.template +9 -18
  14. workbench/model_scripts/quant_regression/quant_regression.template +5 -10
  15. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  16. workbench/model_scripts/xgb_model/generated_model_script.py +24 -33
  17. workbench/model_scripts/xgb_model/xgb_model.template +23 -32
  18. workbench/utils/model_utils.py +2 -1
  19. workbench/utils/shap_utils.py +10 -2
  20. workbench/utils/xgboost_model_utils.py +160 -137
  21. {workbench-0.8.178.dist-info → workbench-0.8.180.dist-info}/METADATA +1 -1
  22. {workbench-0.8.178.dist-info → workbench-0.8.180.dist-info}/RECORD +26 -26
  23. {workbench-0.8.178.dist-info → workbench-0.8.180.dist-info}/WHEEL +0 -0
  24. {workbench-0.8.178.dist-info → workbench-0.8.180.dist-info}/entry_points.txt +0 -0
  25. {workbench-0.8.178.dist-info → workbench-0.8.180.dist-info}/licenses/LICENSE +0 -0
  26. {workbench-0.8.178.dist-info → workbench-0.8.180.dist-info}/top_level.txt +0 -0
@@ -32,10 +32,12 @@ TEMPLATE_PARAMS = {
32
32
  "target": "udm_asy_res_value",
33
33
  "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
34
34
  "compressed_features": [],
35
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/pka-a1-reg-0-nightly-100-test/training",
36
- "train_all_data": True
35
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/logd-hyper-80/training",
36
+ "train_all_data": False,
37
+ "hyperparameters": {'n_estimators': 200, 'max_depth': 6, 'learning_rate': 0.05, 'subsample': 0.7, 'colsample_bytree': 0.3, 'colsample_bylevel': 0.5, 'min_child_weight': 5, 'gamma': 0.2, 'reg_alpha': 0.5, 'reg_lambda': 2.0, 'scale_pos_weight': 1},
37
38
  }
38
39
 
40
+
39
41
  # Function to check if dataframe is empty
40
42
  def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
41
43
  """
@@ -75,7 +77,7 @@ def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFra
75
77
  proba_df = pd.DataFrame(df[proba_column].tolist(), columns=proba_splits)
76
78
 
77
79
  # Drop any proba columns and reset the index in prep for the concat
78
- df = df.drop(columns=[proba_column]+proba_splits, errors="ignore")
80
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
79
81
  df = df.reset_index(drop=True)
80
82
 
81
83
  # Concatenate the new columns with the original DataFrame
@@ -140,8 +142,10 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
140
142
  return df, category_mappings
141
143
 
142
144
 
143
- def decompress_features(df: pd.DataFrame, features: List[str], compressed_features: List[str]) -> Tuple[pd.DataFrame, List[str]]:
144
- """Prepare features for the XGBoost model
145
+ def decompress_features(
146
+ df: pd.DataFrame, features: List[str], compressed_features: List[str]
147
+ ) -> Tuple[pd.DataFrame, List[str]]:
148
+ """Prepare features for the model
145
149
 
146
150
  Args:
147
151
  df (pd.DataFrame): The features DataFrame
@@ -204,6 +208,7 @@ if __name__ == "__main__":
204
208
  model_type = TEMPLATE_PARAMS["model_type"]
205
209
  model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
206
210
  train_all_data = TEMPLATE_PARAMS["train_all_data"]
211
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
207
212
  validation_split = 0.2
208
213
 
209
214
  # Script arguments for input/output directories
@@ -216,11 +221,7 @@ if __name__ == "__main__":
216
221
  args = parser.parse_args()
217
222
 
218
223
  # Read the training data into DataFrames
219
- training_files = [
220
- os.path.join(args.train, file)
221
- for file in os.listdir(args.train)
222
- if file.endswith(".csv")
223
- ]
224
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
224
225
  print(f"Training Files: {training_files}")
225
226
 
226
227
  # Combine files and read them all into a single pandas dataframe
@@ -255,15 +256,16 @@ if __name__ == "__main__":
255
256
  else:
256
257
  # Just do a random training Split
257
258
  print("WARNING: No training column found, splitting data with random state=42")
258
- df_train, df_val = train_test_split(
259
- all_df, test_size=validation_split, random_state=42
260
- )
259
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
261
260
  print(f"FIT/TRAIN: {df_train.shape}")
262
261
  print(f"VALIDATION: {df_val.shape}")
263
262
 
263
+ # Use any hyperparameters to set up both the trainer and model configurations
264
+ print(f"Hyperparameters: {hyperparameters}")
265
+
264
266
  # Now spin up our XGB Model
265
267
  if model_type == "classifier":
266
- xgb_model = xgb.XGBClassifier(enable_categorical=True)
268
+ xgb_model = xgb.XGBClassifier(enable_categorical=True, **hyperparameters)
267
269
 
268
270
  # Encode the target column
269
271
  label_encoder = LabelEncoder()
@@ -271,12 +273,12 @@ if __name__ == "__main__":
271
273
  df_val[target] = label_encoder.transform(df_val[target])
272
274
 
273
275
  else:
274
- xgb_model = xgb.XGBRegressor(enable_categorical=True)
276
+ xgb_model = xgb.XGBRegressor(enable_categorical=True, **hyperparameters)
275
277
  label_encoder = None # We don't need this for regression
276
278
 
277
279
  # Grab our Features, Target and Train the Model
278
280
  y_train = df_train[target]
279
- X_train= df_train[features]
281
+ X_train = df_train[features]
280
282
  xgb_model.fit(X_train, y_train)
281
283
 
282
284
  # Make Predictions on the Validation Set
@@ -315,9 +317,7 @@ if __name__ == "__main__":
315
317
  label_names = label_encoder.classes_
316
318
 
317
319
  # Calculate various model performance metrics
318
- scores = precision_recall_fscore_support(
319
- y_validate, preds, average=None, labels=label_names
320
- )
320
+ scores = precision_recall_fscore_support(y_validate, preds, average=None, labels=label_names)
321
321
 
322
322
  # Put the scores into a dataframe
323
323
  score_df = pd.DataFrame(
@@ -355,7 +355,9 @@ if __name__ == "__main__":
355
355
  print(f"NumRows: {len(df_val)}")
356
356
 
357
357
  # Now save the model to the standard place/name
358
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
358
+ joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
359
+
360
+ # Save the label encoder if we have one
359
361
  if label_encoder:
360
362
  joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
361
363
 
@@ -370,19 +372,8 @@ if __name__ == "__main__":
370
372
 
371
373
  def model_fn(model_dir):
372
374
  """Deserialize and return fitted XGBoost model"""
373
-
374
- model_path = os.path.join(model_dir, "xgb_model.json")
375
-
376
- with open(model_path, "r") as f:
377
- model_json = json.load(f)
378
-
379
- sklearn_data = model_json['learner']['attributes']['scikit_learn']
380
- model_type = json.loads(sklearn_data)['_estimator_type']
381
-
382
- model_class = xgb.XGBClassifier if model_type == "classifier" else xgb.XGBRegressor
383
- model = model_class(enable_categorical=True)
384
- model.load_model(model_path)
385
-
375
+ model_path = os.path.join(model_dir, "xgb_model.joblib")
376
+ model = joblib.load(model_path)
386
377
  return model
387
378
 
388
379
 
@@ -33,9 +33,11 @@ TEMPLATE_PARAMS = {
33
33
  "features": "{{feature_list}}",
34
34
  "compressed_features": "{{compressed_features}}",
35
35
  "model_metrics_s3_path": "{{model_metrics_s3_path}}",
36
- "train_all_data": "{{train_all_data}}"
36
+ "train_all_data": "{{train_all_data}}",
37
+ "hyperparameters": "{{hyperparameters}}",
37
38
  }
38
39
 
40
+
39
41
  # Function to check if dataframe is empty
40
42
  def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
41
43
  """
@@ -75,7 +77,7 @@ def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFra
75
77
  proba_df = pd.DataFrame(df[proba_column].tolist(), columns=proba_splits)
76
78
 
77
79
  # Drop any proba columns and reset the index in prep for the concat
78
- df = df.drop(columns=[proba_column]+proba_splits, errors="ignore")
80
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
79
81
  df = df.reset_index(drop=True)
80
82
 
81
83
  # Concatenate the new columns with the original DataFrame
@@ -140,8 +142,10 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
140
142
  return df, category_mappings
141
143
 
142
144
 
143
- def decompress_features(df: pd.DataFrame, features: List[str], compressed_features: List[str]) -> Tuple[pd.DataFrame, List[str]]:
144
- """Prepare features for the XGBoost model
145
+ def decompress_features(
146
+ df: pd.DataFrame, features: List[str], compressed_features: List[str]
147
+ ) -> Tuple[pd.DataFrame, List[str]]:
148
+ """Prepare features for the model
145
149
 
146
150
  Args:
147
151
  df (pd.DataFrame): The features DataFrame
@@ -204,6 +208,7 @@ if __name__ == "__main__":
204
208
  model_type = TEMPLATE_PARAMS["model_type"]
205
209
  model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
206
210
  train_all_data = TEMPLATE_PARAMS["train_all_data"]
211
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
207
212
  validation_split = 0.2
208
213
 
209
214
  # Script arguments for input/output directories
@@ -216,11 +221,7 @@ if __name__ == "__main__":
216
221
  args = parser.parse_args()
217
222
 
218
223
  # Read the training data into DataFrames
219
- training_files = [
220
- os.path.join(args.train, file)
221
- for file in os.listdir(args.train)
222
- if file.endswith(".csv")
223
- ]
224
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
224
225
  print(f"Training Files: {training_files}")
225
226
 
226
227
  # Combine files and read them all into a single pandas dataframe
@@ -255,15 +256,16 @@ if __name__ == "__main__":
255
256
  else:
256
257
  # Just do a random training Split
257
258
  print("WARNING: No training column found, splitting data with random state=42")
258
- df_train, df_val = train_test_split(
259
- all_df, test_size=validation_split, random_state=42
260
- )
259
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
261
260
  print(f"FIT/TRAIN: {df_train.shape}")
262
261
  print(f"VALIDATION: {df_val.shape}")
263
262
 
263
+ # Use any hyperparameters to set up both the trainer and model configurations
264
+ print(f"Hyperparameters: {hyperparameters}")
265
+
264
266
  # Now spin up our XGB Model
265
267
  if model_type == "classifier":
266
- xgb_model = xgb.XGBClassifier(enable_categorical=True)
268
+ xgb_model = xgb.XGBClassifier(enable_categorical=True, **hyperparameters)
267
269
 
268
270
  # Encode the target column
269
271
  label_encoder = LabelEncoder()
@@ -271,12 +273,12 @@ if __name__ == "__main__":
271
273
  df_val[target] = label_encoder.transform(df_val[target])
272
274
 
273
275
  else:
274
- xgb_model = xgb.XGBRegressor(enable_categorical=True)
276
+ xgb_model = xgb.XGBRegressor(enable_categorical=True, **hyperparameters)
275
277
  label_encoder = None # We don't need this for regression
276
278
 
277
279
  # Grab our Features, Target and Train the Model
278
280
  y_train = df_train[target]
279
- X_train= df_train[features]
281
+ X_train = df_train[features]
280
282
  xgb_model.fit(X_train, y_train)
281
283
 
282
284
  # Make Predictions on the Validation Set
@@ -315,9 +317,7 @@ if __name__ == "__main__":
315
317
  label_names = label_encoder.classes_
316
318
 
317
319
  # Calculate various model performance metrics
318
- scores = precision_recall_fscore_support(
319
- y_validate, preds, average=None, labels=label_names
320
- )
320
+ scores = precision_recall_fscore_support(y_validate, preds, average=None, labels=label_names)
321
321
 
322
322
  # Put the scores into a dataframe
323
323
  score_df = pd.DataFrame(
@@ -355,7 +355,9 @@ if __name__ == "__main__":
355
355
  print(f"NumRows: {len(df_val)}")
356
356
 
357
357
  # Now save the model to the standard place/name
358
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
358
+ joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
359
+
360
+ # Save the label encoder if we have one
359
361
  if label_encoder:
360
362
  joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
361
363
 
@@ -370,19 +372,8 @@ if __name__ == "__main__":
370
372
 
371
373
  def model_fn(model_dir):
372
374
  """Deserialize and return fitted XGBoost model"""
373
-
374
- model_path = os.path.join(model_dir, "xgb_model.json")
375
-
376
- with open(model_path, "r") as f:
377
- model_json = json.load(f)
378
-
379
- sklearn_data = model_json['learner']['attributes']['scikit_learn']
380
- model_type = json.loads(sklearn_data)['_estimator_type']
381
-
382
- model_class = xgb.XGBClassifier if model_type == "classifier" else xgb.XGBRegressor
383
- model = model_class(enable_categorical=True)
384
- model.load_model(model_path)
385
-
375
+ model_path = os.path.join(model_dir, "xgb_model.joblib")
376
+ model = joblib.load(model_path)
386
377
  return model
387
378
 
388
379
 
@@ -222,7 +222,8 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
222
222
  lower_95, upper_95 = df["q_025"], df["q_975"]
223
223
  lower_90, upper_90 = df["q_05"], df["q_95"]
224
224
  lower_80, upper_80 = df["q_10"], df["q_90"]
225
- lower_68, upper_68 = df["q_16"], df["q_84"]
225
+ lower_68 = df.get("q_16", 0)
226
+ upper_68 = df.get("q_84", 0)
226
227
  lower_50, upper_50 = df["q_25"], df["q_75"]
227
228
  elif "prediction_std" in df.columns:
228
229
  lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
@@ -212,6 +212,14 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
212
212
  log.error("No XGBoost model found in the artifact.")
213
213
  return None, None, None, None
214
214
 
215
+ # Get the booster (SHAP requires the booster, not the sklearn wrapper)
216
+ if hasattr(xgb_model, "get_booster"):
217
+ # Full sklearn model - extract the booster
218
+ booster = xgb_model.get_booster()
219
+ else:
220
+ # Already a booster
221
+ booster = xgb_model
222
+
215
223
  # Load category mappings if available
216
224
  category_mappings = load_category_mappings_from_s3(model_artifact_uri)
217
225
 
@@ -229,8 +237,8 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
229
237
  # Create a DMatrix with categorical support
230
238
  dmatrix = xgb.DMatrix(X, enable_categorical=True)
231
239
 
232
- # Use XGBoost's built-in SHAP calculation
233
- shap_values = xgb_model.predict(dmatrix, pred_contribs=True, strict_shape=True)
240
+ # Use XGBoost's built-in SHAP calculation (booster method, not sklearn)
241
+ shap_values = booster.predict(dmatrix, pred_contribs=True, strict_shape=True)
234
242
  features_with_bias = features + ["bias"]
235
243
 
236
244
  # Now we need to subset the columns based on top 10 SHAP values