ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,338 @@
1
+ """Comprehensive ML feature importance analysis comparing multiple methods.
2
+
3
+ This module provides a tear sheet function that runs MDI, PFI, MDA, and SHAP
4
+ importance methods and generates a comparison report with consensus ranking.
5
+ """
6
+
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING, Any, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import polars as pl
13
+ from scipy.stats import spearmanr
14
+
15
+ from ml4t.diagnostic.evaluation.metrics.importance_classical import (
16
+ compute_mdi_importance,
17
+ compute_permutation_importance,
18
+ )
19
+ from ml4t.diagnostic.evaluation.metrics.importance_mda import compute_mda_importance
20
+ from ml4t.diagnostic.evaluation.metrics.importance_shap import compute_shap_importance
21
+
22
+ if TYPE_CHECKING:
23
+ from numpy.typing import NDArray
24
+
25
+
26
+ def _generate_ml_importance_interpretation(
27
+ top_features: list[str],
28
+ method_agreement: dict[str, float],
29
+ warnings: list[str],
30
+ n_consensus: int,
31
+ ) -> str:
32
+ """Generate human-readable interpretation of ML importance analysis.
33
+
34
+ Parameters
35
+ ----------
36
+ top_features : list[str]
37
+ Top features from consensus ranking
38
+ method_agreement : dict[str, float]
39
+ Pairwise correlations between methods
40
+ warnings : list[str]
41
+ List of potential issues detected
42
+ n_consensus : int
43
+ Number of features in top 10 across all methods
44
+
45
+ Returns
46
+ -------
47
+ str
48
+ Human-readable interpretation summary
49
+ """
50
+ lines = []
51
+
52
+ # Consensus features
53
+ if n_consensus > 0:
54
+ lines.append(f"Strong consensus: {n_consensus} features rank in top 10 across all methods")
55
+ lines.append(f" Top consensus features: {', '.join(top_features[:5])}")
56
+ else:
57
+ lines.append("Weak consensus: Different methods identify different important features")
58
+
59
+ # Method agreement
60
+ if method_agreement:
61
+ avg_agreement = float(np.mean(list(method_agreement.values())))
62
+ if avg_agreement > 0.7:
63
+ lines.append(f"High agreement between methods (avg correlation: {avg_agreement:.2f})")
64
+ elif avg_agreement > 0.5:
65
+ lines.append(
66
+ f"Moderate agreement between methods (avg correlation: {avg_agreement:.2f})"
67
+ )
68
+ else:
69
+ lines.append(
70
+ f"Low agreement between methods (avg correlation: {avg_agreement:.2f}) - investigate further"
71
+ )
72
+
73
+ # Warnings
74
+ if warnings:
75
+ lines.append("\nPotential Issues:")
76
+ for warning in warnings:
77
+ lines.append(f" - {warning}")
78
+
79
+ return "\n".join(lines)
80
+
81
+
82
+ def analyze_ml_importance(
83
+ model: Any,
84
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
85
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"],
86
+ feature_names: list[str] | None = None,
87
+ methods: list[str] | None = None,
88
+ scoring: str | Callable | None = None,
89
+ n_repeats: int = 10,
90
+ random_state: int | None = 42,
91
+ ) -> dict[str, Any]:
92
+ """Comprehensive ML feature importance analysis comparing multiple methods.
93
+
94
+ **This is a TEAR SHEET function** - it runs multiple importance methods and
95
+ generates a comparison report with consensus ranking and interpretation.
96
+
97
+ **Use Case**: "Which features does my model rely on? Do different methods agree?"
98
+
99
+ This function replaces 100+ lines of manual comparison code by providing
100
+ integrated analysis showing:
101
+ - Individual method results (MDI, PFI, MDA, SHAP)
102
+ - Consensus ranking (features important across methods)
103
+ - Method agreement/disagreement analysis
104
+ - Auto-generated insights and warnings
105
+
106
+ **Why Compare Methods?**
107
+
108
+ Different importance methods measure different aspects:
109
+ - **MDI** (Mean Decrease Impurity): Fast, but biased toward high-cardinality features
110
+ - **PFI** (Permutation): Unbiased, measures predictive importance
111
+ - **MDA** (Mean Decrease Accuracy): Similar to PFI but removes features completely
112
+ - **SHAP**: Theoretically sound, based on game theory
113
+
114
+ Strong consensus across methods indicates robust feature importance.
115
+ Disagreement suggests model-specific artifacts or feature interactions.
116
+
117
+ Parameters
118
+ ----------
119
+ model : Any
120
+ Fitted model. Requirements vary by method:
121
+ - MDI: Must have `feature_importances_` (tree-based models)
122
+ - PFI, MDA: Must have `predict()` or `score()`
123
+ - SHAP: Must be compatible with TreeExplainer
124
+ X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
125
+ Feature matrix (n_samples, n_features)
126
+ y : Union[pl.Series, pd.Series, np.ndarray]
127
+ Target values (n_samples,)
128
+ feature_names : list[str] | None, default None
129
+ Feature names for labeling. If None, uses column names from DataFrame
130
+ or generates numeric names
131
+ methods : list[str] | None, default ["mdi", "pfi", "shap"]
132
+ Which methods to run. Options: "mdi", "pfi", "mda", "shap"
133
+ scoring : str | Callable | None, default None
134
+ Scoring metric for PFI and MDA
135
+ n_repeats : int, default 10
136
+ Number of permutations for PFI
137
+ random_state : int | None, default 42
138
+ Random seed for reproducibility
139
+
140
+ Returns
141
+ -------
142
+ dict[str, Any]
143
+ Comprehensive analysis results:
144
+ - method_results: Dict of individual method outputs
145
+ - consensus_ranking: Features ranked by average rank across methods
146
+ - method_agreement: Spearman correlations between method rankings
147
+ - top_features_consensus: Features in top 10 for ALL methods
148
+ - warnings: Detected issues
149
+ - interpretation: Auto-generated summary
150
+ - methods_run: Methods successfully executed
151
+ - methods_failed: Failed methods with error messages
152
+
153
+ Raises
154
+ ------
155
+ ValueError
156
+ If no methods specified or all methods fail
157
+
158
+ Examples
159
+ --------
160
+ >>> from sklearn.ensemble import RandomForestClassifier
161
+ >>> from sklearn.datasets import make_classification
162
+ >>>
163
+ >>> # Create synthetic dataset
164
+ >>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
165
+ >>> model = RandomForestClassifier(n_estimators=50, random_state=42)
166
+ >>> model.fit(X, y)
167
+ >>>
168
+ >>> # Comprehensive importance analysis
169
+ >>> result = analyze_ml_importance(model, X, y, methods=["mdi", "pfi"])
170
+ >>>
171
+ >>> # Quick summary
172
+ >>> print(result["interpretation"])
173
+ """
174
+ if methods is None:
175
+ methods = ["mdi", "pfi", "shap"]
176
+
177
+ if not methods:
178
+ raise ValueError("At least one method must be specified")
179
+
180
+ # Extract feature names if not provided
181
+ if feature_names is None:
182
+ if isinstance(X, pl.DataFrame | pd.DataFrame):
183
+ feature_names = list(X.columns)
184
+ else:
185
+ # Generate numeric feature names
186
+ n_features = X.shape[1] if hasattr(X, "shape") else len(X[0])
187
+ feature_names = [f"f{i}" for i in range(n_features)]
188
+
189
+ # Run each method with try/except for optional dependencies
190
+ results = {}
191
+ method_failures = []
192
+
193
+ if "mdi" in methods:
194
+ try:
195
+ results["mdi"] = compute_mdi_importance(model, feature_names=feature_names)
196
+ except Exception as e:
197
+ method_failures.append(("mdi", str(e)))
198
+
199
+ if "pfi" in methods:
200
+ try:
201
+ results["pfi"] = compute_permutation_importance(
202
+ model,
203
+ X,
204
+ y,
205
+ feature_names=feature_names,
206
+ scoring=scoring,
207
+ n_repeats=n_repeats,
208
+ random_state=random_state,
209
+ )
210
+ except Exception as e:
211
+ method_failures.append(("pfi", str(e)))
212
+
213
+ if "mda" in methods:
214
+ try:
215
+ results["mda"] = compute_mda_importance(
216
+ model, X, y, feature_names=feature_names, scoring=scoring
217
+ )
218
+ except Exception as e:
219
+ method_failures.append(("mda", str(e)))
220
+
221
+ if "shap" in methods:
222
+ try:
223
+ results["shap"] = compute_shap_importance(model, X, feature_names=feature_names)
224
+ except ImportError:
225
+ method_failures.append(
226
+ (
227
+ "shap",
228
+ "shap library not installed. Install with: pip install ml4t-diagnostic[ml]",
229
+ )
230
+ )
231
+ except Exception as e:
232
+ method_failures.append(("shap", str(e)))
233
+
234
+ # Check if at least one method succeeded
235
+ if not results:
236
+ error_msg = "All methods failed:\n" + "\n".join(
237
+ f" - {method}: {error}" for method, error in method_failures
238
+ )
239
+ raise ValueError(error_msg)
240
+
241
+ # 2. Compute consensus ranking
242
+ # Convert each method's importance to rankings (1 = most important)
243
+ rankings = {}
244
+ for method_name, result in results.items():
245
+ # Get feature names and importances for this method
246
+ method_feature_names = result["feature_names"]
247
+
248
+ if method_name == "pfi":
249
+ importances = result["importances_mean"]
250
+ elif method_name in ["shap", "mdi", "mda"]:
251
+ importances = result["importances"]
252
+ else:
253
+ # Shouldn't happen, but handle gracefully
254
+ continue
255
+
256
+ # Create a mapping from feature name to importance
257
+ feature_to_importance = dict(zip(method_feature_names, importances, strict=False))
258
+
259
+ # Map to our canonical feature_names list (handle missing features)
260
+ importance_values = np.array(
261
+ [feature_to_importance.get(fname, 0.0) for fname in feature_names]
262
+ )
263
+
264
+ # Rank (higher importance = lower rank number, i.e., rank 0 is most important)
265
+ ranks = np.argsort(np.argsort(importance_values)[::-1])
266
+ rankings[method_name] = ranks
267
+
268
+ # Average ranks across methods
269
+ avg_ranks = np.mean(list(rankings.values()), axis=0)
270
+ consensus_order = np.argsort(avg_ranks)
271
+
272
+ # Get feature names in consensus order
273
+ consensus_ranking = [feature_names[i] for i in consensus_order]
274
+
275
+ # 3. Compute method agreement (Spearman correlation between rankings)
276
+ method_agreement = {}
277
+ method_names = list(rankings.keys())
278
+ for i, m1 in enumerate(method_names):
279
+ for m2 in method_names[i + 1 :]:
280
+ corr, _ = spearmanr(rankings[m1], rankings[m2])
281
+ method_agreement[f"{m1}_vs_{m2}"] = float(corr)
282
+
283
+ # 4. Identify consensus top features (top 10 in all methods)
284
+ top_n = 10
285
+ top_features_by_method = {}
286
+ for method_name, result in results.items():
287
+ # Get top N feature names from this method
288
+ method_top_features = result["feature_names"][:top_n]
289
+ top_features_by_method[method_name] = set(method_top_features)
290
+
291
+ consensus_top = (
292
+ set.intersection(*top_features_by_method.values()) if top_features_by_method else set()
293
+ )
294
+
295
+ # 5. Generate warnings
296
+ warnings = []
297
+
298
+ # Warning: High MDI but low PFI (possible overfitting)
299
+ if "mdi" in results and "pfi" in results:
300
+ mdi_top = set(results["mdi"]["feature_names"][:5])
301
+ pfi_top = set(results["pfi"]["feature_names"][:5])
302
+ disagreement = mdi_top - pfi_top
303
+ if disagreement:
304
+ warnings.append(
305
+ f"Features {disagreement} rank high in MDI but not PFI - possible overfitting to tree structure"
306
+ )
307
+
308
+ # Warning: Low agreement between methods
309
+ if method_agreement:
310
+ min_agreement = min(method_agreement.values())
311
+ if min_agreement < 0.5:
312
+ warnings.append(
313
+ f"Low agreement between methods (min correlation: {min_agreement:.2f}) - results may be unreliable"
314
+ )
315
+
316
+ # Add method failures to warnings
317
+ if method_failures:
318
+ for method, error in method_failures:
319
+ warnings.append(f"Method '{method}' failed: {error}")
320
+
321
+ # 6. Generate interpretation
322
+ interpretation = _generate_ml_importance_interpretation(
323
+ consensus_ranking[:10],
324
+ method_agreement,
325
+ warnings,
326
+ len(consensus_top),
327
+ )
328
+
329
+ return {
330
+ "method_results": results,
331
+ "consensus_ranking": consensus_ranking,
332
+ "method_agreement": method_agreement,
333
+ "top_features_consensus": list(consensus_top),
334
+ "warnings": warnings,
335
+ "interpretation": interpretation,
336
+ "methods_run": list(results.keys()),
337
+ "methods_failed": method_failures,
338
+ }