ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,375 @@
1
+ """Classical feature importance: Permutation (PFI) and Mean Decrease Impurity (MDI).
2
+
3
+ This module provides model-agnostic permutation importance and tree-based MDI
4
+ importance calculations.
5
+ """
6
+
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING, Any, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import polars as pl
13
+
14
+ if TYPE_CHECKING:
15
+ from numpy.typing import NDArray
16
+
17
+
18
+ def compute_permutation_importance(
19
+ model: Any,
20
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
21
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"],
22
+ feature_names: list[str] | None = None,
23
+ scoring: str | Callable | None = None,
24
+ n_repeats: int = 10,
25
+ random_state: int | None = 42,
26
+ n_jobs: int | None = None,
27
+ ) -> dict[str, Any]:
28
+ """Compute Permutation Feature Importance (PFI) for model-agnostic feature ranking.
29
+
30
+ Permutation Feature Importance measures the increase in model error when a
31
+ feature's values are randomly shuffled. Features with high importance cause
32
+ large performance drops when permuted, indicating they are critical for
33
+ the model's predictions.
34
+
35
+ This is a model-agnostic method that works with any fitted estimator,
36
+ making it superior to model-specific importance measures (e.g., tree-based
37
+ feature importances) which can be biased toward high-cardinality features.
38
+
39
+ Parameters
40
+ ----------
41
+ model : Any
42
+ Fitted sklearn-compatible estimator (must have `predict` or `predict_proba`)
43
+ X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
44
+ Feature matrix (n_samples, n_features)
45
+ y : Union[pl.Series, pd.Series, np.ndarray]
46
+ Target values (n_samples,)
47
+ feature_names : list[str] | None, default None
48
+ Feature names for labeling. If None, uses column names from DataFrame
49
+ or generates numeric names for arrays
50
+ scoring : str | Callable | None, default None
51
+ Scoring function to evaluate model performance. If None, uses model's
52
+ default score method. Common options:
53
+ - Classification: 'accuracy', 'roc_auc', 'f1'
54
+ - Regression: 'r2', 'neg_mean_squared_error', 'neg_mean_absolute_error'
55
+ n_repeats : int, default 10
56
+ Number of times to permute each feature (more repeats = more stable estimates)
57
+ random_state : int | None, default 42
58
+ Random seed for reproducibility
59
+ n_jobs : int | None, default None
60
+ Number of parallel jobs (-1 for all CPUs)
61
+
62
+ Returns
63
+ -------
64
+ dict[str, Any]
65
+ Dictionary with permutation importance results:
66
+ - importances_mean: Mean importance per feature
67
+ - importances_std: Standard deviation of importance per feature
68
+ - importances_raw: All permutation results (n_features, n_repeats)
69
+ - feature_names: Feature labels
70
+ - baseline_score: Model score before permutation
71
+ - n_repeats: Number of permutation rounds
72
+ - scoring: Scoring function used
73
+
74
+ Examples
75
+ --------
76
+ >>> from sklearn.ensemble import RandomForestClassifier
77
+ >>> from sklearn.datasets import make_classification
78
+ >>>
79
+ >>> # Train a simple model
80
+ >>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
81
+ >>> model = RandomForestClassifier(n_estimators=10, random_state=42)
82
+ >>> model.fit(X, y)
83
+ >>>
84
+ >>> # Compute permutation importance
85
+ >>> pfi = compute_permutation_importance(
86
+ ... model=model,
87
+ ... X=X,
88
+ ... y=y,
89
+ ... n_repeats=10,
90
+ ... scoring='accuracy'
91
+ ... )
92
+ >>>
93
+ >>> # Examine results
94
+ >>> print(f"Baseline score: {pfi['baseline_score']:.3f}")
95
+ >>> print(f"Most important feature: {pfi['feature_names'][np.argmax(pfi['importances_mean'])]}")
96
+ >>> print(f"Importance: {np.max(pfi['importances_mean']):.3f} ± {pfi['importances_std'][np.argmax(pfi['importances_mean'])]:.3f}")
97
+ Baseline score: 0.920
98
+ Most important feature: feature_0
99
+ Importance: 0.124 ± 0.015
100
+
101
+ Notes
102
+ -----
103
+ **Interpretation**:
104
+ - Importance = 0: Feature not useful
105
+ - Importance > 0: Feature contributes to predictions
106
+ - Importance < 0: Feature hurts performance (may indicate overfitting)
107
+ - Higher importance = More critical feature
108
+
109
+ **Advantages over MDI** (Mean Decrease in Impurity):
110
+ - Model-agnostic: Works with any estimator
111
+ - Unbiased: Not inflated by high-cardinality features
112
+ - Realistic: Measures actual predictive power, not just tree splits
113
+
114
+ **Computational Cost**:
115
+ - Time complexity: O(n_features * n_repeats * prediction_time)
116
+ - Can be slow for large datasets or complex models
117
+ - Use n_jobs=-1 for parallel computation
118
+
119
+ **Best Practices**:
120
+ - Use hold-out validation set (not training data) for unbiased estimates
121
+ - Increase n_repeats (20-30) for more stable results
122
+ - Check for negative importances (may indicate model instability)
123
+ - Compare with other importance methods (SHAP, MDI) for robustness
124
+
125
+ References
126
+ ----------
127
+ .. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001.
128
+ """
129
+ from sklearn.inspection import permutation_importance as sklearn_pfi
130
+
131
+ # Convert inputs to numpy arrays
132
+ X_array: NDArray[Any]
133
+ if isinstance(X, pl.DataFrame):
134
+ if feature_names is None:
135
+ feature_names = X.columns
136
+ X_array = X.to_numpy()
137
+ elif isinstance(X, pd.DataFrame):
138
+ if feature_names is None:
139
+ feature_names = X.columns.tolist()
140
+ X_array = X.to_numpy()
141
+ else:
142
+ X_array = np.asarray(X)
143
+ if feature_names is None:
144
+ feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]
145
+
146
+ # Type assertion: feature_names is guaranteed to be set at this point
147
+ assert feature_names is not None, "feature_names should be set by this point"
148
+
149
+ y_array: NDArray[Any]
150
+ if isinstance(y, pl.Series):
151
+ y_array = y.to_numpy()
152
+ elif isinstance(y, pd.Series):
153
+ y_array = y.to_numpy()
154
+ else:
155
+ y_array = np.asarray(y)
156
+
157
+ # Compute baseline score
158
+ if scoring is None:
159
+ baseline_score = model.score(X_array, y_array)
160
+ else:
161
+ from sklearn.metrics import get_scorer
162
+
163
+ scorer = get_scorer(scoring) if isinstance(scoring, str) else scoring
164
+ baseline_score = scorer(model, X_array, y_array)
165
+
166
+ # Compute permutation importance using sklearn
167
+ result = sklearn_pfi(
168
+ estimator=model,
169
+ X=X_array,
170
+ y=y_array,
171
+ scoring=scoring,
172
+ n_repeats=n_repeats,
173
+ random_state=random_state,
174
+ n_jobs=n_jobs,
175
+ )
176
+
177
+ # Extract and format results
178
+ importances_mean = result.importances_mean
179
+ importances_std = result.importances_std
180
+ importances_raw = result.importances # Shape: (n_features, n_repeats)
181
+
182
+ # Sort by importance (descending)
183
+ sorted_idx = np.argsort(importances_mean)[::-1]
184
+
185
+ return {
186
+ "importances_mean": importances_mean[sorted_idx],
187
+ "importances_std": importances_std[sorted_idx],
188
+ "importances_raw": importances_raw[sorted_idx],
189
+ "feature_names": [feature_names[i] for i in sorted_idx],
190
+ "baseline_score": float(baseline_score),
191
+ "n_repeats": n_repeats,
192
+ "scoring": scoring if scoring is not None else "default",
193
+ "n_features": len(feature_names),
194
+ }
195
+
196
+
197
+ def compute_mdi_importance(
198
+ model: Any,
199
+ feature_names: list[str] | None = None,
200
+ normalize: bool = True,
201
+ ) -> dict[str, Any]:
202
+ """Compute Mean Decrease in Impurity (MDI) feature importance from tree-based models.
203
+
204
+ MDI measures how much each feature contributes to decreasing the weighted
205
+ impurity (Gini for classification, MSE/MAE for regression) across all trees.
206
+ This is computed during model training and is available via the model's
207
+ `feature_importances_` attribute.
208
+
209
+ **Supported Models**:
210
+ - LightGBM: `lightgbm.LGBMClassifier`, `lightgbm.LGBMRegressor` (recommended)
211
+ - XGBoost: `xgboost.XGBClassifier`, `xgboost.XGBRegressor` (recommended)
212
+ - sklearn: `RandomForestClassifier`, `RandomForestRegressor` (not recommended - slow)
213
+ - sklearn: `GradientBoostingClassifier`, `GradientBoostingRegressor` (not recommended - slow)
214
+
215
+ **Not supported**:
216
+ - sklearn's HistGradientBoosting* (doesn't expose feature_importances_)
217
+
218
+ Parameters
219
+ ----------
220
+ model : Any
221
+ Fitted tree-based model with `feature_importances_` attribute.
222
+ Must be one of: LightGBM, XGBoost, or sklearn tree ensembles.
223
+ feature_names : list[str] | None, default None
224
+ Feature names for labeling. If None, uses feature names from model
225
+ or generates numeric names.
226
+ normalize : bool, default True
227
+ If True, ensures importances sum to 1.0 (some models already normalize).
228
+
229
+ Returns
230
+ -------
231
+ dict[str, Any]
232
+ Dictionary with MDI importance results:
233
+ - importances: Feature importance values (sorted descending)
234
+ - feature_names: Feature labels (sorted by importance)
235
+ - n_features: Number of features
236
+ - normalized: Whether values sum to 1.0
237
+ - model_type: Type of model used
238
+
239
+ Raises
240
+ ------
241
+ AttributeError
242
+ If model doesn't have `feature_importances_` attribute
243
+ ImportError
244
+ If LightGBM/XGBoost not installed and trying to use those models
245
+
246
+ Examples
247
+ --------
248
+ >>> import lightgbm as lgb
249
+ >>> from sklearn.datasets import make_classification
250
+ >>>
251
+ >>> # Train LightGBM model
252
+ >>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
253
+ >>> model = lgb.LGBMClassifier(n_estimators=100, random_state=42)
254
+ >>> model.fit(X, y)
255
+ >>>
256
+ >>> # Extract MDI importance
257
+ >>> mdi = compute_mdi_importance(
258
+ ... model=model,
259
+ ... feature_names=[f'feature_{i}' for i in range(10)]
260
+ ... )
261
+ >>>
262
+ >>> # Examine results
263
+ >>> print(f"Most important feature: {mdi['feature_names'][0]}")
264
+ >>> print(f"Importance: {mdi['importances'][0]:.3f}")
265
+ >>> print(f"Model type: {mdi['model_type']}")
266
+ Most important feature: feature_3
267
+ Importance: 0.245
268
+ Model type: lightgbm.LGBMClassifier
269
+
270
+ Notes
271
+ -----
272
+ **MDI vs PFI** (Permutation Feature Importance):
273
+
274
+ **MDI Advantages**:
275
+ - Very fast: Computed during training (no additional overhead)
276
+ - No additional data required
277
+ - Deterministic: Same result every time
278
+
279
+ **MDI Disadvantages**:
280
+ - **Biased toward high-cardinality features**: Features with many unique values
281
+ get inflated importance even if not truly predictive
282
+ - **Only for tree-based models**: Not model-agnostic
283
+ - **Train set importance**: May not reflect test set predictive power
284
+ - **Correlated features**: Can split importance between correlated predictors
285
+
286
+ **When to use MDI**:
287
+ - Quick exploratory analysis
288
+ - When computational budget is limited
289
+ - When working with tree-based models exclusively
290
+
291
+ **When to use PFI instead**:
292
+ - Need unbiased importance estimates
293
+ - Have high-cardinality categorical features
294
+ - Want model-agnostic importance
295
+ - Need to validate importance on test set
296
+
297
+ **Comparison workflow**:
298
+ >>> # Compare MDI and PFI
299
+ >>> mdi = compute_mdi_importance(model, feature_names=features)
300
+ >>> pfi = compute_permutation_importance(model, X_test, y_test, feature_names=features)
301
+ >>>
302
+ >>> # Large discrepancies may indicate:
303
+ >>> # - High-cardinality bias in MDI
304
+ >>> # - Correlated features splitting importance
305
+ >>> # - Overfitting (high MDI, low PFI)
306
+
307
+ **Performance notes**:
308
+ - LightGBM and XGBoost: Production-ready speed and accuracy (RECOMMENDED)
309
+ - sklearn RandomForest/GradientBoosting: 10-100x slower, avoid for large datasets
310
+ - sklearn HistGradientBoosting: Fast but doesn't expose feature_importances_ (use PFI instead)
311
+
312
+ References
313
+ ----------
314
+ - Breiman, L. (2001). "Random Forests". Machine Learning.
315
+ - Louppe, G. et al. (2013). "Understanding variable importances in forests of
316
+ randomized trees". NeurIPS.
317
+ - Strobl, C. et al. (2007). "Bias in random forest variable importance measures".
318
+ BMC Bioinformatics.
319
+ """
320
+ # Check if model has feature_importances_
321
+ if not hasattr(model, "feature_importances_"):
322
+ raise AttributeError(
323
+ f"Model of type {type(model).__name__} does not have 'feature_importances_' attribute. "
324
+ "MDI is only available for tree-based models (LightGBM, XGBoost, sklearn tree ensembles)."
325
+ )
326
+
327
+ # Extract raw importances
328
+ importances = model.feature_importances_
329
+
330
+ # Get feature names
331
+ if feature_names is None:
332
+ # Try to get from model
333
+ if hasattr(model, "feature_name_"):
334
+ # LightGBM
335
+ feature_names = model.feature_name_
336
+ elif hasattr(model, "get_booster") and hasattr(model.get_booster(), "feature_names"):
337
+ # XGBoost
338
+ feature_names = model.get_booster().feature_names
339
+ elif hasattr(model, "feature_names_in_"):
340
+ # sklearn
341
+ feature_names = list(model.feature_names_in_)
342
+ else:
343
+ # Fallback to numeric names
344
+ feature_names = [f"feature_{i}" for i in range(len(importances))]
345
+ else:
346
+ feature_names = list(feature_names)
347
+
348
+ # Validate length match
349
+ if len(feature_names) != len(importances):
350
+ raise ValueError(
351
+ f"Number of feature names ({len(feature_names)}) does not match number of importances ({len(importances)})"
352
+ )
353
+
354
+ # Normalize if requested
355
+ if normalize:
356
+ importance_sum = importances.sum()
357
+ if importance_sum > 0:
358
+ importances = importances / importance_sum
359
+ else:
360
+ # All zeros - already normalized
361
+ pass
362
+
363
+ # Sort by importance (descending)
364
+ sorted_idx = np.argsort(importances)[::-1]
365
+
366
+ # Determine model type
367
+ model_type = f"{type(model).__module__}.{type(model).__name__}"
368
+
369
+ return {
370
+ "importances": importances[sorted_idx],
371
+ "feature_names": [feature_names[i] for i in sorted_idx],
372
+ "n_features": len(feature_names),
373
+ "normalized": normalize,
374
+ "model_type": model_type,
375
+ }