workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,339 @@
1
+ """Shared utility functions for model training scripts (templates).
2
+
3
+ These functions are used across multiple model templates (XGBoost, PyTorch, ChemProp)
4
+ to reduce code duplication and ensure consistent behavior.
5
+ """
6
+
7
+ from io import StringIO
8
+ import json
9
+ import numpy as np
10
+ import pandas as pd
11
+ from sklearn.metrics import (
12
+ confusion_matrix,
13
+ mean_absolute_error,
14
+ median_absolute_error,
15
+ precision_recall_fscore_support,
16
+ r2_score,
17
+ root_mean_squared_error,
18
+ )
19
+ from scipy.stats import spearmanr
20
+
21
+
22
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
23
+ """Check if the provided dataframe is empty and raise an exception if it is.
24
+
25
+ Args:
26
+ df: DataFrame to check
27
+ df_name: Name of the DataFrame (for error message)
28
+
29
+ Raises:
30
+ ValueError: If the DataFrame is empty
31
+ """
32
+ if df.empty:
33
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
34
+ print(msg)
35
+ raise ValueError(msg)
36
+
37
+
38
+ def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
39
+ """Expands a column containing a list of probabilities into separate columns.
40
+
41
+ Handles None values for rows where predictions couldn't be made.
42
+
43
+ Args:
44
+ df: DataFrame containing a "pred_proba" column
45
+ class_labels: List of class labels
46
+
47
+ Returns:
48
+ DataFrame with the "pred_proba" expanded into separate columns (e.g., "class1_proba")
49
+
50
+ Raises:
51
+ ValueError: If DataFrame does not contain a "pred_proba" column
52
+ """
53
+ proba_column = "pred_proba"
54
+ if proba_column not in df.columns:
55
+ raise ValueError('DataFrame does not contain a "pred_proba" column')
56
+
57
+ proba_splits = [f"{label}_proba" for label in class_labels]
58
+ n_classes = len(class_labels)
59
+
60
+ # Handle None values by replacing with list of NaNs
61
+ proba_values = []
62
+ for val in df[proba_column]:
63
+ if val is None:
64
+ proba_values.append([np.nan] * n_classes)
65
+ else:
66
+ proba_values.append(val)
67
+
68
+ proba_df = pd.DataFrame(proba_values, columns=proba_splits)
69
+
70
+ # Drop any existing proba columns and reset index for concat
71
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
72
+ df = df.reset_index(drop=True)
73
+ df = pd.concat([df, proba_df], axis=1)
74
+ return df
75
+
76
+
77
+ def match_features_case_insensitive(df: pd.DataFrame, model_features: list[str]) -> pd.DataFrame:
78
+ """Matches and renames DataFrame columns to match model feature names (case-insensitive).
79
+
80
+ Prioritizes exact matches, then case-insensitive matches.
81
+
82
+ Args:
83
+ df: Input DataFrame
84
+ model_features: List of feature names expected by the model
85
+
86
+ Returns:
87
+ DataFrame with columns renamed to match model features
88
+
89
+ Raises:
90
+ ValueError: If any model features cannot be matched
91
+ """
92
+ df_columns_lower = {col.lower(): col for col in df.columns}
93
+ rename_dict = {}
94
+ missing = []
95
+ for feature in model_features:
96
+ if feature in df.columns:
97
+ continue # Exact match
98
+ elif feature.lower() in df_columns_lower:
99
+ rename_dict[df_columns_lower[feature.lower()]] = feature
100
+ else:
101
+ missing.append(feature)
102
+
103
+ if missing:
104
+ raise ValueError(f"Features not found: {missing}")
105
+
106
+ return df.rename(columns=rename_dict)
107
+
108
+
109
+ def convert_categorical_types(
110
+ df: pd.DataFrame, features: list[str], category_mappings: dict[str, list[str]] | None = None
111
+ ) -> tuple[pd.DataFrame, dict[str, list[str]]]:
112
+ """Converts appropriate columns to categorical type with consistent mappings.
113
+
114
+ In training mode (category_mappings is None or empty), detects object/string columns
115
+ with <20 unique values and converts them to categorical.
116
+ In inference mode (category_mappings provided), applies the stored mappings.
117
+
118
+ Args:
119
+ df: The DataFrame to process
120
+ features: List of feature names to consider for conversion
121
+ category_mappings: Existing category mappings. If None or empty, training mode.
122
+ If populated, inference mode.
123
+
124
+ Returns:
125
+ Tuple of (processed DataFrame, category mappings dictionary)
126
+ """
127
+ if category_mappings is None:
128
+ category_mappings = {}
129
+
130
+ # Training mode
131
+ if not category_mappings:
132
+ for col in df.select_dtypes(include=["object", "string"]):
133
+ if col in features and df[col].nunique() < 20:
134
+ print(f"Training mode: Converting {col} to category")
135
+ df[col] = df[col].astype("category")
136
+ category_mappings[col] = df[col].cat.categories.tolist()
137
+
138
+ # Inference mode
139
+ else:
140
+ for col, categories in category_mappings.items():
141
+ if col in df.columns:
142
+ print(f"Inference mode: Applying categorical mapping for {col}")
143
+ df[col] = pd.Categorical(df[col], categories=categories)
144
+
145
+ return df, category_mappings
146
+
147
+
148
+ def decompress_features(
149
+ df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
+ ) -> tuple[pd.DataFrame, list[str]]:
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
156
+
157
+ Args:
158
+ df: The features DataFrame
159
+ features: Full list of feature names
160
+ compressed_features: List of feature names to decompress
161
+
162
+ Returns:
163
+ Tuple of (DataFrame with decompressed features, updated feature list)
164
+ """
165
+ # Check for any missing values in the required features
166
+ missing_counts = df[features].isna().sum()
167
+ if missing_counts.any():
168
+ missing_features = missing_counts[missing_counts > 0]
169
+ print(
170
+ f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
171
+ "WARNING: You might want to remove/replace all NaN values before processing."
172
+ )
173
+
174
+ # Make a copy to avoid mutating the original list
175
+ decompressed_features = features.copy()
176
+
177
+ for feature in compressed_features:
178
+ if (feature not in df.columns) or (feature not in decompressed_features):
179
+ print(f"Feature '{feature}' not in the features list, skipping decompression.")
180
+ continue
181
+
182
+ # Remove the feature from the list to avoid duplication
183
+ decompressed_features.remove(feature)
184
+
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
189
+
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
194
+
195
+ # Update features list and dataframe
196
+ decompressed_features.extend(new_col_names)
197
+ df = df.drop(columns=[feature])
198
+ df = pd.concat([df, new_df], axis=1)
199
+
200
+ return df, decompressed_features
201
+
202
+
203
+ def input_fn(input_data, content_type: str) -> pd.DataFrame:
204
+ """Parse input data and return a DataFrame.
205
+
206
+ Args:
207
+ input_data: Raw input data (bytes or string)
208
+ content_type: MIME type of the input data
209
+
210
+ Returns:
211
+ Parsed DataFrame
212
+
213
+ Raises:
214
+ ValueError: If input is empty or content_type is not supported
215
+ """
216
+ if not input_data:
217
+ raise ValueError("Empty input data is not supported!")
218
+
219
+ if isinstance(input_data, bytes):
220
+ input_data = input_data.decode("utf-8")
221
+
222
+ if "text/csv" in content_type:
223
+ return pd.read_csv(StringIO(input_data))
224
+ elif "application/json" in content_type:
225
+ return pd.DataFrame(json.loads(input_data))
226
+ else:
227
+ raise ValueError(f"{content_type} not supported!")
228
+
229
+
230
+ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
231
+ """Convert output DataFrame to requested format.
232
+
233
+ Args:
234
+ output_df: DataFrame to convert
235
+ accept_type: Requested MIME type
236
+
237
+ Returns:
238
+ Tuple of (formatted output string, MIME type)
239
+
240
+ Raises:
241
+ RuntimeError: If accept_type is not supported
242
+ """
243
+ if "text/csv" in accept_type:
244
+ csv_output = output_df.fillna("N/A").to_csv(index=False)
245
+ return csv_output, "text/csv"
246
+ elif "application/json" in accept_type:
247
+ return output_df.to_json(orient="records"), "application/json"
248
+ else:
249
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
+
251
+
252
+ def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
+ """Compute standard regression metrics.
254
+
255
+ Args:
256
+ y_true: Ground truth target values
257
+ y_pred: Predicted values
258
+
259
+ Returns:
260
+ Dictionary with keys: rmse, mae, medae, r2, spearmanr, support
261
+ """
262
+ return {
263
+ "rmse": root_mean_squared_error(y_true, y_pred),
264
+ "mae": mean_absolute_error(y_true, y_pred),
265
+ "medae": median_absolute_error(y_true, y_pred),
266
+ "r2": r2_score(y_true, y_pred),
267
+ "spearmanr": spearmanr(y_true, y_pred).correlation,
268
+ "support": len(y_true),
269
+ }
270
+
271
+
272
+ def print_regression_metrics(metrics: dict[str, float]) -> None:
273
+ """Print regression metrics in the format expected by SageMaker metric definitions.
274
+
275
+ Args:
276
+ metrics: Dictionary of metric name -> value
277
+ """
278
+ print(f"rmse: {metrics['rmse']:.3f}")
279
+ print(f"mae: {metrics['mae']:.3f}")
280
+ print(f"medae: {metrics['medae']:.3f}")
281
+ print(f"r2: {metrics['r2']:.3f}")
282
+ print(f"spearmanr: {metrics['spearmanr']:.3f}")
283
+ print(f"support: {metrics['support']}")
284
+
285
+
286
+ def compute_classification_metrics(
287
+ y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str], target_col: str
288
+ ) -> pd.DataFrame:
289
+ """Compute per-class classification metrics.
290
+
291
+ Args:
292
+ y_true: Ground truth labels
293
+ y_pred: Predicted labels
294
+ label_names: List of class label names
295
+ target_col: Name of the target column (for DataFrame output)
296
+
297
+ Returns:
298
+ DataFrame with columns: target_col, precision, recall, f1, support
299
+ """
300
+ scores = precision_recall_fscore_support(y_true, y_pred, average=None, labels=label_names)
301
+ return pd.DataFrame(
302
+ {
303
+ target_col: label_names,
304
+ "precision": scores[0],
305
+ "recall": scores[1],
306
+ "f1": scores[2],
307
+ "support": scores[3],
308
+ }
309
+ )
310
+
311
+
312
+ def print_classification_metrics(score_df: pd.DataFrame, target_col: str, label_names: list[str]) -> None:
313
+ """Print per-class classification metrics in the format expected by SageMaker.
314
+
315
+ Args:
316
+ score_df: DataFrame from compute_classification_metrics
317
+ target_col: Name of the target column
318
+ label_names: List of class label names
319
+ """
320
+ metrics = ["precision", "recall", "f1", "support"]
321
+ for t in label_names:
322
+ for m in metrics:
323
+ value = score_df.loc[score_df[target_col] == t, m].iloc[0]
324
+ print(f"Metrics:{t}:{m} {value}")
325
+
326
+
327
+ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str]) -> None:
328
+ """Print confusion matrix in the format expected by SageMaker.
329
+
330
+ Args:
331
+ y_true: Ground truth labels
332
+ y_pred: Predicted labels
333
+ label_names: List of class label names
334
+ """
335
+ conf_mtx = confusion_matrix(y_true, y_pred, labels=label_names)
336
+ for i, row_name in enumerate(label_names):
337
+ for j, col_name in enumerate(label_names):
338
+ value = conf_mtx[i, j]
339
+ print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
@@ -0,0 +1,277 @@
1
+ """UQ Harness: Uncertainty Quantification using MAPIE Conformalized Quantile Regression.
2
+
3
+ This module provides a reusable UQ harness that can wrap any point predictor model
4
+ (XGBoost, PyTorch, ChemProp, etc.) to provide calibrated prediction intervals.
5
+
6
+ Usage:
7
+ # Training
8
+ uq_models, uq_metadata = train_uq_models(X_train, y_train, X_val, y_val)
9
+ save_uq_models(uq_models, uq_metadata, model_dir)
10
+
11
+ # Inference
12
+ uq_models, uq_metadata = load_uq_models(model_dir)
13
+ df = predict_intervals(df, X, uq_models, uq_metadata)
14
+ df = compute_confidence(df, uq_metadata["median_interval_width"])
15
+ """
16
+
17
+ import json
18
+ import os
19
+ import numpy as np
20
+ import pandas as pd
21
+ import joblib
22
+ from lightgbm import LGBMRegressor
23
+ from mapie.regression import ConformalizedQuantileRegressor
24
+
25
+ # Default confidence levels for prediction intervals
26
+ DEFAULT_CONFIDENCE_LEVELS = [0.50, 0.68, 0.80, 0.90, 0.95]
27
+
28
+
29
+ def train_uq_models(
30
+ X_train: pd.DataFrame | np.ndarray,
31
+ y_train: pd.Series | np.ndarray,
32
+ X_val: pd.DataFrame | np.ndarray,
33
+ y_val: pd.Series | np.ndarray,
34
+ confidence_levels: list[float] | None = None,
35
+ ) -> tuple[dict, dict]:
36
+ """Train MAPIE UQ models for multiple confidence levels.
37
+
38
+ Args:
39
+ X_train: Training features
40
+ y_train: Training targets
41
+ X_val: Validation features for conformalization
42
+ y_val: Validation targets for conformalization
43
+ confidence_levels: List of confidence levels (default: [0.50, 0.68, 0.80, 0.90, 0.95])
44
+
45
+ Returns:
46
+ Tuple of (uq_models dict, uq_metadata dict)
47
+ """
48
+ if confidence_levels is None:
49
+ confidence_levels = DEFAULT_CONFIDENCE_LEVELS
50
+
51
+ mapie_models = {}
52
+
53
+ for confidence_level in confidence_levels:
54
+ alpha = 1 - confidence_level
55
+ lower_q = alpha / 2
56
+ upper_q = 1 - alpha / 2
57
+
58
+ print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
59
+ print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
60
+
61
+ # Train three LightGBM quantile models for this confidence level
62
+ quantile_estimators = []
63
+ for q in [lower_q, upper_q, 0.5]:
64
+ print(f" Training model for quantile {q:.3f}...")
65
+ est = LGBMRegressor(
66
+ objective="quantile",
67
+ alpha=q,
68
+ n_estimators=1000,
69
+ max_depth=6,
70
+ learning_rate=0.01,
71
+ num_leaves=31,
72
+ min_child_samples=20,
73
+ subsample=0.8,
74
+ colsample_bytree=0.8,
75
+ random_state=42,
76
+ verbose=-1,
77
+ force_col_wise=True,
78
+ )
79
+ est.fit(X_train, y_train)
80
+ quantile_estimators.append(est)
81
+
82
+ # Create MAPIE CQR model for this confidence level
83
+ print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
84
+ mapie_model = ConformalizedQuantileRegressor(
85
+ quantile_estimators, confidence_level=confidence_level, prefit=True
86
+ )
87
+
88
+ # Conformalize the model with validation data
89
+ print(" Conformalizing with validation data...")
90
+ mapie_model.conformalize(X_val, y_val)
91
+
92
+ # Store the model
93
+ model_name = f"mapie_{confidence_level:.2f}"
94
+ mapie_models[model_name] = mapie_model
95
+
96
+ # Validate coverage for this confidence level
97
+ y_pred, y_pis = mapie_model.predict_interval(X_val)
98
+ coverage = np.mean((y_val >= y_pis[:, 0, 0]) & (y_val <= y_pis[:, 1, 0]))
99
+ print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
100
+
101
+ # Compute median interval width for confidence calculation (using 80% CI = q_10 to q_90)
102
+ print("\nComputing normalization statistics for confidence scores...")
103
+ model_80 = mapie_models["mapie_0.80"]
104
+ _, y_pis_80 = model_80.predict_interval(X_val)
105
+ interval_width = np.abs(y_pis_80[:, 1, 0] - y_pis_80[:, 0, 0])
106
+ median_interval_width = float(np.median(interval_width))
107
+ print(f" Median interval width (q_10-q_90): {median_interval_width:.6f}")
108
+
109
+ # Analyze interval widths across confidence levels
110
+ print("\nInterval Width Analysis:")
111
+ for conf_level in confidence_levels:
112
+ model = mapie_models[f"mapie_{conf_level:.2f}"]
113
+ _, y_pis = model.predict_interval(X_val)
114
+ widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
115
+ print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
116
+
117
+ uq_metadata = {
118
+ "confidence_levels": confidence_levels,
119
+ "median_interval_width": median_interval_width,
120
+ }
121
+
122
+ return mapie_models, uq_metadata
123
+
124
+
125
+ def save_uq_models(uq_models: dict, uq_metadata: dict, model_dir: str) -> None:
126
+ """Save UQ models and metadata to disk.
127
+
128
+ Args:
129
+ uq_models: Dictionary of MAPIE models keyed by name (e.g., "mapie_0.80")
130
+ uq_metadata: Dictionary with confidence_levels and median_interval_width
131
+ model_dir: Directory to save models
132
+ """
133
+ # Save each MAPIE model
134
+ for model_name, model in uq_models.items():
135
+ joblib.dump(model, os.path.join(model_dir, f"{model_name}.joblib"))
136
+
137
+ # Save median interval width
138
+ with open(os.path.join(model_dir, "median_interval_width.json"), "w") as fp:
139
+ json.dump(uq_metadata["median_interval_width"], fp)
140
+
141
+ # Save UQ metadata
142
+ with open(os.path.join(model_dir, "uq_metadata.json"), "w") as fp:
143
+ json.dump(uq_metadata, fp, indent=2)
144
+
145
+ print(f"Saved {len(uq_models)} UQ models to {model_dir}")
146
+
147
+
148
+ def load_uq_models(model_dir: str) -> tuple[dict, dict]:
149
+ """Load UQ models and metadata from disk.
150
+
151
+ Args:
152
+ model_dir: Directory containing saved models
153
+
154
+ Returns:
155
+ Tuple of (uq_models dict, uq_metadata dict)
156
+ """
157
+ # Load UQ metadata
158
+ uq_metadata_path = os.path.join(model_dir, "uq_metadata.json")
159
+ if os.path.exists(uq_metadata_path):
160
+ with open(uq_metadata_path) as fp:
161
+ uq_metadata = json.load(fp)
162
+ else:
163
+ # Fallback for older models that only have median_interval_width.json
164
+ uq_metadata = {"confidence_levels": DEFAULT_CONFIDENCE_LEVELS}
165
+ median_width_path = os.path.join(model_dir, "median_interval_width.json")
166
+ if os.path.exists(median_width_path):
167
+ with open(median_width_path) as fp:
168
+ uq_metadata["median_interval_width"] = json.load(fp)
169
+
170
+ # Load all MAPIE models
171
+ uq_models = {}
172
+ for conf_level in uq_metadata["confidence_levels"]:
173
+ model_name = f"mapie_{conf_level:.2f}"
174
+ model_path = os.path.join(model_dir, f"{model_name}.joblib")
175
+ if os.path.exists(model_path):
176
+ uq_models[model_name] = joblib.load(model_path)
177
+
178
+ return uq_models, uq_metadata
179
+
180
+
181
+ def predict_intervals(
182
+ df: pd.DataFrame,
183
+ X: pd.DataFrame | np.ndarray,
184
+ uq_models: dict,
185
+ uq_metadata: dict,
186
+ ) -> pd.DataFrame:
187
+ """Add prediction intervals to a DataFrame.
188
+
189
+ Args:
190
+ df: DataFrame to add interval columns to
191
+ X: Features for prediction (must match training features)
192
+ uq_models: Dictionary of MAPIE models
193
+ uq_metadata: Dictionary with confidence_levels
194
+
195
+ Returns:
196
+ DataFrame with added quantile columns (q_025, q_05, ..., q_975)
197
+ """
198
+ confidence_levels = uq_metadata["confidence_levels"]
199
+
200
+ for conf_level in confidence_levels:
201
+ model_name = f"mapie_{conf_level:.2f}"
202
+ model = uq_models[model_name]
203
+
204
+ # Get conformalized predictions
205
+ y_pred, y_pis = model.predict_interval(X)
206
+
207
+ # Map confidence levels to quantile column names
208
+ if conf_level == 0.50: # 50% CI
209
+ df["q_25"] = y_pis[:, 0, 0]
210
+ df["q_75"] = y_pis[:, 1, 0]
211
+ df["q_50"] = y_pred # Median prediction
212
+ elif conf_level == 0.68: # 68% CI (~1 std)
213
+ df["q_16"] = y_pis[:, 0, 0]
214
+ df["q_84"] = y_pis[:, 1, 0]
215
+ elif conf_level == 0.80: # 80% CI
216
+ df["q_10"] = y_pis[:, 0, 0]
217
+ df["q_90"] = y_pis[:, 1, 0]
218
+ elif conf_level == 0.90: # 90% CI
219
+ df["q_05"] = y_pis[:, 0, 0]
220
+ df["q_95"] = y_pis[:, 1, 0]
221
+ elif conf_level == 0.95: # 95% CI
222
+ df["q_025"] = y_pis[:, 0, 0]
223
+ df["q_975"] = y_pis[:, 1, 0]
224
+
225
+ # Calculate pseudo-standard deviation from the 68% interval width
226
+ if "q_84" in df.columns and "q_16" in df.columns:
227
+ df["prediction_std"] = (df["q_84"] - df["q_16"]).abs() / 2.0
228
+
229
+ # Reorder quantile columns for easier reading
230
+ quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_50", "q_75", "q_84", "q_90", "q_95", "q_975"]
231
+ existing_q_cols = [c for c in quantile_cols if c in df.columns]
232
+ other_cols = [c for c in df.columns if c not in quantile_cols]
233
+ df = df[other_cols + existing_q_cols]
234
+
235
+ return df
236
+
237
+
238
+ def compute_confidence(
239
+ df: pd.DataFrame,
240
+ median_interval_width: float,
241
+ lower_q: str = "q_10",
242
+ upper_q: str = "q_90",
243
+ alpha: float = 1.0,
244
+ beta: float = 1.0,
245
+ ) -> pd.DataFrame:
246
+ """Compute confidence scores (0.0 to 1.0) based on prediction interval width.
247
+
248
+ Uses exponential decay based on:
249
+ 1. Interval width relative to median (alpha weight)
250
+ 2. Distance from median prediction (beta weight)
251
+
252
+ Args:
253
+ df: DataFrame with 'prediction', 'q_50', and quantile columns
254
+ median_interval_width: Pre-computed median interval width from training data
255
+ lower_q: Lower quantile column name (default: 'q_10')
256
+ upper_q: Upper quantile column name (default: 'q_90')
257
+ alpha: Weight for interval width term (default: 1.0)
258
+ beta: Weight for distance from median term (default: 1.0)
259
+
260
+ Returns:
261
+ DataFrame with added 'confidence' column
262
+ """
263
+ # Interval width
264
+ interval_width = (df[upper_q] - df[lower_q]).abs()
265
+
266
+ # Distance from median, normalized by interval width
267
+ distance_from_median = (df["prediction"] - df["q_50"]).abs()
268
+ normalized_distance = distance_from_median / (interval_width + 1e-6)
269
+
270
+ # Cap the distance penalty at 1.0
271
+ normalized_distance = np.minimum(normalized_distance, 1.0)
272
+
273
+ # Confidence using exponential decay
274
+ interval_term = interval_width / median_interval_width
275
+ df["confidence"] = np.exp(-(alpha * interval_term + beta * normalized_distance))
276
+
277
+ return df