workbench 0.8.205__py3-none-any.whl → 0.8.213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/models/noise_model.py +388 -0
  2. workbench/api/endpoint.py +3 -6
  3. workbench/api/feature_set.py +1 -1
  4. workbench/api/model.py +5 -11
  5. workbench/cached/cached_model.py +4 -4
  6. workbench/core/artifacts/endpoint_core.py +63 -153
  7. workbench/core/artifacts/model_core.py +21 -19
  8. workbench/core/transforms/features_to_model/features_to_model.py +2 -2
  9. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
  10. workbench/model_script_utils/model_script_utils.py +335 -0
  11. workbench/model_script_utils/pytorch_utils.py +395 -0
  12. workbench/model_script_utils/uq_harness.py +278 -0
  13. workbench/model_scripts/chemprop/chemprop.template +289 -666
  14. workbench/model_scripts/chemprop/generated_model_script.py +292 -669
  15. workbench/model_scripts/chemprop/model_script_utils.py +335 -0
  16. workbench/model_scripts/chemprop/requirements.txt +2 -10
  17. workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
  18. workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
  19. workbench/model_scripts/pytorch_model/pytorch.template +350 -607
  20. workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
  21. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  22. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  23. workbench/model_scripts/script_generation.py +2 -5
  24. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  25. workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
  26. workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  28. workbench/model_scripts/xgb_model/xgb_model.template +344 -407
  29. workbench/scripts/training_test.py +85 -0
  30. workbench/utils/chemprop_utils.py +18 -656
  31. workbench/utils/metrics_utils.py +172 -0
  32. workbench/utils/model_utils.py +104 -47
  33. workbench/utils/pytorch_utils.py +32 -472
  34. workbench/utils/xgboost_local_crossfold.py +267 -0
  35. workbench/utils/xgboost_model_utils.py +49 -356
  36. workbench/web_interface/components/plugins/model_details.py +30 -68
  37. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/METADATA +5 -5
  38. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/RECORD +42 -31
  39. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/entry_points.txt +1 -0
  40. workbench/model_scripts/uq_models/mapie.template +0 -605
  41. workbench/model_scripts/uq_models/requirements.txt +0 -1
  42. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/WHEEL +0 -0
  43. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/top_level.txt +0 -0
@@ -1,198 +1,45 @@
1
- # Model: XGBoost for point predictions + LightGBM with MAPIE for conformalized intervals
2
- from mapie.regression import ConformalizedQuantileRegressor
3
- from lightgbm import LGBMRegressor
1
+ # Model: XGBoost for point predictions + MAPIE UQ Harness for conformalized intervals
4
2
  from xgboost import XGBRegressor
5
3
  from sklearn.model_selection import train_test_split
6
4
 
7
- # Model Performance Scores
8
- from sklearn.metrics import mean_absolute_error, median_absolute_error, r2_score, root_mean_squared_error
9
- from scipy.stats import spearmanr
10
-
11
- from io import StringIO
12
5
  import json
13
6
  import argparse
14
7
  import joblib
15
8
  import os
16
9
  import numpy as np
17
10
  import pandas as pd
18
- from typing import List, Tuple, Optional, Dict
11
+
12
+ # Shared model script utilities
13
+ from model_script_utils import (
14
+ check_dataframe,
15
+ match_features_case_insensitive,
16
+ convert_categorical_types,
17
+ decompress_features,
18
+ input_fn,
19
+ output_fn,
20
+ compute_regression_metrics,
21
+ print_regression_metrics,
22
+ )
23
+
24
+ # UQ Harness for uncertainty quantification
25
+ from uq_harness import (
26
+ train_uq_models,
27
+ save_uq_models,
28
+ load_uq_models,
29
+ predict_intervals,
30
+ compute_confidence,
31
+ )
19
32
 
20
33
  # Template Placeholders
