workbench 0.8.173__py3-none-any.whl → 0.8.175__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

@@ -2,6 +2,7 @@
2
2
 
3
3
  import logging
4
4
  import re
5
+ import time
5
6
  from datetime import datetime
6
7
  from typing import Tuple
7
8
  import pandas as pd
@@ -14,6 +15,9 @@ from workbench.core.artifacts.endpoint_core import EndpointCore
14
15
  from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
15
16
  from workbench.utils.monitor_utils import process_data_capture
16
17
 
18
+ # Setup logging
19
+ log = logging.getLogger("workbench")
20
+
17
21
 
18
22
  class DataCaptureCore:
19
23
  """Manages data capture configuration and retrieval for SageMaker endpoints"""
@@ -203,7 +207,7 @@ class DataCaptureCore:
203
207
  modes = [opt.get("CaptureMode") for opt in capture_options]
204
208
  return ["REQUEST" if m == "Input" else "RESPONSE" for m in modes if m]
205
209
 
206
- def get_captured_data(self, from_date=None, add_timestamp=True) -> Tuple[pd.DataFrame, pd.DataFrame]:
210
+ def get_captured_data(self, from_date: str = None, add_timestamp: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
207
211
  """
208
212
  Read and process captured data from S3.
209
213
 
@@ -226,29 +230,65 @@ class DataCaptureCore:
226
230
  files = [f for f in files if self._file_date_filter(f, from_date_obj)]
227
231
  self.log.info(f"Processing {len(files)} files from {from_date} onwards.")
228
232
  else:
229
- self.log.info(f"Processing all {len(files)} files.")
233
+ self.log.info(f"Processing all {len(files)} files...")
234
+
235
+ # Check if any files remain after filtering
236
+ if not files:
237
+ self.log.info("No files to process after date filtering.")
238
+ return pd.DataFrame(), pd.DataFrame()
239
+
240
+ # Sort files by name (assumed to include timestamp)
230
241
  files.sort()
231
242
 
232
- # Process files
233
- all_input_dfs, all_output_dfs = [], []
234
- for file_path in files:
243
+ # Get all timestamps in one batch if needed
244
+ timestamps = {}
245
+ if add_timestamp:
246
+ # Batch describe operation - much more efficient than per-file calls
247
+ timestamps = wr.s3.describe_objects(path=files)
248
+
249
+ # Process files using concurrent.futures
250
+ start_time = time.time()
251
+
252
+ def process_single_file(file_path):
253
+ """Process a single file and return input/output DataFrames."""
235
254
  try:
255
+ log.debug(f"Processing file: {file_path}...")
236
256
  df = wr.s3.read_json(path=file_path, lines=True)
237
257
  if not df.empty:
238
258
  input_df, output_df = process_data_capture(df)
239
- if add_timestamp:
240
- timestamp = wr.s3.describe_objects(path=file_path)[file_path]["LastModified"]
241
- output_df["timestamp"] = timestamp
242
- all_input_dfs.append(input_df)
243
- all_output_dfs.append(output_df)
259
+ if add_timestamp and file_path in timestamps:
260
+ output_df["timestamp"] = timestamps[file_path]["LastModified"]
261
+ return input_df, output_df
262
+ return pd.DataFrame(), pd.DataFrame()
244
263
  except Exception as e:
245
264
  self.log.warning(f"Error processing {file_path}: {e}")
265
+ return pd.DataFrame(), pd.DataFrame()
266
+
267
+ # Use ThreadPoolExecutor for I/O-bound operations
268
+ from concurrent.futures import ThreadPoolExecutor
269
+
270
+ max_workers = min(32, len(files)) # Cap at 32 threads or number of files
271
+
272
+ all_input_dfs, all_output_dfs = [], []
273
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
274
+ futures = [executor.submit(process_single_file, file_path) for file_path in files]
275
+ for future in futures:
276
+ input_df, output_df = future.result()
277
+ if not input_df.empty:
278
+ all_input_dfs.append(input_df)
279
+ if not output_df.empty:
280
+ all_output_dfs.append(output_df)
246
281
 
247
282
  if not all_input_dfs:
248
283
  self.log.warning("No valid data was processed.")
249
284
  return pd.DataFrame(), pd.DataFrame()
250
285
 
251
- return pd.concat(all_input_dfs, ignore_index=True), pd.concat(all_output_dfs, ignore_index=True)
286
+ input_df = pd.concat(all_input_dfs, ignore_index=True)
287
+ output_df = pd.concat(all_output_dfs, ignore_index=True)
288
+
289
+ elapsed_time = time.time() - start_time
290
+ self.log.info(f"Processed {len(files)} files in {elapsed_time:.2f} seconds.")
291
+ return input_df, output_df
252
292
 
253
293
  def _file_date_filter(self, file_path, from_date_obj):
254
294
  """Extract date from S3 path and compare with from_date."""
@@ -304,7 +344,7 @@ if __name__ == "__main__":
304
344
  # print(pred_df.head())
305
345
 
306
346
  # Check that data capture is working
307
- input_df, output_df = dc.get_captured_data()
347
+ input_df, output_df = dc.get_captured_data(from_date="2025-09-01")
308
348
  if input_df.empty and output_df.empty:
309
349
  print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
310
350
  else:
@@ -1,7 +1,7 @@
1
- # Model: NGBoost Regressor with Distribution output
2
- from ngboost import NGBRegressor
3
- from ngboost.distns import Cauchy, T
4
- from xgboost import XGBRegressor # Point Estimator
1
+ # Model: XGBoost for point predictions + LightGBM with MAPIE for conformalized intervals
2
+ from mapie.regression import ConformalizedQuantileRegressor
3
+ from lightgbm import LGBMRegressor
4
+ from xgboost import XGBRegressor
5
5
  from sklearn.model_selection import train_test_split
6
6
 
7
7
  # Model Performance Scores
@@ -20,19 +20,12 @@ import numpy as np
20
20
  import pandas as pd
21
21
  from typing import List, Tuple
22
22
 
23
- # Local Imports
24
- from proximity import Proximity
25
-
26
-
27
-
28
23
  # Template Placeholders
29
24
  TEMPLATE_PARAMS = {
30
- "id_column": "udm_mol_id",
31
25
  "target": "udm_asy_res_value",
32
- "features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v', 'chiral_centers', 'r_cnt', 's_cnt', 'db_stereo', 'e_cnt', 'z_cnt', 'chiral_fp', 'db_fp'],
26
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
33
27
  "compressed_features": [],
34
- "train_all_data": False,
35
- "track_columns": "udm_asy_res_value"
28
+ "train_all_data": True
36
29
  }
37
30
 
38
31
 
@@ -108,7 +101,7 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
108
101
 
109
102
 
110
103
  def decompress_features(
111
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
104
+ df: pd.DataFrame, features: List[str], compressed_features: List[str]
112
105
  ) -> Tuple[pd.DataFrame, List[str]]:
113
106
  """Prepare features for the model by decompressing bitstring features
114
107
 
@@ -164,13 +157,11 @@ def decompress_features(
164
157
 
165
158
  if __name__ == "__main__":
166
159
  # Template Parameters
167
- id_column = TEMPLATE_PARAMS["id_column"]
168
160
  target = TEMPLATE_PARAMS["target"]
169
161
  features = TEMPLATE_PARAMS["features"]
170
162
  orig_features = features.copy()
171
163
  compressed_features = TEMPLATE_PARAMS["compressed_features"]
172
164
  train_all_data = TEMPLATE_PARAMS["train_all_data"]
173
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
174
165
  validation_split = 0.2
175
166
 
176
167
  # Script arguments for input/output directories
@@ -228,78 +219,167 @@ if __name__ == "__main__":
228
219
  print(f"FIT/TRAIN: {df_train.shape}")
229
220
  print(f"VALIDATION: {df_val.shape}")
230
221
 
231
- # We're using XGBoost for point predictions and NGBoost for uncertainty quantification
232
- xgb_model = XGBRegressor()
233
- ngb_model = NGBRegressor() # Dist=Cauchy) Seems to give HUGE prediction intervals
234
- ngb_model = NGBRegressor(
235
- Dist=T,
236
- learning_rate=0.005,
237
- minibatch_frac=0.1, # Very small batches
238
- col_sample=0.8 # This parameter DOES exist
239
- ) # Testing this out
240
- print("NGBoost using T distribution for uncertainty quantification")
241
-
242
222
  # Prepare features and targets for training
243
223
  X_train = df_train[features]
244
224
  X_validate = df_val[features]
245
225
  y_train = df_train[target]
246
226
  y_validate = df_val[target]
247
227
 
248
- # Train both models using the training data
228
+ # Train XGBoost for point predictions
229
+ print("\nTraining XGBoost for point predictions...")
230
+ xgb_model = XGBRegressor(enable_categorical=True)
249
231
  xgb_model.fit(X_train, y_train)
250
- ngb_model.fit(X_train, y_train, X_val=X_validate, Y_val=y_validate)
251
-
252
- # Make Predictions on the Validation Set
253
- print(f"Making Predictions on Validation Set...")
254
- preds = xgb_model.predict(X_validate)
255
-
256
- # Calculate various model performance metrics (regression)
257
- rmse = root_mean_squared_error(y_validate, preds)
258
- mae = mean_absolute_error(y_validate, preds)
259
- r2 = r2_score(y_validate, preds)
260
- print(f"RMSE: {rmse:.3f}")
261
- print(f"MAE: {mae:.3f}")
262
- print(f"R2: {r2:.3f}")
232
+
233
+ # Evaluate XGBoost performance
234
+ y_pred_xgb = xgb_model.predict(X_validate)
235
+ xgb_rmse = root_mean_squared_error(y_validate, y_pred_xgb)
236
+ xgb_mae = mean_absolute_error(y_validate, y_pred_xgb)
237
+ xgb_r2 = r2_score(y_validate, y_pred_xgb)
238
+
239
+ print(f"\nXGBoost Point Prediction Performance:")
240
+ print(f"RMSE: {xgb_rmse:.3f}")
241
+ print(f"MAE: {xgb_mae:.3f}")
242
+ print(f"R2: {xgb_r2:.3f}")
243
+
244
+ # Define confidence levels we want to model
245
+ confidence_levels = [0.50, 0.80, 0.90, 0.95] # 50%, 80%, 90%, 95% confidence intervals
246
+
247
+ # Store MAPIE models for each confidence level
248
+ mapie_models = {}
249
+
250
+ # Train models for each confidence level
251
+ for confidence_level in confidence_levels:
252
+ alpha = 1 - confidence_level
253
+ lower_q = alpha / 2
254
+ upper_q = 1 - alpha / 2
255
+
256
+ print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
257
+ print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
258
+
259
+ # Train three models for this confidence level
260
+ quantile_estimators = []
261
+ for q in [lower_q, upper_q, 0.5]:
262
+ print(f" Training model for quantile {q:.3f}...")
263
+ est = LGBMRegressor(
264
+ objective="quantile",
265
+ alpha=q,
266
+ n_estimators=1000,
267
+ max_depth=6,
268
+ learning_rate=0.01,
269
+ num_leaves=31,
270
+ min_child_samples=20,
271
+ subsample=0.8,
272
+ colsample_bytree=0.8,
273
+ random_state=42,
274
+ verbose=-1,
275
+ force_col_wise=True
276
+ )
277
+ est.fit(X_train, y_train)
278
+ quantile_estimators.append(est)
279
+
280
+ # Create MAPIE CQR model for this confidence level
281
+ print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
282
+ mapie_model = ConformalizedQuantileRegressor(
283
+ quantile_estimators,
284
+ confidence_level=confidence_level,
285
+ prefit=True
286
+ )
287
+
288
+ # Conformalize the model
289
+ print(f" Conformalizing with validation data...")
290
+ mapie_model.conformalize(X_validate, y_validate)
291
+
292
+ # Store the model
293
+ mapie_models[f"mapie_{confidence_level:.2f}"] = mapie_model
294
+
295
+ # Validate coverage for this confidence level
296
+ y_pred, y_pis = mapie_model.predict_interval(X_validate)
297
+ coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
298
+ print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
299
+
300
+ print(f"\nOverall Model Performance Summary:")
301
+ print(f"XGBoost RMSE: {xgb_rmse:.3f}")
302
+ print(f"XGBoost MAE: {xgb_mae:.3f}")
303
+ print(f"XGBoost R2: {xgb_r2:.3f}")
263
304
  print(f"NumRows: {len(df_val)}")
264
305
 
306
+ # Analyze interval widths across confidence levels
307
+ print(f"\nInterval Width Analysis:")
308
+ for conf_level in confidence_levels:
309
+ model = mapie_models[f"mapie_{conf_level:.2f}"]
310
+ _, y_pis = model.predict_interval(X_validate)
311
+ widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
312
+ print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
313
+
265
314
  # Save the trained XGBoost model
266
315
  xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
267
316
 
268
- # Save the trained NGBoost model
269
- joblib.dump(ngb_model, os.path.join(args.model_dir, "ngb_model.joblib"))
317
+ # Save all MAPIE models
318
+ for model_name, model in mapie_models.items():
319
+ joblib.dump(model, os.path.join(args.model_dir, f"{model_name}.joblib"))
270
320
 
271
- # Save the features (this will validate input during predictions)
321
+ # Save the feature list
272
322
  with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
273
- json.dump(orig_features, fp) # We save the original features, not the decompressed ones
274
-
275
- # Now the Proximity model
276
- model = Proximity(df_train, id_column, features, target, track_columns=track_columns)
323
+ json.dump(features, fp)
324
+
325
+ # Save category mappings if any
326
+ if category_mappings:
327
+ with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
328
+ json.dump(category_mappings, fp)
329
+
330
+ # Save model configuration
331
+ model_config = {
332
+ "model_type": "XGBoost_MAPIE_CQR_LightGBM",
333
+ "confidence_levels": confidence_levels,
334
+ "n_features": len(features),
335
+ "target": target,
336
+ "validation_metrics": {
337
+ "xgb_rmse": float(xgb_rmse),
338
+ "xgb_mae": float(xgb_mae),
339
+ "xgb_r2": float(xgb_r2),
340
+ "n_validation": len(df_val)
341
+ }
342
+ }
343
+ with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
344
+ json.dump(model_config, fp, indent=2)
277
345
 
278
- # Now serialize the model
279
- model.serialize(args.model_dir)
346
+ print(f"\nModel training complete!")
347
+ print(f"Saved 1 XGBoost model and {len(mapie_models)} MAPIE models to {args.model_dir}")
280
348
 
281
349
 
282
350
  #
283
351
  # Inference Section
284
352
  #
285
353
  def model_fn(model_dir) -> dict:
286
- """Load and return XGBoost, NGBoost, and Prox Model from model directory."""
354
+ """Load XGBoost and all MAPIE models from the specified directory."""
355
+
356
+ # Load model configuration to know which models to load
357
+ with open(os.path.join(model_dir, "model_config.json")) as fp:
358
+ config = json.load(fp)
287
359
 
288
360
  # Load XGBoost regressor
289
361
  xgb_path = os.path.join(model_dir, "xgb_model.json")
290
362
  xgb_model = XGBRegressor(enable_categorical=True)
291
363
  xgb_model.load_model(xgb_path)
292
364
 
293
- # Load NGBoost regressor
294
- ngb_model = joblib.load(os.path.join(model_dir, "ngb_model.joblib"))
365
+ # Load all MAPIE models
366
+ mapie_models = {}
367
+ for conf_level in config["confidence_levels"]:
368
+ model_name = f"mapie_{conf_level:.2f}"
369
+ mapie_models[model_name] = joblib.load(os.path.join(model_dir, f"{model_name}.joblib"))
295
370
 
296
- # Deserialize the proximity model
297
- prox_model = Proximity.deserialize(model_dir)
371
+ # Load category mappings if they exist
372
+ category_mappings = {}
373
+ category_path = os.path.join(model_dir, "category_mappings.json")
374
+ if os.path.exists(category_path):
375
+ with open(category_path) as fp:
376
+ category_mappings = json.load(fp)
298
377
 
299
378
  return {
300
- "xgboost": xgb_model,
301
- "ngboost": ngb_model,
302
- "proximity": prox_model
379
+ "xgb_model": xgb_model,
380
+ "mapie_models": mapie_models,
381
+ "confidence_levels": config["confidence_levels"],
382
+ "category_mappings": category_mappings
303
383
  }
304
384
 
305
385
 
@@ -315,7 +395,7 @@ def input_fn(input_data, content_type):
315
395
  if "text/csv" in content_type:
316
396
  return pd.read_csv(StringIO(input_data))
317
397
  elif "application/json" in content_type:
318
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
398
+ return pd.DataFrame(json.loads(input_data))
319
399
  else:
320
400
  raise ValueError(f"{content_type} not supported!")
321
401
 
@@ -323,23 +403,26 @@ def input_fn(input_data, content_type):
323
403
  def output_fn(output_df, accept_type):
324
404
  """Supports both CSV and JSON output formats."""
325
405
  if "text/csv" in accept_type:
326
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
406
+ # Convert categorical columns to string to avoid fillna issues
407
+ for col in output_df.select_dtypes(include=['category']).columns:
408
+ output_df[col] = output_df[col].astype(str)
409
+ csv_output = output_df.fillna("N/A").to_csv(index=False)
327
410
  return csv_output, "text/csv"
328
411
  elif "application/json" in accept_type:
329
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
412
+ return output_df.to_json(orient="records"), "application/json"
330
413
  else:
331
414
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
332
415
 
333
416
 
334
417
  def predict_fn(df, models) -> pd.DataFrame:
335
- """Make Predictions with our XGB Quantile Regression Model
418
+ """Make predictions using XGBoost for point estimates and MAPIE for conformalized intervals
336
419
 
337
420
  Args:
338
421
  df (pd.DataFrame): The input DataFrame
339
- models (dict): The dictionary of models to use for predictions
422
+ models (dict): Dictionary containing XGBoost and MAPIE models
340
423
 
341
424
  Returns:
342
- pd.DataFrame: The DataFrame with the predictions added
425
+ pd.DataFrame: DataFrame with XGBoost predictions and conformalized intervals
343
426
  """
344
427
 
345
428
  # Grab our feature columns (from training)
@@ -350,44 +433,62 @@ def predict_fn(df, models) -> pd.DataFrame:
350
433
  # Match features in a case-insensitive manner
351
434
  matched_df = match_features_case_insensitive(df, model_features)
352
435
 
353
- # Use XGBoost for point predictions
354
- df["prediction"] = models["xgboost"].predict(matched_df[model_features])
355
-
356
- # NGBoost predict returns distribution objects
357
- y_dists = models["ngboost"].pred_dist(matched_df[model_features])
358
-
359
- # Extract parameters from distribution
360
- dist_params = y_dists.params
361
-
362
- # Extract mean and std from distribution parameters
363
- df["prediction_uq"] = dist_params['loc'] # mean
364
- df["prediction_std"] = dist_params['scale'] # standard deviation
365
-
366
- # Add 95% prediction intervals using ppf (percent point function)
367
- # Note: Our hybrid model uses XGB point prediction and NGBoost UQ
368
- # so we need to adjust the bounds to include the point prediction
369
- df["q_025"] = np.minimum(y_dists.ppf(0.025), df["prediction"])
370
- df["q_975"] = np.maximum(y_dists.ppf(0.975), df["prediction"])
371
-
372
- # Add 90% prediction intervals
373
- df["q_05"] = y_dists.ppf(0.05) # 5th percentile
374
- df["q_95"] = y_dists.ppf(0.95) # 95th percentile
375
-
376
- # Add 80% prediction intervals
377
- df["q_10"] = y_dists.ppf(0.10) # 10th percentile
378
- df["q_90"] = y_dists.ppf(0.90) # 90th percentile
436
+ # Apply categorical mappings if they exist
437
+ if models.get("category_mappings"):
438
+ matched_df, _ = convert_categorical_types(
439
+ matched_df,
440
+ model_features,
441
+ models["category_mappings"]
442
+ )
379
443
 
380
- # Add 50% prediction intervals
381
- df["q_25"] = y_dists.ppf(0.25) # 25th percentile
382
- df["q_75"] = y_dists.ppf(0.75) # 75th percentile
444
+ # Get features for prediction
445
+ X = matched_df[model_features]
446
+
447
+ # Get XGBoost point predictions
448
+ df["prediction"] = models["xgb_model"].predict(X)
449
+
450
+ # Get predictions from each MAPIE model for conformalized intervals
451
+ for conf_level in models["confidence_levels"]:
452
+ model_name = f"mapie_{conf_level:.2f}"
453
+ model = models["mapie_models"][model_name]
454
+
455
+ # Get conformalized predictions
456
+ y_pred, y_pis = model.predict_interval(X)
457
+
458
+ # Map confidence levels to quantile names
459
+ if conf_level == 0.50: # 50% CI
460
+ df["q_25"] = y_pis[:, 0, 0]
461
+ df["q_75"] = y_pis[:, 1, 0]
462
+ elif conf_level == 0.80: # 80% CI
463
+ df["q_10"] = y_pis[:, 0, 0]
464
+ df["q_90"] = y_pis[:, 1, 0]
465
+ elif conf_level == 0.90: # 90% CI
466
+ df["q_05"] = y_pis[:, 0, 0]
467
+ df["q_95"] = y_pis[:, 1, 0]
468
+ elif conf_level == 0.95: # 95% CI
469
+ df["q_025"] = y_pis[:, 0, 0]
470
+ df["q_975"] = y_pis[:, 1, 0]
471
+
472
+ # Add median (q_50) from XGBoost prediction
473
+ df["q_50"] = df["prediction"]
474
+
475
+ # Calculate uncertainty metrics based on 95% interval
476
+ interval_width = df["q_975"] - df["q_025"]
477
+ df["prediction_std"] = interval_width / 3.92
383
478
 
384
479
  # Reorder the quantile columns for easier reading
385
480
  quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
386
481
  other_cols = [col for col in df.columns if col not in quantile_cols]
387
482
  df = df[other_cols + quantile_cols]
388
483
 
389
- # Compute Nearest neighbors with Proximity model
390
- models["proximity"].neighbors(df)
484
+ # Uncertainty score
485
+ df["uncertainty_score"] = interval_width / (np.abs(df["prediction"]) + 1e-6)
486
+
487
+ # Confidence bands
488
+ df["confidence_band"] = pd.cut(
489
+ df["uncertainty_score"],
490
+ bins=[0, 0.5, 1.0, 2.0, np.inf],
491
+ labels=["high", "medium", "low", "very_low"]
492
+ )
391
493
 
392
- # Return the modified DataFrame
393
494
  return df
@@ -227,15 +227,7 @@ if __name__ == "__main__":
227
227
 
228
228
  # Train XGBoost for point predictions
229
229
  print("\nTraining XGBoost for point predictions...")
230
- xgb_model = XGBRegressor(
231
- n_estimators=1000,
232
- max_depth=6,
233
- learning_rate=0.01,
234
- subsample=0.8,
235
- colsample_bytree=0.8,
236
- random_state=42,
237
- verbosity=0
238
- )
230
+ xgb_model = XGBRegressor(enable_categorical=True)
239
231
  xgb_model.fit(X_train, y_train)
240
232
 
241
233
  # Evaluate XGBoost performance
@@ -76,55 +76,44 @@ def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
76
76
  Returns:
77
77
  tuple[DataFrame, DataFrame]: Input and output DataFrames.
78
78
  """
79
+
80
+ def parse_endpoint_data(data: dict) -> pd.DataFrame:
81
+ """Parse endpoint data based on encoding type."""
82
+ encoding = data["encoding"].upper()
83
+
84
+ if encoding == "CSV":
85
+ return pd.read_csv(StringIO(data["data"]))
86
+ elif encoding == "JSON":
87
+ json_data = json.loads(data["data"])
88
+ if isinstance(json_data, dict):
89
+ return pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
90
+ else:
91
+ return pd.DataFrame(json_data)
92
+ else:
93
+ return None # Unknown encoding
94
+
79
95
  input_dfs = []
80
96
  output_dfs = []
81
97
 
82
- for idx, row in df.iterrows():
98
+ # Use itertuples() instead of iterrows() for better performance
99
+ for row in df.itertuples(index=True):
83
100
  try:
84
- capture_data = row["captureData"]
101
+ capture_data = row.captureData
85
102
 
86
103
  # Process input data if present
87
104
  if "endpointInput" in capture_data:
88
- input_data = capture_data["endpointInput"]
89
- encoding = input_data["encoding"].upper()
90
-
91
- if encoding == "CSV":
92
- input_df = pd.read_csv(StringIO(input_data["data"]))
93
- elif encoding == "JSON":
94
- json_data = json.loads(input_data["data"])
95
- if isinstance(json_data, dict):
96
- input_df = pd.DataFrame(
97
- {k: [v] if not isinstance(v, list) else v for k, v in json_data.items()}
98
- )
99
- else:
100
- input_df = pd.DataFrame(json_data)
101
- else:
102
- continue # Skip unknown encodings
103
-
104
- input_dfs.append(input_df)
105
+ input_df = parse_endpoint_data(capture_data["endpointInput"])
106
+ if input_df is not None:
107
+ input_dfs.append(input_df)
105
108
 
106
109
  # Process output data if present
107
110
  if "endpointOutput" in capture_data:
108
- output_data = capture_data["endpointOutput"]
109
- encoding = output_data["encoding"].upper()
110
-
111
- if encoding == "CSV":
112
- output_df = pd.read_csv(StringIO(output_data["data"]))
113
- elif encoding == "JSON":
114
- json_data = json.loads(output_data["data"])
115
- if isinstance(json_data, dict):
116
- output_df = pd.DataFrame(
117
- {k: [v] if not isinstance(v, list) else v for k, v in json_data.items()}
118
- )
119
- else:
120
- output_df = pd.DataFrame(json_data)
121
- else:
122
- continue # Skip unknown encodings
123
-
124
- output_dfs.append(output_df)
111
+ output_df = parse_endpoint_data(capture_data["endpointOutput"])
112
+ if output_df is not None:
113
+ output_dfs.append(output_df)
125
114
 
126
115
  except Exception as e:
127
- log.debug(f"Row {idx}: Failed to process row: {e}")
116
+ log.debug(f"Row {row.Index}: Failed to process row: {e}")
128
117
  continue
129
118
 
130
119
  # Combine and return results
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.173
3
+ Version: 0.8.175
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License-Expression: MIT
@@ -51,7 +51,7 @@ workbench/core/artifacts/__init__.py,sha256=ps7rA_rbWnDbvWbg4kvu--IKMY8WmbPRyv4S
51
51
  workbench/core/artifacts/artifact.py,sha256=AtTw8wfMd-fi7cHJHsBAXHUk53kRW_6lyBwwsIbHw54,17750
52
52
  workbench/core/artifacts/athena_source.py,sha256=RNmCe7s6uH4gVHpcdJcL84aSbF5Q1ahJBLLGwHYRXEU,26081
53
53
  workbench/core/artifacts/cached_artifact_mixin.py,sha256=ngqFLZ4cQx_TFouXZgXZQsv_7W6XCvxVGXXSfzzaft8,3775
54
- workbench/core/artifacts/data_capture_core.py,sha256=VJL5AcXOx8PxY1Urw0AFm-czqvs55cDiwH_ZTcr2LS0,13207
54
+ workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcvZyNxYERwvo8o0OQc,14858
55
55
  workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
56
56
  workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
57
57
  workbench/core/artifacts/endpoint_core.py,sha256=lwgiz0jttW8C4YqcKaA8nf231WI3kol-nLnKcAbFJko,49049
@@ -140,8 +140,8 @@ workbench/model_scripts/custom_models/uq_models/Readme.md,sha256=UVpL-lvtTrLqwBe
140
140
  workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=U4LIlpp8Rbu3apyzPR7-55lvlutpTsCro_PUvQ5pklY,6457
141
141
  workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=0IJnSBACQ556ldEiPqR7yPCOOLJs1hQhHmPBvB2d9tY,13491
142
142
  workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=QbDUfkiPCwJ-c-4Twgu4utZuYZaAyeW_3T1IP-_tutw,6683
143
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=yJOL29TWtIAPbhuqK1m9w-MfWq0MVfJcI412VVgDO04,17583
144
- workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=D94Y3U7IruGQlu9m6gXyLRjm502qZafYrwhEM9GP6oE,18337
143
+ workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=AcLf-vXOmn_vpTeiKpNKCW_dRhR8Co1sMFC84EPT4IE,22392
144
+ workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=VkFM0eZM2d-hzDbngk9s08DD5vn2nQRD4coCUfj36Fk,18181
145
145
  workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=eawh0Fp3DhbdCXzWN6KloczT5ZS_ou4ayW65yUTTE4o,14109
146
146
  workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=9-O6P-SW50ul5Wl6es2DMWXSbrwOg7HWsdc8Qdln0MM,8278
147
147
  workbench/model_scripts/custom_models/uq_models/proximity.py,sha256=zqmNlX70LnWXr5fdtFFQppSNTLjlOciQVrjGr-g9jRE,13716
@@ -221,7 +221,7 @@ workbench/utils/license_manager.py,sha256=sDuhk1mZZqUbFmnuFXehyGnui_ALxrmYBg7gYw
221
221
  workbench/utils/log_utils.py,sha256=7n1NJXO_jUX82e6LWAQug6oPo3wiPDBYsqk9gsYab_A,3167
222
222
  workbench/utils/markdown_utils.py,sha256=4lEqzgG4EVmLcvvKKNUwNxVCySLQKJTJmWDiaDroI1w,8306
223
223
  workbench/utils/model_utils.py,sha256=JeEztmFyDJ7yqRozDX0L6apuhLgKx1sgNlO5duB73qc,11938
224
- workbench/utils/monitor_utils.py,sha256=LbfZImf4tHqYz9J8NnW_ggZP45Has_4QwXHQ-Wi3sLw,8381
224
+ workbench/utils/monitor_utils.py,sha256=kVaJ7BgUXs3VPMFYfLC03wkIV4Dq-pEhoXS0wkJFxCc,7858
225
225
  workbench/utils/pandas_utils.py,sha256=uTUx-d1KYfjbS9PMQp2_9FogCV7xVZR6XLzU5YAGmfs,39371
226
226
  workbench/utils/performance_utils.py,sha256=WDNvz-bOdC99cDuXl0urAV4DJ7alk_V3yzKPwvqgST4,1329
227
227
  workbench/utils/pipeline_utils.py,sha256=yzR5tgAzz6zNqvxzZR6YqsbS7r3QDKzBXozaM_ADXlc,2171
@@ -288,9 +288,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
288
288
  workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
289
289
  workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
290
290
  workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
291
- workbench-0.8.173.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
292
- workbench-0.8.173.dist-info/METADATA,sha256=b1gas8B3zXhFnVPVFB8vLCeqoeb8brx4rdMXRus-YJo,9210
293
- workbench-0.8.173.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
294
- workbench-0.8.173.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
295
- workbench-0.8.173.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
296
- workbench-0.8.173.dist-info/RECORD,,
291
+ workbench-0.8.175.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
292
+ workbench-0.8.175.dist-info/METADATA,sha256=hAjhM-oXEqxffYyDwawIsSdTv3iKsRs5_OiZw1sv2RQ,9210
293
+ workbench-0.8.175.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
294
+ workbench-0.8.175.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
295
+ workbench-0.8.175.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
296
+ workbench-0.8.175.dist-info/RECORD,,