workbench 0.8.213__py3-none-any.whl → 0.8.217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/api/__init__.py +3 -0
  9. workbench/api/endpoint.py +10 -5
  10. workbench/api/feature_set.py +76 -6
  11. workbench/api/meta_model.py +289 -0
  12. workbench/api/model.py +43 -4
  13. workbench/core/artifacts/endpoint_core.py +63 -115
  14. workbench/core/artifacts/feature_set_core.py +1 -1
  15. workbench/core/artifacts/model_core.py +6 -4
  16. workbench/core/pipelines/pipeline_executor.py +1 -1
  17. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  18. workbench/model_script_utils/pytorch_utils.py +11 -1
  19. workbench/model_scripts/chemprop/chemprop.template +145 -69
  20. workbench/model_scripts/chemprop/generated_model_script.py +147 -71
  21. workbench/model_scripts/custom_models/chem_info/fingerprints.py +7 -3
  22. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  23. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  24. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  25. workbench/model_scripts/custom_models/uq_models/meta_uq.template +6 -6
  26. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  27. workbench/model_scripts/meta_model/meta_model.template +209 -0
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +42 -24
  29. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  30. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  31. workbench/model_scripts/script_generation.py +4 -0
  32. workbench/model_scripts/xgb_model/generated_model_script.py +169 -158
  33. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  34. workbench/repl/workbench_shell.py +0 -5
  35. workbench/scripts/endpoint_test.py +2 -2
  36. workbench/utils/chem_utils/fingerprints.py +7 -3
  37. workbench/utils/chemprop_utils.py +23 -5
  38. workbench/utils/meta_model_simulator.py +471 -0
  39. workbench/utils/metrics_utils.py +94 -10
  40. workbench/utils/model_utils.py +91 -9
  41. workbench/utils/pytorch_utils.py +1 -1
  42. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  43. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/METADATA +2 -1
  44. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/RECORD +48 -43
  45. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  46. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  47. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/WHEEL +0 -0
  48. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/entry_points.txt +0 -0
  49. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/licenses/LICENSE +0 -0
  50. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,471 @@