21
34
  TEMPLATE_PARAMS = {
22
- "target": "udm_asy_res_efflux_ratio",
23
- "features": ['smr_vsa4', 'tpsa', 'nhohcount', 'peoe_vsa1', 'mollogp', 'vsa_estate3', 'xc_4dv', 'smr_vsa3', 'tertiary_amine_count', 'peoe_vsa8', 'minpartialcharge', 'nitrogen_span', 'vsa_estate2', 'chi1v', 'hba_hbd_ratio', 'molecular_axis_length', 'molmr', 'vsa_estate4', 'num_s_centers', 'vsa_estate6', 'qed', 'numhdonors', 'mi', 'estate_vsa4', 'axp_7d', 'kappa3', 'asphericity', 'estate_vsa8', 'estate_vsa2', 'estate_vsa3', 'peoe_vsa3', 'xp_6dv', 'bcut2d_logphi', 'vsa_estate8', 'amphiphilic_moment', 'type_ii_pattern_count', 'minestateindex', 'charge_centroid_distance', 'molecular_asymmetry', 'molecular_volume_3d', 'bcut2d_mrlow', 'axp_1d', 'vsa_estate9', 'aromatic_interaction_score', 'xp_7dv', 'bcut2d_mwlow', 'axp_7dv', 'slogp_vsa1', 'maxestateindex', 'fr_al_oh', 'nbase', 'xp_2dv', 'radius_of_gyration', 'sps', 'xch_7d', 'bcut2d_mrhi', 'axp_0dv', 'vsa_estate5', 'hallkieralpha', 'xp_0dv', 'fr_nhpyrrole', 'smr_vsa1', 'smr_vsa6', 'chi2v', 'bcut2d_mwhi', 'estate_vsa6', 'bcut2d_logplow', 'peoe_vsa2', 'fractioncsp3', 'slogp_vsa2', 'c3sp3', 'peoe_vsa7', 'estate_vsa9', 'peoe_vsa9', 'avgipc', 'smr_vsa9', 'xpc_4dv', 'balabanj', 'axp_1dv', 'mv', 'minabsestateindex', 'bcut2d_chglo', 'fpdensitymorgan2', 'axp_4d', 'numsaturatedheterocycles', 'fpdensitymorgan1', 'axp_3dv', 'axp_5d', 'smr_vsa5', 'bcut2d_chghi', 'axp_3d', 'xpc_5dv', 'chi4n', 'peoe_vsa10', 'vsa_estate7', 'peoe_vsa11', 'estate_vsa10', 'xp_7d', 'slogp_vsa5', 'xch_7dv', 'vsa_estate10', 'labuteasa', 'estate_vsa5', 'xp_3d', 'chi1', 'xch_4dv', 'xp_6d', 'estate_vsa1', 'axp_4dv', 'phi', 'xp_3dv', 'xch_6dv', 'smr_vsa10', 'num_r_centers', 'xc_5d', 'maxpartialcharge', 'xc_3d', 'peoe_vsa6', 'fr_imidazole', 'axp_2d', 'slogp_vsa3', 'mz', 'axp_6dv', 'xch_6d', 'mm', 'numatomstereocenters', 'c1sp3', 'chi1n', 'fpdensitymorgan3', 'xp_5dv', 'chi3v', 'slogp_vsa4', 'fr_ether', 'xp_2d', 'chi3n', 'xch_5dv', 'axp_6d', 'xc_5dv', 'numheterocycles', 'mpe', 'fr_hoccn', 'xc_3dv', 'type_i_pattern_count', 'chi0v', 'xch_4d', 'numsaturatedcarbocycles', 'mp', 'xch_5d', 'maxabspartialcharge', 'axp_2dv', 'bertzct', 'sse', 'xpc_6dv', 'sv', 'xpc_4d', 'si', 'chi0n', 'mse', 'xpc_6d', 'peoe_vsa12', 'xpc_5d', 'kappa2', 'axp_5dv', 'kappa1', 'chi2n', 'intramolecular_hbond_potential', 'fr_nh0', 'numaliphaticheterocycles', 'smr_vsa7', 'mare', 'fr_priamide', 'vsa_estate1', 'num_stereocenters', 'fr_nh1', 'estate_vsa7', 'fr_piperzine', 'c1sp2', 'slogp_vsa6', 'xp_5d', 'fr_aryl_methyl', 'molwt', 'chi4v', 'xc_6dv', 'heavyatommolwt', 'xp_4d', 'sp', 'slogp_vsa7', 'numhacceptors', 'c2sp3', 'peoe_vsa4', 'slogp_vsa10', 'fr_morpholine', 'fr_methoxy', 'fr_aniline', 'xp_4dv', 'fr_urea', 'c3sp2', 'fr_pyridine', 'hybratio', 'fr_thiazole', 'minabspartialcharge', 'sm', 'axp_0d', 'numaromaticheterocycles', 'nocount', 'xc_4d', 'peoe_vsa13', 'fr_amide', 'num_defined_stereocenters', 'amide_count', 'xc_6d', 'numrotatablebonds', 'c2sp2', 'fr_piperdine', 'numvalenceelectrons', 'c1sp1', 'fr_nitrile', 'fr_phenol', 'c4sp3', 'spe', 'numheteroatoms', 'estate_vsa11', 'sz', 'chi0', 'smr_vsa2', 'fr_ketone_topliss', 'slogp_vsa11', 'fr_benzene', 'fr_ndealkylation2', 'peoe_vsa5', 'fr_c_o', 'numsaturatedrings', 'exactmolwt', 'sare', 'numaliphaticrings', 'fr_al_oh_notert', 'fr_imine', 'frac_defined_stereo', 'numunspecifiedatomstereocenters', 'fr_ar_n', 'fr_bicyclic', 'fr_c_o_nocoo', 'numspiroatoms', 'fr_sulfone', 'fr_ndealkylation1'],
35
+ "target": "solubility",
36
+ "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
24
37
  "compressed_features": [],
25
- "train_all_data": True,
26
- "hyperparameters": {'n_estimators': 500, 'max_depth': 6, 'learning_rate': 0.04},
38
+ "train_all_data": False,
39
+ "hyperparameters": {'training_config': {'max_epochs': 150}, 'model_config': {'layers': '128-64-32'}},
27
40
  }
