workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (140) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +5 -5
  27. workbench/core/artifacts/df_store_core.py +114 -0
  28. workbench/core/artifacts/endpoint_core.py +319 -204
  29. workbench/core/artifacts/feature_set_core.py +249 -45
  30. workbench/core/artifacts/model_core.py +135 -82
  31. workbench/core/artifacts/parameter_store_core.py +98 -0
  32. workbench/core/cloud_platform/cloud_meta.py +0 -1
  33. workbench/core/pipelines/pipeline_executor.py +1 -1
  34. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  35. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  36. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  37. workbench/core/views/training_view.py +113 -42
  38. workbench/core/views/view.py +53 -3
  39. workbench/core/views/view_utils.py +4 -4
  40. workbench/model_script_utils/model_script_utils.py +339 -0
  41. workbench/model_script_utils/pytorch_utils.py +405 -0
  42. workbench/model_script_utils/uq_harness.py +277 -0
  43. workbench/model_scripts/chemprop/chemprop.template +774 -0
  44. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  45. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  46. workbench/model_scripts/chemprop/requirements.txt +3 -0
  47. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  48. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  49. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  50. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  51. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  52. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  53. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  54. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  55. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  56. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  57. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  58. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  59. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  60. workbench/model_scripts/meta_model/meta_model.template +209 -0
  61. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  62. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  63. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  64. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  65. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  66. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  67. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  68. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  69. workbench/model_scripts/script_generation.py +15 -12
  70. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  71. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  72. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  73. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  74. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  75. workbench/repl/workbench_shell.py +18 -14
  76. workbench/resources/open_source_api.key +1 -1
  77. workbench/scripts/endpoint_test.py +162 -0
  78. workbench/scripts/lambda_test.py +73 -0
  79. workbench/scripts/meta_model_sim.py +35 -0
  80. workbench/scripts/ml_pipeline_sqs.py +122 -6
  81. workbench/scripts/training_test.py +85 -0
  82. workbench/themes/dark/custom.css +59 -0
  83. workbench/themes/dark/plotly.json +5 -5
  84. workbench/themes/light/custom.css +153 -40
  85. workbench/themes/light/plotly.json +9 -9
  86. workbench/themes/midnight_blue/custom.css +59 -0
  87. workbench/utils/aws_utils.py +0 -1
  88. workbench/utils/chem_utils/fingerprints.py +87 -46
  89. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  90. workbench/utils/chem_utils/projections.py +16 -6
  91. workbench/utils/chem_utils/vis.py +25 -27
  92. workbench/utils/chemprop_utils.py +141 -0
  93. workbench/utils/config_manager.py +2 -6
  94. workbench/utils/endpoint_utils.py +5 -7
  95. workbench/utils/license_manager.py +2 -6
  96. workbench/utils/markdown_utils.py +57 -0
  97. workbench/utils/meta_model_simulator.py +499 -0
  98. workbench/utils/metrics_utils.py +256 -0
  99. workbench/utils/model_utils.py +260 -76
  100. workbench/utils/pipeline_utils.py +0 -1
  101. workbench/utils/plot_utils.py +159 -34
  102. workbench/utils/pytorch_utils.py +87 -0
  103. workbench/utils/shap_utils.py +11 -57
  104. workbench/utils/theme_manager.py +95 -30
  105. workbench/utils/xgboost_local_crossfold.py +267 -0
  106. workbench/utils/xgboost_model_utils.py +127 -220
  107. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  108. workbench/web_interface/components/model_plot.py +16 -2
  109. workbench/web_interface/components/plugin_unit_test.py +5 -3
  110. workbench/web_interface/components/plugins/ag_table.py +2 -4
  111. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  112. workbench/web_interface/components/plugins/model_details.py +48 -80
  113. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  114. workbench/web_interface/components/settings_menu.py +184 -0
  115. workbench/web_interface/page_views/main_page.py +0 -1
  116. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  117. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
  118. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  119. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  120. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  121. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  122. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  123. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  124. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  125. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
  126. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
  127. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  128. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  129. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  130. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  131. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  132. workbench/themes/quartz/base_css.url +0 -1
  133. workbench/themes/quartz/custom.css +0 -117
  134. workbench/themes/quartz/plotly.json +0 -642
  135. workbench/themes/quartz_dark/base_css.url +0 -1
  136. workbench/themes/quartz_dark/custom.css +0 -131
  137. workbench/themes/quartz_dark/plotly.json +0 -642
  138. workbench/utils/resource_utils.py +0 -39
  139. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  140. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,499 @@
