workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (140) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +5 -5
  27. workbench/core/artifacts/df_store_core.py +114 -0
  28. workbench/core/artifacts/endpoint_core.py +319 -204
  29. workbench/core/artifacts/feature_set_core.py +249 -45
  30. workbench/core/artifacts/model_core.py +135 -82
  31. workbench/core/artifacts/parameter_store_core.py +98 -0
  32. workbench/core/cloud_platform/cloud_meta.py +0 -1
  33. workbench/core/pipelines/pipeline_executor.py +1 -1
  34. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  35. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  36. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  37. workbench/core/views/training_view.py +113 -42
  38. workbench/core/views/view.py +53 -3
  39. workbench/core/views/view_utils.py +4 -4
  40. workbench/model_script_utils/model_script_utils.py +339 -0
  41. workbench/model_script_utils/pytorch_utils.py +405 -0
  42. workbench/model_script_utils/uq_harness.py +277 -0
  43. workbench/model_scripts/chemprop/chemprop.template +774 -0
  44. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  45. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  46. workbench/model_scripts/chemprop/requirements.txt +3 -0
  47. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  48. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  49. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  50. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  51. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  52. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  53. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  54. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  55. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  56. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  57. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  58. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  59. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  60. workbench/model_scripts/meta_model/meta_model.template +209 -0
  61. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  62. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  63. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  64. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  65. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  66. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  67. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  68. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  69. workbench/model_scripts/script_generation.py +15 -12
  70. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  71. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  72. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  73. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  74. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  75. workbench/repl/workbench_shell.py +18 -14
  76. workbench/resources/open_source_api.key +1 -1
  77. workbench/scripts/endpoint_test.py +162 -0
  78. workbench/scripts/lambda_test.py +73 -0
  79. workbench/scripts/meta_model_sim.py +35 -0
  80. workbench/scripts/ml_pipeline_sqs.py +122 -6
  81. workbench/scripts/training_test.py +85 -0
  82. workbench/themes/dark/custom.css +59 -0
  83. workbench/themes/dark/plotly.json +5 -5
  84. workbench/themes/light/custom.css +153 -40
  85. workbench/themes/light/plotly.json +9 -9
  86. workbench/themes/midnight_blue/custom.css +59 -0
  87. workbench/utils/aws_utils.py +0 -1
  88. workbench/utils/chem_utils/fingerprints.py +87 -46
  89. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  90. workbench/utils/chem_utils/projections.py +16 -6
  91. workbench/utils/chem_utils/vis.py +25 -27
  92. workbench/utils/chemprop_utils.py +141 -0
  93. workbench/utils/config_manager.py +2 -6
  94. workbench/utils/endpoint_utils.py +5 -7
  95. workbench/utils/license_manager.py +2 -6
  96. workbench/utils/markdown_utils.py +57 -0
  97. workbench/utils/meta_model_simulator.py +499 -0
  98. workbench/utils/metrics_utils.py +256 -0
  99. workbench/utils/model_utils.py +260 -76
  100. workbench/utils/pipeline_utils.py +0 -1
  101. workbench/utils/plot_utils.py +159 -34
  102. workbench/utils/pytorch_utils.py +87 -0
  103. workbench/utils/shap_utils.py +11 -57
  104. workbench/utils/theme_manager.py +95 -30
  105. workbench/utils/xgboost_local_crossfold.py +267 -0
  106. workbench/utils/xgboost_model_utils.py +127 -220
  107. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  108. workbench/web_interface/components/model_plot.py +16 -2
  109. workbench/web_interface/components/plugin_unit_test.py +5 -3
  110. workbench/web_interface/components/plugins/ag_table.py +2 -4
  111. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  112. workbench/web_interface/components/plugins/model_details.py +48 -80
  113. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  114. workbench/web_interface/components/settings_menu.py +184 -0
  115. workbench/web_interface/page_views/main_page.py +0 -1
  116. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  117. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
  118. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  119. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  120. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  121. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  122. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  123. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  124. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  125. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
  126. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
  127. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  128. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  129. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  130. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  131. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  132. workbench/themes/quartz/base_css.url +0 -1
  133. workbench/themes/quartz/custom.css +0 -117
  134. workbench/themes/quartz/plotly.json +0 -642
  135. workbench/themes/quartz_dark/base_css.url +0 -1
  136. workbench/themes/quartz_dark/custom.css +0 -131
  137. workbench/themes/quartz_dark/plotly.json +0 -642
  138. workbench/utils/resource_utils.py +0 -39
  139. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  140. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,405 @@
