ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,715 @@
|
|
|
1
|
+
"""SHAP-based feature importance with multi-explainer support.
|
|
2
|
+
|
|
3
|
+
This module provides SHAP value computation with automatic explainer selection
|
|
4
|
+
for tree-based, linear, and model-agnostic approaches.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from numpy.typing import NDArray
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _detect_gpu_available() -> bool:
|
|
19
|
+
"""Detect if GPU acceleration is available for SHAP computations.
|
|
20
|
+
|
|
21
|
+
GPU acceleration is currently supported only for TreeExplainer with:
|
|
22
|
+
- NVIDIA GPU
|
|
23
|
+
- CUDA 11.0+
|
|
24
|
+
- cupy library installed
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
bool
|
|
29
|
+
True if GPU is available and cupy is installed, False otherwise
|
|
30
|
+
|
|
31
|
+
Notes
|
|
32
|
+
-----
|
|
33
|
+
This function checks for cupy availability as a proxy for GPU support.
|
|
34
|
+
Even if a GPU is present, cupy must be installed for SHAP to use it.
|
|
35
|
+
|
|
36
|
+
GPU acceleration provides 10-100x speedup for large datasets (>5K samples)
|
|
37
|
+
but has overhead that makes it slower for small datasets (<5K samples).
|
|
38
|
+
"""
|
|
39
|
+
try:
|
|
40
|
+
import cupy as cp
|
|
41
|
+
|
|
42
|
+
# Check if GPU is actually accessible
|
|
43
|
+
_ = cp.cuda.Device(0)
|
|
44
|
+
return True
|
|
45
|
+
except (ImportError, RuntimeError):
|
|
46
|
+
# ImportError: cupy not installed
|
|
47
|
+
# RuntimeError: CUDA not available or no GPU found
|
|
48
|
+
return False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_explainer(
|
|
52
|
+
model: Any,
|
|
53
|
+
X_array: "NDArray[Any]",
|
|
54
|
+
explainer_type: str = "auto",
|
|
55
|
+
use_gpu: bool | str = "auto",
|
|
56
|
+
background_data: Union["NDArray[Any]", None] = None,
|
|
57
|
+
**explainer_kwargs: Any,
|
|
58
|
+
) -> tuple[Any, str, float]:
|
|
59
|
+
"""Select and create appropriate SHAP explainer for the given model.
|
|
60
|
+
|
|
61
|
+
Implements automatic explainer selection with try-except cascade:
|
|
62
|
+
1. TreeExplainer (fast, exact, tree models only)
|
|
63
|
+
2. LinearExplainer (fast, exact, linear models only)
|
|
64
|
+
3. KernelExplainer (slow, approximate, model-agnostic fallback)
|
|
65
|
+
|
|
66
|
+
DeepExplainer is NOT included in auto-selection because it requires
|
|
67
|
+
explicit background data specification. Use explainer_type="deep" explicitly.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
model : Any
|
|
72
|
+
Fitted model to explain
|
|
73
|
+
X_array : np.ndarray
|
|
74
|
+
Feature matrix for SHAP computation
|
|
75
|
+
explainer_type : str, default "auto"
|
|
76
|
+
Explainer type to use:
|
|
77
|
+
- "auto": Try tree -> linear -> kernel (recommended)
|
|
78
|
+
- "tree": TreeExplainer (tree models only)
|
|
79
|
+
- "linear": LinearExplainer (linear models only)
|
|
80
|
+
- "deep": DeepExplainer (neural networks, requires background_data)
|
|
81
|
+
- "kernel": KernelExplainer (model-agnostic, slow)
|
|
82
|
+
use_gpu : bool | str, default "auto"
|
|
83
|
+
GPU acceleration mode (TreeExplainer only):
|
|
84
|
+
- "auto": Use GPU if available and dataset large enough (>5K samples)
|
|
85
|
+
- True: Force GPU usage (raises error if unavailable)
|
|
86
|
+
- False: Force CPU usage
|
|
87
|
+
background_data : np.ndarray | None, default None
|
|
88
|
+
Background dataset for explainers that need it (Kernel, Deep).
|
|
89
|
+
If None, will be auto-sampled from X_array for Kernel.
|
|
90
|
+
Required for Deep explainer.
|
|
91
|
+
**explainer_kwargs : Any
|
|
92
|
+
Additional keyword arguments passed to explainer constructor
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
tuple[Any, str, float]
|
|
97
|
+
- explainer: Initialized SHAP explainer instance
|
|
98
|
+
- type_name: Name of explainer type used ("tree", "linear", "kernel", "deep")
|
|
99
|
+
- ms_per_sample: Estimated milliseconds per sample for performance warnings
|
|
100
|
+
|
|
101
|
+
Raises
|
|
102
|
+
------
|
|
103
|
+
ImportError
|
|
104
|
+
If shap library not installed
|
|
105
|
+
ValueError
|
|
106
|
+
If explainer_type is invalid or if auto-selection fails for all explainers
|
|
107
|
+
RuntimeError
|
|
108
|
+
If GPU requested but unavailable
|
|
109
|
+
"""
|
|
110
|
+
try:
|
|
111
|
+
import shap
|
|
112
|
+
except ImportError as e:
|
|
113
|
+
raise ImportError(
|
|
114
|
+
"SHAP library is not installed. Install with: pip install ml4t-diagnostic[ml] or: pip install shap>=0.41.0"
|
|
115
|
+
) from e
|
|
116
|
+
|
|
117
|
+
# Validate explainer_type
|
|
118
|
+
valid_types = {"auto", "tree", "linear", "deep", "kernel"}
|
|
119
|
+
if explainer_type not in valid_types:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Invalid explainer_type '{explainer_type}'. Must be one of: {', '.join(sorted(valid_types))}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Handle GPU detection and configuration
|
|
125
|
+
gpu_available = _detect_gpu_available()
|
|
126
|
+
use_gpu_final = False
|
|
127
|
+
|
|
128
|
+
if use_gpu == "auto":
|
|
129
|
+
# Auto-detect: Use GPU if available AND dataset large enough
|
|
130
|
+
n_samples = X_array.shape[0]
|
|
131
|
+
use_gpu_final = gpu_available and n_samples >= 5000
|
|
132
|
+
elif use_gpu is True:
|
|
133
|
+
if not gpu_available:
|
|
134
|
+
raise RuntimeError(
|
|
135
|
+
"GPU requested (use_gpu=True) but GPU not available. "
|
|
136
|
+
"Ensure NVIDIA GPU, CUDA 11.0+, and cupy are installed. "
|
|
137
|
+
"Install with: pip install ml4t-diagnostic[gpu]"
|
|
138
|
+
)
|
|
139
|
+
use_gpu_final = True
|
|
140
|
+
else: # use_gpu is False
|
|
141
|
+
use_gpu_final = False
|
|
142
|
+
|
|
143
|
+
# Explicit explainer type requested
|
|
144
|
+
if explainer_type != "auto":
|
|
145
|
+
return _create_explainer_by_type(
|
|
146
|
+
explainer_type=explainer_type,
|
|
147
|
+
model=model,
|
|
148
|
+
X_array=X_array,
|
|
149
|
+
use_gpu=use_gpu_final,
|
|
150
|
+
background_data=background_data,
|
|
151
|
+
shap=shap,
|
|
152
|
+
**explainer_kwargs,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Auto-selection cascade: Tree -> Linear -> Kernel
|
|
156
|
+
errors = []
|
|
157
|
+
|
|
158
|
+
# Try TreeExplainer first (fastest, most common)
|
|
159
|
+
try:
|
|
160
|
+
tree_kwargs = {"feature_perturbation": "tree_path_dependent"}
|
|
161
|
+
tree_kwargs.update(explainer_kwargs) # User kwargs override defaults
|
|
162
|
+
|
|
163
|
+
explainer = shap.TreeExplainer(model, **tree_kwargs)
|
|
164
|
+
# GPU mode only for tree explainer
|
|
165
|
+
if use_gpu_final and hasattr(explainer, "gpu"):
|
|
166
|
+
setattr(explainer, "gpu", True) # noqa: B010
|
|
167
|
+
ms_per_sample = 5.0 # ~1-10ms typical
|
|
168
|
+
return (explainer, "tree", ms_per_sample)
|
|
169
|
+
except Exception as e:
|
|
170
|
+
errors.append(f"TreeExplainer: {e}")
|
|
171
|
+
|
|
172
|
+
# Try LinearExplainer second (fast, exact for linear models)
|
|
173
|
+
try:
|
|
174
|
+
explainer = shap.LinearExplainer(model, X_array, **explainer_kwargs)
|
|
175
|
+
ms_per_sample = 75.0 # ~50-100ms typical
|
|
176
|
+
return (explainer, "linear", ms_per_sample)
|
|
177
|
+
except Exception as e:
|
|
178
|
+
errors.append(f"LinearExplainer: {e}")
|
|
179
|
+
|
|
180
|
+
# Try KernelExplainer as fallback (slow but model-agnostic)
|
|
181
|
+
try:
|
|
182
|
+
# Sample background data if not provided
|
|
183
|
+
if background_data is None:
|
|
184
|
+
background_data = _sample_background(X_array, max_samples=100, method="random")
|
|
185
|
+
|
|
186
|
+
# Create prediction function wrapper to avoid LightGBM property issues
|
|
187
|
+
if hasattr(model, "predict_proba"):
|
|
188
|
+
# For binary classification, return probability of positive class
|
|
189
|
+
def predict_fn(X):
|
|
190
|
+
proba = model.predict_proba(X)
|
|
191
|
+
if proba.shape[1] == 2:
|
|
192
|
+
return proba[:, 1] # Binary: positive class
|
|
193
|
+
return proba # Multiclass: all classes
|
|
194
|
+
else:
|
|
195
|
+
predict_fn = model.predict
|
|
196
|
+
|
|
197
|
+
explainer = shap.KernelExplainer(predict_fn, background_data, **explainer_kwargs)
|
|
198
|
+
ms_per_sample = 5000.0 # ~1-10 seconds typical
|
|
199
|
+
return (explainer, "kernel", ms_per_sample)
|
|
200
|
+
except Exception as e:
|
|
201
|
+
errors.append(f"KernelExplainer: {e}")
|
|
202
|
+
|
|
203
|
+
# All explainers failed
|
|
204
|
+
error_summary = "\n - ".join(errors)
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"Failed to create explainer for model type {type(model).__name__}. "
|
|
207
|
+
f"Tried tree, linear, and kernel explainers. Errors:\n - {error_summary}\n"
|
|
208
|
+
f"Consider using explainer_type='kernel' explicitly with custom background_data."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _create_explainer_by_type(
|
|
213
|
+
explainer_type: str,
|
|
214
|
+
model: Any,
|
|
215
|
+
X_array: "NDArray[Any]",
|
|
216
|
+
use_gpu: bool,
|
|
217
|
+
background_data: Union["NDArray[Any]", None],
|
|
218
|
+
shap: Any,
|
|
219
|
+
**explainer_kwargs: Any,
|
|
220
|
+
) -> tuple[Any, str, float]:
|
|
221
|
+
"""Create specific explainer type (helper for _get_explainer).
|
|
222
|
+
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
explainer_type : str
|
|
226
|
+
One of: "tree", "linear", "deep", "kernel"
|
|
227
|
+
model : Any
|
|
228
|
+
Fitted model
|
|
229
|
+
X_array : np.ndarray
|
|
230
|
+
Feature matrix
|
|
231
|
+
use_gpu : bool
|
|
232
|
+
Whether to use GPU (tree only)
|
|
233
|
+
background_data : np.ndarray | None
|
|
234
|
+
Background data for kernel/deep explainers
|
|
235
|
+
shap : module
|
|
236
|
+
Imported shap module
|
|
237
|
+
**explainer_kwargs : Any
|
|
238
|
+
Additional explainer arguments
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
tuple[Any, str, float]
|
|
243
|
+
(explainer, type_name, ms_per_sample)
|
|
244
|
+
|
|
245
|
+
Raises
|
|
246
|
+
------
|
|
247
|
+
ValueError
|
|
248
|
+
If explainer creation fails
|
|
249
|
+
ImportError
|
|
250
|
+
If deep learning dependencies not available
|
|
251
|
+
"""
|
|
252
|
+
try:
|
|
253
|
+
if explainer_type == "tree":
|
|
254
|
+
# Set default feature_perturbation unless user overrides
|
|
255
|
+
tree_kwargs = {"feature_perturbation": "tree_path_dependent"}
|
|
256
|
+
tree_kwargs.update(explainer_kwargs) # User kwargs override defaults
|
|
257
|
+
|
|
258
|
+
explainer = shap.TreeExplainer(model, **tree_kwargs)
|
|
259
|
+
if use_gpu and hasattr(explainer, "gpu"):
|
|
260
|
+
explainer.gpu = True
|
|
261
|
+
ms_per_sample = 5.0
|
|
262
|
+
return (explainer, "tree", ms_per_sample)
|
|
263
|
+
|
|
264
|
+
elif explainer_type == "linear":
|
|
265
|
+
explainer = shap.LinearExplainer(model, X_array, **explainer_kwargs)
|
|
266
|
+
ms_per_sample = 75.0
|
|
267
|
+
return (explainer, "linear", ms_per_sample)
|
|
268
|
+
|
|
269
|
+
elif explainer_type == "deep":
|
|
270
|
+
if background_data is None:
|
|
271
|
+
raise ValueError(
|
|
272
|
+
"DeepExplainer requires background_data parameter. "
|
|
273
|
+
"Provide a representative sample of your training data "
|
|
274
|
+
"(typically 100-1000 samples)."
|
|
275
|
+
)
|
|
276
|
+
try:
|
|
277
|
+
explainer = shap.DeepExplainer(model, background_data, **explainer_kwargs)
|
|
278
|
+
except ImportError as e:
|
|
279
|
+
raise ImportError(
|
|
280
|
+
"DeepExplainer requires deep learning libraries (TensorFlow or PyTorch). "
|
|
281
|
+
"Install with: pip install ml4t-diagnostic[deep]"
|
|
282
|
+
) from e
|
|
283
|
+
ms_per_sample = 500.0 # ~100ms-1s typical
|
|
284
|
+
return (explainer, "deep", ms_per_sample)
|
|
285
|
+
|
|
286
|
+
elif explainer_type == "kernel":
|
|
287
|
+
if background_data is None:
|
|
288
|
+
background_data = _sample_background(X_array, max_samples=100, method="random")
|
|
289
|
+
|
|
290
|
+
# Create prediction function wrapper to avoid LightGBM property issues
|
|
291
|
+
# For classifiers, use predict_proba if available (more informative)
|
|
292
|
+
if hasattr(model, "predict_proba"):
|
|
293
|
+
# For binary classification, return probability of positive class
|
|
294
|
+
def predict_fn(X):
|
|
295
|
+
proba = model.predict_proba(X)
|
|
296
|
+
if proba.shape[1] == 2:
|
|
297
|
+
return proba[:, 1] # Binary: positive class
|
|
298
|
+
return proba # Multiclass: all classes
|
|
299
|
+
else:
|
|
300
|
+
predict_fn = model.predict
|
|
301
|
+
|
|
302
|
+
explainer = shap.KernelExplainer(predict_fn, background_data, **explainer_kwargs)
|
|
303
|
+
ms_per_sample = 5000.0
|
|
304
|
+
return (explainer, "kernel", ms_per_sample)
|
|
305
|
+
|
|
306
|
+
else:
|
|
307
|
+
raise ValueError(f"Unknown explainer_type: {explainer_type}")
|
|
308
|
+
|
|
309
|
+
except Exception as e:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
f"Failed to create {explainer_type.capitalize()}Explainer for model type {type(model).__name__}. Error: {e}"
|
|
312
|
+
) from e
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _sample_background(
|
|
316
|
+
X_array: "NDArray[Any]", max_samples: int = 100, method: str = "random"
|
|
317
|
+
) -> "NDArray[Any]":
|
|
318
|
+
"""Sample background dataset for KernelExplainer.
|
|
319
|
+
|
|
320
|
+
Background data represents "typical" feature values used as reference
|
|
321
|
+
for computing SHAP values. Smaller backgrounds = faster computation.
|
|
322
|
+
|
|
323
|
+
Parameters
|
|
324
|
+
----------
|
|
325
|
+
X_array : np.ndarray
|
|
326
|
+
Full feature matrix
|
|
327
|
+
max_samples : int, default 100
|
|
328
|
+
Maximum number of background samples
|
|
329
|
+
method : str, default "random"
|
|
330
|
+
Sampling method: "random" or "kmeans"
|
|
331
|
+
|
|
332
|
+
Returns
|
|
333
|
+
-------
|
|
334
|
+
np.ndarray
|
|
335
|
+
Background dataset (max_samples, n_features)
|
|
336
|
+
|
|
337
|
+
Notes
|
|
338
|
+
-----
|
|
339
|
+
- Random: Fast, simple, works well for most cases
|
|
340
|
+
- K-means: Better representation of data distribution, slower
|
|
341
|
+
"""
|
|
342
|
+
n_samples = X_array.shape[0]
|
|
343
|
+
|
|
344
|
+
if n_samples <= max_samples:
|
|
345
|
+
return X_array
|
|
346
|
+
|
|
347
|
+
if method == "random":
|
|
348
|
+
rng = np.random.default_rng(42)
|
|
349
|
+
idx = rng.choice(n_samples, size=max_samples, replace=False)
|
|
350
|
+
return X_array[idx]
|
|
351
|
+
elif method == "kmeans":
|
|
352
|
+
# K-means clustering for representative samples
|
|
353
|
+
try:
|
|
354
|
+
from sklearn.cluster import KMeans
|
|
355
|
+
|
|
356
|
+
kmeans = KMeans(n_clusters=max_samples, random_state=42, n_init=10)
|
|
357
|
+
kmeans.fit(X_array)
|
|
358
|
+
return kmeans.cluster_centers_
|
|
359
|
+
except ImportError:
|
|
360
|
+
# Fallback to random if sklearn not available
|
|
361
|
+
rng = np.random.default_rng(42)
|
|
362
|
+
idx = rng.choice(n_samples, size=max_samples, replace=False)
|
|
363
|
+
return X_array[idx]
|
|
364
|
+
else:
|
|
365
|
+
raise ValueError(f"Unknown sampling method: {method}. Use 'random' or 'kmeans'.")
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _estimate_computation_time(
|
|
369
|
+
explainer_type: str,
|
|
370
|
+
n_samples: int,
|
|
371
|
+
ms_per_sample: float,
|
|
372
|
+
performance_warning: bool = True,
|
|
373
|
+
) -> None:
|
|
374
|
+
"""Estimate SHAP computation time and issue warnings for slow explainers.
|
|
375
|
+
|
|
376
|
+
Warns users before computationally expensive SHAP calculations to prevent
|
|
377
|
+
unexpected long wait times, especially for KernelExplainer.
|
|
378
|
+
|
|
379
|
+
Parameters
|
|
380
|
+
----------
|
|
381
|
+
explainer_type : str
|
|
382
|
+
Type of explainer being used ("tree", "linear", "kernel", "deep")
|
|
383
|
+
n_samples : int
|
|
384
|
+
Number of samples for SHAP computation
|
|
385
|
+
ms_per_sample : float
|
|
386
|
+
Estimated milliseconds per sample for this explainer type
|
|
387
|
+
performance_warning : bool, default True
|
|
388
|
+
Whether to issue performance warnings. Set to False to disable.
|
|
389
|
+
"""
|
|
390
|
+
if not performance_warning:
|
|
391
|
+
return
|
|
392
|
+
|
|
393
|
+
# Only warn for KernelExplainer (1-10 seconds per sample)
|
|
394
|
+
if explainer_type != "kernel":
|
|
395
|
+
return
|
|
396
|
+
|
|
397
|
+
# Compute estimates
|
|
398
|
+
total_seconds = (n_samples * ms_per_sample) / 1000.0
|
|
399
|
+
threshold_seconds = 10.0 # Warn if >10 seconds
|
|
400
|
+
|
|
401
|
+
if total_seconds < threshold_seconds:
|
|
402
|
+
return
|
|
403
|
+
|
|
404
|
+
# Issue warning with time estimates
|
|
405
|
+
time_str = _format_time(total_seconds)
|
|
406
|
+
|
|
407
|
+
# Suggest max_samples=200 as reasonable default
|
|
408
|
+
recommended_samples = 200
|
|
409
|
+
if n_samples > recommended_samples:
|
|
410
|
+
recommended_seconds = (recommended_samples * ms_per_sample) / 1000.0
|
|
411
|
+
recommended_time_str = _format_time(recommended_seconds)
|
|
412
|
+
|
|
413
|
+
warnings.warn(
|
|
414
|
+
f"KernelExplainer is slow (~{int(ms_per_sample)}ms per sample).\n"
|
|
415
|
+
f"Estimated time: ~{time_str} for {n_samples} samples.\n"
|
|
416
|
+
f"Consider using max_samples={recommended_samples} "
|
|
417
|
+
f"(estimated time: ~{recommended_time_str}).\n"
|
|
418
|
+
f"Or use explainer_type='tree' or 'linear' for faster computation if model supports it.",
|
|
419
|
+
UserWarning,
|
|
420
|
+
stacklevel=3,
|
|
421
|
+
)
|
|
422
|
+
else:
|
|
423
|
+
warnings.warn(
|
|
424
|
+
f"KernelExplainer is slow (~{int(ms_per_sample)}ms per sample).\n"
|
|
425
|
+
f"Estimated time: ~{time_str} for {n_samples} samples.\n"
|
|
426
|
+
f"Consider using explainer_type='tree' or 'linear' for faster computation if model supports it.",
|
|
427
|
+
UserWarning,
|
|
428
|
+
stacklevel=3,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _format_time(seconds: float) -> str:
|
|
433
|
+
"""Format seconds into human-readable string.
|
|
434
|
+
|
|
435
|
+
Parameters
|
|
436
|
+
----------
|
|
437
|
+
seconds : float
|
|
438
|
+
Time in seconds
|
|
439
|
+
|
|
440
|
+
Returns
|
|
441
|
+
-------
|
|
442
|
+
str
|
|
443
|
+
Human-readable time string (e.g., "2 minutes", "1 hour 15 minutes")
|
|
444
|
+
|
|
445
|
+
Examples
|
|
446
|
+
--------
|
|
447
|
+
>>> _format_time(45)
|
|
448
|
+
'45 seconds'
|
|
449
|
+
>>> _format_time(120)
|
|
450
|
+
'2 minutes'
|
|
451
|
+
>>> _format_time(3665)
|
|
452
|
+
'1 hour 1 minute'
|
|
453
|
+
"""
|
|
454
|
+
if seconds < 60:
|
|
455
|
+
return f"{int(seconds)} seconds"
|
|
456
|
+
elif seconds < 3600:
|
|
457
|
+
minutes = int(seconds / 60)
|
|
458
|
+
return f"{minutes} minute{'s' if minutes != 1 else ''}"
|
|
459
|
+
else:
|
|
460
|
+
hours = int(seconds / 3600)
|
|
461
|
+
remaining_minutes = int((seconds % 3600) / 60)
|
|
462
|
+
if remaining_minutes == 0:
|
|
463
|
+
return f"{hours} hour{'s' if hours != 1 else ''}"
|
|
464
|
+
else:
|
|
465
|
+
return f"{hours} hour{'s' if hours != 1 else ''} {remaining_minutes} minute{'s' if remaining_minutes != 1 else ''}"
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def compute_shap_importance(
|
|
469
|
+
model: Any,
|
|
470
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
471
|
+
feature_names: list[str] | None = None,
|
|
472
|
+
check_additivity: bool = True,
|
|
473
|
+
max_samples: int | None = None,
|
|
474
|
+
explainer_type: str = "auto",
|
|
475
|
+
use_gpu: bool | str = "auto",
|
|
476
|
+
background_data: Union["NDArray[Any]", None] = None,
|
|
477
|
+
explainer_kwargs: dict | None = None,
|
|
478
|
+
show_progress: bool = False,
|
|
479
|
+
performance_warning: bool = True,
|
|
480
|
+
) -> dict[str, Any]:
|
|
481
|
+
"""Compute SHAP (SHapley Additive exPlanations) values and aggregate to feature importance.
|
|
482
|
+
|
|
483
|
+
SHAP values provide a unified measure of feature importance based on Shapley values
|
|
484
|
+
from cooperative game theory. Each feature's contribution to a prediction is
|
|
485
|
+
calculated by considering all possible feature coalitions, satisfying key
|
|
486
|
+
properties like additivity and consistency.
|
|
487
|
+
|
|
488
|
+
**Key advantages over MDI and PFI**:
|
|
489
|
+
|
|
490
|
+
- **Theoretically sound**: Based on game theory (Shapley values)
|
|
491
|
+
- **Consistent**: Removing a feature always decreases its importance
|
|
492
|
+
- **Local explanations**: Provides per-prediction feature contributions
|
|
493
|
+
- **Interaction-aware**: Accounts for feature interactions naturally
|
|
494
|
+
- **Unbiased**: No bias toward high-cardinality features (unlike MDI)
|
|
495
|
+
- **Model-agnostic**: Works with ANY sklearn-compatible model (v1.1+)
|
|
496
|
+
|
|
497
|
+
**Multi-Explainer Support**:
|
|
498
|
+
|
|
499
|
+
This function automatically selects the best SHAP explainer for your model:
|
|
500
|
+
|
|
501
|
+
- **TreeExplainer**: Fast, exact computation for tree-based models
|
|
502
|
+
- **LinearExplainer**: Fast, exact computation for linear models
|
|
503
|
+
- **KernelExplainer**: Model-agnostic fallback (slower but universal)
|
|
504
|
+
- **DeepExplainer**: Optimized for neural networks (TensorFlow/PyTorch)
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
model : Any
|
|
509
|
+
Fitted model compatible with SHAP explainers.
|
|
510
|
+
X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
|
|
511
|
+
Feature matrix for SHAP computation (typically test/validation set)
|
|
512
|
+
Shape: (n_samples, n_features)
|
|
513
|
+
feature_names : list[str] | None, default None
|
|
514
|
+
Feature names for labeling. If None, uses column names from DataFrame
|
|
515
|
+
or generates numeric names for arrays
|
|
516
|
+
check_additivity : bool, default True
|
|
517
|
+
Verify that SHAP values sum to model predictions (sanity check).
|
|
518
|
+
Only supported by TreeExplainer. Disable for speed if you trust the
|
|
519
|
+
implementation.
|
|
520
|
+
max_samples : int | None, default None
|
|
521
|
+
Maximum number of samples to compute SHAP values for.
|
|
522
|
+
explainer_type : str, default 'auto'
|
|
523
|
+
SHAP explainer to use:
|
|
524
|
+
- 'auto': Automatic selection (Tree -> Linear -> Kernel cascade)
|
|
525
|
+
- 'tree': Force TreeExplainer
|
|
526
|
+
- 'linear': Force LinearExplainer
|
|
527
|
+
- 'kernel': Force KernelExplainer
|
|
528
|
+
- 'deep': Force DeepExplainer
|
|
529
|
+
use_gpu : Union[bool, str], default 'auto'
|
|
530
|
+
Enable GPU acceleration for SHAP computation
|
|
531
|
+
background_data : np.ndarray | None, default None
|
|
532
|
+
Background dataset for KernelExplainer
|
|
533
|
+
explainer_kwargs : dict | None, default None
|
|
534
|
+
Additional keyword arguments passed to the explainer constructor
|
|
535
|
+
show_progress : bool, default False
|
|
536
|
+
Show progress bar for SHAP computation (requires tqdm)
|
|
537
|
+
performance_warning : bool, default True
|
|
538
|
+
Issue warning if computation will take >10 seconds
|
|
539
|
+
|
|
540
|
+
Returns
|
|
541
|
+
-------
|
|
542
|
+
dict[str, Any]
|
|
543
|
+
Dictionary with SHAP importance results:
|
|
544
|
+
- shap_values: SHAP values array, shape (n_samples, n_features)
|
|
545
|
+
- importances: Mean absolute SHAP values per feature (sorted descending)
|
|
546
|
+
- feature_names: Feature labels (sorted by importance)
|
|
547
|
+
- base_value: Expected model output (average prediction)
|
|
548
|
+
- n_features: Number of features
|
|
549
|
+
- n_samples: Number of samples used for SHAP computation
|
|
550
|
+
- model_type: Type of model used
|
|
551
|
+
- explainer_type: Which explainer was used
|
|
552
|
+
- additivity_verified: Whether additivity check passed
|
|
553
|
+
|
|
554
|
+
Raises
|
|
555
|
+
------
|
|
556
|
+
ImportError
|
|
557
|
+
If shap library not installed
|
|
558
|
+
ValueError
|
|
559
|
+
If model is not supported by specified explainer
|
|
560
|
+
RuntimeError
|
|
561
|
+
If SHAP computation fails
|
|
562
|
+
"""
|
|
563
|
+
# Check if shap is installed
|
|
564
|
+
try:
|
|
565
|
+
import shap # noqa: F401 (availability check)
|
|
566
|
+
except ImportError as e:
|
|
567
|
+
raise ImportError(
|
|
568
|
+
"SHAP library is not installed. Install with: pip install ml4t-diagnostic[ml] or: pip install shap>=0.43.0"
|
|
569
|
+
) from e
|
|
570
|
+
|
|
571
|
+
# Convert X to appropriate format
|
|
572
|
+
if isinstance(X, pl.DataFrame):
|
|
573
|
+
X_array = X.to_numpy()
|
|
574
|
+
if feature_names is None:
|
|
575
|
+
feature_names = X.columns
|
|
576
|
+
elif isinstance(X, pd.DataFrame):
|
|
577
|
+
X_array = X.values
|
|
578
|
+
if feature_names is None:
|
|
579
|
+
feature_names = list(X.columns)
|
|
580
|
+
else:
|
|
581
|
+
X_array = np.asarray(X)
|
|
582
|
+
|
|
583
|
+
# Validate shape before accessing shape[1]
|
|
584
|
+
if X_array.ndim != 2:
|
|
585
|
+
raise ValueError(f"X must be 2D array, got shape {X_array.shape}")
|
|
586
|
+
|
|
587
|
+
# Set default feature names if needed (after shape validation)
|
|
588
|
+
if feature_names is None:
|
|
589
|
+
feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]
|
|
590
|
+
|
|
591
|
+
# Ensure feature_names is a list
|
|
592
|
+
if feature_names is not None:
|
|
593
|
+
feature_names = list(feature_names)
|
|
594
|
+
|
|
595
|
+
n_samples_full, n_features = X_array.shape
|
|
596
|
+
|
|
597
|
+
# Subsample if requested
|
|
598
|
+
if max_samples is not None and n_samples_full > max_samples:
|
|
599
|
+
# Use random sampling for representative subset
|
|
600
|
+
rng = np.random.default_rng(42)
|
|
601
|
+
sample_idx = rng.choice(n_samples_full, size=max_samples, replace=False)
|
|
602
|
+
X_array = X_array[sample_idx]
|
|
603
|
+
n_samples = max_samples
|
|
604
|
+
else:
|
|
605
|
+
n_samples = n_samples_full
|
|
606
|
+
|
|
607
|
+
# Validate feature names length
|
|
608
|
+
if len(feature_names) != n_features:
|
|
609
|
+
raise ValueError(
|
|
610
|
+
f"Number of feature names ({len(feature_names)}) does not match number of features in X ({n_features})"
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
# Get appropriate explainer (auto-selects or uses explicit type)
|
|
614
|
+
if explainer_kwargs is None:
|
|
615
|
+
explainer_kwargs = {}
|
|
616
|
+
|
|
617
|
+
explainer, explainer_type_used, ms_per_sample = _get_explainer(
|
|
618
|
+
model=model,
|
|
619
|
+
X_array=X_array,
|
|
620
|
+
explainer_type=explainer_type,
|
|
621
|
+
use_gpu=use_gpu,
|
|
622
|
+
background_data=background_data,
|
|
623
|
+
**explainer_kwargs,
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
# Issue performance warning if needed
|
|
627
|
+
_estimate_computation_time(
|
|
628
|
+
explainer_type=explainer_type_used,
|
|
629
|
+
n_samples=n_samples,
|
|
630
|
+
ms_per_sample=ms_per_sample,
|
|
631
|
+
performance_warning=performance_warning,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# Compute SHAP values with optional progress bar
|
|
635
|
+
try:
|
|
636
|
+
# Only TreeExplainer supports check_additivity parameter
|
|
637
|
+
shap_kwargs = {}
|
|
638
|
+
if explainer_type_used == "tree":
|
|
639
|
+
shap_kwargs["check_additivity"] = check_additivity
|
|
640
|
+
|
|
641
|
+
if show_progress:
|
|
642
|
+
try:
|
|
643
|
+
from tqdm.auto import tqdm
|
|
644
|
+
|
|
645
|
+
# Wrap computation with progress bar for slow explainers
|
|
646
|
+
if explainer_type_used == "kernel":
|
|
647
|
+
# For kernel, show progress
|
|
648
|
+
with tqdm(total=n_samples, desc="Computing SHAP values") as pbar:
|
|
649
|
+
shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
|
|
650
|
+
pbar.update(n_samples)
|
|
651
|
+
else:
|
|
652
|
+
# For tree/linear/deep, just compute (fast enough)
|
|
653
|
+
shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
|
|
654
|
+
except ImportError:
|
|
655
|
+
# tqdm not available, compute without progress bar
|
|
656
|
+
shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
|
|
657
|
+
else:
|
|
658
|
+
shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
|
|
659
|
+
except Exception as e:
|
|
660
|
+
raise RuntimeError(
|
|
661
|
+
f"Failed to compute SHAP values with {explainer_type_used}Explainer. "
|
|
662
|
+
f"Model type: {type(model).__name__}. Error: {e}"
|
|
663
|
+
) from e
|
|
664
|
+
|
|
665
|
+
# Handle binary classification (returns list of arrays OR 3D array)
|
|
666
|
+
if isinstance(shap_values_raw, list):
|
|
667
|
+
if len(shap_values_raw) == 2:
|
|
668
|
+
# Binary classification (older SHAP versions)
|
|
669
|
+
shap_values = shap_values_raw[1]
|
|
670
|
+
else:
|
|
671
|
+
# Multiclass - use first class for importance
|
|
672
|
+
shap_values = shap_values_raw[0]
|
|
673
|
+
else:
|
|
674
|
+
shap_values = shap_values_raw
|
|
675
|
+
# Handle 3D array for binary/multiclass (newer SHAP versions)
|
|
676
|
+
if shap_values.ndim == 3:
|
|
677
|
+
if shap_values.shape[2] == 2:
|
|
678
|
+
# Binary classification: take positive class (index 1)
|
|
679
|
+
shap_values = shap_values[:, :, 1]
|
|
680
|
+
else:
|
|
681
|
+
# Multiclass: aggregate across classes (mean absolute)
|
|
682
|
+
shap_values = np.mean(np.abs(shap_values), axis=2)
|
|
683
|
+
|
|
684
|
+
# Validate SHAP values shape
|
|
685
|
+
if shap_values.shape != (n_samples, n_features):
|
|
686
|
+
raise RuntimeError(
|
|
687
|
+
f"Unexpected SHAP values shape: {shap_values.shape}, expected ({n_samples}, {n_features})"
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Compute feature importance as mean absolute SHAP value
|
|
691
|
+
importances = np.mean(np.abs(shap_values), axis=0)
|
|
692
|
+
|
|
693
|
+
# Sort by importance (descending)
|
|
694
|
+
sorted_idx = np.argsort(importances)[::-1]
|
|
695
|
+
|
|
696
|
+
# Get base value (expected value)
|
|
697
|
+
base_value = explainer.expected_value
|
|
698
|
+
if isinstance(base_value, list | np.ndarray):
|
|
699
|
+
# For binary/multiclass, take positive class or first class
|
|
700
|
+
base_value = base_value[1] if len(base_value) == 2 else base_value[0]
|
|
701
|
+
|
|
702
|
+
# Determine model type
|
|
703
|
+
model_type = f"{type(model).__module__}.{type(model).__name__}"
|
|
704
|
+
|
|
705
|
+
return {
|
|
706
|
+
"shap_values": shap_values,
|
|
707
|
+
"importances": importances[sorted_idx],
|
|
708
|
+
"feature_names": [feature_names[i] for i in sorted_idx],
|
|
709
|
+
"base_value": float(base_value),
|
|
710
|
+
"n_features": n_features,
|
|
711
|
+
"n_samples": n_samples,
|
|
712
|
+
"model_type": model_type,
|
|
713
|
+
"explainer_type": explainer_type_used,
|
|
714
|
+
"additivity_verified": check_additivity,
|
|
715
|
+
}
|