ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,649 @@
1
+ """Importance data extraction for visualization layer.
2
+
3
+ Extracts comprehensive visualization data from feature importance analysis results.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from datetime import datetime
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+
13
+ from .types import (
14
+ FeatureDetailData,
15
+ ImportanceVizData,
16
+ LLMContextData,
17
+ MethodComparisonData,
18
+ MethodImportanceData,
19
+ UncertaintyData,
20
+ )
21
+ from .validation import _validate_lengths_match
22
+
23
+
24
+ def extract_importance_viz_data(
25
+ importance_results: dict[str, Any],
26
+ include_uncertainty: bool = True,
27
+ include_distributions: bool = True,
28
+ include_per_feature: bool = True,
29
+ include_llm_context: bool = True,
30
+ ) -> ImportanceVizData:
31
+ """Extract comprehensive visualization data from importance analysis results.
32
+
33
+ This function transforms raw importance analysis results into a structured
34
+ format optimized for rich interactive visualization. It exposes all details
35
+ including per-method breakdowns, uncertainty estimates, per-feature views,
36
+ and auto-generated narratives.
37
+
38
+ Parameters
39
+ ----------
40
+ importance_results : dict
41
+ Results from analyze_ml_importance() containing:
42
+ - 'consensus_ranking': list of features in importance order
43
+ - 'method_results': dict of {method_name: method_result}
44
+ - 'method_agreement': dict of pairwise correlations
45
+ - 'interpretation': analysis interpretation
46
+ - 'warnings': list of warning messages
47
+ include_uncertainty : bool, default=True
48
+ Whether to compute and include uncertainty metrics (stability, CI).
49
+ Requires bootstrap or repeated analysis data.
50
+ include_distributions : bool, default=True
51
+ Whether to include full distributions (per-repeat values for PFI).
52
+ Useful for detailed uncertainty visualization.
53
+ include_per_feature : bool, default=True
54
+ Whether to create per-feature aggregated views.
55
+ Enables feature drill-down dashboards.
56
+ include_llm_context : bool, default=True
57
+ Whether to generate auto-narratives for LLM consumption.
58
+
59
+ Returns
60
+ -------
61
+ ImportanceVizData
62
+ Complete structured data package with all visualization details.
63
+ See ImportanceVizData TypedDict for full structure.
64
+
65
+ Examples
66
+ --------
67
+ >>> from ml4t.diagnostic.evaluation import analyze_ml_importance
68
+ >>> from ml4t.diagnostic.visualization.data_extraction import extract_importance_viz_data
69
+ >>>
70
+ >>> # Analyze importance
71
+ >>> results = analyze_ml_importance(model, X, y, methods=['mdi', 'pfi'])
72
+ >>>
73
+ >>> # Extract visualization data
74
+ >>> viz_data = extract_importance_viz_data(results)
75
+ >>>
76
+ >>> # Access different views
77
+ >>> print(viz_data['summary']['n_features']) # High-level summary
78
+ >>> print(viz_data['per_method']['mdi']['ranking'][:5]) # Top 5 by MDI
79
+ >>> print(viz_data['per_feature']['momentum']['method_ranks']) # Feature detail
80
+ >>> print(viz_data['llm_context']['key_insights']) # Auto-generated insights
81
+
82
+ Notes
83
+ -----
84
+ - The extracted data is designed for both human visualization and LLM interpretation
85
+ - Per-feature views enable drill-down dashboards
86
+ - Uncertainty metrics enable confidence visualization
87
+ - Auto-narratives prepare for future LLM integration
88
+ """
89
+ # Extract basic info
90
+ consensus_ranking = importance_results.get("consensus_ranking", [])
91
+ method_results = importance_results.get("method_results", {})
92
+ method_agreement = importance_results.get("method_agreement", {})
93
+ interpretation = importance_results.get("interpretation", {})
94
+ warnings = importance_results.get("warnings", [])
95
+ methods_run = importance_results.get("methods_run", list(method_results.keys()))
96
+
97
+ n_features = len(consensus_ranking)
98
+ n_methods = len(methods_run)
99
+
100
+ # Build summary
101
+ summary = _build_summary(
102
+ consensus_ranking, method_agreement, methods_run, n_features, n_methods, warnings
103
+ )
104
+
105
+ # Extract per-method details
106
+ per_method = _extract_per_method_data(
107
+ method_results, include_distributions=include_distributions
108
+ )
109
+
110
+ # Build per-feature aggregations
111
+ per_feature = {}
112
+ if include_per_feature:
113
+ per_feature = _build_per_feature_data(
114
+ consensus_ranking, method_results, method_agreement, methods_run
115
+ )
116
+
117
+ # Compute uncertainty metrics
118
+ uncertainty_data: UncertaintyData = {
119
+ "method_stability": {},
120
+ "rank_stability": {},
121
+ "confidence_intervals": {},
122
+ "coefficient_of_variation": {},
123
+ }
124
+ if include_uncertainty:
125
+ uncertainty_data = _compute_uncertainty_metrics(method_results, consensus_ranking)
126
+
127
+ # Build method comparison data
128
+ method_comparison = _build_method_comparison(method_agreement, method_results, methods_run)
129
+
130
+ # Build metadata
131
+ metadata = {
132
+ "n_features": n_features,
133
+ "n_methods": n_methods,
134
+ "methods_run": methods_run,
135
+ "analysis_timestamp": datetime.now().isoformat(),
136
+ "warnings": warnings,
137
+ "interpretation": interpretation,
138
+ }
139
+
140
+ # Generate LLM context
141
+ llm_context: LLMContextData = {
142
+ "summary_narrative": "",
143
+ "key_insights": [],
144
+ "recommendations": [],
145
+ "caveats": [],
146
+ "analysis_quality": "medium",
147
+ }
148
+ if include_llm_context:
149
+ llm_context = _generate_llm_context(
150
+ summary, per_method, method_comparison, uncertainty_data, warnings
151
+ )
152
+
153
+ return ImportanceVizData(
154
+ summary=summary,
155
+ per_method=per_method,
156
+ per_feature=per_feature,
157
+ uncertainty=uncertainty_data,
158
+ method_comparison=method_comparison,
159
+ metadata=metadata,
160
+ llm_context=llm_context,
161
+ )
162
+
163
+
164
+ # =============================================================================
165
+ # Helper Functions
166
+ # =============================================================================
167
+
168
+
169
+ def _build_summary(
170
+ consensus_ranking: list[str],
171
+ method_agreement: dict[str, float],
172
+ methods_run: list[str],
173
+ n_features: int,
174
+ n_methods: int,
175
+ warnings: list[str],
176
+ ) -> dict[str, Any]:
177
+ """Build high-level summary statistics."""
178
+ # Compute average agreement
179
+ if method_agreement:
180
+ avg_agreement = float(np.mean(list(method_agreement.values())))
181
+ else:
182
+ avg_agreement = 1.0 if n_methods == 1 else 0.0
183
+
184
+ # Determine agreement level
185
+ if avg_agreement > 0.8:
186
+ agreement_level = "high"
187
+ elif avg_agreement > 0.6:
188
+ agreement_level = "medium"
189
+ else:
190
+ agreement_level = "low"
191
+
192
+ return {
193
+ "n_features": n_features,
194
+ "n_methods": n_methods,
195
+ "methods_run": methods_run,
196
+ "top_feature": consensus_ranking[0] if consensus_ranking else None,
197
+ "consensus_ranking": consensus_ranking,
198
+ "avg_method_agreement": avg_agreement,
199
+ "agreement_level": agreement_level,
200
+ "has_warnings": len(warnings) > 0,
201
+ "warnings_count": len(warnings),
202
+ }
203
+
204
+
205
+ def _extract_per_method_data(
206
+ method_results: dict[str, dict], include_distributions: bool = True
207
+ ) -> dict[str, MethodImportanceData]:
208
+ """Extract detailed per-method importance data with normalized values."""
209
+ per_method: dict[str, MethodImportanceData] = {}
210
+
211
+ for method_name, method_result in method_results.items():
212
+ feature_names = method_result.get("feature_names", [])
213
+
214
+ # Get importances based on method type
215
+ if method_name == "pfi":
216
+ importances_mean = method_result.get("importances_mean", [])
217
+ importances_std = method_result.get("importances_std", [])
218
+ importances_raw = method_result.get("importances_raw", [])
219
+
220
+ # Validate length consistency for PFI data
221
+ _validate_lengths_match(
222
+ ("feature_names", feature_names),
223
+ ("importances_mean", importances_mean),
224
+ ("importances_std", importances_std),
225
+ )
226
+
227
+ # Normalize importances to sum to 1.0 (percentage basis)
228
+ total = sum(importances_mean)
229
+ if total > 0:
230
+ importances_mean = [imp / total for imp in importances_mean]
231
+ importances_std = [std / total for std in importances_std]
232
+
233
+ # Convert to dicts (strict=True since we validated above)
234
+ importances_dict = dict(zip(feature_names, importances_mean, strict=True))
235
+ std_dict = dict(zip(feature_names, importances_std, strict=True))
236
+
237
+ # Compute confidence intervals (95% assuming normal)
238
+ # Use standard error (std / sqrt(n_repeats)) for CI of the mean
239
+ n_repeats = method_result.get("n_repeats", 1)
240
+ sqrt_n = np.sqrt(max(n_repeats, 1))
241
+ ci_dict = {}
242
+ for feat, mean, std in zip(
243
+ feature_names, importances_mean, importances_std, strict=False
244
+ ):
245
+ se = std / sqrt_n # Standard error of the mean
246
+ ci_dict[feat] = (float(mean - 1.96 * se), float(mean + 1.96 * se))
247
+
248
+ # Get raw values per repeat
249
+ raw_list = None
250
+ if include_distributions and importances_raw is not None and len(importances_raw) > 0:
251
+ raw_list = []
252
+ for repeat_values in importances_raw:
253
+ raw_list.append(dict(zip(feature_names, repeat_values, strict=False)))
254
+
255
+ per_method[method_name] = MethodImportanceData(
256
+ importances=importances_dict,
257
+ ranking=sorted(feature_names, key=lambda f: importances_dict[f], reverse=True),
258
+ std=std_dict,
259
+ confidence_intervals=ci_dict,
260
+ raw_values=raw_list,
261
+ metadata={
262
+ "n_repeats": method_result.get("n_repeats", 1),
263
+ "scoring": method_result.get("scoring", "unknown"),
264
+ },
265
+ )
266
+
267
+ else:
268
+ # MDI, MDA, SHAP - single value per feature
269
+ importances = method_result.get("importances", [])
270
+
271
+ # Validate length consistency for non-PFI methods
272
+ _validate_lengths_match(
273
+ ("feature_names", feature_names),
274
+ ("importances", importances),
275
+ )
276
+
277
+ # Normalize importances to sum to 1.0 (percentage basis)
278
+ # MDI is already normalized, but SHAP and others may not be
279
+ total = sum(importances)
280
+ if total > 0 and abs(total - 1.0) > 0.01: # Not already normalized
281
+ importances = [imp / total for imp in importances]
282
+
283
+ importances_dict = dict(zip(feature_names, importances, strict=True))
284
+
285
+ per_method[method_name] = MethodImportanceData(
286
+ importances=importances_dict,
287
+ ranking=sorted(feature_names, key=lambda f: importances_dict[f], reverse=True),
288
+ std=None,
289
+ confidence_intervals=None,
290
+ raw_values=None,
291
+ metadata={},
292
+ )
293
+
294
+ return per_method
295
+
296
+
297
+ def _build_per_feature_data(
298
+ consensus_ranking: list[str],
299
+ method_results: dict[str, dict],
300
+ _method_agreement: dict[str, float],
301
+ methods_run: list[str],
302
+ ) -> dict[str, FeatureDetailData]:
303
+ """Build per-feature aggregated views for drill-down."""
304
+ per_feature: dict[str, FeatureDetailData] = {}
305
+
306
+ # Create importance and ranking dicts per method
307
+ method_importances: dict[str, dict[str, float]] = {}
308
+ method_rankings: dict[str, list[str]] = {}
309
+
310
+ for method_name, method_result in method_results.items():
311
+ feature_names = method_result.get("feature_names", [])
312
+
313
+ if method_name == "pfi":
314
+ importances = method_result.get("importances_mean", [])
315
+ else:
316
+ importances = method_result.get("importances", [])
317
+
318
+ method_importances[method_name] = dict(zip(feature_names, importances, strict=False))
319
+ method_rankings[method_name] = sorted(
320
+ feature_names, key=lambda f: method_importances[method_name].get(f, 0), reverse=True
321
+ )
322
+
323
+ # Build per-feature data
324
+ for consensus_rank, feature_name in enumerate(consensus_ranking, start=1):
325
+ method_ranks = {}
326
+ method_scores = {}
327
+ method_stds = {}
328
+
329
+ for method_name in methods_run:
330
+ # Get rank in this method (with safe index lookup)
331
+ try:
332
+ ranking_list = method_rankings.get(method_name, [])
333
+ method_ranks[method_name] = ranking_list.index(feature_name) + 1
334
+ except ValueError:
335
+ # Feature not found in ranking - assign last rank
336
+ method_ranks[method_name] = len(method_rankings.get(method_name, [])) + 1
337
+
338
+ # Get score in this method
339
+ method_scores[method_name] = method_importances.get(method_name, {}).get(
340
+ feature_name, 0.0
341
+ )
342
+
343
+ # Get std if available (PFI) - with bounds checking
344
+ if method_name == "pfi":
345
+ pfi_result = method_results.get("pfi", {})
346
+ feature_names_pfi = pfi_result.get("feature_names", [])
347
+ if feature_name in feature_names_pfi:
348
+ idx = feature_names_pfi.index(feature_name)
349
+ importances_std = pfi_result.get("importances_std", [])
350
+ # Check bounds before accessing
351
+ if idx < len(importances_std):
352
+ method_stds[method_name] = importances_std[idx]
353
+
354
+ # Determine agreement level for this feature
355
+ rank_variance = 0.0 # Initialize before conditional to avoid undefined
356
+ if len(method_ranks) > 1:
357
+ rank_variance = float(np.var(list(method_ranks.values())))
358
+ if rank_variance < 2:
359
+ agreement_level = "high"
360
+ elif rank_variance < 10:
361
+ agreement_level = "medium"
362
+ else:
363
+ agreement_level = "low"
364
+ else:
365
+ agreement_level = "n/a"
366
+
367
+ # Compute stability score (inverse of rank variance, normalized)
368
+ stability_score = 1.0 / (1.0 + rank_variance) if len(method_ranks) > 1 else 1.0
369
+
370
+ # Generate interpretation
371
+ interpretation = _generate_feature_interpretation(
372
+ feature_name, consensus_rank, method_ranks, agreement_level
373
+ )
374
+
375
+ per_feature[feature_name] = FeatureDetailData(
376
+ consensus_rank=consensus_rank,
377
+ consensus_score=float(np.mean(list(method_scores.values()))),
378
+ method_ranks=method_ranks,
379
+ method_scores=method_scores,
380
+ method_stds=method_stds,
381
+ agreement_level=agreement_level,
382
+ stability_score=float(stability_score),
383
+ interpretation=interpretation,
384
+ )
385
+
386
+ return per_feature
387
+
388
+
389
+ def _compute_uncertainty_metrics(
390
+ method_results: dict[str, dict], consensus_ranking: list[str]
391
+ ) -> UncertaintyData:
392
+ """Compute uncertainty and stability metrics."""
393
+ # For now, focus on PFI which has repeat data
394
+ pfi_result = method_results.get("pfi", {})
395
+ has_pfi = bool(pfi_result)
396
+
397
+ method_stability = {}
398
+ confidence_intervals: dict[str, dict[str, tuple[float, float]]] = {}
399
+ coefficient_of_variation: dict[str, dict[str, float]] = {}
400
+ rank_stability: dict[str, list[int]] = {}
401
+
402
+ if has_pfi:
403
+ feature_names = pfi_result.get("feature_names", [])
404
+ importances_mean = pfi_result.get("importances_mean", [])
405
+ importances_std = pfi_result.get("importances_std", [])
406
+
407
+ # Validate length consistency
408
+ _validate_lengths_match(
409
+ ("feature_names", feature_names),
410
+ ("importances_mean", importances_mean),
411
+ ("importances_std", importances_std),
412
+ )
413
+
414
+ # Method stability: average CV across features
415
+ cvs = []
416
+ cv_dict = {}
417
+ for feat, mean, std in zip(feature_names, importances_mean, importances_std, strict=True):
418
+ if mean != 0:
419
+ cv = std / abs(mean)
420
+ cvs.append(cv)
421
+ cv_dict[feat] = float(cv)
422
+ else:
423
+ cv_dict[feat] = 0.0
424
+
425
+ method_stability["pfi"] = float(1.0 - np.mean(cvs)) if cvs else 1.0
426
+ coefficient_of_variation["pfi"] = cv_dict
427
+
428
+ # Confidence intervals (use standard error for CI of the mean)
429
+ n_repeats = pfi_result.get("n_repeats", 1)
430
+ sqrt_n = np.sqrt(max(n_repeats, 1))
431
+ ci_dict = {}
432
+ for feat, mean, std in zip(feature_names, importances_mean, importances_std, strict=True):
433
+ se = std / sqrt_n # Standard error of the mean
434
+ ci_dict[feat] = (float(mean - 1.96 * se), float(mean + 1.96 * se))
435
+ confidence_intervals["pfi"] = ci_dict
436
+
437
+ # Rank stability (if we had bootstrap data, we'd track rank distributions)
438
+ # For now, mark as placeholder
439
+ for feat in consensus_ranking:
440
+ rank_stability[feat] = [] # Placeholder for bootstrap ranks
441
+
442
+ return UncertaintyData(
443
+ method_stability=method_stability,
444
+ rank_stability=rank_stability,
445
+ confidence_intervals=confidence_intervals,
446
+ coefficient_of_variation=coefficient_of_variation,
447
+ )
448
+
449
+
450
+ def _build_method_comparison(
451
+ method_agreement: dict[str, float], method_results: dict[str, dict], methods_run: list[str]
452
+ ) -> MethodComparisonData:
453
+ """Build method comparison metrics."""
454
+ # Build correlation matrix
455
+ len(methods_run)
456
+ correlation_matrix = []
457
+
458
+ for method1 in methods_run:
459
+ row = []
460
+ for method2 in methods_run:
461
+ if method1 == method2:
462
+ row.append(1.0)
463
+ else:
464
+ # Find correlation in method_agreement dict
465
+ key1 = f"{method1}_vs_{method2}"
466
+ key2 = f"{method2}_vs_{method1}"
467
+ corr = method_agreement.get(key1, method_agreement.get(key2, 0.0))
468
+ row.append(float(corr))
469
+ correlation_matrix.append(row)
470
+
471
+ # Compute rank differences
472
+ method_rankings: dict[str, list[str]] = {}
473
+ for method_name, method_result in method_results.items():
474
+ feature_names = method_result.get("feature_names", [])
475
+ if method_name == "pfi":
476
+ importances = method_result.get("importances_mean", [])
477
+ else:
478
+ importances = method_result.get("importances", [])
479
+
480
+ # Validate length consistency
481
+ _validate_lengths_match(
482
+ ("feature_names", feature_names),
483
+ ("importances", importances),
484
+ )
485
+
486
+ importances_dict = dict(zip(feature_names, importances, strict=True))
487
+ ranking = sorted(feature_names, key=lambda f: importances_dict[f], reverse=True)
488
+ method_rankings[method_name] = ranking
489
+
490
+ rank_differences: dict[tuple[str, str], dict[str, int]] = {}
491
+ for i, method1 in enumerate(methods_run):
492
+ for method2 in methods_run[i + 1 :]:
493
+ diff_dict = {}
494
+ ranking1 = method_rankings.get(method1, [])
495
+ ranking2 = method_rankings.get(method2, [])
496
+
497
+ for feat in ranking1:
498
+ if feat in ranking2:
499
+ rank1 = ranking1.index(feat) + 1
500
+ rank2 = ranking2.index(feat) + 1
501
+ diff_dict[feat] = abs(rank1 - rank2)
502
+
503
+ rank_differences[(method1, method2)] = diff_dict
504
+
505
+ return MethodComparisonData(
506
+ correlation_matrix=correlation_matrix,
507
+ correlation_methods=methods_run,
508
+ rank_differences=rank_differences,
509
+ agreement_summary=method_agreement,
510
+ )
511
+
512
+
513
+ def _generate_feature_interpretation(
514
+ feature_name: str, consensus_rank: int, method_ranks: dict[str, int], agreement_level: str
515
+ ) -> str:
516
+ """Generate auto-interpretation for a single feature."""
517
+ if agreement_level == "high":
518
+ return (
519
+ f"'{feature_name}' ranks #{consensus_rank} with strong consensus across methods. "
520
+ f"All methods agree on its importance level."
521
+ )
522
+ elif agreement_level == "medium":
523
+ rank_str = ", ".join([f"{m}=#{r}" for m, r in method_ranks.items()])
524
+ return (
525
+ f"'{feature_name}' ranks #{consensus_rank} overall but shows moderate variation "
526
+ f"across methods ({rank_str}). Consider investigating method-specific biases."
527
+ )
528
+ else:
529
+ rank_str = ", ".join([f"{m}=#{r}" for m, r in method_ranks.items()])
530
+ return (
531
+ f"'{feature_name}' ranks #{consensus_rank} but shows significant disagreement "
532
+ f"across methods ({rank_str}). This may indicate interaction effects or "
533
+ f"method-specific artifacts. Further investigation recommended."
534
+ )
535
+
536
+
537
+ def _generate_llm_context(
538
+ summary: dict[str, Any],
539
+ _per_method: dict[str, MethodImportanceData],
540
+ _method_comparison: MethodComparisonData,
541
+ uncertainty: UncertaintyData,
542
+ warnings: list[str],
543
+ ) -> LLMContextData:
544
+ """Generate auto-narratives and insights for LLM consumption."""
545
+ n_features = summary["n_features"]
546
+ n_methods = summary["n_methods"]
547
+ methods_run = summary["methods_run"]
548
+ top_feature = summary["top_feature"]
549
+ avg_agreement = summary["avg_method_agreement"]
550
+ agreement_level = summary["agreement_level"]
551
+
552
+ # Build summary narrative
553
+ summary_narrative = (
554
+ f"This feature importance analysis examined {n_features} features using "
555
+ f"{n_methods} method{'s' if n_methods > 1 else ''} ({', '.join(methods_run)}). "
556
+ )
557
+
558
+ if top_feature:
559
+ summary_narrative += (
560
+ f"The consensus ranking identified '{top_feature}' as the most important feature. "
561
+ )
562
+
563
+ if n_methods > 1:
564
+ summary_narrative += (
565
+ f"Method agreement is {agreement_level} (average correlation: {avg_agreement:.2f}). "
566
+ )
567
+
568
+ # Generate key insights
569
+ key_insights = []
570
+
571
+ # Insight 1: Top features
572
+ key_insights.append(
573
+ f"Top consensus feature: '{top_feature}'"
574
+ if top_feature
575
+ else "No clear top feature identified"
576
+ )
577
+
578
+ # Insight 2: Method agreement
579
+ if n_methods > 1:
580
+ if agreement_level == "high":
581
+ key_insights.append(
582
+ f"Strong consensus across methods (avg correlation: {avg_agreement:.2f})"
583
+ )
584
+ elif agreement_level == "medium":
585
+ key_insights.append(
586
+ f"Moderate method agreement (avg correlation: {avg_agreement:.2f}) - some variation expected"
587
+ )
588
+ else:
589
+ key_insights.append(
590
+ f"Low method agreement (avg correlation: {avg_agreement:.2f}) - investigate method-specific biases"
591
+ )
592
+
593
+ # Insight 3: Stability (if available)
594
+ if uncertainty.get("method_stability"):
595
+ for method, stability in uncertainty["method_stability"].items():
596
+ if stability < 0.7:
597
+ key_insights.append(
598
+ f"{method.upper()} shows low stability (score: {stability:.2f}) - "
599
+ "importance estimates have high variance"
600
+ )
601
+
602
+ # Generate recommendations
603
+ recommendations = []
604
+
605
+ # Rec 1: Based on agreement
606
+ if n_methods > 1 and avg_agreement < 0.6:
607
+ recommendations.append(
608
+ "Investigate features with large rank disagreements between methods. "
609
+ "This may indicate interaction effects or method-specific artifacts."
610
+ )
611
+
612
+ # Rec 2: Based on stability
613
+ if uncertainty.get("method_stability") and any(
614
+ s < 0.7 for s in uncertainty["method_stability"].values()
615
+ ):
616
+ recommendations.append(
617
+ "Increase number of repeats or use cross-validation to improve importance stability estimates."
618
+ )
619
+
620
+ # Rec 3: General best practice
621
+ recommendations.append(
622
+ "Focus on top consensus features for model interpretability and feature selection."
623
+ )
624
+
625
+ # Caveats
626
+ caveats = []
627
+ if warnings:
628
+ caveats.append(f"Analysis generated {len(warnings)} warning(s) - review carefully.")
629
+
630
+ if n_methods == 1:
631
+ caveats.append(
632
+ "Only one method used. Consider running multiple methods to validate findings."
633
+ )
634
+
635
+ # Determine overall quality
636
+ if n_methods >= 2 and avg_agreement > 0.7 and len(warnings) == 0:
637
+ analysis_quality = "high"
638
+ elif n_methods >= 2 and avg_agreement > 0.5:
639
+ analysis_quality = "medium"
640
+ else:
641
+ analysis_quality = "low"
642
+
643
+ return LLMContextData(
644
+ summary_narrative=summary_narrative,
645
+ key_insights=key_insights,
646
+ recommendations=recommendations,
647
+ caveats=caveats,
648
+ analysis_quality=analysis_quality,
649
+ )