28
41
 
29
42
 
30
- def compute_confidence(
31
- df: pd.DataFrame,
32
- median_interval_width: float,
33
- lower_q: str = "q_10",
34
- upper_q: str = "q_90",
35
- alpha: float = 1.0,
36
- beta: float = 1.0,
37
- ) -> pd.DataFrame:
38
- """
39
- Compute confidence scores (0.0 to 1.0) based on prediction interval width
40
- and distance from median using exponential decay.
41
-
42
- Args:
43
- df: DataFrame with 'prediction', 'q_50', and quantile columns
44
- median_interval_width: Pre-computed median interval width from training data
45
- lower_q: Lower quantile column name (default: 'q_10')
46
- upper_q: Upper quantile column name (default: 'q_90')
47
- alpha: Weight for interval width term (default: 1.0)
48
- beta: Weight for distance from median term (default: 1.0)
49
-
50
- Returns:
51
- DataFrame with added 'confidence' column
52
- """
53
- # Interval width
54
- interval_width = (df[upper_q] - df[lower_q]).abs()
55
-
56
- # Distance from median, normalized by interval width
57
- distance_from_median = (df['prediction'] - df['q_50']).abs()
58
- normalized_distance = distance_from_median / (interval_width + 1e-6)
59
-
60
- # Cap the distance penalty at 1.0
61
- normalized_distance = np.minimum(normalized_distance, 1.0)
62
-
63
- # Confidence using exponential decay
64
- interval_term = interval_width / median_interval_width
65
- df['confidence'] = np.exp(-(alpha * interval_term + beta * normalized_distance))
66
-
67
- return df
68
-
69
-
70
- # Function to check if dataframe is empty
71
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
72
- """
73
- Check if the provided dataframe is empty and raise an exception if it is.
74
-
75
- Args:
76
- df (pd.DataFrame): DataFrame to check
77
- df_name (str): Name of the DataFrame
78
- """
79
- if df.empty:
80
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
81
- print(msg)
82
- raise ValueError(msg)
83
-
84
-
85
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
86
- """
87
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
88
- Prioritizes exact matches, then case-insensitive matches.
89
-
90
- Raises ValueError if any model features cannot be matched.
91
- """
92
- df_columns_lower = {col.lower(): col for col in df.columns}
93
- rename_dict = {}
94
- missing = []
95
- for feature in model_features:
96
- if feature in df.columns:
97
- continue # Exact match
98
- elif feature.lower() in df_columns_lower:
99
- rename_dict[df_columns_lower[feature.lower()]] = feature
100
- else:
101
- missing.append(feature)
102
-
103
- if missing:
104
- raise ValueError(f"Features not found: {missing}")
105
-
106
- # Rename the DataFrame columns to match the model features
107
- return df.rename(columns=rename_dict)
108
-
109
-
110
- def convert_categorical_types(df: pd.DataFrame, features: list, category_mappings={}) -> tuple:
111
- """
112
- Converts appropriate columns to categorical type with consistent mappings.
113
-
114
- Args:
115
- df (pd.DataFrame): The DataFrame to process.
116
- features (list): List of feature names to consider for conversion.
117
- category_mappings (dict, optional): Existing category mappings. If empty dict, we're in
118
- training mode. If populated, we're in inference mode.
119
-
120
- Returns:
121
- tuple: (processed DataFrame, category mappings dictionary)
122
- """
123
- # Training mode
124
- if category_mappings == {}:
125
- for col in df.select_dtypes(include=["object", "string"]):
126
- if col in features and df[col].nunique() < 20:
127
- print(f"Training mode: Converting {col} to category")
128
- df[col] = df[col].astype("category")
129
- category_mappings[col] = df[col].cat.categories.tolist() # Store category mappings
130
-
131
- # Inference mode
132
- else:
133
- for col, categories in category_mappings.items():
134
- if col in df.columns:
135
- print(f"Inference mode: Applying categorical mapping for {col}")
136
- df[col] = pd.Categorical(df[col], categories=categories) # Apply consistent categorical mapping
137
-
138
- return df, category_mappings
139
-
140
-
141
- def decompress_features(
142
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
143
- ) -> Tuple[pd.DataFrame, List[str]]:
144
- """Prepare features for the model by decompressing bitstring features
145
-
146
- Args:
147
- df (pd.DataFrame): The features DataFrame
148
- features (List[str]): Full list of feature names
149
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
150
-
151
- Returns:
152
- pd.DataFrame: DataFrame with the decompressed features
153
- List[str]: Updated list of feature names after decompression
154
-
155
- Raises:
156
- ValueError: If any missing values are found in the specified features
157
- """
158
-
159
- # Check for any missing values in the required features
160
- missing_counts = df[features].isna().sum()
161
- if missing_counts.any():
162
- missing_features = missing_counts[missing_counts > 0]
163
- print(
164
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
165
- "WARNING: You might want to remove/replace all NaN values before processing."
166
- )
167
-
168
- # Decompress the specified compressed features
169
- decompressed_features = features.copy()
170
- for feature in compressed_features:
171
- if (feature not in df.columns) or (feature not in features):
172
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
173
- continue
174
-
175
- # Remove the feature from the list of features to avoid duplication
176
- decompressed_features.remove(feature)
177
-
178
- # Handle all compressed features as bitstrings
179
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
180
- prefix = feature[:3]
181
-
182
- # Create all new columns at once - avoids fragmentation
183
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
184
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
185
-
186
- # Add to features list
187
- decompressed_features.extend(new_col_names)
188
-
189
- # Drop original column and concatenate new ones
190
- df = df.drop(columns=[feature])
191
- df = pd.concat([df, new_df], axis=1)
192
-
193
- return df, decompressed_features
194
-
195
-
196
43
  if __name__ == "__main__":
