ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,371 @@
1
+ """Mean Decrease in Accuracy (MDA) feature importance by feature removal.
2
+
3
+ This module provides MDA importance which measures performance drop when features
4
+ are neutralized, with support for feature groups.
5
+ """
6
+
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING, Any, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import polars as pl
13
+
14
+ if TYPE_CHECKING:
15
+ from numpy.typing import NDArray
16
+
17
+
18
+ def compute_mda_importance(
19
+ model: Any,
20
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
21
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"],
22
+ feature_names: list[str] | None = None,
23
+ feature_groups: dict[str, list[str]] | None = None,
24
+ removal_method: str = "mean",
25
+ scoring: str | Callable | None = None,
26
+ _n_jobs: int | None = None,
27
+ ) -> dict[str, Any]:
28
+ """Compute Mean Decrease in Accuracy (MDA) by feature removal.
29
+
30
+ MDA measures the drop in model performance when features are removed or
31
+ neutralized. Unlike Permutation Feature Importance (PFI) which shuffles
32
+ feature values, MDA replaces feature values with a constant (mean, median,
33
+ or zero), simulating complete feature unavailability.
34
+
35
+ This approach naturally supports feature groups (e.g., one-hot encoded
36
+ categoricals, related features like lat/lon) by removing multiple features
37
+ simultaneously and measuring the joint importance.
38
+
39
+ **Supported Models**:
40
+ - Any fitted sklearn-compatible estimator with `score()` or `predict()` method
41
+ - Classification: LogisticRegression, RandomForest, XGBoost, LightGBM, etc.
42
+ - Regression: LinearRegression, Ridge, GradientBoosting, etc.
43
+
44
+ Parameters
45
+ ----------
46
+ model : Any
47
+ Fitted sklearn-compatible estimator (must have `score()` or `predict()` method)
48
+ X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
49
+ Feature matrix (n_samples, n_features)
50
+ y : Union[pl.Series, pd.Series, np.ndarray]
51
+ Target values (n_samples,)
52
+ feature_names : list[str] | None, default None
53
+ Feature names for labeling. If None, uses column names from DataFrame
54
+ or generates numeric names for arrays
55
+ feature_groups : dict[str, list[str]] | None, default None
56
+ Dictionary mapping group names to lists of feature names.
57
+ When provided, computes importance for feature groups instead of
58
+ individual features. Example: {"location": ["lat", "lon"],
59
+ "time": ["hour", "day", "month"]}
60
+ removal_method : str, default "mean"
61
+ How to neutralize features:
62
+ - "mean": Replace with feature mean (recommended for continuous features)
63
+ - "median": Replace with feature median (robust to outliers)
64
+ - "zero": Replace with zero (can distort if zero is out-of-distribution)
65
+ scoring : str | Callable | None, default None
66
+ Scoring function to evaluate model performance. If None, uses model's
67
+ default score method. Common options:
68
+ - Classification: 'accuracy', 'roc_auc', 'f1'
69
+ - Regression: 'r2', 'neg_mean_squared_error', 'neg_mean_absolute_error'
70
+ n_jobs : int | None, default None
71
+ Number of parallel jobs for scoring (-1 for all CPUs).
72
+ Note: Parallelization is limited compared to sklearn's implementation
73
+ since we need to modify data for each feature.
74
+
75
+ Returns
76
+ -------
77
+ dict[str, Any]
78
+ Dictionary with MDA importance results:
79
+ - importances: Performance drop per feature/group (sorted descending)
80
+ - feature_names: Feature/group labels (sorted by importance)
81
+ - baseline_score: Model score before feature removal
82
+ - removal_method: Method used to neutralize features
83
+ - scoring: Scoring function used
84
+ - n_features: Number of features/groups evaluated
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If removal_method is not one of: "mean", "median", "zero"
90
+ ValueError
91
+ If feature_groups contains unknown feature names
92
+ ValueError
93
+ If X and y have different numbers of samples
94
+
95
+ Examples
96
+ --------
97
+ >>> from sklearn.ensemble import RandomForestClassifier
98
+ >>> from sklearn.datasets import make_classification
99
+ >>> import numpy as np
100
+ >>>
101
+ >>> # Train a simple model
102
+ >>> X, y = make_classification(n_samples=1000, n_features=10, n_informative=3, random_state=42)
103
+ >>> model = RandomForestClassifier(n_estimators=50, random_state=42)
104
+ >>> model.fit(X, y)
105
+ >>>
106
+ >>> # Compute MDA importance
107
+ >>> mda = compute_mda_importance(
108
+ ... model=model,
109
+ ... X=X,
110
+ ... y=y,
111
+ ... removal_method='mean',
112
+ ... scoring='accuracy'
113
+ ... )
114
+ >>>
115
+ >>> # Examine results
116
+ >>> print(f"Baseline score: {mda['baseline_score']:.3f}")
117
+ >>> print(f"Most important feature: {mda['feature_names'][0]}")
118
+ >>> print(f"Importance (accuracy drop): {mda['importances'][0]:.3f}")
119
+ Baseline score: 0.920
120
+ Most important feature: feature_3
121
+ Importance (accuracy drop): 0.124
122
+
123
+ **Feature Groups Example**:
124
+
125
+ >>> # Group related features (e.g., one-hot encoded categorical)
126
+ >>> feature_groups = {
127
+ ... "category_A": ["feature_0", "feature_1", "feature_2"],
128
+ ... "category_B": ["feature_3", "feature_4"],
129
+ ... "numeric": ["feature_5", "feature_6", "feature_7"]
130
+ ... }
131
+ >>>
132
+ >>> mda_groups = compute_mda_importance(
133
+ ... model=model,
134
+ ... X=X,
135
+ ... y=y,
136
+ ... feature_groups=feature_groups,
137
+ ... removal_method='mean'
138
+ ... )
139
+ >>>
140
+ >>> # See which group is most important
141
+ >>> print(f"Most important group: {mda_groups['feature_names'][0]}")
142
+ >>> print(f"Group importance: {mda_groups['importances'][0]:.3f}")
143
+
144
+ Notes
145
+ -----
146
+ **MDA vs PFI** (Permutation Feature Importance):
147
+
148
+ **MDA Characteristics**:
149
+ - Removes feature completely (sets to constant)
150
+ - Simulates true feature unavailability
151
+ - May show larger importance drops than PFI
152
+ - Naturally supports feature groups
153
+ - Similar computational cost to PFI
154
+
155
+ **PFI Characteristics**:
156
+ - Shuffles feature values (breaks feature-target relationship)
157
+ - Preserves feature distribution
158
+ - May show smaller importance drops
159
+ - Requires additional logic for feature groups
160
+ - More commonly used in literature
161
+
162
+ **When to use MDA**:
163
+ - Want to simulate complete feature removal
164
+ - Need to evaluate feature groups jointly
165
+ - Want more conservative importance estimates
166
+ - Comparing "with feature" vs "without feature" scenarios
167
+
168
+ **When to use PFI instead**:
169
+ - Want to match published baselines (PFI more common)
170
+ - Need to preserve feature distributions
171
+ - Want less conservative importance estimates
172
+
173
+ **Feature Groups**:
174
+ Feature groups are useful for:
175
+ - One-hot encoded categoricals (remove all dummy variables together)
176
+ - Related features (lat/lon, year/month/day)
177
+ - Multi-dimensional embeddings
178
+ - Polynomial features of same base feature
179
+
180
+ Removing feature groups jointly captures their combined importance,
181
+ which can be higher than the sum of individual importances due to
182
+ interactions between features in the group.
183
+
184
+ **Removal Methods**:
185
+
186
+ - **mean**: Most common choice for continuous features. Replaces feature
187
+ with its training set mean. This is a "neutral" value that doesn't
188
+ distort the model's input distribution.
189
+
190
+ - **median**: More robust to outliers than mean. Useful for features with
191
+ skewed distributions or outliers.
192
+
193
+ - **zero**: Simple but can be problematic if zero is out-of-distribution
194
+ for a feature (e.g., if feature is always positive). Use with caution.
195
+
196
+ **Computational Cost**:
197
+ - Time complexity: O(n_features * prediction_time) or O(n_groups * prediction_time)
198
+ - Same order as PFI (one evaluation per feature/group)
199
+ - Cannot be trivially parallelized (requires data modification)
200
+ - Faster than SHAP for large datasets
201
+
202
+ **Comparison with Other Methods**:
203
+
204
+ | Method | Speed | Groups | Local | Theory | Bias |
205
+ |--------|----------|--------|-------|-------------|------|
206
+ | MDI | Fastest | No | No | Weak | Yes |
207
+ | PFI | Slow | Hard | No | Strong | No |
208
+ | MDA | Slow | Yes | No | Strong | No |
209
+ | SHAP | Medium | No | Yes | Strongest | No |
210
+
211
+ - **Speed**: MDI instant (from training), PFI/MDA slow (repeated scoring),
212
+ SHAP medium (depends on data size)
213
+ - **Groups**: MDA naturally supports, PFI requires workarounds, MDI/SHAP no
214
+ - **Local**: SHAP provides per-sample importances, others are global only
215
+ - **Theory**: SHAP has strongest game-theoretic foundation, PFI/MDA empirical
216
+ - **Bias**: MDI biased toward high-cardinality features, others unbiased
217
+
218
+ **Best Practices**:
219
+ - Use validation/test set (not training data) for unbiased estimates
220
+ - Compare MDA with PFI and SHAP for robustness
221
+ - Use feature groups for one-hot encoded categoricals
222
+ - Choose removal_method based on feature distributions
223
+ - Verify model still makes reasonable predictions after removal
224
+
225
+ References
226
+ ----------
227
+ .. [ALT] A. Altmann, L. Toloşi, O. Sander, T. Lengauer,
228
+ "Permutation importance: a corrected feature importance measure",
229
+ Bioinformatics, 26(10), 1340-1347, 2010.
230
+ .. [FIS] A. Fisher, C. Rudin, F. Dominici,
231
+ "All Models are Wrong, but Many are Useful: Learning a Variable's
232
+ Importance by Studying an Entire Class of Prediction Models Simultaneously",
233
+ JMLR, 20(177):1-81, 2019.
234
+ """
235
+ # Validate removal method
236
+ valid_methods = ["mean", "median", "zero"]
237
+ if removal_method not in valid_methods:
238
+ raise ValueError(f"removal_method must be one of {valid_methods}, got '{removal_method}'")
239
+
240
+ # Convert inputs to numpy arrays and extract feature names
241
+ if isinstance(X, pl.DataFrame):
242
+ if feature_names is None:
243
+ feature_names = list(X.columns) # Polars columns is already a list
244
+ X_array = X.to_numpy()
245
+ elif isinstance(X, pd.DataFrame):
246
+ if feature_names is None:
247
+ feature_names = X.columns.tolist()
248
+ X_array = X.values
249
+ else:
250
+ X_array = np.asarray(X)
251
+ if feature_names is None:
252
+ feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]
253
+
254
+ y_array: NDArray[Any]
255
+ if isinstance(y, pl.Series):
256
+ y_array = y.to_numpy()
257
+ elif isinstance(y, pd.Series):
258
+ y_array = y.to_numpy()
259
+ else:
260
+ y_array = np.asarray(y)
261
+
262
+ # Validate dimensions
263
+ n_samples, n_features = X_array.shape
264
+ if len(y_array) != n_samples:
265
+ raise ValueError(
266
+ f"X and y have inconsistent numbers of samples: {n_samples} vs {len(y_array)}"
267
+ )
268
+
269
+ # Set up scoring function
270
+ if scoring is None:
271
+ scorer = None
272
+ baseline_score = model.score(X_array, y_array)
273
+ scoring_name = "default"
274
+ else:
275
+ from sklearn.metrics import get_scorer
276
+
277
+ scorer = get_scorer(scoring) if isinstance(scoring, str) else scoring
278
+ baseline_score = scorer(model, X_array, y_array)
279
+ scoring_name = scoring if isinstance(scoring, str) else "custom"
280
+
281
+ # Compute feature replacement values based on removal method
282
+ if removal_method == "mean":
283
+ replacement_values = np.mean(X_array, axis=0)
284
+ elif removal_method == "median":
285
+ replacement_values = np.median(X_array, axis=0)
286
+ else: # removal_method == "zero"
287
+ replacement_values = np.zeros(n_features)
288
+
289
+ # Determine whether we're evaluating individual features or groups
290
+ if feature_groups is not None:
291
+ # Validate feature groups (feature_names is always set by this point)
292
+ assert feature_names is not None
293
+ all_group_features: set[str] = set()
294
+ for group_name, features in feature_groups.items():
295
+ for feat in features:
296
+ if feat not in feature_names:
297
+ raise ValueError(
298
+ f"Feature '{feat}' in group '{group_name}' not found in feature_names"
299
+ )
300
+ all_group_features.add(feat)
301
+
302
+ # Map feature names to indices
303
+ feature_name_to_idx = {name: idx for idx, name in enumerate(feature_names)}
304
+
305
+ # Compute importance for each group
306
+ importances_list = []
307
+ group_names = []
308
+
309
+ for group_name, features in feature_groups.items():
310
+ # Get indices for all features in this group
311
+ feature_indices = [feature_name_to_idx[feat] for feat in features]
312
+
313
+ # Create modified data with group features removed
314
+ X_removed = X_array.copy()
315
+ for idx in feature_indices:
316
+ X_removed[:, idx] = replacement_values[idx]
317
+
318
+ # Compute score with group removed
319
+ removed_score = (
320
+ model.score(X_removed, y_array)
321
+ if scorer is None
322
+ else scorer(model, X_removed, y_array)
323
+ )
324
+
325
+ # Importance is the drop in performance
326
+ importance = baseline_score - removed_score
327
+ importances_list.append(importance)
328
+ group_names.append(group_name)
329
+
330
+ importances = np.array(importances_list)
331
+ eval_feature_names = group_names
332
+ n_eval_features = len(feature_groups)
333
+
334
+ else:
335
+ # Compute importance for individual features
336
+ importances_list = []
337
+
338
+ for feature_idx in range(n_features):
339
+ # Create modified data with feature removed
340
+ X_removed = X_array.copy()
341
+ X_removed[:, feature_idx] = replacement_values[feature_idx]
342
+
343
+ # Compute score with feature removed
344
+ removed_score = (
345
+ model.score(X_removed, y_array)
346
+ if scorer is None
347
+ else scorer(model, X_removed, y_array)
348
+ )
349
+
350
+ # Importance is the drop in performance
351
+ importance = baseline_score - removed_score
352
+ importances_list.append(importance)
353
+
354
+ importances = np.array(importances_list)
355
+ eval_feature_names = feature_names
356
+ n_eval_features = n_features
357
+
358
+ # Sort by importance (descending)
359
+ sorted_idx = np.argsort(importances)[::-1]
360
+
361
+ # Type assertion: eval_feature_names is guaranteed to be set
362
+ assert eval_feature_names is not None, "eval_feature_names should be set by this point"
363
+
364
+ return {
365
+ "importances": importances[sorted_idx],
366
+ "feature_names": [eval_feature_names[i] for i in sorted_idx],
367
+ "baseline_score": float(baseline_score),
368
+ "removal_method": removal_method,
369
+ "scoring": scoring_name,
370
+ "n_features": n_eval_features,
371
+ }