workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,502 +0,0 @@
1
- # Model: XGBoost for point predictions + LightGBM with MAPIE for conformalized intervals
2
- from mapie.regression import ConformalizedQuantileRegressor
3
- from lightgbm import LGBMRegressor
4
- from xgboost import XGBRegressor
5
- from sklearn.model_selection import train_test_split
6
-
7
- # Model Performance Scores
8
- from sklearn.metrics import (
9
- mean_absolute_error,
10
- r2_score,
11
- root_mean_squared_error
12
- )
13
-
14
- from io import StringIO
15
- import json
16
- import argparse
17
- import joblib
18
- import os
19
- import numpy as np
20
- import pandas as pd
21
- from typing import List, Tuple
22
-
23
- # Template Placeholders
24
- TEMPLATE_PARAMS = {
25
- "target": "{{target_column}}",
26
- "features": "{{feature_list}}",
27
- "compressed_features": "{{compressed_features}}",
28
- "train_all_data": "{{train_all_data}}"
29
- }
30
-
31
-
32
- # Function to check if dataframe is empty
33
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
34
- """
35
- Check if the provided dataframe is empty and raise an exception if it is.
36
-
37
- Args:
38
- df (pd.DataFrame): DataFrame to check
39
- df_name (str): Name of the DataFrame
40
- """
41
- if df.empty:
42
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
43
- print(msg)
44
- raise ValueError(msg)
45
-
46
-
47
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
48
- """
49
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
50
- Prioritizes exact matches, then case-insensitive matches.
51
-
52
- Raises ValueError if any model features cannot be matched.
53
- """
54
- df_columns_lower = {col.lower(): col for col in df.columns}
55
- rename_dict = {}
56
- missing = []
57
- for feature in model_features:
58
- if feature in df.columns:
59
- continue # Exact match
60
- elif feature.lower() in df_columns_lower:
61
- rename_dict[df_columns_lower[feature.lower()]] = feature
62
- else:
63
- missing.append(feature)
64
-
65
- if missing:
66
- raise ValueError(f"Features not found: {missing}")
67
-
68
- # Rename the DataFrame columns to match the model features
69
- return df.rename(columns=rename_dict)
70
-
71
-
72
- def convert_categorical_types(df: pd.DataFrame, features: list, category_mappings={}) -> tuple:
73
- """
74
- Converts appropriate columns to categorical type with consistent mappings.
75
-
76
- Args:
77
- df (pd.DataFrame): The DataFrame to process.
78
- features (list): List of feature names to consider for conversion.
79
- category_mappings (dict, optional): Existing category mappings. If empty dict, we're in
80
- training mode. If populated, we're in inference mode.
81
-
82
- Returns:
83
- tuple: (processed DataFrame, category mappings dictionary)
84
- """
85
- # Training mode
86
- if category_mappings == {}:
87
- for col in df.select_dtypes(include=["object", "string"]):
88
- if col in features and df[col].nunique() < 20:
89
- print(f"Training mode: Converting {col} to category")
90
- df[col] = df[col].astype("category")
91
- category_mappings[col] = df[col].cat.categories.tolist() # Store category mappings
92
-
93
- # Inference mode
94
- else:
95
- for col, categories in category_mappings.items():
96
- if col in df.columns:
97
- print(f"Inference mode: Applying categorical mapping for {col}")
98
- df[col] = pd.Categorical(df[col], categories=categories) # Apply consistent categorical mapping
99
-
100
- return df, category_mappings
101
-
102
-
103
- def decompress_features(
104
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
105
- ) -> Tuple[pd.DataFrame, List[str]]:
106
- """Prepare features for the model by decompressing bitstring features
107
-
108
- Args:
109
- df (pd.DataFrame): The features DataFrame
110
- features (List[str]): Full list of feature names
111
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
112
-
113
- Returns:
114
- pd.DataFrame: DataFrame with the decompressed features
115
- List[str]: Updated list of feature names after decompression
116
-
117
- Raises:
118
- ValueError: If any missing values are found in the specified features
119
- """
120
-
121
- # Check for any missing values in the required features
122
- missing_counts = df[features].isna().sum()
123
- if missing_counts.any():
124
- missing_features = missing_counts[missing_counts > 0]
125
- print(
126
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
127
- "WARNING: You might want to remove/replace all NaN values before processing."
128
- )
129
-
130
- # Decompress the specified compressed features
131
- decompressed_features = features.copy()
132
- for feature in compressed_features:
133
- if (feature not in df.columns) or (feature not in features):
134
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
135
- continue
136
-
137
- # Remove the feature from the list of features to avoid duplication
138
- decompressed_features.remove(feature)
139
-
140
- # Handle all compressed features as bitstrings
141
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
142
- prefix = feature[:3]
143
-
144
- # Create all new columns at once - avoids fragmentation
145
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
146
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
147
-
148
- # Add to features list
149
- decompressed_features.extend(new_col_names)
150
-
151
- # Drop original column and concatenate new ones
152
- df = df.drop(columns=[feature])
153
- df = pd.concat([df, new_df], axis=1)
154
-
155
- return df, decompressed_features
156
-
157
-
158
- if __name__ == "__main__":
159
- # Template Parameters
160
- target = TEMPLATE_PARAMS["target"]
161
- features = TEMPLATE_PARAMS["features"]
162
- orig_features = features.copy()
163
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
164
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
165
- validation_split = 0.2
166
-
167
- # Script arguments for input/output directories
168
- parser = argparse.ArgumentParser()
169
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
170
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
171
- parser.add_argument(
172
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
173
- )
174
- args = parser.parse_args()
175
-
176
- # Read the training data into DataFrames
177
- training_files = [
178
- os.path.join(args.train, file)
179
- for file in os.listdir(args.train)
180
- if file.endswith(".csv")
181
- ]
182
- print(f"Training Files: {training_files}")
183
-
184
- # Combine files and read them all into a single pandas dataframe
185
- all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
186
-
187
- # Check if the dataframe is empty
188
- check_dataframe(all_df, "training_df")
189
-
190
- # Features/Target output
191
- print(f"Target: {target}")
192
- print(f"Features: {str(features)}")
193
-
194
- # Convert any features that might be categorical to 'category' type
195
- all_df, category_mappings = convert_categorical_types(all_df, features)
196
-
197
- # If we have compressed features, decompress them
198
- if compressed_features:
199
- print(f"Decompressing features {compressed_features}...")
200
- all_df, features = decompress_features(all_df, features, compressed_features)
201
-
202
- # Do we want to train on all the data?
203
- if train_all_data:
204
- print("Training on ALL of the data")
205
- df_train = all_df.copy()
206
- df_val = all_df.copy()
207
-
208
- # Does the dataframe have a training column?
209
- elif "training" in all_df.columns:
210
- print("Found training column, splitting data based on training column")
211
- df_train = all_df[all_df["training"]]
212
- df_val = all_df[~all_df["training"]]
213
- else:
214
- # Just do a random training Split
215
- print("WARNING: No training column found, splitting data with random state=42")
216
- df_train, df_val = train_test_split(
217
- all_df, test_size=validation_split, random_state=42
218
- )
219
- print(f"FIT/TRAIN: {df_train.shape}")
220
- print(f"VALIDATION: {df_val.shape}")
221
-
222
- # Prepare features and targets for training
223
- X_train = df_train[features]
224
- X_validate = df_val[features]
225
- y_train = df_train[target]
226
- y_validate = df_val[target]
227
-
228
- # Train XGBoost for point predictions
229
- print("\nTraining XGBoost for point predictions...")
230
- xgb_model = XGBRegressor(
231
- n_estimators=1000,
232
- max_depth=6,
233
- learning_rate=0.01,
234
- subsample=0.8,
235
- colsample_bytree=0.8,
236
- random_state=42,
237
- verbosity=0
238
- )
239
- xgb_model.fit(X_train, y_train)
240
-
241
- # Evaluate XGBoost performance
242
- y_pred_xgb = xgb_model.predict(X_validate)
243
- xgb_rmse = root_mean_squared_error(y_validate, y_pred_xgb)
244
- xgb_mae = mean_absolute_error(y_validate, y_pred_xgb)
245
- xgb_r2 = r2_score(y_validate, y_pred_xgb)
246
-
247
- print(f"\nXGBoost Point Prediction Performance:")
248
- print(f"RMSE: {xgb_rmse:.3f}")
249
- print(f"MAE: {xgb_mae:.3f}")
250
- print(f"R2: {xgb_r2:.3f}")
251
-
252
- # Define confidence levels we want to model
253
- confidence_levels = [0.50, 0.80, 0.90, 0.95] # 50%, 80%, 90%, 95% confidence intervals
254
-
255
- # Store MAPIE models for each confidence level
256
- mapie_models = {}
257
-
258
- # Train models for each confidence level
259
- for confidence_level in confidence_levels:
260
- alpha = 1 - confidence_level
261
- lower_q = alpha / 2
262
- upper_q = 1 - alpha / 2
263
-
264
- print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
265
- print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
266
-
267
- # Train three models for this confidence level
268
- quantile_estimators = []
269
- for q in [lower_q, upper_q, 0.5]:
270
- print(f" Training model for quantile {q:.3f}...")
271
- est = LGBMRegressor(
272
- objective="quantile",
273
- alpha=q,
274
- n_estimators=1000,
275
- max_depth=6,
276
- learning_rate=0.01,
277
- num_leaves=31,
278
- min_child_samples=20,
279
- subsample=0.8,
280
- colsample_bytree=0.8,
281
- random_state=42,
282
- verbose=-1,
283
- force_col_wise=True
284
- )
285
- est.fit(X_train, y_train)
286
- quantile_estimators.append(est)
287
-
288
- # Create MAPIE CQR model for this confidence level
289
- print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
290
- mapie_model = ConformalizedQuantileRegressor(
291
- quantile_estimators,
292
- confidence_level=confidence_level,
293
- prefit=True
294
- )
295
-
296
- # Conformalize the model
297
- print(f" Conformalizing with validation data...")
298
- mapie_model.conformalize(X_validate, y_validate)
299
-
300
- # Store the model
301
- mapie_models[f"mapie_{confidence_level:.2f}"] = mapie_model
302
-
303
- # Validate coverage for this confidence level
304
- y_pred, y_pis = mapie_model.predict_interval(X_validate)
305
- coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
306
- print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
307
-
308
- print(f"\nOverall Model Performance Summary:")
309
- print(f"XGBoost RMSE: {xgb_rmse:.3f}")
310
- print(f"XGBoost MAE: {xgb_mae:.3f}")
311
- print(f"XGBoost R2: {xgb_r2:.3f}")
312
- print(f"NumRows: {len(df_val)}")
313
-
314
- # Analyze interval widths across confidence levels
315
- print(f"\nInterval Width Analysis:")
316
- for conf_level in confidence_levels:
317
- model = mapie_models[f"mapie_{conf_level:.2f}"]
318
- _, y_pis = model.predict_interval(X_validate)
319
- widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
320
- print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
321
-
322
- # Save the trained XGBoost model
323
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
324
-
325
- # Save all MAPIE models
326
- for model_name, model in mapie_models.items():
327
- joblib.dump(model, os.path.join(args.model_dir, f"{model_name}.joblib"))
328
-
329
- # Save the feature list
330
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
331
- json.dump(features, fp)
332
-
333
- # Save category mappings if any
334
- if category_mappings:
335
- with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
336
- json.dump(category_mappings, fp)
337
-
338
- # Save model configuration
339
- model_config = {
340
- "model_type": "XGBoost_MAPIE_CQR_LightGBM",
341
- "confidence_levels": confidence_levels,
342
- "n_features": len(features),
343
- "target": target,
344
- "validation_metrics": {
345
- "xgb_rmse": float(xgb_rmse),
346
- "xgb_mae": float(xgb_mae),
347
- "xgb_r2": float(xgb_r2),
348
- "n_validation": len(df_val)
349
- }
350
- }
351
- with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
352
- json.dump(model_config, fp, indent=2)
353
-
354
- print(f"\nModel training complete!")
355
- print(f"Saved 1 XGBoost model and {len(mapie_models)} MAPIE models to {args.model_dir}")
356
-
357
-
358
- #
359
- # Inference Section
360
- #
361
- def model_fn(model_dir) -> dict:
362
- """Load XGBoost and all MAPIE models from the specified directory."""
363
-
364
- # Load model configuration to know which models to load
365
- with open(os.path.join(model_dir, "model_config.json")) as fp:
366
- config = json.load(fp)
367
-
368
- # Load XGBoost regressor
369
- xgb_path = os.path.join(model_dir, "xgb_model.json")
370
- xgb_model = XGBRegressor(enable_categorical=True)
371
- xgb_model.load_model(xgb_path)
372
-
373
- # Load all MAPIE models
374
- mapie_models = {}
375
- for conf_level in config["confidence_levels"]:
376
- model_name = f"mapie_{conf_level:.2f}"
377
- mapie_models[model_name] = joblib.load(os.path.join(model_dir, f"{model_name}.joblib"))
378
-
379
- # Load category mappings if they exist
380
- category_mappings = {}
381
- category_path = os.path.join(model_dir, "category_mappings.json")
382
- if os.path.exists(category_path):
383
- with open(category_path) as fp:
384
- category_mappings = json.load(fp)
385
-
386
- return {
387
- "xgb_model": xgb_model,
388
- "mapie_models": mapie_models,
389
- "confidence_levels": config["confidence_levels"],
390
- "category_mappings": category_mappings
391
- }
392
-
393
-
394
- def input_fn(input_data, content_type):
395
- """Parse input data and return a DataFrame."""
396
- if not input_data:
397
- raise ValueError("Empty input data is not supported!")
398
-
399
- # Decode bytes to string if necessary
400
- if isinstance(input_data, bytes):
401
- input_data = input_data.decode("utf-8")
402
-
403
- if "text/csv" in content_type:
404
- return pd.read_csv(StringIO(input_data))
405
- elif "application/json" in content_type:
406
- return pd.DataFrame(json.loads(input_data))
407
- else:
408
- raise ValueError(f"{content_type} not supported!")
409
-
410
-
411
- def output_fn(output_df, accept_type):
412
- """Supports both CSV and JSON output formats."""
413
- if "text/csv" in accept_type:
414
- # Convert categorical columns to string to avoid fillna issues
415
- for col in output_df.select_dtypes(include=['category']).columns:
416
- output_df[col] = output_df[col].astype(str)
417
- csv_output = output_df.fillna("N/A").to_csv(index=False)
418
- return csv_output, "text/csv"
419
- elif "application/json" in accept_type:
420
- return output_df.to_json(orient="records"), "application/json"
421
- else:
422
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
423
-
424
-
425
- def predict_fn(df, models) -> pd.DataFrame:
426
- """Make predictions using XGBoost for point estimates and MAPIE for conformalized intervals
427
-
428
- Args:
429
- df (pd.DataFrame): The input DataFrame
430
- models (dict): Dictionary containing XGBoost and MAPIE models
431
-
432
- Returns:
433
- pd.DataFrame: DataFrame with XGBoost predictions and conformalized intervals
434
- """
435
-
436
- # Grab our feature columns (from training)
437
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
438
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
439
- model_features = json.load(fp)
440
-
441
- # Match features in a case-insensitive manner
442
- matched_df = match_features_case_insensitive(df, model_features)
443
-
444
- # Apply categorical mappings if they exist
445
- if models.get("category_mappings"):
446
- matched_df, _ = convert_categorical_types(
447
- matched_df,
448
- model_features,
449
- models["category_mappings"]
450
- )
451
-
452
- # Get features for prediction
453
- X = matched_df[model_features]
454
-
455
- # Get XGBoost point predictions
456
- df["prediction"] = models["xgb_model"].predict(X)
457
-
458
- # Get predictions from each MAPIE model for conformalized intervals
459
- for conf_level in models["confidence_levels"]:
460
- model_name = f"mapie_{conf_level:.2f}"
461
- model = models["mapie_models"][model_name]
462
-
463
- # Get conformalized predictions
464
- y_pred, y_pis = model.predict_interval(X)
465
-
466
- # Map confidence levels to quantile names
467
- if conf_level == 0.50: # 50% CI
468
- df["q_25"] = y_pis[:, 0, 0]
469
- df["q_75"] = y_pis[:, 1, 0]
470
- elif conf_level == 0.80: # 80% CI
471
- df["q_10"] = y_pis[:, 0, 0]
472
- df["q_90"] = y_pis[:, 1, 0]
473
- elif conf_level == 0.90: # 90% CI
474
- df["q_05"] = y_pis[:, 0, 0]
475
- df["q_95"] = y_pis[:, 1, 0]
476
- elif conf_level == 0.95: # 95% CI
477
- df["q_025"] = y_pis[:, 0, 0]
478
- df["q_975"] = y_pis[:, 1, 0]
479
-
480
- # Add median (q_50) from XGBoost prediction
481
- df["q_50"] = df["prediction"]
482
-
483
- # Calculate uncertainty metrics based on 95% interval
484
- interval_width = df["q_975"] - df["q_025"]
485
- df["prediction_std"] = interval_width / 3.92
486
-
487
- # Reorder the quantile columns for easier reading
488
- quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
489
- other_cols = [col for col in df.columns if col not in quantile_cols]
490
- df = df[other_cols + quantile_cols]
491
-
492
- # Uncertainty score
493
- df["uncertainty_score"] = interval_width / (np.abs(df["prediction"]) + 1e-6)
494
-
495
- # Confidence bands
496
- df["confidence_band"] = pd.cut(
497
- df["uncertainty_score"],
498
- bins=[0, 0.5, 1.0, 2.0, np.inf],
499
- labels=["high", "medium", "low", "very_low"]
500
- )
501
-
502
- return df