ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,715 @@
1
+ """SHAP-based feature importance with multi-explainer support.
2
+
3
+ This module provides SHAP value computation with automatic explainer selection
4
+ for tree-based, linear, and model-agnostic approaches.
5
+ """
6
+
7
+ import warnings
8
+ from typing import TYPE_CHECKING, Any, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import polars as pl
13
+
14
+ if TYPE_CHECKING:
15
+ from numpy.typing import NDArray
16
+
17
+
18
+ def _detect_gpu_available() -> bool:
19
+ """Detect if GPU acceleration is available for SHAP computations.
20
+
21
+ GPU acceleration is currently supported only for TreeExplainer with:
22
+ - NVIDIA GPU
23
+ - CUDA 11.0+
24
+ - cupy library installed
25
+
26
+ Returns
27
+ -------
28
+ bool
29
+ True if GPU is available and cupy is installed, False otherwise
30
+
31
+ Notes
32
+ -----
33
+ This function checks for cupy availability as a proxy for GPU support.
34
+ Even if a GPU is present, cupy must be installed for SHAP to use it.
35
+
36
+ GPU acceleration provides 10-100x speedup for large datasets (>5K samples)
37
+ but has overhead that makes it slower for small datasets (<5K samples).
38
+ """
39
+ try:
40
+ import cupy as cp
41
+
42
+ # Check if GPU is actually accessible
43
+ _ = cp.cuda.Device(0)
44
+ return True
45
+ except (ImportError, RuntimeError):
46
+ # ImportError: cupy not installed
47
+ # RuntimeError: CUDA not available or no GPU found
48
+ return False
49
+
50
+
51
+ def _get_explainer(
52
+ model: Any,
53
+ X_array: "NDArray[Any]",
54
+ explainer_type: str = "auto",
55
+ use_gpu: bool | str = "auto",
56
+ background_data: Union["NDArray[Any]", None] = None,
57
+ **explainer_kwargs: Any,
58
+ ) -> tuple[Any, str, float]:
59
+ """Select and create appropriate SHAP explainer for the given model.
60
+
61
+ Implements automatic explainer selection with try-except cascade:
62
+ 1. TreeExplainer (fast, exact, tree models only)
63
+ 2. LinearExplainer (fast, exact, linear models only)
64
+ 3. KernelExplainer (slow, approximate, model-agnostic fallback)
65
+
66
+ DeepExplainer is NOT included in auto-selection because it requires
67
+ explicit background data specification. Use explainer_type="deep" explicitly.
68
+
69
+ Parameters
70
+ ----------
71
+ model : Any
72
+ Fitted model to explain
73
+ X_array : np.ndarray
74
+ Feature matrix for SHAP computation
75
+ explainer_type : str, default "auto"
76
+ Explainer type to use:
77
+ - "auto": Try tree -> linear -> kernel (recommended)
78
+ - "tree": TreeExplainer (tree models only)
79
+ - "linear": LinearExplainer (linear models only)
80
+ - "deep": DeepExplainer (neural networks, requires background_data)
81
+ - "kernel": KernelExplainer (model-agnostic, slow)
82
+ use_gpu : bool | str, default "auto"
83
+ GPU acceleration mode (TreeExplainer only):
84
+ - "auto": Use GPU if available and dataset large enough (>5K samples)
85
+ - True: Force GPU usage (raises error if unavailable)
86
+ - False: Force CPU usage
87
+ background_data : np.ndarray | None, default None
88
+ Background dataset for explainers that need it (Kernel, Deep).
89
+ If None, will be auto-sampled from X_array for Kernel.
90
+ Required for Deep explainer.
91
+ **explainer_kwargs : Any
92
+ Additional keyword arguments passed to explainer constructor
93
+
94
+ Returns
95
+ -------
96
+ tuple[Any, str, float]
97
+ - explainer: Initialized SHAP explainer instance
98
+ - type_name: Name of explainer type used ("tree", "linear", "kernel", "deep")
99
+ - ms_per_sample: Estimated milliseconds per sample for performance warnings
100
+
101
+ Raises
102
+ ------
103
+ ImportError
104
+ If shap library not installed
105
+ ValueError
106
+ If explainer_type is invalid or if auto-selection fails for all explainers
107
+ RuntimeError
108
+ If GPU requested but unavailable
109
+ """
110
+ try:
111
+ import shap
112
+ except ImportError as e:
113
+ raise ImportError(
114
+ "SHAP library is not installed. Install with: pip install ml4t-diagnostic[ml] or: pip install shap>=0.41.0"
115
+ ) from e
116
+
117
+ # Validate explainer_type
118
+ valid_types = {"auto", "tree", "linear", "deep", "kernel"}
119
+ if explainer_type not in valid_types:
120
+ raise ValueError(
121
+ f"Invalid explainer_type '{explainer_type}'. Must be one of: {', '.join(sorted(valid_types))}"
122
+ )
123
+
124
+ # Handle GPU detection and configuration
125
+ gpu_available = _detect_gpu_available()
126
+ use_gpu_final = False
127
+
128
+ if use_gpu == "auto":
129
+ # Auto-detect: Use GPU if available AND dataset large enough
130
+ n_samples = X_array.shape[0]
131
+ use_gpu_final = gpu_available and n_samples >= 5000
132
+ elif use_gpu is True:
133
+ if not gpu_available:
134
+ raise RuntimeError(
135
+ "GPU requested (use_gpu=True) but GPU not available. "
136
+ "Ensure NVIDIA GPU, CUDA 11.0+, and cupy are installed. "
137
+ "Install with: pip install ml4t-diagnostic[gpu]"
138
+ )
139
+ use_gpu_final = True
140
+ else: # use_gpu is False
141
+ use_gpu_final = False
142
+
143
+ # Explicit explainer type requested
144
+ if explainer_type != "auto":
145
+ return _create_explainer_by_type(
146
+ explainer_type=explainer_type,
147
+ model=model,
148
+ X_array=X_array,
149
+ use_gpu=use_gpu_final,
150
+ background_data=background_data,
151
+ shap=shap,
152
+ **explainer_kwargs,
153
+ )
154
+
155
+ # Auto-selection cascade: Tree -> Linear -> Kernel
156
+ errors = []
157
+
158
+ # Try TreeExplainer first (fastest, most common)
159
+ try:
160
+ tree_kwargs = {"feature_perturbation": "tree_path_dependent"}
161
+ tree_kwargs.update(explainer_kwargs) # User kwargs override defaults
162
+
163
+ explainer = shap.TreeExplainer(model, **tree_kwargs)
164
+ # GPU mode only for tree explainer
165
+ if use_gpu_final and hasattr(explainer, "gpu"):
166
+ setattr(explainer, "gpu", True) # noqa: B010
167
+ ms_per_sample = 5.0 # ~1-10ms typical
168
+ return (explainer, "tree", ms_per_sample)
169
+ except Exception as e:
170
+ errors.append(f"TreeExplainer: {e}")
171
+
172
+ # Try LinearExplainer second (fast, exact for linear models)
173
+ try:
174
+ explainer = shap.LinearExplainer(model, X_array, **explainer_kwargs)
175
+ ms_per_sample = 75.0 # ~50-100ms typical
176
+ return (explainer, "linear", ms_per_sample)
177
+ except Exception as e:
178
+ errors.append(f"LinearExplainer: {e}")
179
+
180
+ # Try KernelExplainer as fallback (slow but model-agnostic)
181
+ try:
182
+ # Sample background data if not provided
183
+ if background_data is None:
184
+ background_data = _sample_background(X_array, max_samples=100, method="random")
185
+
186
+ # Create prediction function wrapper to avoid LightGBM property issues
187
+ if hasattr(model, "predict_proba"):
188
+ # For binary classification, return probability of positive class
189
+ def predict_fn(X):
190
+ proba = model.predict_proba(X)
191
+ if proba.shape[1] == 2:
192
+ return proba[:, 1] # Binary: positive class
193
+ return proba # Multiclass: all classes
194
+ else:
195
+ predict_fn = model.predict
196
+
197
+ explainer = shap.KernelExplainer(predict_fn, background_data, **explainer_kwargs)
198
+ ms_per_sample = 5000.0 # ~1-10 seconds typical
199
+ return (explainer, "kernel", ms_per_sample)
200
+ except Exception as e:
201
+ errors.append(f"KernelExplainer: {e}")
202
+
203
+ # All explainers failed
204
+ error_summary = "\n - ".join(errors)
205
+ raise ValueError(
206
+ f"Failed to create explainer for model type {type(model).__name__}. "
207
+ f"Tried tree, linear, and kernel explainers. Errors:\n - {error_summary}\n"
208
+ f"Consider using explainer_type='kernel' explicitly with custom background_data."
209
+ )
210
+
211
+
212
+ def _create_explainer_by_type(
213
+ explainer_type: str,
214
+ model: Any,
215
+ X_array: "NDArray[Any]",
216
+ use_gpu: bool,
217
+ background_data: Union["NDArray[Any]", None],
218
+ shap: Any,
219
+ **explainer_kwargs: Any,
220
+ ) -> tuple[Any, str, float]:
221
+ """Create specific explainer type (helper for _get_explainer).
222
+
223
+ Parameters
224
+ ----------
225
+ explainer_type : str
226
+ One of: "tree", "linear", "deep", "kernel"
227
+ model : Any
228
+ Fitted model
229
+ X_array : np.ndarray
230
+ Feature matrix
231
+ use_gpu : bool
232
+ Whether to use GPU (tree only)
233
+ background_data : np.ndarray | None
234
+ Background data for kernel/deep explainers
235
+ shap : module
236
+ Imported shap module
237
+ **explainer_kwargs : Any
238
+ Additional explainer arguments
239
+
240
+ Returns
241
+ -------
242
+ tuple[Any, str, float]
243
+ (explainer, type_name, ms_per_sample)
244
+
245
+ Raises
246
+ ------
247
+ ValueError
248
+ If explainer creation fails
249
+ ImportError
250
+ If deep learning dependencies not available
251
+ """
252
+ try:
253
+ if explainer_type == "tree":
254
+ # Set default feature_perturbation unless user overrides
255
+ tree_kwargs = {"feature_perturbation": "tree_path_dependent"}
256
+ tree_kwargs.update(explainer_kwargs) # User kwargs override defaults
257
+
258
+ explainer = shap.TreeExplainer(model, **tree_kwargs)
259
+ if use_gpu and hasattr(explainer, "gpu"):
260
+ explainer.gpu = True
261
+ ms_per_sample = 5.0
262
+ return (explainer, "tree", ms_per_sample)
263
+
264
+ elif explainer_type == "linear":
265
+ explainer = shap.LinearExplainer(model, X_array, **explainer_kwargs)
266
+ ms_per_sample = 75.0
267
+ return (explainer, "linear", ms_per_sample)
268
+
269
+ elif explainer_type == "deep":
270
+ if background_data is None:
271
+ raise ValueError(
272
+ "DeepExplainer requires background_data parameter. "
273
+ "Provide a representative sample of your training data "
274
+ "(typically 100-1000 samples)."
275
+ )
276
+ try:
277
+ explainer = shap.DeepExplainer(model, background_data, **explainer_kwargs)
278
+ except ImportError as e:
279
+ raise ImportError(
280
+ "DeepExplainer requires deep learning libraries (TensorFlow or PyTorch). "
281
+ "Install with: pip install ml4t-diagnostic[deep]"
282
+ ) from e
283
+ ms_per_sample = 500.0 # ~100ms-1s typical
284
+ return (explainer, "deep", ms_per_sample)
285
+
286
+ elif explainer_type == "kernel":
287
+ if background_data is None:
288
+ background_data = _sample_background(X_array, max_samples=100, method="random")
289
+
290
+ # Create prediction function wrapper to avoid LightGBM property issues
291
+ # For classifiers, use predict_proba if available (more informative)
292
+ if hasattr(model, "predict_proba"):
293
+ # For binary classification, return probability of positive class
294
+ def predict_fn(X):
295
+ proba = model.predict_proba(X)
296
+ if proba.shape[1] == 2:
297
+ return proba[:, 1] # Binary: positive class
298
+ return proba # Multiclass: all classes
299
+ else:
300
+ predict_fn = model.predict
301
+
302
+ explainer = shap.KernelExplainer(predict_fn, background_data, **explainer_kwargs)
303
+ ms_per_sample = 5000.0
304
+ return (explainer, "kernel", ms_per_sample)
305
+
306
+ else:
307
+ raise ValueError(f"Unknown explainer_type: {explainer_type}")
308
+
309
+ except Exception as e:
310
+ raise ValueError(
311
+ f"Failed to create {explainer_type.capitalize()}Explainer for model type {type(model).__name__}. Error: {e}"
312
+ ) from e
313
+
314
+
315
+ def _sample_background(
316
+ X_array: "NDArray[Any]", max_samples: int = 100, method: str = "random"
317
+ ) -> "NDArray[Any]":
318
+ """Sample background dataset for KernelExplainer.
319
+
320
+ Background data represents "typical" feature values used as reference
321
+ for computing SHAP values. Smaller backgrounds = faster computation.
322
+
323
+ Parameters
324
+ ----------
325
+ X_array : np.ndarray
326
+ Full feature matrix
327
+ max_samples : int, default 100
328
+ Maximum number of background samples
329
+ method : str, default "random"
330
+ Sampling method: "random" or "kmeans"
331
+
332
+ Returns
333
+ -------
334
+ np.ndarray
335
+ Background dataset (max_samples, n_features)
336
+
337
+ Notes
338
+ -----
339
+ - Random: Fast, simple, works well for most cases
340
+ - K-means: Better representation of data distribution, slower
341
+ """
342
+ n_samples = X_array.shape[0]
343
+
344
+ if n_samples <= max_samples:
345
+ return X_array
346
+
347
+ if method == "random":
348
+ rng = np.random.default_rng(42)
349
+ idx = rng.choice(n_samples, size=max_samples, replace=False)
350
+ return X_array[idx]
351
+ elif method == "kmeans":
352
+ # K-means clustering for representative samples
353
+ try:
354
+ from sklearn.cluster import KMeans
355
+
356
+ kmeans = KMeans(n_clusters=max_samples, random_state=42, n_init=10)
357
+ kmeans.fit(X_array)
358
+ return kmeans.cluster_centers_
359
+ except ImportError:
360
+ # Fallback to random if sklearn not available
361
+ rng = np.random.default_rng(42)
362
+ idx = rng.choice(n_samples, size=max_samples, replace=False)
363
+ return X_array[idx]
364
+ else:
365
+ raise ValueError(f"Unknown sampling method: {method}. Use 'random' or 'kmeans'.")
366
+
367
+
368
+ def _estimate_computation_time(
369
+ explainer_type: str,
370
+ n_samples: int,
371
+ ms_per_sample: float,
372
+ performance_warning: bool = True,
373
+ ) -> None:
374
+ """Estimate SHAP computation time and issue warnings for slow explainers.
375
+
376
+ Warns users before computationally expensive SHAP calculations to prevent
377
+ unexpected long wait times, especially for KernelExplainer.
378
+
379
+ Parameters
380
+ ----------
381
+ explainer_type : str
382
+ Type of explainer being used ("tree", "linear", "kernel", "deep")
383
+ n_samples : int
384
+ Number of samples for SHAP computation
385
+ ms_per_sample : float
386
+ Estimated milliseconds per sample for this explainer type
387
+ performance_warning : bool, default True
388
+ Whether to issue performance warnings. Set to False to disable.
389
+ """
390
+ if not performance_warning:
391
+ return
392
+
393
+ # Only warn for KernelExplainer (1-10 seconds per sample)
394
+ if explainer_type != "kernel":
395
+ return
396
+
397
+ # Compute estimates
398
+ total_seconds = (n_samples * ms_per_sample) / 1000.0
399
+ threshold_seconds = 10.0 # Warn if >10 seconds
400
+
401
+ if total_seconds < threshold_seconds:
402
+ return
403
+
404
+ # Issue warning with time estimates
405
+ time_str = _format_time(total_seconds)
406
+
407
+ # Suggest max_samples=200 as reasonable default
408
+ recommended_samples = 200
409
+ if n_samples > recommended_samples:
410
+ recommended_seconds = (recommended_samples * ms_per_sample) / 1000.0
411
+ recommended_time_str = _format_time(recommended_seconds)
412
+
413
+ warnings.warn(
414
+ f"KernelExplainer is slow (~{int(ms_per_sample)}ms per sample).\n"
415
+ f"Estimated time: ~{time_str} for {n_samples} samples.\n"
416
+ f"Consider using max_samples={recommended_samples} "
417
+ f"(estimated time: ~{recommended_time_str}).\n"
418
+ f"Or use explainer_type='tree' or 'linear' for faster computation if model supports it.",
419
+ UserWarning,
420
+ stacklevel=3,
421
+ )
422
+ else:
423
+ warnings.warn(
424
+ f"KernelExplainer is slow (~{int(ms_per_sample)}ms per sample).\n"
425
+ f"Estimated time: ~{time_str} for {n_samples} samples.\n"
426
+ f"Consider using explainer_type='tree' or 'linear' for faster computation if model supports it.",
427
+ UserWarning,
428
+ stacklevel=3,
429
+ )
430
+
431
+
432
+ def _format_time(seconds: float) -> str:
433
+ """Format seconds into human-readable string.
434
+
435
+ Parameters
436
+ ----------
437
+ seconds : float
438
+ Time in seconds
439
+
440
+ Returns
441
+ -------
442
+ str
443
+ Human-readable time string (e.g., "2 minutes", "1 hour 15 minutes")
444
+
445
+ Examples
446
+ --------
447
+ >>> _format_time(45)
448
+ '45 seconds'
449
+ >>> _format_time(120)
450
+ '2 minutes'
451
+ >>> _format_time(3665)
452
+ '1 hour 1 minute'
453
+ """
454
+ if seconds < 60:
455
+ return f"{int(seconds)} seconds"
456
+ elif seconds < 3600:
457
+ minutes = int(seconds / 60)
458
+ return f"{minutes} minute{'s' if minutes != 1 else ''}"
459
+ else:
460
+ hours = int(seconds / 3600)
461
+ remaining_minutes = int((seconds % 3600) / 60)
462
+ if remaining_minutes == 0:
463
+ return f"{hours} hour{'s' if hours != 1 else ''}"
464
+ else:
465
+ return f"{hours} hour{'s' if hours != 1 else ''} {remaining_minutes} minute{'s' if remaining_minutes != 1 else ''}"
466
+
467
+
468
+ def compute_shap_importance(
469
+ model: Any,
470
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
471
+ feature_names: list[str] | None = None,
472
+ check_additivity: bool = True,
473
+ max_samples: int | None = None,
474
+ explainer_type: str = "auto",
475
+ use_gpu: bool | str = "auto",
476
+ background_data: Union["NDArray[Any]", None] = None,
477
+ explainer_kwargs: dict | None = None,
478
+ show_progress: bool = False,
479
+ performance_warning: bool = True,
480
+ ) -> dict[str, Any]:
481
+ """Compute SHAP (SHapley Additive exPlanations) values and aggregate to feature importance.
482
+
483
+ SHAP values provide a unified measure of feature importance based on Shapley values
484
+ from cooperative game theory. Each feature's contribution to a prediction is
485
+ calculated by considering all possible feature coalitions, satisfying key
486
+ properties like additivity and consistency.
487
+
488
+ **Key advantages over MDI and PFI**:
489
+
490
+ - **Theoretically sound**: Based on game theory (Shapley values)
491
+ - **Consistent**: Removing a feature always decreases its importance
492
+ - **Local explanations**: Provides per-prediction feature contributions
493
+ - **Interaction-aware**: Accounts for feature interactions naturally
494
+ - **Unbiased**: No bias toward high-cardinality features (unlike MDI)
495
+ - **Model-agnostic**: Works with ANY sklearn-compatible model (v1.1+)
496
+
497
+ **Multi-Explainer Support**:
498
+
499
+ This function automatically selects the best SHAP explainer for your model:
500
+
501
+ - **TreeExplainer**: Fast, exact computation for tree-based models
502
+ - **LinearExplainer**: Fast, exact computation for linear models
503
+ - **KernelExplainer**: Model-agnostic fallback (slower but universal)
504
+ - **DeepExplainer**: Optimized for neural networks (TensorFlow/PyTorch)
505
+
506
+ Parameters
507
+ ----------
508
+ model : Any
509
+ Fitted model compatible with SHAP explainers.
510
+ X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
511
+ Feature matrix for SHAP computation (typically test/validation set)
512
+ Shape: (n_samples, n_features)
513
+ feature_names : list[str] | None, default None
514
+ Feature names for labeling. If None, uses column names from DataFrame
515
+ or generates numeric names for arrays
516
+ check_additivity : bool, default True
517
+ Verify that SHAP values sum to model predictions (sanity check).
518
+ Only supported by TreeExplainer. Disable for speed if you trust the
519
+ implementation.
520
+ max_samples : int | None, default None
521
+ Maximum number of samples to compute SHAP values for.
522
+ explainer_type : str, default 'auto'
523
+ SHAP explainer to use:
524
+ - 'auto': Automatic selection (Tree -> Linear -> Kernel cascade)
525
+ - 'tree': Force TreeExplainer
526
+ - 'linear': Force LinearExplainer
527
+ - 'kernel': Force KernelExplainer
528
+ - 'deep': Force DeepExplainer
529
+ use_gpu : Union[bool, str], default 'auto'
530
+ Enable GPU acceleration for SHAP computation
531
+ background_data : np.ndarray | None, default None
532
+ Background dataset for KernelExplainer
533
+ explainer_kwargs : dict | None, default None
534
+ Additional keyword arguments passed to the explainer constructor
535
+ show_progress : bool, default False
536
+ Show progress bar for SHAP computation (requires tqdm)
537
+ performance_warning : bool, default True
538
+ Issue warning if computation will take >10 seconds
539
+
540
+ Returns
541
+ -------
542
+ dict[str, Any]
543
+ Dictionary with SHAP importance results:
544
+ - shap_values: SHAP values array, shape (n_samples, n_features)
545
+ - importances: Mean absolute SHAP values per feature (sorted descending)
546
+ - feature_names: Feature labels (sorted by importance)
547
+ - base_value: Expected model output (average prediction)
548
+ - n_features: Number of features
549
+ - n_samples: Number of samples used for SHAP computation
550
+ - model_type: Type of model used
551
+ - explainer_type: Which explainer was used
552
+ - additivity_verified: Whether additivity check passed
553
+
554
+ Raises
555
+ ------
556
+ ImportError
557
+ If shap library not installed
558
+ ValueError
559
+ If model is not supported by specified explainer
560
+ RuntimeError
561
+ If SHAP computation fails
562
+ """
563
+ # Check if shap is installed
564
+ try:
565
+ import shap # noqa: F401 (availability check)
566
+ except ImportError as e:
567
+ raise ImportError(
568
+ "SHAP library is not installed. Install with: pip install ml4t-diagnostic[ml] or: pip install shap>=0.43.0"
569
+ ) from e
570
+
571
+ # Convert X to appropriate format
572
+ if isinstance(X, pl.DataFrame):
573
+ X_array = X.to_numpy()
574
+ if feature_names is None:
575
+ feature_names = X.columns
576
+ elif isinstance(X, pd.DataFrame):
577
+ X_array = X.values
578
+ if feature_names is None:
579
+ feature_names = list(X.columns)
580
+ else:
581
+ X_array = np.asarray(X)
582
+
583
+ # Validate shape before accessing shape[1]
584
+ if X_array.ndim != 2:
585
+ raise ValueError(f"X must be 2D array, got shape {X_array.shape}")
586
+
587
+ # Set default feature names if needed (after shape validation)
588
+ if feature_names is None:
589
+ feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]
590
+
591
+ # Ensure feature_names is a list
592
+ if feature_names is not None:
593
+ feature_names = list(feature_names)
594
+
595
+ n_samples_full, n_features = X_array.shape
596
+
597
+ # Subsample if requested
598
+ if max_samples is not None and n_samples_full > max_samples:
599
+ # Use random sampling for representative subset
600
+ rng = np.random.default_rng(42)
601
+ sample_idx = rng.choice(n_samples_full, size=max_samples, replace=False)
602
+ X_array = X_array[sample_idx]
603
+ n_samples = max_samples
604
+ else:
605
+ n_samples = n_samples_full
606
+
607
+ # Validate feature names length
608
+ if len(feature_names) != n_features:
609
+ raise ValueError(
610
+ f"Number of feature names ({len(feature_names)}) does not match number of features in X ({n_features})"
611
+ )
612
+
613
+ # Get appropriate explainer (auto-selects or uses explicit type)
614
+ if explainer_kwargs is None:
615
+ explainer_kwargs = {}
616
+
617
+ explainer, explainer_type_used, ms_per_sample = _get_explainer(
618
+ model=model,
619
+ X_array=X_array,
620
+ explainer_type=explainer_type,
621
+ use_gpu=use_gpu,
622
+ background_data=background_data,
623
+ **explainer_kwargs,
624
+ )
625
+
626
+ # Issue performance warning if needed
627
+ _estimate_computation_time(
628
+ explainer_type=explainer_type_used,
629
+ n_samples=n_samples,
630
+ ms_per_sample=ms_per_sample,
631
+ performance_warning=performance_warning,
632
+ )
633
+
634
+ # Compute SHAP values with optional progress bar
635
+ try:
636
+ # Only TreeExplainer supports check_additivity parameter
637
+ shap_kwargs = {}
638
+ if explainer_type_used == "tree":
639
+ shap_kwargs["check_additivity"] = check_additivity
640
+
641
+ if show_progress:
642
+ try:
643
+ from tqdm.auto import tqdm
644
+
645
+ # Wrap computation with progress bar for slow explainers
646
+ if explainer_type_used == "kernel":
647
+ # For kernel, show progress
648
+ with tqdm(total=n_samples, desc="Computing SHAP values") as pbar:
649
+ shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
650
+ pbar.update(n_samples)
651
+ else:
652
+ # For tree/linear/deep, just compute (fast enough)
653
+ shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
654
+ except ImportError:
655
+ # tqdm not available, compute without progress bar
656
+ shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
657
+ else:
658
+ shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
659
+ except Exception as e:
660
+ raise RuntimeError(
661
+ f"Failed to compute SHAP values with {explainer_type_used}Explainer. "
662
+ f"Model type: {type(model).__name__}. Error: {e}"
663
+ ) from e
664
+
665
+ # Handle binary classification (returns list of arrays OR 3D array)
666
+ if isinstance(shap_values_raw, list):
667
+ if len(shap_values_raw) == 2:
668
+ # Binary classification (older SHAP versions)
669
+ shap_values = shap_values_raw[1]
670
+ else:
671
+ # Multiclass - use first class for importance
672
+ shap_values = shap_values_raw[0]
673
+ else:
674
+ shap_values = shap_values_raw
675
+ # Handle 3D array for binary/multiclass (newer SHAP versions)
676
+ if shap_values.ndim == 3:
677
+ if shap_values.shape[2] == 2:
678
+ # Binary classification: take positive class (index 1)
679
+ shap_values = shap_values[:, :, 1]
680
+ else:
681
+ # Multiclass: aggregate across classes (mean absolute)
682
+ shap_values = np.mean(np.abs(shap_values), axis=2)
683
+
684
+ # Validate SHAP values shape
685
+ if shap_values.shape != (n_samples, n_features):
686
+ raise RuntimeError(
687
+ f"Unexpected SHAP values shape: {shap_values.shape}, expected ({n_samples}, {n_features})"
688
+ )
689
+
690
+ # Compute feature importance as mean absolute SHAP value
691
+ importances = np.mean(np.abs(shap_values), axis=0)
692
+
693
+ # Sort by importance (descending)
694
+ sorted_idx = np.argsort(importances)[::-1]
695
+
696
+ # Get base value (expected value)
697
+ base_value = explainer.expected_value
698
+ if isinstance(base_value, list | np.ndarray):
699
+ # For binary/multiclass, take positive class or first class
700
+ base_value = base_value[1] if len(base_value) == 2 else base_value[0]
701
+
702
+ # Determine model type
703
+ model_type = f"{type(model).__module__}.{type(model).__name__}"
704
+
705
+ return {
706
+ "shap_values": shap_values,
707
+ "importances": importances[sorted_idx],
708
+ "feature_names": [feature_names[i] for i in sorted_idx],
709
+ "base_value": float(base_value),
710
+ "n_features": n_features,
711
+ "n_samples": n_samples,
712
+ "model_type": model_type,
713
+ "explainer_type": explainer_type_used,
714
+ "additivity_verified": check_additivity,
715
+ }