ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""Mean Decrease in Accuracy (MDA) feature importance by feature removal.
|
|
2
|
+
|
|
3
|
+
This module provides MDA importance which measures performance drop when features
|
|
4
|
+
are neutralized, with support for feature groups.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from numpy.typing import NDArray
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def compute_mda_importance(
|
|
19
|
+
model: Any,
|
|
20
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
21
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"],
|
|
22
|
+
feature_names: list[str] | None = None,
|
|
23
|
+
feature_groups: dict[str, list[str]] | None = None,
|
|
24
|
+
removal_method: str = "mean",
|
|
25
|
+
scoring: str | Callable | None = None,
|
|
26
|
+
_n_jobs: int | None = None,
|
|
27
|
+
) -> dict[str, Any]:
|
|
28
|
+
"""Compute Mean Decrease in Accuracy (MDA) by feature removal.
|
|
29
|
+
|
|
30
|
+
MDA measures the drop in model performance when features are removed or
|
|
31
|
+
neutralized. Unlike Permutation Feature Importance (PFI) which shuffles
|
|
32
|
+
feature values, MDA replaces feature values with a constant (mean, median,
|
|
33
|
+
or zero), simulating complete feature unavailability.
|
|
34
|
+
|
|
35
|
+
This approach naturally supports feature groups (e.g., one-hot encoded
|
|
36
|
+
categoricals, related features like lat/lon) by removing multiple features
|
|
37
|
+
simultaneously and measuring the joint importance.
|
|
38
|
+
|
|
39
|
+
**Supported Models**:
|
|
40
|
+
- Any fitted sklearn-compatible estimator with `score()` or `predict()` method
|
|
41
|
+
- Classification: LogisticRegression, RandomForest, XGBoost, LightGBM, etc.
|
|
42
|
+
- Regression: LinearRegression, Ridge, GradientBoosting, etc.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
model : Any
|
|
47
|
+
Fitted sklearn-compatible estimator (must have `score()` or `predict()` method)
|
|
48
|
+
X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
|
|
49
|
+
Feature matrix (n_samples, n_features)
|
|
50
|
+
y : Union[pl.Series, pd.Series, np.ndarray]
|
|
51
|
+
Target values (n_samples,)
|
|
52
|
+
feature_names : list[str] | None, default None
|
|
53
|
+
Feature names for labeling. If None, uses column names from DataFrame
|
|
54
|
+
or generates numeric names for arrays
|
|
55
|
+
feature_groups : dict[str, list[str]] | None, default None
|
|
56
|
+
Dictionary mapping group names to lists of feature names.
|
|
57
|
+
When provided, computes importance for feature groups instead of
|
|
58
|
+
individual features. Example: {"location": ["lat", "lon"],
|
|
59
|
+
"time": ["hour", "day", "month"]}
|
|
60
|
+
removal_method : str, default "mean"
|
|
61
|
+
How to neutralize features:
|
|
62
|
+
- "mean": Replace with feature mean (recommended for continuous features)
|
|
63
|
+
- "median": Replace with feature median (robust to outliers)
|
|
64
|
+
- "zero": Replace with zero (can distort if zero is out-of-distribution)
|
|
65
|
+
scoring : str | Callable | None, default None
|
|
66
|
+
Scoring function to evaluate model performance. If None, uses model's
|
|
67
|
+
default score method. Common options:
|
|
68
|
+
- Classification: 'accuracy', 'roc_auc', 'f1'
|
|
69
|
+
- Regression: 'r2', 'neg_mean_squared_error', 'neg_mean_absolute_error'
|
|
70
|
+
n_jobs : int | None, default None
|
|
71
|
+
Number of parallel jobs for scoring (-1 for all CPUs).
|
|
72
|
+
Note: Parallelization is limited compared to sklearn's implementation
|
|
73
|
+
since we need to modify data for each feature.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
dict[str, Any]
|
|
78
|
+
Dictionary with MDA importance results:
|
|
79
|
+
- importances: Performance drop per feature/group (sorted descending)
|
|
80
|
+
- feature_names: Feature/group labels (sorted by importance)
|
|
81
|
+
- baseline_score: Model score before feature removal
|
|
82
|
+
- removal_method: Method used to neutralize features
|
|
83
|
+
- scoring: Scoring function used
|
|
84
|
+
- n_features: Number of features/groups evaluated
|
|
85
|
+
|
|
86
|
+
Raises
|
|
87
|
+
------
|
|
88
|
+
ValueError
|
|
89
|
+
If removal_method is not one of: "mean", "median", "zero"
|
|
90
|
+
ValueError
|
|
91
|
+
If feature_groups contains unknown feature names
|
|
92
|
+
ValueError
|
|
93
|
+
If X and y have different numbers of samples
|
|
94
|
+
|
|
95
|
+
Examples
|
|
96
|
+
--------
|
|
97
|
+
>>> from sklearn.ensemble import RandomForestClassifier
|
|
98
|
+
>>> from sklearn.datasets import make_classification
|
|
99
|
+
>>> import numpy as np
|
|
100
|
+
>>>
|
|
101
|
+
>>> # Train a simple model
|
|
102
|
+
>>> X, y = make_classification(n_samples=1000, n_features=10, n_informative=3, random_state=42)
|
|
103
|
+
>>> model = RandomForestClassifier(n_estimators=50, random_state=42)
|
|
104
|
+
>>> model.fit(X, y)
|
|
105
|
+
>>>
|
|
106
|
+
>>> # Compute MDA importance
|
|
107
|
+
>>> mda = compute_mda_importance(
|
|
108
|
+
... model=model,
|
|
109
|
+
... X=X,
|
|
110
|
+
... y=y,
|
|
111
|
+
... removal_method='mean',
|
|
112
|
+
... scoring='accuracy'
|
|
113
|
+
... )
|
|
114
|
+
>>>
|
|
115
|
+
>>> # Examine results
|
|
116
|
+
>>> print(f"Baseline score: {mda['baseline_score']:.3f}")
|
|
117
|
+
>>> print(f"Most important feature: {mda['feature_names'][0]}")
|
|
118
|
+
>>> print(f"Importance (accuracy drop): {mda['importances'][0]:.3f}")
|
|
119
|
+
Baseline score: 0.920
|
|
120
|
+
Most important feature: feature_3
|
|
121
|
+
Importance (accuracy drop): 0.124
|
|
122
|
+
|
|
123
|
+
**Feature Groups Example**:
|
|
124
|
+
|
|
125
|
+
>>> # Group related features (e.g., one-hot encoded categorical)
|
|
126
|
+
>>> feature_groups = {
|
|
127
|
+
... "category_A": ["feature_0", "feature_1", "feature_2"],
|
|
128
|
+
... "category_B": ["feature_3", "feature_4"],
|
|
129
|
+
... "numeric": ["feature_5", "feature_6", "feature_7"]
|
|
130
|
+
... }
|
|
131
|
+
>>>
|
|
132
|
+
>>> mda_groups = compute_mda_importance(
|
|
133
|
+
... model=model,
|
|
134
|
+
... X=X,
|
|
135
|
+
... y=y,
|
|
136
|
+
... feature_groups=feature_groups,
|
|
137
|
+
... removal_method='mean'
|
|
138
|
+
... )
|
|
139
|
+
>>>
|
|
140
|
+
>>> # See which group is most important
|
|
141
|
+
>>> print(f"Most important group: {mda_groups['feature_names'][0]}")
|
|
142
|
+
>>> print(f"Group importance: {mda_groups['importances'][0]:.3f}")
|
|
143
|
+
|
|
144
|
+
Notes
|
|
145
|
+
-----
|
|
146
|
+
**MDA vs PFI** (Permutation Feature Importance):
|
|
147
|
+
|
|
148
|
+
**MDA Characteristics**:
|
|
149
|
+
- Removes feature completely (sets to constant)
|
|
150
|
+
- Simulates true feature unavailability
|
|
151
|
+
- May show larger importance drops than PFI
|
|
152
|
+
- Naturally supports feature groups
|
|
153
|
+
- Similar computational cost to PFI
|
|
154
|
+
|
|
155
|
+
**PFI Characteristics**:
|
|
156
|
+
- Shuffles feature values (breaks feature-target relationship)
|
|
157
|
+
- Preserves feature distribution
|
|
158
|
+
- May show smaller importance drops
|
|
159
|
+
- Requires additional logic for feature groups
|
|
160
|
+
- More commonly used in literature
|
|
161
|
+
|
|
162
|
+
**When to use MDA**:
|
|
163
|
+
- Want to simulate complete feature removal
|
|
164
|
+
- Need to evaluate feature groups jointly
|
|
165
|
+
- Want more conservative importance estimates
|
|
166
|
+
- Comparing "with feature" vs "without feature" scenarios
|
|
167
|
+
|
|
168
|
+
**When to use PFI instead**:
|
|
169
|
+
- Want to match published baselines (PFI more common)
|
|
170
|
+
- Need to preserve feature distributions
|
|
171
|
+
- Want less conservative importance estimates
|
|
172
|
+
|
|
173
|
+
**Feature Groups**:
|
|
174
|
+
Feature groups are useful for:
|
|
175
|
+
- One-hot encoded categoricals (remove all dummy variables together)
|
|
176
|
+
- Related features (lat/lon, year/month/day)
|
|
177
|
+
- Multi-dimensional embeddings
|
|
178
|
+
- Polynomial features of same base feature
|
|
179
|
+
|
|
180
|
+
Removing feature groups jointly captures their combined importance,
|
|
181
|
+
which can be higher than the sum of individual importances due to
|
|
182
|
+
interactions between features in the group.
|
|
183
|
+
|
|
184
|
+
**Removal Methods**:
|
|
185
|
+
|
|
186
|
+
- **mean**: Most common choice for continuous features. Replaces feature
|
|
187
|
+
with its training set mean. This is a "neutral" value that doesn't
|
|
188
|
+
distort the model's input distribution.
|
|
189
|
+
|
|
190
|
+
- **median**: More robust to outliers than mean. Useful for features with
|
|
191
|
+
skewed distributions or outliers.
|
|
192
|
+
|
|
193
|
+
- **zero**: Simple but can be problematic if zero is out-of-distribution
|
|
194
|
+
for a feature (e.g., if feature is always positive). Use with caution.
|
|
195
|
+
|
|
196
|
+
**Computational Cost**:
|
|
197
|
+
- Time complexity: O(n_features * prediction_time) or O(n_groups * prediction_time)
|
|
198
|
+
- Same order as PFI (one evaluation per feature/group)
|
|
199
|
+
- Cannot be trivially parallelized (requires data modification)
|
|
200
|
+
- Faster than SHAP for large datasets
|
|
201
|
+
|
|
202
|
+
**Comparison with Other Methods**:
|
|
203
|
+
|
|
204
|
+
| Method | Speed | Groups | Local | Theory | Bias |
|
|
205
|
+
|--------|----------|--------|-------|-------------|------|
|
|
206
|
+
| MDI | Fastest | No | No | Weak | Yes |
|
|
207
|
+
| PFI | Slow | Hard | No | Strong | No |
|
|
208
|
+
| MDA | Slow | Yes | No | Strong | No |
|
|
209
|
+
| SHAP | Medium | No | Yes | Strongest | No |
|
|
210
|
+
|
|
211
|
+
- **Speed**: MDI instant (from training), PFI/MDA slow (repeated scoring),
|
|
212
|
+
SHAP medium (depends on data size)
|
|
213
|
+
- **Groups**: MDA naturally supports, PFI requires workarounds, MDI/SHAP no
|
|
214
|
+
- **Local**: SHAP provides per-sample importances, others are global only
|
|
215
|
+
- **Theory**: SHAP has strongest game-theoretic foundation, PFI/MDA empirical
|
|
216
|
+
- **Bias**: MDI biased toward high-cardinality features, others unbiased
|
|
217
|
+
|
|
218
|
+
**Best Practices**:
|
|
219
|
+
- Use validation/test set (not training data) for unbiased estimates
|
|
220
|
+
- Compare MDA with PFI and SHAP for robustness
|
|
221
|
+
- Use feature groups for one-hot encoded categoricals
|
|
222
|
+
- Choose removal_method based on feature distributions
|
|
223
|
+
- Verify model still makes reasonable predictions after removal
|
|
224
|
+
|
|
225
|
+
References
|
|
226
|
+
----------
|
|
227
|
+
.. [ALT] A. Altmann, L. Toloşi, O. Sander, T. Lengauer,
|
|
228
|
+
"Permutation importance: a corrected feature importance measure",
|
|
229
|
+
Bioinformatics, 26(10), 1340-1347, 2010.
|
|
230
|
+
.. [FIS] A. Fisher, C. Rudin, F. Dominici,
|
|
231
|
+
"All Models are Wrong, but Many are Useful: Learning a Variable's
|
|
232
|
+
Importance by Studying an Entire Class of Prediction Models Simultaneously",
|
|
233
|
+
JMLR, 20(177):1-81, 2019.
|
|
234
|
+
"""
|
|
235
|
+
# Validate removal method
|
|
236
|
+
valid_methods = ["mean", "median", "zero"]
|
|
237
|
+
if removal_method not in valid_methods:
|
|
238
|
+
raise ValueError(f"removal_method must be one of {valid_methods}, got '{removal_method}'")
|
|
239
|
+
|
|
240
|
+
# Convert inputs to numpy arrays and extract feature names
|
|
241
|
+
if isinstance(X, pl.DataFrame):
|
|
242
|
+
if feature_names is None:
|
|
243
|
+
feature_names = list(X.columns) # Polars columns is already a list
|
|
244
|
+
X_array = X.to_numpy()
|
|
245
|
+
elif isinstance(X, pd.DataFrame):
|
|
246
|
+
if feature_names is None:
|
|
247
|
+
feature_names = X.columns.tolist()
|
|
248
|
+
X_array = X.values
|
|
249
|
+
else:
|
|
250
|
+
X_array = np.asarray(X)
|
|
251
|
+
if feature_names is None:
|
|
252
|
+
feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]
|
|
253
|
+
|
|
254
|
+
y_array: NDArray[Any]
|
|
255
|
+
if isinstance(y, pl.Series):
|
|
256
|
+
y_array = y.to_numpy()
|
|
257
|
+
elif isinstance(y, pd.Series):
|
|
258
|
+
y_array = y.to_numpy()
|
|
259
|
+
else:
|
|
260
|
+
y_array = np.asarray(y)
|
|
261
|
+
|
|
262
|
+
# Validate dimensions
|
|
263
|
+
n_samples, n_features = X_array.shape
|
|
264
|
+
if len(y_array) != n_samples:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
f"X and y have inconsistent numbers of samples: {n_samples} vs {len(y_array)}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Set up scoring function
|
|
270
|
+
if scoring is None:
|
|
271
|
+
scorer = None
|
|
272
|
+
baseline_score = model.score(X_array, y_array)
|
|
273
|
+
scoring_name = "default"
|
|
274
|
+
else:
|
|
275
|
+
from sklearn.metrics import get_scorer
|
|
276
|
+
|
|
277
|
+
scorer = get_scorer(scoring) if isinstance(scoring, str) else scoring
|
|
278
|
+
baseline_score = scorer(model, X_array, y_array)
|
|
279
|
+
scoring_name = scoring if isinstance(scoring, str) else "custom"
|
|
280
|
+
|
|
281
|
+
# Compute feature replacement values based on removal method
|
|
282
|
+
if removal_method == "mean":
|
|
283
|
+
replacement_values = np.mean(X_array, axis=0)
|
|
284
|
+
elif removal_method == "median":
|
|
285
|
+
replacement_values = np.median(X_array, axis=0)
|
|
286
|
+
else: # removal_method == "zero"
|
|
287
|
+
replacement_values = np.zeros(n_features)
|
|
288
|
+
|
|
289
|
+
# Determine whether we're evaluating individual features or groups
|
|
290
|
+
if feature_groups is not None:
|
|
291
|
+
# Validate feature groups (feature_names is always set by this point)
|
|
292
|
+
assert feature_names is not None
|
|
293
|
+
all_group_features: set[str] = set()
|
|
294
|
+
for group_name, features in feature_groups.items():
|
|
295
|
+
for feat in features:
|
|
296
|
+
if feat not in feature_names:
|
|
297
|
+
raise ValueError(
|
|
298
|
+
f"Feature '{feat}' in group '{group_name}' not found in feature_names"
|
|
299
|
+
)
|
|
300
|
+
all_group_features.add(feat)
|
|
301
|
+
|
|
302
|
+
# Map feature names to indices
|
|
303
|
+
feature_name_to_idx = {name: idx for idx, name in enumerate(feature_names)}
|
|
304
|
+
|
|
305
|
+
# Compute importance for each group
|
|
306
|
+
importances_list = []
|
|
307
|
+
group_names = []
|
|
308
|
+
|
|
309
|
+
for group_name, features in feature_groups.items():
|
|
310
|
+
# Get indices for all features in this group
|
|
311
|
+
feature_indices = [feature_name_to_idx[feat] for feat in features]
|
|
312
|
+
|
|
313
|
+
# Create modified data with group features removed
|
|
314
|
+
X_removed = X_array.copy()
|
|
315
|
+
for idx in feature_indices:
|
|
316
|
+
X_removed[:, idx] = replacement_values[idx]
|
|
317
|
+
|
|
318
|
+
# Compute score with group removed
|
|
319
|
+
removed_score = (
|
|
320
|
+
model.score(X_removed, y_array)
|
|
321
|
+
if scorer is None
|
|
322
|
+
else scorer(model, X_removed, y_array)
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Importance is the drop in performance
|
|
326
|
+
importance = baseline_score - removed_score
|
|
327
|
+
importances_list.append(importance)
|
|
328
|
+
group_names.append(group_name)
|
|
329
|
+
|
|
330
|
+
importances = np.array(importances_list)
|
|
331
|
+
eval_feature_names = group_names
|
|
332
|
+
n_eval_features = len(feature_groups)
|
|
333
|
+
|
|
334
|
+
else:
|
|
335
|
+
# Compute importance for individual features
|
|
336
|
+
importances_list = []
|
|
337
|
+
|
|
338
|
+
for feature_idx in range(n_features):
|
|
339
|
+
# Create modified data with feature removed
|
|
340
|
+
X_removed = X_array.copy()
|
|
341
|
+
X_removed[:, feature_idx] = replacement_values[feature_idx]
|
|
342
|
+
|
|
343
|
+
# Compute score with feature removed
|
|
344
|
+
removed_score = (
|
|
345
|
+
model.score(X_removed, y_array)
|
|
346
|
+
if scorer is None
|
|
347
|
+
else scorer(model, X_removed, y_array)
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Importance is the drop in performance
|
|
351
|
+
importance = baseline_score - removed_score
|
|
352
|
+
importances_list.append(importance)
|
|
353
|
+
|
|
354
|
+
importances = np.array(importances_list)
|
|
355
|
+
eval_feature_names = feature_names
|
|
356
|
+
n_eval_features = n_features
|
|
357
|
+
|
|
358
|
+
# Sort by importance (descending)
|
|
359
|
+
sorted_idx = np.argsort(importances)[::-1]
|
|
360
|
+
|
|
361
|
+
# Type assertion: eval_feature_names is guaranteed to be set
|
|
362
|
+
assert eval_feature_names is not None, "eval_feature_names should be set by this point"
|
|
363
|
+
|
|
364
|
+
return {
|
|
365
|
+
"importances": importances[sorted_idx],
|
|
366
|
+
"feature_names": [eval_feature_names[i] for i in sorted_idx],
|
|
367
|
+
"baseline_score": float(baseline_score),
|
|
368
|
+
"removal_method": removal_method,
|
|
369
|
+
"scoring": scoring_name,
|
|
370
|
+
"n_features": n_eval_features,
|
|
371
|
+
}
|