ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,588 @@
1
+ """Trade-level SHAP diagnostics for ML trading feedback loop.
2
+
3
+ Connects SHAP values to trade outcomes for systematic debugging and improvement.
4
+
5
+ This module is a thin wrapper around the modular trade_shap package.
6
+ Implementation has been refactored into:
7
+ - ml4t.diagnostic.evaluation.trade_shap.models (data models)
8
+ - ml4t.diagnostic.evaluation.trade_shap.pipeline (TradeShapPipeline)
9
+ - ml4t.diagnostic.evaluation.trade_shap.explain (TradeShapExplainer)
10
+ - ml4t.diagnostic.evaluation.trade_shap.cluster (HierarchicalClusterer)
11
+ - ml4t.diagnostic.evaluation.trade_shap.characterize (PatternCharacterizer)
12
+ - ml4t.diagnostic.evaluation.trade_shap.hypotheses (HypothesisGenerator)
13
+
14
+ Example:
15
+ >>> analyzer = TradeShapAnalyzer(model, features_df, shap_values)
16
+ >>> result = analyzer.explain_worst_trades(worst_trades)
17
+ >>> for pattern in result.error_patterns:
18
+ ... print(pattern.hypothesis, pattern.actions)
19
+
20
+ See: docs/trimmed/evaluation/trade_shap_diagnostics.md for full documentation.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from typing import TYPE_CHECKING, Any
26
+
27
+ import numpy as np
28
+ import polars as pl
29
+ from numpy.typing import NDArray
30
+
31
+ # Re-export all models and components from modular package
32
+ from ml4t.diagnostic.evaluation.trade_shap import (
33
+ # Alignment
34
+ AlignmentResult,
35
+ # Characterization
36
+ CharacterizationConfig,
37
+ # Clustering
38
+ ClusteringConfig,
39
+ ClusteringResult,
40
+ # Result models
41
+ ErrorPattern,
42
+ FeatureStatistics,
43
+ HierarchicalClusterer,
44
+ # Hypothesis generation
45
+ HypothesisConfig,
46
+ HypothesisGenerator,
47
+ # Normalization
48
+ NormalizationType,
49
+ PatternCharacterizer,
50
+ Template,
51
+ TemplateMatcher,
52
+ TimestampAligner,
53
+ TradeExplainFailure,
54
+ # Explainer
55
+ TradeShapExplainer,
56
+ TradeShapExplanation,
57
+ # Pipeline
58
+ TradeShapPipeline,
59
+ TradeShapPipelineConfig,
60
+ TradeShapResult,
61
+ benjamini_hochberg,
62
+ compute_centroids,
63
+ compute_cluster_sizes,
64
+ find_optimal_clusters,
65
+ load_templates,
66
+ normalize,
67
+ normalize_l1,
68
+ normalize_l2,
69
+ standardize,
70
+ )
71
+
72
+ if TYPE_CHECKING:
73
+ from ml4t.diagnostic.config import TradeConfig
74
+ from ml4t.diagnostic.evaluation.trade_analysis import TradeMetrics
75
+
76
+
77
+ class TradeShapAnalyzer:
78
+ """Analyze trade failures using SHAP explanations.
79
+
80
+ This class wraps TradeShapPipeline with additional features:
81
+ - On-demand SHAP value computation from a model
82
+ - Pandas/Polars DataFrame conversion
83
+ - GPU acceleration support
84
+
85
+ For simpler use cases with pre-computed SHAP values, use TradeShapPipeline
86
+ directly.
87
+
88
+ Example:
89
+ >>> analyzer = TradeShapAnalyzer(model, features_df, shap_values)
90
+ >>> result = analyzer.explain_worst_trades(worst_trades)
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ model: Any,
96
+ features_df: pl.DataFrame | Any,
97
+ shap_values: NDArray[np.floating[Any]] | None = None,
98
+ config: TradeConfig | None = None,
99
+ explainer_type: str = "auto",
100
+ use_gpu: bool | str = "auto",
101
+ background_data: NDArray[Any] | None = None,
102
+ explainer_kwargs: dict | None = None,
103
+ show_progress: bool = False,
104
+ performance_warning: bool = True,
105
+ ):
106
+ """Initialize with model, features DataFrame, and optional SHAP values.
107
+
108
+ Args:
109
+ model: Trained model for SHAP computation
110
+ features_df: DataFrame with 'timestamp' column and feature columns
111
+ shap_values: Pre-computed SHAP values (optional, computed if None)
112
+ config: TradeConfig for analysis parameters
113
+ explainer_type: SHAP explainer type ('auto', 'tree', 'kernel', etc.)
114
+ use_gpu: Whether to use GPU acceleration
115
+ background_data: Background data for SHAP computation
116
+ explainer_kwargs: Additional kwargs for SHAP explainer
117
+ show_progress: Show progress bars during computation
118
+ performance_warning: Warn about performance issues
119
+ """
120
+ self.model = model
121
+ self.features_df = self._validate_and_convert_features(features_df)
122
+ self.shap_values = shap_values
123
+ self.config = config or self._get_default_config()
124
+
125
+ # Store API parameters for on-demand SHAP computation
126
+ self._explainer_type = explainer_type
127
+ self._use_gpu = use_gpu
128
+ self._background_data = background_data
129
+ self._explainer_kwargs = explainer_kwargs or {}
130
+ self._show_progress = show_progress
131
+ self._performance_warning = performance_warning
132
+
133
+ # Extract feature names
134
+ self.feature_names = self._extract_feature_names()
135
+
136
+ # Validate SHAP values if provided
137
+ if self.shap_values is not None:
138
+ self._validate_shap_values()
139
+
140
+ # Pipeline created lazily after SHAP values are available
141
+ self._pipeline: TradeShapPipeline | None = None
142
+ self._hypothesis_generator: HypothesisGenerator | None = None
143
+
144
+ def _validate_and_convert_features(self, features_df: Any) -> pl.DataFrame:
145
+ """Validate and convert features DataFrame to Polars."""
146
+ if not isinstance(features_df, pl.DataFrame):
147
+ import pandas as pd
148
+
149
+ if isinstance(features_df, pd.DataFrame):
150
+ features_df = pl.from_pandas(features_df)
151
+ else:
152
+ raise TypeError(
153
+ f"features_df must be pl.DataFrame or pd.DataFrame, got {type(features_df)}"
154
+ )
155
+
156
+ if "timestamp" not in features_df.columns:
157
+ raise ValueError(
158
+ "features_df must have 'timestamp' column for SHAP alignment to trades."
159
+ )
160
+
161
+ return features_df
162
+
163
+ def _extract_feature_names(self) -> list[str]:
164
+ """Extract feature names from DataFrame."""
165
+ feature_names = [col for col in self.features_df.columns if col != "timestamp"]
166
+ if not feature_names:
167
+ raise ValueError("No feature columns found in features_df.")
168
+ return feature_names
169
+
170
+ def _validate_shap_values(self) -> None:
171
+ """Validate SHAP values shape matches features."""
172
+ if self.shap_values is None:
173
+ return
174
+
175
+ n_samples = len(self.features_df)
176
+ n_features = len(self.feature_names)
177
+
178
+ if self.shap_values.shape != (n_samples, n_features):
179
+ raise ValueError(
180
+ f"SHAP values shape {self.shap_values.shape} doesn't match "
181
+ f"features_df shape ({n_samples}, {n_features})."
182
+ )
183
+
184
+ def _get_default_config(self) -> TradeConfig:
185
+ """Get default configuration."""
186
+ from ml4t.diagnostic.config import TradeConfig
187
+
188
+ return TradeConfig()
189
+
190
+ def _compute_shap_values(self) -> None:
191
+ """Compute SHAP values on-demand if not provided."""
192
+ from ml4t.diagnostic.evaluation.metrics import compute_shap_importance
193
+
194
+ feature_cols = [col for col in self.features_df.columns if col != "timestamp"]
195
+ features_df = self.features_df.select(feature_cols)
196
+
197
+ result = compute_shap_importance(
198
+ model=self.model,
199
+ X=features_df,
200
+ feature_names=feature_cols,
201
+ explainer_type=self._explainer_type,
202
+ use_gpu=self._use_gpu,
203
+ background_data=self._background_data,
204
+ show_progress=self._show_progress,
205
+ explainer_kwargs=self._explainer_kwargs,
206
+ )
207
+
208
+ self.shap_values = result["shap_values"]
209
+
210
+ def _ensure_pipeline(self) -> TradeShapPipeline:
211
+ """Ensure pipeline is initialized with SHAP values."""
212
+ if self._pipeline is None:
213
+ # Compute SHAP values if not provided
214
+ if self.shap_values is None:
215
+ self._compute_shap_values()
216
+
217
+ # Build pipeline config from TradeConfig
218
+ # Check for nested alignment config
219
+ alignment_cfg = getattr(self.config, "alignment", None)
220
+ if alignment_cfg is not None:
221
+ # AlignmentSettings has: tolerance, mode, missing_strategy, top_n_features
222
+ tolerance = getattr(alignment_cfg, "tolerance", 0.0)
223
+ mode = getattr(alignment_cfg, "mode", "entry")
224
+ missing_strategy = getattr(alignment_cfg, "missing_strategy", "skip")
225
+ top_n = getattr(alignment_cfg, "top_n_features", 10)
226
+ normalization = getattr(alignment_cfg, "normalization", "l2")
227
+ else:
228
+ tolerance = getattr(self.config, "alignment_tolerance_seconds", 0.0)
229
+ mode = getattr(self.config, "alignment_mode", "entry")
230
+ missing_strategy = getattr(self.config, "missing_value_strategy", "skip")
231
+ top_n = getattr(self.config, "top_n_features", 10)
232
+ normalization = getattr(self.config, "normalization", "l2")
233
+
234
+ pipeline_config = TradeShapPipelineConfig(
235
+ alignment_tolerance_seconds=tolerance,
236
+ alignment_mode=mode,
237
+ missing_value_strategy=missing_strategy,
238
+ top_n_features=top_n,
239
+ normalization=normalization,
240
+ )
241
+
242
+ self._pipeline = TradeShapPipeline(
243
+ features_df=self.features_df,
244
+ shap_values=self.shap_values,
245
+ feature_names=self.feature_names,
246
+ config=pipeline_config,
247
+ )
248
+
249
+ return self._pipeline
250
+
251
+ def explain_worst_trades(
252
+ self,
253
+ worst_trades: list[TradeMetrics],
254
+ n: int | None = None,
255
+ ) -> TradeShapResult:
256
+ """Explain worst trades with full SHAP analysis pipeline.
257
+
258
+ Args:
259
+ worst_trades: List of trades sorted by loss (worst first)
260
+ n: Number of trades to analyze (None = all)
261
+
262
+ Returns:
263
+ TradeShapResult with explanations, patterns, and hypotheses
264
+ """
265
+ pipeline = self._ensure_pipeline()
266
+ return pipeline.analyze_worst_trades(worst_trades, n=n)
267
+
268
+ def explain_trade(
269
+ self,
270
+ trade: TradeMetrics,
271
+ ) -> TradeShapExplanation | TradeExplainFailure:
272
+ """Explain a single trade."""
273
+ pipeline = self._ensure_pipeline()
274
+ return pipeline.explain_trade(trade)
275
+
276
+ def explain_trades(
277
+ self,
278
+ trades: list[TradeMetrics],
279
+ ) -> tuple[list[TradeShapExplanation], list[TradeExplainFailure]]:
280
+ """Explain multiple trades."""
281
+ pipeline = self._ensure_pipeline()
282
+ return pipeline.explain_trades(trades)
283
+
284
+ _UNSET: Any = object() # Sentinel for "use config default"
285
+
286
+ def extract_shap_vectors(
287
+ self,
288
+ explanations: list[TradeShapExplanation],
289
+ normalization: str | None | Any = _UNSET,
290
+ top_n_features: int | None = None,
291
+ ) -> NDArray[np.floating[Any]]:
292
+ """Extract SHAP vectors from explanations.
293
+
294
+ Args:
295
+ explanations: List of TradeShapExplanation objects
296
+ normalization: Normalization type ('l1', 'l2', 'standardize', None for none,
297
+ or omit to use config default)
298
+ top_n_features: Reduce to top N features (by mean |SHAP|)
299
+
300
+ Returns:
301
+ 2D array of shape (n_explanations, n_features)
302
+
303
+ Raises:
304
+ ValueError: If explanations is empty or normalization is invalid
305
+ """
306
+ if not explanations:
307
+ raise ValueError("Cannot extract vectors from empty explanations list")
308
+
309
+ # Stack SHAP vectors
310
+ vectors = np.vstack([exp.shap_vector for exp in explanations])
311
+
312
+ # Handle top_n reduction
313
+ if top_n_features is not None:
314
+ n_features = vectors.shape[1]
315
+ if top_n_features > n_features:
316
+ raise ValueError(
317
+ f"top_n_features ({top_n_features}) exceeds feature count ({n_features})"
318
+ )
319
+ if top_n_features < 1:
320
+ raise ValueError("top_n_features must be positive")
321
+ # Select top features by mean absolute SHAP
322
+ importance = np.abs(vectors).mean(axis=0)
323
+ top_idx = np.argsort(importance)[-top_n_features:]
324
+ vectors = vectors[:, top_idx]
325
+
326
+ # Apply normalization
327
+ # If normalization is _UNSET, use config default; if None, no normalization
328
+ if normalization is self._UNSET:
329
+ # Use config default if available (check clustering then alignment)
330
+ normalization = getattr(getattr(self.config, "clustering", None), "normalization", None)
331
+ if normalization is None:
332
+ normalization = getattr(
333
+ getattr(self.config, "alignment", None), "normalization", None
334
+ )
335
+
336
+ if normalization is not None:
337
+ vectors = normalize(vectors, normalization)
338
+
339
+ return vectors
340
+
341
+ def cluster_patterns(
342
+ self,
343
+ shap_vectors: NDArray[np.floating[Any]],
344
+ n_clusters: int | None = None,
345
+ ) -> ClusteringResult:
346
+ """Cluster SHAP vectors to identify error patterns.
347
+
348
+ Args:
349
+ shap_vectors: 2D array of SHAP vectors (n_trades, n_features)
350
+ n_clusters: Number of clusters (auto-detected if None)
351
+
352
+ Returns:
353
+ ClusteringResult with cluster assignments and metrics
354
+
355
+ Raises:
356
+ ValueError: If insufficient trades for clustering
357
+ """
358
+ n_trades = len(shap_vectors)
359
+
360
+ if n_trades < 3:
361
+ raise ValueError("Need at least 3 trades for clustering")
362
+
363
+ # Check against min_trades_for_clustering config
364
+ min_trades = getattr(self.config, "min_trades_for_clustering", 10)
365
+ if n_trades < min_trades:
366
+ raise ValueError(
367
+ f"Insufficient trades for clustering: {n_trades} < {min_trades} "
368
+ "(set min_trades_for_clustering to lower this threshold)"
369
+ )
370
+
371
+ if n_clusters is not None:
372
+ if n_clusters < 1:
373
+ raise ValueError("n_clusters must be positive")
374
+ if n_clusters > n_trades:
375
+ raise ValueError(f"n_clusters ({n_clusters}) exceeds trade count ({n_trades})")
376
+
377
+ # Get clustering config
378
+ clustering_cfg = getattr(self.config, "clustering", None)
379
+ if clustering_cfg is not None:
380
+ config = ClusteringConfig(
381
+ min_cluster_size=getattr(clustering_cfg, "min_cluster_size", 3),
382
+ distance_metric=getattr(clustering_cfg, "distance_metric", "euclidean"),
383
+ linkage_method=getattr(clustering_cfg, "linkage_method", "ward"),
384
+ )
385
+ else:
386
+ config = ClusteringConfig()
387
+
388
+ clusterer = HierarchicalClusterer(config=config)
389
+ return clusterer.cluster(shap_vectors, n_clusters=n_clusters)
390
+
391
+ def characterize_pattern(
392
+ self,
393
+ shap_vectors: NDArray[np.floating[Any]] | None = None,
394
+ clustering_result: ClusteringResult | None = None,
395
+ cluster_id: int | None = None,
396
+ feature_names: list[str] | None = None,
397
+ top_n: int = 5,
398
+ *,
399
+ # Backward-compat kwargs
400
+ cluster_assignments: list[int] | None = None,
401
+ ) -> dict[str, Any]:
402
+ """Characterize a single error pattern.
403
+
404
+ Supports both old dict-return API and new object-based API.
405
+
406
+ Args:
407
+ shap_vectors: 2D array of SHAP vectors
408
+ clustering_result: Result from cluster_patterns() (new API)
409
+ cluster_id: Which cluster to characterize
410
+ feature_names: Feature names (uses self.feature_names if None)
411
+ top_n: Number of top features to include
412
+ cluster_assignments: Cluster labels (backward compat, use clustering_result instead)
413
+
414
+ Returns:
415
+ Dict with pattern info (cluster_id, n_trades, top_features, etc.)
416
+
417
+ Raises:
418
+ ValueError: If cluster_id is invalid
419
+ """
420
+ if shap_vectors is None:
421
+ raise ValueError("shap_vectors is required")
422
+ if cluster_id is None:
423
+ raise ValueError("cluster_id is required")
424
+
425
+ # Handle backward compat: cluster_assignments list vs ClusteringResult
426
+ if cluster_assignments is not None:
427
+ # Old API: create minimal ClusteringResult-like structure
428
+ labels = cluster_assignments
429
+ n_clusters = len(set(labels))
430
+ centroids = None # Will compute from shap_vectors
431
+ elif clustering_result is not None:
432
+ labels = clustering_result.cluster_assignments
433
+ n_clusters = clustering_result.n_clusters
434
+ centroids = clustering_result.centroids
435
+ else:
436
+ raise ValueError("Either clustering_result or cluster_assignments is required")
437
+
438
+ if cluster_id < 0 or cluster_id >= n_clusters:
439
+ raise ValueError(f"cluster_id {cluster_id} out of range [0, {n_clusters})")
440
+
441
+ if feature_names is None:
442
+ feature_names = self.feature_names
443
+
444
+ # Validate feature count
445
+ if shap_vectors.shape[1] != len(feature_names):
446
+ raise ValueError(
447
+ f"Feature count mismatch: vectors have {shap_vectors.shape[1]} features, "
448
+ f"but got {len(feature_names)} feature names"
449
+ )
450
+
451
+ # Get cluster mask
452
+ cluster_mask = np.array([lbl == cluster_id for lbl in labels])
453
+ other_mask = ~cluster_mask
454
+
455
+ cluster_shap = shap_vectors[cluster_mask]
456
+ other_shap = shap_vectors[other_mask]
457
+ n_trades = int(cluster_mask.sum())
458
+
459
+ # Compute centroids if not provided
460
+ if centroids is None:
461
+ centroids = np.zeros((n_clusters, shap_vectors.shape[1]))
462
+ for c in range(n_clusters):
463
+ c_mask = np.array([lbl == c for lbl in labels])
464
+ if c_mask.sum() > 0:
465
+ centroids[c] = shap_vectors[c_mask].mean(axis=0)
466
+
467
+ # Use characterizer
468
+ char_cfg = getattr(self.config, "characterization", None)
469
+ if char_cfg is not None:
470
+ config = CharacterizationConfig(
471
+ top_n_features=top_n,
472
+ significance_level=getattr(char_cfg, "significance_level", 0.05),
473
+ )
474
+ else:
475
+ config = CharacterizationConfig(top_n_features=top_n)
476
+
477
+ characterizer = PatternCharacterizer(
478
+ feature_names=feature_names,
479
+ config=config,
480
+ )
481
+ pattern = characterizer.characterize_cluster(
482
+ cluster_shap=cluster_shap,
483
+ other_shap=other_shap,
484
+ cluster_id=cluster_id,
485
+ centroids=centroids,
486
+ )
487
+
488
+ # Return dict for backward compat
489
+ # top_features is list[tuple[str, float, float, float, bool]]
490
+ # (name, mean_shap, p_value_t, p_value_mw, is_significant)
491
+ return {
492
+ "cluster_id": cluster_id,
493
+ "n_trades": n_trades,
494
+ "top_features": [
495
+ {
496
+ "feature": tf[0],
497
+ "mean_shap": tf[1],
498
+ "p_value_t": tf[2],
499
+ "p_value_mw": tf[3],
500
+ "significant": tf[4],
501
+ }
502
+ for tf in pattern.top_features
503
+ ],
504
+ "pattern_description": pattern.description,
505
+ "separation_score": pattern.separation_score,
506
+ "distinctiveness": pattern.distinctiveness,
507
+ # Include ErrorPattern object for callers that want it
508
+ "_pattern_object": pattern,
509
+ }
510
+
511
+ @property
512
+ def hypothesis_generator(self) -> HypothesisGenerator:
513
+ """Get hypothesis generator for custom hypothesis generation."""
514
+ if self._hypothesis_generator is None:
515
+ # Get hypothesis config from TradeConfig
516
+ ext_config = getattr(self.config, "hypothesis", None)
517
+
518
+ # Convert HypothesisGenerationConfig to HypothesisConfig if needed
519
+ if ext_config is not None and hasattr(ext_config, "min_confidence"):
520
+ # It's a HypothesisGenerationConfig - convert to HypothesisConfig
521
+ config = HypothesisConfig(
522
+ template_library=getattr(ext_config, "template_library", "comprehensive"),
523
+ min_confidence=getattr(ext_config, "min_confidence", 0.5),
524
+ max_actions=getattr(ext_config, "max_hypotheses_per_cluster", 4),
525
+ )
526
+ elif isinstance(ext_config, HypothesisConfig):
527
+ config = ext_config
528
+ else:
529
+ config = HypothesisConfig()
530
+
531
+ self._hypothesis_generator = HypothesisGenerator(config=config)
532
+ return self._hypothesis_generator
533
+
534
+ def generate_hypothesis(
535
+ self,
536
+ error_pattern: ErrorPattern,
537
+ ) -> ErrorPattern:
538
+ """Generate hypothesis for an error pattern.
539
+
540
+ Args:
541
+ error_pattern: Error pattern to analyze
542
+
543
+ Returns:
544
+ ErrorPattern with hypothesis, actions, and confidence fields populated
545
+ """
546
+ return self.hypothesis_generator.generate_hypothesis(
547
+ error_pattern,
548
+ feature_names=self.feature_names,
549
+ )
550
+
551
+
552
+ __all__ = [
553
+ # Main analyzer class
554
+ "TradeShapAnalyzer",
555
+ # Pipeline (new recommended interface)
556
+ "TradeShapPipeline",
557
+ "TradeShapPipelineConfig",
558
+ # Result models
559
+ "TradeShapResult",
560
+ "TradeShapExplanation",
561
+ "TradeExplainFailure",
562
+ "ErrorPattern",
563
+ "ClusteringResult",
564
+ # Components
565
+ "TradeShapExplainer",
566
+ "TimestampAligner",
567
+ "AlignmentResult",
568
+ "HierarchicalClusterer",
569
+ "ClusteringConfig",
570
+ "PatternCharacterizer",
571
+ "CharacterizationConfig",
572
+ "FeatureStatistics",
573
+ "HypothesisGenerator",
574
+ "HypothesisConfig",
575
+ # Utilities
576
+ "normalize",
577
+ "normalize_l1",
578
+ "normalize_l2",
579
+ "standardize",
580
+ "NormalizationType",
581
+ "benjamini_hochberg",
582
+ "find_optimal_clusters",
583
+ "compute_cluster_sizes",
584
+ "compute_centroids",
585
+ "Template",
586
+ "TemplateMatcher",
587
+ "load_templates",
588
+ ]