1
+ """MetaModelSimulator: Simulate and analyze ensemble model performance.
2
+
3
+ This class helps evaluate whether a meta model (ensemble) would outperform
4
+ individual child models by analyzing endpoint inference predictions.
5
+ """
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ from scipy import stats
10
+ import logging
11
+
12
+ from workbench.api import Model
13
+
14
+ # Set up the log
15
+ log = logging.getLogger("workbench")
16
+
17
+
18
+ class MetaModelSimulator:
19
+ """Simulate meta model performance from child model predictions.
20
+
21
+ This class loads cross-validation predictions from multiple models and
22
+ analyzes how different ensemble strategies would perform compared to
23
+ the individual models.
24
+
25
+ Example:
26
+ ```python
27
+ from workbench.utils.meta_model_simulator import MetaModelSimulator
28
+
29
+ sim = MetaModelSimulator(["model-a", "model-b", "model-c"])
30
+ sim.report() # Print full analysis
31
+ sim.strategy_comparison() # Compare ensemble strategies
32
+ ```
33
+ """
34
+
35
+ def __init__(self, model_names: list[str], id_column: str = "id"):
36
+ """Initialize the simulator with a list of model names.
37
+
38
+ Args:
39
+ model_names: List of model names to include in the ensemble
40
+ id_column: Column name to use for row alignment (default: "id")
41
+ """
42
+ self.model_names = model_names
43
+ self.id_column = id_column
44
+ self._dfs: dict[str, pd.DataFrame] = {}
45
+ self._target_column: str | None = None
46
+ self._load_predictions()
47
+
48
+ def _load_predictions(self):
49
+ """Load endpoint inference predictions for all models."""
50
+ log.info(f"Loading predictions for {len(self.model_names)} models...")
51
+ for name in self.model_names:
52
+ model = Model(name)
53
+ if self._target_column is None:
54
+ self._target_column = model.target()
55
+ df = model.get_inference_predictions("full_inference")
56
+ if df is None:
57
+ raise ValueError(
58
+ f"No full_inference predictions found for model '{name}'. Run endpoint inference first."
59
+ )
60
+ df["residual"] = df["prediction"] - df[self._target_column]
61
+ df["abs_residual"] = df["residual"].abs()
62
+ self._dfs[name] = df
63
+
64
+ # Align DataFrames by sorting on id column
65
+ self._dfs = {name: df.sort_values(self.id_column).reset_index(drop=True) for name, df in self._dfs.items()}
66
+ log.info(f"Loaded {len(self._dfs)} models, {len(list(self._dfs.values())[0])} samples each")
67
+
68
+ def report(self, details: bool = False):
69
+ """Print a comprehensive analysis report
70
+
71
+ Args:
72
+ details: Whether to include detailed sections (default: False)
73
+ """
74
+ self.model_performance()
75
+ self.residual_correlations()
76
+ self.strategy_comparison()
77
+ self.ensemble_failure_analysis()
78
+ if details:
79
+ self.confidence_analysis()
80
+ self.model_agreement()
81
+ self.ensemble_weights()
82
+ self.confidence_weight_distribution()
83
+
84
+ def confidence_analysis(self) -> dict[str, dict]:
85
+ """Analyze how confidence correlates with prediction accuracy.
86
+
87
+ Returns:
88
+ Dict mapping model name to confidence stats
89
+ """
90
+ print("=" * 60)
91
+ print("CONFIDENCE VS RESIDUALS ANALYSIS")
92
+ print("=" * 60)
93
+
94
+ results = {}
95
+ for name, df in self._dfs.items():
96
+ print(f"\n{name}:")
97
+ print("-" * 50)
98
+
99
+ conf = df["confidence"]
100
+ print(
101
+ f" Confidence: mean={conf.mean():.3f}, std={conf.std():.3f}, "
102
+ f"min={conf.min():.3f}, max={conf.max():.3f}"
103
+ )
104
+
105
+ corr_pearson, p_pearson = stats.pearsonr(df["confidence"], df["abs_residual"])
106
+ corr_spearman, p_spearman = stats.spearmanr(df["confidence"], df["abs_residual"])
107
+
108
+ print(" Confidence vs |residual|:")
109
+ print(f" Pearson r={corr_pearson:.3f} (p={p_pearson:.2e})")
110
+ print(f" Spearman r={corr_spearman:.3f} (p={p_spearman:.2e})")
111
+
112
+ df["conf_quartile"] = pd.qcut(df["confidence"], q=4, labels=["Q1 (low)", "Q2", "Q3", "Q4 (high)"])
113
+ quartile_stats = df.groupby("conf_quartile", observed=True)["abs_residual"].agg(
114
+ ["mean", "median", "std", "count"]
115
+ )
116
+ print(" Error by confidence quartile:")
117
+ print(quartile_stats.to_string().replace("\n", "\n "))
118
+
119
+ results[name] = {
120
+ "mean_conf": conf.mean(),
121
+ "pearson_r": corr_pearson,
122
+ "spearman_r": corr_spearman,
123
+ }
124
+
125
+ return results
126
+
127
+ def residual_correlations(self) -> pd.DataFrame:
128
+ """Analyze correlation of residuals between models.
129
+
130
+ Returns:
131
+ Correlation matrix DataFrame
132
+ """
133
+ print("\n" + "=" * 60)
134
+ print("RESIDUAL CORRELATIONS BETWEEN MODELS")
135
+ print("=" * 60)
136
+
137
+ residual_df = pd.DataFrame({name: df["residual"].values for name, df in self._dfs.items()})
138
+
139
+ corr_matrix = residual_df.corr()
140
+ print("\nPearson correlation of residuals:")
141
+ print(corr_matrix.to_string())
142
+
143
+ spearman_matrix = residual_df.corr(method="spearman")
144
+ print("\nSpearman correlation of residuals:")
145
+ print(spearman_matrix.to_string())
146
+
147
+ print("\nInterpretation:")
148
+ print(" - Low correlation = models make different errors (good for ensemble)")
149
+ print(" - High correlation = models make similar errors (less ensemble benefit)")
150
+
151
+ return corr_matrix
152
+
153
+ def model_agreement(self) -> dict:
154
+ """Analyze where models agree/disagree in predictions.
155
+
156
+ Returns:
157
+ Dict with agreement statistics
158
+ """
159
+ print("\n" + "=" * 60)
160
+ print("MODEL AGREEMENT ANALYSIS")
161
+ print("=" * 60)
162
+
163
+ pred_df = pd.DataFrame()
164
+ for name, df in self._dfs.items():
165
+ if pred_df.empty:
166
+ pred_df[self.id_column] = df[self.id_column]
167
+ pred_df["target"] = df[self._target_column]
168
+ pred_df[f"{name}_pred"] = df["prediction"].values
169
+
170
+ pred_cols = [f"{name}_pred" for name in self._dfs.keys()]
171
+ pred_df["pred_std"] = pred_df[pred_cols].std(axis=1)
172
+ pred_df["pred_mean"] = pred_df[pred_cols].mean(axis=1)
173
+ pred_df["ensemble_residual"] = pred_df["pred_mean"] - pred_df["target"]
174
+ pred_df["ensemble_abs_residual"] = pred_df["ensemble_residual"].abs()
175
+
176
+ print("\nPrediction std across models (disagreement):")
177
+ print(
178
+ f" mean={pred_df['pred_std'].mean():.3f}, median={pred_df['pred_std'].median():.3f}, "
179
+ f"max={pred_df['pred_std'].max():.3f}"
180
+ )
181
+
182
+ corr, p = stats.spearmanr(pred_df["pred_std"], pred_df["ensemble_abs_residual"])
183
+ print(f"\nDisagreement vs ensemble error: Spearman r={corr:.3f} (p={p:.2e})")
184
+
185
+ pred_df["disagree_quartile"] = pd.qcut(
186
+ pred_df["pred_std"], q=4, labels=["Q1 (agree)", "Q2", "Q3", "Q4 (disagree)"]
187
+ )
188
+ quartile_stats = pred_df.groupby("disagree_quartile", observed=True)["ensemble_abs_residual"].agg(
189
+ ["mean", "median", "count"]
190
+ )
191
+ print("\nEnsemble error by disagreement quartile:")
192
+ print(quartile_stats.to_string().replace("\n", "\n "))
193
+
194
+ return {
195
+ "mean_disagreement": pred_df["pred_std"].mean(),
196
+ "disagreement_error_corr": corr,
197
+ }
198
+
199
+ def model_performance(self) -> pd.DataFrame:
200
+ """Show per-model performance metrics.
201
+
202
+ Returns:
203
+ DataFrame with performance metrics for each model
204
+ """
205
+ print("\n" + "=" * 60)
206
+ print("PER-MODEL PERFORMANCE SUMMARY")
207
+ print("=" * 60)
208
+
209
+ metrics = []
210
+ for name, df in self._dfs.items():
211
+ residuals = df["residual"]
212
+ target = df[self._target_column]
213
+ pred = df["prediction"]
214
+
215
+ rmse = np.sqrt((residuals**2).mean())
216
+ mae = residuals.abs().mean()
217
+ r2 = 1 - (residuals**2).sum() / ((target - target.mean()) ** 2).sum()
218
+ spearman = stats.spearmanr(target, pred)[0]
219
+
220
+ metrics.append(
221
+ {
222
+ "model": name,
223
+ "rmse": rmse,
224
+ "mae": mae,
225
+ "r2": r2,
226
+ "spearman": spearman,
227
+ "mean_conf": df["confidence"].mean(),
228
+ }
229
+ )
230
+
231
+ metrics_df = pd.DataFrame(metrics).set_index("model")
232
+ print("\n" + metrics_df.to_string())
233
+ return metrics_df
234
+
235
+ def ensemble_weights(self) -> dict[str, float]:
236
+ """Calculate suggested ensemble weights based on inverse MAE.
237
+
238
+ Returns:
239
+ Dict mapping model name to suggested weight
240
+ """
241
+ print("\n" + "=" * 60)
242
+ print("SUGGESTED ENSEMBLE WEIGHTS")
243
+ print("=" * 60)
244
+
245
+ mae_scores = {name: df["abs_residual"].mean() for name, df in self._dfs.items()}
246
+
247
+ inv_mae = {name: 1.0 / mae for name, mae in mae_scores.items()}
248
+ total = sum(inv_mae.values())
249
+ weights = {name: w / total for name, w in inv_mae.items()}
250
+
251
+ print("\nWeights based on inverse MAE:")
252
+ for name, weight in weights.items():
253
+ print(f" {name}: {weight:.3f} (MAE={mae_scores[name]:.3f})")
254
+
255
+ print(f"\nEqual weights would be: {1.0/len(self._dfs):.3f} each")
256
+
257
+ return weights
258
+
259
+ def strategy_comparison(self) -> pd.DataFrame:
260
+ """Compare different ensemble strategies.
261
+
262
+ Returns:
263
+ DataFrame with MAE for each strategy, sorted best to worst
264
+ """
265
+ print("\n" + "=" * 60)
266
+ print("ENSEMBLE STRATEGY COMPARISON")
267
+ print("=" * 60)
268
+
269
+ combined = pd.DataFrame()
270
+ model_names = list(self._dfs.keys())
271
+
272
+ for name, df in self._dfs.items():
273
+ if combined.empty:
274
+ combined[self.id_column] = df[self.id_column]
275
+ combined["target"] = df[self._target_column]
276
+ combined[f"{name}_pred"] = df["prediction"].values
277
+ combined[f"{name}_conf"] = df["confidence"].values
278
+
279
+ pred_cols = [f"{name}_pred" for name in model_names]
280
+ conf_cols = [f"{name}_conf" for name in model_names]
281
+
282
+ results = []
283
+
284
+ # Strategy 1: Simple mean
285
+ combined["simple_mean"] = combined[pred_cols].mean(axis=1)
286
+ mae = (combined["simple_mean"] - combined["target"]).abs().mean()
287
+ results.append({"strategy": "Simple Mean", "mae": mae})
288
+
289
+ # Strategy 2: Confidence-weighted
290
+ conf_arr = combined[conf_cols].values
291
+ pred_arr = combined[pred_cols].values
292
+ conf_sum = conf_arr.sum(axis=1, keepdims=True) + 1e-8
293
+ weights = conf_arr / conf_sum
294
+ combined["conf_weighted"] = (pred_arr * weights).sum(axis=1)
295
+ mae = (combined["conf_weighted"] - combined["target"]).abs().mean()
296
+ results.append({"strategy": "Confidence-Weighted", "mae": mae})
297
+
298
+ # Strategy 3: Inverse-MAE weighted
299
+ mae_scores = {name: self._dfs[name]["abs_residual"].mean() for name in model_names}
300
+ inv_mae_weights = np.array([1.0 / mae_scores[name] for name in model_names])
301
+ inv_mae_weights = inv_mae_weights / inv_mae_weights.sum()
302
+ combined["inv_mae_weighted"] = (pred_arr * inv_mae_weights).sum(axis=1)
303
+ mae = (combined["inv_mae_weighted"] - combined["target"]).abs().mean()
304
+ results.append({"strategy": "Inverse-MAE Weighted", "mae": mae})
305
+
306
+ # Strategy 4: Best model only
307
+ best_model = min(mae_scores, key=mae_scores.get)
308
+ combined["best_only"] = combined[f"{best_model}_pred"]
309
+ mae = (combined["best_only"] - combined["target"]).abs().mean()
310
+ results.append({"strategy": f"Best Model Only ({best_model})", "mae": mae})
311
+
312
+ # Strategy 5: Scaled confidence-weighted (confidence * model_weights)
313
+ scaled_conf = conf_arr * inv_mae_weights
314
+ scaled_conf_sum = scaled_conf.sum(axis=1, keepdims=True) + 1e-8
315
+ scaled_weights = scaled_conf / scaled_conf_sum
316
+ combined["scaled_conf_weighted"] = (pred_arr * scaled_weights).sum(axis=1)
317
+ mae = (combined["scaled_conf_weighted"] - combined["target"]).abs().mean()
318
+ results.append({"strategy": "Scaled Conf-Weighted", "mae": mae})
319
+
320
+ # Strategy 6: Drop worst model (use simple mean of remaining, or raw prediction if only 1 left)
321
+ worst_model = max(mae_scores, key=mae_scores.get)
322
+ remaining = [n for n in model_names if n != worst_model]
323
+ remaining_pred_cols = [f"{n}_pred" for n in remaining]
324
+ if len(remaining) == 1:
325
+ # Single model remaining - use raw prediction (same as "Best Model Only")
326
+ combined["drop_worst"] = combined[remaining_pred_cols[0]]
327
+ else:
328
+ # Multiple models remaining - use simple mean
329
+ combined["drop_worst"] = combined[remaining_pred_cols].mean(axis=1)
330
+ mae = (combined["drop_worst"] - combined["target"]).abs().mean()
331
+ results.append({"strategy": f"Drop Worst ({worst_model})", "mae": mae})
332
+
333
+ results_df = pd.DataFrame(results).sort_values("mae")
334
+ print("\n" + results_df.to_string(index=False))
335
+
336
+ print("\nIndividual model MAEs for reference:")
337
+ for name, mae in sorted(mae_scores.items(), key=lambda x: x[1]):
338
+ print(f" {name}: {mae:.4f}")
339
+
340
+ return results_df
341
+
342
+ def confidence_weight_distribution(self) -> pd.DataFrame:
343
+ """Analyze how confidence weights are distributed across models.
344
+
345
+ Returns:
346
+ DataFrame with weight distribution statistics
347
+ """
348
+ print("\n" + "=" * 60)
349
+ print("CONFIDENCE WEIGHT DISTRIBUTION")
350
+ print("=" * 60)
351
+
352
+ model_names = list(self._dfs.keys())
353
+ conf_df = pd.DataFrame({name: df["confidence"].values for name, df in self._dfs.items()})
354
+
355
+ conf_sum = conf_df.sum(axis=1)
356
+ weight_df = conf_df.div(conf_sum, axis=0)
357
+
358
+ print("\nMean weight per model (from confidence-weighting):")
359
+ for name in model_names:
360
+ print(f" {name}: {weight_df[name].mean():.3f}")
361
+
362
+ print("\nWeight distribution stats:")
363
+ print(weight_df.describe().to_string())
364
+
365
+ print("\nHow often each model has highest weight:")
366
+ winner = weight_df.idxmax(axis=1)
367
+ winner_counts = winner.value_counts()
368
+ for name in model_names:
369
+ count = winner_counts.get(name, 0)
370
+ print(f" {name}: {count} ({100*count/len(weight_df):.1f}%)")
371
+
372
+ return weight_df
373
+
374
+ def ensemble_failure_analysis(self) -> dict:
375
+ """Compare ensemble vs best overall model (not per-row oracle).
376
+
377
+ Returns:
378
+ Dict with comparison statistics
379
+ """
380
+ print("\n" + "=" * 60)
381
+ print("ENSEMBLE VS BEST MODEL COMPARISON")
382
+ print("=" * 60)
383
+
384
+ model_names = list(self._dfs.keys())
385
+
386
+ combined = pd.DataFrame()
387
+ for name, df in self._dfs.items():
388
+ if combined.empty:
389
+ combined[self.id_column] = df[self.id_column]
390
+ combined["target"] = df[self._target_column]
391
+ combined[f"{name}_pred"] = df["prediction"].values
392
+ combined[f"{name}_conf"] = df["confidence"].values
393
+ combined[f"{name}_abs_err"] = df["abs_residual"].values
394
+
395
+ pred_cols = [f"{name}_pred" for name in model_names]
396
+
397
+ # Calculate ensemble prediction (inverse-MAE weighted)
398
+ mae_scores = {name: self._dfs[name]["abs_residual"].mean() for name in model_names}
399
+ inv_mae_weights = np.array([1.0 / mae_scores[name] for name in model_names])
400
+ inv_mae_weights = inv_mae_weights / inv_mae_weights.sum()
401
+ pred_arr = combined[pred_cols].values
402
+ combined["ensemble_pred"] = (pred_arr * inv_mae_weights).sum(axis=1)
403
+ combined["ensemble_abs_err"] = (combined["ensemble_pred"] - combined["target"]).abs()
404
+
405
+ # Find best overall model (lowest MAE)
406
+ best_model = min(mae_scores, key=mae_scores.get)
407
+ combined["best_model_abs_err"] = combined[f"{best_model}_abs_err"]
408
+
409
+ # Compare ensemble vs best model
410
+ combined["ensemble_better"] = combined["ensemble_abs_err"] < combined["best_model_abs_err"]
411
+ n_better = combined["ensemble_better"].sum()
412
+ n_total = len(combined)
413
+
414
+ ensemble_mae = combined["ensemble_abs_err"].mean()
415
+ best_model_mae = mae_scores[best_model]
416
+
417
+ print(f"\nBest individual model: {best_model} (MAE={best_model_mae:.4f})")
418
+ print(f"Ensemble MAE: {ensemble_mae:.4f}")
419
+ if ensemble_mae < best_model_mae:
420
+ improvement = (best_model_mae - ensemble_mae) / best_model_mae * 100
421
+ print(f"Ensemble improves over best model by {improvement:.1f}%")
422
+ else:
423
+ degradation = (ensemble_mae - best_model_mae) / best_model_mae * 100
424
+ print(f"Ensemble is worse than best model by {degradation:.1f}%")
425
+
426
+ print("\nPer-row comparison:")
427
+ print(f" Ensemble wins: {n_better}/{n_total} ({100*n_better/n_total:.1f}%)")
428
+ print(f" Best model wins: {n_total - n_better}/{n_total} ({100*(n_total - n_better)/n_total:.1f}%)")
429
+
430
+ # When ensemble wins
431
+ ensemble_wins = combined[combined["ensemble_better"]]
432
+ if len(ensemble_wins) > 0:
433
+ print("\nWhen ensemble wins:")
434
+ print(f" Mean ensemble error: {ensemble_wins['ensemble_abs_err'].mean():.3f}")
435
+ print(f" Mean best model error: {ensemble_wins['best_model_abs_err'].mean():.3f}")
436
+
437
+ # When best model wins
438
+ best_wins = combined[~combined["ensemble_better"]]
439
+ if len(best_wins) > 0:
440
+ print("\nWhen best model wins:")
441
+ print(f" Mean ensemble error: {best_wins['ensemble_abs_err'].mean():.3f}")
442
+ print(f" Mean best model error: {best_wins['best_model_abs_err'].mean():.3f}")
443
+
444
+ return {
445
+ "ensemble_mae": ensemble_mae,
446
+ "best_model": best_model,
447
+ "best_model_mae": best_model_mae,
448
+ "ensemble_win_rate": n_better / n_total,
449
+ }
450
+
451
+
452
+ if __name__ == "__main__":
453
+ # Example usage
454
+
455
+ print("\n" + "*" * 80)
456
+ print("Full ensemble analysis: XGB + PyTorch + ChemProp")
457
+ print("*" * 80)
458
+ sim = MetaModelSimulator(
459
+ ["logd-reg-xgb", "logd-reg-pytorch", "logd-reg-chemprop"],
460
+ id_column="molecule_name",
461
+ )
462
+ sim.report(details=True) # Full analysis
463
+
464
+ print("\n" + "*" * 80)
465
+ print("Two model ensemble analysis: PyTorch + ChemProp")
466
+ print("*" * 80)
467
+ sim = MetaModelSimulator(
468
+ ["logd-reg-pytorch", "logd-reg-chemprop"],
469
+ id_column="molecule_name",
470
+ )
471
+ sim.report(details=True) # Full analysis
@@ -18,10 +18,32 @@ from sklearn.metrics import (
18
18
  log = logging.getLogger("workbench")
19
19
 
20
20
 
21
+ def validate_proba_columns(predictions_df: pd.DataFrame, class_labels: List[str], guessing: bool = False) -> bool:
22
+ """Validate that probability columns match class labels.
23
+
24
+ Args:
25
+ predictions_df: DataFrame with prediction results
26
+ class_labels: List of class labels
27
+ guessing: Whether class labels were guessed from data
28
+
29
+ Returns:
30
+ True if validation passes
31
+
32
+ Raises:
33
+ ValueError: If probability columns don't match class labels
34
+ """
35
+ proba_columns = [col.replace("_proba", "") for col in predictions_df.columns if col.endswith("_proba")]
36
+
37
+ if sorted(class_labels) != sorted(proba_columns):
38
+ label_type = "GUESSED class_labels" if guessing else "class_labels"
39
+ raise ValueError(f"_proba columns {proba_columns} != {label_type} {class_labels}!")
40
+ return True
41
+
42
+
21
43
  def compute_classification_metrics(
22
44
  predictions_df: pd.DataFrame,
23
45
  target_col: str,
24
- class_labels: List[str],
46
+ class_labels: Optional[List[str]] = None,
25
47
  prediction_col: str = "prediction",
26
48
  ) -> pd.DataFrame:
27
49
  """Compute classification metrics from a predictions DataFrame.
@@ -29,26 +51,62 @@ def compute_classification_metrics(
29
51
  Args:
30
52
  predictions_df: DataFrame with target and prediction columns
31
53
  target_col: Name of the target column
32
- class_labels: List of class labels in order
54
+ class_labels: List of class labels in order (if None, inferred from target column)
33
55
  prediction_col: Name of the prediction column (default: "prediction")
34
56
 
35
57
  Returns:
36
58
  DataFrame with per-class metrics (precision, recall, f1, roc_auc, support)
37
- plus a weighted 'all' row
59
+ plus a weighted 'all' row. Returns empty DataFrame if validation fails.
38
60
  """
39
- y_true = predictions_df[target_col]
40
- y_pred = predictions_df[prediction_col]
61
+ # Validate inputs
62
+ if predictions_df.empty:
63
+ log.warning("Empty DataFrame provided. Returning empty metrics.")
64
+ return pd.DataFrame()
65
+
66
+ if prediction_col not in predictions_df.columns:
67
+ log.warning(f"Prediction column '{prediction_col}' not found in DataFrame. Returning empty metrics.")
68
+ return pd.DataFrame()
69
+
70
+ if target_col not in predictions_df.columns:
71
+ log.warning(f"Target column '{target_col}' not found in DataFrame. Returning empty metrics.")
72
+ return pd.DataFrame()
73
+
74
+ # Handle NaN predictions
75
+ df = predictions_df.copy()
76
+ nan_pred = df[prediction_col].isnull().sum()
77
+ if nan_pred > 0:
78
+ log.warning(f"Dropping {nan_pred} rows with NaN predictions.")
79
+ df = df[~df[prediction_col].isnull()]
80
+
81
+ if df.empty:
82
+ log.warning("No valid rows after dropping NaNs. Returning empty metrics.")
83
+ return pd.DataFrame()
84
+
85
+ # Handle class labels
86
+ guessing = False
87
+ if class_labels is None:
88
+ log.warning("Class labels not provided. Inferring from target column.")
89
+ class_labels = df[target_col].unique().tolist()
90
+ guessing = True
91
+
92
+ # Validate probability columns if present
93
+ proba_cols = [col for col in df.columns if col.endswith("_proba")]
94
+ if proba_cols:
95
+ validate_proba_columns(df, class_labels, guessing=guessing)
96
+
97
+ y_true = df[target_col]
98
+ y_pred = df[prediction_col]
41
99
 
42
100
  # Precision, recall, f1, support per class
43
101
  prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=class_labels, zero_division=0)