197
44
  # Template Parameters
198
45
  target = TEMPLATE_PARAMS["target"]
@@ -200,7 +47,7 @@ if __name__ == "__main__":
200
47
  orig_features = features.copy()
201
48
  compressed_features = TEMPLATE_PARAMS["compressed_features"]
202
49
  train_all_data = TEMPLATE_PARAMS["train_all_data"]
203
- hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
50
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"] or {}
204
51
  validation_split = 0.2
205
52
 
206
53
  # Script arguments for input/output directories
@@ -253,8 +100,8 @@ if __name__ == "__main__":
253
100
  print(f"VALIDATION: {df_val.shape}")
254
101
 
255
102
  # Extract sample weights if present
256
- if 'sample_weight' in df_train.columns:
257
- sample_weights = df_train['sample_weight']
103
+ if "sample_weight" in df_train.columns:
104
+ sample_weights = df_train["sample_weight"]
258
105
  print(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}, mean={sample_weights.mean():.2f}")
259
106
  else:
260
107
  sample_weights = None
@@ -266,7 +113,9 @@ if __name__ == "__main__":
266
113
  y_train = df_train[target]
267
114
  y_validate = df_val[target]
268
115
 
116
+ # ==========================================
269
117
  # Train XGBoost for point predictions
118
+ # ==========================================
270
119
  print("\nTraining XGBoost for point predictions...")
