workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (140) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +5 -5
  27. workbench/core/artifacts/df_store_core.py +114 -0
  28. workbench/core/artifacts/endpoint_core.py +319 -204
  29. workbench/core/artifacts/feature_set_core.py +249 -45
  30. workbench/core/artifacts/model_core.py +135 -82
  31. workbench/core/artifacts/parameter_store_core.py +98 -0
  32. workbench/core/cloud_platform/cloud_meta.py +0 -1
  33. workbench/core/pipelines/pipeline_executor.py +1 -1
  34. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  35. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  36. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  37. workbench/core/views/training_view.py +113 -42
  38. workbench/core/views/view.py +53 -3
  39. workbench/core/views/view_utils.py +4 -4
  40. workbench/model_script_utils/model_script_utils.py +339 -0
  41. workbench/model_script_utils/pytorch_utils.py +405 -0
  42. workbench/model_script_utils/uq_harness.py +277 -0
  43. workbench/model_scripts/chemprop/chemprop.template +774 -0
  44. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  45. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  46. workbench/model_scripts/chemprop/requirements.txt +3 -0
  47. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  48. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  49. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  50. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  51. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  52. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  53. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  54. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  55. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  56. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  57. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  58. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  59. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  60. workbench/model_scripts/meta_model/meta_model.template +209 -0
  61. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  62. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  63. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  64. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  65. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  66. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  67. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  68. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  69. workbench/model_scripts/script_generation.py +15 -12
  70. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  71. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  72. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  73. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  74. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  75. workbench/repl/workbench_shell.py +18 -14
  76. workbench/resources/open_source_api.key +1 -1
  77. workbench/scripts/endpoint_test.py +162 -0
  78. workbench/scripts/lambda_test.py +73 -0
  79. workbench/scripts/meta_model_sim.py +35 -0
  80. workbench/scripts/ml_pipeline_sqs.py +122 -6
  81. workbench/scripts/training_test.py +85 -0
  82. workbench/themes/dark/custom.css +59 -0
  83. workbench/themes/dark/plotly.json +5 -5
  84. workbench/themes/light/custom.css +153 -40
  85. workbench/themes/light/plotly.json +9 -9
  86. workbench/themes/midnight_blue/custom.css +59 -0
  87. workbench/utils/aws_utils.py +0 -1
  88. workbench/utils/chem_utils/fingerprints.py +87 -46
  89. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  90. workbench/utils/chem_utils/projections.py +16 -6
  91. workbench/utils/chem_utils/vis.py +25 -27
  92. workbench/utils/chemprop_utils.py +141 -0
  93. workbench/utils/config_manager.py +2 -6
  94. workbench/utils/endpoint_utils.py +5 -7
  95. workbench/utils/license_manager.py +2 -6
  96. workbench/utils/markdown_utils.py +57 -0
  97. workbench/utils/meta_model_simulator.py +499 -0
  98. workbench/utils/metrics_utils.py +256 -0
  99. workbench/utils/model_utils.py +260 -76
  100. workbench/utils/pipeline_utils.py +0 -1
  101. workbench/utils/plot_utils.py +159 -34
  102. workbench/utils/pytorch_utils.py +87 -0
  103. workbench/utils/shap_utils.py +11 -57
  104. workbench/utils/theme_manager.py +95 -30
  105. workbench/utils/xgboost_local_crossfold.py +267 -0
  106. workbench/utils/xgboost_model_utils.py +127 -220
  107. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  108. workbench/web_interface/components/model_plot.py +16 -2
  109. workbench/web_interface/components/plugin_unit_test.py +5 -3
  110. workbench/web_interface/components/plugins/ag_table.py +2 -4
  111. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  112. workbench/web_interface/components/plugins/model_details.py +48 -80
  113. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  114. workbench/web_interface/components/settings_menu.py +184 -0
  115. workbench/web_interface/page_views/main_page.py +0 -1
  116. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  117. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
  118. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  119. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  120. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  121. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  122. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  123. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  124. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  125. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
  126. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
  127. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  128. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  129. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  130. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  131. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  132. workbench/themes/quartz/base_css.url +0 -1
  133. workbench/themes/quartz/custom.css +0 -117
  134. workbench/themes/quartz/plotly.json +0 -642
  135. workbench/themes/quartz_dark/base_css.url +0 -1
  136. workbench/themes/quartz_dark/custom.css +0 -131
  137. workbench/themes/quartz_dark/plotly.json +0 -642
  138. workbench/utils/resource_utils.py +0 -39
  139. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  140. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,494 +0,0 @@