44
102
 
45
103
  # ROC AUC per class (requires probability columns and sorted labels)
46
- proba_cols = [f"{label}_proba" for label in class_labels]
47
- if all(col in predictions_df.columns for col in proba_cols):
104
+ proba_col_names = [f"{label}_proba" for label in class_labels]
105
+ if all(col in df.columns for col in proba_col_names):
48
106
  # roc_auc_score requires labels to be sorted, so we sort and reorder results back
49
107
  sorted_labels = sorted(class_labels)
50
108
  sorted_proba_cols = [f"{label}_proba" for label in sorted_labels]
51
- y_score_sorted = predictions_df[sorted_proba_cols].values
109
+ y_score_sorted = df[sorted_proba_cols].values
52
110
  roc_auc_sorted = roc_auc_score(y_true, y_score_sorted, labels=sorted_labels, multi_class="ovr", average=None)
53
111
  # Map back to original class_labels order
54
112
  label_to_auc = dict(zip(sorted_labels, roc_auc_sorted))
@@ -97,9 +155,35 @@ def compute_regression_metrics(
97
155
 
98
156
  Returns:
99
157
  DataFrame with regression metrics (rmse, mae, medae, r2, spearmanr, support)
158
+ Returns empty DataFrame if validation fails or no valid data.
100
159
  """
101
- y_true = predictions_df[target_col].values
102
- y_pred = predictions_df[prediction_col].values
160
+ # Validate inputs
161
+ if predictions_df.empty:
162
+ log.warning("Empty DataFrame provided. Returning empty metrics.")
163
+ return pd.DataFrame()
164
+
165
+ if prediction_col not in predictions_df.columns:
166
+ log.warning(f"Prediction column '{prediction_col}' not found in DataFrame. Returning empty metrics.")
167
+ return pd.DataFrame()
168
+
169
+ if target_col not in predictions_df.columns:
170
+ log.warning(f"Target column '{target_col}' not found in DataFrame. Returning empty metrics.")
171
+ return pd.DataFrame()
172
+
173
+ # Handle NaN values
174
+ df = predictions_df[[target_col, prediction_col]].copy()
175
+ nan_target = df[target_col].isnull().sum()
176
+ nan_pred = df[prediction_col].isnull().sum()
177
+ if nan_target > 0 or nan_pred > 0:
178
+ log.warning(f"NaNs found: {target_col}={nan_target}, {prediction_col}={nan_pred}. Dropping NaN rows.")
179
+ df = df.dropna()
180
+
181
+ if df.empty:
182
+ log.warning("No valid rows after dropping NaNs. Returning empty metrics.")
183
+ return pd.DataFrame()
184
+
185
+ y_true = df[target_col].values
186
+ y_pred = df[prediction_col].values
103
187
 
104
188
  return pd.DataFrame(
105
189
  [