271
120
  print(f" Hyperparameters: {hyperparameters}")
272
121
  xgb_model = XGBRegressor(enable_categorical=True, **hyperparameters)
@@ -274,136 +123,27 @@ if __name__ == "__main__":
274
123
 
275
124
  # Evaluate XGBoost performance
276
125
  y_pred_xgb = xgb_model.predict(X_validate)
277
- xgb_rmse = root_mean_squared_error(y_validate, y_pred_xgb)
278
- xgb_mae = mean_absolute_error(y_validate, y_pred_xgb)
279
- xgb_r2 = r2_score(y_validate, y_pred_xgb)
280
-
281
- xgb_medae = median_absolute_error(y_validate, y_pred_xgb)
282
- xgb_spearman = spearmanr(y_validate, y_pred_xgb).correlation
126
+ xgb_metrics = compute_regression_metrics(y_validate, y_pred_xgb)
283
127
 
284
128
  print(f"\nXGBoost Point Prediction Performance:")
285
- print(f"rmse: {xgb_rmse:.3f}")
286
- print(f"mae: {xgb_mae:.3f}")
287
- print(f"medae: {xgb_medae:.3f}")
288
- print(f"r2: {xgb_r2:.3f}")
289
- print(f"spearmanr: {xgb_spearman:.3f}")
290
-
291
- # Define confidence levels we want to model
292
- confidence_levels = [0.50, 0.68, 0.80, 0.90, 0.95] # 50%, 68%, 80%, 90%, 95% confidence intervals
293
-
294
- # Store MAPIE models for each confidence level
295
- mapie_models = {}
296
-
297
- # Train models for each confidence level
298
- for confidence_level in confidence_levels:
299
- alpha = 1 - confidence_level
300
- lower_q = alpha / 2
301
- upper_q = 1 - alpha / 2
302
-
303
- print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
304
- print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
305
-
306
- # Train three models for this confidence level
307
- quantile_estimators = []
308
- for q in [lower_q, upper_q, 0.5]:
309
- print(f" Training model for quantile {q:.3f}...")
310
- est = LGBMRegressor(
311
- objective="quantile",
312
- alpha=q,
313
- n_estimators=1000,
314
- max_depth=6,
315
- learning_rate=0.01,
316
- num_leaves=31,
317
- min_child_samples=20,
318
- subsample=0.8,
319
- colsample_bytree=0.8,
320
- random_state=42,
321
- verbose=-1,
322
- force_col_wise=True,
323
- )
324
- est.fit(X_train, y_train)
325
- quantile_estimators.append(est)
326
-
327
- # Create MAPIE CQR model for this confidence level
328
- print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
329
- mapie_model = ConformalizedQuantileRegressor(
330
- quantile_estimators, confidence_level=confidence_level, prefit=True
331
- )
332
-
333
- # Conformalize the model
334
- print(f" Conformalizing with validation data...")
335
- mapie_model.conformalize(X_validate, y_validate)
336
-
337
- # Store the model
338
- mapie_models[f"mapie_{confidence_level:.2f}"] = mapie_model
339
-
340
- # Validate coverage for this confidence level
341
- y_pred, y_pis = mapie_model.predict_interval(X_validate)
342
- coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
343
- print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
344
-
345
- support = len(df_val)
129
+ print_regression_metrics(xgb_metrics)
130
+
131
+ # ==========================================
132
+ # Train UQ models using the harness
133
+ # ==========================================
134
+ uq_models, uq_metadata = train_uq_models(X_train, y_train, X_validate, y_validate)
135
+
346
136
  print(f"\nOverall Model Performance Summary:")