1
+ """PyTorch utilities for tabular data modeling.
2
+
3
+ Provides a lightweight TabularMLP model with categorical embeddings and
4
+ training utilities for use in Workbench model scripts.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from typing import Optional
10
+
11
+ import joblib
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.utils.data import DataLoader, TensorDataset
17
+
18
+
19
+ class FeatureScaler:
20
+ """Standard scaler for continuous features (zero mean, unit variance)."""
21
+
22
+ def __init__(self):
23
+ self.means: Optional[np.ndarray] = None
24
+ self.stds: Optional[np.ndarray] = None
25
+ self.feature_names: Optional[list[str]] = None
26
+
27
+ def fit(self, df: pd.DataFrame, continuous_cols: list[str]) -> "FeatureScaler":
28
+ """Fit the scaler on training data."""
29
+ self.feature_names = continuous_cols
30
+ data = df[continuous_cols].values.astype(np.float32)
31
+ self.means = np.nanmean(data, axis=0)
32
+ self.stds = np.nanstd(data, axis=0)
33
+ # Avoid division by zero for constant features
34
+ self.stds[self.stds == 0] = 1.0
35
+ return self
36
+
37
+ def transform(self, df: pd.DataFrame) -> np.ndarray:
38
+ """Transform data using fitted parameters."""
39
+ data = df[self.feature_names].values.astype(np.float32)
40
+ # Fill NaN with mean before scaling
41
+ for i, mean in enumerate(self.means):
42
+ data[np.isnan(data[:, i]), i] = mean
43
+ return (data - self.means) / self.stds
44
+
45
+ def fit_transform(self, df: pd.DataFrame, continuous_cols: list[str]) -> np.ndarray:
46
+ """Fit and transform in one step."""
47
+ self.fit(df, continuous_cols)
48
+ return self.transform(df)
49
+
50
+ def save(self, path: str) -> None:
51
+ """Save scaler parameters."""
52
+ joblib.dump(
53
+ {
54
+ "means": self.means.tolist(),
55
+ "stds": self.stds.tolist(),
56
+ "feature_names": self.feature_names,
57
+ },
58
+ path,
59
+ )
60
+
61
+ @classmethod
62
+ def load(cls, path: str) -> "FeatureScaler":
63
+ """Load scaler from saved parameters."""
64
+ data = joblib.load(path)
65
+ scaler = cls()
66
+ scaler.means = np.array(data["means"], dtype=np.float32)
67
+ scaler.stds = np.array(data["stds"], dtype=np.float32)
68
+ scaler.feature_names = data["feature_names"]
69
+ return scaler
70
+
71
+
72
+ class TabularMLP(nn.Module):
73
+ """Feedforward neural network for tabular data with optional categorical embeddings.
74
+
75
+ Args:
76
+ n_continuous: Number of continuous input features
77
+ categorical_cardinalities: List of cardinalities for each categorical feature
78
+ embedding_dims: List of embedding dimensions for each categorical feature
79
+ hidden_layers: List of hidden layer sizes (e.g., [256, 128, 64])
80
+ n_outputs: Number of output units
81
+ task: "regression" or "classification"
82
+ dropout: Dropout rate
83
+ use_batch_norm: Whether to use batch normalization
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ n_continuous: int,
89
+ categorical_cardinalities: list[int],
90
+ embedding_dims: list[int],
91
+ hidden_layers: list[int],
92
+ n_outputs: int,
93
+ task: str = "regression",
94
+ dropout: float = 0.1,
95
+ use_batch_norm: bool = True,
96
+ ):
97
+ super().__init__()
98
+ self.task = task
99
+ self.n_continuous = n_continuous
100
+ self.categorical_cardinalities = categorical_cardinalities
101
+
102
+ # Embedding layers for categorical features
103
+ self.embeddings = nn.ModuleList(
104
+ [nn.Embedding(n_cats, emb_dim) for n_cats, emb_dim in zip(categorical_cardinalities, embedding_dims)]
105
+ )
106
+
107
+ # Calculate input dimension
108
+ total_emb_dim = sum(embedding_dims)
109
+ input_dim = n_continuous + total_emb_dim
110
+
111
+ # Build MLP layers
112
+ layers = []
113
+ for hidden_dim in hidden_layers:
114
+ layers.append(nn.Linear(input_dim, hidden_dim))
115
+ if use_batch_norm:
116
+ layers.append(nn.BatchNorm1d(hidden_dim))
117
+ layers.append(nn.LeakyReLU())
118
+ layers.append(nn.Dropout(dropout))
119
+ input_dim = hidden_dim
120
+
121
+ self.mlp = nn.Sequential(*layers)
122
+ self.head = nn.Linear(input_dim, n_outputs)
123
+
124
+ def forward(self, x_cont: torch.Tensor, x_cat: Optional[torch.Tensor] = None) -> torch.Tensor:
125
+ """Forward pass.
126
+
127
+ Args:
128
+ x_cont: Continuous features tensor of shape (batch, n_continuous)
129
+ x_cat: Categorical features tensor of shape (batch, n_categoricals), optional
130
+
131
+ Returns:
132
+ Output tensor of shape (batch, n_outputs)
133
+ """
134
+ # Embed categorical features and concatenate with continuous
135
+ if x_cat is not None and len(self.embeddings) > 0:
136
+ embs = [emb(x_cat[:, i]) for i, emb in enumerate(self.embeddings)]
137
+ x = torch.cat([x_cont] + embs, dim=1)
138
+ else:
139
+ x = x_cont
140
+
141
+ x = self.mlp(x)
142
+ out = self.head(x)
143
+
144
+ if self.task == "classification":
145
+ out = torch.softmax(out, dim=1)
146
+
147
+ return out
148
+
149
+
150
+ def compute_embedding_dims(cardinalities: list[int], max_dim: int = 50) -> list[int]:
151
+ """Compute embedding dimensions using the rule of thumb: min(50, (n+1)//2)."""
152
+ return [min(max_dim, (n + 1) // 2) for n in cardinalities]
153
+
154
+
155
+ def prepare_data(
156
+ df: pd.DataFrame,
157
+ continuous_cols: list[str],
158
+ categorical_cols: list[str],
159
+ target_col: Optional[str] = None,
160
+ category_mappings: Optional[dict] = None,
161
+ scaler: Optional[FeatureScaler] = None,
162
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], dict, Optional[FeatureScaler]]:
163
+ """Prepare dataframe for model input.
164
+
165
+ Args:
166
+ df: Input dataframe
167
+ continuous_cols: List of continuous feature column names
168
+ categorical_cols: List of categorical feature column names
169
+ target_col: Target column name (optional, for training)
170
+ category_mappings: Existing category mappings (for inference)
171
+ scaler: Existing FeatureScaler (for inference), or None to fit a new one
172
+
173
+ Returns:
174
+ Tuple of (x_cont, x_cat, y, category_mappings, scaler)
175
+ """
176
+ # Continuous features with standardization
177
+ if scaler is None:
178
+ scaler = FeatureScaler()
179
+ cont_data = scaler.fit_transform(df, continuous_cols)
180
+ else:
181
+ cont_data = scaler.transform(df)
182
+ x_cont = torch.tensor(cont_data, dtype=torch.float32)
183
+
184
+ # Categorical features
185
+ x_cat = None
186
+ if categorical_cols:
187
+ if category_mappings is None:
188
+ category_mappings = {}
189
+ for col in categorical_cols:
190
+ unique_vals = df[col].unique().tolist()
191
+ category_mappings[col] = {v: i for i, v in enumerate(unique_vals)}
192
+
193
+ cat_indices = []
194
+ for col in categorical_cols:
195
+ mapping = category_mappings[col]
196
+ # Map values to indices, use 0 for unknown categories
197
+ indices = df[col].map(lambda x: mapping.get(x, 0)).values
198
+ cat_indices.append(indices)
199
+
200
+ x_cat = torch.tensor(np.column_stack(cat_indices), dtype=torch.long)
201
+
202
+ # Target
203
+ y = None
204
+ if target_col is not None:
205
+ y = torch.tensor(df[target_col].values, dtype=torch.float32)
206
+ if len(y.shape) == 1:
207
+ y = y.unsqueeze(1)
208
+
209
+ return x_cont, x_cat, y, category_mappings, scaler
210
+
211
+
212
+ def create_model(
213
+ n_continuous: int,
214
+ categorical_cardinalities: list[int],
215
+ hidden_layers: list[int],
216
+ n_outputs: int,
217
+ task: str = "regression",
218
+ dropout: float = 0.1,
219
+ use_batch_norm: bool = True,
220
+ ) -> TabularMLP:
221
+ """Create a TabularMLP model with appropriate embedding dimensions."""
222
+ embedding_dims = compute_embedding_dims(categorical_cardinalities)
223
+ return TabularMLP(
224
+ n_continuous=n_continuous,
225
+ categorical_cardinalities=categorical_cardinalities,
226
+ embedding_dims=embedding_dims,
227
+ hidden_layers=hidden_layers,
228
+ n_outputs=n_outputs,
229
+ task=task,
230
+ dropout=dropout,
231
+ use_batch_norm=use_batch_norm,
232
+ )
233
+
234
+
235
+ def train_model(
236
+ model: TabularMLP,
237
+ train_x_cont: torch.Tensor,
238
+ train_x_cat: Optional[torch.Tensor],
239
+ train_y: torch.Tensor,
240
+ val_x_cont: torch.Tensor,
241
+ val_x_cat: Optional[torch.Tensor],
242
+ val_y: torch.Tensor,
243
+ task: str = "regression",
244
+ max_epochs: int = 200,
245
+ patience: int = 20,
246
+ batch_size: int = 128,
247
+ learning_rate: float = 1e-3,
248
+ loss: str = "L1Loss",
249
+ device: str = "cpu",
250
+ ) -> tuple[TabularMLP, dict]:
251
+ """Train the model with early stopping.
252
+
253
+ Returns:
254
+ Tuple of (trained model, training history dict)
255
+ """
256
+ model = model.to(device)
257
+
258
+ # Create dataloaders
259
+ if train_x_cat is not None:
260
+ train_dataset = TensorDataset(train_x_cont, train_x_cat, train_y)
261
+ val_dataset = TensorDataset(val_x_cont, val_x_cat, val_y)
262
+ else:
263
+ # Use dummy categorical tensor
264
+ dummy_cat = torch.zeros(train_x_cont.shape[0], 0, dtype=torch.long)
265
+ dummy_val_cat = torch.zeros(val_x_cont.shape[0], 0, dtype=torch.long)
266
+ train_dataset = TensorDataset(train_x_cont, dummy_cat, train_y)
267
+ val_dataset = TensorDataset(val_x_cont, dummy_val_cat, val_y)
268
+
269
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
270
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
271
+
272
+ # Loss and optimizer
273
+ if task == "classification":
274
+ criterion = nn.CrossEntropyLoss()
275
+ else:
276
+ # Map loss name to PyTorch loss class
277
+ loss_map = {
278
+ "L1Loss": nn.L1Loss,
279
+ "MSELoss": nn.MSELoss,
280
+ "HuberLoss": nn.HuberLoss,
281
+ "SmoothL1Loss": nn.SmoothL1Loss,
282
+ }
283
+ if loss not in loss_map:
284
+ raise ValueError(f"Unknown loss '{loss}'. Supported: {list(loss_map.keys())}")
285
+ criterion = loss_map[loss]()
286
+
287
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
288
+
289
+ # Training loop with early stopping
290
+ best_val_loss = float("inf")
291
+ best_state = None
292
+ epochs_without_improvement = 0
293
+ history = {"train_loss": [], "val_loss": []}
294
+
295
+ for epoch in range(max_epochs):
296
+ # Training
297
+ model.train()
298
+ train_losses = []
299
+ for batch in train_loader:
300
+ x_cont, x_cat, y = [b.to(device) for b in batch]
301
+ x_cat = x_cat if x_cat.shape[1] > 0 else None
302
+
303
+ optimizer.zero_grad()
304
+ out = model(x_cont, x_cat)
305
+
306
+ if task == "classification":
307
+ loss = criterion(out, y.squeeze().long())
308
+ else:
309
+ loss = criterion(out, y)
310
+
311
+ loss.backward()
312
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
313
+ optimizer.step()
314
+ train_losses.append(loss.item())
315
+
316
+ # Validation
317
+ model.eval()
318
+ val_losses = []
319
+ with torch.no_grad():
320
+ for batch in val_loader:
321
+ x_cont, x_cat, y = [b.to(device) for b in batch]
322
+ x_cat = x_cat if x_cat.shape[1] > 0 else None
323
+ out = model(x_cont, x_cat)
324
+
325
+ if task == "classification":
326
+ loss = criterion(out, y.squeeze().long())
327
+ else:
328
+ loss = criterion(out, y)
329
+ val_losses.append(loss.item())
330
+
331
+ train_loss = np.mean(train_losses)
332
+ val_loss = np.mean(val_losses)
333
+ history["train_loss"].append(train_loss)
334
+ history["val_loss"].append(val_loss)
335
+
336
+ # Early stopping check
337
+ if val_loss < best_val_loss:
338
+ best_val_loss = val_loss
339
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
340
+ epochs_without_improvement = 0
341
+ else:
342
+ epochs_without_improvement += 1
343
+
344
+ if (epoch + 1) % 10 == 0:
345
+ print(f"Epoch {epoch + 1}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
346
+
347
+ if epochs_without_improvement >= patience:
348
+ print(f"Early stopping at epoch {epoch + 1}")
349
+ break
350
+
351
+ # Load best weights
352
+ if best_state is not None:
353
+ model.load_state_dict(best_state)
354
+
355
+ model = model.to("cpu")
356
+ return model, history
357
+
358
+
359
+ def predict(
360
+ model: TabularMLP,
361
+ x_cont: torch.Tensor,
362
+ x_cat: Optional[torch.Tensor] = None,
363
+ device: str = "cpu",
364
+ ) -> np.ndarray:
365
+ """Run inference with the model."""
366
+ model = model.to(device)
367
+ model.eval()
368
+
369
+ with torch.no_grad():
370
+ x_cont = x_cont.to(device)
371
+ if x_cat is not None:
372
+ x_cat = x_cat.to(device)
373
+ out = model(x_cont, x_cat)
374
+
375
+ return out.cpu().numpy()
376
+
377
+
378
+ def save_model(model: TabularMLP, path: str, model_config: dict) -> None:
379
+ """Save model weights and configuration."""
380
+ os.makedirs(path, exist_ok=True)
381
+ torch.save(model.state_dict(), os.path.join(path, "model.pt"))
382
+ with open(os.path.join(path, "config.json"), "w") as f:
383
+ json.dump(model_config, f, indent=2)
384
+
385
+
386
+ def load_model(path: str, device: str = "cpu") -> TabularMLP:
387
+ """Load model from saved weights and configuration."""
388
+ with open(os.path.join(path, "config.json")) as f:
389
+ config = json.load(f)
390
+
391
+ model = create_model(
392
+ n_continuous=config["n_continuous"],
393
+ categorical_cardinalities=config["categorical_cardinalities"],
394
+ hidden_layers=config["hidden_layers"],
395
+ n_outputs=config["n_outputs"],
396
+ task=config["task"],
397
+ dropout=config.get("dropout", 0.1),
398
+ use_batch_norm=config.get("use_batch_norm", True),
399
+ )
400
+
401
+ state_dict = torch.load(os.path.join(path, "model.pt"), map_location=device, weights_only=True)
402
+ model.load_state_dict(state_dict)
403
+ model.eval()
404
+
405
+ return model
@@ -1,2 +1,2 @@
1
- # Note: The training and inference images already have torch and pytorch-tabular installed.
1
+ # Note: The training and inference images already have torch + supporting packages installed.
2
2
  # So we only need to install packages that are not already included in the images.
@@ -0,0 +1,277 @@
1
+ """UQ Harness: Uncertainty Quantification using MAPIE Conformalized Quantile Regression.
2
+
3
+ This module provides a reusable UQ harness that can wrap any point predictor model
4
+ (XGBoost, PyTorch, ChemProp, etc.) to provide calibrated prediction intervals.
5
+
6
+ Usage:
7
+ # Training
8
+ uq_models, uq_metadata = train_uq_models(X_train, y_train, X_val, y_val)
9
+ save_uq_models(uq_models, uq_metadata, model_dir)
10
+
11
+ # Inference
12
+ uq_models, uq_metadata = load_uq_models(model_dir)
13
+ df = predict_intervals(df, X, uq_models, uq_metadata)
14
+ df = compute_confidence(df, uq_metadata["median_interval_width"])
15
+ """
16
+
17
+ import json
18
+ import os
19
+ import numpy as np
20
+ import pandas as pd
21
+ import joblib
22
+ from lightgbm import LGBMRegressor
23
+ from mapie.regression import ConformalizedQuantileRegressor
24
+
25
+ # Default confidence levels for prediction intervals
26
+ DEFAULT_CONFIDENCE_LEVELS = [0.50, 0.68, 0.80, 0.90, 0.95]
27
+
28
+
29
+ def train_uq_models(
30
+ X_train: pd.DataFrame | np.ndarray,
31
+ y_train: pd.Series | np.ndarray,
32
+ X_val: pd.DataFrame | np.ndarray,
33
+ y_val: pd.Series | np.ndarray,
34
+ confidence_levels: list[float] | None = None,
35
+ ) -> tuple[dict, dict]:
36
+ """Train MAPIE UQ models for multiple confidence levels.
37
+
38
+ Args:
39
+ X_train: Training features
40
+ y_train: Training targets
41
+ X_val: Validation features for conformalization
42
+ y_val: Validation targets for conformalization
43
+ confidence_levels: List of confidence levels (default: [0.50, 0.68, 0.80, 0.90, 0.95])
44
+
45
+ Returns:
46
+ Tuple of (uq_models dict, uq_metadata dict)
47
+ """
48
+ if confidence_levels is None:
49
+ confidence_levels = DEFAULT_CONFIDENCE_LEVELS
50
+
51
+ mapie_models = {}
52
+
53
+ for confidence_level in confidence_levels:
54
+ alpha = 1 - confidence_level
55
+ lower_q = alpha / 2
56
+ upper_q = 1 - alpha / 2
57
+
58
+ print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
59
+ print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
60
+
61
+ # Train three LightGBM quantile models for this confidence level
62
+ quantile_estimators = []
63
+ for q in [lower_q, upper_q, 0.5]:
64
+ print(f" Training model for quantile {q:.3f}...")
65
+ est = LGBMRegressor(
66
+ objective="quantile",
67
+ alpha=q,
68
+ n_estimators=1000,
69
+ max_depth=6,
70
+ learning_rate=0.01,
71
+ num_leaves=31,
72
+ min_child_samples=20,
73
+ subsample=0.8,
74
+ colsample_bytree=0.8,
75
+ random_state=42,
76
+ verbose=-1,
77
+ force_col_wise=True,
78
+ )
79
+ est.fit(X_train, y_train)
80
+ quantile_estimators.append(est)
81
+
82
+ # Create MAPIE CQR model for this confidence level
83
+ print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
84
+ mapie_model = ConformalizedQuantileRegressor(
85
+ quantile_estimators, confidence_level=confidence_level, prefit=True
86
+ )
87
+
88
+ # Conformalize the model with validation data
89
+ print(" Conformalizing with validation data...")
90
+ mapie_model.conformalize(X_val, y_val)
91
+
92
+ # Store the model
93
+ model_name = f"mapie_{confidence_level:.2f}"
94
+ mapie_models[model_name] = mapie_model
95
+
96
+ # Validate coverage for this confidence level
97
+ y_pred, y_pis = mapie_model.predict_interval(X_val)
98
+ coverage = np.mean((y_val >= y_pis[:, 0, 0]) & (y_val <= y_pis[:, 1, 0]))
99
+ print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
100
+
101
+ # Compute median interval width for confidence calculation (using 80% CI = q_10 to q_90)
102
+ print("\nComputing normalization statistics for confidence scores...")
103
+ model_80 = mapie_models["mapie_0.80"]
104
+ _, y_pis_80 = model_80.predict_interval(X_val)
105
+ interval_width = np.abs(y_pis_80[:, 1, 0] - y_pis_80[:, 0, 0])
106
+ median_interval_width = float(np.median(interval_width))
107
+ print(f" Median interval width (q_10-q_90): {median_interval_width:.6f}")
108
+
109
+ # Analyze interval widths across confidence levels
110
+ print("\nInterval Width Analysis:")
111
+ for conf_level in confidence_levels:
112
+ model = mapie_models[f"mapie_{conf_level:.2f}"]
113
+ _, y_pis = model.predict_interval(X_val)
114
+ widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
115
+ print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
116
+
117
+ uq_metadata = {
118
+ "confidence_levels": confidence_levels,
119
+ "median_interval_width": median_interval_width,
120
+ }
121
+
122
+ return mapie_models, uq_metadata
123
+
124
+
125
+ def save_uq_models(uq_models: dict, uq_metadata: dict, model_dir: str) -> None:
126
+ """Save UQ models and metadata to disk.
127
+
128
+ Args:
129
+ uq_models: Dictionary of MAPIE models keyed by name (e.g., "mapie_0.80")
130
+ uq_metadata: Dictionary with confidence_levels and median_interval_width
131
+ model_dir: Directory to save models
132
+ """
133
+ # Save each MAPIE model
134
+ for model_name, model in uq_models.items():
135
+ joblib.dump(model, os.path.join(model_dir, f"{model_name}.joblib"))
136
+
137
+ # Save median interval width
138
+ with open(os.path.join(model_dir, "median_interval_width.json"), "w") as fp:
139
+ json.dump(uq_metadata["median_interval_width"], fp)
140
+
141
+ # Save UQ metadata
142
+ with open(os.path.join(model_dir, "uq_metadata.json"), "w") as fp:
143
+ json.dump(uq_metadata, fp, indent=2)
144
+
145
+ print(f"Saved {len(uq_models)} UQ models to {model_dir}")
146
+
147
+
148
+ def load_uq_models(model_dir: str) -> tuple[dict, dict]:
149
+ """Load UQ models and metadata from disk.
150
+
151
+ Args:
152
+ model_dir: Directory containing saved models
153
+
154
+ Returns:
155
+ Tuple of (uq_models dict, uq_metadata dict)
156
+ """
157
+ # Load UQ metadata
158
+ uq_metadata_path = os.path.join(model_dir, "uq_metadata.json")
159
+ if os.path.exists(uq_metadata_path):
160
+ with open(uq_metadata_path) as fp:
161
+ uq_metadata = json.load(fp)
162
+ else:
163
+ # Fallback for older models that only have median_interval_width.json
164
+ uq_metadata = {"confidence_levels": DEFAULT_CONFIDENCE_LEVELS}
165
+ median_width_path = os.path.join(model_dir, "median_interval_width.json")
166
+ if os.path.exists(median_width_path):
167
+ with open(median_width_path) as fp:
168
+ uq_metadata["median_interval_width"] = json.load(fp)
169
+
170
+ # Load all MAPIE models
171
+ uq_models = {}
172
+ for conf_level in uq_metadata["confidence_levels"]:
173
+ model_name = f"mapie_{conf_level:.2f}"
174
+ model_path = os.path.join(model_dir, f"{model_name}.joblib")
175
+ if os.path.exists(model_path):
176
+ uq_models[model_name] = joblib.load(model_path)
177
+
178
+ return uq_models, uq_metadata
179
+
180
+
181
+ def predict_intervals(
182
+ df: pd.DataFrame,
183
+ X: pd.DataFrame | np.ndarray,
184
+ uq_models: dict,
185
+ uq_metadata: dict,
186
+ ) -> pd.DataFrame:
187
+ """Add prediction intervals to a DataFrame.
188
+
189
+ Args:
190
+ df: DataFrame to add interval columns to
191
+ X: Features for prediction (must match training features)
192
+ uq_models: Dictionary of MAPIE models
193
+ uq_metadata: Dictionary with confidence_levels
194
+
195
+ Returns:
196
+ DataFrame with added quantile columns (q_025, q_05, ..., q_975)
197
+ """
198
+ confidence_levels = uq_metadata["confidence_levels"]
199
+
200
+ for conf_level in confidence_levels:
201
+ model_name = f"mapie_{conf_level:.2f}"
202
+ model = uq_models[model_name]
203
+
204
+ # Get conformalized predictions
205
+ y_pred, y_pis = model.predict_interval(X)
206
+
207
+ # Map confidence levels to quantile column names
208
+ if conf_level == 0.50: # 50% CI
209
+ df["q_25"] = y_pis[:, 0, 0]
210
+ df["q_75"] = y_pis[:, 1, 0]
211
+ df["q_50"] = y_pred # Median prediction
212
+ elif conf_level == 0.68: # 68% CI (~1 std)
213
+ df["q_16"] = y_pis[:, 0, 0]
214
+ df["q_84"] = y_pis[:, 1, 0]
215
+ elif conf_level == 0.80: # 80% CI
216
+ df["q_10"] = y_pis[:, 0, 0]
217
+ df["q_90"] = y_pis[:, 1, 0]
218
+ elif conf_level == 0.90: # 90% CI
219
+ df["q_05"] = y_pis[:, 0, 0]
220
+ df["q_95"] = y_pis[:, 1, 0]
221
+ elif conf_level == 0.95: # 95% CI
222
+ df["q_025"] = y_pis[:, 0, 0]
223
+ df["q_975"] = y_pis[:, 1, 0]
224
+
225
+ # Calculate pseudo-standard deviation from the 68% interval width
226
+ if "q_84" in df.columns and "q_16" in df.columns:
227
+ df["prediction_std"] = (df["q_84"] - df["q_16"]).abs() / 2.0
228
+
229
+ # Reorder quantile columns for easier reading
230
+ quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_50", "q_75", "q_84", "q_90", "q_95", "q_975"]
231
+ existing_q_cols = [c for c in quantile_cols if c in df.columns]
232
+ other_cols = [c for c in df.columns if c not in quantile_cols]
233
+ df = df[other_cols + existing_q_cols]
234
+
235
+ return df
236
+
237
+
238
+ def compute_confidence(
239
+ df: pd.DataFrame,
240
+ median_interval_width: float,
241
+ lower_q: str = "q_10",
242
+ upper_q: str = "q_90",
243
+ alpha: float = 1.0,
244
+ beta: float = 1.0,
245
+ ) -> pd.DataFrame:
246
+ """Compute confidence scores (0.0 to 1.0) based on prediction interval width.
247
+
248
+ Uses exponential decay based on:
249
+ 1. Interval width relative to median (alpha weight)
250
+ 2. Distance from median prediction (beta weight)
251
+
252
+ Args:
253
+ df: DataFrame with 'prediction', 'q_50', and quantile columns
254
+ median_interval_width: Pre-computed median interval width from training data
255
+ lower_q: Lower quantile column name (default: 'q_10')
256
+ upper_q: Upper quantile column name (default: 'q_90')
257
+ alpha: Weight for interval width term (default: 1.0)
258
+ beta: Weight for distance from median term (default: 1.0)
259
+
260
+ Returns:
261
+ DataFrame with added 'confidence' column
262
+ """
263
+ # Interval width
264
+ interval_width = (df[upper_q] - df[lower_q]).abs()
265
+
266
+ # Distance from median, normalized by interval width
267
+ distance_from_median = (df["prediction"] - df["q_50"]).abs()
268
+ normalized_distance = distance_from_median / (interval_width + 1e-6)
269
+
270
+ # Cap the distance penalty at 1.0
271
+ normalized_distance = np.minimum(normalized_distance, 1.0)
272
+
273
+ # Confidence using exponential decay
274
+ interval_term = interval_width / median_interval_width
275
+ df["confidence"] = np.exp(-(alpha * interval_term + beta * normalized_distance))
276
+
277
+ return df