ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""Classical feature importance: Permutation (PFI) and Mean Decrease Impurity (MDI).
|
|
2
|
+
|
|
3
|
+
This module provides model-agnostic permutation importance and tree-based MDI
|
|
4
|
+
importance calculations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from numpy.typing import NDArray
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def compute_permutation_importance(
|
|
19
|
+
model: Any,
|
|
20
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
21
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"],
|
|
22
|
+
feature_names: list[str] | None = None,
|
|
23
|
+
scoring: str | Callable | None = None,
|
|
24
|
+
n_repeats: int = 10,
|
|
25
|
+
random_state: int | None = 42,
|
|
26
|
+
n_jobs: int | None = None,
|
|
27
|
+
) -> dict[str, Any]:
|
|
28
|
+
"""Compute Permutation Feature Importance (PFI) for model-agnostic feature ranking.
|
|
29
|
+
|
|
30
|
+
Permutation Feature Importance measures the increase in model error when a
|
|
31
|
+
feature's values are randomly shuffled. Features with high importance cause
|
|
32
|
+
large performance drops when permuted, indicating they are critical for
|
|
33
|
+
the model's predictions.
|
|
34
|
+
|
|
35
|
+
This is a model-agnostic method that works with any fitted estimator,
|
|
36
|
+
making it superior to model-specific importance measures (e.g., tree-based
|
|
37
|
+
feature importances) which can be biased toward high-cardinality features.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
model : Any
|
|
42
|
+
Fitted sklearn-compatible estimator (must have `predict` or `predict_proba`)
|
|
43
|
+
X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
|
|
44
|
+
Feature matrix (n_samples, n_features)
|
|
45
|
+
y : Union[pl.Series, pd.Series, np.ndarray]
|
|
46
|
+
Target values (n_samples,)
|
|
47
|
+
feature_names : list[str] | None, default None
|
|
48
|
+
Feature names for labeling. If None, uses column names from DataFrame
|
|
49
|
+
or generates numeric names for arrays
|
|
50
|
+
scoring : str | Callable | None, default None
|
|
51
|
+
Scoring function to evaluate model performance. If None, uses model's
|
|
52
|
+
default score method. Common options:
|
|
53
|
+
- Classification: 'accuracy', 'roc_auc', 'f1'
|
|
54
|
+
- Regression: 'r2', 'neg_mean_squared_error', 'neg_mean_absolute_error'
|
|
55
|
+
n_repeats : int, default 10
|
|
56
|
+
Number of times to permute each feature (more repeats = more stable estimates)
|
|
57
|
+
random_state : int | None, default 42
|
|
58
|
+
Random seed for reproducibility
|
|
59
|
+
n_jobs : int | None, default None
|
|
60
|
+
Number of parallel jobs (-1 for all CPUs)
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
dict[str, Any]
|
|
65
|
+
Dictionary with permutation importance results:
|
|
66
|
+
- importances_mean: Mean importance per feature
|
|
67
|
+
- importances_std: Standard deviation of importance per feature
|
|
68
|
+
- importances_raw: All permutation results (n_features, n_repeats)
|
|
69
|
+
- feature_names: Feature labels
|
|
70
|
+
- baseline_score: Model score before permutation
|
|
71
|
+
- n_repeats: Number of permutation rounds
|
|
72
|
+
- scoring: Scoring function used
|
|
73
|
+
|
|
74
|
+
Examples
|
|
75
|
+
--------
|
|
76
|
+
>>> from sklearn.ensemble import RandomForestClassifier
|
|
77
|
+
>>> from sklearn.datasets import make_classification
|
|
78
|
+
>>>
|
|
79
|
+
>>> # Train a simple model
|
|
80
|
+
>>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
|
|
81
|
+
>>> model = RandomForestClassifier(n_estimators=10, random_state=42)
|
|
82
|
+
>>> model.fit(X, y)
|
|
83
|
+
>>>
|
|
84
|
+
>>> # Compute permutation importance
|
|
85
|
+
>>> pfi = compute_permutation_importance(
|
|
86
|
+
... model=model,
|
|
87
|
+
... X=X,
|
|
88
|
+
... y=y,
|
|
89
|
+
... n_repeats=10,
|
|
90
|
+
... scoring='accuracy'
|
|
91
|
+
... )
|
|
92
|
+
>>>
|
|
93
|
+
>>> # Examine results
|
|
94
|
+
>>> print(f"Baseline score: {pfi['baseline_score']:.3f}")
|
|
95
|
+
>>> print(f"Most important feature: {pfi['feature_names'][np.argmax(pfi['importances_mean'])]}")
|
|
96
|
+
>>> print(f"Importance: {np.max(pfi['importances_mean']):.3f} ± {pfi['importances_std'][np.argmax(pfi['importances_mean'])]:.3f}")
|
|
97
|
+
Baseline score: 0.920
|
|
98
|
+
Most important feature: feature_0
|
|
99
|
+
Importance: 0.124 ± 0.015
|
|
100
|
+
|
|
101
|
+
Notes
|
|
102
|
+
-----
|
|
103
|
+
**Interpretation**:
|
|
104
|
+
- Importance = 0: Feature not useful
|
|
105
|
+
- Importance > 0: Feature contributes to predictions
|
|
106
|
+
- Importance < 0: Feature hurts performance (may indicate overfitting)
|
|
107
|
+
- Higher importance = More critical feature
|
|
108
|
+
|
|
109
|
+
**Advantages over MDI** (Mean Decrease in Impurity):
|
|
110
|
+
- Model-agnostic: Works with any estimator
|
|
111
|
+
- Unbiased: Not inflated by high-cardinality features
|
|
112
|
+
- Realistic: Measures actual predictive power, not just tree splits
|
|
113
|
+
|
|
114
|
+
**Computational Cost**:
|
|
115
|
+
- Time complexity: O(n_features * n_repeats * prediction_time)
|
|
116
|
+
- Can be slow for large datasets or complex models
|
|
117
|
+
- Use n_jobs=-1 for parallel computation
|
|
118
|
+
|
|
119
|
+
**Best Practices**:
|
|
120
|
+
- Use hold-out validation set (not training data) for unbiased estimates
|
|
121
|
+
- Increase n_repeats (20-30) for more stable results
|
|
122
|
+
- Check for negative importances (may indicate model instability)
|
|
123
|
+
- Compare with other importance methods (SHAP, MDI) for robustness
|
|
124
|
+
|
|
125
|
+
References
|
|
126
|
+
----------
|
|
127
|
+
.. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001.
|
|
128
|
+
"""
|
|
129
|
+
from sklearn.inspection import permutation_importance as sklearn_pfi
|
|
130
|
+
|
|
131
|
+
# Convert inputs to numpy arrays
|
|
132
|
+
X_array: NDArray[Any]
|
|
133
|
+
if isinstance(X, pl.DataFrame):
|
|
134
|
+
if feature_names is None:
|
|
135
|
+
feature_names = X.columns
|
|
136
|
+
X_array = X.to_numpy()
|
|
137
|
+
elif isinstance(X, pd.DataFrame):
|
|
138
|
+
if feature_names is None:
|
|
139
|
+
feature_names = X.columns.tolist()
|
|
140
|
+
X_array = X.to_numpy()
|
|
141
|
+
else:
|
|
142
|
+
X_array = np.asarray(X)
|
|
143
|
+
if feature_names is None:
|
|
144
|
+
feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]
|
|
145
|
+
|
|
146
|
+
# Type assertion: feature_names is guaranteed to be set at this point
|
|
147
|
+
assert feature_names is not None, "feature_names should be set by this point"
|
|
148
|
+
|
|
149
|
+
y_array: NDArray[Any]
|
|
150
|
+
if isinstance(y, pl.Series):
|
|
151
|
+
y_array = y.to_numpy()
|
|
152
|
+
elif isinstance(y, pd.Series):
|
|
153
|
+
y_array = y.to_numpy()
|
|
154
|
+
else:
|
|
155
|
+
y_array = np.asarray(y)
|
|
156
|
+
|
|
157
|
+
# Compute baseline score
|
|
158
|
+
if scoring is None:
|
|
159
|
+
baseline_score = model.score(X_array, y_array)
|
|
160
|
+
else:
|
|
161
|
+
from sklearn.metrics import get_scorer
|
|
162
|
+
|
|
163
|
+
scorer = get_scorer(scoring) if isinstance(scoring, str) else scoring
|
|
164
|
+
baseline_score = scorer(model, X_array, y_array)
|
|
165
|
+
|
|
166
|
+
# Compute permutation importance using sklearn
|
|
167
|
+
result = sklearn_pfi(
|
|
168
|
+
estimator=model,
|
|
169
|
+
X=X_array,
|
|
170
|
+
y=y_array,
|
|
171
|
+
scoring=scoring,
|
|
172
|
+
n_repeats=n_repeats,
|
|
173
|
+
random_state=random_state,
|
|
174
|
+
n_jobs=n_jobs,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Extract and format results
|
|
178
|
+
importances_mean = result.importances_mean
|
|
179
|
+
importances_std = result.importances_std
|
|
180
|
+
importances_raw = result.importances # Shape: (n_features, n_repeats)
|
|
181
|
+
|
|
182
|
+
# Sort by importance (descending)
|
|
183
|
+
sorted_idx = np.argsort(importances_mean)[::-1]
|
|
184
|
+
|
|
185
|
+
return {
|
|
186
|
+
"importances_mean": importances_mean[sorted_idx],
|
|
187
|
+
"importances_std": importances_std[sorted_idx],
|
|
188
|
+
"importances_raw": importances_raw[sorted_idx],
|
|
189
|
+
"feature_names": [feature_names[i] for i in sorted_idx],
|
|
190
|
+
"baseline_score": float(baseline_score),
|
|
191
|
+
"n_repeats": n_repeats,
|
|
192
|
+
"scoring": scoring if scoring is not None else "default",
|
|
193
|
+
"n_features": len(feature_names),
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def compute_mdi_importance(
|
|
198
|
+
model: Any,
|
|
199
|
+
feature_names: list[str] | None = None,
|
|
200
|
+
normalize: bool = True,
|
|
201
|
+
) -> dict[str, Any]:
|
|
202
|
+
"""Compute Mean Decrease in Impurity (MDI) feature importance from tree-based models.
|
|
203
|
+
|
|
204
|
+
MDI measures how much each feature contributes to decreasing the weighted
|
|
205
|
+
impurity (Gini for classification, MSE/MAE for regression) across all trees.
|
|
206
|
+
This is computed during model training and is available via the model's
|
|
207
|
+
`feature_importances_` attribute.
|
|
208
|
+
|
|
209
|
+
**Supported Models**:
|
|
210
|
+
- LightGBM: `lightgbm.LGBMClassifier`, `lightgbm.LGBMRegressor` (recommended)
|
|
211
|
+
- XGBoost: `xgboost.XGBClassifier`, `xgboost.XGBRegressor` (recommended)
|
|
212
|
+
- sklearn: `RandomForestClassifier`, `RandomForestRegressor` (not recommended - slow)
|
|
213
|
+
- sklearn: `GradientBoostingClassifier`, `GradientBoostingRegressor` (not recommended - slow)
|
|
214
|
+
|
|
215
|
+
**Not supported**:
|
|
216
|
+
- sklearn's HistGradientBoosting* (doesn't expose feature_importances_)
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
model : Any
|
|
221
|
+
Fitted tree-based model with `feature_importances_` attribute.
|
|
222
|
+
Must be one of: LightGBM, XGBoost, or sklearn tree ensembles.
|
|
223
|
+
feature_names : list[str] | None, default None
|
|
224
|
+
Feature names for labeling. If None, uses feature names from model
|
|
225
|
+
or generates numeric names.
|
|
226
|
+
normalize : bool, default True
|
|
227
|
+
If True, ensures importances sum to 1.0 (some models already normalize).
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
dict[str, Any]
|
|
232
|
+
Dictionary with MDI importance results:
|
|
233
|
+
- importances: Feature importance values (sorted descending)
|
|
234
|
+
- feature_names: Feature labels (sorted by importance)
|
|
235
|
+
- n_features: Number of features
|
|
236
|
+
- normalized: Whether values sum to 1.0
|
|
237
|
+
- model_type: Type of model used
|
|
238
|
+
|
|
239
|
+
Raises
|
|
240
|
+
------
|
|
241
|
+
AttributeError
|
|
242
|
+
If model doesn't have `feature_importances_` attribute
|
|
243
|
+
ImportError
|
|
244
|
+
If LightGBM/XGBoost not installed and trying to use those models
|
|
245
|
+
|
|
246
|
+
Examples
|
|
247
|
+
--------
|
|
248
|
+
>>> import lightgbm as lgb
|
|
249
|
+
>>> from sklearn.datasets import make_classification
|
|
250
|
+
>>>
|
|
251
|
+
>>> # Train LightGBM model
|
|
252
|
+
>>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
|
|
253
|
+
>>> model = lgb.LGBMClassifier(n_estimators=100, random_state=42)
|
|
254
|
+
>>> model.fit(X, y)
|
|
255
|
+
>>>
|
|
256
|
+
>>> # Extract MDI importance
|
|
257
|
+
>>> mdi = compute_mdi_importance(
|
|
258
|
+
... model=model,
|
|
259
|
+
... feature_names=[f'feature_{i}' for i in range(10)]
|
|
260
|
+
... )
|
|
261
|
+
>>>
|
|
262
|
+
>>> # Examine results
|
|
263
|
+
>>> print(f"Most important feature: {mdi['feature_names'][0]}")
|
|
264
|
+
>>> print(f"Importance: {mdi['importances'][0]:.3f}")
|
|
265
|
+
>>> print(f"Model type: {mdi['model_type']}")
|
|
266
|
+
Most important feature: feature_3
|
|
267
|
+
Importance: 0.245
|
|
268
|
+
Model type: lightgbm.LGBMClassifier
|
|
269
|
+
|
|
270
|
+
Notes
|
|
271
|
+
-----
|
|
272
|
+
**MDI vs PFI** (Permutation Feature Importance):
|
|
273
|
+
|
|
274
|
+
**MDI Advantages**:
|
|
275
|
+
- Very fast: Computed during training (no additional overhead)
|
|
276
|
+
- No additional data required
|
|
277
|
+
- Deterministic: Same result every time
|
|
278
|
+
|
|
279
|
+
**MDI Disadvantages**:
|
|
280
|
+
- **Biased toward high-cardinality features**: Features with many unique values
|
|
281
|
+
get inflated importance even if not truly predictive
|
|
282
|
+
- **Only for tree-based models**: Not model-agnostic
|
|
283
|
+
- **Train set importance**: May not reflect test set predictive power
|
|
284
|
+
- **Correlated features**: Can split importance between correlated predictors
|
|
285
|
+
|
|
286
|
+
**When to use MDI**:
|
|
287
|
+
- Quick exploratory analysis
|
|
288
|
+
- When computational budget is limited
|
|
289
|
+
- When working with tree-based models exclusively
|
|
290
|
+
|
|
291
|
+
**When to use PFI instead**:
|
|
292
|
+
- Need unbiased importance estimates
|
|
293
|
+
- Have high-cardinality categorical features
|
|
294
|
+
- Want model-agnostic importance
|
|
295
|
+
- Need to validate importance on test set
|
|
296
|
+
|
|
297
|
+
**Comparison workflow**:
|
|
298
|
+
>>> # Compare MDI and PFI
|
|
299
|
+
>>> mdi = compute_mdi_importance(model, feature_names=features)
|
|
300
|
+
>>> pfi = compute_permutation_importance(model, X_test, y_test, feature_names=features)
|
|
301
|
+
>>>
|
|
302
|
+
>>> # Large discrepancies may indicate:
|
|
303
|
+
>>> # - High-cardinality bias in MDI
|
|
304
|
+
>>> # - Correlated features splitting importance
|
|
305
|
+
>>> # - Overfitting (high MDI, low PFI)
|
|
306
|
+
|
|
307
|
+
**Performance notes**:
|
|
308
|
+
- LightGBM and XGBoost: Production-ready speed and accuracy (RECOMMENDED)
|
|
309
|
+
- sklearn RandomForest/GradientBoosting: 10-100x slower, avoid for large datasets
|
|
310
|
+
- sklearn HistGradientBoosting: Fast but doesn't expose feature_importances_ (use PFI instead)
|
|
311
|
+
|
|
312
|
+
References
|
|
313
|
+
----------
|
|
314
|
+
- Breiman, L. (2001). "Random Forests". Machine Learning.
|
|
315
|
+
- Louppe, G. et al. (2013). "Understanding variable importances in forests of
|
|
316
|
+
randomized trees". NeurIPS.
|
|
317
|
+
- Strobl, C. et al. (2007). "Bias in random forest variable importance measures".
|
|
318
|
+
BMC Bioinformatics.
|
|
319
|
+
"""
|
|
320
|
+
# Check if model has feature_importances_
|
|
321
|
+
if not hasattr(model, "feature_importances_"):
|
|
322
|
+
raise AttributeError(
|
|
323
|
+
f"Model of type {type(model).__name__} does not have 'feature_importances_' attribute. "
|
|
324
|
+
"MDI is only available for tree-based models (LightGBM, XGBoost, sklearn tree ensembles)."
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Extract raw importances
|
|
328
|
+
importances = model.feature_importances_
|
|
329
|
+
|
|
330
|
+
# Get feature names
|
|
331
|
+
if feature_names is None:
|
|
332
|
+
# Try to get from model
|
|
333
|
+
if hasattr(model, "feature_name_"):
|
|
334
|
+
# LightGBM
|
|
335
|
+
feature_names = model.feature_name_
|
|
336
|
+
elif hasattr(model, "get_booster") and hasattr(model.get_booster(), "feature_names"):
|
|
337
|
+
# XGBoost
|
|
338
|
+
feature_names = model.get_booster().feature_names
|
|
339
|
+
elif hasattr(model, "feature_names_in_"):
|
|
340
|
+
# sklearn
|
|
341
|
+
feature_names = list(model.feature_names_in_)
|
|
342
|
+
else:
|
|
343
|
+
# Fallback to numeric names
|
|
344
|
+
feature_names = [f"feature_{i}" for i in range(len(importances))]
|
|
345
|
+
else:
|
|
346
|
+
feature_names = list(feature_names)
|
|
347
|
+
|
|
348
|
+
# Validate length match
|
|
349
|
+
if len(feature_names) != len(importances):
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Number of feature names ({len(feature_names)}) does not match number of importances ({len(importances)})"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Normalize if requested
|
|
355
|
+
if normalize:
|
|
356
|
+
importance_sum = importances.sum()
|
|
357
|
+
if importance_sum > 0:
|
|
358
|
+
importances = importances / importance_sum
|
|
359
|
+
else:
|
|
360
|
+
# All zeros - already normalized
|
|
361
|
+
pass
|
|
362
|
+
|
|
363
|
+
# Sort by importance (descending)
|
|
364
|
+
sorted_idx = np.argsort(importances)[::-1]
|
|
365
|
+
|
|
366
|
+
# Determine model type
|
|
367
|
+
model_type = f"{type(model).__module__}.{type(model).__name__}"
|
|
368
|
+
|
|
369
|
+
return {
|
|
370
|
+
"importances": importances[sorted_idx],
|
|
371
|
+
"feature_names": [feature_names[i] for i in sorted_idx],
|
|
372
|
+
"n_features": len(feature_names),
|
|
373
|
+
"normalized": normalize,
|
|
374
|
+
"model_type": model_type,
|
|
375
|
+
}
|