347
- print(f"rmse: {xgb_rmse:.3f}")
348
- print(f"mae: {xgb_mae:.3f}")
349
- print(f"medae: {xgb_medae:.3f}")
350
- print(f"r2: {xgb_r2:.3f}")
351
- print(f"spearmanr: {xgb_spearman:.3f}")
352
- print(f"support: {support}")
353
-
354
- # Analyze interval widths across confidence levels
355
- print(f"\nInterval Width Analysis:")
356
- for conf_level in confidence_levels:
357
- model = mapie_models[f"mapie_{conf_level:.2f}"]
358
- _, y_pis = model.predict_interval(X_validate)
359
- widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
360
- print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
361
-
362
- # Compute normalization statistics for confidence calculation
363
- print(f"\nComputing normalization statistics for confidence scores...")
364
-
365
- # Add predictions directly to validation dataframe
366
- df_val["prediction"] = xgb_model.predict(X_validate)
367
-
368
- # Add all quantile predictions
369
- for conf_level in confidence_levels:
370
- model_name = f"mapie_{conf_level:.2f}"
371
- model = mapie_models[model_name]
372
- y_pred, y_pis = model.predict_interval(X_validate)
373
-
374
- if conf_level == 0.50:
375
- df_val["q_25"] = y_pis[:, 0, 0]
376
- df_val["q_75"] = y_pis[:, 1, 0]
377
- # y_pred is the median prediction
378
- df_val["q_50"] = y_pred
379
- elif conf_level == 0.68:
380
- df_val["q_16"] = y_pis[:, 0, 0]
381
- df_val["q_84"] = y_pis[:, 1, 0]
382
- elif conf_level == 0.80:
383
- df_val["q_10"] = y_pis[:, 0, 0]
384
- df_val["q_90"] = y_pis[:, 1, 0]
385
- elif conf_level == 0.90:
386
- df_val["q_05"] = y_pis[:, 0, 0]
387
- df_val["q_95"] = y_pis[:, 1, 0]
388
- elif conf_level == 0.95:
389
- df_val["q_025"] = y_pis[:, 0, 0]
390
- df_val["q_975"] = y_pis[:, 1, 0]
391
-
392
- # Compute normalization stats using q_10 and q_90 (default range)
393
- interval_width = (df_val["q_90"] - df_val["q_10"]).abs()
394
- median_interval_width = float(interval_width.median())
395
- print(f" Median interval width (q_10-q_90): {median_interval_width:.6f}")
396
-
397
- # Save median interval width for confidence calculation
398
- with open(os.path.join(args.model_dir, "median_interval_width.json"), "w") as fp:
399
- json.dump(median_interval_width, fp)
137
+ print_regression_metrics(xgb_metrics)
400
138
 
139
+ # ==========================================
140
+ # Save all models
141
+ # ==========================================
401
142
  # Save the trained XGBoost model
402
143
  joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
403
144
 
404
- # Save all MAPIE models
405
- for model_name, model in mapie_models.items():
406
- joblib.dump(model, os.path.join(args.model_dir, f"{model_name}.joblib"))
145
+ # Save UQ models using the harness
146
+ save_uq_models(uq_models, uq_metadata, args.model_dir)
407
147
 
408
148
  # Save the feature list
409
149
  with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
@@ -416,14 +156,14 @@ if __name__ == "__main__":
416
156
 
417
157
  # Save model configuration
418
158
  model_config = {
419
- "model_type": "XGBoost_MAPIE_CQR_LightGBM",
420
- "confidence_levels": confidence_levels,
159
+ "model_type": "XGBoost_MAPIE_UQ",
160
+ "confidence_levels": uq_metadata["confidence_levels"],
421
161
  "n_features": len(features),
422
162
  "target": target,
423
163
  "validation_metrics": {
424
- "xgb_rmse": float(xgb_rmse),
425
- "xgb_mae": float(xgb_mae),
426
- "xgb_r2": float(xgb_r2),
164
+ "xgb_rmse": float(xgb_metrics["rmse"]),
165
+ "xgb_mae": float(xgb_metrics["mae"]),
166
+ "xgb_r2": float(xgb_metrics["r2"]),
427
167
  "n_validation": len(df_val),
428
168
  },
429
169
  }
@@ -431,16 +171,16 @@ if __name__ == "__main__":
431
171
  json.dump(model_config, fp, indent=2)