1
+ """MetaModelSimulator: Simulate and analyze ensemble model performance.
2
+
3
+ This class helps evaluate whether a meta model (ensemble) would outperform
4
+ individual child models by analyzing endpoint inference predictions.
5
+ """
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ from scipy import stats
10
+ import logging
11
+
12
+ from workbench.api import Model
13
+
14
+ # Set up the log
15
+ log = logging.getLogger("workbench")
16
+
17
+
18
+ class MetaModelSimulator:
19
+ """Simulate meta model performance from child model predictions.
20
+
21
+ This class loads cross-validation predictions from multiple models and
22
+ analyzes how different ensemble strategies would perform compared to
23
+ the individual models.
24
+
25
+ Example:
26
+ ```python
27
+ from workbench.utils.meta_model_simulator import MetaModelSimulator
28
+
29
+ sim = MetaModelSimulator(["model-a", "model-b", "model-c"])
30
+ sim.report() # Print full analysis
31
+ sim.strategy_comparison() # Compare ensemble strategies
32
+ ```
33
+ """
34
+
35
+ def __init__(self, model_names: list[str], id_column: str = "id"):
36
+ """Initialize the simulator with a list of model names.
37
+
38
+ Args:
39
+ model_names: List of model names to include in the ensemble
40
+ id_column: Column name to use for row alignment (default: "id")
41
+ """
42
+ self.model_names = model_names
43
+ self.id_column = id_column
44
+ self._dfs: dict[str, pd.DataFrame] = {}
45
+ self._target_column: str | None = None
46
+ self._load_predictions()
47
+
48
+ def _load_predictions(self):
49
+ """Load endpoint inference predictions for all models."""
50
+ log.info(f"Loading predictions for {len(self.model_names)} models...")
51
+ for name in self.model_names:
52
+ model = Model(name)
53
+ if self._target_column is None:
54
+ self._target_column = model.target()
55
+ df = model.get_inference_predictions("full_inference")
56
+ if df is None:
57
+ raise ValueError(
58
+ f"No full_inference predictions found for model '{name}'. Run endpoint inference first."
59
+ )
60
+ df["residual"] = df["prediction"] - df[self._target_column]
61
+ df["abs_residual"] = df["residual"].abs()
62
+ self._dfs[name] = df
63
+
64
+ # Find common rows across all models
65
+ id_sets = {name: set(df[self.id_column]) for name, df in self._dfs.items()}
66
+ common_ids = set.intersection(*id_sets.values())
67
+ sizes = ", ".join(f"{name}: {len(ids)}" for name, ids in id_sets.items())
68
+ log.info(f"Row counts before alignment: {sizes} -> common: {len(common_ids)}")
69
+ self._dfs = {name: df[df[self.id_column].isin(common_ids)] for name, df in self._dfs.items()}
70
+
71
+ # Align DataFrames by sorting on id column
72
+ self._dfs = {name: df.sort_values(self.id_column).reset_index(drop=True) for name, df in self._dfs.items()}
73
+ log.info(f"Loaded {len(self._dfs)} models, {len(list(self._dfs.values())[0])} samples each")
74
+
75
+ def report(self, details: bool = False):
76
+ """Print a comprehensive analysis report
77
+
78
+ Args:
79
+ details: Whether to include detailed sections (default: False)
80
+ """
81
+ self.model_performance()
82
+ self.residual_correlations()
83
+ self.strategy_comparison()
84
+ self.ensemble_failure_analysis()
85
+ if details:
86
+ self.confidence_analysis()
87
+ self.model_agreement()
88
+ self.ensemble_weights()
89
+ self.confidence_weight_distribution()
90
+
91
+ def confidence_analysis(self) -> dict[str, dict]:
92
+ """Analyze how confidence correlates with prediction accuracy.
93
+
94
+ Returns:
95
+ Dict mapping model name to confidence stats
96
+ """
97
+ print("=" * 60)
98
+ print("CONFIDENCE VS RESIDUALS ANALYSIS")
99
+ print("=" * 60)
100
+
101
+ results = {}
102
+ for name, df in self._dfs.items():
103
+ print(f"\n{name}:")
104
+ print("-" * 50)
105
+
106
+ conf = df["confidence"]
107
+ print(
108
+ f" Confidence: mean={conf.mean():.3f}, std={conf.std():.3f}, "
109
+ f"min={conf.min():.3f}, max={conf.max():.3f}"
110
+ )
111
+
112
+ corr_pearson, p_pearson = stats.pearsonr(df["confidence"], df["abs_residual"])
113
+ corr_spearman, p_spearman = stats.spearmanr(df["confidence"], df["abs_residual"])
114
+
115
+ print(" Confidence vs |residual|:")
116
+ print(f" Pearson r={corr_pearson:.3f} (p={p_pearson:.2e})")
117
+ print(f" Spearman r={corr_spearman:.3f} (p={p_spearman:.2e})")
118
+
119
+ df["conf_quartile"] = pd.qcut(df["confidence"], q=4, labels=["Q1 (low)", "Q2", "Q3", "Q4 (high)"])
120
+ quartile_stats = df.groupby("conf_quartile", observed=True)["abs_residual"].agg(
121
+ ["mean", "median", "std", "count"]
122
+ )
123
+ print(" Error by confidence quartile:")
124
+ print(quartile_stats.to_string().replace("\n", "\n "))
125
+
126
+ results[name] = {
127
+ "mean_conf": conf.mean(),
128
+ "pearson_r": corr_pearson,
129
+ "spearman_r": corr_spearman,
130
+ }
131
+
132
+ return results
133
+
134
+ def residual_correlations(self) -> pd.DataFrame:
135
+ """Analyze correlation of residuals between models.
136
+
137
+ Returns:
138
+ Correlation matrix DataFrame
139
+ """
140
+ print("\n" + "=" * 60)
141
+ print("RESIDUAL CORRELATIONS BETWEEN MODELS")
142
+ print("=" * 60)
143
+
144
+ residual_df = pd.DataFrame({name: df["residual"].values for name, df in self._dfs.items()})
145
+
146
+ corr_matrix = residual_df.corr()
147
+ print("\nPearson correlation of residuals:")
148
+ print(corr_matrix.to_string())
149
+
150
+ spearman_matrix = residual_df.corr(method="spearman")
151
+ print("\nSpearman correlation of residuals:")
152
+ print(spearman_matrix.to_string())
153
+
154
+ print("\nInterpretation:")
155
+ print(" - Low correlation = models make different errors (good for ensemble)")
156
+ print(" - High correlation = models make similar errors (less ensemble benefit)")
157
+
158
+ return corr_matrix
159
+
160
+ def model_agreement(self) -> dict:
161
+ """Analyze where models agree/disagree in predictions.
162
+
163
+ Returns:
164
+ Dict with agreement statistics
165
+ """
166
+ print("\n" + "=" * 60)
167
+ print("MODEL AGREEMENT ANALYSIS")
168
+ print("=" * 60)
169
+
170
+ pred_df = pd.DataFrame()
171
+ for name, df in self._dfs.items():
172
+ if pred_df.empty:
173
+ pred_df[self.id_column] = df[self.id_column]
174
+ pred_df["target"] = df[self._target_column]
175
+ pred_df[f"{name}_pred"] = df["prediction"].values
176
+
177
+ pred_cols = [f"{name}_pred" for name in self._dfs.keys()]
178
+ pred_df["pred_std"] = pred_df[pred_cols].std(axis=1)
179
+ pred_df["pred_mean"] = pred_df[pred_cols].mean(axis=1)
180
+ pred_df["ensemble_residual"] = pred_df["pred_mean"] - pred_df["target"]
181
+ pred_df["ensemble_abs_residual"] = pred_df["ensemble_residual"].abs()
182
+
183
+ print("\nPrediction std across models (disagreement):")
184
+ print(
185
+ f" mean={pred_df['pred_std'].mean():.3f}, median={pred_df['pred_std'].median():.3f}, "
186
+ f"max={pred_df['pred_std'].max():.3f}"
187
+ )
188
+
189
+ corr, p = stats.spearmanr(pred_df["pred_std"], pred_df["ensemble_abs_residual"])
190
+ print(f"\nDisagreement vs ensemble error: Spearman r={corr:.3f} (p={p:.2e})")
191
+
192
+ pred_df["disagree_quartile"] = pd.qcut(
193
+ pred_df["pred_std"], q=4, labels=["Q1 (agree)", "Q2", "Q3", "Q4 (disagree)"]
194
+ )
195
+ quartile_stats = pred_df.groupby("disagree_quartile", observed=True)["ensemble_abs_residual"].agg(
196
+ ["mean", "median", "count"]
197
+ )
198
+ print("\nEnsemble error by disagreement quartile:")
199
+ print(quartile_stats.to_string().replace("\n", "\n "))
200
+
201
+ return {
202
+ "mean_disagreement": pred_df["pred_std"].mean(),
203
+ "disagreement_error_corr": corr,
204
+ }
205
+
206
+ def model_performance(self) -> pd.DataFrame:
207
+ """Show per-model performance metrics.
208
+
209
+ Returns:
210
+ DataFrame with performance metrics for each model
211
+ """
212
+ print("\n" + "=" * 60)
213
+ print("PER-MODEL PERFORMANCE SUMMARY")
214
+ print("=" * 60)
215
+
216
+ metrics = []
217
+ for name, df in self._dfs.items():
218
+ residuals = df["residual"]
219
+ target = df[self._target_column]
220
+ pred = df["prediction"]
221
+
222
+ rmse = np.sqrt((residuals**2).mean())
223
+ mae = residuals.abs().mean()
224
+ r2 = 1 - (residuals**2).sum() / ((target - target.mean()) ** 2).sum()
225
+ spearman = stats.spearmanr(target, pred)[0]
226
+
227
+ metrics.append(
228
+ {
229
+ "model": name,
230
+ "rmse": rmse,
231
+ "mae": mae,
232
+ "r2": r2,
233
+ "spearman": spearman,
234
+ "mean_conf": df["confidence"].mean(),
235
+ }
236
+ )
237
+
238
+ metrics_df = pd.DataFrame(metrics).set_index("model")
239
+ print("\n" + metrics_df.to_string())
240
+ return metrics_df
241
+
242
+ def ensemble_weights(self) -> dict[str, float]:
243
+ """Calculate suggested ensemble weights based on inverse MAE.
244
+
245
+ Returns:
246
+ Dict mapping model name to suggested weight
247
+ """
248
+ print("\n" + "=" * 60)
249
+ print("SUGGESTED ENSEMBLE WEIGHTS")
250
+ print("=" * 60)
251
+
252
+ mae_scores = {name: df["abs_residual"].mean() for name, df in self._dfs.items()}
253
+
254
+ inv_mae = {name: 1.0 / mae for name, mae in mae_scores.items()}
255
+ total = sum(inv_mae.values())
256
+ weights = {name: w / total for name, w in inv_mae.items()}
257
+
258
+ print("\nWeights based on inverse MAE:")
259
+ for name, weight in weights.items():
260
+ print(f" {name}: {weight:.3f} (MAE={mae_scores[name]:.3f})")
261
+
262
+ print(f"\nEqual weights would be: {1.0/len(self._dfs):.3f} each")
263
+
264
+ return weights
265
+
266
+ def strategy_comparison(self) -> pd.DataFrame:
267
+ """Compare different ensemble strategies.
268
+
269
+ Returns:
270
+ DataFrame with MAE for each strategy, sorted best to worst
271
+ """
272
+ print("\n" + "=" * 60)
273
+ print("ENSEMBLE STRATEGY COMPARISON")
274
+ print("=" * 60)
275
+
276
+ combined = pd.DataFrame()
277
+ model_names = list(self._dfs.keys())
278
+
279
+ for name, df in self._dfs.items():
280
+ if combined.empty:
281
+ combined[self.id_column] = df[self.id_column]
282
+ combined["target"] = df[self._target_column]
283
+ combined[f"{name}_pred"] = df["prediction"].values
284
+ combined[f"{name}_conf"] = df["confidence"].values
285
+
286
+ pred_cols = [f"{name}_pred" for name in model_names]
287
+ conf_cols = [f"{name}_conf" for name in model_names]
288
+
289
+ results = []
290
+
291
+ # Strategy 1: Simple mean
292
+ combined["simple_mean"] = combined[pred_cols].mean(axis=1)
293
+ mae = (combined["simple_mean"] - combined["target"]).abs().mean()
294
+ results.append({"strategy": "Simple Mean", "mae": mae})
295
+
296
+ # Strategy 2: Confidence-weighted
297
+ conf_arr = combined[conf_cols].values
298
+ pred_arr = combined[pred_cols].values
299
+ conf_sum = conf_arr.sum(axis=1, keepdims=True) + 1e-8
300
+ weights = conf_arr / conf_sum
301
+ combined["conf_weighted"] = (pred_arr * weights).sum(axis=1)
302
+ mae = (combined["conf_weighted"] - combined["target"]).abs().mean()
303
+ results.append({"strategy": "Confidence-Weighted", "mae": mae})
304
+
305
+ # Strategy 3: Inverse-MAE weighted
306
+ mae_scores = {name: self._dfs[name]["abs_residual"].mean() for name in model_names}
307
+ inv_mae_weights = np.array([1.0 / mae_scores[name] for name in model_names])
308
+ inv_mae_weights = inv_mae_weights / inv_mae_weights.sum()
309
+ combined["inv_mae_weighted"] = (pred_arr * inv_mae_weights).sum(axis=1)
310
+ mae = (combined["inv_mae_weighted"] - combined["target"]).abs().mean()
311
+ results.append({"strategy": "Inverse-MAE Weighted", "mae": mae})
312
+
313
+ # Strategy 4: Best model only
314
+ best_model = min(mae_scores, key=mae_scores.get)
315
+ combined["best_only"] = combined[f"{best_model}_pred"]
316
+ mae = (combined["best_only"] - combined["target"]).abs().mean()
317
+ results.append({"strategy": f"Best Model Only ({best_model})", "mae": mae})
318
+
319
+ # Strategy 5: Scaled confidence-weighted (confidence * model_weights)
320
+ scaled_conf = conf_arr * inv_mae_weights
321
+ scaled_conf_sum = scaled_conf.sum(axis=1, keepdims=True) + 1e-8
322
+ scaled_weights = scaled_conf / scaled_conf_sum
323
+ combined["scaled_conf_weighted"] = (pred_arr * scaled_weights).sum(axis=1)
324
+ mae = (combined["scaled_conf_weighted"] - combined["target"]).abs().mean()
325
+ results.append({"strategy": "Scaled Conf-Weighted", "mae": mae})
326
+
327
+ # Strategy 6: Drop worst model (use simple mean of remaining, or raw prediction if only 1 left)
328
+ worst_model = max(mae_scores, key=mae_scores.get)
329
+ remaining = [n for n in model_names if n != worst_model]
330
+ remaining_pred_cols = [f"{n}_pred" for n in remaining]
331
+ if len(remaining) == 1:
332
+ # Single model remaining - use raw prediction (same as "Best Model Only")
333
+ combined["drop_worst"] = combined[remaining_pred_cols[0]]
334
+ else:
335
+ # Multiple models remaining - use simple mean
336
+ combined["drop_worst"] = combined[remaining_pred_cols].mean(axis=1)
337
+ mae = (combined["drop_worst"] - combined["target"]).abs().mean()
338
+ results.append({"strategy": f"Drop Worst ({worst_model})", "mae": mae})
339
+
340
+ results_df = pd.DataFrame(results).sort_values("mae")
341
+ print("\n" + results_df.to_string(index=False))
342
+
343
+ print("\nIndividual model MAEs for reference:")
344
+ for name, mae in sorted(mae_scores.items(), key=lambda x: x[1]):
345
+ print(f" {name}: {mae:.4f}")
346
+
347
+ return results_df
348
+
349
+ def confidence_weight_distribution(self) -> pd.DataFrame:
350
+ """Analyze how confidence weights are distributed across models.
351
+
352
+ Returns:
353
+ DataFrame with weight distribution statistics
354
+ """
355
+ print("\n" + "=" * 60)
356
+ print("CONFIDENCE WEIGHT DISTRIBUTION")
357
+ print("=" * 60)
358
+
359
+ model_names = list(self._dfs.keys())
360
+ conf_df = pd.DataFrame({name: df["confidence"].values for name, df in self._dfs.items()})
361
+
362
+ conf_sum = conf_df.sum(axis=1)
363
+ weight_df = conf_df.div(conf_sum, axis=0)
364
+
365
+ print("\nMean weight per model (from confidence-weighting):")
366
+ for name in model_names:
367
+ print(f" {name}: {weight_df[name].mean():.3f}")
368
+
369
+ print("\nWeight distribution stats:")
370
+ print(weight_df.describe().to_string())
371
+
372
+ print("\nHow often each model has highest weight:")
373
+ winner = weight_df.idxmax(axis=1)
374
+ winner_counts = winner.value_counts()
375
+ for name in model_names:
376
+ count = winner_counts.get(name, 0)
377
+ print(f" {name}: {count} ({100*count/len(weight_df):.1f}%)")
378
+
379
+ return weight_df
380
+
381
+ def ensemble_failure_analysis(self) -> dict:
382
+ """Compare best ensemble strategy vs best individual model.
383
+
384
+ Returns:
385
+ Dict with comparison statistics
386
+ """
387
+ print("\n" + "=" * 60)
388
+ print("BEST ENSEMBLE VS BEST MODEL COMPARISON")
389
+ print("=" * 60)
390
+
391
+ model_names = list(self._dfs.keys())
392
+
393
+ combined = pd.DataFrame()
394
+ for name, df in self._dfs.items():
395
+ if combined.empty:
396
+ combined[self.id_column] = df[self.id_column]
397
+ combined["target"] = df[self._target_column]
398
+ combined[f"{name}_pred"] = df["prediction"].values
399
+ combined[f"{name}_conf"] = df["confidence"].values
400
+ combined[f"{name}_abs_err"] = df["abs_residual"].values
401
+
402
+ pred_cols = [f"{name}_pred" for name in model_names]
403
+ conf_cols = [f"{name}_conf" for name in model_names]
404
+ pred_arr = combined[pred_cols].values
405
+ conf_arr = combined[conf_cols].values
406
+
407
+ mae_scores = {name: self._dfs[name]["abs_residual"].mean() for name in model_names}
408
+ inv_mae_weights = np.array([1.0 / mae_scores[name] for name in model_names])
409
+ inv_mae_weights = inv_mae_weights / inv_mae_weights.sum()
410
+
411
+ # Compute all ensemble strategies (true ensembles that combine multiple models)
412
+ ensemble_strategies = {}
413
+ ensemble_strategies["Simple Mean"] = combined[pred_cols].mean(axis=1)
414
+ conf_sum = conf_arr.sum(axis=1, keepdims=True) + 1e-8
415
+ ensemble_strategies["Confidence-Weighted"] = (pred_arr * (conf_arr / conf_sum)).sum(axis=1)
416
+ ensemble_strategies["Inverse-MAE Weighted"] = (pred_arr * inv_mae_weights).sum(axis=1)
417
+ scaled_conf = conf_arr * inv_mae_weights
418
+ scaled_conf_sum = scaled_conf.sum(axis=1, keepdims=True) + 1e-8
419
+ ensemble_strategies["Scaled Conf-Weighted"] = (pred_arr * (scaled_conf / scaled_conf_sum)).sum(axis=1)
420
+ worst_model = max(mae_scores, key=mae_scores.get)
421
+ remaining = [n for n in model_names if n != worst_model]
422
+ remaining_cols = [f"{n}_pred" for n in remaining]
423
+ # Only add Drop Worst if it still combines multiple models
424
+ if len(remaining) > 1:
425
+ ensemble_strategies[f"Drop Worst ({worst_model})"] = combined[remaining_cols].mean(axis=1)
426
+
427
+ # Find best individual model
428
+ best_model = min(mae_scores, key=mae_scores.get)
429
+ combined["best_model_abs_err"] = combined[f"{best_model}_abs_err"]
430
+ best_model_mae = mae_scores[best_model]
431
+
432
+ # Find best true ensemble strategy
433
+ strategy_maes = {name: (preds - combined["target"]).abs().mean() for name, preds in ensemble_strategies.items()}
434
+ best_strategy = min(strategy_maes, key=strategy_maes.get)
435
+ combined["ensemble_pred"] = ensemble_strategies[best_strategy]
436
+ combined["ensemble_abs_err"] = (combined["ensemble_pred"] - combined["target"]).abs()
437
+ ensemble_mae = strategy_maes[best_strategy]
438
+
439
+ # Compare
440
+ combined["ensemble_better"] = combined["ensemble_abs_err"] < combined["best_model_abs_err"]
441
+ n_better = combined["ensemble_better"].sum()
442
+ n_total = len(combined)
443
+
444
+ print(f"\nBest individual model: {best_model} (MAE={best_model_mae:.4f})")
445
+ print(f"Best ensemble strategy: {best_strategy} (MAE={ensemble_mae:.4f})")
446
+ if ensemble_mae < best_model_mae:
447
+ improvement = (best_model_mae - ensemble_mae) / best_model_mae * 100
448
+ print(f"Ensemble improves over best model by {improvement:.1f}%")
449
+ else:
450
+ degradation = (ensemble_mae - best_model_mae) / best_model_mae * 100
451
+ print(f"No ensemble benefit: best single model outperforms all ensemble strategies by {degradation:.1f}%")
452
+
453
+ print("\nPer-row comparison:")
454
+ print(f" Ensemble wins: {n_better}/{n_total} ({100*n_better/n_total:.1f}%)")
455
+ print(f" Best model wins: {n_total - n_better}/{n_total} ({100*(n_total - n_better)/n_total:.1f}%)")
456
+
457
+ # When ensemble wins
458
+ ensemble_wins = combined[combined["ensemble_better"]]
459
+ if len(ensemble_wins) > 0:
460
+ print("\nWhen ensemble wins:")
461
+ print(f" Mean ensemble error: {ensemble_wins['ensemble_abs_err'].mean():.3f}")
462
+ print(f" Mean best model error: {ensemble_wins['best_model_abs_err'].mean():.3f}")
463
+
464
+ # When best model wins
465
+ best_wins = combined[~combined["ensemble_better"]]
466
+ if len(best_wins) > 0:
467
+ print("\nWhen best model wins:")
468
+ print(f" Mean ensemble error: {best_wins['ensemble_abs_err'].mean():.3f}")
469
+ print(f" Mean best model error: {best_wins['best_model_abs_err'].mean():.3f}")
470
+
471
+ return {
472
+ "ensemble_mae": ensemble_mae,
473
+ "best_strategy": best_strategy,
474
+ "best_model": best_model,
475
+ "best_model_mae": best_model_mae,
476
+ "ensemble_win_rate": n_better / n_total,
477
+ }
478
+
479
+
480
+ if __name__ == "__main__":
481
+ # Example usage
482
+
483
+ print("\n" + "*" * 80)
484
+ print("Full ensemble analysis: XGB + PyTorch + ChemProp")
485
+ print("*" * 80)
486
+ sim = MetaModelSimulator(
487
+ ["logd-reg-xgb", "logd-reg-pytorch", "logd-reg-chemprop"],
488
+ id_column="molecule_name",
489
+ )
490
+ sim.report(details=True) # Full analysis
491
+
492
+ print("\n" + "*" * 80)
493
+ print("Two model ensemble analysis: PyTorch + ChemProp")
494
+ print("*" * 80)
495
+ sim = MetaModelSimulator(
496
+ ["logd-reg-pytorch", "logd-reg-chemprop"],
497
+ id_column="molecule_name",
498
+ )
499
+ sim.report(details=True) # Full analysis