ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
"""Comprehensive ML feature importance analysis comparing multiple methods.
|
|
2
|
+
|
|
3
|
+
This module provides a tear sheet function that runs MDI, PFI, MDA, and SHAP
|
|
4
|
+
importance methods and generates a comparison report with consensus ranking.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import polars as pl
|
|
13
|
+
from scipy.stats import spearmanr
|
|
14
|
+
|
|
15
|
+
from ml4t.diagnostic.evaluation.metrics.importance_classical import (
|
|
16
|
+
compute_mdi_importance,
|
|
17
|
+
compute_permutation_importance,
|
|
18
|
+
)
|
|
19
|
+
from ml4t.diagnostic.evaluation.metrics.importance_mda import compute_mda_importance
|
|
20
|
+
from ml4t.diagnostic.evaluation.metrics.importance_shap import compute_shap_importance
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from numpy.typing import NDArray
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _generate_ml_importance_interpretation(
|
|
27
|
+
top_features: list[str],
|
|
28
|
+
method_agreement: dict[str, float],
|
|
29
|
+
warnings: list[str],
|
|
30
|
+
n_consensus: int,
|
|
31
|
+
) -> str:
|
|
32
|
+
"""Generate human-readable interpretation of ML importance analysis.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
top_features : list[str]
|
|
37
|
+
Top features from consensus ranking
|
|
38
|
+
method_agreement : dict[str, float]
|
|
39
|
+
Pairwise correlations between methods
|
|
40
|
+
warnings : list[str]
|
|
41
|
+
List of potential issues detected
|
|
42
|
+
n_consensus : int
|
|
43
|
+
Number of features in top 10 across all methods
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
str
|
|
48
|
+
Human-readable interpretation summary
|
|
49
|
+
"""
|
|
50
|
+
lines = []
|
|
51
|
+
|
|
52
|
+
# Consensus features
|
|
53
|
+
if n_consensus > 0:
|
|
54
|
+
lines.append(f"Strong consensus: {n_consensus} features rank in top 10 across all methods")
|
|
55
|
+
lines.append(f" Top consensus features: {', '.join(top_features[:5])}")
|
|
56
|
+
else:
|
|
57
|
+
lines.append("Weak consensus: Different methods identify different important features")
|
|
58
|
+
|
|
59
|
+
# Method agreement
|
|
60
|
+
if method_agreement:
|
|
61
|
+
avg_agreement = float(np.mean(list(method_agreement.values())))
|
|
62
|
+
if avg_agreement > 0.7:
|
|
63
|
+
lines.append(f"High agreement between methods (avg correlation: {avg_agreement:.2f})")
|
|
64
|
+
elif avg_agreement > 0.5:
|
|
65
|
+
lines.append(
|
|
66
|
+
f"Moderate agreement between methods (avg correlation: {avg_agreement:.2f})"
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
lines.append(
|
|
70
|
+
f"Low agreement between methods (avg correlation: {avg_agreement:.2f}) - investigate further"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Warnings
|
|
74
|
+
if warnings:
|
|
75
|
+
lines.append("\nPotential Issues:")
|
|
76
|
+
for warning in warnings:
|
|
77
|
+
lines.append(f" - {warning}")
|
|
78
|
+
|
|
79
|
+
return "\n".join(lines)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def analyze_ml_importance(
|
|
83
|
+
model: Any,
|
|
84
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
85
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"],
|
|
86
|
+
feature_names: list[str] | None = None,
|
|
87
|
+
methods: list[str] | None = None,
|
|
88
|
+
scoring: str | Callable | None = None,
|
|
89
|
+
n_repeats: int = 10,
|
|
90
|
+
random_state: int | None = 42,
|
|
91
|
+
) -> dict[str, Any]:
|
|
92
|
+
"""Comprehensive ML feature importance analysis comparing multiple methods.
|
|
93
|
+
|
|
94
|
+
**This is a TEAR SHEET function** - it runs multiple importance methods and
|
|
95
|
+
generates a comparison report with consensus ranking and interpretation.
|
|
96
|
+
|
|
97
|
+
**Use Case**: "Which features does my model rely on? Do different methods agree?"
|
|
98
|
+
|
|
99
|
+
This function replaces 100+ lines of manual comparison code by providing
|
|
100
|
+
integrated analysis showing:
|
|
101
|
+
- Individual method results (MDI, PFI, MDA, SHAP)
|
|
102
|
+
- Consensus ranking (features important across methods)
|
|
103
|
+
- Method agreement/disagreement analysis
|
|
104
|
+
- Auto-generated insights and warnings
|
|
105
|
+
|
|
106
|
+
**Why Compare Methods?**
|
|
107
|
+
|
|
108
|
+
Different importance methods measure different aspects:
|
|
109
|
+
- **MDI** (Mean Decrease Impurity): Fast, but biased toward high-cardinality features
|
|
110
|
+
- **PFI** (Permutation): Unbiased, measures predictive importance
|
|
111
|
+
- **MDA** (Mean Decrease Accuracy): Similar to PFI but removes features completely
|
|
112
|
+
- **SHAP**: Theoretically sound, based on game theory
|
|
113
|
+
|
|
114
|
+
Strong consensus across methods indicates robust feature importance.
|
|
115
|
+
Disagreement suggests model-specific artifacts or feature interactions.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
model : Any
|
|
120
|
+
Fitted model. Requirements vary by method:
|
|
121
|
+
- MDI: Must have `feature_importances_` (tree-based models)
|
|
122
|
+
- PFI, MDA: Must have `predict()` or `score()`
|
|
123
|
+
- SHAP: Must be compatible with TreeExplainer
|
|
124
|
+
X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
|
|
125
|
+
Feature matrix (n_samples, n_features)
|
|
126
|
+
y : Union[pl.Series, pd.Series, np.ndarray]
|
|
127
|
+
Target values (n_samples,)
|
|
128
|
+
feature_names : list[str] | None, default None
|
|
129
|
+
Feature names for labeling. If None, uses column names from DataFrame
|
|
130
|
+
or generates numeric names
|
|
131
|
+
methods : list[str] | None, default ["mdi", "pfi", "shap"]
|
|
132
|
+
Which methods to run. Options: "mdi", "pfi", "mda", "shap"
|
|
133
|
+
scoring : str | Callable | None, default None
|
|
134
|
+
Scoring metric for PFI and MDA
|
|
135
|
+
n_repeats : int, default 10
|
|
136
|
+
Number of permutations for PFI
|
|
137
|
+
random_state : int | None, default 42
|
|
138
|
+
Random seed for reproducibility
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
dict[str, Any]
|
|
143
|
+
Comprehensive analysis results:
|
|
144
|
+
- method_results: Dict of individual method outputs
|
|
145
|
+
- consensus_ranking: Features ranked by average rank across methods
|
|
146
|
+
- method_agreement: Spearman correlations between method rankings
|
|
147
|
+
- top_features_consensus: Features in top 10 for ALL methods
|
|
148
|
+
- warnings: Detected issues
|
|
149
|
+
- interpretation: Auto-generated summary
|
|
150
|
+
- methods_run: Methods successfully executed
|
|
151
|
+
- methods_failed: Failed methods with error messages
|
|
152
|
+
|
|
153
|
+
Raises
|
|
154
|
+
------
|
|
155
|
+
ValueError
|
|
156
|
+
If no methods specified or all methods fail
|
|
157
|
+
|
|
158
|
+
Examples
|
|
159
|
+
--------
|
|
160
|
+
>>> from sklearn.ensemble import RandomForestClassifier
|
|
161
|
+
>>> from sklearn.datasets import make_classification
|
|
162
|
+
>>>
|
|
163
|
+
>>> # Create synthetic dataset
|
|
164
|
+
>>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
|
|
165
|
+
>>> model = RandomForestClassifier(n_estimators=50, random_state=42)
|
|
166
|
+
>>> model.fit(X, y)
|
|
167
|
+
>>>
|
|
168
|
+
>>> # Comprehensive importance analysis
|
|
169
|
+
>>> result = analyze_ml_importance(model, X, y, methods=["mdi", "pfi"])
|
|
170
|
+
>>>
|
|
171
|
+
>>> # Quick summary
|
|
172
|
+
>>> print(result["interpretation"])
|
|
173
|
+
"""
|
|
174
|
+
if methods is None:
|
|
175
|
+
methods = ["mdi", "pfi", "shap"]
|
|
176
|
+
|
|
177
|
+
if not methods:
|
|
178
|
+
raise ValueError("At least one method must be specified")
|
|
179
|
+
|
|
180
|
+
# Extract feature names if not provided
|
|
181
|
+
if feature_names is None:
|
|
182
|
+
if isinstance(X, pl.DataFrame | pd.DataFrame):
|
|
183
|
+
feature_names = list(X.columns)
|
|
184
|
+
else:
|
|
185
|
+
# Generate numeric feature names
|
|
186
|
+
n_features = X.shape[1] if hasattr(X, "shape") else len(X[0])
|
|
187
|
+
feature_names = [f"f{i}" for i in range(n_features)]
|
|
188
|
+
|
|
189
|
+
# Run each method with try/except for optional dependencies
|
|
190
|
+
results = {}
|
|
191
|
+
method_failures = []
|
|
192
|
+
|
|
193
|
+
if "mdi" in methods:
|
|
194
|
+
try:
|
|
195
|
+
results["mdi"] = compute_mdi_importance(model, feature_names=feature_names)
|
|
196
|
+
except Exception as e:
|
|
197
|
+
method_failures.append(("mdi", str(e)))
|
|
198
|
+
|
|
199
|
+
if "pfi" in methods:
|
|
200
|
+
try:
|
|
201
|
+
results["pfi"] = compute_permutation_importance(
|
|
202
|
+
model,
|
|
203
|
+
X,
|
|
204
|
+
y,
|
|
205
|
+
feature_names=feature_names,
|
|
206
|
+
scoring=scoring,
|
|
207
|
+
n_repeats=n_repeats,
|
|
208
|
+
random_state=random_state,
|
|
209
|
+
)
|
|
210
|
+
except Exception as e:
|
|
211
|
+
method_failures.append(("pfi", str(e)))
|
|
212
|
+
|
|
213
|
+
if "mda" in methods:
|
|
214
|
+
try:
|
|
215
|
+
results["mda"] = compute_mda_importance(
|
|
216
|
+
model, X, y, feature_names=feature_names, scoring=scoring
|
|
217
|
+
)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
method_failures.append(("mda", str(e)))
|
|
220
|
+
|
|
221
|
+
if "shap" in methods:
|
|
222
|
+
try:
|
|
223
|
+
results["shap"] = compute_shap_importance(model, X, feature_names=feature_names)
|
|
224
|
+
except ImportError:
|
|
225
|
+
method_failures.append(
|
|
226
|
+
(
|
|
227
|
+
"shap",
|
|
228
|
+
"shap library not installed. Install with: pip install ml4t-diagnostic[ml]",
|
|
229
|
+
)
|
|
230
|
+
)
|
|
231
|
+
except Exception as e:
|
|
232
|
+
method_failures.append(("shap", str(e)))
|
|
233
|
+
|
|
234
|
+
# Check if at least one method succeeded
|
|
235
|
+
if not results:
|
|
236
|
+
error_msg = "All methods failed:\n" + "\n".join(
|
|
237
|
+
f" - {method}: {error}" for method, error in method_failures
|
|
238
|
+
)
|
|
239
|
+
raise ValueError(error_msg)
|
|
240
|
+
|
|
241
|
+
# 2. Compute consensus ranking
|
|
242
|
+
# Convert each method's importance to rankings (1 = most important)
|
|
243
|
+
rankings = {}
|
|
244
|
+
for method_name, result in results.items():
|
|
245
|
+
# Get feature names and importances for this method
|
|
246
|
+
method_feature_names = result["feature_names"]
|
|
247
|
+
|
|
248
|
+
if method_name == "pfi":
|
|
249
|
+
importances = result["importances_mean"]
|
|
250
|
+
elif method_name in ["shap", "mdi", "mda"]:
|
|
251
|
+
importances = result["importances"]
|
|
252
|
+
else:
|
|
253
|
+
# Shouldn't happen, but handle gracefully
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
# Create a mapping from feature name to importance
|
|
257
|
+
feature_to_importance = dict(zip(method_feature_names, importances, strict=False))
|
|
258
|
+
|
|
259
|
+
# Map to our canonical feature_names list (handle missing features)
|
|
260
|
+
importance_values = np.array(
|
|
261
|
+
[feature_to_importance.get(fname, 0.0) for fname in feature_names]
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Rank (higher importance = lower rank number, i.e., rank 0 is most important)
|
|
265
|
+
ranks = np.argsort(np.argsort(importance_values)[::-1])
|
|
266
|
+
rankings[method_name] = ranks
|
|
267
|
+
|
|
268
|
+
# Average ranks across methods
|
|
269
|
+
avg_ranks = np.mean(list(rankings.values()), axis=0)
|
|
270
|
+
consensus_order = np.argsort(avg_ranks)
|
|
271
|
+
|
|
272
|
+
# Get feature names in consensus order
|
|
273
|
+
consensus_ranking = [feature_names[i] for i in consensus_order]
|
|
274
|
+
|
|
275
|
+
# 3. Compute method agreement (Spearman correlation between rankings)
|
|
276
|
+
method_agreement = {}
|
|
277
|
+
method_names = list(rankings.keys())
|
|
278
|
+
for i, m1 in enumerate(method_names):
|
|
279
|
+
for m2 in method_names[i + 1 :]:
|
|
280
|
+
corr, _ = spearmanr(rankings[m1], rankings[m2])
|
|
281
|
+
method_agreement[f"{m1}_vs_{m2}"] = float(corr)
|
|
282
|
+
|
|
283
|
+
# 4. Identify consensus top features (top 10 in all methods)
|
|
284
|
+
top_n = 10
|
|
285
|
+
top_features_by_method = {}
|
|
286
|
+
for method_name, result in results.items():
|
|
287
|
+
# Get top N feature names from this method
|
|
288
|
+
method_top_features = result["feature_names"][:top_n]
|
|
289
|
+
top_features_by_method[method_name] = set(method_top_features)
|
|
290
|
+
|
|
291
|
+
consensus_top = (
|
|
292
|
+
set.intersection(*top_features_by_method.values()) if top_features_by_method else set()
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# 5. Generate warnings
|
|
296
|
+
warnings = []
|
|
297
|
+
|
|
298
|
+
# Warning: High MDI but low PFI (possible overfitting)
|
|
299
|
+
if "mdi" in results and "pfi" in results:
|
|
300
|
+
mdi_top = set(results["mdi"]["feature_names"][:5])
|
|
301
|
+
pfi_top = set(results["pfi"]["feature_names"][:5])
|
|
302
|
+
disagreement = mdi_top - pfi_top
|
|
303
|
+
if disagreement:
|
|
304
|
+
warnings.append(
|
|
305
|
+
f"Features {disagreement} rank high in MDI but not PFI - possible overfitting to tree structure"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Warning: Low agreement between methods
|
|
309
|
+
if method_agreement:
|
|
310
|
+
min_agreement = min(method_agreement.values())
|
|
311
|
+
if min_agreement < 0.5:
|
|
312
|
+
warnings.append(
|
|
313
|
+
f"Low agreement between methods (min correlation: {min_agreement:.2f}) - results may be unreliable"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Add method failures to warnings
|
|
317
|
+
if method_failures:
|
|
318
|
+
for method, error in method_failures:
|
|
319
|
+
warnings.append(f"Method '{method}' failed: {error}")
|
|
320
|
+
|
|
321
|
+
# 6. Generate interpretation
|
|
322
|
+
interpretation = _generate_ml_importance_interpretation(
|
|
323
|
+
consensus_ranking[:10],
|
|
324
|
+
method_agreement,
|
|
325
|
+
warnings,
|
|
326
|
+
len(consensus_top),
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
return {
|
|
330
|
+
"method_results": results,
|
|
331
|
+
"consensus_ranking": consensus_ranking,
|
|
332
|
+
"method_agreement": method_agreement,
|
|
333
|
+
"top_features_consensus": list(consensus_top),
|
|
334
|
+
"warnings": warnings,
|
|
335
|
+
"interpretation": interpretation,
|
|
336
|
+
"methods_run": list(results.keys()),
|
|
337
|
+
"methods_failed": method_failures,
|
|
338
|
+
}
|