432
172
 
433
173
  print(f"\nModel training complete!")
434
- print(f"Saved 1 XGBoost model and {len(mapie_models)} MAPIE models to {args.model_dir}")
174
+ print(f"Saved XGBoost model and {len(uq_models)} UQ models to {args.model_dir}")
435
175
 
436
176
 
437
177
  #
438
178
  # Inference Section
439
179
  #
440
180
  def model_fn(model_dir) -> dict:
441
- """Load XGBoost and all MAPIE models from the specified directory."""
181
+ """Load XGBoost and all UQ models from the specified directory."""
442
182
 
443
- # Load model configuration to know which models to load
183
+ # Load model configuration
444
184
  with open(os.path.join(model_dir, "model_config.json")) as fp:
445
185
  config = json.load(fp)
446
186
 
@@ -448,11 +188,8 @@ def model_fn(model_dir) -> dict:
448
188
  xgb_path = os.path.join(model_dir, "xgb_model.joblib")
449
189
  xgb_model = joblib.load(xgb_path)
450
190
 
451
- # Load all MAPIE models
452
- mapie_models = {}
453
- for conf_level in config["confidence_levels"]:
454
- model_name = f"mapie_{conf_level:.2f}"
455
- mapie_models[model_name] = joblib.load(os.path.join(model_dir, f"{model_name}.joblib"))
191
+ # Load UQ models using the harness
192
+ uq_models, uq_metadata = load_uq_models(model_dir)
456
193
 
457
194
  # Load category mappings if they exist
458
195
  category_mappings = {}
@@ -461,68 +198,24 @@ def model_fn(model_dir) -> dict:
461
198
  with open(category_path) as fp:
462
199
  category_mappings = json.load(fp)
463
200
 
464
- # Load median interval width for confidence calculation
465
- median_interval_width = None
466
- median_width_path = os.path.join(model_dir, "median_interval_width.json")
467
- if os.path.exists(median_width_path):
468
- with open(median_width_path) as fp:
469
- median_interval_width = json.load(fp)
470
-
471
201
  return {
472
202
  "xgb_model": xgb_model,
473
- "mapie_models": mapie_models,
474
- "confidence_levels": config["confidence_levels"],
203
+ "uq_models": uq_models,
204
+ "uq_metadata": uq_metadata,
475
205
  "category_mappings": category_mappings,
476
- "median_interval_width": median_interval_width,
477
206
  }
478
207
 
479
208
 
480
- def input_fn(input_data, content_type):
481
- """Parse input data and return a DataFrame."""
482
- if not input_data:
483
- raise ValueError("Empty input data is not supported!")
484
-
485
- # Decode bytes to string if necessary
486
- if isinstance(input_data, bytes):
487
- input_data = input_data.decode("utf-8")
488
-
489
- if "text/csv" in content_type:
490
- return pd.read_csv(StringIO(input_data))
491
- elif "application/json" in content_type:
492
- return pd.DataFrame(json.loads(input_data))
493
- else:
494
- raise ValueError(f"{content_type} not supported!")
495
-
496
-
497
- def output_fn(output_df, accept_type):
498
- """Supports both CSV and JSON output formats."""
499
- if "text/csv" in accept_type:
500
- # Convert categorical columns to string to avoid fillna issues
501
- for col in output_df.select_dtypes(include=["category"]).columns:
502
- output_df[col] = output_df[col].astype(str)
503
- csv_output = output_df.fillna("N/A").to_csv(index=False)
504
- return csv_output, "text/csv"
505
- elif "application/json" in accept_type:
506
- return output_df.to_json(orient="records"), "application/json"
507
- else:
508
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
509
-
510
-
511
209
  def predict_fn(df, models) -> pd.DataFrame:
512
- """Make predictions using XGBoost for point estimates and MAPIE for conformalized intervals
210
+ """Make predictions using XGBoost for point estimates and UQ harness for intervals.
513
211
 
514
212
  Args:
515
213
  df (pd.DataFrame): The input DataFrame
516
- models (dict): Dictionary containing XGBoost and MAPIE models
214
+ models (dict): Dictionary containing XGBoost and UQ models
517
215
 