1
- # Model: XGBoost for point predictions + LightGBM with MAPIE for conformalized intervals
2
- from mapie.regression import ConformalizedQuantileRegressor
3
- from lightgbm import LGBMRegressor
4
- from xgboost import XGBRegressor
5
- from sklearn.model_selection import train_test_split
6
-
7
- # Model Performance Scores
8
- from sklearn.metrics import (
9
- mean_absolute_error,
10
- r2_score,
11
- root_mean_squared_error
12
- )
13
-
14
- from io import StringIO
15
- import json
16
- import argparse
17
- import joblib
18
- import os
19
- import numpy as np
20
- import pandas as pd
21
- from typing import List, Tuple
22
-
23
- # Template Placeholders
24
- TEMPLATE_PARAMS = {
25
- "target": "{{target_column}}",
26
- "features": "{{feature_list}}",
27
- "compressed_features": "{{compressed_features}}",
28
- "train_all_data": "{{train_all_data}}"
29
- }
30
-
31
-
32
- # Function to check if dataframe is empty
33
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
34
- """
35
- Check if the provided dataframe is empty and raise an exception if it is.
36
-
37
- Args:
38
- df (pd.DataFrame): DataFrame to check
39
- df_name (str): Name of the DataFrame
40
- """
41
- if df.empty:
42
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
43
- print(msg)
44
- raise ValueError(msg)
45
-
46
-
47
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
48
- """
49
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
50
- Prioritizes exact matches, then case-insensitive matches.
51
-
52
- Raises ValueError if any model features cannot be matched.
53
- """
54
- df_columns_lower = {col.lower(): col for col in df.columns}
55
- rename_dict = {}
56
- missing = []
57
- for feature in model_features:
58
- if feature in df.columns:
59
- continue # Exact match
60
- elif feature.lower() in df_columns_lower:
61
- rename_dict[df_columns_lower[feature.lower()]] = feature
62
- else:
63
- missing.append(feature)
64
-
65
- if missing:
66
- raise ValueError(f"Features not found: {missing}")
67
-
68
- # Rename the DataFrame columns to match the model features
69
- return df.rename(columns=rename_dict)
70
-
71
-
72
- def convert_categorical_types(df: pd.DataFrame, features: list, category_mappings={}) -> tuple:
73
- """
74
- Converts appropriate columns to categorical type with consistent mappings.
75
-
76
- Args:
77
- df (pd.DataFrame): The DataFrame to process.
78
- features (list): List of feature names to consider for conversion.
79
- category_mappings (dict, optional): Existing category mappings. If empty dict, we're in
80
- training mode. If populated, we're in inference mode.
81
-
82
- Returns:
83
- tuple: (processed DataFrame, category mappings dictionary)
84
- """
85
- # Training mode
86
- if category_mappings == {}:
87
- for col in df.select_dtypes(include=["object", "string"]):
88
- if col in features and df[col].nunique() < 20:
89
- print(f"Training mode: Converting {col} to category")
90
- df[col] = df[col].astype("category")
91
- category_mappings[col] = df[col].cat.categories.tolist() # Store category mappings
92
-
93
- # Inference mode
94
- else:
95
- for col, categories in category_mappings.items():
96
- if col in df.columns:
97
- print(f"Inference mode: Applying categorical mapping for {col}")
98
- df[col] = pd.Categorical(df[col], categories=categories) # Apply consistent categorical mapping
99
-
100
- return df, category_mappings
101
-
102
-
103
- def decompress_features(
104
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
105
- ) -> Tuple[pd.DataFrame, List[str]]:
106
- """Prepare features for the model by decompressing bitstring features
107
-
108
- Args:
109
- df (pd.DataFrame): The features DataFrame
110
- features (List[str]): Full list of feature names
111
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
112
-
113
- Returns:
114
- pd.DataFrame: DataFrame with the decompressed features
115
- List[str]: Updated list of feature names after decompression
116
-
117
- Raises:
118
- ValueError: If any missing values are found in the specified features
119
- """
120
-
121
- # Check for any missing values in the required features
122
- missing_counts = df[features].isna().sum()
123
- if missing_counts.any():
124
- missing_features = missing_counts[missing_counts > 0]
125
- print(
126
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
127
- "WARNING: You might want to remove/replace all NaN values before processing."
128
- )
129
-
130
- # Decompress the specified compressed features
131
- decompressed_features = features.copy()
132
- for feature in compressed_features:
133
- if (feature not in df.columns) or (feature not in features):
134
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
135
- continue
136
-
137
- # Remove the feature from the list of features to avoid duplication
138
- decompressed_features.remove(feature)
139
-
140
- # Handle all compressed features as bitstrings
141
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
142
- prefix = feature[:3]
143
-
144
- # Create all new columns at once - avoids fragmentation
145
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
146
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
147
-
148
- # Add to features list
149
- decompressed_features.extend(new_col_names)
150
-
151
- # Drop original column and concatenate new ones
152
- df = df.drop(columns=[feature])
153
- df = pd.concat([df, new_df], axis=1)
154
-
155
- return df, decompressed_features
156
-
157
-
158
- if __name__ == "__main__":
159
- # Template Parameters
160
- target = TEMPLATE_PARAMS["target"]
161
- features = TEMPLATE_PARAMS["features"]
162
- orig_features = features.copy()
163
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
164
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
165
- validation_split = 0.2
166
-
167
- # Script arguments for input/output directories
168
- parser = argparse.ArgumentParser()
169
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
170
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
171
- parser.add_argument(
172
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
173
- )
174
- args = parser.parse_args()
175
-
176
- # Read the training data into DataFrames
177
- training_files = [
178
- os.path.join(args.train, file)
179
- for file in os.listdir(args.train)
180
- if file.endswith(".csv")
181
- ]
182
- print(f"Training Files: {training_files}")
183
-
184
- # Combine files and read them all into a single pandas dataframe
185
- all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
186
-
187
- # Check if the dataframe is empty
188
- check_dataframe(all_df, "training_df")
189
-
190
- # Features/Target output
191
- print(f"Target: {target}")
192
- print(f"Features: {str(features)}")
193
-
194
- # Convert any features that might be categorical to 'category' type
195
- all_df, category_mappings = convert_categorical_types(all_df, features)
196
-
197
- # If we have compressed features, decompress them
198
- if compressed_features:
199
- print(f"Decompressing features {compressed_features}...")
200
- all_df, features = decompress_features(all_df, features, compressed_features)
201
-
202
- # Do we want to train on all the data?
203
- if train_all_data:
204
- print("Training on ALL of the data")
205
- df_train = all_df.copy()
206
- df_val = all_df.copy()
207
-
208
- # Does the dataframe have a training column?
209
- elif "training" in all_df.columns:
210
- print("Found training column, splitting data based on training column")
211
- df_train = all_df[all_df["training"]]
212
- df_val = all_df[~all_df["training"]]
213
- else:
214
- # Just do a random training Split
215
- print("WARNING: No training column found, splitting data with random state=42")
216
- df_train, df_val = train_test_split(
217
- all_df, test_size=validation_split, random_state=42
218
- )
219
- print(f"FIT/TRAIN: {df_train.shape}")
220
- print(f"VALIDATION: {df_val.shape}")
221
-
222
- # Prepare features and targets for training
223
- X_train = df_train[features]
224
- X_validate = df_val[features]
225
- y_train = df_train[target]
226
- y_validate = df_val[target]
227
-
228
- # Train XGBoost for point predictions
229
- print("\nTraining XGBoost for point predictions...")
230
- xgb_model = XGBRegressor(enable_categorical=True)
231
- xgb_model.fit(X_train, y_train)
232
-
233
- # Evaluate XGBoost performance
234
- y_pred_xgb = xgb_model.predict(X_validate)
235
- xgb_rmse = root_mean_squared_error(y_validate, y_pred_xgb)
236
- xgb_mae = mean_absolute_error(y_validate, y_pred_xgb)
237
- xgb_r2 = r2_score(y_validate, y_pred_xgb)
238
-
239
- print(f"\nXGBoost Point Prediction Performance:")
240
- print(f"RMSE: {xgb_rmse:.3f}")
241
- print(f"MAE: {xgb_mae:.3f}")
242
- print(f"R2: {xgb_r2:.3f}")
243
-
244
- # Define confidence levels we want to model
245
- confidence_levels = [0.50, 0.80, 0.90, 0.95] # 50%, 80%, 90%, 95% confidence intervals
246
-
247
- # Store MAPIE models for each confidence level
248
- mapie_models = {}
249
-
250
- # Train models for each confidence level
251
- for confidence_level in confidence_levels:
252
- alpha = 1 - confidence_level
253
- lower_q = alpha / 2
254
- upper_q = 1 - alpha / 2
255
-
256
- print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
257
- print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
258
-
259
- # Train three models for this confidence level
260
- quantile_estimators = []
261
- for q in [lower_q, upper_q, 0.5]:
262
- print(f" Training model for quantile {q:.3f}...")
263
- est = LGBMRegressor(
264
- objective="quantile",
265
- alpha=q,
266
- n_estimators=1000,
267
- max_depth=6,
268
- learning_rate=0.01,
269
- num_leaves=31,
270
- min_child_samples=20,
271
- subsample=0.8,
272
- colsample_bytree=0.8,
273
- random_state=42,
274
- verbose=-1,
275
- force_col_wise=True
276
- )
277
- est.fit(X_train, y_train)
278
- quantile_estimators.append(est)
279
-
280
- # Create MAPIE CQR model for this confidence level
281
- print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
282
- mapie_model = ConformalizedQuantileRegressor(
283
- quantile_estimators,
284
- confidence_level=confidence_level,
285
- prefit=True
286
- )
287
-
288
- # Conformalize the model
289
- print(f" Conformalizing with validation data...")
290
- mapie_model.conformalize(X_validate, y_validate)
291
-
292
- # Store the model
293
- mapie_models[f"mapie_{confidence_level:.2f}"] = mapie_model
294
-
295
- # Validate coverage for this confidence level
296
- y_pred, y_pis = mapie_model.predict_interval(X_validate)
297
- coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
298
- print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
299
-
300
- print(f"\nOverall Model Performance Summary:")
301
- print(f"XGBoost RMSE: {xgb_rmse:.3f}")
302
- print(f"XGBoost MAE: {xgb_mae:.3f}")
303
- print(f"XGBoost R2: {xgb_r2:.3f}")
304
- print(f"NumRows: {len(df_val)}")
305
-
306
- # Analyze interval widths across confidence levels
307
- print(f"\nInterval Width Analysis:")
308
- for conf_level in confidence_levels:
309
- model = mapie_models[f"mapie_{conf_level:.2f}"]
310
- _, y_pis = model.predict_interval(X_validate)
311
- widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
312
- print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
313
-
314
- # Save the trained XGBoost model
315
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
316
-
317
- # Save all MAPIE models
318
- for model_name, model in mapie_models.items():
319
- joblib.dump(model, os.path.join(args.model_dir, f"{model_name}.joblib"))
320
-
321
- # Save the feature list
322
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
323
- json.dump(features, fp)
324
-
325
- # Save category mappings if any
326
- if category_mappings:
327
- with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
328
- json.dump(category_mappings, fp)
329
-
330
- # Save model configuration
331
- model_config = {
332
- "model_type": "XGBoost_MAPIE_CQR_LightGBM",
333
- "confidence_levels": confidence_levels,
334
- "n_features": len(features),
335
- "target": target,
336
- "validation_metrics": {
337
- "xgb_rmse": float(xgb_rmse),
338
- "xgb_mae": float(xgb_mae),
339
- "xgb_r2": float(xgb_r2),
340
- "n_validation": len(df_val)
341
- }
342
- }
343
- with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
344
- json.dump(model_config, fp, indent=2)
345
-
346
- print(f"\nModel training complete!")
347
- print(f"Saved 1 XGBoost model and {len(mapie_models)} MAPIE models to {args.model_dir}")
348
-
349
-
350
- #
351
- # Inference Section
352
- #
353
- def model_fn(model_dir) -> dict:
354
- """Load XGBoost and all MAPIE models from the specified directory."""
355
-
356
- # Load model configuration to know which models to load
357
- with open(os.path.join(model_dir, "model_config.json")) as fp:
358
- config = json.load(fp)
359
-
360
- # Load XGBoost regressor
361
- xgb_path = os.path.join(model_dir, "xgb_model.json")
362
- xgb_model = XGBRegressor(enable_categorical=True)
363
- xgb_model.load_model(xgb_path)
364
-
365
- # Load all MAPIE models
366
- mapie_models = {}
367
- for conf_level in config["confidence_levels"]:
368
- model_name = f"mapie_{conf_level:.2f}"
369
- mapie_models[model_name] = joblib.load(os.path.join(model_dir, f"{model_name}.joblib"))
370
-
371
- # Load category mappings if they exist
372
- category_mappings = {}
373
- category_path = os.path.join(model_dir, "category_mappings.json")
374
- if os.path.exists(category_path):
375
- with open(category_path) as fp:
376
- category_mappings = json.load(fp)
377
-
378
- return {
379
- "xgb_model": xgb_model,
380
- "mapie_models": mapie_models,
381
- "confidence_levels": config["confidence_levels"],
382
- "category_mappings": category_mappings
383
- }
384
-
385
-
386
- def input_fn(input_data, content_type):
387
- """Parse input data and return a DataFrame."""
388
- if not input_data:
389
- raise ValueError("Empty input data is not supported!")
390
-
391
- # Decode bytes to string if necessary
392
- if isinstance(input_data, bytes):
393
- input_data = input_data.decode("utf-8")
394
-
395
- if "text/csv" in content_type:
396
- return pd.read_csv(StringIO(input_data))
397
- elif "application/json" in content_type:
398
- return pd.DataFrame(json.loads(input_data))
399
- else:
400
- raise ValueError(f"{content_type} not supported!")
401
-
402
-
403
- def output_fn(output_df, accept_type):
404
- """Supports both CSV and JSON output formats."""
405
- if "text/csv" in accept_type:
406
- # Convert categorical columns to string to avoid fillna issues
407
- for col in output_df.select_dtypes(include=['category']).columns:
408
- output_df[col] = output_df[col].astype(str)
409
- csv_output = output_df.fillna("N/A").to_csv(index=False)
410
- return csv_output, "text/csv"
411
- elif "application/json" in accept_type:
412
- return output_df.to_json(orient="records"), "application/json"
413
- else:
414
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
415
-
416
-
417
- def predict_fn(df, models) -> pd.DataFrame:
418
- """Make predictions using XGBoost for point estimates and MAPIE for conformalized intervals
419
-
420
- Args:
421
- df (pd.DataFrame): The input DataFrame
422
- models (dict): Dictionary containing XGBoost and MAPIE models
423
-
424
- Returns:
425
- pd.DataFrame: DataFrame with XGBoost predictions and conformalized intervals
426
- """
427
-
428
- # Grab our feature columns (from training)
429
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
430
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
431
- model_features = json.load(fp)
432
-
433
- # Match features in a case-insensitive manner
434
- matched_df = match_features_case_insensitive(df, model_features)
435
-
436
- # Apply categorical mappings if they exist
437
- if models.get("category_mappings"):
438
- matched_df, _ = convert_categorical_types(
439
- matched_df,
440
- model_features,
441
- models["category_mappings"]
442
- )
443
-
444
- # Get features for prediction
445
- X = matched_df[model_features]
446
-
447
- # Get XGBoost point predictions
448
- df["prediction"] = models["xgb_model"].predict(X)
449
-
450
- # Get predictions from each MAPIE model for conformalized intervals
451
- for conf_level in models["confidence_levels"]:
452
- model_name = f"mapie_{conf_level:.2f}"
453
- model = models["mapie_models"][model_name]
454
-
455
- # Get conformalized predictions
456
- y_pred, y_pis = model.predict_interval(X)
457
-
458
- # Map confidence levels to quantile names
459
- if conf_level == 0.50: # 50% CI
460
- df["q_25"] = y_pis[:, 0, 0]
461
- df["q_75"] = y_pis[:, 1, 0]
462
- elif conf_level == 0.80: # 80% CI
463
- df["q_10"] = y_pis[:, 0, 0]
464
- df["q_90"] = y_pis[:, 1, 0]
465
- elif conf_level == 0.90: # 90% CI
466
- df["q_05"] = y_pis[:, 0, 0]
467
- df["q_95"] = y_pis[:, 1, 0]
468
- elif conf_level == 0.95: # 95% CI
469
- df["q_025"] = y_pis[:, 0, 0]
470
- df["q_975"] = y_pis[:, 1, 0]
471
-
472
- # Add median (q_50) from XGBoost prediction
473
- df["q_50"] = df["prediction"]
474
-
475
- # Calculate uncertainty metrics based on 50% interval
476
- interval_width = df["q_75"] - df["q_25"]
477
- df["prediction_std"] = interval_width / 1.348
478
-
479
- # Reorder the quantile columns for easier reading
480
- quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
481
- other_cols = [col for col in df.columns if col not in quantile_cols]
482
- df = df[other_cols + quantile_cols]
483
-
484
- # Uncertainty score
485
- df["uncertainty_score"] = interval_width / (np.abs(df["prediction"]) + 1e-6)
486
-
487
- # Confidence bands
488
- df["confidence_band"] = pd.cut(
489
- df["uncertainty_score"],
490
- bins=[0, 0.5, 1.0, 2.0, np.inf],
491
- labels=["high", "medium", "low", "very_low"]
492
- )
493
-
494
- return df