518
216
  Returns:
519
- pd.DataFrame: DataFrame with XGBoost predictions and conformalized intervals
217
+ pd.DataFrame: DataFrame with predictions and conformalized intervals
520
218
  """
521
-
522
- # Flag for outlier stretch adjustment for the prediction intervals
523
- # if the predicted values are outside the intervals
524
- outlier_stretch = False
525
-
526
219
  # Grab our feature columns (from training)
527
220
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
528
221
  with open(os.path.join(model_dir, "feature_columns.json")) as fp:
@@ -541,65 +234,15 @@ def predict_fn(df, models) -> pd.DataFrame:
541
234
  # Get XGBoost point predictions
542
235
  df["prediction"] = models["xgb_model"].predict(X)
543
236
 
544
- # Get predictions from each MAPIE model for conformalized intervals
545
- for conf_level in models["confidence_levels"]:
546
- model_name = f"mapie_{conf_level:.2f}"
547
- model = models["mapie_models"][model_name]
548
-
549
- # Get conformalized predictions
550
- y_pred, y_pis = model.predict_interval(X)
551
-
552
- # Map confidence levels to quantile names
553
- if conf_level == 0.50: # 50% CI
554
- df["q_25"] = y_pis[:, 0, 0]
555
- df["q_75"] = y_pis[:, 1, 0]
556
- # y_pred is the median prediction
557
- df["q_50"] = y_pred
558
- elif conf_level == 0.68: # 68% CI
559
- df["q_16"] = y_pis[:, 0, 0]
560
- df["q_84"] = y_pis[:, 1, 0]
561
- elif conf_level == 0.80: # 80% CI
562
- df["q_10"] = y_pis[:, 0, 0]
563
- df["q_90"] = y_pis[:, 1, 0]
564
- elif conf_level == 0.90: # 90% CI
565
- df["q_05"] = y_pis[:, 0, 0]
566
- df["q_95"] = y_pis[:, 1, 0]
567
- elif conf_level == 0.95: # 95% CI
568
- df["q_025"] = y_pis[:, 0, 0]
569
- df["q_975"] = y_pis[:, 1, 0]
570
-
571
- # Calculate a pseudo-standard deviation from the 68% interval width
572
- df["prediction_std"] = (df["q_84"] - df["q_16"]).abs() / 2.0
573
-
574
- # Reorder the quantile columns for easier reading
575
- quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_50", "q_75", "q_84", "q_90", "q_95", "q_975"]
576
- other_cols = [col for col in df.columns if col not in quantile_cols]
577
- df = df[other_cols + quantile_cols]
578
-
579
- # Adjust the outer quantiles to ensure they encompass the prediction
580
- if outlier_stretch:
581
- # Lower intervals adjustments
582
- df["q_025"] = np.minimum(df["q_025"], df["prediction"])
583
- df["q_05"] = np.minimum(df["q_05"], df["prediction"])
584
- df["q_10"] = np.minimum(df["q_10"], df["prediction"])
585
- df["q_16"] = np.minimum(df["q_16"], df["prediction"])
586
- df["q_25"] = np.minimum(df["q_25"], df["prediction"])
587
-
588
- # Upper intervals adjustments
589
- df["q_75"] = np.maximum(df["q_75"], df["prediction"])
590
- df["q_84"] = np.maximum(df["q_84"], df["prediction"])
591
- df["q_90"] = np.maximum(df["q_90"], df["prediction"])
592
- df["q_95"] = np.maximum(df["q_95"], df["prediction"])
593
- df["q_975"] = np.maximum(df["q_975"], df["prediction"])
594
-
595
- # Compute confidence scores using pre-computed normalization stats
237
+ # Get prediction intervals using UQ harness
238
+ df = predict_intervals(df, X, models["uq_models"], models["uq_metadata"])
239
+
240
+ # Compute confidence scores
596
241
  df = compute_confidence(
597
242
  df,
243
+ median_interval_width=models["uq_metadata"]["median_interval_width"],
598
244
  lower_q="q_10",
599
245
  upper_q="q_90",
600
- alpha=1.0,
601
- beta=1.0,
602
- median_interval_width=models["median_interval_width"],
603
246
  )
604
247
 
605
248
  return df